12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182(* See http://colah.github.io/posts/2015-08-Understanding-LSTMs
for a simple description of LSTM networks.
*)openBaseopenTensorflowletlstm_~type_~size_c~size_x=letcreate_vars()=Var.normal~type_[size_c+size_x;size_c]~stddev:0.1,Var.f_or_d[size_c]0.~type_inletwf,bf=create_vars()inletwi,bi=create_vars()inletwC,bC=create_vars()inletwo,bo=create_vars()inStaged.stage(fun~h~x~c->letopenOpsinleth_and_x=concatone32[h;x]inletc=sigmoid(h_and_x*^wf+bf)*c+sigmoid(h_and_x*^wi+bi)*tanh(sigmoid(h_and_x*^wC+bC))inleth=sigmoid(h_and_x*^wo+bo)*tanhcin`hh,`cc)letlstm~size_c~size_x=lstm_~type_:Float~size_c~size_xletlstm_d~size_c~size_x=lstm_~type_:Double~size_c~size_xletgru_~type_~size_h~size_x=letcreate_vars()=Var.normal~type_[size_h+size_x;size_h]~stddev:0.1,Var.f_or_d~type_[size_h]0.in(* The reset parameters *)letwr,br=create_vars()in(* The mixing variables *)letwz,bz=create_vars()in(* The contribution of x and the resetted old state *)letwH,bH=create_vars()inStaged.stage(fun~h~x->letopenOpsinleth_and_x=concatone32[h;x]in(* h partly reseted reset h *)letrh=sigmoid(h_and_x*^wr+br)*hinletrh_and_x=concatone32[rh;x]in(* the new value of h *)letnh=tanh(rh_and_x*^wH+bH)in(* How do we mix th new h and the old h *)letz=sigmoid(h_and_x*^wz+bz)in(* we mix the old h and the new h *)z*nh+(f_or_d~type_1.0-z)*h)letgru~size_h~size_x=gru_~type_:Float~size_h~size_xletgru_d~size_h~size_x=gru_~type_:Double~size_h~size_xmoduleUnfold=structletunfold_gen~xs~seq_len~input_dim~output_shape~init~f=(* xs should be tensor of dimension:
(batch_size, seq_len, input_dim)
Split it the seq_len dimension to unroll the rnn.
*)letxs=letshape=Ops.const_int~type_:Int32[-1;input_dim]inOps.splitOps.one32xs~num_split:seq_len|>List.map~f:(funn->Ops.reshapenshape)inlety_bars,_output_mem=letshape=Ops.const_int~type_:Int32output_shapeinList.foldxs~init:([],init)~f:(fun(y_bars,mem)x->lety_bar,`memmem=f~x~meminOps.reshapey_barshape::y_bars,mem)iny_barsletunfold~xs~seq_len~dim~init~f=lety_bars=unfold_gen~xs~seq_len~input_dim:dim~output_shape:[-1;1;dim]~init~finOps.concatOps.one32(List.revy_bars)letunfold_last~xs~seq_len~input_dim~output_dim~init~f=unfold_gen~xs~seq_len~input_dim~output_shape:[-1;output_dim]~init~f|>List.hdend