123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282(* TODO: use a better queue for the tasks *)moduleA=Atomic_includeRunnerlet(let@)=(@@)typethread_loop_wrapper=thread:Thread.t->pool:t->(unit->unit)->unit->unitletglobal_thread_wrappers_:thread_loop_wrapperlistA.t=A.make[]letadd_global_thread_loop_wrapperf:unit=whileletl=A.getglobal_thread_wrappers_innot(A.compare_and_setglobal_thread_wrappers_l(f::l))doDomain_.relax()donetypestate={active:boolA.t;threads:Thread.tarray;qs:taskBb_queue.tarray;cur_q:intA.t;(** Selects queue into which to push *)}(** internal state *)(** Run [task] as is, on the pool. *)letrun_direct_(self:state)(task:task):unit=letn_qs=Array.lengthself.qsinletoffset=A.fetch_and_addself.cur_q1in(* blocking push, last resort *)let[@inline]push_waitf=letq_idx=offsetmodArray.lengthself.qsinletq=self.qs.(q_idx)inBb_queue.pushqfintry(* try each queue with a round-robin initial offset *)for_retry=1to10dofori=0ton_qs-1doletq_idx=(i+offset)modArray.lengthself.qsinletq=self.qs.(q_idx)inifBb_queue.try_pushqtaskthenraise_notraceExitdonedone;push_waittaskwith|Exit->()|Bb_queue.Closed->raiseShutdownletrecrun_async_(self:state)(task:task):unit=lettask'()=(* run [f()] and handle [suspend] in it *)Suspend_.with_suspendtask~run:(fun~with_handlertask->ifwith_handlerthenrun_async_selftaskelserun_direct_selftask)inrun_direct_selftask'letrun=run_asyncletsize_(self:state)=Array.lengthself.threadsletnum_tasks_(self:state):int=letn=ref0inArray.iter(funq->n:=!n+Bb_queue.sizeq)self.qs;!n[@@@ifge5.0](* DLA interop *)letprepare_for_await():Dla_.t=(* current state *)letst:((with_handler:bool->task->unit)*Suspend_.suspension)optionA.t=A.makeNoneinletrelease():unit=matchA.exchangestNonewith|None->()|Some(run,k)->run~with_handler:true(fun()->k(Ok()))andawait():unit=Suspend_.suspend{Suspend_.handle=(fun~runk->A.setst(Some(run,k)))}inlett={Dla_.release;await }int[@@@else_]letprepare_for_await()={Dla_.release=ignore;await=ignore}[@@@endif]exceptionGot_taskoftasktype around_task=AT_pair:(t->'a)*(t->'a->unit)->around_taskletworker_thread_(runner:t)~on_exn~around_task(active:boolA.t)(qs:taskBb_queue.tarray)~(offset:int):unit=let num_qs=Array.lengthqsinlet(AT_pair(before_task,after_task))=around_taskinletmain_loop()=whileA.getactive do(* last resort: block onmy queue *)letpop_blocking()=letmy_q=qs.(offset modnum_qs)inBb_queue.popmy_qinlettask=tryfori=0tonum_qs-1doletq=qs.((offset+i)modnum_qs)inmatchBb_queue.try_pop~force_lock:falseqwith|Somef->raise_notrace(Got_taskf)|None->()done;pop_blocking()withGot_taskf->finlet_ctx=before_taskrunnerin(* run the task now, catching errors *)(trytask()withe->letbt=Printexc.get_raw_backtrace()inon_exnebt);after_taskrunner_ctxdoneintry(* handle domain-local await *)Dla_.using~prepare_for_await~while_running:main_loopwithBb_queue.Closed->()letdefault_thread_init_exit_~dom_id:_~t_id:_()=()(** We want a reasonable number of queues. Even if your system is
a beast with hundreds of cores, trying
to work-steal through hundreds of queues will have a cost.
Hence, we limit the number of queues to at most 32 (number picked
via the ancestral technique of the pifomètre). *)letmax_queues=32letshutdown_~wait(self:state):unit=letwas_active=A.exchangeself.activefalsein(* close the job queues, which will fail future calls to [run],
and wake up the subset of [self.threads] that are waiting on them. *)ifwas_active thenArray.iter Bb_queue.closeself.qs;ifwaitthenArray.iterThread.joinself.threadstype('a,'b)create_args =?on_init_thread:(dom_id:int->t_id:int->unit->unit)->?on_exit_thread:(dom_id:int->t_id:int->unit->unit)->?thread_wrappers:thread_loop_wrapperlist->?on_exn:(exn->Printexc.raw_backtrace->unit)->?around_task:(t->'b)*(t->'b->unit)->?min:int->?per_domain:int->'a(** Arguments used in {!create}. See {!create} for explanations. *)letcreate?(on_init_thread=default_thread_init_exit_)?(on_exit_thread=default_thread_init_exit_)?(thread_wrappers=[])?(on_exn=fun__->())?around_task?min:(min_threads=1)?(per_domain=0)():t=(* wrapper *)letaround_task=matcharound_taskwith|Some(f,g)->AT_pair(f,g)|None->AT_pair(ignore,fun __->())in(* number of threads to run *)letmin_threads=max1min_threadsinletnum_domains=D_pool_.n_domains()inassert(num_domains>=1);letnum_threads=maxmin_threads(num_domains*per_domain)in(* make sure we don't bias towards the first domain(s) in {!D_pool_} *)letoffset=Random.intnum_domains inletactive=A.maketrueinletqs=letnum_qs =min(minnum_domainsnum_threads)max_queuesinArray.initnum_qs(fun_->Bb_queue.create())inletpool=letdummy =Thread.self()in{active;threads=Array.makenum_threadsdummy;qs;cur_q=A.make0}inletrunner=Runner.For_runner_implementors.create~shutdown:(fun~wait()->shutdown_pool~wait)~run_async:(funf->run_async_poolf)~size:(fun()->size_pool)~num_tasks:(fun()->num_tasks_pool)()in(* temporary queue used to obtain thread handles from domains
on which the thread are started. *)letreceive_threads=Bb_queue.create()in(* start the thread with index [i]*)letstart_thread_with_idxi=letdom_idx=(offset+i)modnum_domainsin(* function run in the thread itself *)letmain_thread_fun():unit=letthread=Thread.self()inlett_id=Thread.idthreadinon_init_thread~dom_id:dom_idx~t_id();letall_wrappers=List.rev_appendthread_wrappers(A.getglobal_thread_wrappers_)inletrun()=worker_thread_runner~on_exn~around_taskactiveqs~offset:iin(* the actual worker loop is [worker_thread_], with all
wrappers for this pool and for all pools (global_thread_wrappers_) *)letrun'=List.fold_left(funrunf->f~thread~pool:runnerrun)runall_wrappersin(* now run the main loop *)Fun.protectrun'~finally:(fun()->(* on termination, decrease refcount of underlying domain *)D_pool_.decr_ondom_idx);on_exit_thread~dom_id:dom_idx~t_id()in(* function called in domain with index [i], to
create the thread and push it into [receive_threads] *)letcreate_thread_in_domain ()=letthread=Thread.createmain_thread_fun()in(* send the thread from the domain back to us *)Bb_queue.pushreceive_threads(i,thread)inD_pool_.run_ondom_idxcreate_thread_in_domainin(* start all threads, placing them on the domains
according to their indexand [offset] in a round-robin fashion. *)fori=0tonum_threads-1dostart_thread_with_idxidone;(* receive the newly created threads back from domains *)for_j=1tonum_threadsdoleti,th=Bb_queue.popreceive_threads inpool.threads.(i)<-thdone;runnerletwith_?on_init_thread?on_exit_thread?thread_wrappers?on_exn?around_task?min?per_domain ()f=letpool=create?on_init_thread?on_exit_thread?thread_wrappers?on_exn?around_task?min?per_domain()inlet@()=Fun.protect~finally:(fun()->shutdownpool)infpool