diff --git a/Makefile b/Makefile index 0c140e18..adf95352 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ clean: @dune clean test: - @dune runtest $(DUNE_OPTS) + @dune runtest $(DUNE_OPTS) --no-buffer test-autopromote: @dune runtest $(DUNE_OPTS) --auto-promote diff --git a/src/core/ws_pool.ml b/src/core/ws_pool.ml index 7581636f..e10529d9 100644 --- a/src/core/ws_pool.ml +++ b/src/core/ws_pool.ml @@ -6,18 +6,7 @@ include Runner let ( let@ ) = ( @@ ) -module Id = struct - type t = unit ref - (** Unique identifier for a pool *) - - let create () : t = Sys.opaque_identity (ref ()) - let equal : t -> t -> bool = ( == ) -end - type state = { - id_: Id.t; - (** Unique to this pool. Used to make sure tasks stay within the same - pool. *) active: bool A.t; (** Becomes [false] when the pool is shutdown. *) mutable workers: worker_state array; (** Fixed set of workers. *) main_q: WL.task_full Queue.t; @@ -99,12 +88,15 @@ let schedule_in_main_queue (self : state) task : unit = longer permitted *) raise Shutdown -let schedule_from_w (self : worker_state) (task : WL.task_full) : unit = +let schedule_from_anywhere_ (st : state) (task : WL.task_full) : unit = match get_current_worker_ () with - | Some w when Id.equal self.st.id_ w.st.id_ -> + | Some w when st == w.st -> (* use worker from the same pool *) schedule_on_current_worker w task - | _ -> schedule_in_main_queue self.st task + | _ -> schedule_in_main_queue st task + +let schedule_from_w (w : worker_state) task : unit = + schedule_from_anywhere_ w.st task exception Got_task of WL.task_full @@ -223,7 +215,8 @@ let as_runner_ (self : state) : t = Runner.For_runner_implementors.create ~shutdown:(fun ~wait () -> shutdown_ self ~wait) ~run_async:(fun ~fiber f -> - schedule_in_main_queue self @@ T_start { fiber; f }) + let task = WL.T_start { fiber; f } in + schedule_from_anywhere_ self task) ~size:(fun () -> size_ self) ~num_tasks:(fun () -> num_tasks_ self) () @@ -240,7 +233,6 @@ type ('a, 'b) create_args = let create ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) ?num_threads ?name () : t = - let pool_id_ = Id.create () in let num_domains = Domain_pool_.max_number_of_domains () in let num_threads = Util_pool_.num_threads ?num_threads () in @@ -249,7 +241,6 @@ let create ?(on_init_thread = default_thread_init_exit_) let pool = { - id_ = pool_id_; active = A.make true; workers = [||]; main_q = Queue.create (); diff --git a/test/dune b/test/dune index 38b6a9c8..0e2a295c 100644 --- a/test/dune +++ b/test/dune @@ -20,3 +20,10 @@ unix trace-tef trace)) + +(test + (name t_fib_await_mem) + (package moonpool) + (enabled_if + (= %{system} linux)) + (libraries moonpool)) diff --git a/test/t_fib_await_mem.ml b/test/t_fib_await_mem.ml new file mode 100644 index 00000000..f255fdac --- /dev/null +++ b/test/t_fib_await_mem.ml @@ -0,0 +1,54 @@ +(* regression test for #45 *) + +open Moonpool + +let ( let@ ) = ( @@ ) + +let rec fib_direct x = + if x <= 1 then + 1 + else + fib_direct (x - 1) + fib_direct (x - 2) + +let cutoff = 8 + +let rec fib_await ~on x : int Fut.t = + if x <= cutoff then + Fut.spawn ~on (fun () -> fib_direct x) + else + Fut.spawn ~on (fun () -> + let n1 = fib_await ~on (x - 1) in + let n2 = fib_await ~on (x - 2) in + let n1 = Fut.await n1 in + let n2 = Fut.await n2 in + n1 + n2) + +(** Read VmHWM (peak RSS in kB) from /proc/self/status. *) +let get_vmhwm_kb () : int option = + let path = "/proc/self/status" in + match In_channel.with_open_bin path In_channel.input_all with + | exception Sys_error _ -> None + | content -> + let lines = String.split_on_char '\n' content in + List.find_map + (fun line -> Scanf.sscanf_opt line "VmHWM: %d kB" Fun.id) + lines + +let max_rss_bytes = 150_000_000 + +let () = + let@ pool = Ws_pool.with_ ~num_threads:4 () in + let result = fib_await ~on:pool 40 |> Fut.wait_block_exn in + assert (result = 165580141); + match get_vmhwm_kb () with + | None -> + Printf.printf "fib 40 = %d (skip RSS check: no /proc/self/status)\n%!" + result + | Some hwm_kb -> + let hwm_bytes = hwm_kb * 1024 in + Printf.printf "fib 40 = %d, peak RSS = %d bytes\n%!" result hwm_bytes; + if hwm_bytes > max_rss_bytes then ( + Printf.eprintf "FAIL: peak RSS %d bytes exceeds limit %d bytes\n%!" + hwm_bytes max_rss_bytes; + exit 1 + )