123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073# 1 "src/base/dense/owl_base_dense_ndarray_generic.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2019 Liang Wang <liang.wang@cl.cam.ac.uk>
*)[@@@warning"-32"]openBigarrayopenOwl_typestype('a,'b)t=('a,'b,c_layout)Genarray.ttype('a,'b)kind=('a,'b)Bigarray.kindmoduleScalar=Owl_base_maths(* Prepend an array with ones to the given length *)let_prepend_dimsdimsdesired_len=letdims_len=Array.lengthdimsinifdims_len>=desired_lenthendimselse(Array.append(Array.make(desired_len-dims_len)1)dims)let_get_broadcasted_dimsdims_adims_b=letlen_c=Stdlib.max(Array.lengthdims_a)(Array.lengthdims_b)inletext_dims_a=_prepend_dimsdims_alen_cinletext_dims_b=_prepend_dimsdims_blen_cinletdims_c=Array.makelen_c0infori=0tolen_c-1doletval_a=ext_dims_a.(i)inletval_b=ext_dims_b.(i)inifval_a=val_bthendims_c.(i)<-val_aelseifval_a!=1&&val_b!=1thenraise(Invalid_argument"The arrays cannot be broadcast into the same shape")elsedims_c.(i)<-(Stdlib.maxval_aval_b)done;(ext_dims_a,ext_dims_b,dims_c)(* Increment the index array, with respect to the dimensions array *)let_next_indexinddims=letnum_dims=Array.lengthindinletp=ref(num_dims-1)inletok=reffalseinwhile!p>=0&¬!okdoifind.(!p)+1<dims.(!p)then(ind.(!p)<-(ind.(!p)+1);ok:=true;)else(ind.(!p)<-0;p:=!p-1;)done;!oklet_get_broadcasted_indexinddims=letnum_dims=Array.lengthdimsinletcalc_fun=(funi->letmax_ind=dims.(i)inletind_val=ind.(i)inifind_val<max_indthenind_valelse(ifmax_ind=1then0elseraise(Invalid_argument"not broadcasted correctly")))in(Array.initnum_dimscalc_fun)let_apply_permarrperm=Array.init(Array.lengtharr)(funi->arr.(perm.(i)))let_draw_int_samplesreplacementrangecount=ifnotreplacement&&count>rangethenraise(Invalid_argument"cannot draw that many samples from the given range, without replacement")else(letpop_cnt=refrangeinletpop=Array.init!pop_cnt(funi->i)inletrand_gen=Random.State.make_self_init()inletdraw_fun=(fun_->letindex=Random.State.intrand_gen!pop_cntinletsample=pop.(index)inifreplacementthensampleelse(pop_cnt:=!pop_cnt-1;pop.(index)<-pop.(!pop_cnt);(* eliminate sample by swapping with last element *)sample))inArray.initcountdraw_fun)let_enumerate_slice_defdim?(step)startstop=letstart=ifstart<0thendim+startelsestartinletstop=ifstop<0thendim+stopelsestopinletstep=matchstepwith|Somex->x|None->if(start<=stop)then1else-1inassert(((start<=stop)&&(step>0))||((start>stop)&&(step<0)));letstep_abs=Stdlib.absstepinletlen=((Stdlib.abs(stop-start))+step_abs)/step_absin(Array.initlen(funi->start+i*step))(* Rewrite the indices s.t. for each dimension they are a list of explicit indices *)let_expand_slice_indicesindex_listdims=letrank=Array.lengthdimsinletsdef_len=List.lengthindex_listin(* the number of dimensions this slice specifies *)let_expand_slice_index=(funiind->matchindwith|[]->Array.initdims.(i)(funi->i)|[start]->_enumerate_slice_defdims.(i)startstart|[start;stop]->_enumerate_slice_defdims.(i)startstop|[start;stop;step]->_enumerate_slice_defdims.(i)~step:stepstartstop|x->Array.of_listx)inArray.append(Array.of_list(List.mapi_expand_slice_indexindex_list))(* for the axis where the index was specified *)(Array.init(rank-sdef_len)(* the rest of the axis is just all of them *)(funp->Array.initdims.(p+sdef_len)(funi->i)))letresetx=let_kind=Genarray.kindxinGenarray.fillx(Owl_const.zero_kind)letemptykinddims=Genarray.createkindc_layoutdimsletcreatekinddimsvalue=letx=emptykinddimsinGenarray.fillxvalue;xletcreate_~outa=Genarray.filloutaletzeroskinddims=createkinddims(Owl_const.zerokind)letzeros_~out=resetoutletoneskinddims=createkinddims(Owl_const.onekind)letones_~out=Genarray.(fillout(Owl_const.one(kindout)))letshapex=Genarray.dimsxletnth_dimxi=Genarray.nth_dimxiletnum_dimsx=Array.length(shapex)letnumelx=Owl_utils.numelxletkindx=Genarray.kindxletgetxindex=(Genarray.getxindex)letsetxindexvalue=(Genarray.setxindexvalue)leteyekindn=letm=zeroskind[|n;n|]infori=0ton-1dosetm[|i;i|](Owl_const.onekind)done;m(*TODO: optimise, test *)letget_sliceindex_listvarr=letdims=shapevarrinletrank=Array.lengthdimsinletindex_array=_expand_slice_indicesindex_listdimsinletslice_dims=Array.map(funa->Array.lengtha)index_arrayinletslice_varr=empty(kindvarr)slice_dimsinletslice_ind=Array.makerank0inletoriginal_ind=Array.makerank0inletshould_stop=reffalseinwhilenot!should_stopdofori=0torank-1dooriginal_ind.(i)<-(index_array.(i)).(slice_ind.(i))done;Genarray.setslice_varrslice_ind(Genarray.getvarroriginal_ind);ifnot(_next_indexslice_indslice_dims)thenshould_stop:=truedone;slice_varr(*TODO: optimise, test *)letset_sliceindex_listvarrslice_varr=letdims=shapevarrinletrank=Array.lengthdimsinletindex_array=_expand_slice_indicesindex_listdimsinletslice_dims=Array.map(funa->Array.lengtha)index_arrayinletslice_varr=reshapeslice_varrslice_dimsinletslice_ind=Array.makerank0inletoriginal_ind=Array.makerank0inletshould_stop=reffalseinwhilenot!should_stopdofori=0torank-1dooriginal_ind.(i)<-(index_array.(i)).(slice_ind.(i))done;Genarray.setvarroriginal_ind(Genarray.getslice_varrslice_ind);ifnot(_next_indexslice_indslice_dims)thenshould_stop:=truedone(* The result shares the underlying buffer with original, not a copy *)letreshapexd=letminus_one=Owl_utils.Array.countd(-1)inassert(minus_one<=1);ifminus_one=0thenreshapexdelse(letn=numelxinletm=Array.fold_right(*)d(-1)inlete=Array.map(funa->ifa=-1thenn/melsea)dinreshapexe)(* Return the array as a contiguous block, without copying *)letflattenx=reshapex[|(numelx)|]letfillxa=Genarray.fillxaletcopyx=lety=empty(kindx)(shapex)inGenarray.blitxy;yletcopy_~outx=letsrc=flattenxinletdst=flattenoutinGenarray.blitsrcdstletreshape_~outx=ifnot(x==out)thencopy_~outxletreversex=letn=numelxinlety=empty(kindx)(shapex)inlety_flat=reshapey[|n|]inletx_flat=reshapex[|n|]infori=0ton-1dosety_flat[|i|](getx_flat[|n-1-i|])done;yletreverse_~outx=letn=numelxinlety_flat=reshapeout[|n|]inletx_flat=reshapex[|n|]infori=0ton-1dosety_flat[|i|](getx_flat[|n-1-i|])doneletmap_fx=lety=flattenx|>array1_of_genarrayinletlength=numelxinfori=0tolength-1do(Array1.unsafe_setyi(f(Array1.unsafe_getyi)))doneletmapi_fx=lety=flattenx|>array1_of_genarrayinletlength=numelxinfori=0tolength-1do(Array1.unsafe_setyi(fi(Array1.unsafe_getyi)))doneletinitkinddimsf=letvarr=emptykinddimsinletvarr_flat=flattenvarr|>array1_of_genarrayinletn=numelvarrinfori=0ton-1doArray1.unsafe_setvarr_flati(fi)done;varrletinit_ndkdf=letx=emptykdinlety=array1_of_genarray(flattenx)inletn=numelxinlets=Owl_utils.calc_stridedinletj=Array.copysinfori=0ton-1doOwl_utils.index_1d_ndijs;Array1.unsafe_setyi(fj)done;x(* Map a NDarray from elements x -> f(x), by copying the array *)letmapfx=lety=copyxinmap_fy;yletmapifx=lety=copyxinlety'=flatteny|>array1_of_genarrayinfori=0to(Array1.dimy')-1doleta=Array1.unsafe_gety'iinArray1.unsafe_sety'i(fia)done;yletstridesx=x|>shape|>Owl_utils.calc_strideletslice_sizex=x|>shape|>Owl_utils.calc_slice(* TODO: performance can be optimised by removing embedded loops *)(* generic fold funtion *)letfoldi?axisfax=letx'=flattenx|>array1_of_genarrayinmatchaxiswith|Someaxis->(letm,n,o,s=Owl_utils.reduce_paramsaxisxinletstart_x=ref0inletstart_y=ref0inletincy=ref0inletk=ref0inlety=create(kindx)sainlety'=flatteny|>array1_of_genarrayinfor_i=0tom-1doforj=0ton-1doletb=Array1.unsafe_gety'(!start_y+!incy)inletc=Array1.unsafe_getx'(!start_x+j)inArray1.unsafe_sety'(!start_y+!incy)(f!kbc);if!incy+1=othenincy:=0elseincy:=!incy+1;k:=!k+1;done;start_x:=!start_x+n;start_y:=!start_y+o;done;y)|None->(letb=refainfori=0to(numelx)-1doletc=Array1.unsafe_getx'iinb:=fi!bcdone;create(kindx)[|1|]!b)letfold?axisfax=foldi?axis(fun_bc->fbc)ax(* generic scan function *)letscani?axisfx=letd=num_dimsxinleta=matchaxiswith|Somea->a|None->d-1inassert(0<=a&&a<d);let_stride=stridesxinlet_slicez=slice_sizexinletm=(numelx)/_slicez.(a)inletn=_slicez.(a)-_stride.(a)inletincx=_slicez.(a)inletincy=_slicez.(a)inletstart_x=ref0inletstart_y=ref_stride.(a)inletk=ref0inlety=copyxinlety'=flatteny|>array1_of_genarrayinfor_i=0tom-1doforj=0ton-1doletb=Array1.unsafe_gety'(!start_x+j)inletc=Array1.unsafe_gety'(!start_y+j)inArray1.unsafe_sety'(!start_y+j)(f!kbc);k:=!k+1done;start_x:=!start_x+incx;start_y:=!start_y+incy;done;yletscan?axisfx=scani?axis(fun_ab->fab)xletiterifx=letx'=flattenx|>array1_of_genarrayinfori=0to(Array1.dimx')-1doleta=Array1.unsafe_getx'iinfiadoneletiterfx=letx'=flattenx|>array1_of_genarrayinfori=0to(Array1.dimx')-1doleta=Array1.unsafe_getx'iinfadoneletfilterifx=lets=Owl_utils.Stack.make()initeri(funiy->iffiy=truethenOwl_utils.Stack.pushsi)x;Owl_utils.Stack.to_arraysletfilterfx=filteri(fun_y->fy)xletsequential_?a?step~out=letk=kindoutinleta=matchawith|Somea->a|None->Owl_const.zerokinletstep=matchstepwith|Somestep->step|None->Owl_const.onekinlet_add=Owl_base_dense_common._add_eltkinlet_mul=Owl_base_dense_common._mul_eltkinlet_flt=Owl_base_dense_common._float_typ_eltkinmapi_(funi_->_adda(_mul(_flt(float_of_inti))step))outletsequentialk?a?stepdimension=letx=emptykdimensioninsequential_?a?step~out:x;xletof_arraykindarrdims=letvarr=emptykinddimsinletflat_varr=flattenvarr|>array1_of_genarrayinletn=numelvarrinfori=0ton-1doArray1.unsafe_setflat_varriarr.(i)done;varrletuniformk?a?bdims=leta=matchawithSomea->a|None->Owl_const.zerokinletb=matchbwithSomeb->b|None->Owl_const.onekinletuniform_fun=Owl_base_dense_common._uniform_eltkabinletx=emptykdimsinmap_uniform_funx;xletuniform_?a?b~out=letk=kindoutinleta=matchawithSomea->a|None->Owl_const.zerokinletb=matchbwithSomeb->b|None->Owl_const.onekinletuniform_fun=Owl_base_dense_common._uniform_eltkabinmap_uniform_funoutletbernoullik?(p=0.5)dims=letbernoulli_fun=fun_->leta=Owl_base_stats.bernoulli_rvs~pinOwl_base_dense_common._float_typ_eltkainletx=emptykdimsinmap_bernoulli_funx;xletbernoulli_?(p=0.5)~out=letk=kindoutinletbernoulli_fun=fun_->leta=Owl_base_stats.bernoulli_rvs~pinOwl_base_dense_common._float_typ_eltkainmap_bernoulli_funoutletgaussiank?mu?sigmadims=letmu=matchmuwithSomea->a|None->Owl_const.zerokinletsigma=matchsigmawithSomea->a|None->Owl_const.onekinletgaussian_fun=Owl_base_dense_common._gaussian_eltkmusigmainletx=emptykdimsinmap_gaussian_funx;xletgaussian_?mu?sigma~out=letk=kindoutinletmu=matchmuwithSomea->a|None->Owl_const.zerokinletsigma=matchsigmawithSomea->a|None->Owl_const.onekinletgaussian_fun=Owl_base_dense_common._gaussian_eltkmusigmainmap_gaussian_funoutletprint?max_row?max_col?header?fmtx=letdims=shapexinletrank=Array.lengthdimsinletn=dims.(rank-1)inletmax_row=matchmax_rowwith|Somea->Somea|None->Some((numelx)/n)inletmax_col=matchmax_colwith|Somea->Somea|None->SomeninOwl_pretty.print_dsnda?max_row?max_col?header?elt_to_str_fun:fmtx(* TODO: optimise *)lettilevarrreps=(* First ensure len(reps) = num_dims(varr) *)letdims=shapevarrinletresult_rank=Stdlib.max(Array.lengthdims)(Array.lengthreps)inletdims=_prepend_dimsdimsresult_rankinletreps=_prepend_dimsrepsresult_rankinletvarr=reshapevarrdimsin(* now len(reps) = num_dims(varr) *)letresult_dims=Array.map2(funab->a*b)dimsrepsinletresult_varr=empty(kindvarr)result_dimsinletresult_ind=Array.makeresult_rank0inletoriginal_ind=Array.makeresult_rank0inletshould_stop=reffalseinwhilenot!should_stopdofori=0toresult_rank-1dooriginal_ind.(i)<-(Stdlib.(mod)result_ind.(i)dims.(i))done;Genarray.setresult_varrresult_ind(Genarray.getvarroriginal_ind);ifnot(_next_indexresult_indresult_dims)thenshould_stop:=truedone;result_varr(* TODO: optimise *)letsplit?(axis=0)partsvarr=letdims=shapevarrinletrank=Array.lengthdimsinletpos=ref0inletaxis_indices=Array.map(fund->(pos:=!pos+d;[!pos-d;!pos-1]))partsinletslices_defs=Array.map(funind->Array.to_list(Array.initrank(funi->ifi=axisthenindelse[])))axis_indicesin(Array.map(fundef->get_slicedefvarr)slices_defs)letsqueeze?(axis=[||])x=leta=matchArray.lengthaxiswith|0->Array.init(num_dimsx)(funi->i)|_->axisinlets=Owl_utils.Array.filteri(funiv->not(v==1&&Array.memia))(shapex)inreshapexsletexpand?(hi=false)xd=letd0=d-(num_dimsx)inmatchd0>0with|true->(ifhi=truethenOwl_utils.Array.pad`Right1d0(shapex)|>reshapexelseOwl_utils.Array.pad`Left1d0(shapex)|>reshapex)|false->x(* TODO : ensure this is desired behaviour *)(* Similar to draw rows for matrices *)letdraw?(axis=0)varrcount=letdims=shapevarrinletrank=Array.lengthdimsinletindices=_draw_int_samplesfalsedims.(axis)countin(get_slice(List.initrank(funi->ifi=axisthen(Array.to_listindices)else[]))varr,indices)let_expand_padding_indexds=letls=Array.lengthsinletld=Array.lengthdinletd=Owl_utils.Array.pad`Right[|0;0|](ls-ld)dinArray.map(function|[||]->[|0;0|]|[|x|]->[|x;x|]|x->x)dletrec_copy_to_paddingp1lsl0l1i0i1d0d1s0s1x0x1=ifd0<d1then(fori=0tos0.(d0)-1doi0.(d0)<-i;i1.(d0)<-i+p1.(d0).(0);_copy_to_paddingp1lsl0l1i0i1(d0+1)d1s0s1x0x1;i0.(d0)<-0;i1.(d0)<-p1.(d0).(0);done)else(letj0=Owl_utils.index_nd_1di0l0inletj1=Owl_utils.index_nd_1di1l1inletsubx=Genarray.sub_leftx0j0ls.(d0)inletsuby=Genarray.sub_leftx1j1ls.(d0)inGenarray.blitsubxsuby)let_highest_padding_dimensionp=letl=Array.lengthp-1inletd=reflin(tryfori=ldownto0dod:=i;ifp.(i)<>[|0;0|]thenfailwith"pad:highest_padding_dimension"donewith_exn->());!dletpad?vdx=letk=kindxinletv=matchvwith|Somev->v|None->Owl_const.zerokinlets0=shapexinletx'=flattenxinletp1=_expand_padding_index(Owl_utils.llss2aarrd)s0inlets1=Array.map2(funmn->m+n.(0)+n.(1))s0p1inlets'=Owl_utils_array.fold_right(*)s11inlety'=createk[|s'|]vinletls=Owl_utils.calc_slices0inletl0=Owl_utils.calc_strides0inletl1=Owl_utils.calc_strides1inleti0=Array.make(num_dimsx)0inleti1=Array.map(funa->a.(0))p1inletd0=0inletd1=_highest_padding_dimensionp1in_copy_to_paddingp1lsl0l1i0i1d0d1s0s1x'y';reshapey's1(* TODO: optimise? *)letconcatenate?(axis=0)varrs=letvarrs_num=Array.lengthvarrsin(* dimensions of all NDarrays *)letall_dims=Array.mapshapevarrsin(* the dimensions before the axis *)letprefix_dims=Array.suball_dims.(0)0axisin(* the sum of the dimensions of each NDarray along given axis *)letsum_axis_dims=Array.fold_left(funxa->x+a.(axis))0all_dimsin(* the dimensions after the axis *)letsuffix_dims=Array.suball_dims.(0)(axis+1)((Array.lengthall_dims.(0))-axis-1)inletresult_dims=Array.concat[prefix_dims;[|sum_axis_dims|];suffix_dims]inletresult_varr=empty(kindvarrs.(0))result_dimsinletprefix_dims_product=Array.fold_left(*)1prefix_dimsinletsuffix_dims_product=Array.fold_left(*)1suffix_dimsinletreshaper_fun=((* Reshape the variable as [prefix_dims_product, rest] *)funvarr->letold_shape=shapevarrinletnew_shape=[|prefix_dims_product;old_shape.(axis)*suffix_dims_product|]inreshapevarrnew_shape)inletreshaped_result=reshaper_funresult_varrinletreshaped_varrs=Array.mapreshaper_funvarrsinbeginfori=0toprefix_dims_product-1doletstart_index=ref0inletresult_slice=Genarray.slice_leftreshaped_result[|i|]inforj=0tovarrs_num-1doletsrc_slice=Genarray.slice_leftreshaped_varrs.(j)[|i|]inletblock_len=all_dims.(j).(axis)*suffix_dims_productinletresult_sub=Genarray.sub_leftresult_slice!start_indexblock_leninGenarray.blitsrc_sliceresult_sub;start_index:=!start_index+block_lendonedone;result_varrend(* TODO: is there a more efficient way to do copy? *)letrepeatxreps=(* check the validity of reps *)ifArray.exists((>)1)repsthenfailwith"repeat: repetition must be >= 1";letx_dims=num_dimsxinassert(Array.lengthreps=x_dims);if(Array.for_all((=)1)reps)=truethencopyxelse(let_kind=kindxinletx'=flattenxinletx_shape=shapexinlety_shape=Array.map2(*)x_shaperepsinletnum=Owl_utils_array.fold_right(*)y_shape1inlety'=empty_kind[|num|]inifx_dims=1then(letofsy=ref0infori=0tonumelx-1doletelemx=getx'[|i|]infor_j=0toreps.(0)-1dosety'[|!ofsy|]elemx;ofsy:=!ofsy+1donedone)else(lethighest_dim=x_dims-1inletslice_x=Owl_utils.calc_slicex_shapeinletstride_y=Owl_utils.calc_stridey_shapeinlethd=ref(highest_dim+1)inwhile!hd>1&&reps.(!hd-1)=1dohd:=!hd-1;done;lethd=if!hd=highest_dim+1thenhighest_dimelse!hdin(* Copy the HD dimension from x to y *)letblock_num=Array.makehd0infori=0tohd-1doblock_num.(i)<-slice_x.(i)/slice_x.(hd);done;letcounter=Array.makehd0inletofsx=ref0inletofsy=ref0inletblock_sz=reps.(hd)infor_i=0toblock_num.(0)-1doletofsy_sub=ref!ofsyinifblock_sz=1then(letsubx=Genarray.sub_leftx'!ofsxslice_x.(hd)inletsuby=Genarray.sub_lefty'!ofsy_subslice_x.(hd)inGenarray.blitsubxsuby)else(forj=0toslice_x.(hd)-1doletelemx=getx'[|!ofsx+j|]infork=0toblock_sz-1dosety'[|!ofsy_sub+k|]elemxdone;ofsy_sub:=!ofsy_sub+block_szdone);ofsx:=!ofsx+slice_x.(hd);ofsy:=!ofsy+stride_y.(hd-1)*reps.(hd-1);forj=hd-1downto1doletc=counter.(j)inifc+1=block_num.(j)then(ofsy:=!ofsy+stride_y.(j-1)*(reps.(j-1)-1););counter.(j)<-ifc+1=block_num.(j)then0elsec+1donedone;(* Copy the lower dimensions within y *)ford=hd-1downto0doletblock_num=Array.make(d+1)0infori=0toddoblock_num.(i)<-slice_x.(i)/slice_x.(d+1);done;letofsy=ref0inletblock_sz=stride_y.(d)inletcounter=Array.makehd0infor_i=0toblock_num.(0)-1doletofsy_sub=ref(!ofsy+block_sz)infor_j=1toreps.(d)-1doletsubx=Genarray.sub_lefty'!ofsyblock_szinletsuby=Genarray.sub_lefty'!ofsy_subblock_szinGenarray.blitsubxsuby;ofsy_sub:=!ofsy_sub+block_szdone;ofsy:=!ofsy+stride_y.(d)*reps.(d);forj=d-1downto0doletc=counter.(j)inifc+1=block_num.(j+1)then(ofsy:=!ofsy+stride_y.(j)*(reps.(j)-1););counter.(j)<-ifc+1=block_num.(j+1)then0elsec+1donedonedone);reshapey'y_shape)(* mathematical functions *)letabsx=let_kind=kindxinlet_func=Owl_base_dense_common._abs_elt_kindinmap_funcxletabs_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._abs_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletconjx=let_kind=kindxinlet_func=Owl_base_dense_common._conj_elt_kindinmap_funcxletconj_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._conj_elt_kindinletout=matchoutwithSomeo->o|None->xinmap_funcoutletnegx=let_kind=kindxinlet_func=Owl_base_dense_common._neg_elt_kindinmap_funcxletneg_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._neg_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletrecix=let_kind=kindxinlet_func=Owl_base_dense_common._inv_elt_kindinmap_funcxletreci_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._inv_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletfloorx=let_kind=kindxinlet_func=Owl_base_dense_common._floor_elt_kindinmap_funcxletfloor_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._floor_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletceilx=let_kind=kindxinlet_func=Owl_base_dense_common._ceil_elt_kindinmap_funcxletceil_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._ceil_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletroundx=let_kind=kindxinlet_func=Owl_base_dense_common._round_elt_kindinmap_funcxletround_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._round_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutlettruncx=let_kind=kindxinlet_func=Owl_base_dense_common._trunc_elt_kindinmap_funcxlettrunc_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._trunc_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletfixx=let_kind=kindxinlet_func=Owl_base_dense_common._fix_elt_kindinmap_funcxletfix_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._fix_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutleterf_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.erf")leterf_?_out_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.erf_")leterfc_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.erfc")leterfc_?_out_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.erfc_")letsqrx=let_kind=kindxinlet_func=Owl_base_dense_common._sqr_elt_kindinmap_funcxletsqr_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._sqr_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletsqrtx=let_kind=kindxinlet_func=Owl_base_dense_common._sqrt_elt_kindinmap_funcxletsqrt_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._sqrt_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletcbrtx=let_kind=kindxinletb=Owl_base_dense_common._float_typ_elt_kind(1./.3.)inlet_func=funa->Owl_base_dense_common._pow_elt_kindabinmap_funcxletcbrt_?outx=let_kind=kindxinletb=Owl_base_dense_common._float_typ_elt_kind(1./.3.)inlet_func=funa->Owl_base_dense_common._pow_elt_kindabinletout=matchoutwithSomeo->o|None->xinmap__funcoutletlogx=let_kind=kindxinlet_func=Owl_base_dense_common._log_elt_kindinmap_funcxletlog_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._log_elt_kindinletout=matchoutwithSomeo->o|None->xinmap_Scalar.logoutletlog2x=let_kind=kindxinlet_func=Owl_base_dense_common._log2_elt_kindinmap_funcxletlog2_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._log2_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletlog10x=let_kind=kindxinlet_func=Owl_base_dense_common._log10_elt_kindinmap_funcxletlog10_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._log10_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletlog1px=let_kind=kindxinlet_func=Owl_base_dense_common._log1p_elt_kindinmap_funcxletlog1p_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._log1p_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletexpx=let_kind=kindxinlet_func=Owl_base_dense_common._exp_elt_kindinmap_funcxletexp_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._exp_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletexp2x=let_kind=kindxinlet_func=Owl_base_dense_common._exp2_elt_kindinmap_funcxletexp2_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._exp2_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletexp10x=let_kind=kindxinlet_func=Owl_base_dense_common._exp10_elt_kindinmap_funcxletexp10_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._exp10_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletexpm1x=let_kind=kindxinlet_func=Owl_base_dense_common._expm1_elt_kindinmap_funcxletexpm1_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._expm1_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletsinx=let_kind=kindxinlet_func=Owl_base_dense_common._sin_elt_kindinmap_funcxletsin_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._sin_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletcosx=let_kind=kindxinlet_func=Owl_base_dense_common._cos_elt_kindinmap_funcxletcos_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._cos_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutlettanx=let_kind=kindxinlet_func=Owl_base_dense_common._tan_elt_kindinmap_funcxlettan_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._tan_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletsinhx=let_kind=kindxinlet_func=Owl_base_dense_common._sinh_elt_kindinmap_funcxletsinh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._sinh_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletcoshx=let_kind=kindxinlet_func=Owl_base_dense_common._cosh_elt_kindinmap_funcxletcosh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._cosh_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutlettanhx=let_kind=kindxinlet_func=Owl_base_dense_common._tanh_elt_kindinmap_funcxlettanh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._tanh_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletasinx=let_kind=kindxinlet_func=Owl_base_dense_common._asin_elt_kindinmap_funcxletasin_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._asin_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletacosx=let_kind=kindxinlet_func=Owl_base_dense_common._acos_elt_kindinmap_funcxletacos_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._acos_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletatanx=let_kind=kindxinlet_func=Owl_base_dense_common._atan_elt_kindinmap_funcxletatan_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._atan_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletasinhx=let_kind=kindxinlet_func=Owl_base_dense_common._asinh_elt_kindinmap_funcxletasinh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._asinh_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletacoshx=let_kind=kindxinlet_func=Owl_base_dense_common._acosh_elt_kindinmap_funcxletacosh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._acosh_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletatanhx=let_kind=kindxinlet_func=Owl_base_dense_common._atanh_elt_kindinmap_funcxletatanh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._atanh_elt_kindinletout=matchoutwithSomeo->o|None->xinmap__funcoutletsum_slices?(axis=0)varr=letdims=shapevarrinletrank=Array.lengthdimsin(* reshape into 2d matrix *)letnum_rows=Array.fold_left(*)1(Array.subdims0(axis+1))inletnum_cols=(numelvarr)/num_rowsinletvarr_mat=reshapevarr[|num_rows;num_cols|]inletresult_vec=empty(kindvarr)[|num_cols|]inletresult_varr=reshaperesult_vec(Array.subdims(axis+1)(rank-axis-1))inletrow_sum=ref0.inforj=0tonum_cols-1dorow_sum:=0.;fori=0tonum_rows-1dorow_sum:=!row_sum+.(Genarray.getvarr_mat[|i;j|])done;Genarray.setresult_vec[|j|]!row_sumdone;result_varr(* -1. for negative numbers, 0 or (-0) for 0,
1 for positive numbers, nan for nan*)letsignumx=mapScalar.signumxletsignum_?outx=letout=matchoutwithSomeo->o|None->xinmap_Scalar.signumout(* Apply 1 / (1 + exp (-x)) for each element x *)letsigmoidx=mapScalar.sigmoidxletsigmoid_?outx=letout=matchoutwithSomeo->o|None->xinmap_Scalar.sigmoidoutletrelux=mapScalar.reluxletrelu_?outx=letout=matchoutwithSomeo->o|None->xinmap_Scalar.reluoutletsoftsignx=mapScalar.softsignxletsoftsign_?outx=letout=matchoutwithSomeo->o|None->xinmap_Scalar.softsignoutletsoftplusx=mapScalar.softplusxletsoftplus_?outx=letout=matchoutwithSomeo->o|None->xinmap_Scalar.softplusoutlet_fold_leftfavarr=letaref=refainletvarr_linear=flattenvarr|>array1_of_genarrayinletlength=numelvarrinbeginfori=0tolength-1doaref:=(f!aref(Array1.unsafe_getvarr_lineari))done;!arefend(* Min of all elements in the NDarray *)letmin'x=let_kind=kindxinlet_max_val=Owl_base_dense_common._max_val_elt_kindin_fold_left(Owl_base_dense_common._min_elt_kind)_max_valx(* Max of all elements in the NDarray *)letmax'x=let_kind=kindxinlet_min_val=Owl_base_dense_common._min_val_elt_kindin_fold_left(Owl_base_dense_common._max_elt_kind)_min_valx(* Sum of all elements *)letsum'x=let_kind=kindxin_fold_left(Owl_base_dense_common._add_elt_kind)(Owl_const.zero_kind)x(* Folding along a specified axis, aka reduction. The
f: function of type 'a -> 'a -> 'a.
m: number of slices.
n: x's slice size.
o: x's strides, also y's slice size.
x: source; y: shape of destination. Note that o <= n.
*)let_fold_along?outfmnoxysnelem=letx=flattenxinlety=matchoutwith|Someo->o|>flatten|None->create(kindx)ysnelem|>flatteninletidx=ref0inletidy=ref0inletincy=ref0infor_i=0to(m-1)doforj=0to(n-1)doletaddon=Genarray.getx[|!idx+j|]inletorig=Genarray.gety[|!idy+!incy|]inGenarray.sety[|!idy+!incy|](forigaddon);incy:=if(!incy+1=o)then0else!incy+1done;idx:=!idx+n;idy:=!idy+o;done;reshapeyysletsum?axisx=let_kind=kindxinletzero=Owl_const.zero_kindinmatchaxiswith|Somea->(letm,n,o,s=Owl_utils.reduce_paramsaxinlet_op=Owl_base_dense_common._add_elt_kindin_fold_along_opmnoxszero)|None->create(kindx)(Array.make11)(sum'x)letsum_~out~axisx=let_kind=kindxinletzero=Owl_const.zero_kindinGenarray.filloutzero;matchaxiswith|Somea->(letm,n,o,s=Owl_utils.reduce_paramsaxinlet_op=Owl_base_dense_common._add_elt_kindin_fold_along_op~outmnoxszero|>ignore)|None->(lety=flattenoutinsety[|0|](sum'x))letsum_reduce?axisx=let_kind=kindxinlet_dims=num_dimsxinletzero=Owl_const.zero_kindinmatchaxiswith|Somea->(letx_shape=shapexinletdims'=Owl_utils.squeeze_continuous_dimsx_shapeainifArray.lengthdims'=1then(create(kindx)(Array.make_dims1)(sum'x))else(lety=ref(reshapexdims')inletflag=ref(Array.mem0a)infori=0toArray.lengthdims'-1doif!flag=truethen(letm,n,o,s=Owl_utils.reduce_paramsi!yiny:=_fold_along(Owl_base_dense_common._add_elt_kind)mno!yszero);flag:=not!flagdone;lety_shape=Array.copyx_shapeinArray.iter(funj->y_shape.(j)<-1)a;reshape!yy_shape))|None->create(kindx)(Array.make_dims1)(sum'x)letmin?axisx=let_kind=kindxinletmax_val=Owl_base_dense_common._max_val_elt_kindinmatchaxiswith|Somea->(letm,n,o,s=Owl_utils.reduce_paramsaxin_fold_along(Owl_base_dense_common._min_elt_kind)mnoxsmax_val)|None->min'x|>create_kind[|1|]letmin_~out~axisx=let_kind=kindxinletmax_val=Owl_base_dense_common._max_val_elt_kindinGenarray.filloutmax_val;matchaxiswith|Somea->(letm,n,o,s=Owl_utils.reduce_paramsaxinlet_op=Owl_base_dense_common._min_elt_kindin_fold_along~out_opmnoxsmax_val|>ignore)|None->(lety=flattenoutinsety[|0|](min'x))letmax?axisx=let_kind=kindxinletmin_val=Owl_base_dense_common._min_val_elt_kindinmatchaxiswith|Somea->(letm,n,o,s=Owl_utils.reduce_paramsaxin_fold_along(Owl_base_dense_common._max_elt_kind)mnoxsmin_val)|None->max'x|>create_kind[|1|]letmax_~out~axisx=let_kind=kindxinletmin_val=Owl_base_dense_common._min_val_elt_kindinGenarray.filloutmin_val;matchaxiswith|Somea->(letm,n,o,s=Owl_utils.reduce_paramsaxin_fold_along~out(Owl_base_dense_common._max_elt_kind)mnoxsmin_val|>ignore)|None->(lety=flattenoutinsety[|0|](max'x))letl1norm'varr=letl1norm_fun=(funaggregateelem->(aggregate+.(Scalar.abs(elem))))in(_fold_leftl1norm_fun0.varr)letl2norm_sqr'varr=letl2norm_sqr_fun=(funaggregateelem->(aggregate+.(elem*.elem)))in(_fold_leftl2norm_sqr_fun0.varr)letl2norm'varr=letl2norm_sqr_val=l2norm_sqr'varrin(Scalar.sqrtl2norm_sqr_val)let_broadcasted_op?outvarr_avarr_bop_fun=let(dims_a,dims_b,dims_c)=_get_broadcasted_dims(shapevarr_a)(shapevarr_b)inlet_kind=kindvarr_ainletvarr_a=reshapevarr_adims_ainletvarr_b=reshapevarr_bdims_binletvarr_c=matchoutwith|Someout->out|None->empty_kinddims_cinletind=Array.make(Array.lengthdims_c)0inletshould_stop=reffalseinbeginwhilenot!should_stopdoletind_a=_get_broadcasted_indexinddims_ainletind_b=_get_broadcasted_indexinddims_binGenarray.setvarr_cind(op_fun(Genarray.getvarr_aind_a)(Genarray.getvarr_bind_b));ifnot(_next_indexinddims_c)thenshould_stop:=truedone;varr_cendletaddxy=let_op=Owl_base_dense_common._add_elt(kindx)in_broadcasted_opxy_opletadd_?outxy=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._add_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletsubxy=let_op=Owl_base_dense_common._sub_elt(kindx)in_broadcasted_opxy_opletsub_?outxy=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._sub_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletmulxy=let_op=Owl_base_dense_common._mul_elt(kindx)in_broadcasted_opxy_opletmul_?outxy=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._mul_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletdivxy=let_op=Owl_base_dense_common._div_elt(kindx)in_broadcasted_opxy_opletdiv_?outxy=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._div_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletatan2xy=_broadcasted_opxy(Scalar.atan2)letatan2_?outxy=letout=matchoutwithSomeo->o|None->xinletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_opxy(Scalar.atan2)|>ignorelethypotxy=_broadcasted_opxy(Scalar.hypot)lethypot_?outxy=letout=matchoutwithSomeo->o|None->xinletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_opxy(Scalar.hypot)|>ignoreletpowxy=let_kind=kindxinlet_op=Owl_base_dense_common._pow_elt_kindin_broadcasted_opxy_opletpow_?outxy=let_kind=kindxinlet_op=Owl_base_dense_common._pow_elt_kindinletout=matchoutwithSomeo->o|None->xinletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletfmodxy=_broadcasted_opxy(Scalar.fmod)letfmod_?outxy=letout=matchoutwithSomeo->o|None->xinletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_opxy(Scalar.fmod)|>ignoreletmin2xy=let_op=Owl_base_dense_common._min_elt(kindx)in_broadcasted_opxy_opletmin2_?outxy=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._min_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletmax2xy=let_op=Owl_base_dense_common._max_elt(kindx)in_broadcasted_opxy_opletmax2_?outxy=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._max_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletadd_scalarxa=let_op=Owl_base_dense_common._add_elt(kindx)inmap(funy->_opya)xletadd_scalar_?outxa=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._add_elt(kindx)inmap_(funy->_opya)outletsub_scalarxa=let_op=Owl_base_dense_common._sub_elt(kindx)inmap(funy->_opya)xletsub_scalar_?outxa=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._sub_elt(kindx)inmap_(funy->_opya)outletmul_scalarxa=let_op=Owl_base_dense_common._mul_elt(kindx)inmap(funy->_opya)xletmul_scalar_?outxa=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._mul_elt(kindx)inmap_(funy->_opya)outletdiv_scalarxa=let_op=Owl_base_dense_common._div_elt(kindx)inmap(funy->_opya)xletdiv_scalar_?outxa=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._div_elt(kindx)inmap_(funy->_opya)outletpow_scalarxa=let_op=Owl_base_dense_common._pow_elt(kindx)inmap(funy->_opya)xletpow_scalar_?outxa=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._pow_elt(kindx)inmap_(funy->_opya)outletatan2_scalarxa=let_op=Scalar.atan2inmap(funy->_opya)xletatan2_scalar_?outxa=letout=matchoutwithSomeo->o|None->xinlet_op=Scalar.atan2inmap_(funy->_opya)outletfmod_scalarxa=let_op=Scalar.fmodinmap(funy->_opya)xletfmod_scalar_?outxa=letout=matchoutwithSomeo->o|None->xinlet_op=Scalar.fmodinmap_(funy->_opya)out(* TODO *)letfma_x_y_z=failwith"Owl_base_dense_ndarray_generic:fma: not implemented"letscalar_addax=let_op=Owl_base_dense_common._add_elt(kindx)inmap(funy->_opay)xletscalar_add_?outax=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._add_elt(kindx)inmap_(funy->_opay)outletscalar_subax=let_op=Owl_base_dense_common._sub_elt(kindx)inmap(funy->_opay)xletscalar_sub_?outax=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._sub_elt(kindx)inmap_(funy->_opay)outletscalar_mulax=let_op=Owl_base_dense_common._mul_elt(kindx)inmap(funy->_opay)xletscalar_mul_?outax=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._mul_elt(kindx)inmap_(funy->_opay)outletscalar_divax=let_op=Owl_base_dense_common._div_elt(kindx)inmap(funy->_opay)xletscalar_div_?outax=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._div_elt(kindx)inmap_(funy->_opay)outletscalar_powax=let_op=Owl_base_dense_common._pow_elt(kindx)inmap(funy->_opay)xletscalar_pow_?outax=letout=matchoutwithSomeo->o|None->xinlet_op=Owl_base_dense_common._pow_elt(kindx)inmap_(funy->_opay)outletscalar_atan2ax=let_op=Scalar.atan2inmap(funy->_opay)xletscalar_atan2_?outax=letout=matchoutwithSomeo->o|None->xinlet_op=Scalar.atan2inmap_(funy->_opay)outletscalar_fmodax=let_op=Scalar.fmodinmap(funy->_opay)xletscalar_fmod_?outax=letout=matchoutwithSomeo->o|None->xinlet_op=Scalar.fmodinmap_(funy->_opay)outletclip_by_value?(amin=Stdlib.min_float)?(amax=Stdlib.max_float)x=let_op=(funy->Stdlib.minamax(Stdlib.maxaminy))inmap_opxletclip_by_l2normclip_normx=letl2norm_val=l2norm'xinifl2norm_val>clip_normthenmul_scalarx(clip_norm/.l2norm_val)elsexletsoftmax?(axis=(-1))x=letx=copyxinletaxis=Owl_utils.adjust_indexaxis(num_dimsx)insub_~out:xx(max~axisx);exp_~out:xx;leta=sum~axisxindiv_~out:xxa;xletsoftmax_?out?(axis=(-1))x=letout=matchoutwithSomeo->o|None->xinletaxis=Owl_utils.adjust_indexaxis(num_dimsx)insub_~outx(max~axisx);exp_~outx;leta=sum~axisxindiv_~outxa(* Comparison functions *)(** Return true if for all elements comp_fun (xa, xb) == true, false otherwise.
Returns false as soon as it finds a counterexample. (NOT broadcasted) *)let_compare_util_shortcircuitvarr_avarr_bcomp_fun=letn=numelvarr_ainletm=numelvarr_binifn!=mthenfalseelseletvarr_a=flattenvarr_a|>array1_of_genarrayinletvarr_b=flattenvarr_b|>array1_of_genarrayinletall_ok=reftrueinleti=ref0in(while!all_ok&&(!i<n)doletx=Array1.unsafe_getvarr_a!iinlety=Array1.unsafe_getvarr_b!iinif(not(comp_funxy))thenall_ok:=false;i:=!i+1done;!all_ok)letapprox_equal?epsvarr_avarr_b=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inletapprox_equal_fun=(funxy->(Scalar.abs(Scalar.subxy))<eps)in(_compare_util_shortcircuitvarr_avarr_bapprox_equal_fun)letequalxy=(_compare_util_shortcircuitxyStdlib.(=))letnot_equalxy=(_compare_util_shortcircuitxyStdlib.(<>))letlessxy=(_compare_util_shortcircuitxyStdlib.(<))letgreaterxy=(_compare_util_shortcircuitxyStdlib.(>))letless_equalxy=(_compare_util_shortcircuitxyStdlib.(<=))letgreater_equalxy=(_compare_util_shortcircuitxyStdlib.(>=))(** Return true if for all elements of a comp_fun (xa, bb) == true, false otherwise.
Returns false as soon as it finds a counterexample. (NOT broadcasted) *)let_compare_util_shortcircuit_scalarvarr_abcomp_fun=letn=numelvarr_ainletvarr_a=flattenvarr_a|>array1_of_genarrayinletall_ok=reftrueinleti=ref0in(while!all_ok&&(!i<n)doletx=Array1.unsafe_getvarr_a!iinif(not(comp_funxb))thenall_ok:=false;i:=!i+1done;!all_ok)letapprox_equal_scalar?epsvarr_ab=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inletapprox_equal_scalar_fun=(funxy->(Scalar.abs(Scalar.subxy))<eps)in(_compare_util_shortcircuit_scalarvarr_abapprox_equal_scalar_fun)letequal_scalarxa=(_compare_util_shortcircuit_scalarxaStdlib.(=))letnot_equal_scalarxa=(_compare_util_shortcircuit_scalarxaStdlib.(<>))letless_scalarxa=(_compare_util_shortcircuit_scalarxaStdlib.(<))letgreater_scalarxa=(_compare_util_shortcircuit_scalarxaStdlib.(>))letless_equal_scalarvarr_ab=(_compare_util_shortcircuit_scalarvarr_abStdlib.(<=))letgreater_equal_scalarxa=(_compare_util_shortcircuit_scalarxaStdlib.(>=))(* Broadcasted operation, return an array with values of 1
if (one_fun elem_from_a elem_from_b) == true, 0 otherwise *)let_make_elt_compare_funkindcmp_fun=letc0=Owl_const.zerokindinletc1=Owl_const.onekindinlet_funcab=ifcmp_funabthenc1elsec0in_funcletelt_equalxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(=)in_broadcasted_opxy_funcletelt_equal_?outxy=letout=matchoutwithSomeo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(=)in_broadcasted_op~outxy_funcletapprox_elt_equal?epsxy=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inletapprox_equal_fun=(funxy->(Scalar.abs(Scalar.subxy))<eps)inlet_func=_make_elt_compare_fun(kindx)approx_equal_funin_broadcasted_opxy_funcletelt_not_equalxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(<>)in_broadcasted_opxy_funcletelt_not_equal_?outxy=letout=matchoutwithSomeo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(<>)in_broadcasted_op~outxy_funcletelt_lessxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(<)in_broadcasted_opxy_funcletelt_less_?outxy=letout=matchoutwithSomeo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(<)in_broadcasted_op~outxy_funcletelt_greaterxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(>)in_broadcasted_opxy_funcletelt_greater_?outxy=letout=matchoutwithSomeo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(>)in_broadcasted_op~outxy_funcletelt_less_equalxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(<=)in_broadcasted_opxy_funcletelt_less_equal_?outxy=letout=matchoutwithSomeo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(<=)in_broadcasted_op~outxy_funcletelt_greater_equalxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(>=)in_broadcasted_opxy_funcletelt_greater_equal_?outxy=letout=matchoutwithSomeo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(>=)in_broadcasted_op~outxy_func(* Util function, return an array with values of 1
if (one_fun elem_from_a b) == true, 0 otherwise *)let_make_elt_compare_scalarxcmp_fun=let_kind=kindxinletc0=Owl_const.zero_kindinletc1=Owl_const.one_kindinlet_funca=ifcmp_funathenc1elsec0in_funcletelt_equal_scalarxa=letcmp_fun=(funy->y=a)inlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_equal_scalar_?outxa=letout=matchoutwithSomeo->o|None->xinletcmp_fun=(funy->y=a)inlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletapprox_elt_equal_scalar?epsxa=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inletcmp_fun=(funy->(Scalar.abs(Scalar.subya))<eps)inlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_not_equal_scalarxa=letcmp_fun=(funy->y<>a)inlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_not_equal_scalar_?outxa=letout=matchoutwithSomeo->o|None->xinletcmp_fun=(funy->y<>a)inlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletelt_less_scalarxa=letcmp_fun=(funy->y<a)inlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_less_scalar_?outxa=letout=matchoutwithSomeo->o|None->xinletcmp_fun=(funy->y<a)inlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletelt_greater_scalarxa=letcmp_fun=(funy->y>a)inlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_greater_scalar_?outxa=letout=matchoutwithSomeo->o|None->xinletcmp_fun=(funy->y>a)inlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletelt_less_equal_scalarxa=letcmp_fun=(funy->y<=a)inlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_less_equal_scalar_?outxa=letout=matchoutwithSomeo->o|None->xinletcmp_fun=(funy->y<=a)inlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletelt_greater_equal_scalarxa=letcmp_fun=(funy->y>=a)inlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_greater_equal_scalar_?outxa=letout=matchoutwithSomeo->o|None->xinletcmp_fun=(funy->y>=a)inlet_func=_make_elt_compare_scalarxcmp_funinmap_funcoutletexistsfx=letn=numelxinletx=flattenx|>array1_of_genarrayinletfound=reffalseinleti=ref0inwhile(!i<n)&&(not!found)doleta=Array1.unsafe_getx!iiniffathenfound:=true;i:=!i+1done;!foundletnot_existsfvarr=(not(existsfvarr))letfor_allfvarr=letnot_f=(funx->not(fx))in(not_existsnot_fvarr)letis_zerovarr=letk=kindvarrinletc0=Owl_const.zerokinletnon_zero_fun=(funx->x<>c0)in(not_existsnon_zero_funvarr)letis_positivevarr=letk=kindvarrinletc0=Owl_const.zerokinletnon_positive_fun=(funx->x<=c0)in(not_existsnon_positive_funvarr)letis_negativevarr=letk=kindvarrinletc0=Owl_const.zerokinletnon_negative_fun=(funx->x>=c0)in(not_existsnon_negative_funvarr)letis_nonpositivevarr=letk=kindvarrinletc0=Owl_const.zerokinletpositive_fun=(funx->x>c0)in(not_existspositive_funvarr)letis_nonnegativevarr=letk=kindvarrinletc0=Owl_const.zerokinletnegative_fun=(funx->x<c0)in(not_existsnegative_funvarr)letis_normalx=let_kind=kindxinletis_normal_fun=Owl_base_dense_common._is_normal_elt_kindinfor_allis_normal_funxletnot_nanx=let_kind=kindxinletis_nan_fun=Owl_base_dense_common._is_nan_elt_kindinnot_existsis_nan_funxletnot_infx=let_kind=kindxinletis_inf_fun=Owl_base_dense_common._is_inf_elt_kindinnot_existsis_inf_funx(* Neural network related functions *)(*TODO: optimise *)(* conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; output_channel]
*)letconv2d?(padding=SAME)inputkernelstride=letp0=(num_dimsinput=4)inletp1=(num_dimskernel=4)inletp2=(Array.lengthstride=2)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"conv2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=(in_channel=kernel_shp.(2))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel shape"inlets5=Printf.sprintf"conv2d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inlet(output_cols,output_rows)=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinlet_kind=kindinputinletoutput=empty_kind[|batches;output_cols;output_rows;out_channel|]inlet(pad_top,pad_left,_,_)=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinletsum=ref0.inbeginforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofork=0toout_channel-1dosum:=0.;fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doforq=0toin_channel-1doletin_col=i*col_stride+di-pad_leftinletin_row=j*row_stride+dj-pad_topinletin_val=(if((0<=in_col)&&(in_col<input_cols)&&(0<=in_row)&&(in_row<input_rows))then(getinput[|b;in_col;in_row;q|])else0.)insum:=!sum+.in_val*.(getkernel[|di;dj;q;k|])done;(*q*)done;(*dj*)done;(*di*)(setoutput[|b;i;j;k|]!sum)done;(*k*)done;(*j*)done;(*i*)done;(*b*)outputend(* conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_stride]
output: [batch; output_column; output_channel]
*)letconv1d?(padding=SAME)inputkernelstride=letp0=(num_dimsinput=3)inletp1=(num_dimskernel=3)inletp2=(Array.lengthstride=1)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=(in_channel=kernel_shp.(1))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel shape"inlets5=Printf.sprintf"conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=conv2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutput(* TODO: optimise *)(* conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)letconv3d?(padding=SAME)inputkernelstride=letp0=(num_dimsinput=5)inletp1=(num_dimskernel=5)inletp2=(Array.lengthstride=3)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"conv3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=(in_channel=kernel_shp.(3))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel shape"inlets5=Printf.sprintf"conv3d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinlet_kind=kindinputinletoutput=empty_kind[|batches;output_cols;output_rows;output_dpts;out_channel|]inlet(pad_top,pad_left,pad_shallow,_,_,_)=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinletsum=ref0.inbeginforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofordpt=0tooutput_dpts-1dofork=0toout_channel-1dosum:=0.;fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doford_dpt=0tokernel_dpts-1doforq=0toin_channel-1doletin_col=i*col_stride+di-pad_leftinletin_row=j*row_stride+dj-pad_topinletin_dpt=dpt*dpt_stride+d_dpt-pad_shallowinletin_val=(if((0<=in_col)&&(in_col<input_cols)&&(0<=in_row)&&(in_row<input_rows)&&(0<=in_dpt)&&(in_dpt<input_dpts))then(getinput[|b;in_col;in_row;in_dpt;q|])else0.)insum:=!sum+.in_val*.(getkernel[|di;dj;d_dpt;q;k|])done;(*q*)done;(*d_dpt*)done;(*dj*)done;(*di*)(setoutput[|b;i;j;dpt;k|]!sum)done;(*k*)done;(*dpt*)done;(*j*)done;(*i*)done;(*b*)outputend(* General function for avg_pool2d and max_pool2d *)let_pool2d?(padding=SAME)inputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun=letp0=(num_dimsinput=4)inletp1=(Array.lengthkernel=2)inletp2=(Array.lengthstride=2)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"_pool2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inlet(output_cols,output_rows)=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinlet_kind=kindinputinletoutput=empty_kind[|batches;output_cols;output_rows;in_channel|]inlet(pad_top,pad_left,_,_)=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinbeginforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofork=0toin_channel-1doinit_pool_fun();fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doletin_col=i*col_stride+di-pad_leftinletin_row=j*row_stride+dj-pad_topinif((0<=in_col)&&(in_col<input_cols)&&(0<=in_row)&&(in_row<input_rows))thenadd_val_pool_fun(getinput[|b;in_col;in_row;k|])done;(*dj*)done;(*di*)(setoutput[|b;i;j;k|](end_pool_fun()))done;(*k*)done;(*j*)done;(*i*)done;(*b*)outputendlet_pool3d?(padding=SAME)inputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun=letp0=(num_dimsinput=5)inletp1=(Array.lengthkernel=3)inletp2=(Array.lengthstride=3)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"_pool3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinlet_kind=kindinputinletoutput=empty_kind[|batches;output_cols;output_rows;output_dpts;in_channel|]inlet(pad_top,pad_left,pad_shallow,_,_,_)=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinbeginforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofordpt=0tooutput_dpts-1dofork=0toin_channel-1doinit_pool_fun();fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doford_dpt=0tokernel_dpts-1doletin_col=i*col_stride+di-pad_leftinletin_row=j*row_stride+dj-pad_topinletin_dpt=dpt*dpt_stride+d_dpt-pad_shallowinif((0<=in_col)&&(in_col<input_cols)&&(0<=in_row)&&(in_row<input_rows)&&(0<=in_dpt)&&(in_dpt<input_dpts))thenadd_val_pool_fun(getinput[|b;in_col;in_row;in_dpt;k|])done;(*d_dpt*)done;(*dj*)done;(*di*)(setoutput[|b;i;j;dpt;k|](end_pool_fun()))done;(*k*)done;(*dpt*)done;(*j*)done;(*i*)done;(*b*)outputend(* max_pool2d: 4d input and 2d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; input_channel]
*)letmax_pool2d?(padding=SAME)inputkernelstride=letmax_pool=ref0.inletinit_pool_fun=(fun()->max_pool:=Stdlib.min_float)inletadd_val_pool_fun=(funv->max_pool:=Stdlib.max!max_poolv)inletend_pool_fun=(fun()->!max_pool)in(_pool2d~padding:paddinginputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun)(* max_pool1d: 3d input and 1d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column]
stride: [column_stride]
output: [batch; output_column; input_channel]
*)letmax_pool1d?(padding=SAME)inputkernelstride=letp0=(num_dimsinput=3)inletp1=(Array.lengthkernel=1)inletp2=(Array.lengthstride=1)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=max_pool2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;in_channel|]inoutput(* max_pool3d: 5d input and 3d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; input_channel]
*)letmax_pool3d?(padding=SAME)inputkernelstride=letmax_pool=ref0.inletinit_pool_fun=(fun()->max_pool:=Stdlib.min_float)inletadd_val_pool_fun=(funv->max_pool:=Stdlib.max!max_poolv)inletend_pool_fun=(fun()->!max_pool)in(_pool3d~padding:paddinginputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun)(* similar to max_pool2d *)letavg_pool2d?(padding=SAME)inputkernelstride=letsum_pool=ref0.inletcnt=ref0.inletinit_pool_fun=(fun()->(sum_pool:=0.;cnt:=0.))inletadd_val_pool_fun=(funv->sum_pool:=!sum_pool+.v;cnt:=!cnt+.1.)inletend_pool_fun=(fun()->(!sum_pool/.!cnt))in(_pool2d~padding:paddinginputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun)(* similar to max_pool1d *)letavg_pool1d?(padding=SAME)inputkernelstride=letp0=(num_dimsinput=3)inletp1=(Array.lengthkernel=1)inletp2=(Array.lengthstride=1)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=avg_pool2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;in_channel|]inoutput(* simiar to max_pool3d *)letavg_pool3d?(padding=SAME)inputkernelstride=letsum_pool=ref0.inletcnt=ref0.inletinit_pool_fun=(fun()->(sum_pool:=0.;cnt:=0.))inletadd_val_pool_fun=(funv->sum_pool:=!sum_pool+.v;cnt:=!cnt+.1.)inletend_pool_fun=(fun()->(!sum_pool/.!cnt))in(_pool3d~padding:paddinginputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun)(*TODO: optimise *)(* gradient of conv2d w.r.t the input *)letconv2d_backward_inputinputkernelstrideoutput'=letp0=(num_dimsinput=4)inletp1=(num_dimskernel=4)inletp2=(num_dimsoutput'=4)inletp3=(Array.lengthstride=2)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=(in_channel=kernel_shp.(2))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel shape"inlets5=Printf.sprintf"conv2d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=(batches=output_shp.(0))inletp6=(out_channel=output_shp.(3))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of output' shape"inlets8=Printf.sprintf"conv2d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletinput'=empty(kindinput)(shapeinput)inlet(pad_top,pad_left,_,_)=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinbeginforb=0tobatches-1doforin_i=0toinput_cols-1doforin_j=0toinput_rows-1doforq=0toin_channel-1doletsum=ref0.infordi=0tokernel_cols-1dofordj=0tokernel_rows-1doif(((Stdlib.(mod)(in_i+pad_left-di)col_stride)=0)&&((Stdlib.(mod)(in_j+pad_top-dj)row_stride)=0))thenbeginletout_col=(in_i+pad_left-di)/col_strideinletout_row=(in_j+pad_top-dj)/row_strideinif((0<=out_col)&&(out_col<output_cols)&&(0<=out_row)&&(out_row<output_rows))thenfork=0toout_channel-1doletout_grad=getoutput'[|b;out_col;out_row;k|]inletkernel_val=getkernel[|di;dj;q;k|]insum:=!sum+.out_grad*.kernel_valdone;(*k*)enddone;(*dj*)done;(*di*)(setinput'[|b;in_i;in_j;q|]!sum)done;(*q*)done;(*in_j*)done;(*in_i*)done;(*b*)input'end(* gradient of conv2d w.r.t the kernel *)letconv2d_backward_kernelinputkernelstrideoutput'=letp0=(num_dimsinput=4)inletp1=(num_dimskernel=4)inletp2=(num_dimsoutput'=4)inletp3=(Array.lengthstride=2)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=(in_channel=kernel_shp.(2))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel shape"inlets5=Printf.sprintf"conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=(batches=output_shp.(0))inletp6=(out_channel=output_shp.(3))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletkernel'=empty(kindkernel)(shapekernel)inlet(pad_top,pad_left,_,_)=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinbeginfordi=0tokernel_cols-1dofordj=0tokernel_rows-1doforq=0toin_channel-1dofork=0toout_channel-1doletsum=ref0.inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1doletin_col=i*col_stride+di-pad_leftinletin_row=j*row_stride+dj-pad_topinif((0<=in_col)&&(in_col<input_cols)&&(0<=in_row)&&(in_row<input_rows))thenletout_grad=getoutput'[|b;i;j;k|]inletinput_val=getinput[|b;in_col;in_row;q|]insum:=!sum+.out_grad*.input_valdone;(*j*)done;(*i*)done;(*b*)setkernel'[|di;dj;q;k|]!sumdone;(*k*)done;(*q*)done;(*dj*)done;(*di*)kernel'endlettranspose?axisvarr=letdims=shapevarrinletrank=Array.lengthdimsinletaxis_perm=matchaxiswith|Someperm->perm|None->Array.initrank(funi->rank-i-1)inletnew_dims=_apply_permdimsaxis_perminletnew_varr=empty(kindvarr)new_dimsinletind=Array.makerank0inletshould_stop=reffalseinbeginwhilenot!should_stopdoGenarray.setnew_varr(_apply_permindaxis_perm)(Genarray.getvarrind);ifnot(_next_indexinddims)thenshould_stop:=truedone;new_varrend(* transpose_conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; output_channel]
*)lettranspose_conv2d?(padding=SAME)inputkernelstride=letp0=(num_dimsinput=4)inletp1=(num_dimskernel=4)inletp2=(Array.lengthstride=2)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=(in_channel=kernel_shp.(2))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel shape"inlets5=Printf.sprintf"transpose_conv2d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_transpose_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput'=empty(kindinput)[|batches;output_cols;output_rows;out_channel|]inletkernel=transpose~axis:[|0;1;3;2|]kernelinconv2d_backward_inputoutput'kernelstrideinput(* gradient of transpose_conv2d w.r.t the input *)lettranspose_conv2d_backward_inputinputkernelstrideoutput'=letp0=(num_dimsinput=4)inletp1=(num_dimskernel=4)inletp2=(num_dimsoutput'=4)inletp3=(Array.lengthstride=2)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv2d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=(in_channel=kernel_shp.(2))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel shape"inlets5=Printf.sprintf"transpose_conv2d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=(batches=output_shp.(0))inletp6=(out_channel=output_shp.(3))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of output' shape"inlets8=Printf.sprintf"transpose_conv2d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletpadding=SAMEinletoutput_cols_same,output_rows_same=Owl_utils_infer_shape.calc_transpose_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletp=if((output_cols_same=output_cols)&&(output_rows_same=output_rows))thenSAMEelseVALIDinletkernel=transpose~axis:[|0;1;3;2|]kernelinconv2d~padding:poutput'kernelstride(* gradient of transpose_conv2d w.r.t the kernel *)lettranspose_conv2d_backward_kernelinputkernelstrideoutput'=conv2d_backward_kerneloutput'kernelstrideinput(* transpose_conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_stride]
output: [batch; output_column; output_channel]
*)lettranspose_conv1d?(padding=SAME)inputkernelstride=letp0=(num_dimsinput=3)inletp1=(num_dimskernel=3)inletp2=(Array.lengthstride=1)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=(in_channel=kernel_shp.(1))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel shape"inlets5=Printf.sprintf"transpose_conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=transpose_conv2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutput(* gradient of conv1d w.r.t the input *)letconv1d_backward_inputinputkernelstrideoutput'=letp0=(num_dimsinput=3)inletp1=(num_dimskernel=3)inletp2=(num_dimsoutput'=3)inletp3=(Array.lengthstride=1)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=(in_channel=kernel_shp.(1))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel shape"inlets5=Printf.sprintf"conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=(batches=output'_shp.(0))inletp6=(out_channel=output'_shp.(2))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of output' shape"inlets8=Printf.sprintf"conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=conv2d_backward_inputinputkernelstrideoutput'inreshapeinput'input_shp(* gradient of conv1d w.r.t the kernel *)letconv1d_backward_kernelinputkernelstrideoutput'=letp0=(num_dimsinput=3)inletp1=(num_dimskernel=3)inletp2=(num_dimsoutput'=3)inletp3=(Array.lengthstride=1)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=(in_channel=kernel_shp.(1))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel shape"inlets5=Printf.sprintf"conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=(batches=output'_shp.(0))inletp6=(out_channel=output'_shp.(2))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of output' shape"inlets8=Printf.sprintf"conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=conv2d_backward_kernelinputkernelstrideoutput'inreshapekernel'kernel_shp(* gradient of transpose_conv1d w.r.t the input *)lettranspose_conv1d_backward_inputinputkernelstrideoutput'=letp0=(num_dimsinput=3)inletp1=(num_dimskernel=3)inletp2=(num_dimsoutput'=3)inletp3=(Array.lengthstride=1)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=(in_channel=kernel_shp.(1))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel shape"inlets5=Printf.sprintf"transpose_conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=(batches=output'_shp.(0))inletp6=(out_channel=output'_shp.(2))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=transpose_conv2d_backward_inputinputkernelstrideoutput'inreshapeinput'input_shp(* gradient of conv1d w.r.t the kernel *)lettranspose_conv1d_backward_kernelinputkernelstrideoutput'=letp0=(num_dimsinput=3)inletp1=(num_dimskernel=3)inletp2=(num_dimsoutput'=3)inletp3=(Array.lengthstride=1)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=(in_channel=kernel_shp.(1))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel shape"inlets5=Printf.sprintf"transpose_conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=(batches=output'_shp.(0))inletp6=(out_channel=output'_shp.(2))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=transpose_conv2d_backward_kernelinputkernelstrideoutput'inreshapekernel'kernel_shp(*TODO: optimise *)(* gradient of conv3d w.r.t the input *)letconv3d_backward_inputinputkernelstrideoutput'=letp0=(num_dimsinput=5)inletp1=(num_dimskernel=5)inletp2=(num_dimsoutput'=5)inletp3=(Array.lengthstride=3)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=(in_channel=kernel_shp.(3))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel shape"inlets5=Printf.sprintf"conv3d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=(batches=output_shp.(0))inletp6=(out_channel=output_shp.(4))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of output' shape"inlets8=Printf.sprintf"conv3d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletinput'=empty(kindinput)(shapeinput)inlet(pad_top,pad_left,pad_shallow,_,_,_)=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinbeginforb=0tobatches-1doforin_i=0toinput_cols-1doforin_j=0toinput_rows-1doforin_dpt=0toinput_dpts-1doforq=0toin_channel-1doletsum=ref0.infordi=0tokernel_cols-1dofordj=0tokernel_rows-1doford_dpt=0tokernel_dpts-1doif(((Stdlib.(mod)(in_i+pad_left-di)col_stride)=0)&&((Stdlib.(mod)(in_j+pad_top-dj)row_stride)=0)&&((Stdlib.(mod)(in_dpt+pad_shallow-d_dpt)dpt_stride)=0))thenbeginletout_col=(in_i+pad_left-di)/col_strideinletout_row=(in_j+pad_top-dj)/row_strideinletout_dpt=(in_dpt+pad_shallow-d_dpt)/dpt_strideinif((0<=out_col)&&(out_col<output_cols)&&(0<=out_row)&&(out_row<output_rows)&&(0<=out_dpt)&&(out_dpt<output_dpts))thenfork=0toout_channel-1doletout_grad=getoutput'[|b;out_col;out_row;out_dpt;k|]inletkernel_val=getkernel[|di;dj;d_dpt;q;k|]insum:=!sum+.out_grad*.kernel_valdone;(*k*)enddone;(*d_dpt*)done;(*dj*)done;(*di*)(setinput'[|b;in_i;in_j;in_dpt;q|]!sum)done;(*q*)done;(*in_dpt*)done;(*in_j*)done;(*in_i*)done;(*b*)input'end(* gradient of conv3d w.r.t the kernel *)letconv3d_backward_kernelinputkernelstrideoutput'=letp0=(num_dimsinput=5)inletp1=(num_dimskernel=5)inletp2=(num_dimsoutput'=5)inletp3=(Array.lengthstride=3)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=(in_channel=kernel_shp.(3))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel shape"inlets5=Printf.sprintf"conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=(batches=output_shp.(0))inletp6=(out_channel=output_shp.(4))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletkernel'=empty(kindkernel)(shapekernel)inlet(pad_top,pad_left,pad_shallow,_,_,_)=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinbeginfordi=0tokernel_cols-1dofordj=0tokernel_rows-1doford_dpt=0tokernel_dpts-1doforq=0toin_channel-1dofork=0toout_channel-1doletsum=ref0.inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofordpt=0tooutput_dpts-1doletin_col=i*col_stride+di-pad_leftinletin_row=j*row_stride+dj-pad_topinletin_dpt=dpt*dpt_stride+d_dpt-pad_shallowinif((0<=in_col)&&(in_col<input_cols)&&(0<=in_row)&&(in_row<input_rows)&&(0<=in_dpt)&&(in_dpt<input_dpts))thenletout_grad=getoutput'[|b;i;j;dpt;k|]inletinput_val=getinput[|b;in_col;in_row;in_dpt;q|]insum:=!sum+.out_grad*.input_valdone;(*dpt*)done;(*j*)done;(*i*)done;(*b*)setkernel'[|di;dj;d_dpt;q;k|]!sumdone;(*k*)done;(*q*)done;(*d_dpt*)done;(*dj*)done;(*di*)kernel'end(* transpose_conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)lettranspose_conv3d?(padding=SAME)inputkernelstride=letp0=(num_dimsinput=5)inletp1=(num_dimskernel=5)inletp2=(Array.lengthstride=3)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=(in_channel=kernel_shp.(3))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel shape"inlets5=Printf.sprintf"transpose_conv3d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;out_channel|]inletkernel=transpose~axis:[|0;1;2;4;3|]kernelinconv3d_backward_inputoutputkernelstrideinput(* gradient of transpose_conv3d w.r.t the input *)lettranspose_conv3d_backward_inputinputkernelstrideoutput'=letp0=(num_dimsinput=5)inletp1=(num_dimskernel=5)inletp2=(num_dimsoutput'=5)inletp3=(Array.lengthstride=3)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv3d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=(in_channel=kernel_shp.(3))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel shape"inlets5=Printf.sprintf"transpose_conv3d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=(batches=output_shp.(0))inletp6=(out_channel=output_shp.(4))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of output' shape"inlets8=Printf.sprintf"transpose_conv3d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletpadding=SAMEinletoutput_cols_same,output_rows_same,output_dpts_same=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletp=if((output_cols_same=output_cols)&&(output_rows_same=output_rows)&&(output_dpts_same=output_dpts))thenSAMEelseVALIDinletkernel=transpose~axis:[|0;1;2;4;3|]kernelinconv3d~padding:poutput'kernelstride(* gradient of transpose_conv3d w.r.t the kernel *)lettranspose_conv3d_backward_kernelinputkernelstrideoutput'=conv3d_backward_kerneloutput'kernelstrideinput(* TODO: definitely optimise *)(* General function for avg_pool2d and max_pool2d *)let_pool2d_backward_paddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun=letp0=(num_dimsinput=4)inletp1=(Array.lengthkernel=2)inletp2=(Array.lengthstride=2)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"_pool2d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=(batches=output_shp.(0))inletp6=(in_channel=output_shp.(3))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Printf.sprintf"input shape is [%s]"s0inlets3=Printf.sprintf"output' shape is [%s]"s1inlets4=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets5=Printf.sprintf"the 4th dimension of input shape should be equal to the 4th dimension of output' shape"inlets6=Printf.sprintf"_pool2d_backward: %s; %s; %s; %s."s2s3s4s5inOwl_exception.INVALID_ARGUMENTs6inOwl_exception.verify(p5&&p6)error;let(pad_top,pad_left,_,_)=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinletinput'=zeros(kindinput)(shapeinput)inbeginforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofork=0toin_channel-1doinit_pool_fun();fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doletin_col=i*col_stride+di-pad_leftinletin_row=j*row_stride+dj-pad_topinif((0<=in_col)&&(in_col<input_cols)&&(0<=in_row)&&(in_row<input_rows))thenadd_val_pool_fun(getinput[|b;in_col;in_row;k|])done;(*dj*)done;(*di*)letoutput_val=end_pool_fun()inletoutput_grad=getoutput'[|b;i;j;k|]infordi=0tokernel_cols-1dofordj=0tokernel_rows-1doletin_col=i*col_stride+di-pad_leftinletin_row=j*row_stride+dj-pad_topinif((0<=in_col)&&(in_col<input_cols)&&(0<=in_row)&&(in_row<input_rows))thenletinput_val=(getinput[|b;in_col;in_row;k|])inletinput_grad=(getinput'[|b;in_col;in_row;k|])insetinput'[|b;in_col;in_row;k|](compute_grad_funinput_valinput_gradoutput_valoutput_grad)done;(*dj*)done;(*di*)done;(*k*)done;(*j*)done;(*i*)done;(*b*)input'end(* calculate the gradient of max_pool2d *)letmax_pool2d_backwardpaddinginputkernelstrideoutput'=letmax_pool=ref0.inletinit_pool_fun=(fun()->max_pool:=Stdlib.min_float)inletadd_val_pool_fun=(funv->max_pool:=Stdlib.max!max_poolv)inletend_pool_fun=(fun()->!max_pool)inletcompute_grad_fun=(funinput_valinput_gradoutput_valoutput_grad->if((Scalar.abs(input_val-.output_val))<1e-8)(*TODO: change comparison here *)theninput_grad+.output_gradelseinput_grad)in(_pool2d_backwardpaddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun)(* calculate the gradient of avg_pool2d *)letavg_pool2d_backwardpaddinginputkernelstrideoutput'=letsum_pool=ref0.inletcnt=ref0.inletinit_pool_fun=(fun()->(sum_pool:=0.;cnt:=0.))inletadd_val_pool_fun=(funv->sum_pool:=!sum_pool+.v;cnt:=!cnt+.1.)inletend_pool_fun=(fun()->(!sum_pool/.!cnt))inletcompute_grad_fun=(fun_input_valinput_grad_output_valoutput_grad->input_grad+.output_grad/.!cnt)in(_pool2d_backwardpaddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun)(* TODO: definitely optimise *)(* General function for avg_pool3d and max_pool3d *)let_pool3d_backward_paddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun=letp0=(num_dimsinput=5)inletp1=(Array.lengthkernel=3)inletp2=(Array.lengthstride=3)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"_pool3d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=(batches=output_shp.(0))inletp6=(in_channel=output_shp.(4))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Printf.sprintf"input shape is [%s]"s0inlets3=Printf.sprintf"output' shape is [%s]"s1inlets4=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets5=Printf.sprintf"the 5th dimension of input shape should be equal to the 5th dimension of output' shape"inlets6=Printf.sprintf"_pool3d_backward: %s; %s; %s; %s."s2s3s4s5inOwl_exception.INVALID_ARGUMENTs6inOwl_exception.verify(p5&&p6)error;let(pad_top,pad_left,pad_shallow,_,_,_)=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinletinput'=zeros(kindinput)(shapeinput)inbeginforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofordpt=0tooutput_dpts-1dofork=0toin_channel-1doinit_pool_fun();fordi=0tokernel_cols-1dofordj=0tokernel_rows-1dofordk=0tokernel_dpts-1doletin_col=i*col_stride+di-pad_leftinletin_row=j*row_stride+dj-pad_topinletin_dpt=dpt*dpt_stride+dk-pad_shallowinif((0<=in_col)&&(in_col<input_cols)&&(0<=in_row)&&(in_row<input_rows)&&(0<=in_dpt)&&(in_dpt<input_dpts))thenadd_val_pool_fun(getinput[|b;in_col;in_row;in_dpt;k|])done;(*dk*)done;(*dj*)done;(*di*)letoutput_val=end_pool_fun()inletoutput_grad=getoutput'[|b;i;j;dpt;k|]infordi=0tokernel_cols-1dofordj=0tokernel_rows-1dofordk=0tokernel_dpts-1doletin_col=i*col_stride+di-pad_leftinletin_row=j*row_stride+dj-pad_topinletin_dpt=dpt*dpt_stride+dk-pad_shallowinif((0<=in_col)&&(in_col<input_cols)&&(0<=in_row)&&(in_row<input_rows)&&(0<=in_dpt)&&(in_dpt<input_dpts))thenletinput_val=(getinput[|b;in_col;in_row;in_dpt;k|])inletinput_grad=(getinput'[|b;in_col;in_row;in_dpt;k|])insetinput'[|b;in_col;in_row;in_dpt;k|](compute_grad_funinput_valinput_gradoutput_valoutput_grad)done;(*dk*)done;(*dj*)done;(*di*)done;(*k*)done;(*dpt*)done;(*j*)done;(*i*)done;(*b*)input'end(* calculate the gradient of max_pool3d *)letmax_pool3d_backwardpaddinginputkernelstrideoutput'=letmax_pool=ref0.inletinit_pool_fun=(fun()->max_pool:=Stdlib.min_float)inletadd_val_pool_fun=(funv->max_pool:=Stdlib.max!max_poolv)inletend_pool_fun=(fun()->!max_pool)inletcompute_grad_fun=(funinput_valinput_gradoutput_valoutput_grad->if((Scalar.abs(input_val-.output_val))<1e-8)(*TODO: change comparison here *)theninput_grad+.output_gradelseinput_grad)in(_pool3d_backwardpaddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun)(* calculate the gradient of avg_pool3d *)letavg_pool3d_backwardpaddinginputkernelstrideoutput'=letsum_pool=ref0.inletcnt=ref0.inletinit_pool_fun=(fun()->(sum_pool:=0.;cnt:=0.))inletadd_val_pool_fun=(funv->sum_pool:=!sum_pool+.v;cnt:=!cnt+.1.)inletend_pool_fun=(fun()->(!sum_pool/.!cnt))inletcompute_grad_fun=(fun_input_valinput_grad_output_valoutput_grad->input_grad+.output_grad/.!cnt)in(_pool3d_backwardpaddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun)(* calculate the gradient of max_pool1d *)letmax_pool1d_backwardpaddinginputkernelstrideoutput'=letp0=(num_dimsinput=3)inletp1=(Array.lengthkernel=1)inletp2=(Array.lengthstride=1)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletinput'=max_pool2d_backwardpaddinginputkernelstrideoutput'inreshapeinput'input_shp(* calculate the gradient of avg_pool1d *)letavg_pool1d_backwardpaddinginputkernelstrideoutput'=letp0=(num_dimsinput=3)inletp1=(Array.lengthkernel=1)inletp2=(Array.lengthstride=1)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletinput'=avg_pool2d_backwardpaddinginputkernelstrideoutput'inreshapeinput'input_shp(* create a dilated 2d kernel *)letupsample_kernel2dkernelrate=ifrate=[|1;1|]thenkernelelse(letkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletin_channel=kernel_shp.(2)inletout_channel=kernel_shp.(3)inletcol_rate=rate.(0)inletrow_rate=rate.(1)inletcol_up=kernel_cols+(kernel_cols-1)*(col_rate-1)inletrow_up=kernel_rows+(kernel_rows-1)*(row_rate-1)inletnew_kernel=zeros(kindkernel)[|col_up;row_up;in_channel;out_channel|]inforc=0to(kernel_cols-1)doforr=0to(kernel_rows-1)dofori=0to(in_channel-1)doforo=0to(out_channel-1)doletv=getkernel[|c;r;i;o|]insetnew_kernel[|c*col_rate;r*row_rate;i;o|]v;donedonedonedone;new_kernel)(* change a dilated 2d kernel back to normal *)letdownsample_kernel2dkernelrate=ifrate=[|1;1|]thenkernelelse(letkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletin_channel=kernel_shp.(2)inletout_channel=kernel_shp.(3)inletcol_rate=rate.(0)inletrow_rate=rate.(1)inletcol_down=(kernel_cols+(col_rate-1))/col_rateinletrow_down=(kernel_rows+(row_rate-1))/row_rateinletnew_kernel=zeros(kindkernel)[|col_down;row_down;in_channel;out_channel|]inforc=0to(col_down-1)doforr=0to(row_down-1)dofori=0to(in_channel-1)doforo=0to(out_channel-1)doletv=getkernel[|c*col_rate;r*row_rate;i;o|]insetnew_kernel[|c;r;i;o|]vdonedonedonedone;new_kernel)(* dilated_conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
rate : [col_dilation_rate; row_dilation_rate]
output: [batch; output_column; output_row; output_channel]
*)letdilated_conv2d?(padding=SAME)inputkernelstriderate=letp0=(Array.lengthrate=2)inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv2d: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel2dkernelrateinconv2d~paddinginputkernelstride(* gradient of dilated_conv2d w.r.t the input *)letdilated_conv2d_backward_inputinputkernelstriderateoutput'=letp0=(Array.lengthrate=2)inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv2d_backward_input: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel2dkernelrateinconv2d_backward_inputinputkernelstrideoutput'(* gradient of dilated_conv2d w.r.t the kernel *)letdilated_conv2d_backward_kernelinputkernelstriderateoutput'=letp0=(Array.lengthrate=2)inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv2d_backward_kernel: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel2dkernelrateinletkernel'=conv2d_backward_kernelinputkernelstrideoutput'indownsample_kernel2dkernel'rate(* create a dilated 3d kernel *)letupsample_kernel3dkernelrate=ifrate=[|1;1;1|]thenkernelelse(letkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletin_channel=kernel_shp.(3)inletout_channel=kernel_shp.(4)inletcol_rate=rate.(0)inletrow_rate=rate.(1)inletdpt_rate=rate.(2)inletcol_up=kernel_cols+(kernel_cols-1)*(col_rate-1)inletrow_up=kernel_rows+(kernel_rows-1)*(row_rate-1)inletdpt_up=kernel_dpts+(kernel_dpts-1)*(dpt_rate-1)inletnew_kernel=zeros(kindkernel)[|col_up;row_up;dpt_up;in_channel;out_channel|]inforc=0to(kernel_cols-1)doforr=0to(kernel_rows-1)doford=0to(kernel_dpts-1)dofori=0to(in_channel-1)doforo=0to(out_channel-1)doletv=getkernel[|c;r;d;i;o|]insetnew_kernel[|c*col_rate;r*row_rate;d*dpt_rate;i;o|]v;donedonedonedonedone;new_kernel)(* change a dilated 3d kernel back to normal *)letdownsample_kernel3dkernelrate=ifrate=[|1;1;1|]thenkernelelse(letkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletin_channel=kernel_shp.(3)inletout_channel=kernel_shp.(4)inletcol_rate=rate.(0)inletrow_rate=rate.(1)inletdpt_rate=rate.(2)inletcol_down=(kernel_cols+(col_rate-1))/col_rateinletrow_down=(kernel_rows+(row_rate-1))/row_rateinletdpt_down=(kernel_dpts+(dpt_rate-1))/dpt_rateinletnew_kernel=zeros(kindkernel)[|col_down;row_down;dpt_down;in_channel;out_channel|]inforc=0to(col_down-1)doforr=0to(row_down-1)doford=0to(dpt_down-1)dofori=0to(in_channel-1)doforo=0to(out_channel-1)doletv=getkernel[|c*col_rate;r*row_rate;d*dpt_rate;i;o|]insetnew_kernel[|c;r;d;i;o|]vdonedonedonedonedone;new_kernel)(* dilated_conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
rate : [col_dilation_rate; row_dilation_rate; depth_dilation_rate]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)letdilated_conv3d?(padding=SAME)inputkernelstriderate=letp0=(Array.lengthrate=3)inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv3d: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel3dkernelrateinconv3d~paddinginputkernelstride(* gradient of dilated_conv3d w.r.t the input *)letdilated_conv3d_backward_inputinputkernelstriderateoutput'=letp0=(Array.lengthrate=3)inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv3d_backward_input: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel3dkernelrateinconv3d_backward_inputinputkernelstrideoutput'(* gradient of dilated_conv3d w.r.t the kernel *)letdilated_conv3d_backward_kernelinputkernelstriderateoutput'=letp0=(Array.lengthrate=3)inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv3d_backward_kernel: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel3dkernelrateinletkernel'=conv3d_backward_kernelinputkernelstrideoutput'indownsample_kernel3dkernel'rate(* dilated_conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_rate]
output: [batch; output_column; output_channel]
*)letdilated_conv1d?(padding=SAME)inputkernelstriderate=letp0=(num_dimsinput=3)inletp1=(num_dimskernel=3)inletp2=(Array.lengthstride=1)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"dilated_conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=(in_channel=kernel_shp.(1))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel shape"inlets5=Printf.sprintf"dilated_conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=dilated_conv2d~paddinginputkernelstriderateinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutput(* gradient of dilated_conv1d w.r.t the input *)letdilated_conv1d_backward_inputinputkernelstriderateoutput'=letp0=(num_dimsinput=3)inletp1=(num_dimskernel=3)inletp2=(num_dimsoutput'=3)inletp3=(Array.lengthstride=1)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=(in_channel=kernel_shp.(1))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel shape"inlets5=Printf.sprintf"dilated_conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=(batches=output'_shp.(0))inletp6=(out_channel=output'_shp.(2))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=dilated_conv2d_backward_inputinputkernelstriderateoutput'inreshapeinput'input_shp(* gradient of dilated_conv1d w.r.t the kernel *)letdilated_conv1d_backward_kernelinputkernelstriderateoutput'=letp0=(num_dimsinput=3)inletp1=(num_dimskernel=3)inletp2=(num_dimsoutput'=3)inletp3=(Array.lengthstride=1)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=(in_channel=kernel_shp.(1))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel shape"inlets5=Printf.sprintf"dilated_conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=(batches=output'_shp.(0))inletp6=(out_channel=output'_shp.(2))inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=dilated_conv2d_backward_kernelinputkernelstriderateoutput'inreshapekernel'kernel_shpletupsampling2dinputsize=letp0=(num_dimsinput=4)inletp1=(Array.lengthsize=2)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;repeatinput[|1;size.(0);size.(1);1|]letupsampling2d_backwardinputsizeoutput=letp0=(num_dimsinput=4)inletp1=(Array.lengthsize=2)inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d_backward: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;let_kind=kindinputinletinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletcol_scale=size.(0)inletrow_scale=size.(1)inletoutput_shp=shapeoutputinletoutput_cols=input_cols*col_scaleinletoutput_rows=input_rows*row_scaleinletp2=(output_cols=output_shp.(1))inletp3=(output_rows=output_shp.(2))inleterror()=lets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Printf.sprintf"output shape is [%s]"s1inlets3=Printf.sprintf"scaled output cols is %i, should be equal to the 2nd dimension of output shape"output_colsinlets4=Printf.sprintf"scaled output rows is %i, should be equal to the 3rd dimension of output shape"output_rowsinlets5=Printf.sprintf"upsampling2d_backward: %s; %s; %s."s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p2&&p3)error;letinput'=zeros_kindinput_shpinforb=0tobatches-1doforc=0tooutput_cols-1doletin_c=c/col_scaleinletin_c=Stdlib.minin_c(input_cols-1)inforr=0tooutput_rows-1doletin_r=r/row_scaleinletin_r=Stdlib.minin_r(input_rows-1)infori=0toin_channel-1doletin_val=getinput'[|b;in_c;in_r;i|]inletout_val=getoutput[|b;c;r;i|]insetinput'[|b;in_c;in_r;i|](in_val+.out_val)donedonedonedone;input'(* matrix functions *)let_remove_unit_dimsdims=letremoved_ones_list=List.filter(funx->x>1)(Array.to_listdims)inletnot_empty_list=matchremoved_ones_listwith|[]->[1]|_->removed_ones_listin(Array.of_listnot_empty_list)let_check_is_matrixdims=if(Array.lengthdims)!=2thenraise(Invalid_argument"The given NDarray is not a matrix!")else()letrow_numvarr=letdims=shapevarrin(_check_is_matrixdims;dims.(0))letcol_numvarr=letdims=shapevarrin(_check_is_matrixdims;dims.(1))(* NOTE: this is a view into the original array *)letrowvarrind=letdims=shapevarrin(_check_is_matrixdims;Genarray.slice_leftvarr[|ind|])letrowsvarrindices=letdims=shapevarrinlet_=_check_is_matrixdimsinletnew_rownum=Array.lengthindicesinletnew_colnum=dims.(1)inletnew_varr=empty(kindvarr)[|new_rownum;new_colnum|]inbeginfori=0tonew_rownum-1doGenarray.blit(Genarray.slice_leftvarr[|indices.(i)|])(* indices[i] row of the original *)(Genarray.slice_leftnew_varr[|i|])(* i-th row of the new matrix *)done;new_varrendletcopy_row_tovecvarrind=letdims=shapevarrinlet_=_check_is_matrixdimsin(Genarray.blitvec(Genarray.slice_leftvarr[|ind|]))letcopy_col_tovecvarrind=letdims=shapevarrinlet_=_check_is_matrixdimsinletvec_dims=_remove_unit_dims(shapevec)inletvec_len=if(Array.lengthvec_dims)=1thenvec_dims.(0)elseraise(Invalid_argument"Vector is not a column vector")inletnum_rows=dims.(0)inletvec_linear=flattenvec|>array1_of_genarrayinifnum_rows!=vec_lenthenraise(Invalid_argument"Column vector does not have the same length as the number of rows in the matrix")elsebeginfori=0tonum_rows-1doGenarray.setvarr[|i;ind|](Array1.unsafe_getvec_lineari)doneendletdotvarr_avarr_b=let(dims_a,dims_b)=(shapevarr_a,shapevarr_b)inlet(_,_)=(_check_is_matrixdims_a,_check_is_matrixdims_b)inletm=dims_a.(0)inletcdim=dims_a.(1)inletn=dims_b.(1)inif(dims_b.(0))!=cdimthenraise(Invalid_argument"Matrices cannot be multipled")elseletvarr_c=empty(kindvarr_a)[|m;n|]inletsum=ref0.inbeginfori=0tom-1doforj=0ton-1dosum:=0.;fork=0tocdim-1dosum:=!sum+.((Genarray.getvarr_a[|i;k|])*.(Genarray.getvarr_b[|k;j|]))done;Genarray.setvarr_c[|i;j|]!sumdonedone;varr_cendlettracevarr=letdims=shapevarrinlet_=_check_is_matrixdimsinletn=dims.(0)inifdims.(1)!=nthenraise(Invalid_argument"Argument is not a square matrix")elseletsum=ref0.inbeginfori=0ton-1dosum:=!sum+.(Genarray.getvarr[|i;i|])done;!sumend(* NOTE: each row is actually a view in the original matrix, no copying involved *)letto_rowsvarr=letdims=shapevarrinlet_=_check_is_matrixdimsinletm=dims.(0)in(Array.initm(funi->(Genarray.slice_leftvarr[|i|])))letto_cols_harr=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.to_cols")letof_rowsrows=letm=Array.lengthrowsinletrow_dim=shape(rows.(0))inletdims=Array.append[|m|]row_diminletvarr=empty(kindrows.(0))dimsinbeginfori=0tom-1doGenarray.blitrows.(i)(Genarray.slice_leftvarr[|i|])done;varrendletof_cols_cols=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.of_cols")letof_arrayskindarrays=letm=Array.lengtharraysinletn=Array.length(arrays.(0))inletvarr=emptykind[|m;n|]inbeginfori=0tom-1doforj=0ton-1doGenarray.setvarr[|i;j|](Array.unsafe_get(arrays.(i))j)donedone;varrendletdraw_rows?(replacement=true)varrcount=letdims=shapevarrinletindices=_draw_int_samplesreplacement(Array.lengthdims)countinletextracted=rowsvarrindicesin(extracted,indices)letdraw_rows2?(replacement=true)varr_avarr_bcount=letextracted_a,indices=draw_rows~replacement:replacementvarr_acountinletextracted_b=rowsvarr_bindicesin(extracted_a,extracted_b,indices)(* TODO: optimise and test *)(*
Implementing the following algorithm:
http://www.irma-international.org/viewtitle/41011/ *)letinvvarr=letdims=shapevarrinlet_=_check_is_matrixdimsinletn=Array.unsafe_getdims0inif(Array.unsafe_getdims1)!=nthenfailwith"no inverse - the matrix is not square"elseletpivot_row=Array.maken0.inletresult_varr=copyvarrinbeginforp=0ton-1doletpivot_elem=getresult_varr[|p;p|]inifgetresult_varr[|p;p|]=0.thenfailwith"the matrix does not have an inverse";(* update elements of the pivot row, save old vals *)forj=0ton-1dopivot_row.(j)<-getresult_varr[|p;j|];ifj!=pthensetresult_varr[|p;j|](pivot_row.(j)/.pivot_elem)done;(* update elements of the pivot col *)fori=0ton-1doifi!=pthensetresult_varr[|i;p|]((getresult_varr[|i;p|])/.(~-.pivot_elem))done;(* update the rest of the matrix *)fori=0ton-1doletpivot_col_elem=getresult_varr[|i;p|]inforj=0ton-1doifi!=p&&j!=pthenletpivot_row_elem=pivot_row.(j)in(* use old value *)letold_val=getresult_varr[|i;j|]inletnew_val=old_val+.(pivot_row_elem*.pivot_col_elem)in(setresult_varr[|i;j|]new_val)done;done;(* update the pivot element *)setresult_varr[|p;p|](1./.pivot_elem)done;result_varrendletlogdet_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.logdet")letqr_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.qr")letlq_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.lq")letchol?(upper=true)_x=upper|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.chol")letsvd?(thin=true)_x=thin|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.svd")letsylvester_a_b_c=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.sylvester")letlyapunov_a_q=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.lyapunov")letdiscrete_lyapunov?(solver=`default)_a_q=solver|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.discrete_lyapunov")letlinsolve?(trans=false)?(typ=`n)_a_b=trans|>ignore;typ|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.linsolve")letcare?(diag_r=false)_a_b_q_r=diag_r|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.care")letdiag?(k=0)_x=k|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.diag")letdiagm?(k=0)_x=k|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.diagm")lettril?(k=0)_x=k|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.tril")lettriu?(k=0)_x=k|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.triu")(* TODO: here k is not used, but neither is it in nonbase dense array? - investigate *)letload_kf=Owl_io.marshal_from_filefletmax_rowsvarr=letdims=shapevarrinlet_=_check_is_matrixdimsinletr,c=dims.(0),dims.(1)inletresult=Array.maker(0.,0,0)inbeginfori=0tor-1doletbest=refStdlib.min_floatinletbest_pos=ref~-1inforj=0toc-1doletx=getvarr[|i;j|]inif(x>!best)then(best:=x;best_pos:=j)done;result.(i)<-(!best,i,!best_pos)done;resultendletone_hot_depth_x=failwith"Owl_base_dense_ndarray_generic:one_hot: not implemented"(* Helper functions *)letfloat_to_eltx=xletelt_to_floatx=x(* ends here *)