From bd102092978f85b3329ebc16a8c2d7b88b7d4162 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Sun, 9 Sep 2018 20:52:04 +0200 Subject: [PATCH] wip, vmmc and vmmd talk with each other! --- .merlin | 2 +- _tags | 1 - app/vmm_client.ml | 2 +- app/vmm_console.ml | 69 ++-- app/vmm_influxdb_stats.ml | 97 +++--- app/vmm_log.ml | 257 +++++++++----- app/vmm_tls_endpoint.ml | 172 +++++++++ app/vmmc.ml | 264 ++++++++++++++ app/vmmd.ml | 367 +++++++------------ opam | 1 - pkg/pkg.ml | 6 +- src/vmm_asn.ml | 5 +- src/vmm_commands.ml | 223 ++++++++++++ src/vmm_core.ml | 26 +- src/vmm_engine.ml | 584 +++++-------------------------- src/vmm_lwt.ml | 35 +- src/vmm_tls.ml | 8 +- src/vmm_wire.ml | 716 +++++++++++++++++++++----------------- src/vmm_x509.ml | 163 +++++++++ stats/vmm_stats.ml | 11 +- stats/vmm_stats_lwt.ml | 8 +- 21 files changed, 1746 insertions(+), 1271 deletions(-) create mode 100644 app/vmm_tls_endpoint.ml create mode 100644 app/vmmc.ml create mode 100644 src/vmm_commands.ml create mode 100644 src/vmm_x509.ml diff --git a/.merlin b/.merlin index 3014a5d..d00b855 100644 --- a/.merlin +++ b/.merlin @@ -5,6 +5,6 @@ S provision B _build/** -PKG topkg logs ipaddr x509 tls rresult bos lwt cmdliner hex cstruct.ppx duration +PKG topkg logs ipaddr x509 tls rresult bos lwt cmdliner hex duration PKG ptime ptime.clock.os ipaddr.unix decompress PKG lwt.unix \ No newline at end of file diff --git a/_tags b/_tags index 8b1a4c8..fe13147 100644 --- a/_tags +++ b/_tags @@ -4,7 +4,6 @@ true : package(rresult logs ipaddr x509 tls bos hex ptime ptime.clock.os astring "src" : include : package(decompress) -: package(ppx_cstruct) : package(asn1-combinators) : package(lwt lwt.unix) : package(lwt tls.lwt) diff --git a/app/vmm_client.ml b/app/vmm_client.ml index acd7cc7..46b5485 100644 --- a/app/vmm_client.ml +++ b/app/vmm_client.ml @@ -4,7 +4,7 @@ open Lwt.Infix open Vmm_core -let my_version = `WV0 +let my_version = `WV2 let command = ref 1 let process db hdr data = diff --git a/app/vmm_console.ml b/app/vmm_console.ml index 055efa5..65e4563 100644 --- a/app/vmm_console.ml +++ b/app/vmm_console.ml @@ -13,21 +13,14 @@ open Lwt.Infix open Astring -open Vmm_wire -open Vmm_wire.Console - -let my_version = `WV0 - -let pp_sockaddr ppf = function - | Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str - | Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d" - (Unix.string_of_inet_addr addr) port +let my_version = `WV2 let pp_unix_error ppf e = Fmt.string ppf (Unix.error_message e) let active = ref String.Map.empty let read_console name ring channel () = + let id = Vmm_core.id_of_string name in Lwt.catch (fun () -> let rec loop () = Lwt_io.read_line channel >>= fun line -> @@ -37,8 +30,10 @@ let read_console name ring channel () = (match String.Map.find name !active with | None -> Lwt.return_unit | Some fd -> - Vmm_lwt.write_raw fd (data my_version name t line) >>= function - | Error _ -> Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) + Vmm_lwt.write_wire fd (Vmm_wire.Console.data my_version id t line) >>= function + | Error _ -> + Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >|= fun () -> + active := String.Map.remove name !active | Ok () -> Lwt.return_unit) >>= loop in @@ -70,7 +65,8 @@ let open_fifo name = let t = ref String.Map.empty -let add_fifo name = +let add_fifo id = + let name = Vmm_core.string_of_id id in open_fifo name >|= function | Some f -> let ring = Vmm_ring.create () in @@ -82,63 +78,68 @@ let add_fifo name = | None -> Error (`Msg "opening") -let attach s name = - Logs.debug (fun m -> m "attempting to attach %s" name) ; +let attach s id = + let name = Vmm_core.string_of_id id in + Logs.debug (fun m -> m "attempting to attach %a" Vmm_core.pp_id id) ; match String.Map.find name !t with | None -> Lwt.return (Error (`Msg "not found")) | Some _ -> active := String.Map.add name s !active ; Lwt.return (Ok "attached") -let detach name = +let detach id = + let name = Vmm_core.string_of_id id in active := String.Map.remove name !active ; Lwt.return (Ok "removed") let history s name since = - match String.Map.find name !t with - | None -> Lwt.return (Rresult.R.error_msgf "ring %s not found (%d): %a" - name (String.Map.cardinal !t) + match String.Map.find (Vmm_core.string_of_id name) !t with + | None -> Lwt.return (Rresult.R.error_msgf "ring %a not found (%d): %a" + Vmm_core.pp_id name (String.Map.cardinal !t) Fmt.(list ~sep:(unit ";") string) (List.map fst (String.Map.bindings !t))) | Some r -> let entries = Vmm_ring.read_history r since in Logs.debug (fun m -> m "found %d history" (List.length entries)) ; Lwt_list.iter_s (fun (i, v) -> - Vmm_lwt.write_raw s (data my_version name i v) >|= fun _ -> ()) + Vmm_lwt.write_wire s (Vmm_wire.Console.data my_version name i v) >|= fun _ -> ()) entries >|= fun () -> Ok "success" let handle s addr () = - Logs.info (fun m -> m "handling connection %a" pp_sockaddr addr) ; + Logs.info (fun m -> m "handling connection %a" Vmm_lwt.pp_sockaddr addr) ; let rec loop () = - Vmm_lwt.read_exactly s >>= function + Vmm_lwt.read_wire s >>= function | Error (`Msg msg) -> Logs.err (fun m -> m "error while reading %s" msg) ; loop () | Error _ -> Logs.err (fun m -> m "exception while reading") ; Lwt.return_unit + | Ok (hdr, _) when Vmm_wire.is_reply hdr -> + Logs.err (fun m -> m "unexpected reply") ; + loop () | Ok (hdr, data) -> - (if not (version_eq hdr.version my_version) then + (if not (Vmm_wire.version_eq hdr.version my_version) then Lwt.return (Error (`Msg "ignoring data with bad version")) else - match decode_str data with + match Vmm_wire.decode_strings data with | Error e -> Lwt.return (Error e) - | Ok (name, off) -> - match Console.int_to_op hdr.tag with - | Some Add_console -> add_fifo name - | Some Attach_console -> attach s name - | Some Detach_console -> detach name - | Some History -> - (match decode_ts ~off data with + | Ok (id, off) -> match Vmm_wire.Console.int_to_op hdr.tag with + | Some Vmm_wire.Console.Add_console -> add_fifo id + | Some Vmm_wire.Console.Attach_console -> attach s id + | Some Vmm_wire.Console.Detach_console -> detach id + | Some Vmm_wire.Console.History -> + (match Vmm_wire.decode_ptime ~off data with | Error e -> Lwt.return (Error e) - | Ok since -> history s name since) - | _ -> + | Ok since -> history s id since) + | Some Vmm_wire.Console.Data -> Lwt.return (Error (`Msg "unexpected Data")) + | None -> Lwt.return (Error (`Msg "unknown command"))) >>= (function - | Ok msg -> Vmm_lwt.write_raw s (success ~msg hdr.id my_version) + | Ok msg -> Vmm_lwt.write_wire s (Vmm_wire.success ~msg my_version hdr.Vmm_wire.id hdr.Vmm_wire.tag) | Error (`Msg msg) -> Logs.err (fun m -> m "error while processing command: %s" msg) ; - Vmm_lwt.write_raw s (fail ~msg hdr.id my_version)) >>= function + Vmm_lwt.write_wire s (Vmm_wire.fail ~msg my_version hdr.Vmm_wire.id)) >>= function | Ok () -> loop () | Error _ -> Logs.err (fun m -> m "exception while writing to socket") ; diff --git a/app/vmm_influxdb_stats.ml b/app/vmm_influxdb_stats.ml index 8a3c89d..a13450c 100644 --- a/app/vmm_influxdb_stats.ml +++ b/app/vmm_influxdb_stats.ml @@ -142,9 +142,9 @@ end let my_version = `WV1 -let command = ref 1 +let command = ref 1L -let (req : string IM.t ref) = ref IM.empty +let (req : string IM64.t ref) = ref IM64.empty let str_of_e = function | `Eof -> "end of file" @@ -192,54 +192,53 @@ let rec read_sock_write_tcp closing db c ?fd addr addrtype = else begin let open Vmm_wire in Logs.debug (fun m -> m "reading from unix socket") ; - Vmm_lwt.read_exactly c >>= function + Vmm_lwt.read_wire c >>= function | Error e -> Logs.err (fun m -> m "error %s while reading vmm socket (return)" (str_of_e e)) ; closing := true ; safe_close fd | Ok (hdr, data) -> - if not (version_eq hdr.version my_version) then begin - Logs.err (fun m -> m "unknown wire protocol version") ; - closing := true ; - safe_close fd - end else - let name = - try IM.find hdr.id !req - with Not_found -> "not found" - in - req := IM.remove hdr.id !req ; - begin match Stats.int_to_op hdr.tag with - | Some Stats.Stat_reply -> - begin match Vmm_wire.Stats.decode_stats (Cstruct.of_string data) with - | Error (`Msg msg) -> - Logs.warn (fun m -> m "error %s while decoding stats %s, ignoring" - msg name) ; - Lwt.return (Some fd) - | Ok (ru, vmm, ifs) -> - let ru = P.encode_ru name ru in - let vmm = P.encode_vmm name vmm in - let taps = List.map (P.encode_if name) ifs in - let out = (String.concat ~sep:"\n" (ru :: vmm :: taps)) ^ "\n" in - Logs.debug (fun m -> m "writing %d via tcp" (String.length out)) ; - Vmm_lwt.write_raw fd out >>= function - | Ok () -> - Logs.debug (fun m -> m "wrote successfully") ; - Lwt.return (Some fd) - | Error e -> - Logs.err (fun m -> m "error %s while writing to tcp (%s)" - (str_of_e e) name) ; - safe_close fd >|= fun () -> - None - end - | _ when hdr.tag = fail_tag -> - Logs.err (fun m -> m "failed to retrieve statistics for %s" name) ; - Lwt.return (Some fd) - | _ -> - Logs.err (fun m -> m "unhandled tag %d for %s" hdr.tag name) ; - Lwt.return (Some fd) - end >>= fun fd -> - read_sock_write_tcp closing db c ?fd addr addrtype + let name = + try IM64.find hdr.id !req + with Not_found -> "not found" + in + req := IM64.remove hdr.id !req ; + (if not (version_eq hdr.version my_version) then begin + Logs.err (fun m -> m "unknown wire protocol version") ; + closing := true ; + safe_close fd >|= fun () -> + None + end else if Vmm_wire.is_fail hdr then begin + Logs.err (fun m -> m "failed to retrieve statistics for %s" name) ; + Lwt.return (Some fd) + end else if Vmm_wire.is_reply hdr then + begin match Vmm_wire.Stats.decode_stats data with + | Error (`Msg msg) -> + Logs.warn (fun m -> m "error %s while decoding stats %s, ignoring" + msg name) ; + Lwt.return (Some fd) + | Ok (ru, vmm, ifs) -> + let ru = P.encode_ru name ru in + let vmm = P.encode_vmm name vmm in + let taps = List.map (P.encode_if name) ifs in + let out = (String.concat ~sep:"\n" (ru :: vmm :: taps)) ^ "\n" in + Logs.debug (fun m -> m "writing %d via tcp" (String.length out)) ; + Vmm_lwt.write_wire fd (Cstruct.of_string out) >>= function + | Ok () -> + Logs.debug (fun m -> m "wrote successfully") ; + Lwt.return (Some fd) + | Error e -> + Logs.err (fun m -> m "error %s while writing to tcp (%s)" + (str_of_e e) name) ; + safe_close fd >|= fun () -> + None + end + else begin + Logs.err (fun m -> m "unhandled tag %lu for %s" hdr.tag name) ; + Lwt.return (Some fd) + end) >>= fun fd -> + read_sock_write_tcp closing db c ?fd addr addrtype end let rec query_sock closing prefix db c interval = @@ -252,12 +251,12 @@ let rec query_sock closing prefix db c interval = | Error e -> Lwt.return (Error e) | Ok () -> let id = identifier id in - let id = match prefix with None -> id | Some p -> p ^ "." ^ id in + let id = match prefix with None -> [ id ] | Some p -> [ p ; id ] in let request = Vmm_wire.Stats.stat !command my_version id in - req := IM.add !command name !req ; - incr command ; - Logs.debug (fun m -> m "%d requesting %s via socket" !command id) ; - Vmm_lwt.write_raw c request) + req := IM64.add !command name !req ; + command := Int64.succ !command ; + Logs.debug (fun m -> m "%Lu requesting %a via socket" !command pp_id id) ; + Vmm_lwt.write_wire c request) (Ok ()) db >>= function | Error e -> Logs.err (fun m -> m "error %s while writing to vmm socket" (str_of_e e)) ; diff --git a/app/vmm_log.ml b/app/vmm_log.ml index 010e4ac..04b1a4e 100644 --- a/app/vmm_log.ml +++ b/app/vmm_log.ml @@ -14,14 +14,72 @@ open Lwt.Infix open Astring -open Vmm_wire -open Vmm_wire.Log +let my_version = `WV2 -let my_version = `WV0 +type t = N of Lwt_unix.file_descr list * t String.Map.t -let write_complete s str = - let l = String.length str in - let b = Bytes.unsafe_of_string str in +let empty = N ([], String.Map.empty) + +let insert id fd t = + let rec go (N (fds, m)) = function + | [] -> N ((fd :: fds), m) + | x::xs -> + let n = match String.Map.find_opt x m with + | None -> empty + | Some n -> n + in + let entry = go n xs in + N (fds, String.Map.add x entry m) + in + go t id + +let remove id fd t = + let rec go (N (fds, m)) = function + | [] -> + begin match List.filter (fun fd' -> fd <> fd') fds with + | [] -> None + | fds' -> Some (N (fds', m)) + end + | x::xs -> + let n' = match String.Map.find_opt x m with + | None -> None + | Some n -> go n xs + in + let m' = match n' with + | None -> String.Map.remove x m + | Some entry -> String.Map.add x entry m + in + if String.Map.is_empty m' && fds = [] then None else Some (N (fds, m')) + in + match go t id with + | None -> empty + | Some n -> n + +let collect id t = + let rec go acc prefix (N (fds, m)) = + let acc' = + let here = List.map (fun fd -> (prefix, fd)) fds in + here @ acc + in + function + | [] -> acc' + | x::xs -> + match String.Map.find_opt x m with + | None -> acc' + | Some n -> go acc' (prefix @ [ x ]) n xs + in + go [] [] t id + +let broadcast prefix data t = + Lwt_list.fold_left_s (fun t (id, s) -> + Vmm_lwt.write_wire s data >|= function + | Ok () -> t + | Error `Exception -> remove id s t) + t (collect prefix t) + +let write_complete s cs = + let l = Cstruct.len cs in + let b = Cstruct.to_bytes cs in let rec w off = let len = l - off in Lwt_unix.write s b off len >>= fun n -> @@ -29,110 +87,141 @@ let write_complete s str = in w 0 -let pp_sockaddr ppf = function - | Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str - | Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d" - (Unix.string_of_inet_addr addr) port +let write_to_file file = + let mvar = Lwt_mvar.create_empty () in + let rec write_loop ?(retry = true) ?data ?fd () = + match fd with + | None when retry -> + Lwt_unix.openfile file Lwt_unix.[O_APPEND;O_CREAT;O_WRONLY] 0o600 >>= fun fd -> + write_loop ~retry:false ?data ~fd () + | None -> + Logs.err (fun m -> m "retry is false, exiting") ; + Lwt.return_unit + | Some fd -> + (match data with + | None -> Lwt_mvar.take mvar + | Some d -> Lwt.return d) >>= fun data -> + Lwt.catch + (fun () -> write_complete fd data >|= fun () -> (true, None, Some fd)) + (fun e -> + Logs.err (fun m -> m "exception %s while writing" (Printexc.to_string e)) ; + Vmm_lwt.safe_close fd >|= fun () -> + (retry, Some data, None)) >>= fun (retry, data, fd) -> + write_loop ~retry ?data ?fd () + in + mvar, write_loop -let handle fd ring s addr () = - Logs.info (fun m -> m "handling connection from %a" pp_sockaddr addr) ; - let str = Fmt.strf "%a: CONNECT\n" (Ptime.pp_human ~tz_offset_s:0 ()) (Ptime_clock.now ()) in - write_complete fd str >>= fun () -> +(* TODO: + - should there be an unsubscribe command? + - should there be acks for history/datain? + *) + +let tree = ref empty + +let bcast = ref 0L + +let handle mvar ring s addr () = + Logs.info (fun m -> m "handling connection from %a" Vmm_lwt.pp_sockaddr addr) ; + let str = Fmt.strf "%a: CONNECT\n" (Ptime.pp_human ()) (Ptime_clock.now ()) in + Lwt_mvar.put mvar (Cstruct.of_string str) >>= fun () -> let rec loop () = - Vmm_lwt.read_exactly s >>= function + Vmm_lwt.read_wire s >>= function | Error (`Msg e) -> Logs.err (fun m -> m "error while reading %s" e) ; loop () | Error _ -> Logs.err (fun m -> m "exception while reading") ; Lwt.return_unit - | Ok (hdr, data) -> - let out = - (if not (version_eq hdr.version my_version) then - Error (`Msg "unknown version") - else match int_to_op hdr.tag with - | Some Data -> - (match decode_ts data with - | Ok ts -> Vmm_ring.write ring (ts, data) - | Error _ -> - Logs.warn (fun m -> m "ignoring error while decoding timestamp %s" data)) ; - Ok (`Data data) - | Some History -> - begin match decode_str data with - | Error e -> Error e - | Ok (str, off) -> match decode_ts ~off data with - | Error e -> Error e - | Ok ts -> - let elements = Vmm_ring.read_history ring ts in - let res = List.fold_left (fun acc (_, x) -> - match Vmm_wire.Log.decode_log_hdr (Cstruct.of_string x) with - | Ok (hdr, _) -> - Logs.debug (fun m -> m "found an entry: %a" (Vmm_core.Log.pp_hdr []) hdr) ; - if String.equal str (Vmm_core.string_of_id hdr.Vmm_core.Log.context) then - x :: acc - else - acc - | _ -> acc) - [] elements - in - (* just need a wrapper in tag = Log.Data, id = reqid *) - let out = - List.fold_left (fun acc x -> - let length = String.length x in - let hdr = Vmm_wire.create_header { length ; id = hdr.id ; tag = op_to_int Data ; version = my_version } in - (Cstruct.to_string hdr ^ x) :: acc) - [] (List.rev res) - in - Ok (`Out out) - end - | _ -> Error (`Msg "unknown command")) - in - match out with - | Error (`Msg msg) -> - begin - Logs.err (fun m -> m "error while processing: %s" msg) ; - Vmm_lwt.write_raw s (fail ~msg hdr.id my_version) >>= function - | Error _ -> Logs.err (fun m -> m "error0 while writing") ; Lwt.return_unit - | Ok () -> loop () + | Ok (hdr, _) when Vmm_wire.is_reply hdr -> + Logs.warn (fun m -> m "ignoring reply") ; + loop () + | Ok (hdr, _) when not (Vmm_wire.version_eq hdr.Vmm_wire.version my_version) -> + Logs.warn (fun m -> m "unsupported version") ; + Lwt.return_unit + | Ok (hdr, data) -> match Vmm_wire.Log.int_to_op hdr.Vmm_wire.tag with + | Some Vmm_wire.Log.Log -> + begin match Vmm_wire.Log.decode_log_hdr data with + | Error (`Msg err) -> + Logs.warn (fun m -> m "ignoring error %s while decoding log" err) ; + loop () + | Ok (hdr, _) -> + Vmm_ring.write ring (hdr.Vmm_core.Log.ts, Cstruct.to_string data) ; + Lwt_mvar.put mvar data >>= fun () -> + let data' = Vmm_wire.encode ~body:data my_version !bcast (Vmm_wire.Log.op_to_int Vmm_wire.Log.Broadcast) in + bcast := Int64.succ !bcast ; + broadcast hdr.Vmm_core.Log.context data' !tree >>= fun tree' -> + tree := tree' ; + loop () end - | Ok (`Data data) -> - begin - write_complete fd data >>= fun () -> - Vmm_lwt.write_raw s (success hdr.id my_version) >>= function - | Error _ -> Logs.err (fun m -> m "error1 while writing") ; Lwt.return_unit - | Ok () -> loop () + | Some Vmm_wire.Log.History -> + begin match Vmm_wire.decode_id_ts data with + | Error (`Msg err) -> + Logs.warn (fun m -> m "ignoring error %s while decoding history" err) ; + loop () + | Ok ((sub, ts), _) -> + let elements = Vmm_ring.read_history ring ts in + let res = + List.fold_left (fun acc (_, x) -> + let cs = Cstruct.of_string x in + match Vmm_wire.Log.decode_log_hdr cs with + | Ok (hdr, _) when Vmm_core.is_sub_id ~super:hdr.Vmm_core.Log.context ~sub -> + cs :: acc + | _ -> acc) + [] elements + in + (* just need a wrapper in tag = Log.Data, id = reqid *) + Lwt_list.fold_left_s (fun r body -> + match r with + | Ok () -> + let data = Vmm_wire.encode ~body my_version hdr.Vmm_wire.id (Vmm_wire.Log.op_to_int Vmm_wire.Log.Log) in + Vmm_lwt.write_wire s data + | Error e -> Lwt.return (Error e)) + (Ok ()) res >>= function + | Ok () -> loop () + | Error _ -> + Logs.err (fun m -> m "error while sending data in history") ; + Lwt.return_unit end - | Ok (`Out datas) -> - Lwt_list.fold_left_s (fun r x -> match r with - | Error e -> Lwt.return (Error e) - | Ok () -> Vmm_lwt.write_raw s x) - (Ok ()) datas >>= function - | Error _ -> Logs.err (fun m -> m "error2 while writing") ; Lwt.return_unit - | Ok () -> - Vmm_lwt.write_raw s (success hdr.id my_version) >>= function - | Error _ -> Logs.err (fun m -> m "error3 while writing") ; Lwt.return_unit - | Ok () -> loop () + | Some Vmm_wire.Log.Subscribe -> + begin match Vmm_wire.decode_strings data with + | Error (`Msg err) -> + Logs.warn (fun m -> m "ignoring error %s while decoding subscribe" err) ; + loop () + | Ok (id, _) -> + tree := insert id s !tree ; + let out = Vmm_wire.success my_version hdr.Vmm_wire.id hdr.Vmm_wire.tag in + Vmm_lwt.write_wire s out >>= function + | Ok () -> loop () + | Error _ -> + Logs.err (fun m -> m "error while sending reply for subscribe") ; + Lwt.return_unit + end + | _ -> + Logs.err (fun m -> m "unknown command") ; + loop () in loop () >>= fun () -> - Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit) + Vmm_lwt.safe_close s + (* should remove all the s from the tree above *) let jump _ file sock = Sys.(set_signal sigpipe Signal_ignore) ; Lwt_main.run - (Lwt_unix.openfile file Lwt_unix.[O_APPEND;O_CREAT;O_WRONLY] 0o600 >>= fun fd -> - (Lwt_unix.file_exists sock >>= function + ((Lwt_unix.file_exists sock >>= function | true -> Lwt_unix.unlink sock | false -> Lwt.return_unit) >>= fun () -> let s = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in Lwt_unix.(bind s (ADDR_UNIX sock)) >>= fun () -> Lwt_unix.listen s 1 ; let ring = Vmm_ring.create () in + let mvar, writer = write_to_file file in let rec loop () = Lwt_unix.accept s >>= fun (cs, addr) -> - Lwt.async (handle fd ring cs addr) ; + Lwt.async (handle mvar ring cs addr) ; loop () in - loop ()) + Lwt.pick [ loop () ; writer () ]) ; + `Ok () let setup_log style_renderer level = Fmt_tty.setup_std_outputs ?style_renderer (); diff --git a/app/vmm_tls_endpoint.ml b/app/vmm_tls_endpoint.ml new file mode 100644 index 0000000..85c100a --- /dev/null +++ b/app/vmm_tls_endpoint.ml @@ -0,0 +1,172 @@ +(* (c) 2017, 2018 Hannes Mehnert, all rights reserved *) + +open Lwt.Infix + +let write_tls state t data = + Vmm_tls.write_tls (fst t) data >>= function + | Ok () -> Lwt.return_unit + | Error `Exception -> + let state', out = Vmm_engine.handle_disconnect !state t in + state := state' ; + Lwt_list.iter_s (fun (s, data) -> write_raw s data) out >>= fun () -> + Tls_lwt.Unix.close (fst t) + +let to_ipaddr (_, sa) = match sa with + | Lwt_unix.ADDR_UNIX _ -> invalid_arg "cannot convert unix address" + | Lwt_unix.ADDR_INET (addr, port) -> Ipaddr_unix.V4.of_inet_addr_exn addr, port + +let pp_sockaddr ppf (_, sa) = match sa with + | Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str + | Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d" + (Unix.string_of_inet_addr addr) port + + +let server_socket port = + let open Lwt_unix in + let s = socket PF_INET SOCK_STREAM 0 in + set_close_on_exec s ; + setsockopt s SO_REUSEADDR true ; + bind s (ADDR_INET (Unix.inet_addr_any, port)) >>= fun () -> + listen s 10 ; + Lwt.return s + +let rec read_log state s = + Vmm_lwt.read_exactly s >>= function + | Error (`Msg msg) -> + Logs.err (fun m -> m "reading log error %s" msg) ; + read_log state s + | Error _ -> + Logs.err (fun m -> m "exception while reading log") ; + invalid_arg "log socket communication issue" + | Ok (hdr, data) -> + let state', outs = Vmm_engine.handle_log !state hdr data in + state := state' ; + process state outs >>= fun () -> + read_log state s + +let rec read_cons state s = + Vmm_lwt.read_exactly s >>= function + | Error (`Msg msg) -> + Logs.err (fun m -> m "reading console error %s" msg) ; + read_cons state s + | Error _ -> + Logs.err (fun m -> m "exception while reading console socket") ; + invalid_arg "console socket communication issue" + | Ok (hdr, data) -> + let state', outs = Vmm_engine.handle_cons !state hdr data in + state := state' ; + process state outs >>= fun () -> + read_cons state s + +let rec read_stats state s = + Vmm_lwt.read_exactly s >>= function + | Error (`Msg msg) -> + Logs.err (fun m -> m "reading stats error %s" msg) ; + read_stats state s + | Error _ -> + Logs.err (fun m -> m "exception while reading stats") ; + Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit) >|= fun () -> + invalid_arg "stat socket communication issue" + | Ok (hdr, data) -> + let state', outs = Vmm_engine.handle_stat !state hdr data in + state := state' ; + process state outs >>= fun () -> + read_stats state s + +let cmp_s (_, a) (_, b) = + let open Lwt_unix in + match a, b with + | ADDR_UNIX str, ADDR_UNIX str' -> String.compare str str' = 0 + | ADDR_INET (addr, port), ADDR_INET (addr', port') -> + port = port' && + String.compare (Unix.string_of_inet_addr addr) (Unix.string_of_inet_addr addr') = 0 + | _ -> false + +let jump _ cacert cert priv_key port = + Sys.(set_signal sigpipe Signal_ignore) ; + Lwt_main.run + (Nocrypto_entropy_lwt.initialize () >>= fun () -> + (init_sock Vmm_core.tmpdir "cons" >|= function + | None -> invalid_arg "cannot connect to console socket" + | Some c -> c) >>= fun c -> + init_sock Vmm_core.tmpdir "stat" >>= fun s -> + (init_sock Vmm_core.tmpdir "log" >|= function + | None -> invalid_arg "cannot connect to log socket" + | Some l -> l) >>= fun l -> + server_socket port >>= fun socket -> + X509_lwt.private_of_pems ~cert ~priv_key >>= fun cert -> + X509_lwt.certs_of_pem cacert >>= (function + | [ ca ] -> Lwt.return ca + | _ -> Lwt.fail_with "expect single ca as cacert") >>= fun ca -> + let config = + Tls.(Config.server ~version:(Core.TLS_1_2, Core.TLS_1_2) + ~reneg:true ~certificates:(`Single cert) ()) + in + (match Vmm_engine.init cmp_s c s l with + | Ok s -> Lwt.return s + | Error (`Msg m) -> Lwt.fail_with m) >>= fun t -> + let state = ref t in + Lwt.async (fun () -> read_cons state c) ; + (match s with + | None -> () + | Some s -> Lwt.async (fun () -> read_stats state s)) ; + Lwt.async (fun () -> read_log state l) ; + Lwt.async stats_loop ; + let rec loop () = + Lwt.catch (fun () -> + Lwt_unix.accept socket >>= fun (fd, addr) -> + Lwt_unix.set_close_on_exec fd ; + Lwt.catch + (fun () -> Tls_lwt.Unix.server_of_fd config fd >|= fun t -> (t, addr)) + (fun exn -> + Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () -> + Lwt.fail exn) >>= fun t -> + Lwt.async (fun () -> + Lwt.catch + (fun () -> handle ca state t) + (fun e -> + Logs.err (fun m -> m "error while handle() %s" + (Printexc.to_string e)) ; + Lwt.return_unit)) ; + loop ()) + (function + | Unix.Unix_error (e, f, _) -> + Logs.err (fun m -> m "Unix error %s in %s" (Unix.error_message e) f) ; + loop () + | Tls_lwt.Tls_failure a -> + Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a)) ; + loop () + | exn -> + Logs.err (fun m -> m "exception %s" (Printexc.to_string exn)) ; + loop ()) + in + loop ()) + +let setup_log style_renderer level = + Fmt_tty.setup_std_outputs ?style_renderer (); + Logs.set_level level; + Logs.set_reporter (Logs_fmt.reporter ~dst:Format.std_formatter ()) + +open Cmdliner + +let setup_log = + Term.(const setup_log + $ Fmt_cli.style_renderer () + $ Logs_cli.level ()) + +let cacert = + let doc = "CA certificate" in + Arg.(required & pos 0 (some file) None & info [] ~doc) + +let cert = + let doc = "Certificate" in + Arg.(required & pos 1 (some file) None & info [] ~doc) + +let key = + let doc = "Private key" in + Arg.(required & pos 2 (some file) None & info [] ~doc) + +let port = + let doc = "TCP listen port" in + Arg.(value & opt int 1025 & info [ "port" ] ~doc) + diff --git a/app/vmmc.ml b/app/vmmc.ml new file mode 100644 index 0000000..df61962 --- /dev/null +++ b/app/vmmc.ml @@ -0,0 +1,264 @@ +(* (c) 2017, 2018 Hannes Mehnert, all rights reserved *) + +open Lwt.Infix + +open Vmm_core + +let my_version = `WV2 +let my_command = 1L + +(* +let process db hdr data = + let open Vmm_wire in + let open Rresult.R.Infix in + if not (version_eq hdr.version my_version) then + Logs.err (fun m -> m "unknown wire protocol version") + else + let r = + match hdr.tag with + | x when x = Client.stat_msg_tag -> + Client.decode_stat data >>= fun (ru, vmm, ifd) -> + Logs.app (fun m -> m "statistics: %a %a %a" + pp_rusage ru + Fmt.(list ~sep:(unit ", ") (pair ~sep:(unit ": ") string uint64)) vmm + Fmt.(list ~sep:(unit ", ") pp_ifdata) ifd) ; + Ok () + | x when x = Client.log_msg_tag -> + Client.decode_log data >>= fun log -> + Logs.app (fun m -> m "log: %a" (Vmm_core.Log.pp db) log) ; + Ok () + | x when x = Client.console_msg_tag -> + Client.decode_console data >>= fun (name, ts, msg) -> + Logs.app (fun m -> m "console %s: %a %s" (translate_serial db name) (Ptime.pp_human ~tz_offset_s:0 ()) ts msg) ; + Ok () + | x when x = Client.info_msg_tag -> + Client.decode_info data >>= fun vms -> + List.iter (fun (name, cmd, pid, taps) -> + Logs.app (fun m -> m "info %s: %s %d taps %a" (translate_serial db name) + cmd pid Fmt.(list ~sep:(unit ", ") string) taps)) + vms ; + Ok () + | x when x = fail_tag -> + decode_str data >>= fun (msg, _) -> + Logs.err (fun m -> m "failed %s" msg) ; + Ok () + | x when x = success_tag -> + decode_str data >>= fun (msg, _) -> + Logs.app (fun m -> m "success %s" msg) ; + Ok () + | x -> Rresult.R.error_msgf "unknown header tag %02X" x + in + match r with + | Ok () -> () + | Error (`Msg msg) -> Logs.err (fun m -> m "error while processing: %s" msg) +*) + +let process fd = + Vmm_lwt.read_wire fd >|= function + | Error _ -> Error () + | Ok (hdr, data) -> + if not (Vmm_wire.version_eq hdr.Vmm_wire.version my_version) then begin + Logs.err (fun m -> m "unknown wire protocol version") ; + Error () + end else begin + if Vmm_wire.is_fail hdr then begin + let msg = match Vmm_wire.decode_string data with + | Ok (msg, _) -> Some msg + | Error _ -> None + in + Logs.err (fun m -> m "command failed %a" Fmt.(option ~none:(unit "") string) msg) ; + Error () + end else if Vmm_wire.is_reply hdr && hdr.Vmm_wire.id = my_command then + Ok data + else begin + Logs.err (fun m -> m "received unexpected data") ; + Error () + end + end + +let connect socket = + let c = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in + Lwt_unix.set_close_on_exec c ; + Lwt_unix.connect c (Lwt_unix.ADDR_UNIX socket) >|= fun () -> + c + +let info_ _ socket name = + Lwt_main.run ( + connect socket >>= fun fd -> + let name' = Astring.String.cuts ~empty:false ~sep:"." name in + let info = Vmm_wire.Vm.info my_command my_version name' in + (Vmm_lwt.write_wire fd info >>= function + | Ok () -> + (process fd >|= function + | Error () -> () + | Ok data -> + match Vmm_wire.Vm.decode_vms data with + | Ok (vms, _) -> + List.iter (fun (id, memory, cmd, pid, taps) -> + Logs.app (fun m -> m "VM %a %dMB command %s pid %d taps %a" + pp_id id memory cmd pid Fmt.(list ~sep:(unit ", ") string) taps)) + vms + | Error (`Msg msg) -> + Logs.err (fun m -> m "error %s while decoding vms" msg)) + | Error `Exception -> Lwt.return_unit) >>= fun () -> + Vmm_lwt.safe_close fd + ) ; + `Ok () + +let really_destroy socket name = + connect socket >>= fun fd -> + let cmd = Vmm_wire.Vm.destroy my_command my_version (Astring.String.cuts ~empty:false ~sep:"." name) in + (Vmm_lwt.write_wire fd cmd >>= function + | Ok () -> + (process fd >|= function + | Error () -> () + | Ok _ -> Logs.app (fun m -> m "destroyed VM")) + | Error `Exception -> Lwt.return_unit) >>= fun () -> + Vmm_lwt.safe_close fd + +let destroy _ socket name = + Lwt_main.run (really_destroy socket name) ; + `Ok () + +let create _ socket force name image cpuid requested_memory boot_params block_device network = + let image' = match Bos.OS.File.read (Fpath.v image) with + | Ok data -> data + | Error (`Msg s) -> invalid_arg s + in + let prefix, vname = match List.rev (Astring.String.cuts ~empty:false ~sep:"." name) with + | [ name ] -> [], name + | name::tl -> List.rev tl, name + | [] -> assert false + and argv = match boot_params with + | [] -> None + | xs -> Some xs + and vmimage = `Ukvm_amd64, Cstruct.of_string image' + in + let vm_config = { + prefix ; vname ; cpuid ; requested_memory ; block_device ; network ; + vmimage ; argv + } in + Lwt_main.run ( + (if force then + really_destroy socket name + else + Lwt.return_unit) >>= fun () -> + connect socket >>= fun fd -> + let vm = Vmm_wire.Vm.create my_command my_version vm_config in + (Vmm_lwt.write_wire fd vm >>= function + | Error `Exception -> Lwt.return_unit + | Ok () -> process fd >|= function + | Ok _ -> Logs.app (fun m -> m "successfully started VM") + | Error () -> ()) >>= fun () -> + Vmm_lwt.safe_close fd + ) ; + `Ok () + +let help _ _ man_format cmds = function + | None -> `Help (`Pager, None) + | Some t when List.mem t cmds -> `Help (man_format, Some t) + | Some _ -> List.iter print_endline cmds; `Ok () + +let setup_log style_renderer level = + Fmt_tty.setup_std_outputs ?style_renderer (); + Logs.set_level level; + Logs.set_reporter (Logs_fmt.reporter ~dst:Format.std_formatter ()) + +open Cmdliner + +let setup_log = + Term.(const setup_log + $ Fmt_cli.style_renderer () + $ Logs_cli.level ()) + +let socket = + let doc = "Socket to connect to" in + let sock = Fpath.(to_string (Vmm_core.tmpdir / "vmmd" + "sock")) in + Arg.(value & opt string sock & info [ "s" ; "socket" ] ~doc) + +let force = + let doc = "force VM creation." in + Arg.(value & flag & info [ "f" ; "force" ] ~doc) + +let image = + let doc = "File of virtual machine image." in + Arg.(required & pos 1 (some file) None & info [] ~doc) + +let vm_name = + let doc = "Name virtual machine config." in + Arg.(required & pos 0 (some string) None & info [] ~doc) + +let destroy_cmd = + let doc = "destroys a virtual machine" in + let man = + [`S "DESCRIPTION"; + `P "Destroy a virtual machine."] + in + Term.(ret (const destroy $ setup_log $ socket $ vm_name)), + Term.info "destroy" ~doc ~man + +let info_cmd = + let doc = "information about VMs" in + let man = + [`S "DESCRIPTION"; + `P "Shows information about VMs."] + in + Term.(ret (const info_ $ setup_log $ socket $ vm_name)), + Term.info "info" ~doc ~man + +let cpu = + let doc = "CPUid" in + Arg.(value & opt int 0 & info [ "cpu" ] ~doc) + +let mem = + let doc = "Memory to provision" in + Arg.(value & opt int 512 & info [ "mem" ] ~doc) + +let args = + let doc = "Boot arguments" in + Arg.(value & opt_all string [] & info [ "arg" ] ~doc) + +let block = + let doc = "Block device name" in + Arg.(value & opt (some string) None & info [ "block" ] ~doc) + +let net = + let doc = "Network device" in + Arg.(value & opt_all string [] & info [ "net" ] ~doc) + +let create_cmd = + let doc = "creates a virtual machine" in + let man = + [`S "DESCRIPTION"; + `P "Creates a virtual machine."] + in + Term.(ret (const create $ setup_log $ socket $ force $ vm_name $ image $ cpu $ mem $ args $ block $ net)), + Term.info "create" ~doc ~man + +let help_cmd = + let topic = + let doc = "The topic to get help on. `topics' lists the topics." in + Arg.(value & pos 0 (some string) None & info [] ~docv:"TOPIC" ~doc) + in + let doc = "display help about vmmc" in + let man = + [`S "DESCRIPTION"; + `P "Prints help about conex commands and subcommands"] + in + Term.(ret (const help $ setup_log $ socket $ Term.man_format $ Term.choice_names $ topic)), + Term.info "help" ~doc ~man + +let default_cmd = + let doc = "VMM client" in + let man = [ + `S "DESCRIPTION" ; + `P "$(tname) connects to vmmd via a local socket" ] + in + Term.(ret (const help $ setup_log $ socket $ Term.man_format $ Term.choice_names $ Term.pure None)), + Term.info "vmmc" ~version:"%%VERSION_NUM%%" ~doc ~man + +let cmds = [ help_cmd ; info_cmd ; destroy_cmd ; create_cmd ] + +let () = + match Term.eval_choice default_cmd cmds + with `Ok () -> exit 0 | _ -> exit 1 diff --git a/app/vmmd.ml b/app/vmmd.ml index 2967d8b..8202458 100644 --- a/app/vmmd.ml +++ b/app/vmmd.ml @@ -16,136 +16,86 @@ let pp_stats ppf s = open Lwt.Infix -let write_raw s data = - Vmm_lwt.write_raw s data >|= fun _ -> () +type out = [ + | `Cons of Cstruct.t + | `Stat of Cstruct.t + | `Log of Cstruct.t +] -let write_tls state t data = - Vmm_tls.write_tls (fst t) data >>= function - | Ok () -> Lwt.return_unit - | Error `Exception -> - let state', out = Vmm_engine.handle_disconnect !state t in - state := state' ; - Lwt_list.iter_s (fun (s, data) -> write_raw s data) out >>= fun () -> - Tls_lwt.Unix.close (fst t) - -let to_ipaddr (_, sa) = match sa with - | Lwt_unix.ADDR_UNIX _ -> invalid_arg "cannot convert unix address" - | Lwt_unix.ADDR_INET (addr, port) -> Ipaddr_unix.V4.of_inet_addr_exn addr, port - -let pp_sockaddr ppf (_, sa) = match sa with - | Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str - | Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d" - (Unix.string_of_inet_addr addr) port - -let process state xs = - Lwt_list.iter_s (function - | `Raw (s, str) -> write_raw s str - | `Tls (s, str) -> write_tls state s str) - xs - -let handle ca state t = - Logs.debug (fun m -> m "connection from %a" pp_sockaddr t) ; - let authenticator = - let time = Ptime_clock.now () in - X509.Authenticator.chain_of_trust ~time ~crls:!state.Vmm_engine.crls [ca] +let handle state out c_fd fd addr = + (* out is for `Log | `Stat | `Cons (including reconnect semantics) *) + (* need to handle data out (+ die on write failure) *) + Logs.debug (fun m -> m "connection from %a" Vmm_lwt.pp_sockaddr addr) ; + (* now we need to read a packet and handle it + (1) + (a) easy for info (look up name/prefix in resources) + (b) destroy looks up vm in resources, executes kill (wait for pid will do the cleanup) + logs "destroy issued" + (c) create initiates the vm startup procedure: + write image file, create fifo, create tap(s), send fifo to console + -- Lwt effects happen (console) -- + executes ukvm-bin + waiter, send stats pid and taps, inserts await into state, logs "created vm" + -- Lwt effects happen (stats, logs, wait_and_clear) -- + (2) goto (1) + *) + let process xs = + Lwt_list.iter_p (function + | #out as o -> out o + | `Data cs -> + (* rather: terminate connection *) + Vmm_lwt.write_wire fd cs >|= fun _ -> ()) xs in - Lwt.catch - (fun () -> Tls_lwt.Unix.reneg ~authenticator (fst t)) - (fun e -> - (match e with - | Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a)) - | Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f)) - | exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ; - Tls_lwt.Unix.close (fst t) >>= fun () -> - Lwt.fail e) >>= fun () -> - (match Tls_lwt.Unix.epoch (fst t) with - | `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain - | `Error -> - Tls_lwt.Unix.close (fst t) >>= fun () -> - Lwt.fail_with "error while getting epoch") >>= fun chain -> - match Vmm_engine.handle_initial !state t (to_ipaddr t) chain ca with - | Ok (state', outs, next) -> - state := state' ; - process state outs >>= fun () -> - begin match next with - | `Create (task, next) -> - (match task with - | None -> Lwt.return_unit - | Some (kill, wait) -> kill () ; wait) >>= fun () -> - let await, wakeme = Lwt.wait () in - begin match next !state await with - | Ok (state', outs, cont) -> - state := state' ; - process state outs >>= fun () -> - begin match cont !state t with - | Ok (state', outs, vm) -> - state := state' ; - s := { !s with vm_created = succ !s.vm_created } ; - Lwt.async (fun () -> - Vmm_lwt.wait_and_clear vm.Vmm_core.pid vm.Vmm_core.stdout >>= fun r -> - let state', outs = Vmm_engine.handle_shutdown !state vm r in - s := { !s with vm_destroyed = succ !s.vm_destroyed } ; - state := state' ; - process state outs >|= fun () -> - Lwt.wakeup wakeme ()) ; - process state outs >>= fun () -> - begin match Vmm_engine.setup_stats !state vm with - | Ok (state', outs) -> - state := state' ; - process state outs - | Error (`Msg e) -> - Logs.warn (fun m -> m "(ignored) error %s while setting up statistics" e) ; - Lwt.return_unit - end - | Error (`Msg e) -> - Logs.err (fun m -> m "error while create %s" e) ; - let err = Vmm_wire.fail ~msg:e 0 !state.Vmm_engine.client_version in - process state [ `Tls (t, err) ] - end - | Error (`Msg e) -> - Logs.err (fun m -> m "error while cont %s" e) ; - let err = Vmm_wire.fail ~msg:e 0 !state.Vmm_engine.client_version in - process state [ `Tls (t, err) ] - end >>= fun () -> - Tls_lwt.Unix.close (fst t) - | `Loop (prefix, perms) -> - let rec loop () = - Vmm_tls.read_tls (fst t) >>= function - | Error (`Msg msg) -> - Logs.err (fun m -> m "reading client %a error: %s" pp_sockaddr t msg) ; - loop () - | Error _ -> - Logs.err (fun m -> m "disconnect from %a" pp_sockaddr t) ; - let state', cons = Vmm_engine.handle_disconnect !state t in - state := state' ; - Lwt_list.iter_s (fun (s, data) -> write_raw s data) cons >>= fun () -> - Tls_lwt.Unix.close (fst t) - | Ok (hdr, buf) -> - let state', out = Vmm_engine.handle_command !state t prefix perms hdr buf in - state := state' ; - process state out >>= fun () -> - loop () - in - loop () - | `Close socks -> - Logs.debug (fun m -> m "closing session with %d active ones" (List.length socks)) ; - Lwt_list.iter_s (fun (t, _) -> Tls_lwt.Unix.close t) socks >>= fun () -> - Tls_lwt.Unix.close (fst t) - end - | Error (`Msg e) -> - Logs.err (fun m -> m "VMM %a %s" pp_sockaddr t e) ; - let err = Vmm_wire.fail ~msg:e 0 !state.Vmm_engine.client_version in - process state [`Tls (t, err)] >>= fun () -> - Tls_lwt.Unix.close (fst t) - -let server_socket port = - let open Lwt_unix in - let s = socket PF_INET SOCK_STREAM 0 in - set_close_on_exec s ; - setsockopt s SO_REUSEADDR true ; - bind s (ADDR_INET (Unix.inet_addr_any, port)) >>= fun () -> - listen s 10 ; - Lwt.return s + Logs.debug (fun m -> m "now reading") ; + (Vmm_lwt.read_wire fd >>= function + | Error _ -> + Logs.err (fun m -> m "error while reading") ; + Lwt.return_unit + | Ok (hdr, buf) -> + Logs.debug (fun m -> m "read sth") ; + let state', data, next = Vmm_engine.handle_command !state hdr buf in + state := state' ; + process data >>= fun () -> + match next with + | `End -> Lwt.return_unit + | `Wait (task, out) -> task >>= fun () -> process out + | `Create cont -> + (* data contained a write to console, we need to wait for its reply first *) + Vmm_lwt.read_wire c_fd >>= function + | Ok (_, data) when Vmm_wire.is_fail hdr -> + Logs.err (fun m -> m "console failed with %s" (Cstruct.to_string data)) ; + Lwt.return_unit + | Ok (_, _) when Vmm_wire.is_reply hdr -> + (* assert hdr.id = id! *) + (* TODO slightly more tricky, since we need to "Vmm_lwt.wait_and_clear" in here *) + let await, wakeme = Lwt.wait () in + begin match cont !state await with + | Error (`Msg msg) -> + Logs.err (fun m -> m "create continuation failed %s" msg) ; + Lwt.return_unit + | Ok (state'', out, vm) -> + state := state'' ; + s := { !s with vm_created = succ !s.vm_created } ; + Lwt.async (fun () -> + Vmm_lwt.wait_and_clear vm.Vmm_core.pid vm.Vmm_core.stdout >>= fun r -> + let state', out' = Vmm_engine.handle_shutdown !state vm r in + s := { !s with vm_destroyed = succ !s.vm_destroyed } ; + state := state' ; + process out' >|= fun () -> + Lwt.wakeup wakeme ()) ; + process out >>= fun () -> + begin match Vmm_engine.setup_stats !state vm with + | Ok (state', out) -> + state := state' ; + process out (* TODO: need to read from stats socket! *) + | Error (`Msg e) -> + Logs.warn (fun m -> m "(ignored) error %s while setting up statistics" e) ; + Lwt.return_unit + end + end + | _ -> + Logs.err (fun m -> m "error while reading from console") ; + Lwt.return_unit) >>= fun () -> + Vmm_lwt.safe_close fd let init_sock dir name = let c = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in @@ -159,120 +109,63 @@ let init_sock dir name = (Lwt.catch (fun () -> Lwt_unix.close c) (fun _ -> Lwt.return_unit)) >|= fun () -> None) -let rec read_log state s = - Vmm_lwt.read_exactly s >>= function - | Error (`Msg msg) -> - Logs.err (fun m -> m "reading log error %s" msg) ; - read_log state s - | Error _ -> - Logs.err (fun m -> m "exception while reading log") ; - invalid_arg "log socket communication issue" - | Ok (hdr, data) -> - let state', outs = Vmm_engine.handle_log !state hdr data in - state := state' ; - process state outs >>= fun () -> - read_log state s +let create_mbox name = + init_sock Vmm_core.tmpdir name >|= function + | None -> None + | Some fd -> + let mvar = Lwt_mvar.create_empty () in + (* could be more elaborate: + if fails, we can reconnect and spit our more log messages to the new socket + if fails, all running VMs terminate, so we can terminate as well ;) + if fails, we'd need to retransmit all VM info to stat (or stat has to ask at connect) *) + let rec loop () = + Lwt_mvar.take mvar >>= fun data -> + Vmm_lwt.write_wire fd data >>= function + | Ok () -> loop () + | Error `Exception -> invalid_arg ("exception while writing to " ^ name) ; + in + Lwt.async loop ; + Some (mvar, fd) -let rec read_cons state s = - Vmm_lwt.read_exactly s >>= function - | Error (`Msg msg) -> - Logs.err (fun m -> m "reading console error %s" msg) ; - read_cons state s - | Error _ -> - Logs.err (fun m -> m "exception while reading console socket") ; - invalid_arg "console socket communication issue" - | Ok (hdr, data) -> - let state', outs = Vmm_engine.handle_cons !state hdr data in - state := state' ; - process state outs >>= fun () -> - read_cons state s - -let rec read_stats state s = - Vmm_lwt.read_exactly s >>= function - | Error (`Msg msg) -> - Logs.err (fun m -> m "reading stats error %s" msg) ; - read_stats state s - | Error _ -> - Logs.err (fun m -> m "exception while reading stats") ; - Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit) >|= fun () -> - invalid_arg "stat socket communication issue" - | Ok (hdr, data) -> - let state', outs = Vmm_engine.handle_stat !state hdr data in - state := state' ; - process state outs >>= fun () -> - read_stats state s - -let cmp_s (_, a) (_, b) = - let open Lwt_unix in - match a, b with - | ADDR_UNIX str, ADDR_UNIX str' -> String.compare str str' = 0 - | ADDR_INET (addr, port), ADDR_INET (addr', port') -> - port = port' && - String.compare (Unix.string_of_inet_addr addr) (Unix.string_of_inet_addr addr') = 0 - | _ -> false +let server_socket dir name = + let file = Fpath.(dir / name + "sock") in + let sock = Fpath.to_string file in + (Lwt_unix.file_exists sock >>= function + | true -> Lwt_unix.unlink sock + | false -> Lwt.return_unit) >>= fun () -> + let s = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in + Lwt_unix.(bind s (ADDR_UNIX sock)) >|= fun () -> + Lwt_unix.listen s 1 ; + s let rec stats_loop () = Logs.info (fun m -> m "%a" pp_stats !s) ; Lwt_unix.sleep 600. >>= fun () -> stats_loop () -let jump _ cacert cert priv_key port = +let jump _ = Sys.(set_signal sigpipe Signal_ignore) ; Lwt_main.run - (Nocrypto_entropy_lwt.initialize () >>= fun () -> - (init_sock Vmm_core.tmpdir "cons" >|= function + (server_socket Vmm_core.tmpdir "vmmd" >>= fun ss -> + (create_mbox "cons" >|= function | None -> invalid_arg "cannot connect to console socket" - | Some c -> c) >>= fun c -> - init_sock Vmm_core.tmpdir "stat" >>= fun s -> - (init_sock Vmm_core.tmpdir "log" >|= function + | Some c -> c) >>= fun (c, c_fd) -> + create_mbox "stat" >>= fun s -> + (create_mbox "log" >|= function | None -> invalid_arg "cannot connect to log socket" - | Some l -> l) >>= fun l -> - server_socket port >>= fun socket -> - X509_lwt.private_of_pems ~cert ~priv_key >>= fun cert -> - X509_lwt.certs_of_pem cacert >>= (function - | [ ca ] -> Lwt.return ca - | _ -> Lwt.fail_with "expect single ca as cacert") >>= fun ca -> - let config = - Tls.(Config.server ~version:(Core.TLS_1_2, Core.TLS_1_2) - ~reneg:true ~certificates:(`Single cert) ()) + | Some l -> l) >>= fun (l, _l_fd) -> + let state = ref (Vmm_engine.init ()) in + let out = function + | `Stat data -> (match s with None -> Lwt.return_unit | Some (s, _s_fd) -> Lwt_mvar.put s data) + | `Log data -> Lwt_mvar.put l data + | `Cons data -> Lwt_mvar.put c data in - (match Vmm_engine.init cmp_s c s l with - | Ok s -> Lwt.return s - | Error (`Msg m) -> Lwt.fail_with m) >>= fun t -> - let state = ref t in - Lwt.async (fun () -> read_cons state c) ; - (match s with - | None -> () - | Some s -> Lwt.async (fun () -> read_stats state s)) ; - Lwt.async (fun () -> read_log state l) ; Lwt.async stats_loop ; let rec loop () = - Lwt.catch (fun () -> - Lwt_unix.accept socket >>= fun (fd, addr) -> - Lwt_unix.set_close_on_exec fd ; - Lwt.catch - (fun () -> Tls_lwt.Unix.server_of_fd config fd >|= fun t -> (t, addr)) - (fun exn -> - Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () -> - Lwt.fail exn) >>= fun t -> - Lwt.async (fun () -> - Lwt.catch - (fun () -> handle ca state t) - (fun e -> - Logs.err (fun m -> m "error while handle() %s" - (Printexc.to_string e)) ; - Lwt.return_unit)) ; - loop ()) - (function - | Unix.Unix_error (e, f, _) -> - Logs.err (fun m -> m "Unix error %s in %s" (Unix.error_message e) f) ; - loop () - | Tls_lwt.Tls_failure a -> - Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a)) ; - loop () - | exn -> - Logs.err (fun m -> m "exception %s" (Printexc.to_string exn)) ; - loop ()) + Lwt_unix.accept ss >>= fun (fd, addr) -> + Lwt_unix.set_close_on_exec fd ; + Lwt.async (fun () -> handle state out c_fd fd addr) ; + loop () in loop ()) @@ -288,24 +181,8 @@ let setup_log = $ Fmt_cli.style_renderer () $ Logs_cli.level ()) -let cacert = - let doc = "CA certificate" in - Arg.(required & pos 0 (some file) None & info [] ~doc) - -let cert = - let doc = "Certificate" in - Arg.(required & pos 1 (some file) None & info [] ~doc) - -let key = - let doc = "Private key" in - Arg.(required & pos 2 (some file) None & info [] ~doc) - -let port = - let doc = "TCP listen port" in - Arg.(value & opt int 1025 & info [ "port" ] ~doc) - let cmd = - Term.(ret (const jump $ setup_log $ cacert $ cert $ key $ port)), + Term.(ret (const jump $ setup_log)), Term.info "vmmd" ~version:"%%VERSION_NUM%%" let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1 diff --git a/opam b/opam index 1a9aa72..0b5376b 100644 --- a/opam +++ b/opam @@ -14,7 +14,6 @@ depends: [ "ipaddr" {>= "2.2.0"} "hex" "cstruct" - "ppx_cstruct" {build & >= "3.0.0"} "logs" "rresult" "bos" diff --git a/pkg/pkg.ml b/pkg/pkg.ml index 14dd74e..54bbcbd 100644 --- a/pkg/pkg.ml +++ b/pkg/pkg.ml @@ -9,7 +9,9 @@ let () = Pkg.bin "app/vmmd" ; Pkg.bin "app/vmm_console" ; Pkg.bin "app/vmm_log" ; - Pkg.bin "app/vmm_client" ; + (* Pkg.bin "app/vmm_client" ; *) + (* Pkg.bin "app/vmm_tls_endpoint" ; *) + Pkg.bin "app/vmmc" ; Pkg.bin "provision/vmm_req_permissions" ; Pkg.bin "provision/vmm_req_delegation" ; Pkg.bin "provision/vmm_req_vm" ; @@ -18,6 +20,6 @@ let () = Pkg.bin "provision/vmm_gen_ca" ; Pkg.clib "stats/libvmm_stats_stubs.clib" ; Pkg.bin "stats/vmm_stats_lwt" ; - Pkg.bin "app/vmm_prometheus_stats" ; + (* Pkg.bin "app/vmm_prometheus_stats" ; *) Pkg.bin "app/vmm_influxdb_stats" ; ] diff --git a/src/vmm_asn.ml b/src/vmm_asn.ml index 8f38f34..099c1c0 100644 --- a/src/vmm_asn.ml +++ b/src/vmm_asn.ml @@ -38,15 +38,16 @@ end let perms : permission list Asn.t = Asn.S.bit_string_flags [ - 0, `All ; + 0, `All ; (* no *) 1, `Info ; 2, `Create ; - 3, `Block ; + 3, `Block ; (* create [name] [size] ; destroy [name] *) 4, `Statistics ; 5, `Console ; 6, `Log ; 7, `Crl ; 9, `Force_create ; + (* 10, `Destroy ; (* [name] *) *) ] open Rresult.R.Infix diff --git a/src/vmm_commands.ml b/src/vmm_commands.ml new file mode 100644 index 0000000..e4bf64b --- /dev/null +++ b/src/vmm_commands.ml @@ -0,0 +1,223 @@ +(* (c) 2017, 2018 Hannes Mehnert, all rights reserved *) + +open Astring + +open Vmm_core + +open Rresult +open R.Infix + +let handle_command t s prefix perms hdr buf = + let res = + if not (Vmm_wire.version_eq hdr.Vmm_wire.version t.client_version) then + Error (`Msg "unknown client version") + else match Vmm_wire.Client.cmd_of_int hdr.Vmm_wire.tag with + | None -> Error (`Msg "unknown command") + | Some x when cmd_allowed perms x -> + begin + Vmm_wire.decode_str buf >>= fun (buf, _l) -> + let arg = if String.length buf = 0 then prefix else prefix @ [buf] in + let vmid = string_of_id arg in + match x with + | Info -> + begin match Vmm_resources.find t.resources arg with + | None -> + Logs.debug (fun m -> m "info: couldn't find %a" pp_id arg) ; + R.error_msgf "info: %s not found" buf + | Some x -> + let data = + Vmm_resources.fold (fun acc vm -> + acc ^ Vmm_wire.Client.encode_vm vm.config.vname vm) + "" x + in + let out = Vmm_wire.Client.info data hdr.Vmm_wire.id t.client_version in + Ok (t, [ `Tls (s, out) ]) + end + | Destroy_vm -> + begin match Vmm_resources.find_vm t.resources arg with + | Some vm -> + Vmm_unix.destroy vm ; + let out = Vmm_wire.success hdr.Vmm_wire.id t.client_version in + Ok (t, [ `Tls (s, out) ]) + | _ -> + Error (`Msg ("destroy: not found " ^ buf)) + end + | Attach -> + (* TODO: get (optionally) from client, instead of hardcoding Ptime.epoch below *) + let on_success t = + let cons = Vmm_wire.Console.history t.console_counter t.console_version vmid Ptime.epoch in + let old = match String.Map.find vmid t.console_attached with + | None -> [] + | Some s -> + let out = Vmm_wire.success hdr.Vmm_wire.id t.client_version in + [ `Tls (s, out) ] + in + let console_attached = String.Map.add vmid s t.console_attached in + { t with console_counter = succ t.console_counter ; console_attached }, + `Raw (t.console_socket, cons) :: old + in + let cons = Vmm_wire.Console.attach t.console_counter t.console_version vmid in + let console_requests = IM.add t.console_counter on_success t.console_requests in + Ok ({ t with console_counter = succ t.console_counter ; console_requests }, + [ `Raw (t.console_socket, cons) ]) + | Detach -> + let cons = Vmm_wire.Console.detach t.console_counter t.console_version vmid in + (match String.Map.find vmid t.console_attached with + | None -> Error (`Msg "not attached") + | Some x when t.cmp x s -> Ok (String.Map.remove vmid t.console_attached) + | Some _ -> Error (`Msg "this socket is not attached")) >>= fun console_attached -> + let out = Vmm_wire.success hdr.Vmm_wire.id t.client_version in + Ok ({ t with console_counter = succ t.console_counter ; console_attached }, + [ `Raw (t.console_socket, cons) ; `Tls (s, out) ]) + | Statistics -> + begin match t.stats_socket with + | None -> Error (`Msg "no statistics available") + | Some _ -> match Vmm_resources.find_vm t.resources arg with + | Some vm -> + let stat_out = Vmm_wire.Stats.stat t.stats_counter t.stats_version vmid in + let d = (s, hdr.Vmm_wire.id, translate_tap vm) in + let stats_requests = IM.add t.stats_counter d t.stats_requests in + Ok ({ t with stats_counter = succ t.stats_counter ; stats_requests }, + stat t stat_out) + | _ -> Error (`Msg ("statistics: not found " ^ buf)) + end + | Log -> + begin + let log_out = Vmm_wire.Log.history t.log_counter t.log_version (string_of_id prefix) Ptime.epoch in + let log_requests = IM.add t.log_counter (s, hdr.Vmm_wire.id) t.log_requests in + let log_counter = succ t.log_counter in + Ok ({ t with log_counter ; log_requests }, [ `Raw (t.log_socket, log_out) ]) + end + | Create_block | Destroy_block -> Error (`Msg "NYI") + end + | Some _ -> Error (`Msg "unauthorised command") + in + match res with + | Ok r -> r + | Error (`Msg msg) -> + Logs.debug (fun m -> m "error while processing command: %s" msg) ; + let out = Vmm_wire.fail ~msg hdr.Vmm_wire.id t.client_version in + (t, [ `Tls (s, out) ]) + +let handle_stat state hdr data = + let open Vmm_wire in + if not (version_eq hdr.version state.stats_version) then begin + Logs.warn (fun m -> m "ignoring message with unknown stats version") ; + state, [] + end else if hdr.tag = success_tag then + state, [] + else + match IM.find hdr.id state.stats_requests with + | exception Not_found -> + Logs.err (fun m -> m "couldn't find stat request") ; + state, [] + | (s, req_id, f) -> + let stats_requests = IM.remove hdr.id state.stats_requests in + let state = { state with stats_requests } in + let out = + match Stats.int_to_op hdr.tag with + | Some Stats.Stat_reply -> + begin match Stats.decode_stats (Cstruct.of_string data) with + | Ok (ru, vmm, ifs) -> + let ifs = + List.map + (fun x -> + match f x.name with + | Some name -> { x with name } + | None -> x) + ifs + in + let data = Cstruct.to_string (Stats.encode_stats (ru, vmm, ifs)) in + let out = Client.stat data req_id state.client_version in + [ `Tls (s, out) ] + | Error (`Msg msg) -> + Logs.err (fun m -> m "error %s while decode statistics" msg) ; + let out = fail req_id state.client_version in + [ `Tls (s, out) ] + end + | None when hdr.tag = fail_tag -> + let out = fail ~msg:data req_id state.client_version in + [ `Tls (s, out) ] + | _ -> + Logs.err (fun m -> m "unexpected reply from stat") ; + [] + in + (state, out) + +let handle_cons state hdr data = + let open Vmm_wire in + if not (version_eq hdr.version state.console_version) then begin + Logs.warn (fun m -> m "ignoring message with unknown console version") ; + state, [] + end else match Console.int_to_op hdr.tag with + | Some Console.Data -> + begin match decode_str data with + | Error (`Msg msg) -> + Logs.err (fun m -> m "error while decoding console message %s" msg) ; + (state, []) + | Ok (file, off) -> + (match String.Map.find file state.console_attached with + | Some s -> + let out = Client.console off file data state.client_version in + (state, [ `Tls (s, out) ]) + | None -> + (* TODO: should detach? *) + Logs.err (fun m -> m "couldn't find attached console for %s" file) ; + (state, [])) + end + | None when hdr.tag = success_tag -> + (match IM.find hdr.id state.console_requests with + | exception Not_found -> + (state, []) + | cont -> + let state', outs = cont state in + let console_requests = IM.remove hdr.id state.console_requests in + ({ state' with console_requests }, outs)) + | None when hdr.tag = fail_tag -> + (match IM.find hdr.id state.console_requests with + | exception Not_found -> + Logs.err (fun m -> m "fail couldn't find request id") ; + (state, []) + | _ -> + Logs.err (fun m -> m "failed while trying to do something on console") ; + let console_requests = IM.remove hdr.id state.console_requests in + ({ state with console_requests }, [])) + | _ -> + Logs.err (fun m -> m "unexpected message received from console socket") ; + (state, []) + +let handle_log state hdr buf = + let open Vmm_wire in + let open Vmm_wire.Log in + if not (version_eq hdr.version state.log_version) then begin + Logs.warn (fun m -> m "ignoring message with unknown stats version") ; + state, [] + end else match IM.find hdr.id state.log_requests with + | exception Not_found -> + Logs.warn (fun m -> m "(ignored) coudn't find log request") ; + (state, []) + | (s, rid) -> + let r = match int_to_op hdr.tag with + | Some Data -> + decode_log_hdr (Cstruct.of_string buf) >>= fun (hdr, rest) -> + decode_event rest >>= fun event -> + let tls = Vmm_wire.Client.log hdr event state.client_version in + Ok (state, [ `Tls (s, tls) ]) + | None when hdr.tag = success_tag -> + let log_requests = IM.remove hdr.id state.log_requests in + let tls = Vmm_wire.success rid state.client_version in + Ok ({ state with log_requests }, [ `Tls (s, tls) ]) + | None when hdr.tag = fail_tag -> + let log_requests = IM.remove hdr.id state.log_requests in + let tls = Vmm_wire.fail rid state.client_version in + Ok ({ state with log_requests }, [ `Tls (s, tls) ]) + | _ -> + Logs.err (fun m -> m "couldn't parse log reply") ; + let log_requests = IM.remove hdr.id state.log_requests in + Ok ({ state with log_requests }, []) + in + match r with + | Ok (s, out) -> s, out + | Error (`Msg msg) -> + Logs.err (fun m -> m "error while processing log %s" msg) ; + state, [] diff --git a/src/vmm_core.ml b/src/vmm_core.ml index 558d59b..74d2fcf 100644 --- a/src/vmm_core.ml +++ b/src/vmm_core.ml @@ -14,6 +14,7 @@ end module IS = Set.Make(I) module IM = Map.Make(I) +module IM64 = Map.Make(Int64) type permission = [ `All | `Info | `Create | `Block | `Statistics | `Console | `Log | `Crl | `Force_create] @@ -88,6 +89,17 @@ let cmd_allowed permissions cmd = type vmtype = [ `Ukvm_amd64 | `Ukvm_arm64 | `Ukvm_amd64_compressed ] +let vmtype_to_int = function + | `Ukvm_amd64 -> 0 + | `Ukvm_arm64 -> 1 + | `Ukvm_amd64_compressed -> 2 + +let int_to_vmtype = function + | 0 -> Some `Ukvm_amd64 + | 1 -> Some `Ukvm_arm64 + | 2 -> Some `Ukvm_amd64_compressed + | _ -> None + let pp_vmtype ppf = function | `Ukvm_amd64 -> Fmt.pf ppf "ukvm-amd64" | `Ukvm_amd64_compressed -> Fmt.pf ppf "ukvm-amd64-compressed" @@ -340,7 +352,7 @@ module Log = struct let pp_hdr db ppf (hdr : hdr) = let name = translate_serial db hdr.name in - Fmt.pf ppf "%a: %s" (Ptime.pp_human ~tz_offset_s:0 ()) hdr.ts name + Fmt.pf ppf "%a: %s" (Ptime.pp_human ()) hdr.ts name let hdr context name = { ts = Ptime_clock.now () ; context ; name } @@ -350,10 +362,6 @@ module Log = struct | `Logout of Ipaddr.V4.t * int | `VM_start of int * string list * string option | `VM_stop of int * [ `Exit of int | `Signal of int | `Stop of int ] - | `Block_create of string * int - | `Block_destroy of string - | `Delegate of string list * string option - (* | `CRL of string *) ] let pp_event ppf = function @@ -371,14 +379,6 @@ module Log = struct | `Stop n -> "stop", n in Fmt.pf ppf "STOPPED %d with %s %a" pid s Fmt.Dump.signal c - | `Block_create (name, size) -> - Fmt.pf ppf "BLOCK_CREATE %s %d" name size - | `Block_destroy name -> Fmt.pf ppf "BLOCK_DESTROY %s" name - | `Delegate (bridges, block) -> - Fmt.pf ppf "DELEGATE %a, block %a" - Fmt.(list ~sep:(unit "; ") string) bridges - Fmt.(option ~none:(unit "no") string) block - (* | `CRL of string *) type msg = hdr * event diff --git a/src/vmm_engine.ml b/src/vmm_engine.ml index 2e7d6f7..c1f9bba 100644 --- a/src/vmm_engine.ml +++ b/src/vmm_engine.ml @@ -1,4 +1,4 @@ -(* (c) 2017 Hannes Mehnert, all rights reserved *) +(* (c) 2017, 2018 Hannes Mehnert, all rights reserved *) open Astring @@ -7,21 +7,12 @@ open Vmm_core open Rresult open R.Infix -type ('a, 'b, 'c) t = { - cmp : 'b -> 'b -> bool ; - console_socket : 'a ; - console_counter : int ; - console_requests : (('a, 'b, 'c) t -> ('a, 'b, 'c) t * [ `Raw of 'a * string | `Tls of 'b * string ] list) IM.t ; - console_attached : 'b String.Map.t ; (* vm_name -> socket *) +type 'a t = { + console_counter : int64 ; console_version : Vmm_wire.version ; - stats_socket : 'a option ; - stats_counter : int ; - stats_requests : ('b * int * (string -> string option)) IM.t ; + stats_counter : int64 ; stats_version : Vmm_wire.version ; - log_socket : 'a ; - log_counter : int ; - log_requests : ('b * int) IM.t ; - log_attached : ('b * string) list String.Map.t ; + log_counter : int64 ; log_version : Vmm_wire.version ; client_version : Vmm_wire.version ; (* TODO: refine, maybe: @@ -29,121 +20,63 @@ type ('a, 'b, 'c) t = { used_bridges : String.Set.t String.Map.t ; (* TODO: used block devices (since each may only be active once) *) resources : Vmm_resources.t ; - tasks : 'c String.Map.t ; - crls : X509.CRL.c list ; + tasks : 'a String.Map.t ; } -let init cmp console_socket stats_socket log_socket = - (* error hard on permission denied etc. *) - let crls = Fpath.(dbdir / "crls") in - (Bos.OS.Dir.exists crls >>= function - | true -> Ok true - | false -> Bos.OS.Dir.create crls) >>= fun _ -> - let err _ x = x in - Bos.OS.Dir.fold_contents ~elements:`Files ~traverse:`None ~err - (fun f acc -> - acc >>= fun acc -> - Bos.OS.File.read f >>= fun data -> - match X509.Encoding.crl_of_cstruct (Cstruct.of_string data) with - | None -> R.error_msgf "couldn't parse CRL %a" Fpath.pp f - | Some crl -> Ok (crl :: acc)) - (Ok []) - crls >>= fun crls -> - crls >>= fun crls -> - Ok { - cmp ; - console_socket ; console_counter = 1 ; console_requests = IM.empty ; - console_attached = String.Map.empty ; console_version = `WV0 ; - stats_socket ; stats_counter = 1 ; stats_requests = IM.empty ; - stats_version = `WV1 ; - log_socket ; log_counter = 1 ; log_attached = String.Map.empty ; - log_version = `WV0 ; log_requests = IM.empty ; - client_version = `WV0 ; - used_bridges = String.Map.empty ; - resources = Vmm_resources.empty ; - tasks = String.Map.empty ; - crls - } - -let asn_version = `AV0 +let init () = { + console_counter = 1L ; console_version = `WV2 ; + stats_counter = 1L ; stats_version = `WV2 ; + log_counter = 1L ; log_version = `WV2 ; + client_version = `WV2 ; + used_bridges = String.Map.empty ; + resources = Vmm_resources.empty ; + tasks = String.Map.empty ; +} let log state (hdr, event) = - let pre = string_of_id hdr.Log.context in - let out = match String.Map.find pre state.log_attached with - | None -> [] - | Some x -> x - in - let data = Vmm_wire.Log.data state.log_counter state.log_version hdr event in - let tls = Vmm_wire.Client.log hdr event state.client_version in - let log_counter = succ state.log_counter in + let data = Vmm_wire.Log.log state.log_counter state.log_version hdr event in + let log_counter = Int64.succ state.log_counter in Logs.debug (fun m -> m "LOG %a" (Log.pp []) (hdr, event)) ; - ({ state with log_counter }, - `Raw (state.log_socket, data) :: List.map (fun (s, _) -> `Tls (s, tls)) out) + ({ state with log_counter }, `Log data) -let stat state str = - match state.stats_socket with - | None -> [] - | Some s -> [ `Raw (s, str) ] - -let handle_disconnect state t = - Logs.err (fun m -> m "disconnect!!") ; - let rem, console_attached = - String.Map.partition (fun _ s -> state.cmp s t) state.console_attached - in - let out, console_counter = - List.fold_left (fun (acc, ctr) name -> - (acc ^ Vmm_wire.Console.detach ctr state.console_version name, succ ctr)) - ("", state.console_counter) - (fst (List.split (String.Map.bindings rem))) - in - let log_attached = String.Map.fold (fun k v n -> - match List.filter (fun (e, _) -> not (state.cmp t e)) v with - | [] -> n - | xs -> String.Map.add k xs n) - state.log_attached String.Map.empty - in - let out = - if String.length out = 0 then - [] - else - [ (state.console_socket, out) ] - in - { state with console_attached ; console_counter ; log_attached }, out - -let handle_create t vm_config policies = +let handle_create t hdr vm_config (* policies *) = let full = fullname vm_config in (if Vmm_resources.exists t.resources full then Error (`Msg "VM with same name is already running") else Ok ()) >>= fun () -> - Logs.debug (fun m -> m "now checking dynamic policies") ; - Vmm_resources.check_dynamic t.resources vm_config policies >>= fun () -> + (* Logs.debug (fun m -> m "now checking dynamic policies") ; + Vmm_resources.check_dynamic t.resources vm_config policies >>= fun () -> *) (* prepare VM: save VM image to disk, create fifo, ... *) Vmm_unix.prepare vm_config >>= fun taps -> Logs.debug (fun m -> m "prepared vm with taps %a" Fmt.(list ~sep:(unit ",@ ") string) taps) ; - Ok (fun t s -> - (* actually execute the vm *) - Vmm_unix.exec vm_config taps >>= fun vm -> - Logs.debug (fun m -> m "exec()ed vm") ; - Vmm_resources.insert t.resources full vm >>= fun resources -> - let used_bridges = - List.fold_left2 (fun b br ta -> - let old = match String.Map.find br b with - | None -> String.Set.empty - | Some x -> x - in - String.Map.add br (String.Set.add ta old) b) - t.used_bridges vm_config.network taps - in - let t = { t with resources ; used_bridges } in - let t, out = log t (Log.hdr vm_config.prefix vm_config.vname, `VM_start (vm.pid, vm.taps, None)) in - let tls_out = Vmm_wire.success ~msg:"VM started" 0 t.client_version in - Ok (t, `Tls (s, tls_out) :: out, vm)) + (* TODO should we pre-reserve sth in t? *) + let cons = Vmm_wire.Console.add t.console_counter t.console_version full in + Ok ({ t with console_counter = Int64.succ t.console_counter }, [ `Cons cons ], + `Create (fun t task -> + (* actually execute the vm *) + Vmm_unix.exec vm_config taps >>= fun vm -> + Logs.debug (fun m -> m "exec()ed vm") ; + Vmm_resources.insert t.resources full vm >>= fun resources -> + let tasks = String.Map.add (string_of_id full) task t.tasks in + let used_bridges = + List.fold_left2 (fun b br ta -> + let old = match String.Map.find br b with + | None -> String.Set.empty + | Some x -> x + in + String.Map.add br (String.Set.add ta old) b) + t.used_bridges vm_config.network taps + in + let t = { t with resources ; tasks ; used_bridges } in + let t, out = log t (Log.hdr vm_config.prefix vm_config.vname, `VM_start (vm.pid, vm.taps, None)) in + let data = Vmm_wire.success t.client_version hdr.Vmm_wire.id Vmm_wire.Vm.(op_to_int Create) in + Ok (t, [ `Data data ; out ], vm))) let setup_stats t vm = - let stat_out = Vmm_wire.Stats.add t.stats_counter t.stats_version (vm_id vm.config) vm.pid vm.taps in - let t = { t with stats_counter = succ t.stats_counter } in - Ok (t, stat t stat_out) + let stat_out = Vmm_wire.Stats.add t.stats_counter t.stats_version (fullname vm.config) vm.pid vm.taps in + let t = { t with stats_counter = Int64.succ t.stats_counter } in + Ok (t, [ `Stat stat_out ]) let handle_shutdown t vm r = (match Vmm_unix.shutdown vm with @@ -165,386 +98,59 @@ let handle_shutdown t vm r = String.Map.add br (String.Set.remove ta old) b) t.used_bridges vm.config.network vm.taps in - let stat_out = Vmm_wire.Stats.remove t.stats_counter t.stats_version (vm_id vm.config) in + let stat_out = Vmm_wire.Stats.remove t.stats_counter t.stats_version (fullname vm.config) in let tasks = String.Map.remove (vm_id vm.config) t.tasks in - let t = { t with stats_counter = succ t.stats_counter ; resources ; used_bridges ; tasks } in - let t, outs = log t (Log.hdr vm.config.prefix vm.config.vname, - `VM_stop (vm.pid, r)) + let t = { t with stats_counter = Int64.succ t.stats_counter ; resources ; used_bridges ; tasks } in + let t, logout = log t (Log.hdr vm.config.prefix vm.config.vname, + `VM_stop (vm.pid, r)) in - (t, stat t stat_out @ outs) + (t, [ `Stat stat_out ; logout ]) -let handle_command t s prefix perms hdr buf = - let res = - if not (Vmm_wire.version_eq hdr.Vmm_wire.version t.client_version) then +let handle_command t hdr buf = + let msg_to_err = function + | Ok x -> x + | Error (`Msg msg) -> + Logs.debug (fun m -> m "error while processing command: %s" msg) ; + let out = Vmm_wire.fail ~msg t.client_version hdr.Vmm_wire.id in + (t, [ `Data out ], `End) + in + msg_to_err ( + if Vmm_wire.is_reply hdr then begin + Logs.warn (fun m -> m "ignoring reply") ; + Ok (t, [], `End) + end else if not (Vmm_wire.version_eq hdr.Vmm_wire.version t.client_version) then Error (`Msg "unknown client version") - else match Vmm_wire.Client.cmd_of_int hdr.Vmm_wire.tag with + else Vmm_wire.decode_strings buf >>= fun (id, _off) -> + match Vmm_wire.Vm.int_to_op hdr.Vmm_wire.tag with | None -> Error (`Msg "unknown command") - | Some x when cmd_allowed perms x -> - begin - Vmm_wire.decode_str buf >>= fun (buf, _l) -> - let arg = if String.length buf = 0 then prefix else prefix @ [buf] in - let vmid = string_of_id arg in - match x with - | Info -> - begin match Vmm_resources.find t.resources arg with - | None -> - Logs.debug (fun m -> m "info: couldn't find %a" pp_id arg) ; - R.error_msgf "info: %s not found" buf - | Some x -> - let data = - Vmm_resources.fold (fun acc vm -> - acc ^ Vmm_wire.Client.encode_vm vm.config.vname vm) - "" x - in - let out = Vmm_wire.Client.info data hdr.Vmm_wire.id t.client_version in - Ok (t, [ `Tls (s, out) ]) - end - | Destroy_vm -> - begin match Vmm_resources.find_vm t.resources arg with - | Some vm -> - Vmm_unix.destroy vm ; - let out = Vmm_wire.success hdr.Vmm_wire.id t.client_version in - Ok (t, [ `Tls (s, out) ]) - | _ -> - Error (`Msg ("destroy: not found " ^ buf)) - end - | Attach -> - (* TODO: get (optionally) from client, instead of hardcoding Ptime.epoch below *) - let on_success t = - let cons = Vmm_wire.Console.history t.console_counter t.console_version vmid Ptime.epoch in - let old = match String.Map.find vmid t.console_attached with - | None -> [] - | Some s -> - let out = Vmm_wire.success hdr.Vmm_wire.id t.client_version in - [ `Tls (s, out) ] - in - let console_attached = String.Map.add vmid s t.console_attached in - { t with console_counter = succ t.console_counter ; console_attached }, - `Raw (t.console_socket, cons) :: old + | Some Info -> + Logs.debug (fun m -> m "info %a" pp_id id) ; + begin match Vmm_resources.find t.resources id with + | None -> + Logs.debug (fun m -> m "info: couldn't find %a" pp_id id) ; + Error (`Msg "info: not found") + | Some x -> + let data = + Vmm_resources.fold (fun acc vm -> vm :: acc) [] x in - let cons = Vmm_wire.Console.attach t.console_counter t.console_version vmid in - let console_requests = IM.add t.console_counter on_success t.console_requests in - Ok ({ t with console_counter = succ t.console_counter ; console_requests }, - [ `Raw (t.console_socket, cons) ]) - | Detach -> - let cons = Vmm_wire.Console.detach t.console_counter t.console_version vmid in - (match String.Map.find vmid t.console_attached with - | None -> Error (`Msg "not attached") - | Some x when t.cmp x s -> Ok (String.Map.remove vmid t.console_attached) - | Some _ -> Error (`Msg "this socket is not attached")) >>= fun console_attached -> - let out = Vmm_wire.success hdr.Vmm_wire.id t.client_version in - Ok ({ t with console_counter = succ t.console_counter ; console_attached }, - [ `Raw (t.console_socket, cons) ; `Tls (s, out) ]) - | Statistics -> - begin match t.stats_socket with - | None -> Error (`Msg "no statistics available") - | Some _ -> match Vmm_resources.find_vm t.resources arg with - | Some vm -> - let stat_out = Vmm_wire.Stats.stat t.stats_counter t.stats_version vmid in - let d = (s, hdr.Vmm_wire.id, translate_tap vm) in - let stats_requests = IM.add t.stats_counter d t.stats_requests in - Ok ({ t with stats_counter = succ t.stats_counter ; stats_requests }, - stat t stat_out) - | _ -> Error (`Msg ("statistics: not found " ^ buf)) - end - | Log -> - begin - let log_out = Vmm_wire.Log.history t.log_counter t.log_version (string_of_id prefix) Ptime.epoch in - let log_requests = IM.add t.log_counter (s, hdr.Vmm_wire.id) t.log_requests in - let log_counter = succ t.log_counter in - Ok ({ t with log_counter ; log_requests }, [ `Raw (t.log_socket, log_out) ]) - end - | Create_block | Destroy_block -> Error (`Msg "NYI") + let out = Vmm_wire.Vm.info_reply hdr.Vmm_wire.id t.client_version data in + Ok (t, [ `Data out ], `End) end - | Some _ -> Error (`Msg "unauthorised command") - in - match res with - | Ok r -> r - | Error (`Msg msg) -> - Logs.debug (fun m -> m "error while processing command: %s" msg) ; - let out = Vmm_wire.fail ~msg hdr.Vmm_wire.id t.client_version in - (t, [ `Tls (s, out) ]) - -let handle_single_revocation t prefix serial = - let id = identifier serial in - (match Vmm_resources.find t.resources (prefix @ [ id ]) with - | None -> () - | Some e -> Vmm_resources.iter Vmm_unix.destroy e) ; - (* also revoke all active sessions!? *) - (* TODO: maybe we need a vmm_resources like structure for sessions as well!? *) - let log_attached, kill = - let pid = string_of_id prefix in - match String.Map.find pid t.log_attached with - | None -> t.log_attached, [] - | Some xs -> - (* those where snd v = serial: drop *) - let drop, keep = List.partition (fun (_, s) -> String.equal s id) xs in - String.Map.add pid keep t.log_attached, drop - in - (* two things: - 1 revoked LEAF certs need to go (k = prefix, snd v = serial) [see above] - 2 revoked CA certs need to wipe subtree (all entries where k starts with prefix @ serial) *) - let log_attached, kill = - String.Map.fold (fun k' v (l, k) -> - if is_sub_id ~super:(prefix@[id]) ~sub:(id_of_string k') then - (l, v @ k) - else - (String.Map.add k' v l, k)) - log_attached - (String.Map.empty, kill) - in - let state, out = - List.fold_left (fun (s, out) (t, _) -> - let s', out' = handle_disconnect s t in - s', out @ out') - ({ t with log_attached }, []) - kill - in - (state, - List.map (fun x -> `Raw x) out, - List.map fst kill) - -let handle_revocation t s leaf chain ca prefix = - Vmm_asn.crl_of_cert leaf >>= fun crl -> - (* verify data (must be signed by the last cert of the chain (or cacert if chain is empty))! *) - let issuer = match chain with - | subca::_ -> subca - | [] -> ca - in - let time = Ptime_clock.now () in - (if X509.CRL.verify crl ~time issuer then Ok () else Error (`Msg "couldn't verify CRL")) >>= fun () -> - (* the this_update must be > now, next_update < now, this_update > .this_update, number > .number *) - (* TODO: can we have something better for uniqueness of CRL? *) - let local = try Some (List.find (fun crl -> X509.CRL.verify crl issuer) t.crls) with Not_found -> None in - (match local with - | None -> Ok () - | Some local -> match X509.CRL.crl_number local, X509.CRL.crl_number crl with - | None, _ -> Ok () - | Some _, None -> Error (`Msg "CRL number not present") - | Some x, Some y -> if y > x then Ok () else Error (`Msg "CRL number not increased")) >>= fun () -> - (* filename should be whatever_dir / crls / *) - let filename = Fpath.(dbdir / "crls" / string_of_id prefix) in - Bos.OS.File.delete filename >>= fun () -> - Bos.OS.File.write filename (Cstruct.to_string (X509.Encoding.crl_to_cstruct crl)) >>= fun () -> - (* remove crl with same issuer from crls, and inject this one into state *) - let crls = - match local with - | None -> crl :: t.crls - | Some _ -> crl :: List.filter (fun c -> c <> crl) t.crls - in - (* iterate over revoked serials, find active resources, and kill them *) - let newly_revoked = - let old = match local with - | Some x -> List.map (fun rc -> rc.X509.CRL.serial) (X509.CRL.revoked_certificates x) - | None -> [] - in - let new_rev = List.map (fun rc -> rc.X509.CRL.serial) (X509.CRL.revoked_certificates crl) in - List.filter (fun n -> not (List.mem n old)) new_rev - in - let t, out, close = - List.fold_left (fun (t, out, close) serial -> - let t', out', close' = handle_single_revocation t prefix serial in - (t', out @ out', close @ close')) - (t, [], []) newly_revoked - in - let tls_out = Vmm_wire.success ~msg:"updated revocation list" 0 t.client_version in - Ok ({ t with crls }, `Tls (s, tls_out) :: out, `Close close) - -let handle_initial t s addr chain ca = - separate_chain chain >>= fun (leaf, chain) -> - Logs.debug (fun m -> m "leaf is %s, chain %a" - (X509.common_name_to_string leaf) - Fmt.(list ~sep:(unit "->") string) - (List.map X509.common_name_to_string chain)) ; - (* TODO here: inspect top-level-cert of chain. - may need to create bridges and/or block device subdirectory (zfs create) *) - let prefix = List.map id chain in - let login_hdr, login_ev = Log.hdr prefix (id leaf), `Login addr in - let t, out = log t (login_hdr, login_ev) in - let initial_out = `Tls (s, Vmm_wire.Client.log login_hdr login_ev t.client_version) in - Vmm_asn.permissions_of_cert asn_version leaf >>= fun perms -> - (if (List.mem `Create perms || List.mem `Force_create perms) && Vmm_asn.contains_vm leaf then - (* convert certificate to vm_config *) - Vmm_asn.vm_of_cert prefix leaf >>= fun vm_config -> - Logs.debug (fun m -> m "vm %a" pp_vm_config vm_config) ; - (* get names and static resources *) - List.fold_left (fun acc ca -> - acc >>= fun acc -> - Vmm_asn.delegation_of_cert asn_version ca >>= fun res -> - let name = id ca in - Ok ((name, res) :: acc)) - (Ok []) chain >>= fun policies -> - (* check static policies *) - Logs.debug (fun m -> m "now checking static policies") ; - check_policies vm_config (List.map snd policies) >>= fun () -> - let t, task = - let force = List.mem `Force_create perms in - if force then - let fid = vm_id vm_config in - match String.Map.find fid t.tasks with - | None -> t, None - | Some task -> - let kill () = - match Vmm_resources.find_vm t.resources (fullname vm_config) with - | None -> - Logs.err (fun m -> m "found a task, but no vm for %a (%s)" - pp_id (fullname vm_config) fid) - | Some vm -> - Logs.debug (fun m -> m "killing %a now" pp_vm vm) ; - Vmm_unix.destroy vm - in - let tasks = String.Map.remove fid t.tasks in - ({ t with tasks }, Some (kill, task)) - else - t, None - in - let next t sleeper = - handle_create t vm_config policies >>= fun cont -> - let id = vm_id vm_config in - let cons = Vmm_wire.Console.add t.console_counter t.console_version id in - let tasks = String.Map.add id sleeper t.tasks in - Ok ({ t with console_counter = succ t.console_counter ; tasks }, - [ `Raw (t.console_socket, cons) ], - cont) - in - Ok (t, [], `Create (task, next)) - else if List.mem `Crl perms && Vmm_asn.contains_crl leaf then - handle_revocation t s leaf chain ca prefix - else - let log_attached = - if cmd_allowed perms Log then - let pre = string_of_id prefix in - let v = match String.Map.find pre t.log_attached with - | None -> [] - | Some xs -> xs - in - String.Map.add pre ((s, id leaf) :: v) t.log_attached - else - t.log_attached - in - Ok ({ t with log_attached }, [], `Loop (prefix, perms)) - ) >>= fun (t, outs, res) -> - Ok (t, initial_out :: out @ outs, res) - -let handle_stat state hdr data = - let open Vmm_wire in - if not (version_eq hdr.version state.stats_version) then begin - Logs.warn (fun m -> m "ignoring message with unknown stats version") ; - state, [] - end else if hdr.tag = success_tag then - state, [] - else - match IM.find hdr.id state.stats_requests with - | exception Not_found -> - Logs.err (fun m -> m "couldn't find stat request") ; - state, [] - | (s, req_id, f) -> - let stats_requests = IM.remove hdr.id state.stats_requests in - let state = { state with stats_requests } in - let out = - match Stats.int_to_op hdr.tag with - | Some Stats.Stat_reply -> - begin match Stats.decode_stats (Cstruct.of_string data) with - | Ok (ru, vmm, ifs) -> - let ifs = - List.map - (fun x -> - match f x.name with - | Some name -> { x with name } - | None -> x) - ifs - in - let data = Cstruct.to_string (Stats.encode_stats (ru, vmm, ifs)) in - let out = Client.stat data req_id state.client_version in - [ `Tls (s, out) ] - | Error (`Msg msg) -> - Logs.err (fun m -> m "error %s while decode statistics" msg) ; - let out = fail req_id state.client_version in - [ `Tls (s, out) ] - end - | None when hdr.tag = fail_tag -> - let out = fail ~msg:data req_id state.client_version in - [ `Tls (s, out) ] - | _ -> - Logs.err (fun m -> m "unexpected reply from stat") ; - [] - in - (state, out) - -let handle_cons state hdr data = - let open Vmm_wire in - if not (version_eq hdr.version state.console_version) then begin - Logs.warn (fun m -> m "ignoring message with unknown console version") ; - state, [] - end else match Console.int_to_op hdr.tag with - | Some Console.Data -> - begin match decode_str data with - | Error (`Msg msg) -> - Logs.err (fun m -> m "error while decoding console message %s" msg) ; - (state, []) - | Ok (file, off) -> - (match String.Map.find file state.console_attached with - | Some s -> - let out = Client.console off file data state.client_version in - (state, [ `Tls (s, out) ]) - | None -> - (* TODO: should detach? *) - Logs.err (fun m -> m "couldn't find attached console for %s" file) ; - (state, [])) - end - | None when hdr.tag = success_tag -> - (match IM.find hdr.id state.console_requests with - | exception Not_found -> - (state, []) - | cont -> - let state', outs = cont state in - let console_requests = IM.remove hdr.id state.console_requests in - ({ state' with console_requests }, outs)) - | None when hdr.tag = fail_tag -> - (match IM.find hdr.id state.console_requests with - | exception Not_found -> - Logs.err (fun m -> m "fail couldn't find request id") ; - (state, []) - | _ -> - Logs.err (fun m -> m "failed while trying to do something on console") ; - let console_requests = IM.remove hdr.id state.console_requests in - ({ state with console_requests }, [])) - | _ -> - Logs.err (fun m -> m "unexpected message received from console socket") ; - (state, []) - -let handle_log state hdr buf = - let open Vmm_wire in - let open Vmm_wire.Log in - if not (version_eq hdr.version state.log_version) then begin - Logs.warn (fun m -> m "ignoring message with unknown stats version") ; - state, [] - end else match IM.find hdr.id state.log_requests with - | exception Not_found -> - Logs.warn (fun m -> m "(ignored) coudn't find log request") ; - (state, []) - | (s, rid) -> - let r = match int_to_op hdr.tag with - | Some Data -> - decode_log_hdr (Cstruct.of_string buf) >>= fun (hdr, rest) -> - decode_event rest >>= fun event -> - let tls = Vmm_wire.Client.log hdr event state.client_version in - Ok (state, [ `Tls (s, tls) ]) - | None when hdr.tag = success_tag -> - let log_requests = IM.remove hdr.id state.log_requests in - let tls = Vmm_wire.success rid state.client_version in - Ok ({ state with log_requests }, [ `Tls (s, tls) ]) - | None when hdr.tag = fail_tag -> - let log_requests = IM.remove hdr.id state.log_requests in - let tls = Vmm_wire.fail rid state.client_version in - Ok ({ state with log_requests }, [ `Tls (s, tls) ]) - | _ -> - Logs.err (fun m -> m "couldn't parse log reply") ; - let log_requests = IM.remove hdr.id state.log_requests in - Ok ({ state with log_requests }, []) - in - match r with - | Ok (s, out) -> s, out - | Error (`Msg msg) -> - Logs.err (fun m -> m "error while processing log %s" msg) ; - state, [] + | Some Create -> + Vmm_wire.Vm.decode_vm_config buf >>= fun vm_config -> + handle_create t hdr vm_config + | Some Destroy -> + match Vmm_resources.find_vm t.resources id with + | Some vm -> + Vmm_unix.destroy vm ; + let id_str = string_of_id id in + let out, next = + let success = Vmm_wire.success t.client_version hdr.Vmm_wire.id hdr.Vmm_wire.tag in + let s = [ `Data success ] in + match String.Map.find_opt id_str t.tasks with + | None -> s, `End + | Some t -> [], `Wait (t, s) + in + let tasks = String.Map.remove id_str t.tasks in + Ok ({ t with tasks }, out, next) + | None -> Error (`Msg "destroy: not found")) diff --git a/src/vmm_lwt.ml b/src/vmm_lwt.ml index bfaff67..80dfb34 100644 --- a/src/vmm_lwt.ml +++ b/src/vmm_lwt.ml @@ -2,6 +2,11 @@ open Lwt.Infix +let pp_sockaddr ppf = function + | Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str + | Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d" + (Unix.string_of_inet_addr addr) port + let pp_process_status ppf = function | Unix.WEXITED c -> Fmt.pf ppf "exited with %d" c | Unix.WSIGNALED s -> Fmt.pf ppf "killed by signal %a" Fmt.Dump.signal s @@ -36,8 +41,8 @@ let wait_and_clear pid stdout = Logs.debug (fun m -> m "pid %d exited: %a" pid pp_process_status s) ; ret s -let read_exactly s = - let buf = Bytes.create 8 in +let read_wire s = + let buf = Bytes.create (Int32.to_int Vmm_wire.header_size) in let rec r b i l = Lwt.catch (fun () -> Lwt_unix.read s b i l >>= function @@ -53,29 +58,28 @@ let read_exactly s = let err = Printexc.to_string e in Logs.err (fun m -> m "exception %s while reading" err) ; Lwt.return (Error `Exception)) - in - r buf 0 8 >>= function + r buf 0 (Int32.to_int Vmm_wire.header_size) >>= function | Error e -> Lwt.return (Error e) | Ok () -> - match Vmm_wire.parse_header (Bytes.to_string buf) with + match Vmm_wire.decode_header (Cstruct.of_bytes buf) with | Error (`Msg m) -> Lwt.return (Error (`Msg m)) | Ok hdr -> - let l = hdr.Vmm_wire.length in + let l = Int32.to_int hdr.Vmm_wire.length in if l > 0 then let b = Bytes.create l in r b 0 l >|= function | Error e -> Error e | Ok () -> - (* Logs.debug (fun m -> m "read hdr %a, body %a" + Logs.debug (fun m -> m "read hdr %a, body %a" Cstruct.hexdump_pp (Cstruct.of_bytes buf) - Cstruct.hexdump_pp (Cstruct.of_bytes b)) ; *) - Ok (hdr, Bytes.to_string b) + Cstruct.hexdump_pp (Cstruct.of_bytes b)) ; + Ok (hdr, Cstruct.of_bytes b) else - Lwt.return (Ok (hdr, "")) + Lwt.return (Ok (hdr, Cstruct.empty)) -let write_raw s buf = - let buf = Bytes.unsafe_of_string buf in +let write_wire s buf = + let buf = Cstruct.to_bytes buf in let rec w off l = Lwt.catch (fun () -> Lwt_unix.send s buf off l [] >>= fun n -> @@ -87,5 +91,10 @@ let write_raw s buf = Logs.err (fun m -> m "exception %s while writing" (Printexc.to_string e)) ; Lwt.return (Error `Exception)) in - (* Logs.debug (fun m -> m "writing %a" Cstruct.hexdump_pp (Cstruct.of_bytes buf)) ; *) + Logs.debug (fun m -> m "writing %a" Cstruct.hexdump_pp (Cstruct.of_bytes buf)) ; w 0 (Bytes.length buf) + +let safe_close fd = + Lwt.catch + (fun () -> Lwt_unix.close fd) + (fun _ -> Lwt.return_unit) diff --git a/src/vmm_tls.ml b/src/vmm_tls.ml index 671fea3..b41d40a 100644 --- a/src/vmm_tls.ml +++ b/src/vmm_tls.ml @@ -26,14 +26,14 @@ let read_tls t = Logs.err (fun m -> m "TLS read exception %s" (Printexc.to_string e)) ; Lwt.return (Error `Exception)) in - let buf = Cstruct.create 8 in - r_n buf 0 8 >>= function + let buf = Cstruct.create (Int32.to_int Vmm_wire.header_size) in + r_n buf 0 (Int32.to_int Vmm_wire.header_size) >>= function | Error e -> Lwt.return (Error e) | Ok () -> - match Vmm_wire.parse_header (Cstruct.to_string buf) with + match Vmm_wire.decode_header buf with | Error (`Msg m) -> Lwt.return (Error (`Msg m)) | Ok hdr -> - let l = hdr.Vmm_wire.length in + let l = Int32.to_int hdr.Vmm_wire.length in if l > 0 then let b = Cstruct.create l in r_n b 0 l >|= function diff --git a/src/vmm_wire.ml b/src/vmm_wire.ml index 0fc77d0..26330bf 100644 --- a/src/vmm_wire.ml +++ b/src/vmm_wire.ml @@ -3,124 +3,148 @@ (* the wire protocol - length prepended binary data *) (* each message (on all channels) is prefixed by a common header: - - length (16 bit) spanning the message (excluding the 8 bytes header) - - id (16 bit) unique id chosen by sender (for request/reply) - 0 shouldn't be used (reserved for log/console messages which do not correspond to a request) - - version (16 bit) the version used on this channel - - tag (16 bit) the type of message + - tag (32 bit) the type of message + it is only 31 bit, the highest (leftmost) bit indicates query (0) or reply (1) + a failure is reported with the special tag 0xFFFFFFFF (all bits set) - data is a string + every request leads to a reply + WV0 and WV1 used 16 bit only + - version (16 bit) the version used on this channel (used to be byte 4-6) + - padding (16 bit) + - id (64 bit) unique id chosen by sender (for request/reply) - 0 shouldn't be used (reserved for log/console messages which do not correspond to a request) + - length (32 bit) spanning the message (excluding the 20 bytes header) + - full VM name (i.e. foo.bar.baz) encoded as size of list followed by list of strings + - replies do not contain the VM name Version and tag are protocol-specific - the channel between vmm and console uses different tags and mayuse a different version than between vmm and - client. *) + client. + + every command issued is replied to with success or failure. broadcast + communication (console data, log events) are not acknowledged by the + recipient. + *) + + +(* TODO unlikely that this is 32bit clean *) open Astring open Vmm_core -type version = [ `WV0 | `WV1 ] +type version = [ `WV0 | `WV1 | `WV2 ] let version_to_int = function | `WV0 -> 0 | `WV1 -> 1 + | `WV2 -> 2 let version_of_int = function | 0 -> Ok `WV0 | 1 -> Ok `WV1 + | 2 -> Ok `WV2 | _ -> Error (`Msg "unknown wire version") let version_eq a b = match a, b with | `WV0, `WV0 -> true | `WV1, `WV1 -> true + | `WV2, `WV2 -> true | _ -> false let pp_version ppf v = Fmt.string ppf (match v with | `WV0 -> "wire version 0" - | `WV1 -> "wire version 1") + | `WV1 -> "wire version 1" + | `WV2 -> "wire version 2") type header = { - length : int ; - id : int ; version : version ; - tag : int ; + tag : int32 ; + length : int32 ; + id : int64 ; } +let header_size = 20l + +let max_size = 0x7FFFFFFFl + +(* Throughout this module, we don't expect any cstruct bigger than the above + max_size (encode checks this!) *) + open Rresult open R.Infix + +let cs_create len = Cstruct.create (Int32.to_int len) + +let cs_len cs = + let l = Cstruct.len cs in + assert (l lsr 31 = 0) ; + Int32.of_int l + let check_len cs l = - if Cstruct.len cs < l then + if Int32.compare (cs_len cs) l = -1 then Error (`Msg "underflow") else Ok () +let cs_shift cs num = + check_len cs (Int32.of_int num) >>= fun () -> + Ok (Cstruct.shift cs num) + let check_exact cs l = - if Cstruct.len cs = l then + if cs_len cs = l then Ok () else Error (`Msg "bad length") -let empty = Cstruct.create 0 - let null cs = if Cstruct.len cs = 0 then Ok () else Error (`Msg "trailing bytes") -let parse_header buf = - let cs = Cstruct.of_string buf in - check_len cs 8 >>= fun () -> - let length = Cstruct.BE.get_uint16 cs 0 - and id = Cstruct.BE.get_uint16 cs 2 - and version = Cstruct.BE.get_uint16 cs 4 - and tag = Cstruct.BE.get_uint16 cs 6 - in - version_of_int version >>= fun version -> - Ok { length ; id ; version ; tag } +let decode_header cs = + check_len cs 8l >>= fun () -> + let version = Cstruct.BE.get_uint16 cs 4 in + version_of_int version >>= function + | `WV0 | `WV1 -> Error (`Msg "unsupported version") + | `WV2 as version -> + check_len cs header_size >>= fun () -> + let tag = Cstruct.BE.get_uint32 cs 0 + and id = Cstruct.BE.get_uint64 cs 8 + and length = Cstruct.BE.get_uint32 cs 16 + in + Ok { length ; id ; version ; tag } -let create_header { length ; id ; version ; tag } = - let hdr = Cstruct.create 8 in - Cstruct.BE.set_uint16 hdr 0 length ; - Cstruct.BE.set_uint16 hdr 2 id ; - Cstruct.BE.set_uint16 hdr 4 (version_to_int version) ; - Cstruct.BE.set_uint16 hdr 6 tag ; - hdr +let encode_header { length ; id ; version ; tag } = + match version with + | `WV0 | `WV1 -> invalid_arg "version no longer supported" + | `WV2 -> + let hdr = cs_create header_size in + Cstruct.BE.set_uint32 hdr 0 tag ; + Cstruct.BE.set_uint16 hdr 4 (version_to_int version) ; + Cstruct.BE.set_uint64 hdr 8 id ; + Cstruct.BE.set_uint32 hdr 16 length ; + hdr + +let max_str_len = 0xFFFF let decode_string cs = - check_len cs 2 >>= fun () -> + check_len cs 2l >>= fun () -> let l = Cstruct.BE.get_uint16 cs 0 in - check_len cs (2 + l) >>= fun () -> + check_len cs (Int32.add 2l (Int32.of_int l)) >>= fun () -> let str = Cstruct.(to_string (sub cs 2 l)) in Ok (str, l + 2) -(* external use only *) -let decode_str str = - if String.length str = 0 then - Ok ("", 0) - else - decode_string (Cstruct.of_string str) - -let decode_strings cs = - let rec go acc buf = - if Cstruct.len buf = 0 then - Ok (List.rev acc) - else - decode_string buf >>= fun (x, l) -> - go (x :: acc) (Cstruct.shift buf l) - in - go [] cs - let encode_string str = let l = String.length str in + assert (l < max_str_len) ; let cs = Cstruct.create (2 + l) in Cstruct.BE.set_uint16 cs 0 l ; Cstruct.blit_from_string str 0 cs 2 l ; - cs, 2 + l - -let encode_strings xs = - Cstruct.concat - (List.map (fun s -> fst (encode_string s)) xs) + cs let max = Int64.of_int max_int let min = Int64.of_int min_int let decode_int ?(off = 0) cs = + check_len cs Int32.(add (of_int off) 8l) >>= fun () -> let i = Cstruct.BE.get_uint64 cs off in if i > max then Error (`Msg "int too big") @@ -134,35 +158,64 @@ let encode_int i = Cstruct.BE.set_uint64 cs 0 (Int64.of_int i) ; cs -(* TODO: 32 bit system clean *) -let decode_pid cs = - check_len cs 4 >>= fun () -> - let pid = Cstruct.BE.get_uint32 cs 0 in - Ok (Int32.to_int pid) +let decode_list inner buf = + decode_int buf >>= fun len -> + let rec go acc idx = function + | 0 -> Ok (List.rev acc, idx) + | n -> + cs_shift buf idx >>= fun cs' -> + inner cs' >>= fun (data, len) -> + go (data :: acc) (idx + len) (pred n) + in + go [] 8 len -(* TODO: can we do sth more appropriate than raise? *) -let encode_pid pid = - let cs = Cstruct.create 4 in - if Int32.to_int Int32.max_int > pid && - Int32.to_int Int32.min_int < pid - then begin - Cstruct.BE.set_uint32 cs 0 (Int32.of_int pid) ; - cs - end else - invalid_arg "pid too big" +let encode_list inner data = + let cs = encode_int (List.length data) in + Cstruct.concat (cs :: (List.map inner data)) -let decode_ptime cs = - check_len cs 16 >>= fun () -> - decode_int cs >>= fun d -> - let ps = Cstruct.BE.get_uint64 cs 8 in +let decode_strings = decode_list decode_string + +let encode_strings = encode_list encode_string + +let encode ?name ?body version id tag = + let vm = match name with None -> Cstruct.empty | Some id -> encode_strings id in + let payload = match body with None -> Cstruct.empty | Some x -> x in + let header = + let length = Int32.(add (cs_len payload) (cs_len vm)) in + { length ; id ; version ; tag } + in + Cstruct.concat [ encode_header header ; vm ; payload ] + +let maybe_str = function + | None -> Cstruct.empty + | Some c -> encode_string c + +let fail_tag = 0xFFFFFFFFl + +let reply_tag = 0x80000000l + +let is_tag v tag = Int32.logand v tag = v + +let is_reply { tag ; _ } = is_tag reply_tag tag + +let is_fail { tag ; _ } = is_tag fail_tag tag + +let reply ?body version id tag = + encode ?body version id (Int32.logor reply_tag tag) + +let fail ?msg version id = + encode ~body:(maybe_str msg) version id fail_tag + +let success ?msg version id tag = + reply ~body:(maybe_str msg) version id tag + +let decode_ptime ?(off = 0) cs = + cs_shift cs off >>= fun cs' -> + check_len cs' 16l >>= fun () -> + decode_int cs' >>= fun d -> + let ps = Cstruct.BE.get_uint64 cs' 8 in Ok (Ptime.v (d, ps)) -(* EXPORT only *) -let decode_ts ?(off = 0) buf = - let cs = Cstruct.of_string buf in - let cs = Cstruct.shift cs off in - decode_ptime cs - let encode_ptime ts = let d, ps = Ptime.(Span.to_d_ps (to_span ts)) in let cs = Cstruct.create 16 in @@ -170,99 +223,70 @@ let encode_ptime ts = Cstruct.BE.set_uint64 cs 8 ps ; cs -let fail_tag = 0xFFFE -let success_tag = 0xFFFF - -let may_enc_str = function - | None -> empty, 0 - | Some msg -> encode_string msg - -let success ?msg id version = - let data, length = may_enc_str msg in - let r = - Cstruct.append - (create_header { length ; id ; version ; tag = success_tag }) data - in - Cstruct.to_string r - -let fail ?msg id version = - let data, length = may_enc_str msg in - let r = - Cstruct.append - (create_header { length ; id ; version ; tag = fail_tag }) data - in - Cstruct.to_string r - module Console = struct - [%%cenum - type op = - | Add_console - | Attach_console - | Detach_console - | History - | Data - [@@uint16_t] - ] + type op = + | Add_console + | Attach_console + | Detach_console + | History + | Data (* is a reply, never acked *) - let encode id version op ?payload nam = - let data, l = encode_string nam in - let length, p = - match payload with - | None -> l, empty - | Some x -> l + Cstruct.len x, x - and tag = op_to_int op - in - let r = - Cstruct.concat - [ (create_header { length ; id ; version ; tag }) ; data ; p ] - in - Cstruct.to_string r + let op_to_int = function + | Add_console -> 0x0100l + | Attach_console -> 0x0101l + | Detach_console -> 0x0102l + | History -> 0x0103l + | Data -> 0x0104l - let data ?(id = 0) v file ts msg = - let payload = + let int_to_op = function + | 0x0100l -> Some Add_console + | 0x0101l -> Some Attach_console + | 0x0102l -> Some Detach_console + | 0x0103l -> Some History + | 0x0104l -> Some Data + | _ -> None + + let data version name ts msg = + let body = let ts = encode_ptime ts - and data, _ = encode_string msg + and data = encode_string msg in Cstruct.append ts data in - encode id v Data ~payload file + encode version ~name ~body 0L (op_to_int Data) - let add id v name = encode id v Add_console name + let add id version name = encode ~name version id (op_to_int Add_console) - let attach id v name = encode id v Attach_console name + let attach id version name = encode ~name version id (op_to_int Attach_console) - let detach id v name = encode id v Detach_console name + let detach id version name = encode ~name version id (op_to_int Detach_console) - let history id v name since = - let payload = encode_ptime since in - encode id v History ~payload name + let history id version name since = + let body = encode_ptime since in + encode ~name ~body version id (op_to_int History) end module Stats = struct - [%%cenum - type op = - | Add - | Remove - | Stat_request - | Stat_reply - [@@uint16_t] - ] + type op = + | Add + | Remove + | Stats - let encode id version op ?payload nam = - let data, l = encode_string nam in - let length, p = - match payload with - | None -> l, empty - | Some x -> l + Cstruct.len x, x - and tag = op_to_int op - in - let r = - Cstruct.concat [ create_header { length ; version ; id ; tag } ; data ; p ] - in - Cstruct.to_string r + let op_to_int = function + | Add -> 0x0200l + | Remove -> 0x0201l + | Stats -> 0x0202l + + let int_to_op = function + | 0x0200l -> Some Add + | 0x0201l -> Some Remove + | 0x0202l -> Some Stats + | _ -> None + + let rusage_len = 144l let encode_rusage ru = - let cs = Cstruct.create (18 * 8) in + let cs = cs_create rusage_len in Cstruct.BE.set_uint64 cs 0 (fst ru.utime) ; Cstruct.BE.set_uint64 cs 8 (Int64.of_int (snd ru.utime)) ; Cstruct.BE.set_uint64 cs 16 (fst ru.stime) ; @@ -284,7 +308,7 @@ module Stats = struct cs let decode_rusage cs = - check_exact cs 144 >>= fun () -> + check_exact cs rusage_len >>= fun () -> (decode_int ~off:8 cs >>= fun ms -> Ok (Cstruct.BE.get_uint64 cs 0, ms)) >>= fun utime -> (decode_int ~off:24 cs >>= fun ms -> @@ -307,9 +331,11 @@ module Stats = struct Ok { utime ; stime ; maxrss ; ixrss ; idrss ; isrss ; minflt ; majflt ; nswap ; inblock ; outblock ; msgsnd ; msgrcv ; nsignals ; nvcsw ; nivcsw } + let ifdata_len = 116l + let encode_ifdata i = - let name, _ = encode_string i.name in - let cs = Cstruct.create (12 * 8 + 5 * 4) in + let name = encode_string i.name in + let cs = cs_create ifdata_len in Cstruct.BE.set_uint32 cs 0 i.flags ; Cstruct.BE.set_uint32 cs 4 i.send_length ; Cstruct.BE.set_uint32 cs 8 i.max_send_length ; @@ -331,8 +357,8 @@ module Stats = struct let decode_ifdata buf = decode_string buf >>= fun (name, l) -> - check_len buf (l + 116) >>= fun () -> - let cs = Cstruct.shift buf l in + cs_shift buf l >>= fun cs -> + check_len cs ifdata_len >>= fun () -> let flags = Cstruct.BE.get_uint32 cs 0 and send_length = Cstruct.BE.get_uint32 cs 4 and max_send_length = Cstruct.BE.get_uint32 cs 8 @@ -355,24 +381,18 @@ module Stats = struct baudrate ; input_packets ; input_errors ; output_packets ; output_errors ; collisions ; input_bytes ; output_bytes ; input_mcast ; output_mcast ; input_dropped ; output_dropped }, - l + 116) + Int32.(to_int ifdata_len) + l) - let add id v nam pid taps = - let payload = Cstruct.append (encode_pid pid) (encode_strings taps) in - encode id v Add ~payload nam + let add id version name pid taps = + let body = Cstruct.append (encode_int pid) (encode_strings taps) in + encode ~name ~body version id (op_to_int Add) - let remove id v nam = encode id v Remove nam + let remove id version name = encode ~name version id (op_to_int Remove) - let stat id v nam = encode id v Stat_request nam + let stat id version name = encode ~name version id (op_to_int Stats) - let stat_reply id version payload = - let length = Cstruct.len payload - and tag = op_to_int Stat_reply - in - let r = - Cstruct.append (create_header { length ; id ; version ; tag }) payload - in - Cstruct.to_string r + let stat_reply id version body = + reply ~body version id (op_to_int Stats) let encode_int64 i = let cs = Cstruct.create 8 in @@ -380,87 +400,76 @@ module Stats = struct cs let decode_int64 ?(off = 0) cs = - check_len cs (8 + off) >>= fun () -> + check_len cs (Int32.add 8l (Int32.of_int off)) >>= fun () -> Ok (Cstruct.BE.get_uint64 cs off) - let encode_vmm_stats xs = - encode_int (List.length xs) :: - List.flatten - (List.map (fun (k, v) -> [ fst (encode_string k) ; encode_int64 v ]) xs) + let encode_vmm_stats = + encode_list + (fun (k, v) -> Cstruct.append (encode_string k) (encode_int64 v)) - let decode_vmm_stats cs = - let rec go acc ctr buf = - if ctr = 0 then - Ok (List.rev acc, buf) - else + let decode_vmm_stats = + decode_list (fun buf -> decode_string buf >>= fun (str, off) -> decode_int64 ~off buf >>= fun v -> - go ((str, v) :: acc) (pred ctr) (Cstruct.shift buf (off + 8)) - in - decode_int cs >>= fun stat_num -> - go [] stat_num (Cstruct.shift cs 8) + Ok ((str, v), off + 8)) let encode_stats (ru, vmm, ifd) = Cstruct.concat - (encode_rusage ru :: - encode_vmm_stats vmm @ - encode_int (List.length ifd) :: List.map encode_ifdata ifd) + [ encode_rusage ru ; + encode_vmm_stats vmm ; + encode_list encode_ifdata ifd ] let decode_stats cs = - check_len cs 144 >>= fun () -> - let ru, rest = Cstruct.split cs 144 in + check_len cs rusage_len >>= fun () -> + let ru, rest = Cstruct.split cs (Int32.to_int rusage_len) in decode_rusage ru >>= fun ru -> - decode_vmm_stats rest >>= fun (vmm, rest) -> - let rec go acc ctr buf = - if ctr = 0 then - Ok (List.rev acc, buf) - else - decode_ifdata buf >>= fun (this, used) -> - go (this :: acc) (pred ctr) (Cstruct.shift buf used) - in - decode_int rest >>= fun num_if -> - go [] num_if (Cstruct.shift rest 8) >>= fun (ifs, _rest) -> + decode_vmm_stats rest >>= fun (vmm, off) -> + cs_shift rest off >>= fun rest' -> + decode_list decode_ifdata rest' >>= fun (ifs, _) -> Ok (ru, vmm, ifs) let decode_pid_taps data = - decode_pid data >>= fun pid -> - decode_strings (Cstruct.shift data 4) >>= fun taps -> + decode_int data >>= fun pid -> + decode_strings (Cstruct.shift data 8) >>= fun (taps, _off) -> Ok (pid, taps) end +let decode_id_ts cs = + decode_strings cs >>= fun (id, off) -> + decode_ptime ~off cs >>= fun ts -> + Ok ((id, ts), off + 16) + +let split_id id = match List.rev id with + | [] -> Error (`Msg "bad header") + | name::rest -> Ok (name, List.rev rest) + module Log = struct - [%%cenum - type op = - | Data - | History - [@@uint16_t] - ] + type op = + | Log + | History + | Broadcast + | Subscribe - let history id version ctx ts = - let tag = op_to_int History in - let nam, _ = encode_string ctx in - let payload = Cstruct.append nam (encode_ptime ts) in - let length = Cstruct.len payload in - let r = - Cstruct.append (create_header { length ; version ; id ; tag }) payload - in - Cstruct.to_string r + let op_to_int = function + | Log -> 0x0300l + | History -> 0x0301l + | Broadcast -> 0x0302l + | Subscribe -> 0x0303l - let encode_log_hdr ?(drop_context = false) hdr = - let ts = encode_ptime hdr.Log.ts - and ctx, _ = encode_string (if drop_context then "" else (string_of_id hdr.Log.context)) - and name, _ = encode_string hdr.Log.name - in - Cstruct.concat [ ts ; ctx ; name ] + let int_to_op = function + | 0x0300l -> Some Log + | 0x0301l -> Some History + | 0x0302l -> Some Broadcast + | 0x0303l -> Some Subscribe + | _ -> None + + let history id version name ts = + encode ~name ~body:(encode_ptime ts) version id (op_to_int History) let decode_log_hdr cs = - decode_ptime cs >>= fun ts -> - let r = Cstruct.shift cs 16 in - decode_string r >>= fun (ctx, l) -> - let context = id_of_string ctx in - let r = Cstruct.shift r l in - decode_string r >>= fun (name, l) -> - Ok ({ Log.ts ; context ; name }, Cstruct.shift r l) + decode_id_ts cs >>= fun ((id, ts), off) -> + split_id id >>= fun (name, context) -> + Ok ({ Log.ts ; context ; name }, Cstruct.shift cs (16 + off)) let encode_addr ip port = let cs = Cstruct.create 6 in @@ -469,24 +478,25 @@ module Log = struct cs let decode_addr cs = - check_len cs 6 >>= fun () -> + check_len cs 6l >>= fun () -> let ip = Ipaddr.V4.of_int32 (Cstruct.BE.get_uint32 cs 0) and port = Cstruct.BE.get_uint16 cs 4 in Ok (ip, port) let encode_vm (pid, taps, block) = - let cs = encode_pid pid in - let bl, _ = encode_string (match block with None -> "" | Some x -> x) in + let cs = encode_int pid in + let bl = encode_string (match block with None -> "" | Some x -> x) in let taps = encode_strings taps in Cstruct.concat [ cs ; bl ; taps ] let decode_vm cs = - decode_pid cs >>= fun pid -> - let r = Cstruct.shift cs 4 in + decode_int cs >>= fun pid -> + let r = Cstruct.shift cs 8 in decode_string r >>= fun (block, l) -> let block = if block = "" then None else Some block in - decode_strings (Cstruct.shift r l) >>= fun taps -> + cs_shift r l >>= fun r' -> + decode_strings r' >>= fun taps -> Ok (pid, taps, block) let encode_pid_exit pid c = @@ -495,19 +505,17 @@ module Log = struct | `Signal n -> 1, n | `Stop n -> 2, n in - let cs = Cstruct.create 1 in - Cstruct.set_uint8 cs 0 r ; - let pid = encode_pid pid - and code = encode_int c + let r_cs = encode_int r + and pid_cs = encode_int pid + and c_cs = encode_int c in - Cstruct.concat [ pid ; cs ; code ] + Cstruct.concat [ pid_cs ; r_cs ; c_cs ] let decode_pid_exit cs = - check_len cs 13 >>= fun () -> - decode_pid cs >>= fun pid -> - let r = Cstruct.get_uint8 cs 4 in - let code = Cstruct.shift cs 5 in - decode_int code >>= fun c -> + check_len cs 24l >>= fun () -> + decode_int cs >>= fun pid -> + decode_int ~off:8 cs >>= fun r -> + decode_int ~off:16 cs >>= fun c -> (match r with | 0 -> Ok (`Exit c) | 1 -> Ok (`Signal c) @@ -515,43 +523,20 @@ module Log = struct | _ -> Error (`Msg "couldn't parse exit status")) >>= fun r -> Ok (pid, r) - let encode_block nam siz = - Cstruct.append (fst (encode_string nam)) (encode_int siz) - - let decode_block cs = - decode_string cs >>= fun (nam, l) -> - check_len cs (l + 8) >>= fun () -> - decode_int ~off:l cs >>= fun siz -> - Ok (nam, siz) - - let encode_delegate bridges bs = - Cstruct.append - (fst (encode_string (match bs with None -> "" | Some x -> x))) - (encode_strings bridges) - - let decode_delegate buf = - decode_string buf >>= fun (bs, l) -> - let bs = if bs = "" then None else Some bs in - decode_strings (Cstruct.shift buf l) >>= fun bridges -> - Ok (bridges, bs) - let encode_event ev = let tag, data = match ev with - | `Startup -> 0, empty + | `Startup -> 0, Cstruct.empty | `Login (ip, port) -> 1, encode_addr ip port | `Logout (ip, port) -> 2, encode_addr ip port | `VM_start vm -> 3, encode_vm vm | `VM_stop (pid, c) -> 4, encode_pid_exit pid c - | `Block_create (nam, siz) -> 5, encode_block nam siz - | `Block_destroy nam -> 6, fst (encode_string nam) - | `Delegate (bridges, bs) -> 7, encode_delegate bridges bs in let cs = Cstruct.create 2 in Cstruct.BE.set_uint16 cs 0 tag ; Cstruct.append cs data let decode_event cs = - check_len cs 2 >>= fun () -> + check_len cs 2l >>= fun () -> let data = Cstruct.(shift cs 2) in match Cstruct.BE.get_uint16 cs 0 with | 0 -> Ok `Startup @@ -559,55 +544,139 @@ module Log = struct | 2 -> decode_addr data >>= fun addr -> Ok (`Logout addr) | 3 -> decode_vm data >>= fun vm -> Ok (`VM_start vm) | 4 -> decode_pid_exit data >>= fun ex -> Ok (`VM_stop ex) - | 5 -> decode_block data >>= fun bl -> Ok (`Block_create bl) - | 6 -> decode_string data >>= fun (nam, _) -> Ok (`Block_destroy nam) - | 7 -> decode_delegate data >>= fun d -> Ok (`Delegate d) | x -> R.error_msgf "couldn't parse event type %d" x - let data id version hdr event = - let hdr = encode_log_hdr hdr - and ev = encode_event event + let log id version hdr event = + let body = Cstruct.append (encode_ptime hdr.Log.ts) (encode_event event) + and name = hdr.Log.context @ [ hdr.Log.name ] in - let payload = Cstruct.append hdr ev in - let length = Cstruct.len payload - and tag = op_to_int Data - in - let r = - Cstruct.append (create_header { length ; id ; version ; tag }) payload - in - Cstruct.to_string r + encode ~name ~body version id (op_to_int Log) end -module Client = struct - let cmd_to_int = function - | Info -> 0 - | Destroy_vm -> 1 - | Create_block -> 2 - | Destroy_block -> 3 - | Statistics -> 4 - | Attach -> 5 - | Detach -> 6 - | Log -> 7 - and cmd_of_int = function - | 0 -> Some Info - | 1 -> Some Destroy_vm - | 2 -> Some Create_block - | 3 -> Some Destroy_block - | 4 -> Some Statistics - | 5 -> Some Attach - | 6 -> Some Detach - | 7 -> Some Log +module Vm = struct + type op = + | Create + | Destroy + | Info + (* | Add_policy *) + + let op_to_int = function + | Create -> 0x0400l + | Destroy -> 0x0401l + | Info -> 0x0402l + + let int_to_op = function + | 0x0400l -> Some Create + | 0x0401l -> Some Destroy + | 0x0402l -> Some Info | _ -> None - let console_msg_tag = 0xFFF0 - let log_msg_tag = 0xFFF1 - let stat_msg_tag = 0xFFF2 - let info_msg_tag = 0xFFF3 + let info id version name = + encode ~name version id (op_to_int Info) + + let encode_vm vm = + let name = encode_strings (vm.config.prefix @ [ vm.config.vname ]) + and memory = encode_int vm.config.requested_memory + and cs = encode_string (Bos.Cmd.to_string vm.cmd) + and pid = encode_int vm.pid + and taps = encode_strings vm.taps + in + Cstruct.concat [ name ; memory ; cs ; pid ; taps ] + + let info_reply id version vms = + let body = encode_list encode_vm vms in + reply ~body version id (op_to_int Info) + + let decode_vm cs = + decode_strings cs >>= fun (id, l) -> + cs_shift cs l >>= fun cs' -> + decode_int cs' >>= fun memory -> + cs_shift cs' 8 >>= fun cs'' -> + decode_string cs'' >>= fun (cmd, l') -> + cs_shift cs'' l' >>= fun cs''' -> + decode_int cs''' >>= fun pid -> + cs_shift cs''' 8 >>= fun cs'''' -> + decode_strings cs'''' >>= fun (taps, l'') -> + Ok ((id, memory, cmd, pid, taps), l + 8 + l' + l'') + + let decode_vms buf = decode_list decode_vm buf + + let encode_vm_config vm = + let cpu = encode_int vm.cpuid + and mem = encode_int vm.requested_memory + and block = encode_string (match vm.block_device with None -> "" | Some x -> x) + and network = encode_strings vm.network + and vmimage = Cstruct.concat [ encode_int (vmtype_to_int (fst vm.vmimage)) ; + encode_int (Cstruct.len (snd vm.vmimage)) ; + snd vm.vmimage ] + and args = encode_strings (match vm.argv with None -> [] | Some args -> args) + in + Cstruct.concat [ cpu ; mem ; block ; network ; vmimage ; args ] + + let decode_vm_config buf = + decode_strings buf >>= fun (id, off) -> + Logs.debug (fun m -> m "vm_config id %a" pp_id id) ; + split_id id >>= fun (vname, prefix) -> + cs_shift buf off >>= fun buf' -> + decode_int buf' >>= fun cpuid -> + Logs.debug (fun m -> m "cpuid %d" cpuid) ; + decode_int ~off:8 buf' >>= fun requested_memory -> + Logs.debug (fun m -> m "mem %d" requested_memory) ; + cs_shift buf' 16 >>= fun buf'' -> + decode_string buf'' >>= fun (block, off) -> + Logs.debug (fun m -> m "block %s" block) ; + cs_shift buf'' off >>= fun buf''' -> + let block_device = if block = "" then None else Some block in + decode_strings buf''' >>= fun (network, off') -> + cs_shift buf''' off' >>= fun buf'''' -> + decode_int buf'''' >>= fun vmtype -> + (match int_to_vmtype vmtype with + | Some x -> Ok x + | None -> Error (`Msg "unknown vmtype")) >>= fun vmtype -> + decode_int ~off:8 buf'''' >>= fun size -> + check_len buf'''' (Int32.of_int size) >>= fun () -> + let vmimage = (vmtype, Cstruct.sub buf'''' 16 size) in + cs_shift buf'''' (16 + size) >>= fun buf''''' -> + decode_strings buf''''' >>= fun (argv, _) -> + let argv = match argv with [] -> None | xs -> Some xs in + Ok { vname ; prefix ; cpuid ; requested_memory ; block_device ; network ; vmimage ; argv } + + let create id version vm = + let body = encode_vm_config vm in + let name = vm.prefix @ [ vm.vname ] in + encode ~name ~body version id (op_to_int Create) + + let destroy id version name = + encode ~name version id (op_to_int Destroy) +end + +(* +module Client = struct + let cmd_to_int = function + | Info -> 0x0500l + | Destroy_vm -> 0x0501l + | Create_block -> 0x0502l + | Destroy_block -> 0x0503l + | Statistics -> 0x0504l + | Attach -> 0x0505l + | Detach -> 0x0506l + | Log -> 0x0507l + and cmd_of_int = function + | 0x0500l -> Some Info + | 0x0501l -> Some Destroy_vm + | 0x0502l -> Some Create_block + | 0x0503l -> Some Destroy_block + | 0x0504l -> Some Statistics + | 0x0505l -> Some Attach + | 0x0506l -> Some Detach + | 0x0507l -> Some Log + | _ -> None let cmd ?arg it id version = let pay, length = may_enc_str arg and tag = cmd_to_int it in + let length = Int32.of_int length in let hdr = create_header { length ; id ; version ; tag } in Cstruct.(to_string (append hdr pay)) @@ -617,17 +686,17 @@ module Client = struct (Log.encode_log_hdr ~drop_context:true hdr) (Log.encode_event event) in - let length = Cstruct.len payload in + let length = cs_len payload in let r = Cstruct.append - (create_header { length ; id = 0 ; version ; tag = log_msg_tag }) + (create_header { length ; id = 0L ; version ; tag = Log.(op_to_int Data) }) payload in Cstruct.to_string r let stat data id version = - let length = String.length data in - let hdr = create_header { length ; id ; version ; tag = stat_msg_tag } in + let length = Int32.of_int (String.length data) in + let hdr = create_header { length ; id ; version ; tag = Stats.(op_to_int Stat_reply) } in Cstruct.to_string hdr ^ data let console off name payload version = @@ -640,15 +709,16 @@ module Client = struct let p' = Astring.String.drop ~max:off payload in p', l + String.length p' in + let length = Int32.of_int length in let hdr = - create_header { length ; id = 0 ; version ; tag = console_msg_tag } + create_header { length ; id = 0L ; version ; tag = Console.(op_to_int Data) } in Cstruct.(to_string (append hdr nam)) ^ payload let encode_vm name vm = - let name, _ = encode_string name - and cs, _ = encode_string (Bos.Cmd.to_string vm.cmd) - and pid = encode_pid vm.pid + let name = encode_string name + and cs = encode_string (Bos.Cmd.to_string vm.cmd) + and pid = encode_int vm.pid and taps = encode_strings vm.taps in let tapc = encode_int (Cstruct.len taps) in @@ -657,13 +727,14 @@ module Client = struct let info data id version = let length = String.length data in - let hdr = create_header { length ; id ; version ; tag = info_msg_tag } in + let length = Int32.of_int length in + let hdr = create_header { length ; id ; version ; tag = success_tag } in Cstruct.to_string hdr ^ data let decode_vm cs = decode_string cs >>= fun (name, l) -> decode_string (Cstruct.shift cs l) >>= fun (cmd, l') -> - decode_pid (Cstruct.shift cs (l + l')) >>= fun pid -> + decode_int (Cstruct.shift cs (l + l')) >>= fun pid -> decode_int ~off:(l + l' + 4) cs >>= fun tapc -> let taps = Cstruct.sub cs (l + l' + 12) tapc in decode_strings taps >>= fun taps -> @@ -695,3 +766,4 @@ module Client = struct decode_string (Cstruct.shift cs (l + 16)) >>= fun (line, _) -> Ok (name, ts, line) end + *) diff --git a/src/vmm_x509.ml b/src/vmm_x509.ml new file mode 100644 index 0000000..37657b1 --- /dev/null +++ b/src/vmm_x509.ml @@ -0,0 +1,163 @@ + +let asn_version = `AV0 + +let handle_single_revocation t prefix serial = + let id = identifier serial in + (match Vmm_resources.find t.resources (prefix @ [ id ]) with + | None -> () + | Some e -> Vmm_resources.iter Vmm_unix.destroy e) ; + (* also revoke all active sessions!? *) + (* TODO: maybe we need a vmm_resources like structure for sessions as well!? *) + let log_attached, kill = + let pid = string_of_id prefix in + match String.Map.find pid t.log_attached with + | None -> t.log_attached, [] + | Some xs -> + (* those where snd v = serial: drop *) + let drop, keep = List.partition (fun (_, s) -> String.equal s id) xs in + String.Map.add pid keep t.log_attached, drop + in + (* two things: + 1 revoked LEAF certs need to go (k = prefix, snd v = serial) [see above] + 2 revoked CA certs need to wipe subtree (all entries where k starts with prefix @ serial) *) + let log_attached, kill = + String.Map.fold (fun k' v (l, k) -> + if is_sub_id ~super:(prefix@[id]) ~sub:(id_of_string k') then + (l, v @ k) + else + (String.Map.add k' v l, k)) + log_attached + (String.Map.empty, kill) + in + let state, out = + List.fold_left (fun (s, out) (t, _) -> + let s', out' = handle_disconnect s t in + s', out @ out') + ({ t with log_attached }, []) + kill + in + (state, + List.map (fun x -> `Raw x) out, + List.map fst kill) + +let handle_revocation t s leaf chain ca prefix = + Vmm_asn.crl_of_cert leaf >>= fun crl -> + (* verify data (must be signed by the last cert of the chain (or cacert if chain is empty))! *) + let issuer = match chain with + | subca::_ -> subca + | [] -> ca + in + let time = Ptime_clock.now () in + (if X509.CRL.verify crl ~time issuer then Ok () else Error (`Msg "couldn't verify CRL")) >>= fun () -> + (* the this_update must be > now, next_update < now, this_update > .this_update, number > .number *) + (* TODO: can we have something better for uniqueness of CRL? *) + let local = try Some (List.find (fun crl -> X509.CRL.verify crl issuer) t.crls) with Not_found -> None in + (match local with + | None -> Ok () + | Some local -> match X509.CRL.crl_number local, X509.CRL.crl_number crl with + | None, _ -> Ok () + | Some _, None -> Error (`Msg "CRL number not present") + | Some x, Some y -> if y > x then Ok () else Error (`Msg "CRL number not increased")) >>= fun () -> + (* filename should be whatever_dir / crls / *) + let filename = Fpath.(dbdir / "crls" / string_of_id prefix) in + Bos.OS.File.delete filename >>= fun () -> + Bos.OS.File.write filename (Cstruct.to_string (X509.Encoding.crl_to_cstruct crl)) >>= fun () -> + (* remove crl with same issuer from crls, and inject this one into state *) + let crls = + match local with + | None -> crl :: t.crls + | Some _ -> crl :: List.filter (fun c -> c <> crl) t.crls + in + (* iterate over revoked serials, find active resources, and kill them *) + let newly_revoked = + let old = match local with + | Some x -> List.map (fun rc -> rc.X509.CRL.serial) (X509.CRL.revoked_certificates x) + | None -> [] + in + let new_rev = List.map (fun rc -> rc.X509.CRL.serial) (X509.CRL.revoked_certificates crl) in + List.filter (fun n -> not (List.mem n old)) new_rev + in + let t, out, close = + List.fold_left (fun (t, out, close) serial -> + let t', out', close' = handle_single_revocation t prefix serial in + (t', out @ out', close @ close')) + (t, [], []) newly_revoked + in + let tls_out = Vmm_wire.success ~msg:"updated revocation list" 0 t.client_version in + Ok ({ t with crls }, `Tls (s, tls_out) :: out, `Close close) + +let handle_initial t s addr chain ca = + separate_chain chain >>= fun (leaf, chain) -> + Logs.debug (fun m -> m "leaf is %s, chain %a" + (X509.common_name_to_string leaf) + Fmt.(list ~sep:(unit "->") string) + (List.map X509.common_name_to_string chain)) ; + (* TODO here: inspect top-level-cert of chain. + may need to create bridges and/or block device subdirectory (zfs create) *) + let prefix = List.map id chain in + let login_hdr, login_ev = Log.hdr prefix (id leaf), `Login addr in + let t, out = log t (login_hdr, login_ev) in + let initial_out = `Tls (s, Vmm_wire.Client.log login_hdr login_ev t.client_version) in + Vmm_asn.permissions_of_cert asn_version leaf >>= fun perms -> + (if (List.mem `Create perms || List.mem `Force_create perms) && Vmm_asn.contains_vm leaf then + (* convert certificate to vm_config *) + Vmm_asn.vm_of_cert prefix leaf >>= fun vm_config -> + Logs.debug (fun m -> m "vm %a" pp_vm_config vm_config) ; + (* get names and static resources *) + List.fold_left (fun acc ca -> + acc >>= fun acc -> + Vmm_asn.delegation_of_cert asn_version ca >>= fun res -> + let name = id ca in + Ok ((name, res) :: acc)) + (Ok []) chain >>= fun policies -> + (* check static policies *) + Logs.debug (fun m -> m "now checking static policies") ; + check_policies vm_config (List.map snd policies) >>= fun () -> + let t, task = + let force = List.mem `Force_create perms in + if force then + let fid = vm_id vm_config in + match String.Map.find fid t.tasks with + | None -> t, None + | Some task -> + let kill () = + match Vmm_resources.find_vm t.resources (fullname vm_config) with + | None -> + Logs.err (fun m -> m "found a task, but no vm for %a (%s)" + pp_id (fullname vm_config) fid) + | Some vm -> + Logs.debug (fun m -> m "killing %a now" pp_vm vm) ; + Vmm_unix.destroy vm + in + let tasks = String.Map.remove fid t.tasks in + ({ t with tasks }, Some (kill, task)) + else + t, None + in + let next t sleeper = + handle_create t vm_config policies >>= fun cont -> + let id = vm_id vm_config in + let cons = Vmm_wire.Console.add t.console_counter t.console_version id in + let tasks = String.Map.add id sleeper t.tasks in + Ok ({ t with console_counter = succ t.console_counter ; tasks }, + [ `Raw (t.console_socket, cons) ], + cont) + in + Ok (t, [], `Create (task, next)) + else if List.mem `Crl perms && Vmm_asn.contains_crl leaf then + handle_revocation t s leaf chain ca prefix + else + let log_attached = + if cmd_allowed perms Log then + let pre = string_of_id prefix in + let v = match String.Map.find pre t.log_attached with + | None -> [] + | Some xs -> xs + in + String.Map.add pre ((s, id leaf) :: v) t.log_attached + else + t.log_attached + in + Ok ({ t with log_attached }, [], `Loop (prefix, perms)) + ) >>= fun (t, outs, res) -> + Ok (t, initial_out :: out @ outs, res) diff --git a/stats/vmm_stats.ml b/stats/vmm_stats.ml index b7cf686..e8268d7 100644 --- a/stats/vmm_stats.ml +++ b/stats/vmm_stats.ml @@ -192,10 +192,9 @@ let remove_vmid t vmid = let remove_vmids t vmids = List.fold_left remove_vmid t vmids -let handle t hdr buf = +let handle t hdr cs = let open Vmm_wire in let open Vmm_wire.Stats in - let cs = Cstruct.of_string buf in let r = if not (version_eq my_version hdr.version) then Error (`Msg "cannot handle version") @@ -205,11 +204,11 @@ let handle t hdr buf = | Some Add -> decode_pid_taps (Cstruct.shift cs off) >>= fun (pid, taps) -> add_pid t name pid taps >>= fun t -> - Ok (t, `Add name, success ~msg:"added" hdr.id my_version) + Ok (t, `Add name, success ~msg:"added" my_version hdr.id (op_to_int Add)) | Some Remove -> let t = remove_vmid t name in - Ok (t, `Remove name, success ~msg:"removed" hdr.id my_version) - | Some Stat_request -> + Ok (t, `Remove name, success ~msg:"removed" my_version hdr.id (op_to_int Remove)) + | Some Stats -> stats t name >>= fun s -> Ok (t, `None, stat_reply hdr.id my_version (encode_stats s)) | _ -> Error (`Msg "unknown command") @@ -218,4 +217,4 @@ let handle t hdr buf = | Ok (t, action, out) -> t, action, out | Error (`Msg msg) -> Logs.err (fun m -> m "error while processing %s" msg) ; - t, `None, fail ~msg hdr.id my_version + t, `None, fail ~msg my_version hdr.id diff --git a/stats/vmm_stats_lwt.ml b/stats/vmm_stats_lwt.ml index c8a7731..644cc62 100644 --- a/stats/vmm_stats_lwt.ml +++ b/stats/vmm_stats_lwt.ml @@ -24,11 +24,11 @@ let pp_sockaddr ppf = function let handle s addr () = Logs.info (fun m -> m "handling stats connection %a" pp_sockaddr addr) ; let rec loop acc = - Vmm_lwt.read_exactly s >>= function + Vmm_lwt.read_wire s >>= function | Error (`Msg msg) -> Logs.err (fun m -> m "error while reading %s" msg) ; loop acc | Error _ -> Logs.err (fun m -> m "exception while reading") ; Lwt.return acc | Ok (hdr, data) -> - Logs.debug (fun m -> m "received %a" Cstruct.hexdump_pp (Cstruct.of_string data)) ; + Logs.debug (fun m -> m "received %a" Cstruct.hexdump_pp data) ; let t', action, out = Vmm_stats.handle !t hdr data in let acc = match action with | `Add pid -> pid :: acc @@ -36,8 +36,8 @@ let handle s addr () = | `None -> acc in t := t' ; - Logs.debug (fun m -> m "sent %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ; - Vmm_lwt.write_raw s out >>= function + Logs.debug (fun m -> m "sent %a" Cstruct.hexdump_pp out) ; + Vmm_lwt.write_wire s out >>= function | Ok () -> loop acc | Error _ -> Logs.err (fun m -> m "exception while writing") ; Lwt.return acc in