123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382# 1 "src/base/linalg/owl_base_linalg_generic.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2019 Liang Wang <liang.wang@cl.cam.ac.uk>
*)type('a,'b)t=('a,'b)Owl_base_dense_ndarray_generic.tmoduleM=Owl_base_dense_ndarray_generic(* Check matrix properties *)letis_triux=letshp=M.shapexinletm,n=shp.(0),shp.(1)inletk=Stdlib.minmninlet_a0=Owl_const.zero(M.kindx)intryfori=0tok-1doforj=0toi-1doassert(M.getx[|i;j|]=_a0)donedone;truewith_exn->falseletis_trilx=letshp=M.shapexinletm,n=shp.(0),shp.(1)inletk=Stdlib.minmninlet_a0=Owl_const.zero(M.kindx)intryfori=0tok-1doforj=i+1tok-1doassert(M.getx[|i;j|]=_a0)donedone;truewith_exn->falseletis_symmetricx=letshp=M.shapexinletm,n=shp.(0),shp.(1)inifm<>nthenfalseelse(tryfori=0ton-1doforj=i+1ton-1doleta=M.getx[|j;i|]inletb=M.getx[|i;j|]inassert(a=b)donedone;truewith_exn->false)letis_hermitianx=letshp=M.shapexinletm,n=shp.(0),shp.(1)inifm<>nthenfalseelse(tryfori=0ton-1doforj=iton-1doleta=M.getx[|j;i|]inletb=Complex.conj(M.getx[|i;j|])inassert(a=b)donedone;truewith_exn->false)letis_diagx=is_triux&&is_trilxlet_check_is_matrixdims=if(Array.lengthdims)<>2thenraise(Invalid_argument"The given NDarray is not a matrix!")else()(* ======= WARNING: the linalg functions below are experimental. ======= *)(* ========= Corner cases etc. are not sufficiently tested. ============ *)(* Linear equation solution by Gauss-Jordan elimination.
* Input matrix: a[n][n], b[n][m];
* Output: ``ainv``, inversed matrix of a; ``x``, so that ax = b.
* TODO: Extend to multiple types: double, complex; unify with existing owl
* structures e.g. naming.
* Test: https://github.com/scipy/scipy/blob/master/scipy/linalg/tests/test_basic.py#L496 *)letlinsolve_gaussab=let(dims_a,dims_b)=(M.shapea,M.shapeb)inlet(_,_)=(_check_is_matrixdims_a,_check_is_matrixdims_b)inleta=M.copyainletb=M.copybinletn=dims_a.(0)inletm=dims_b.(1)inleticol=ref0inletirow=ref0inletdum=ref0.0inletpivinv=ref0.0inletindxc=Array.maken0inletindxr=Array.maken0inletipiv=Array.maken0in(* Main loop over the columns to be reduced. *)fori=0ton-1doletbig=ref0.0in(* Outer loop of the search for at pivot element *)forj=0ton-1doifipiv.(j)<>1then(fork=0ton-1doifipiv.(k)==0then(letv=M.geta[|j;k|]|>abs_floatinif(v>=!big)then(big:=v;irow:=j;icol:=k;))done)done;ipiv.(!icol)<-ipiv.(!icol)+1;if(!irow<>!icol)then(forl=0ton-1doletu=M.geta[|!irow;l|]inletv=M.geta[|!icol;l|]inM.seta[|!icol;l|]u;M.seta[|!irow;l|]vdone;forl=0tom-1doletu=M.getb[|!irow;l|]inletv=M.getb[|!icol;l|]inM.setb[|!icol;l|]u;M.setb[|!irow;l|]vdone);indxr.(i)<-!irow;indxc.(i)<-!icol;letp=M.geta[|!icol;!icol|]inif(p=0.0)thenraiseOwl_exception.SINGULAR;pivinv:=1.0/.p;M.seta[|!icol;!icol|]1.0;forl=0ton-1doletprev=M.geta[|!icol;l|]inM.seta[|!icol;l|](prev*.!pivinv)done;forl=0tom-1doletprev=M.getb[|!icol;l|]inM.setb[|!icol;l|](prev*.!pivinv)done;forll=0ton-1doif(ll<>!icol)then(dum:=M.geta[|ll;!icol|];M.seta[|ll;!icol|]0.0;forl=0ton-1doletp=M.geta[|!icol;l|]inletprev=M.geta[|ll;l|]inM.seta[|ll;l|](prev-.p*.!dum)done;forl=0tom-1doletp=M.getb[|!icol;l|]inletprev=M.getb[|ll;l|]inM.setb[|ll;l|](prev-.p*.!dum)done)donedone;forl=n-1downto0doif(indxr.(l)<>indxc.(l))then(fork=0ton-1doletu=M.geta[|k;indxr.(l)|]inletv=M.geta[|k;indxc.(l)|]inM.seta[|k;indxc.(l)|]u;M.seta[|k;indxr.(l)|]vdone)done;a,b(* LU decomposition.
* Input matrix: a[n][n]; return L/U in one matrix, and the row permutation vector.
* Test: https://github.com/scipy/scipy/blob/master/scipy/linalg/tests/test_decomp.py
*)let_lu_basea=letlu=M.copyainletn=(M.shapea).(0)inletm=(M.shapea).(1)inassert(n=m);letindx=Array.maken0in(* implicit scaling of each row *)letvv=Array.maken0.inlettiny=1.0e-40inletbig=ref0.inlettemp=ref0.in(* flag of row exchange *)letd=ref1.0inletimax=ref0in(* loop over rows to get the implicit scaling information *)fori=0ton-1dobig:=0.;forj=0ton-1dotemp:=M.getlu[|i;j|]|>abs_float;if!temp>!bigthenbig:=!tempdone;if!big=0.thenraiseOwl_exception.SINGULAR;vv.(i)<-1.0/.!bigdone;fork=0ton-1dobig:=0.;(* choose suitable pivot *)fori=kton-1dotemp:=(M.getlu[|i;k|]|>abs_float)*.vv.(i);if!temp>!bigthen(big:=!temp;imax:=i)done;(* interchange rows *)ifk<>!imaxthen(forj=0ton-1dotemp:=M.getlu[|!imax;j|];lettmp=M.getlu[|k;j|]inM.setlu[|!imax;j|]tmp;M.setlu[|k;j|]!tempdone;d:=!d*.-1.;vv.(!imax)<-vv.(k));indx.(k)<-!imax;ifM.getlu[|k;k|]=0.thenM.setlu[|k;k|]tiny;fori=k+1ton-1dolettmp0=M.getlu[|i;k|]inlettmp1=M.getlu[|k;k|]intemp:=tmp0/.tmp1;M.setlu[|i;k|]!temp;forj=k+1ton-1doletprev=M.getlu[|i;j|]inM.setlu[|i;j|](prev-.(!temp*.M.getlu[|k;j|]))donedonedone;lu,indx,!d(* LU decomposition, return L, U, and permutation vector *)letlua=letk=M.kindainletlu,indx,_=_lu_baseainletn=(M.shapelu).(0)inletm=(M.shapelu).(1)inassert(n=m&&n>=2);letl=M.eyekninforr=1ton-1doforc=0tor-1doletv=M.getlu[|r;c|]inM.setl[|r;c|]v;M.setlu[|r;c|]0.donedone;l,lu,indxlet_lu_solve_vecab=assert(Array.length(M.shapeb)=1);letn=(M.shapea).(0)inif(M.shapeb).(0)<>nthenfailwith"LUdcmp::solve bad sizes";letii=ref0inletsum=ref0.inletx=M.copybinletlu,indx,_=_lu_baseainfori=0ton-1doletip=indx.(i)insum:=M.getx[|ip|];M.setx[|ip|](M.getx[|i|]);if!ii<>0thenforj=!ii-1toi-1dosum:=!sum-.(M.getlu[|i;j|]*.M.getx[|j|])doneelseif!sum<>0.thenii:=!ii+1;M.setx[|i|]!sumdone;fori=n-1downto0dosum:=M.getx[|i|];forj=i+1ton-1dosum:=!sum-.(M.getlu[|i;j|]*.M.getx[|j|])done;M.setx[|i|](!sum/.M.getlu[|i;i|])done;x(* Linear equation solution by LU decomposition.
* Input matrix: a[n][n], b[n][m];
* Output: ``x``, so that ax = b. *)letlinsolve_luab=letdims_a,dims_b=M.shapea,M.shapebinlet_,_=_check_is_matrixdims_a,_check_is_matrixdims_binassert(dims_a.(0)=dims_a.(1));letm=dims_b.(1)inletb=M.copybinforj=0tom-1doletvec=M.get_slice[[];[j]]b|>M.flatteninletx=_lu_solve_vecavecinM.set_slice[[];[j]]bxdone;b(* Matrix inverse *)letinva=letdims_a=M.shapeain_check_is_matrixdims_a|>ignore;assert(dims_a.(0)=dims_a.(1));letn=dims_a.(0)inletb=M.eye(M.kinda)ninlinsolve_luab(* Determinant of matrix a *)letdeta=letdims_a=M.shapeain_check_is_matrixdims_a|>ignore;assert(dims_a.(0)=dims_a.(1));letn=dims_a.(0)inletlu,_,sign=_lu_baseainletbig=refsigninfori=0ton-1dobig:=!big*.M.getlu[|i;i|]done;!big(* Solver for tridiagonal matrix
* Input: a[n], b[n], c[n], which together consit the tridiagonal matrix A, and the right side vector r[n]. Return: x[n].
*)lettridiag_solve_vecabcr=letn=Array.lengthainletn1=Array.lengthbinletn2=Array.lengthcinassert(n=n1&&n=n2);ifb.(0)=0.thenraise(Invalid_argument"tridiag_solve_vec: 0 at the beginning of diagonal vector");letbet=refb.(0)inletgam=Array.maken0.inletx=Array.maken0.inx.(0)<-r.(0)/.!bet;forj=1ton-1dogam.(j)<-c.(j-1)/.!bet;bet:=b.(j)-.(a.(j)*.gam.(j));if!bet=0.thenraise(Invalid_argument"tridiag_solve_vec: algorithm fails");x.(j)<-(r.(j)-.(a.(j)*.x.(j-1)))/.!betdone;forj=n-2downto0dox.(j)<-x.(j)-.(gam.(j+1)*.x.(j+1))done;x(* ends here *)