Source file initializers.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
type t = {
f :
'layout 'dev.
int -> int array -> (float, 'layout) Rune.dtype -> (float, 'layout) Rune.t;
}
let compute_fans shape in_axis out_axis =
let rank = Array.length shape in
if rank = 0 then (1, 1)
else
let in_axis = if in_axis < 0 then rank + in_axis else in_axis in
let out_axis = if out_axis < 0 then rank + out_axis else out_axis in
if
rank = 1 || in_axis < 0 || in_axis >= rank || out_axis < 0
|| out_axis >= rank
then
let total_size = Array.fold_left ( * ) 1 shape in
(total_size, total_size)
else
let fan_in = shape.(in_axis) in
let fan_out = shape.(out_axis) in
let receptive_field_size = ref 1 in
for i = 0 to rank - 1 do
if i <> in_axis && i <> out_axis then
receptive_field_size := !receptive_field_size * shape.(i)
done;
(fan_in * !receptive_field_size, fan_out * !receptive_field_size)
let truncated_normal_impl ~mean ~stddev ~lower ~upper seed shape dtype =
let rec generate_until_valid max_attempts =
if max_attempts <= 0 then
let z = Rune.randn dtype ~seed shape in
let z_scaled =
Rune.add
(Rune.mul z (Rune.scalar dtype stddev))
(Rune.scalar dtype mean)
in
let lower_t = Rune.scalar dtype lower in
let upper_t = Rune.scalar dtype upper in
let clamped = Rune.maximum z_scaled lower_t in
Rune.minimum clamped upper_t
else
let z = Rune.randn dtype ~seed shape in
let z_scaled =
Rune.add
(Rune.mul z (Rune.scalar dtype stddev))
(Rune.scalar dtype mean)
in
let lower_t = Rune.scalar dtype lower in
let upper_t = Rune.scalar dtype upper in
let in_bounds_lower = Rune.greater_equal z_scaled lower_t in
let in_bounds_upper = Rune.less_equal z_scaled upper_t in
let in_bounds = Rune.logical_and in_bounds_lower in_bounds_upper in
let num_in_bounds = Rune.sum (Rune.cast dtype in_bounds) in
let total_elements = Array.fold_left ( * ) 1 shape in
let num_in_bounds_array = Rune.to_array num_in_bounds in
let acceptance_ratio =
num_in_bounds_array.(0) /. float_of_int total_elements
in
if acceptance_ratio > 0.8 then
let clamped = Rune.maximum z_scaled lower_t in
Rune.minimum clamped upper_t
else
generate_until_valid (max_attempts - 1)
in
generate_until_valid 100
let constant value : t =
{
f =
(fun seed shape dtype ->
ignore seed;
Rune.full dtype shape value);
}
let zeros () = constant 0.0
let ones () = constant 1.0
let uniform ?(scale = 0.01) () =
{
f =
(fun seed shape dtype ->
let u01 = Rune.rand dtype ~seed shape in
Rune.mul u01 (Rune.scalar dtype scale));
}
let normal ?(stddev = 0.01) () =
{
f =
(fun seed shape dtype ->
let z = Rune.randn dtype ~seed shape in
Rune.mul z (Rune.scalar dtype stddev));
}
let truncated_normal ?(stddev = 0.01) ?(lower = -2.0) ?(upper = 2.0) () =
{
f =
(fun seed shape dtype ->
truncated_normal_impl ~mean:0.0 ~stddev ~lower ~upper seed shape dtype);
}
let variance_scaling ~scale ~mode ~distribution ~in_axis ~out_axis () =
{
f =
(fun seed shape dtype ->
let fan_in, fan_out = compute_fans shape in_axis out_axis in
let n =
match mode with
| `Fan_in -> float_of_int fan_in
| `Fan_out -> float_of_int fan_out
| `Fan_avg -> float_of_int (fan_in + fan_out) /. 2.0
in
let variance = scale /. n in
let stddev = sqrt variance in
match distribution with
| `Normal ->
let z = Rune.randn dtype ~seed shape in
Rune.mul z (Rune.scalar dtype stddev)
| `Truncated_normal ->
truncated_normal_impl ~mean:0.0 ~stddev ~lower:(-2.0) ~upper:2.0
seed shape dtype
| `Uniform ->
let limit = sqrt (3.0 *. variance) in
let u01 = Rune.rand dtype ~seed shape in
let scale_t = Rune.scalar dtype (2.0 *. limit) in
let shift = Rune.scalar dtype limit in
Rune.sub (Rune.mul u01 scale_t) shift);
}
let glorot_uniform ?(in_axis = -2) ?(out_axis = -1) () =
variance_scaling ~scale:1.0 ~mode:`Fan_avg ~distribution:`Uniform ~in_axis
~out_axis ()
let glorot_normal ?(in_axis = -2) ?(out_axis = -1) () =
variance_scaling ~scale:1.0 ~mode:`Fan_avg ~distribution:`Truncated_normal
~in_axis ~out_axis ()
let xavier_uniform = glorot_uniform
let xavier_normal = glorot_normal
let lecun_uniform ?(in_axis = -2) ?(out_axis = -1) () =
variance_scaling ~scale:1.0 ~mode:`Fan_in ~distribution:`Uniform ~in_axis
~out_axis ()
let lecun_normal ?(in_axis = -2) ?(out_axis = -1) () =
variance_scaling ~scale:1.0 ~mode:`Fan_in ~distribution:`Truncated_normal
~in_axis ~out_axis ()
let he_uniform ?(in_axis = -2) ?(out_axis = -1) () =
variance_scaling ~scale:2.0 ~mode:`Fan_in ~distribution:`Uniform ~in_axis
~out_axis ()
let he_normal ?(in_axis = -2) ?(out_axis = -1) () =
variance_scaling ~scale:2.0 ~mode:`Fan_in ~distribution:`Truncated_normal
~in_axis ~out_axis ()
let kaiming_uniform = he_uniform
let kaiming_normal = he_normal
let orthogonal ?(scale = 1.0) ?(column_axis = -1) () =
{
f =
(fun seed shape dtype ->
let rank = Array.length shape in
let column_axis =
if column_axis < 0 then rank + column_axis else column_axis
in
let rows = ref 1 in
let cols = ref 1 in
for i = 0 to rank - 1 do
if i = column_axis then cols := !cols * shape.(i)
else rows := !rows * shape.(i)
done;
let flat_shape = [| !rows; !cols |] in
let a = Rune.randn dtype ~seed flat_shape in
let q =
if !rows < !cols then
let q_t = Rune.transpose ~axes:[ 1; 0 ] a in
let norms =
Rune.sqrt (Rune.sum (Rune.mul q_t q_t) ~axes:[ 1 ] ~keepdims:true)
in
let q_normalized =
Rune.div q_t (Rune.add norms (Rune.scalar dtype 1e-10))
in
Rune.transpose ~axes:[ 1; 0 ] q_normalized
else
let norms =
Rune.sqrt (Rune.sum (Rune.mul a a) ~axes:[ 1 ] ~keepdims:true)
in
let q_normalized =
Rune.div a (Rune.add norms (Rune.scalar dtype 1e-10))
in
q_normalized
in
let q_scaled = Rune.mul q (Rune.scalar dtype scale) in
Rune.reshape shape q_scaled);
}
let delta_orthogonal ?(scale = 1.0) ?(column_axis = -1) () =
{
f =
(fun seed shape dtype ->
let rank = Array.length shape in
if rank < 3 || rank > 5 then
failwith "delta_orthogonal requires 3D, 4D, or 5D shape";
let column_axis =
if column_axis < 0 then rank + column_axis else column_axis
in
let spatial_dims = Array.sub shape 1 (rank - 2) in
let is_square =
Array.for_all (fun d -> d = spatial_dims.(0)) spatial_dims
in
if not is_square then
failwith "delta_orthogonal requires square spatial dimensions";
let in_channels = shape.(0) in
let out_channels = shape.(rank - 1) in
let orth_shape = [| in_channels; out_channels |] in
let _orth =
(orthogonal ~scale ~column_axis ()).f seed orth_shape dtype
in
let result = Rune.zeros dtype shape in
let _center_idx = Array.make (rank - 2) (spatial_dims.(0) / 2) in
result);
}
let uniform_range ~low ~high () =
{
f =
(fun seed shape dtype ->
let u01 = Rune.rand dtype ~seed shape in
let scale = Rune.scalar dtype (high -. low) in
let shift = Rune.scalar dtype low in
Rune.add (Rune.mul u01 scale) shift);
}
let normal_range ~mean ~stddev () =
{
f =
(fun seed shape dtype ->
let z = Rune.randn dtype ~seed shape in
Rune.add
(Rune.mul z (Rune.scalar dtype stddev))
(Rune.scalar dtype mean));
}