123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140openBaseopenHardcamlincludeMul_intfmoduleConfig=structtypet=|Dadda|Wallace[@@derivingenumerate,sexp_of]endmoduleMake_gen(B:Gen)=structopenBmoduleA=Add.Make_gen(B)(* A [Weights.t] represents the product of two [B.t]s as a set of weighted single-bit
wires, where the value of a wire is zero if its bit is zero, or [Int.pow 2 w] if its
bit is one, where [w] is the weight of the wire. [Weights.t] stores the wires
indexed by weight. *)moduleWeights:sigtypet[@@derivingsexp_of](* [create a b] creates the trivially correct product by multiplying every bit
in [a] by every bit in [b]. *)valcreate:B.t->B.t->tvalmax_wires_at_any_weight:t->int(* [layer t ~config] does one round of simplification on [t] by combining groups of
three wires of the same weight using a full adder, and possibly combining leftover
groups of two wires of the same weight using a half adder, as per [config]. *)vallayer:t->config:Config.t->t(* [sum t] should only be called if [max_wires_at_any_weight t <= 2]. [sum]
constructs two numbers out of the wires in [t] and uses [B.(+:)] to add them. *)valsum:t->target_width:int->B.tend=struct(* [t.(w)] holds all wires of weight [w]. *)typet=B.bitlistarray[@@derivingsexp_of]letmax_wires_at_any_weightt=Array.foldt~init:0~f:(funmaxwires->Int.maxmax(List.lengthwires));;letcreateab=letwa=widthainletwb=widthbinletmax_weight=wa+wb-2inArray.init(max_weight+1)~f:(funw->List.initwa~f:(funi->List.initwb~f:(funj->ifi+j=wthenSome(bitai&:bitbj)elseNone))|>List.concat|>List.filter_opt);;letlayer(t:t)~(config:Config.t)=letis_final_stage=max_wires_at_any_weightt=3inletresult=Array.create~len:(Array.lengtht+1)[]inletaddweightwire=result.(weight)<-wire::result.(weight)inArray.iterit~f:(funweightwires->letrecloopwires=matchwireswith|[]->()|[a]->addweighta|[a;b]->letuse_half_adder=matchconfigwith|Wallace->true|Dadda->(* We're trying to make [length result.(weight)] be zero mod 3 after
including [a] and [b]. [m] measures the current length mod 3, not
including [a] and [b]. So if [m = 2], we use a half adder, which will
make the length be zero mod 3 after including the [sum] output of the
half adder. *)letm=List.lengthresult.(weight)%3inifis_final_stagethenm<>0elsem=2inifuse_half_adderthen(let{A.carry;sum}=A.half_adderabinaddweightsum;add(weight+1)carry)else(addweighta;addweightb)(* The latest wikipedia description does something a bit different for the dadda
tree multiplier - it reduces each level less agressively with full adders
according to a target weight count. I dont see it using less logic resources,
but it will push partial products further down the tree *)|a::b::c::wires->let{A.carry;sum}=A.full_adderabcinaddweightsum;add(weight+1)carry;loopwiresinloopwires);result;;letsumt~target_width=leta,b=Array.foldt~init:([],[])~f:(fun(a,b)ab->leta',b'=matchabwith|[a';b']->a',b'|[a']->a',gnd|[]->gnd,gnd|_->assertfalseina'::a,b'::b)inuresize(concat_msba)target_width+:uresize(concat_msbb)target_width;;endletcreateconfigab=letrecoptimise(weights:Weights.t)=letmax_wires_at_any_weight=Weights.max_wires_at_any_weightweightsinifmax_wires_at_any_weight<=2thenweightselseoptimise(Weights.layerweights~config)inWeights.sum(optimise(Weights.createab))~target_width:(widtha+widthb);;endletcreate_gen(typea)~(config:Config.t)(moduleB:Genwithtypet=a)ab=letmoduleM=Make_gen(B)inM.createconfigab;;moduleComb_as_gen(B:Comb.S):Genwithtypet=B.twithtypebit=B.t=structincludeBtypebit=t[@@derivingsexp_of]endletcreate(typea)~config(moduleB:Comb.Swithtypet=a)ab=create_gen~config(moduleComb_as_gen(B))ab;;