diff --git a/app/vmmd.ml b/app/vmmd.ml index ee2d10d..e7acdd8 100644 --- a/app/vmmd.ml +++ b/app/vmmd.ml @@ -149,99 +149,115 @@ let rec stats_loop () = let jump _ = Sys.(set_signal sigpipe Signal_ignore); - Lwt_main.run - (server_socket `Vmmd >>= fun ss -> - (create_mbox `Log >|= function - | None -> invalid_arg "cannot connect to log socket" - | Some l -> l) >>= fun (l, l_fd, l_mut) -> - let self_destruct_mutex = Lwt_mutex.create () in - let self_destruct () = - Lwt_mutex.with_lock self_destruct_mutex (fun () -> - (if Vmm_vmmd.killall !state then - (* not too happy about the sleep here, but cleaning up resources - is really important (fifos, vm images, tap devices) - which is - done asynchronously (in the task waitpid() on the pid) *) - Lwt_unix.sleep 1. - else - Lwt.return_unit) >>= fun () -> - Vmm_lwt.safe_close ss) - in - Sys.(set_signal sigterm (Signal_handle (fun _ -> Lwt.async self_destruct))); - (create_mbox `Console >|= function - | None -> invalid_arg "cannot connect to console socket" - | Some c -> c) >>= fun (c, c_fd, c_mut) -> - create_mbox `Stats >>= fun s -> + match Vmm_vmmd.restore_unikernels () with + | Error (`Msg msg) -> Logs.err (fun m -> m "bailing out: %s" msg) + | Ok old_unikernels -> - let write_reply txt (header, cmd) name mvar fd mut = - Lwt_mutex.with_lock mut (fun () -> - Lwt_mvar.put mvar (header, cmd) >>= fun () -> - Vmm_lwt.read_wire fd) >|= function - | Ok (header', reply) -> - if not Vmm_commands.(version_eq header.version header'.version) then begin - Logs.err (fun m -> m "%s: wrong version (got %a, expected %a) in reply from %s" - txt - Vmm_commands.pp_version header'.Vmm_commands.version - Vmm_commands.pp_version header.Vmm_commands.version - name) ; - invalid_arg "bad version received" - end else if not Vmm_commands.(Int64.equal header.sequence header'.sequence) then begin - Logs.err (fun m -> m "%s: wrong id %Lu (expected %Lu) in reply from %s" - txt header'.Vmm_commands.sequence header.Vmm_commands.sequence name) ; - invalid_arg "wrong sequence number received" - end else begin - Logs.debug (fun m -> m "%s: received valid reply from %s %a" - txt name Vmm_commands.pp_wire (header', reply)) ; - match reply with - | `Success _ -> () - | `Failure msg -> - (* can we programatically solve such a situation? *) - (* we at least know e.g when writing to console resulted in an error, - that we can't continue but need to roll back -- and not continue - with execvp() *) - Logs.err (fun m -> m "%s: received failure %s from %s" txt msg name) - | _ -> - Logs.err (fun m -> m "%s: unexpected data from %s" txt name) ; - invalid_arg "unexpected data" - end - | Error _ -> - Logs.err (fun m -> m "error in read from %s" name) ; - invalid_arg "communication failure" - in - let out txt = function - | `Stat wire -> - begin match s with - | None -> Lwt.return_unit - | Some (s, s_fd, s_mut) -> write_reply txt wire "stats" s s_fd s_mut - end - | `Log wire -> write_reply txt wire "log" l l_fd l_mut - | `Cons wire -> write_reply txt wire "console" c c_fd c_mut - in - let process ?fd txt wires = - Lwt_list.iter_p (function - | (#Vmm_vmmd.service_out as o) -> out txt o - | `Data wire -> match fd with - | None -> - Logs.app (fun m -> m "%s received %a" txt Vmm_commands.pp_wire wire) ; - Lwt.return_unit - | Some fd -> - (* TODO should we terminate the connection on write failure? *) - Vmm_lwt.write_wire fd wire >|= fun _ -> - ()) - wires - in + Lwt_main.run + (server_socket `Vmmd >>= fun ss -> + (create_mbox `Log >|= function + | None -> invalid_arg "cannot connect to log socket" + | Some l -> l) >>= fun (l, l_fd, l_mut) -> + let self_destruct_mutex = Lwt_mutex.create () in + let self_destruct () = + Lwt_mutex.with_lock self_destruct_mutex (fun () -> + (if Vmm_vmmd.killall !state then + (* not too happy about the sleep here, but cleaning up resources + is really important (fifos, vm images, tap devices) - which + is done asynchronously (in the task waitpid() on the pid) *) + Lwt_unix.sleep 1. + else + Lwt.return_unit) >>= fun () -> + Vmm_lwt.safe_close ss) + in + Sys.(set_signal sigterm (Signal_handle (fun _ -> Lwt.async self_destruct))); + (create_mbox `Console >|= function + | None -> invalid_arg "cannot connect to console socket" + | Some c -> c) >>= fun (c, c_fd, c_mut) -> + create_mbox `Stats >>= fun s -> - Lwt.async stats_loop ; - Lwt.catch (fun () -> - let rec loop () = - Lwt_unix.accept ss >>= fun (fd, addr) -> - Lwt_unix.set_close_on_exec fd ; - Lwt.async (fun () -> handle (process ~fd) fd addr) ; - loop () - in - loop ()) - (fun e -> - Logs.err (fun m -> m "exception %s, shutting down" (Printexc.to_string e)); - self_destruct ())) + let write_reply txt (header, cmd) name mvar fd mut = + Lwt_mutex.with_lock mut (fun () -> + Lwt_mvar.put mvar (header, cmd) >>= fun () -> + Vmm_lwt.read_wire fd) >|= function + | Ok (header', reply) -> + if not Vmm_commands.(version_eq header.version header'.version) then begin + Logs.err (fun m -> m "%s: wrong version (got %a, expected %a) in reply from %s" + txt + Vmm_commands.pp_version header'.Vmm_commands.version + Vmm_commands.pp_version header.Vmm_commands.version + name) ; + invalid_arg "bad version received" + end else if not Vmm_commands.(Int64.equal header.sequence header'.sequence) then begin + Logs.err (fun m -> m "%s: wrong id %Lu (expected %Lu) in reply from %s" + txt header'.Vmm_commands.sequence header.Vmm_commands.sequence name) ; + invalid_arg "wrong sequence number received" + end else begin + Logs.debug (fun m -> m "%s: received valid reply from %s %a" + txt name Vmm_commands.pp_wire (header', reply)) ; + match reply with + | `Success _ -> () + | `Failure msg -> + (* can we programatically solve such a situation? *) + (* we at least know e.g when writing to console resulted in an error, + that we can't continue but need to roll back -- and not continue + with execvp() *) + Logs.err (fun m -> m "%s: received failure %s from %s" txt msg name) + | _ -> + Logs.err (fun m -> m "%s: unexpected data from %s" txt name) ; + invalid_arg "unexpected data" + end + | Error _ -> + Logs.err (fun m -> m "error in read from %s" name) ; + invalid_arg "communication failure" + in + let out txt = function + | `Stat wire -> + begin match s with + | None -> Lwt.return_unit + | Some (s, s_fd, s_mut) -> write_reply txt wire "stats" s s_fd s_mut + end + | `Log wire -> write_reply txt wire "log" l l_fd l_mut + | `Cons wire -> write_reply txt wire "console" c c_fd c_mut + in + let process ?fd txt wires = + Lwt_list.iter_p (function + | (#Vmm_vmmd.service_out as o) -> out txt o + | `Data wire -> match fd with + | None -> + Logs.app (fun m -> m "%s received %a" txt Vmm_commands.pp_wire wire) ; + Lwt.return_unit + | Some fd -> + (* TODO should we terminate the connection on write failure? *) + Vmm_lwt.write_wire fd wire >|= fun _ -> + ()) wires + in + + Lwt.async stats_loop ; + + let start_unikernel (name, config) = + match Vmm_vmmd.handle_create !state [] name config with + | Error (`Msg msg) -> + Logs.err (fun m -> m "failed to restart %a: %s" Name.pp name msg) ; + Lwt.return_unit + | Ok (state', out, `Create next) -> + state := state' ; + process "create from dump" out >>= fun () -> + create process next + in + Lwt_list.iter_p start_unikernel (Vmm_trie.all old_unikernels) >>= fun () -> + + Lwt.catch (fun () -> + let rec loop () = + Lwt_unix.accept ss >>= fun (fd, addr) -> + Lwt_unix.set_close_on_exec fd ; + Lwt.async (fun () -> handle (process ~fd) fd addr) ; + loop () + in + loop ()) + (fun e -> + Logs.err (fun m -> m "exception %s, shutting down" (Printexc.to_string e)); + self_destruct ())) open Cmdliner diff --git a/src/vmm_asn.ml b/src/vmm_asn.ml index 699c808..5064f09 100644 --- a/src/vmm_asn.ml +++ b/src/vmm_asn.ml @@ -521,6 +521,42 @@ let logs_of_disk buf = in next [] buf +let trie e = + let f elts = + List.fold_left (fun trie (key, value) -> + match Name.of_string key with + | Error (`Msg m) -> invalid_arg m + | Ok name -> + let trie, ret = Vmm_trie.insert name value trie in + assert (ret = None); + trie) Vmm_trie.empty elts + and g trie = + List.map (fun (k, v) -> Name.to_string k, v) (Vmm_trie.all trie) + in + Asn.S.map f g @@ + Asn.S.(sequence_of + (sequence2 + (required ~label:"name" utf8_string) + (required ~label:"value" e))) + +let version0_unikernels = trie unikernel_config + +let unikernels = + (* the choice is the implicit version + migration... be aware when + any dependent data layout changes .oO(/o\) *) + let f = function + | `C1 () -> Asn.S.error (`Parse "shouldn't happen") + | `C2 data -> data + and g data = + `C2 data + in + Asn.S.map f g @@ + Asn.S.(choice2 + (explicit 0 null) + (explicit 1 version0_unikernels)) + +let unikernels_of_cstruct, unikernels_to_cstruct = projections_of unikernels + type cert_extension = version * t let cert_extension = diff --git a/src/vmm_asn.mli b/src/vmm_asn.mli index 475d447..6bab81f 100644 --- a/src/vmm_asn.mli +++ b/src/vmm_asn.mli @@ -25,3 +25,6 @@ type cert_extension = Vmm_commands.version * Vmm_commands.t val cert_extension_of_cstruct : Cstruct.t -> (cert_extension, [> `Msg of string ]) result val cert_extension_to_cstruct : cert_extension -> Cstruct.t + +val unikernels_to_cstruct : Unikernel.config Vmm_trie.t -> Cstruct.t +val unikernels_of_cstruct : Cstruct.t -> (Unikernel.config Vmm_trie.t, [> `Msg of string ]) result diff --git a/src/vmm_unix.ml b/src/vmm_unix.ml index 4656450..a877a4b 100644 --- a/src/vmm_unix.ml +++ b/src/vmm_unix.ml @@ -54,6 +54,24 @@ let close_no_err fd = try close fd with _ -> () open Vmm_core let dbdir = Fpath.(v "/var" / "db" / "albatross") + +let dump, restore = + let open R.Infix in + let state_file = Fpath.(dbdir / "state") in + (fun data -> + Bos.OS.File.exists state_file >>= fun exists -> + (if exists then begin + let bak = Fpath.(state_file + "bak") in + Bos.OS.U.(error_to_msg @@ rename state_file bak) + end else Ok ()) >>= fun () -> + Bos.OS.File.write state_file (Cstruct.to_string data)), + (fun () -> + Bos.OS.File.exists state_file >>= fun exists -> + if exists then + Bos.OS.File.read state_file >>| fun data -> + Cstruct.of_string data + else Error `NoFile) + let blockdir = Fpath.(dbdir / "block") let block_file name = diff --git a/src/vmm_unix.mli b/src/vmm_unix.mli index f6606e3..81fbf8b 100644 --- a/src/vmm_unix.mli +++ b/src/vmm_unix.mli @@ -20,3 +20,7 @@ val create_block : Name.t -> int -> (unit, [> R.msg ]) result val destroy_block : Name.t -> (unit, [> R.msg ]) result val find_block_devices : unit -> ((Name.t * int) list, [> R.msg ]) result + +val dump : Cstruct.t -> (unit, [> R.msg ]) result + +val restore : unit -> (Cstruct.t, [> R.msg | `NoFile ]) result diff --git a/src/vmm_vmmd.ml b/src/vmm_vmmd.ml index 4c40d29..4a4a245 100644 --- a/src/vmm_vmmd.ml +++ b/src/vmm_vmmd.ml @@ -16,10 +16,12 @@ type 'a t = { waiters : 'a String.Map.t ; } +let in_shutdown = ref false + let killall t = match List.map snd (Vmm_trie.all t.resources.Vmm_resources.unikernels) with | [] -> false - | vms -> List.iter Vmm_unix.destroy vms ; true + | vms -> in_shutdown := true ; List.iter Vmm_unix.destroy vms ; true let waiter t id = let name = Name.to_string id in @@ -80,6 +82,30 @@ let log t name event = Logs.debug (fun m -> m "log %a" Log.pp data) ; ({ t with log_counter }, `Log (header, `Data (`Log_data data))) +let restore_unikernels () = + match Vmm_unix.restore () with + | Error `NoFile -> + Logs.warn (fun m -> m "no state dump found, starting with no unikernels") ; + Ok Vmm_trie.empty + | Error (`Msg msg) -> Error (`Msg ("while reading state: " ^ msg)) + | Ok data -> + match Vmm_asn.unikernels_of_cstruct data with + | Error (`Msg msg) -> Error (`Msg ("couldn't parse state: " ^ msg)) + | Ok unikernels -> + Logs.info (fun m -> m "restored some unikernels") ; + Ok unikernels + +let dump_unikernels t = + let unikernels = Vmm_trie.all t.resources.Vmm_resources.unikernels in + let trie = List.fold_left (fun t (name, unik) -> + fst @@ Vmm_trie.insert name unik.Unikernel.config t) + Vmm_trie.empty unikernels + in + let data = Vmm_asn.unikernels_to_cstruct trie in + match Vmm_unix.dump data with + | Error (`Msg msg) -> Logs.err (fun m -> m "failed to dump unikernels: %s" msg) + | Ok () -> Logs.info (fun m -> m "dumped current state") + let setup_stats t name vm = let stat_out = let pid = vm.Unikernel.pid in @@ -122,6 +148,7 @@ let handle_create t reply name vm_config = Logs.debug (fun m -> m "exec()ed vm") ; Vmm_resources.insert_vm t.resources name vm >>= fun resources -> let t = { t with resources } in + dump_unikernels t ; let t, out = log t name (`Unikernel_start (name, vm.Unikernel.pid, vm.Unikernel.taps, None)) in let t, stat_out = setup_stats t name vm in Ok (t, stat_out :: out :: reply, name, vm))) @@ -137,6 +164,7 @@ let handle_shutdown t name vm r = | Ok resources -> resources in let t = { t with resources } in + if not !in_shutdown then dump_unikernels t ; let t, logout = log t name (`Unikernel_stop (name, vm.Unikernel.pid, r)) in let t, stat_out = remove_stats t name in (t, [ stat_out ; logout ]) diff --git a/src/vmm_vmmd.mli b/src/vmm_vmmd.mli index f81a9b4..d26f7c3 100644 --- a/src/vmm_vmmd.mli +++ b/src/vmm_vmmd.mli @@ -24,6 +24,10 @@ type 'a create = val handle_shutdown : 'a t -> Name.t -> Unikernel.t -> [ `Exit of int | `Signal of int | `Stop of int ] -> 'a t * out list +val handle_create : 'a t -> out list -> + Name.t -> Unikernel.config -> + ('a t * out list * [ `Create of 'a create ], [> `Msg of string ]) result + val handle_command : 'a t -> Vmm_commands.wire -> 'a t * out list * [ `Create of 'a create @@ -33,3 +37,7 @@ val handle_command : 'a t -> Vmm_commands.wire -> | `Wait_and_create of Name.t * ('a t -> 'a t * out list * [ `Create of 'a create | `End ]) ] val killall : 'a t -> bool + +val restore_unikernels : unit -> (Unikernel.config Vmm_trie.t, [> `Msg of string ]) result + +val dump_unikernels : 'a t -> unit