123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675openBaseopenTensorflow_coreopen!Tensorflowtype_1dtype_2dtype_3dmoduleShape=structtype'at=|D1:int->_1dt|D2:int*int->_2dt|D3:int*int*int->_3dtletdim_list(typea)(t:at)=matchtwith|D1d->[d]|D2(d,d')->[d;d']|D3(d,d',d'')->[d;d';d'']lettotal_dim(typea)(t:at)=matchtwith|D1d->d|D2(d,d')->d*d'|D3(d,d',d'')->d*d'*d''endexceptionShape_mismatchofintlist*intlist*stringlet()=Caml.Printexc.register_printer(function|Shape_mismatch(dims,dims',str)->letdims=List.mapdims~f:Int.to_string|>String.concat~sep:", "inletdims'=List.mapdims'~f:Int.to_string|>String.concat~sep:", "inSome(Printf.sprintf"Shape mismatch %s: %s <> %s"strdimsdims')|_->None)letshape_mismatchshape1shape2~op_name=letshape1=Shape.dim_listshape1inletshape2=Shape.dim_listshape2inraise(Shape_mismatch(shape1,shape2,op_name))moduleId=structincludeIntletcreate=letcnt=ref0infun()->incrcnt;!cntendmoduleInput_id=structtypet=Id.tendmoduleUnary=structtypet=|Sigmoid|Tanh|Relu|Softmax|Reduce_sum|Square|Negletapplytnode=matchtwith|Sigmoid->Ops.sigmoidnode|Tanh->Ops.tanhnode|Relu->Ops.relunode|Softmax->Ops.softmaxnode|Reduce_sum->Ops.reduce_sumnode|Square->Ops.squarenode|Neg->Ops.negnodeendmoduleBinary=structtypet=|Plus|Minus|Timesletop_name=function|Plus->"plus"|Minus->"minus"|Times->"times"letapplytnode1node2=matchtwith|Plus->Ops.(node1+node2)|Minus->Ops.(node1-node2)|Times->Ops.(node1*node2)endtypeinit=[`constoffloat|`normaloffloat|`truncated_normaloffloat]typepool={filter:int*int;strides:int*int;padding:[`same|`valid];avg_or_max:[`avg|`max]}typeconv2d={filter:int*int;strides:int*int;padding:[`same|`valid];in_channels:int;out_channels:int;w_init:init;b_init:init;name:stringoption}type'aop=|Input:'aop|Const:float->'aop|Unary:Unary.t*'at->'aop|Binary:Binary.t*'at*'at->'aop|Dense:init*init*_1dt*stringoption->_1dop|Pool:pool*_3dt->_3dop|Conv2d:conv2d*_3dt->_3dop|Reshape:'aShape.t*'bt->'aop|Concat:_1dtlist->_2dop|Split:_2dt*int*int->_1dop|Var:'at->'aopand'at={shape:'aShape.t;op:'aop;id:Id.t}typep=P:_t->pletshapet=t.shapeletidt=t.idletinput~shape=letid=Id.create()in{shape;op=Input;id},idletconstf~shape={shape;op=Constf;id=Id.create()}letunaryunaryt={shape=shapet;op=Unary(unary,t);id=Id.create()}letsigmoidt=unarySigmoidtlettanht=unaryTanhtletrelut=unaryRelutletsoftmaxt=unarySoftmaxtletreduce_sumt=unaryReduce_sumtletsquaret=unarySquaretletnegt=unaryNegtletbinarybinaryt1t2=ifCaml.(<>)t1.shapet2.shapethenshape_mismatcht1.shapet2.shape~op_name:(Binary.op_namebinary);{shape=shapet1;op=Binary(binary,t1,t2);id=Id.create()}letreshapet~shape={shape;op=Reshape(shape,t);id=Id.create()}letflattent=reshapet~shape:(D1(Shape.total_dimt.shape))letsplitt=letid=Id.create()inletShape.D2(num_split,d)=t.shapeinList.initnum_split~f:(funidx->{shape=D1d;op=Split(t,idx,num_split);id})letconcat=function|[]->failwith"concat called on an empty list"|hd::_asl->letshape{shape=Shape.D1shape;_}=shapeinlethd_shape=shapehdinList.iterl~f:(funt->ifhd_shape<>shapetthenraise(Shape_mismatch([hd_shape],[shapet],"concat")));{shape=D2(List.lengthl,hd_shape);op=Concatl;id=Id.create()}letvart={shape=t.shape;op=Vart;id=Id.create()}letdense'?(w_init=`const0.)?(b_init=`const0.)?namedim=letid=Id.create()inStaged.stage(funt->{shape=D1dim;op=Dense(w_init,b_init,t,name);id})letdense?w_init?b_init?namedim=Staged.unstage(dense'?w_init?b_init?namedim)letconv_sizes~input_height~input_width~filter_height~filter_width~stride_height~stride_width~padding=letinput_height,input_width=matchpaddingwith|`same->input_height,input_width|`valid->input_height-filter_height+1,input_width-filter_width+1in(input_height-1)/stride_height+1,(input_width-1)/stride_width+1letpadding_str=function|`same->"SAME"|`valid->"VALID"letpool~avg_or_maxt~filter~strides~padding=letinput_height,input_width,input_channels=matcht.shapewith|Shape.D3(d,d',d'')->d,d',d''inletfilter_height,filter_width=filterinletstride_height,stride_width=stridesinletoutput_height,output_width=conv_sizes~input_height~input_width~filter_height~filter_width~stride_height~stride_width~paddinginletpool={filter;strides;padding;avg_or_max}in{shape=D3(output_height,output_width,input_channels);op=Pool(pool,t);id=Id.create()}letmax_pool=pool~avg_or_max:`maxletavg_pool=pool~avg_or_max:`avgletconv2d'?(w_init=`const0.)?(b_init=`const0.)?name~filter~out_channels~strides~padding()=letid=Id.create()inStaged.stage(funt->letinput_height,input_width,input_channels=matcht.shapewith|Shape.D3(d,d',d'')->d,d',d''inletconv2d={filter;strides;padding;in_channels=input_channels;out_channels;w_init;b_init;name}inletfilter_height,filter_width=filterinletstride_height,stride_width=stridesinletoutput_height,output_width=conv_sizes~input_height~input_width~filter_height~filter_width~stride_height~stride_width~paddingin{shape=D3(output_height,output_width,out_channels);op=Conv2d(conv2d,t);id})letconv2d?w_init?b_init?name~filter~out_channels~strides~padding()=Staged.unstage(conv2d'?w_init?b_init?name~filter~out_channels~strides~padding())letcreate_vardims~init~type_=matchinitwith|`constf->Var.f_or_ddimsf~type_|`normalstddev->Var.normaldims~stddev~type_|`truncated_normalstddev->Var.truncated_normaldims~stddev~type_letbuild_nodet~type_=letinputs=Hashtbl.create(moduleId)inletexplicit_vars=Hashtbl.create(moduleId)inletdense_vars=Hashtbl.create(moduleId)inletconv_vars=Hashtbl.create(moduleId)inletsplits=Hashtbl.create(moduleId)inletvar_names=Hashtbl.create(moduleNode.Id)inletall_nodes=Hashtbl.create(moduleId)inletrecwalk(Pt)=letnode=matcht.opwith|Unary(unary,t)->Unary.applyunary(walk(Pt))|Binary(binary,t1,t2)->Binary.applybinary(walk(Pt1))(walk(Pt2))|Constf->Ops.f_or_d~shape:(Shape.dim_listt.shape)~type_f|Dense(w_init,b_init,input,name_opt)->letShape.D1input_shape=input.shapeinletShape.D1shape=t.shapeinletw,b=Hashtbl.find_or_adddense_varst.id~default:(fun()->letw=create_var~type_~init:w_init[input_shape;shape]inletb=create_var~type_~init:b_init[shape]inOption.itername_opt~f:(funname->Hashtbl.setvar_names~key:(Node.idw)~data:(name^"/"^name^"_weights");Hashtbl.setvar_names~key:(Node.idb)~data:(name^"/"^name^"_biases"));w,b)inOps.(walk(Pinput)*^w+b)|Input->Hashtbl.find_or_addinputst.id~default:(fun()->Ops.placeholder~type_(-1::Shape.dim_listt.shape))|>Ops.Placeholder.to_node|Pool(pool,t)->letfilter_height,filter_width=pool.filterinletstride_height,stride_width=pool.stridesinletpool_ops=matchpool.avg_or_maxwith|`max->Ops.maxPool|`avg->Ops.avgPoolin(* [...Pool] only exists for float and not for double so cast to float. *)pool_ops(walk(Pt)|>Ops.cast~type_:Float)~ksize:[1;filter_height;filter_width;1]~strides:[1;stride_height;stride_width;1]~padding:(padding_strpool.padding)|>Ops.cast~type_|Conv2d(conv2d,u)->letfilter_height,filter_width=conv2d.filterinletout_channels=conv2d.out_channelsinletw,b,in_channels=Hashtbl.find_or_addconv_varst.id~default:(fun()->letin_channels=conv2d.in_channelsinletw=create_var~type_~init:conv2d.w_init[filter_height;filter_width;in_channels;out_channels]inletb=create_var~type_~init:conv2d.b_init[out_channels]inOption.iterconv2d.name~f:(funname->Hashtbl.setvar_names~key:(Node.idw)~data:(name^"/"^name^"_filters");Hashtbl.setvar_names~key:(Node.idb)~data:(name^"/"^name^"_biases"));w,b,in_channels)inifin_channels<>conv2d.in_channelsthenshape_mismatch(D1in_channels)(D1conv2d.in_channels)~op_name:"conv2d in-channels";letstride_height,stride_width=conv2d.stridesinletstrides=[1;stride_height;stride_width;1]inOps.(conv2D~strides~padding:(padding_strconv2d.padding)(walk(Pu))w+b)|Reshape(shape,u)->letdim_list=Shape.dim_listshapeinlettotal_dim_output=Shape.total_dimshapeinlettotal_dim_input=Shape.total_dimu.shapeiniftotal_dim_output<>total_dim_inputthenshape_mismatchshapeu.shape~op_name:"reshape";Ops.reshape(walk(Pu))(Ops.const_int~type_:Int32(-1::dim_list))|Concatlist->List.maplist~f:(funt->walk(Pt))|>Ops.(concatone32)|Split(u,idx,num_split)->letlist=Hashtbl.find_or_addsplitst.id~default:(fun()->Ops.(split~num_splitone32(walk(Pu))))inList.nth_exnlistidx|Varu->Hashtbl.find_or_addexplicit_varst.id~default:(fun()->letdim_list=Shape.dim_listt.shapeinVar.createdim_list~type_~init:(walk(Pu)))inHashtbl.setall_nodes~key:t.id~data:node;nodeinwalkt,inputs,var_names,explicit_vars,all_nodesmoduleOptimizer=struct(* We should use some inline records here when they will be available. *)typet=|Gradient_descentoffloat|Adamoffloat*floatoption*floatoption*floatoption|Momentumoffloat*floatletgradient_descent~learning_rate=Gradient_descentlearning_rateletadam~learning_rate?beta1?beta2?epsilon()=Adam(learning_rate,beta1,beta2,epsilon)letmomentum~learning_rate~momentum=Momentum(learning_rate,momentum)letget?varsf?varsdt~loss=matchtwith|Gradient_descentlearning_rate->Optimizers.gradient_descent_minimizer?varsf?varsd~learning_rate:(Ops.flearning_rate)loss|Adam(learning_rate,beta1,beta2,epsilon)->Optimizers.adam_minimizerloss?varsf?varsd~learning_rate:(Ops.flearning_rate)?beta1:(Option.mapbeta1~f:Ops.f)?beta2:(Option.mapbeta2~f:Ops.f)?epsilon:(Option.mapepsilon~f:Ops.f)|Momentum(learning_rate,momentum)->Optimizers.momentum_minimizerloss?varsf?varsd~learning_rate:(Ops.flearning_rate)~momentum:(Ops.fmomentum)endmoduleLoss=structtypet=|Cross_entropyof[`sum|`mean]|L2of[`sum|`mean]letcross_entropysum_mean=Cross_entropysum_meanletl2sum_mean=L2sum_meanletgett~sample_ys~model_ys=letreduce=function|`sum->Ops.reduce_sum|`mean->Ops.reduce_meaninmatchtwith|Cross_entropysum_mean->Ops.(neg(reducesum_mean(sample_ys*logmodel_ys)))|L2sum_mean->Ops.(reducesum_mean(square(sample_ys-model_ys)))endmoduleModel=structtype'afnn='attype('a,'b,'c)t={node:'bNode.t;placeholder:'bOps.Placeholder.t;inputs:(Id.t,'bOps.Placeholder.t)Hashtbl.t;save_nodes:(string,[`unit]Node.t)Hashtbl.t;load_and_assign_nodes:(string,Node.plist)Hashtbl.t;var_names:(Node.Id.t,string)Hashtbl.t;explicit_vars:(Id.t,'bNode.t)Hashtbl.t;all_nodes:(Id.t,'bNode.t)Hashtbl.t;eq:('c*'b)Tensor.eq}letcreate(typea)(typeb)(eq:(b*a)Tensor.eq)fnn=letcreateeq~type_=letnode,inputs,var_names,explicit_vars,all_nodes=build_node(Pfnn)~type_inletplaceholder=Ops.placeholder~type_(Shape.dim_listfnn.shape)in{node;placeholder;inputs;save_nodes=Hashtbl.create(moduleString);load_and_assign_nodes=Hashtbl.create(moduleString);var_names;explicit_vars;all_nodes;eq}inmatcheqwith|Tensor.Float->(createFloat~type_:Float:(_,a,b)t)|Tensor.Double->failwith"The Double type is not supported."letpredict(typea)(typeb)(t:(_,a,b)t)?output_id(inputs:(Input_id.t*(float,b)Tensor.t)list)=letpredictf_or_d_inputf_or_d_output=letinputs=List.mapinputs~f:(fun(id,tensor)->matchHashtbl.findt.inputsidwith|None->failwith"missing input"|Someplaceholder->f_or_d_inputplaceholdertensor)inSession.run~inputs(f_or_d_outputt.node)inletoutput_node=matchoutput_idwith|None->t.node|Someid->matchHashtbl.findt.all_nodesidwith|None->failwith"Cannot find any node with the proper id"|Somenode->nodeinmatchNode.output_typeoutput_node,t.eqwith|Node.Type.Float,Tensor.Float->(predictSession.Input.floatSession.Output.float:(float,b)Tensor.t)|Node.Type.Double,Tensor.Double->(predictSession.Input.doubleSession.Output.double:(float,b)Tensor.t)|_->.letfit(typea)(typeb)?(addn_inputs:(Input_id.t*(float,b)Tensor.t)listoption)?batch_size?explicit_vars(t:(_,a,b)t)~loss~optimizer~epochs~input_id~xs~ys=letfitplaceholdernodef_or_dscalar_f_or_d=letloss=Loss.getloss~sample_ys:(Ops.Placeholder.to_nodeplaceholder)~model_ys:nodeinletoptimizer=letvarsf,varsd=matchexplicit_varswith|None->None,None|Someexplicit_vars->matcht.eqwith|Tensor.Float->letnodes=List.mapexplicit_vars~f:(funev->Hashtbl.find_exnt.explicit_varsev.id)inSome(nodes:[`float]Node.tlist),None|Tensor.Double->letnodes=List.mapexplicit_vars~f:(funev->Hashtbl.find_exnt.explicit_varsev.id)inNone,Some(nodes:[`double]Node.tlist)inOptimizer.getoptimizer~loss?varsf?varsdinletsamples=(Tensor.dimsxs).(0)inletbatch_size=matchbatch_sizewith|None->None|Somebatch_sizewhenbatch_size>samples->None|Some_ass->sinletfind_inputid=matchHashtbl.findt.inputsidwith|None->failwith"missing input"|Someplaceholder->placeholderinletaddn_inputs=Option.value_mapaddn_inputs~default:[]~f:(List.map~f:(fun(id,tensor)->f_or_d(find_inputid)tensor))inletxs_placeholder=find_inputinput_idinletinputs~xs~ys=f_or_dxs_placeholderxs::f_or_dt.placeholderys::addn_inputsinforepoch=1toepochsdoletinputs=matchbatch_sizewith|None->inputs~xs~ys|Somebatch_size->letoffset=((epoch-1)*batch_size)%(samples-batch_size)inletxs=Tensor.sub_leftxsoffsetbatch_sizeinletys=Tensor.sub_leftysoffsetbatch_sizeininputs~xs~ysinleterr=Session.run~inputs~targets:optimizer(scalar_f_or_dloss)inStdio.printf"Epoch: %6d/%-6d Loss: %.2f\n%!"epochepochserrdoneinmatchNode.output_typet.node,t.eqwith|Node.Type.Float,Tensor.Float->fitt.placeholdert.nodeSession.Input.floatSession.Output.scalar_float|Node.Type.Double,Tensor.Double->fitt.placeholdert.nodeSession.Input.doubleSession.Output.scalar_double|_->.letall_vars_with_namest=Var.get_all_vars[Node.Pt.node]|>List.filter_map~f:(funvar->letname=letnode_id=matchvarwith|Node.Pv->Node.idvinHashtbl.findt.var_namesnode_idinOption.mapname~f:(funname->name,var))letinput_list(typea)(typeb)(t:(_,a,b)t)(inputs:(Input_id.t*(float,b)Tensor.t)list)=List.mapinputs~f:(fun(id,tensor)->matchHashtbl.findt.inputsidwith|None->failwith"missing input"|Someplaceholder->matchNode.output_typet.node,t.eqwith|Node.Type.Float,Tensor.Float->Session.Input.floatplaceholdertensor|Node.Type.Double,Tensor.Double->Session.Input.doubleplaceholdertensor|_->.)letsave?(inputs=[])t~filename=letsave_node=Hashtbl.find_or_addt.save_nodesfilename~default:(fun()->letall_vars_with_names=all_vars_with_namestinifList.is_emptyall_vars_with_namesthenfailwith"No variable to save can be found (only named variables are saved)";Ops.save~filenameall_vars_with_names)inSession.run~inputs:(input_listtinputs)~targets:[Node.Psave_node]Session.Output.emptyletload?(inputs=[])t~filename=letload_and_assign_nodes=Hashtbl.find_or_addt.load_and_assign_nodesfilename~default:(fun()->letfilename=Ops.const_string0filenameinList.map(all_vars_with_namest)~f:(fun(var_name,(Node.Pvar))->Ops.restore~type_:(Node.output_typevar)filename(Ops.const_string0var_name)|>Ops.assignvar|>funnode->Node.Pnode))inSession.run~inputs:(input_listtinputs)~targets:load_and_assign_nodesSession.Output.emptyendlet(+)t1t2=binaryPlust1t2let(-)t1t2=binaryMinust1t2let(*)t1t2=binaryTimest1t2