Source file owl_neural_parallel.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# 1 "src/owl/neural/owl_neural_parallel.ml"
(** Neural network: interface of parallel engine *)
open Owl_algodiff.S
open Owl_optimise.S
module type EngineSig = sig
type param_context
type barrier =
| ASP
| BSP
| SSP
| PSP
val get : 'a -> 'b * int
val set : 'a -> 'b -> unit
val worker_num : unit -> int
val start : ?barrier:barrier -> string -> string -> unit
val register_barrier : (param_context ref -> int * string list) -> unit
val register_schedule : ('a list -> ('a * ('b * 'c) list) list) -> unit
val register_pull : (('a * 'b) list -> ('a * 'c) list) -> unit
val register_push : ('a -> ('b * 'c) list -> ('b * 'c) list) -> unit
val register_stop : (param_context ref -> bool) -> unit
end
module type ModelSig = sig
type network
val mkpar : network -> t array array
val init : network -> unit
val update : network -> t array array -> unit
val copy : network -> network
val train_generic
: ?state:Checkpoint.state
-> ?params:Params.typ
-> ?init_model:bool
-> network
-> t
-> t
-> Checkpoint.state
end
module Make (M : ModelSig) (E : EngineSig) = struct
type task =
{ mutable id : int
; mutable state : Checkpoint.state option
; mutable params : Params.typ
; mutable model : M.network
; mutable data_x : t
; mutable data_y : t
}
let make_task id params model data_x data_y =
{ id; state = None; params; model; data_x; data_y }
let delta_model model0 model1 =
let par0 = M.mkpar model0 in
let par1 = M.mkpar model1 in
let delta = Owl_utils.aarr_map2 (fun a0 a1 -> Maths.(a0 - a1)) par0 par1 in
M.update model0 delta
let local_model task =
try E.get task.id |> fst with
| Not_found ->
Owl_log.warn "set up first model";
M.init task.model;
E.set task.id task.model;
E.get task.id |> fst
let schedule task workers =
let model = local_model task in
let tasks = List.map (fun x -> x, [ task.id, model ]) workers in
tasks
let pull task vars =
let n = E.worker_num () |> float_of_int in
assert (n >= 1.);
List.map
(fun (k, model1) ->
let model0 = local_model task in
let par0 = M.mkpar model0 in
let par1 = M.mkpar model1 in
Owl_utils.aarr_map2 (fun a0 a1 -> Maths.(a0 + a1)) par0 par1 |> M.update model0;
task.model <- model0;
E.set task.id task.model;
k, model0)
vars
let push task _id vars =
let updates =
List.map
(fun (k, model) ->
task.model <- M.copy model;
let params = task.params in
let x = task.data_x in
let y = task.data_y in
let state =
match task.state with
| Some state -> M.(train_generic ~state ~params ~init_model:false model x y)
| None -> M.(train_generic ~params ~init_model:false model x y)
in
Checkpoint.(state.stop <- false);
task.state <- Some state;
delta_model model task.model;
k, M.copy model)
vars
in
updates
let stop _task _context = false
let train_generic ?params nn x y jid url =
let params =
match params with
| Some p -> p
| None -> Params.default ()
in
let id = Owl_stats.uniform_int_rvs ~a:0 ~b:max_int in
let task = make_task id params nn x y in
E.register_schedule (schedule task);
E.register_pull (pull task);
E.register_push (push task);
E.register_stop (stop task);
E.start ~barrier:E.ASP jid url
let train ?params nn x y jid url = train_generic ?params nn (Arr x) (Arr y) jid url
end