diff --git a/app/vmmd.ml b/app/vmmd.ml index 07d8a88..b756504 100644 --- a/app/vmmd.ml +++ b/app/vmmd.ml @@ -37,20 +37,20 @@ let create process cont = let state', out' = Vmm_vmmd.handle_shutdown !state name vm r in state := state' ; s := { !s with vm_destroyed = succ !s.vm_destroyed } ; - (process "handle shutdown" out' >|= fun _ -> ()) >|= fun () -> + process "handle shutdown (stat, log)" out' >|= fun () -> let state', waiter_opt = Vmm_vmmd.waiter !state name in state := state' ; (match waiter_opt with | None -> () | Some wakeme -> Lwt.wakeup wakeme ())) ; - (process "setting up statistics, log, reply" out >|= fun _ -> ()) + process "setting up stat, log, reply" out let register who header = match Vmm_vmmd.register !state who Lwt.task with | None -> Error (`Data (header, `Failure "task already registered")) | Some (state', task) -> state := state' ; Ok task -let handle out fd addr = +let handle process fd addr = Logs.debug (fun m -> m "connection from %a" Vmm_lwt.pp_sockaddr addr) ; (* now we need to read a packet and handle it (1) @@ -64,21 +64,6 @@ let handle out fd addr = -- Lwt effects happen (stats, logs, wait_and_clear) -- (2) goto (1) *) - let process txt wires = - Lwt_list.fold_left_s (fun r data -> - match r, data with - | Ok (), (#Vmm_vmmd.service_out as o) -> out o - | Ok (), `Data wire -> - (* rather: terminate connection *) - Vmm_lwt.write_wire fd wire >|= fun _ -> - Ok () - | Error e, _ -> Lwt.return (Error e)) - (Ok ()) wires >|= function - | Ok () -> Ok () - | Error (`Msg msg) -> - Logs.err (fun m -> m "error in processing data %s: %s" txt msg) ; - Error () - in let rec loop () = Logs.debug (fun m -> m "now reading") ; Vmm_lwt.read_wire fd >>= function @@ -86,32 +71,31 @@ let handle out fd addr = Logs.err (fun m -> m "error while reading") ; Lwt.return_unit | Ok wire -> - Logs.debug (fun m -> m "read sth") ; + Logs.debug (fun m -> m "read %a" Vmm_commands.pp_wire wire) ; let state', data, next = Vmm_vmmd.handle_command !state wire in state := state' ; - process "handle_command" data >>= function - | Error () -> Lwt.return_unit - | Ok () -> match next with - | `Loop -> loop () - | `End -> Lwt.return_unit - | `Create cont -> create process cont - | `Wait (who, out) -> - (match register who (fst wire) with - | Error out' -> process "wait" [ out' ] >|= ignore - | Ok task -> - task >>= fun () -> - process "wait" [ out ] >|= ignore) - | `Wait_and_create (who, next) -> - (match register who (fst wire) with - | Error out' -> process "wait and create" [ out' ] >|= ignore - | Ok task -> - task >>= fun () -> - let state', data, n = next !state in - state := state' ; - process "wait and create" data >>= fun _ -> - match n with - | `End -> Lwt.return_unit - | `Create cont -> create process cont >|= ignore) + process "handle command" data >>= fun () -> + match next with + | `Loop -> loop () + | `End -> Lwt.return_unit + | `Create cont -> create process cont + | `Wait (who, out) -> + (match register who (fst wire) with + | Error out' -> process "wait" [ out' ] + | Ok task -> + task >>= fun () -> + process "wait" [ out ]) + | `Wait_and_create (who, next) -> + (match register who (fst wire) with + | Error out' -> process "wait and create" [ out' ] + | Ok task -> + task >>= fun () -> + let state', data, n = next !state in + state := state' ; + process "wait and create" data >>= fun () -> + match n with + | `End -> Lwt.return_unit + | `Create cont -> create process cont) in loop () >>= fun () -> Vmm_lwt.safe_close fd @@ -185,39 +169,70 @@ let jump _ = | Some c -> c) >>= fun (c, c_fd, c_mut) -> create_mbox `Stats >>= fun s -> - let write_reply (header, cmd) name mvar fd mut = + let write_reply txt (header, cmd) name mvar fd mut = Lwt_mutex.with_lock mut (fun () -> Lwt_mvar.put mvar (header, cmd) >>= fun () -> Vmm_lwt.read_wire fd) >|= function | Ok (header', reply) -> - if not Vmm_commands.(version_eq header.version header'.version) then - Error (`Msg ("wrong version in reply from " ^ name)) - else if not Vmm_commands.(Int64.equal header.sequence header'.sequence) then - Error (`Msg ( - Fmt.strf "wrong id %Lu (expected %Lu) in reply from %s" - header'.Vmm_commands.sequence header.Vmm_commands.sequence name)) - else begin match reply with - | `Success _ -> Ok () - | `Failure msg -> Error (`Msg (msg ^ " from " ^ name)) - | _ -> Error (`Msg ("unexpected data from " ^ name)) + if not Vmm_commands.(version_eq header.version header'.version) then begin + Logs.err (fun m -> m "%s: wrong version (got %a, expected %a) in reply from %s" + txt + Vmm_commands.pp_version header'.Vmm_commands.version + Vmm_commands.pp_version header.Vmm_commands.version + name) ; + invalid_arg "bad version received" + end else if not Vmm_commands.(Int64.equal header.sequence header'.sequence) then begin + Logs.err (fun m -> m "%s: wrong id %Lu (expected %Lu) in reply from %s" + txt header'.Vmm_commands.sequence header.Vmm_commands.sequence name) ; + invalid_arg "wrong sequence number received" + end else begin + Logs.debug (fun m -> m "%s: received valid reply from %s %a" + txt name Vmm_commands.pp_wire (header', reply)) ; + match reply with + | `Success _ -> () + | `Failure msg -> + (* can we programatically solve such a situation? *) + (* we at least know e.g when writing to console resulted in an error, + that we can't continue but need to roll back -- and not continue + with execvp() *) + Logs.err (fun m -> m "%s: received failure %s from %s" txt msg name) + | _ -> + Logs.err (fun m -> m "%s: unexpected data from %s" txt name) ; + invalid_arg "unexpected data" end - | Error _ -> Error (`Msg ("error in read from " ^ name)) + | Error _ -> + Logs.err (fun m -> m "error in read from %s" name) ; + invalid_arg "communication failure" in - let out = function + let out txt = function | `Stat wire -> begin match s with - | None -> Lwt.return (Ok ()) - | Some (s, s_fd, s_mut) -> write_reply wire "stats" s s_fd s_mut + | None -> Lwt.return_unit + | Some (s, s_fd, s_mut) -> write_reply txt wire "stats" s s_fd s_mut end - | `Log wire -> write_reply wire "log" l l_fd l_mut - | `Cons wire -> write_reply wire "console" c c_fd c_mut + | `Log wire -> write_reply txt wire "log" l l_fd l_mut + | `Cons wire -> write_reply txt wire "console" c c_fd c_mut in + let process ?fd txt wires = + Lwt_list.iter_p (function + | (#Vmm_vmmd.service_out as o) -> out txt o + | `Data wire -> match fd with + | None -> + Logs.app (fun m -> m "%s received %a" txt Vmm_commands.pp_wire wire) ; + Lwt.return_unit + | Some fd -> + (* TODO should we terminate the connection on write failure? *) + Vmm_lwt.write_wire fd wire >|= fun _ -> + ()) + wires + in + Lwt.async stats_loop ; Lwt.catch (fun () -> let rec loop () = Lwt_unix.accept ss >>= fun (fd, addr) -> Lwt_unix.set_close_on_exec fd ; - Lwt.async (fun () -> handle out fd addr) ; + Lwt.async (fun () -> handle (process ~fd) fd addr) ; loop () in loop ())