1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
let err_capacity = "Buffer.create: capacity must be positive"
let err_empty = "Buffer.sample: buffer is empty"
let err_batch_size = "Buffer.sample: batch_size must be positive"
type ('obs, 'act) transition = {
observation : 'obs;
action : 'act;
reward : float;
next_observation : 'obs;
terminated : bool;
truncated : bool;
}
type ('obs, 'act) t = {
capacity : int;
mutable size : int;
mutable pos : int;
mutable observations : 'obs array;
mutable actions : 'act array;
rewards : float array;
mutable next_observations : 'obs array;
terminateds : bool array;
truncateds : bool array;
}
let create ~capacity =
if capacity <= 0 then invalid_arg err_capacity;
{
capacity;
size = 0;
pos = 0;
observations = [||];
actions = [||];
rewards = Array.make capacity 0.0;
next_observations = [||];
terminateds = Array.make capacity false;
truncateds = Array.make capacity false;
}
let ensure_init buf (tr : _ transition) =
if Array.length buf.observations = 0 then begin
buf.observations <- Array.make buf.capacity tr.observation;
buf.actions <- Array.make buf.capacity tr.action;
buf.next_observations <- Array.make buf.capacity tr.next_observation
end
let add buf tr =
ensure_init buf tr;
buf.observations.(buf.pos) <- tr.observation;
buf.actions.(buf.pos) <- tr.action;
buf.rewards.(buf.pos) <- tr.reward;
buf.next_observations.(buf.pos) <- tr.next_observation;
buf.terminateds.(buf.pos) <- tr.terminated;
buf.truncateds.(buf.pos) <- tr.truncated;
buf.pos <- (buf.pos + 1) mod buf.capacity;
if buf.size < buf.capacity then buf.size <- buf.size + 1
let clear buf =
buf.size <- 0;
buf.pos <- 0
let sample_indices buf ~batch_size =
if buf.size = 0 then invalid_arg err_empty;
if batch_size <= 0 then invalid_arg err_batch_size;
let n = min batch_size buf.size in
let raw = Nx.randint Nx.int32 ~high:buf.size [| n |] 0 in
let idx : Int32.t array = Nx.to_array raw in
(idx, n)
let sample buf ~batch_size =
let idx, n = sample_indices buf ~batch_size in
Array.init n (fun i ->
let j = Int32.to_int idx.(i) in
{
observation = buf.observations.(j);
action = buf.actions.(j);
reward = buf.rewards.(j);
next_observation = buf.next_observations.(j);
terminated = buf.terminateds.(j);
truncated = buf.truncateds.(j);
})
let sample_arrays buf ~batch_size =
let idx, n = sample_indices buf ~batch_size in
let get arr i = arr.(Int32.to_int idx.(i)) in
let observations = Array.init n (get buf.observations) in
let actions = Array.init n (get buf.actions) in
let rewards = Array.init n (get buf.rewards) in
let next_observations = Array.init n (get buf.next_observations) in
let terminated = Array.init n (get buf.terminateds) in
let truncated = Array.init n (get buf.truncateds) in
(observations, actions, rewards, next_observations, terminated, truncated)
let size buf = buf.size
let is_full buf = buf.size = buf.capacity
let capacity buf = buf.capacity