123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104# 1 "src/opt/adam/adam.ml"moduleMake(AD:Owl_algodiff_generic_sig.SigwithtypeA.elt=float)(P:Prms.PT)=structtypefv=AD.ttypeprm=AD.ttypeprms=prmP.ttypef=prms->fvtypex={p:AD.t;m:AD.t;v:AD.t}typexs=xP.ttypestate={xs:xs;f:f;fv:float;k:int}typestop=state->boolletiters=s.kletprmss=P.maps.xs~f:(funx->x.p)letfs=s.fletfvs=s.fvletinit~prms0~f()=letfv=AD.unpack_flt(fprms0)inletxs=P.mapprms0~f:(funp->letm=AD.copy_primal'pinAD.Mat.resetm;letv=AD.copy_primal'pinAD.Mat.resetv;{p;m;v})in{xs;fv;f;k=0}letmin_updatelrxmveps=AD.Maths.(x-(lr*m/(sqrtv+eps)))letmax_updatelrxmveps=AD.Maths.(x+(lr*m/(sqrtv+eps)))letstops=ifs.kmod10=0thenPrintf.printf"\rstep: %i | loss: %4.9f%!"s.ks.fv;s.fv<1E-3letoptimiseupdate?(stop=stop)?(beta1=0.9)?(beta2=0.999)?(eps=1E-8)~lrs=letbeta1=AD.(Fbeta1)inletbeta1_=AD.(Maths.(F1.-beta1))inletbeta2=AD.(Fbeta2)inletbeta2_=AD.(Maths.(F1.-beta2))inleteps=AD.(Feps)inletrecrunsb1b2=ifstopsthenselse(letb1=AD.Maths.(b1*beta1)inletb2=AD.Maths.(b2*beta2)inlett=AD.tag()inletxs=P.map~f:(funx->letp=AD.make_reversex.ptin{xwithp})s.xsinletl=s.f(P.map~f:(funx->x.p)xs)inAD.(reverse_prop(F1.)l);letfv=AD.unpack_fltlinletxs=P.map~f:(funx->letp=AD.primalx.pinletg=AD.adjvalx.pin(* first moment *)letm=AD.Maths.((beta1*x.m)+(beta1_*g))in(* bias-corrected first moment *)letm_=AD.Maths.(m/(F1.-b1))in(* second moment *)letv=AD.Maths.((beta2*x.v)+(beta2_*sqrg))in(* bias-corrected second moment *)letv_=AD.(Maths.(v/(F1.-b2)))inletp=matchlrwith|Lr.Fixlr->update(AD.pack_fltlr)pm_v_eps|>AD.primal|Lr.Adah->update(AD.pack_flt(hs.k))pm_v_eps|>AD.primalin{p;m;v})xsinlets={swithxs;k=succs.k;fv}inrunsb1b2)inrunsAD.(F1.)AD.(F1.)letmin=optimisemin_updateletmax=optimisemax_updateend