Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/dune_rpc_impl/server.ml
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,16 @@ module Clients = struct

type t = entry Session.Id.Map.t

let set_menu t id menu =
match Session.Id.Map.find t id with
| None -> ()
| Some entry -> entry.menu <- Some menu

let empty = Session.Id.Map.empty

let add_session t (session : _ Session.Stage1.t) =
let id = Session.Stage1.id session in
let result = { menu = None; session } in
Session.Stage1.register_upgrade_callback session (fun menu ->
result.menu <- Some menu);
Session.Id.Map.add_exn t id result

let remove_session t (session : _ Session.t) =
Expand Down Expand Up @@ -230,13 +233,18 @@ let handler (t : t Fdecl.t) : 'a Dune_rpc_server.Handler.t =
t.clients <- Clients.add_session t.clients session;
Fiber.return client
in
let on_upgrade session menu =
let+ () = Fiber.return () in
let t = Fdecl.get t in
Clients.set_menu t.clients (Session.id session) menu
in
let on_terminate session =
let t = Fdecl.get t in
t.clients <- Clients.remove_session t.clients session;
Fiber.return ()
in
let rpc =
Handler.create ~on_terminate ~on_init
Handler.create ~on_terminate ~on_init ~on_upgrade
~version:Dune_rpc_private.Version.latest ()
in
let () =
Expand Down
40 changes: 23 additions & 17 deletions src/dune_rpc_server/dune_rpc_server.ml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ module Session = struct
; send : Packet.t list option -> unit Fiber.t
; pool : Fiber.Pool.t
; mutable state : 'a state
; mutable on_upgrade : (Menu.t -> unit) option
}

let set t state =
Expand All @@ -120,7 +119,6 @@ module Session = struct
; send
; state = Uninitialized (Close.create ())
; id = Id.gen ()
; on_upgrade = None
; pool = Fiber.Pool.create ()
}

Expand Down Expand Up @@ -150,12 +148,9 @@ module Session = struct
in
variant "Initialized" [ record ]

let to_dyn f { id; state; queries = _; send = _; on_upgrade = _; pool = _ }
=
let to_dyn f { id; state; queries = _; send = _; pool = _ } =
let open Dyn in
record [ ("id", Id.to_dyn id); ("state", dyn_of_state f state) ]

let register_upgrade_callback t f = t.on_upgrade <- Some f
end

type 'a t =
Expand Down Expand Up @@ -187,8 +182,7 @@ module Session = struct

let id t = t.base.id

let of_stage1 (base : _ Stage1.t) handler menu =
let () = Option.iter base.on_upgrade ~f:(fun f -> f menu) in
let of_stage1 (base : _ Stage1.t) handler =
{ base; handler; pollers = Dune_rpc_private.Id.Map.empty }

let notification t decl n =
Expand Down Expand Up @@ -279,6 +273,7 @@ module H = struct
type 'a base =
{ on_init : 'a Session.Stage1.t -> Initialize.Request.t -> 'a Fiber.t
; on_terminate : 'a Session.t -> unit Fiber.t
; on_upgrade : 'a Session.t -> Menu.t -> unit Fiber.t
; version : int * int
}

Expand Down Expand Up @@ -410,18 +405,19 @@ module H = struct
Menu.select_common ~remote_versions:client_versions
~local_versions:t.known_versions
with
| None ->
abort session
~message:"Server and client have no method versions in common"
| Some menu ->
let response =
Version_negotiation.(
Conv.to_sexp Response.sexp (Response.create (Menu.to_list menu)))
in
let* () = session.send (Some [ Response (id, Ok response) ]) in
let handler = t.to_handler menu in
run_session { base = t.base; handler } stats
(Session.of_stage1 session handler menu)
| None ->
abort session
~message:"Server and client have no method versions in common")))
let session = Session.of_stage1 session handler in
let* () = t.base.on_upgrade session menu in
run_session { base = t.base; handler } stats session)))

let handle (type a) (t : a stage1) stats (session : a Session.Stage1.t) =
let open Fiber.O in
Expand Down Expand Up @@ -468,10 +464,11 @@ module H = struct
{ builder : 's Session.t V.Builder.t
; on_terminate : 's Session.t -> unit Fiber.t
; on_init : 's Session.Stage1.t -> Initialize.Request.t -> 's Fiber.t
; on_upgrade : 's Session.t -> Menu.t -> unit Fiber.t
; version : int * int
}

let to_handler { builder; on_terminate; on_init; version } =
let to_handler { builder; on_terminate; on_init; version; on_upgrade } =
let to_handler menu =
V.Builder.to_handler builder
~session_version:(fun s -> (Session.initialize s).dune_version)
Expand All @@ -482,10 +479,19 @@ module H = struct
|> String.Map.of_list_map_exn ~f:(fun (name, gens) ->
(name, Int.Set.of_list gens))
in
{ to_handler; base = { on_init; on_terminate; version }; known_versions }
{ to_handler
; base = { on_init; on_terminate; on_upgrade; version }
; known_versions
}

let create ?(on_terminate = fun _ -> Fiber.return ()) ~on_init ~version () =
{ builder = V.Builder.create (); on_init; on_terminate; version }
let create ?(on_terminate = fun _ -> Fiber.return ()) ~on_init
?(on_upgrade = fun _ _ -> Fiber.return ()) ~version () =
{ builder = V.Builder.create ()
; on_init
; on_terminate
; version
; on_upgrade
}

let implement_request (t : _ t) = V.Builder.implement_request t.builder

Expand Down
9 changes: 2 additions & 7 deletions src/dune_rpc_server/dune_rpc_server.mli
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,6 @@ module Session : sig
val request_close : 'a t -> unit Fiber.t

val to_dyn : ('a -> Dyn.t) -> 'a t -> Dyn.t

(** Register a callback to be called once version negotiation has concluded.
At most one callback can be set at once; calling this function multiple
times for the same session will override previous invocations.

The registered callback is guaranteed to be called at most once. *)
val register_upgrade_callback : _ t -> (Menu.t -> unit) -> unit
end
end

Expand All @@ -89,6 +82,8 @@ module Handler : sig
(** Initiation hook. It's guaranteed to be called before any
requests/notifications. It's job is to initialize the session
state. *)
-> ?on_upgrade:('a Session.t -> Menu.t -> unit Fiber.t)
(** called immediately after the client has finished negotitation *)
-> version:int * int
(** version of the rpc. it's expected to support all earlier versions *)
-> unit
Expand Down