1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
let strf = Printf.sprintf
let err_empty = "Vec_env.create: env list must not be empty"
let err_action_len n m = strf "Vec_env.step: expected %d actions, got %d" n m
let err_space kind =
strf "Vec_env.create: all environments must have the same %s space" kind
type 'obs step = {
observations : 'obs array;
rewards : float array;
terminated : bool array;
truncated : bool array;
infos : Info.t array;
}
type ('obs, 'act, 'render) t = {
envs : ('obs, 'act, 'render) Env.t array;
observation_space : 'obs Space.t;
action_space : 'act Space.t;
}
let ensure_compatible envs =
let first = envs.(0) in
let obs_spec = Space.spec (Env.observation_space first) in
let act_spec = Space.spec (Env.action_space first) in
for i = 1 to Array.length envs - 1 do
let env = envs.(i) in
if not (Space.equal_spec obs_spec (Space.spec (Env.observation_space env)))
then invalid_arg (err_space "observation");
if not (Space.equal_spec act_spec (Space.spec (Env.action_space env))) then
invalid_arg (err_space "action")
done
let create envs =
match envs with
| [] -> invalid_arg err_empty
| first :: _ ->
let envs = Array.of_list envs in
ensure_compatible envs;
{
envs;
observation_space = Env.observation_space first;
action_space = Env.action_space first;
}
let num_envs t = Array.length t.envs
let observation_space t = t.observation_space
let action_space t = t.action_space
let reset t () =
let n = Array.length t.envs in
let results = Array.init n (fun i -> Env.reset t.envs.(i) ()) in
let observations = Array.map fst results in
let infos = Array.map snd results in
(observations, infos)
let step t actions =
let n = Array.length t.envs in
if Array.length actions <> n then
invalid_arg (err_action_len n (Array.length actions));
let results = Array.init n (fun i -> Env.step t.envs.(i) actions.(i)) in
let observations = Array.make n results.(0).observation in
let rewards = Array.make n 0. in
let terminated = Array.make n false in
let truncated = Array.make n false in
let infos = Array.make n Info.empty in
for i = 0 to n - 1 do
let result = results.(i) in
rewards.(i) <- result.reward;
terminated.(i) <- result.terminated;
truncated.(i) <- result.truncated;
if result.terminated || result.truncated then begin
let final_obs = Space.pack t.observation_space result.observation in
let info = Info.set "final_observation" final_obs result.info in
let info = Info.set "final_info" (Info.to_value result.info) info in
let obs, reset_info = Env.reset t.envs.(i) () in
observations.(i) <- obs;
infos.(i) <- Info.merge info reset_info
end
else begin
observations.(i) <- result.observation;
infos.(i) <- result.info
end
done;
{ observations; rewards; terminated; truncated; infos }
let close t = Array.iter Env.close t.envs