diff --git a/app/vmmd.ml b/app/vmmd.ml index eadd7ce..4ae9236 100644 --- a/app/vmmd.ml +++ b/app/vmmd.ml @@ -71,7 +71,7 @@ let handle out fd addr = (Ok ()) wires >|= function | Ok () -> Ok () | Error (`Msg msg) -> - Logs.err (fun m -> m "error in process %s: %s" txt msg) ; + Logs.err (fun m -> m "error in processing data %s: %s" txt msg) ; Error () in let rec loop () = @@ -128,12 +128,13 @@ let create_mbox sock = if fails, we'd need to retransmit all VM info to stat (or stat has to ask at connect) *) let rec loop () = Lwt_mvar.take mvar >>= fun data -> + Logs.debug (fun m -> m "writing %a" Vmm_commands.pp_wire data) ; Vmm_lwt.write_wire fd data >>= function | Ok () -> loop () | Error `Exception -> invalid_arg ("exception while writing to " ^ Fmt.to_to_string pp_socket sock) ; in Lwt.async loop ; - Some (mvar, fd) + Some (mvar, fd, Lwt_mutex.create ()) let server_socket sock = let name = socket_path sock in @@ -152,53 +153,71 @@ let rec stats_loop () = stats_loop () let jump _ = - Sys.(set_signal sigpipe Signal_ignore) ; + Sys.(set_signal sigpipe Signal_ignore); Lwt_main.run (server_socket `Vmmd >>= fun ss -> - (create_mbox `Console >|= function - | None -> invalid_arg "cannot connect to console socket" - | Some c -> c) >>= fun (c, c_fd) -> - create_mbox `Stats >>= fun s -> (create_mbox `Log >|= function | None -> invalid_arg "cannot connect to log socket" - | Some l -> l) >>= fun (l, l_fd) -> - let write_reply (header, cmd) mvar fd = - Lwt_mvar.put mvar (header, cmd) >>= fun () -> - Vmm_lwt.read_wire fd >|= function + | Some l -> l) >>= fun (l, l_fd, l_mut) -> + let self_destruct () = + Vmm_vmmd.kill !state ; + (* not too happy about the sleep here, but cleaning up resources is + really important (fifos, vm images, tap devices) - which is done in + asynchronous (waiter tasks) + *) + Lwt_unix.sleep 1. >>= fun () -> + Vmm_lwt.safe_close ss + in + Sys.(set_signal sigterm (Signal_handle (fun _ -> Lwt.async self_destruct))); + (create_mbox `Console >|= function + | None -> invalid_arg "cannot connect to console socket" + | Some c -> c) >>= fun (c, c_fd, c_mut) -> + create_mbox `Stats >>= fun s -> + + let write_reply (header, cmd) name mvar fd mut = + Lwt_mutex.with_lock mut (fun () -> + Lwt_mvar.put mvar (header, cmd) >>= fun () -> + Vmm_lwt.read_wire fd) >|= function | Ok (header', reply) -> if not Vmm_commands.(version_eq header.version header'.version) then - Error (`Msg "wrong version in reply") + Error (`Msg ("wrong version in reply from " ^ name)) else if not Vmm_commands.(Int64.equal header.sequence header'.sequence) then - Error (`Msg "wrong id in reply") + Error (`Msg ( + Fmt.strf "wrong id %Lu (expected %Lu) in reply from %s" + header'.Vmm_commands.sequence header.Vmm_commands.sequence name)) else begin match reply with | `Success _ -> Ok () - | `Failure msg -> Error (`Msg msg) - | _ -> Error (`Msg "unexpected data") + | `Failure msg -> Error (`Msg (msg ^ " from " ^ name)) + | _ -> Error (`Msg ("unexpected data from " ^ name)) end - | Error _ -> Error (`Msg "error in read") + | Error _ -> Error (`Msg ("error in read from " ^ name)) in let out = function | `Stat wire -> begin match s with | None -> Lwt.return (Ok ()) - | Some (s, s_fd) -> write_reply wire s s_fd + | Some (s, s_fd, s_mut) -> write_reply wire "stats" s s_fd s_mut end - | `Log wire -> write_reply wire l l_fd - | `Cons wire -> write_reply wire c c_fd + | `Log wire -> write_reply wire "log" l l_fd l_mut + | `Cons wire -> write_reply wire "console" c c_fd c_mut in Lwt.async stats_loop ; - let rec loop () = - Lwt_unix.accept ss >>= fun (fd, addr) -> - Lwt_unix.set_close_on_exec fd ; - Lwt.async (fun () -> handle out fd addr) ; - loop () - in - loop ()) + Lwt.catch (fun () -> + let rec loop () = + Lwt_unix.accept ss >>= fun (fd, addr) -> + Lwt_unix.set_close_on_exec fd ; + Lwt.async (fun () -> handle out fd addr) ; + loop () + in + loop ()) + (fun e -> + Logs.err (fun m -> m "exception %s, shutting down" (Printexc.to_string e)); + self_destruct ())) open Cmdliner let cmd = - Term.(ret (const jump $ setup_log)), + Term.(const jump $ setup_log), Term.info "vmmd" ~version:"%%VERSION_NUM%%" let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1 diff --git a/src/vmm_vmmd.ml b/src/vmm_vmmd.ml index 16c3241..22d54a6 100644 --- a/src/vmm_vmmd.ml +++ b/src/vmm_vmmd.ml @@ -16,6 +16,10 @@ type 'a t = { tasks : 'a String.Map.t ; } +let kill t = + List.iter Vmm_unix.destroy + (List.map snd (Vmm_trie.all t.resources.Vmm_resources.unikernels)) + let init wire_version = let t = { wire_version ; diff --git a/src/vmm_vmmd.mli b/src/vmm_vmmd.mli index d12fc7b..f14d9cb 100644 --- a/src/vmm_vmmd.mli +++ b/src/vmm_vmmd.mli @@ -28,3 +28,5 @@ val handle_command : 'a t -> Vmm_commands.wire -> | `End ]) ] val setup_stats : 'a t -> Name.t -> Unikernel.t -> 'a t * out + +val kill : 'a t -> unit