versioning: revise it all, use a 'current' in Vmm_commands, all daemons reply with the received version on that particular stream

This commit is contained in:
Hannes Mehnert 2019-11-11 21:49:51 +01:00
parent 365a569b25
commit 784429744c
27 changed files with 337 additions and 384 deletions

View file

@ -3,8 +3,6 @@
open Lwt.Infix open Lwt.Infix
open X509 open X509
let version = `AV4
let read fd = let read fd =
(* now we busy read and process output *) (* now we busy read and process output *)
Logs.debug (fun m -> m "reading tls stream") ; Logs.debug (fun m -> m "reading tls stream") ;
@ -15,7 +13,7 @@ let read fd =
Lwt.return (Ok ()) Lwt.return (Ok ())
| Error _ -> Lwt.return (Error (`Msg ("read failure"))) | Error _ -> Lwt.return (Error (`Msg ("read failure")))
| Ok wire -> | Ok wire ->
Albatross_cli.print_result version wire ; Albatross_cli.print_result wire ;
loop () loop ()
in in
loop () loop ()
@ -37,6 +35,10 @@ let timestamps validity =
| Some now, Some exp -> (now, exp) | Some now, Some exp -> (now, exp)
let handle (host, port) cert key ca id (cmd : Vmm_commands.t) = let handle (host, port) cert key ca id (cmd : Vmm_commands.t) =
Printexc.register_printer (function
| Tls_lwt.Tls_alert x -> Some ("TLS alert: " ^ Tls.Packet.alert_type_to_string x)
| Tls_lwt.Tls_failure f -> Some ("TLS failure: " ^ Tls.Engine.string_of_failure f)
| _ -> None) ;
Vmm_lwt.read_from_file cert >>= fun cert_cs -> Vmm_lwt.read_from_file cert >>= fun cert_cs ->
Vmm_lwt.read_from_file key >>= fun key_cs -> Vmm_lwt.read_from_file key >>= fun key_cs ->
match Certificate.decode_pem cert_cs, Private_key.decode_pem key_cs with match Certificate.decode_pem cert_cs, Private_key.decode_pem key_cs with
@ -48,7 +50,7 @@ let handle (host, port) cert key ca id (cmd : Vmm_commands.t) =
let tmpkey = Nocrypto.Rsa.generate 4096 in let tmpkey = Nocrypto.Rsa.generate 4096 in
let name = Vmm_core.Name.to_string id in let name = Vmm_core.Name.to_string id in
let extensions = let extensions =
let v = Vmm_asn.cert_extension_to_cstruct (version, cmd) in let v = Vmm_asn.to_cert_extension cmd in
Extension.(add Key_usage (true, [ `Digital_signature ; `Key_encipherment ]) Extension.(add Key_usage (true, [ `Digital_signature ; `Key_encipherment ])
(add Basic_constraints (true, (false, None)) (add Basic_constraints (true, (false, None))
(add Ext_key_usage (true, [ `Client_auth ]) (add Ext_key_usage (true, [ `Client_auth ])
@ -285,7 +287,7 @@ let default_cmd =
`P "$(tname) executes the provided subcommand on a remote albatross" ] `P "$(tname) executes the provided subcommand on a remote albatross" ]
in in
Term.(ret (const help $ setup_log $ destination $ Term.man_format $ Term.choice_names $ Term.pure None)), Term.(ret (const help $ setup_log $ destination $ Term.man_format $ Term.choice_names $ Term.pure None)),
Term.info "albatross_client_bistro" ~version:"%%VERSION_NUM%%" ~doc ~man Term.info "albatross_client_bistro" ~version ~doc ~man
let cmds = [ help_cmd ; info_cmd ; let cmds = [ help_cmd ; info_cmd ;
policy_cmd ; remove_policy_cmd ; add_policy_cmd ; policy_cmd ; remove_policy_cmd ; add_policy_cmd ;

View file

@ -2,8 +2,6 @@
open Lwt.Infix open Lwt.Infix
let version = `AV4
let socket t = function let socket t = function
| Some x -> x | Some x -> x
| None -> Vmm_core.socket_path t | None -> Vmm_core.socket_path t
@ -11,7 +9,7 @@ let socket t = function
let process fd = let process fd =
Vmm_lwt.read_wire fd >|= function Vmm_lwt.read_wire fd >|= function
| Error _ -> Error () | Error _ -> Error ()
| Ok wire -> Ok (Albatross_cli.print_result version wire) | Ok wire -> Ok (Albatross_cli.print_result wire)
let read fd = let read fd =
(* now we busy read and process output *) (* now we busy read and process output *)
@ -32,7 +30,7 @@ let handle opt_socket name (cmd : Vmm_commands.t) =
in in
Lwt.return err Lwt.return err
| Some fd -> | Some fd ->
let header = Vmm_commands.{ version ; sequence = 0L ; name } in let header = Vmm_commands.header name in
Vmm_lwt.write_wire fd (header, `Command cmd) >>= function Vmm_lwt.write_wire fd (header, `Command cmd) >>= function
| Error `Exception -> Lwt.return (Error (`Msg "exception")) | Error `Exception -> Lwt.return (Error (`Msg "exception"))
| Ok () -> | Ok () ->
@ -248,7 +246,7 @@ let default_cmd =
`P "$(tname) connects to albatrossd via a local socket" ] `P "$(tname) connects to albatrossd via a local socket" ]
in in
Term.(ret (const help $ setup_log $ socket $ Term.man_format $ Term.choice_names $ Term.pure None)), Term.(ret (const help $ setup_log $ socket $ Term.man_format $ Term.choice_names $ Term.pure None)),
Term.info "albatross_client_local" ~version:"%%VERSION_NUM%%" ~doc ~man Term.info "albatross_client_local" ~version ~doc ~man
let cmds = [ help_cmd ; info_cmd ; let cmds = [ help_cmd ; info_cmd ;
policy_cmd ; remove_policy_cmd ; add_policy_cmd ; policy_cmd ; remove_policy_cmd ; add_policy_cmd ;

View file

@ -2,8 +2,6 @@
open Lwt.Infix open Lwt.Infix
let version = `AV4
let rec read_tls_write_cons t = let rec read_tls_write_cons t =
Vmm_tls_lwt.read_tls t >>= function Vmm_tls_lwt.read_tls t >>= function
| Error `Eof -> | Error `Eof ->
@ -12,7 +10,7 @@ let rec read_tls_write_cons t =
| Error _ -> | Error _ ->
Lwt.return (Error (`Msg ("read failure"))) Lwt.return (Error (`Msg ("read failure")))
| Ok wire -> | Ok wire ->
Albatross_cli.print_result version wire ; Albatross_cli.print_result wire ;
read_tls_write_cons t read_tls_write_cons t
let client cas host port cert priv_key = let client cas host port cert priv_key =
@ -82,7 +80,7 @@ let cmd =
`P "$(tname) connects to an Albatross server and initiates a TLS handshake" ] `P "$(tname) connects to an Albatross server and initiates a TLS handshake" ]
in in
Term.(pure run_client $ setup_log $ cas $ client_cert $ client_key $ destination), Term.(pure run_client $ setup_log $ cas $ client_cert $ client_key $ destination),
Term.info "albatross_client_remote_tls" ~version:"%%VERSION_NUM%%" ~doc ~man Term.info "albatross_client_remote_tls" ~version ~doc ~man
let () = let () =
match Term.eval cmd match Term.eval cmd

View file

@ -48,14 +48,12 @@ let init_influx name data =
in in
Lwt.async report Lwt.async report
let print_result version (header, reply) = let print_result ((_, reply) as wire) =
if not (Vmm_commands.version_eq header.Vmm_commands.version version) then match reply with
Logs.err (fun m -> m "version not equal") | `Success _ -> Logs.app (fun m -> m "%a" Vmm_commands.pp_wire wire)
else match reply with | `Data _ -> Logs.app (fun m -> m "%a" Vmm_commands.pp_wire wire)
| `Success _ -> Logs.app (fun m -> m "%a" Vmm_commands.pp_wire (header, reply)) | `Failure _ -> Logs.warn (fun m -> m "%a" Vmm_commands.pp_wire wire)
| `Data _ -> Logs.app (fun m -> m "%a" Vmm_commands.pp_wire (header, reply)) | `Command _ -> Logs.err (fun m -> m "unexpected command %a" Vmm_commands.pp_wire wire)
| `Failure _ -> Logs.warn (fun m -> m "%a" Vmm_commands.pp_wire (header, reply))
| `Command _ -> Logs.err (fun m -> m "unexpected command %a" Vmm_commands.pp_wire (header, reply))
let setup_log style_renderer level = let setup_log style_renderer level =
Fmt_tty.setup_std_outputs ?style_renderer (); Fmt_tty.setup_std_outputs ?style_renderer ();
@ -252,3 +250,7 @@ let count =
let since_count since count = match since with let since_count since count = match since with
| None -> `Count count | None -> `Count count
| Some since -> `Since since | Some since -> `Since since
let version =
Fmt.strf "version %%VERSION%% protocol version %a"
Vmm_commands.pp_version Vmm_commands.current

View file

@ -14,8 +14,6 @@ open Lwt.Infix
open Astring open Astring
let my_version = `AV4
let pp_unix_error ppf e = Fmt.string ppf (Unix.error_message e) let pp_unix_error ppf e = Fmt.string ppf (Unix.error_message e)
let active = ref String.Map.empty let active = ref String.Map.empty
@ -31,8 +29,8 @@ let read_console id name ring fd =
Vmm_ring.write ring (t, line) ; Vmm_ring.write ring (t, line) ;
(match String.Map.find name !active with (match String.Map.find name !active with
| None -> Lwt.return_unit | None -> Lwt.return_unit
| Some fd -> | Some (version, fd) ->
let header = Vmm_commands.{ version = my_version ; sequence = 0L ; name = id } in let header = Vmm_commands.header ~version id in
Vmm_lwt.write_wire fd (header, `Data (`Console_data (t, line))) >>= function Vmm_lwt.write_wire fd (header, `Data (`Console_data (t, line))) >>= function
| Error _ -> | Error _ ->
Vmm_lwt.safe_close fd >|= fun () -> Vmm_lwt.safe_close fd >|= fun () ->
@ -87,21 +85,21 @@ let add_fifo id =
Lwt.async (fun () -> read_console id name ring f >|= fun () -> fifos `Close) ; Lwt.async (fun () -> read_console id name ring f >|= fun () -> fifos `Close) ;
Ok () Ok ()
let subscribe s id = let subscribe s version id =
let name = Vmm_core.Name.to_string id in let name = Vmm_core.Name.to_string id in
Logs.debug (fun m -> m "attempting to subscribe %a" Vmm_core.Name.pp id) ; Logs.debug (fun m -> m "attempting to subscribe %a" Vmm_core.Name.pp id) ;
match String.Map.find name !t with match String.Map.find name !t with
| None -> | None ->
active := String.Map.add name s !active ; active := String.Map.add name (version, s) !active ;
Lwt.return (None, "waiting for VM") Lwt.return (None, "waiting for VM")
| Some r -> | Some r ->
(match String.Map.find name !active with (match String.Map.find name !active with
| None -> Lwt.return_unit | None -> Lwt.return_unit
| Some s -> Vmm_lwt.safe_close s) >|= fun () -> | Some (_, s) -> Vmm_lwt.safe_close s) >|= fun () ->
active := String.Map.add name s !active ; active := String.Map.add name (version, s) !active ;
(Some r, "subscribed") (Some r, "subscribed")
let send_history s r id since = let send_history s version r id since =
let entries = let entries =
match since with match since with
| `Count n -> Vmm_ring.read_last r n | `Count n -> Vmm_ring.read_last r n
@ -109,7 +107,7 @@ let send_history s r id since =
in in
Logs.debug (fun m -> m "%a found %d history" Vmm_core.Name.pp id (List.length entries)) ; Logs.debug (fun m -> m "%a found %d history" Vmm_core.Name.pp id (List.length entries)) ;
Lwt_list.iter_s (fun (i, v) -> Lwt_list.iter_s (fun (i, v) ->
let header = Vmm_commands.{ version = my_version ; sequence = 0L ; name = id } in let header = Vmm_commands.header ~version id in
Vmm_lwt.write_wire s (header, `Data (`Console_data (i, v))) >>= function Vmm_lwt.write_wire s (header, `Data (`Console_data (i, v))) >>= function
| Ok () -> Lwt.return_unit | Ok () -> Lwt.return_unit
| Error _ -> Vmm_lwt.safe_close s) | Error _ -> Vmm_lwt.safe_close s)
@ -123,10 +121,7 @@ let handle s addr =
Logs.err (fun m -> m "exception while reading") ; Logs.err (fun m -> m "exception while reading") ;
Lwt.return_unit Lwt.return_unit
| Ok (header, `Command (`Console_cmd cmd)) -> | Ok (header, `Command (`Console_cmd cmd)) ->
if not (Vmm_commands.version_eq header.Vmm_commands.version my_version) then begin begin
Logs.err (fun m -> m "ignoring data with bad version") ;
Lwt.return_unit
end else begin
let name = header.Vmm_commands.name in let name = header.Vmm_commands.name in
match cmd with match cmd with
| `Console_add -> | `Console_add ->
@ -143,13 +138,13 @@ let handle s addr =
Lwt.return_unit Lwt.return_unit
end end
| `Console_subscribe ts -> | `Console_subscribe ts ->
subscribe s name >>= fun (ring, res) -> subscribe s header.Vmm_commands.version name >>= fun (ring, res) ->
Vmm_lwt.write_wire s (header, `Success (`String res)) >>= function Vmm_lwt.write_wire s (header, `Success (`String res)) >>= function
| Error _ -> Vmm_lwt.safe_close s | Error _ -> Vmm_lwt.safe_close s
| Ok () -> | Ok () ->
(match ring with (match ring with
| None -> Lwt.return_unit | None -> Lwt.return_unit
| Some r -> send_history s r name ts) >>= fun () -> | Some r -> send_history s header.Vmm_commands.version r name ts) >>= fun () ->
(* now we wait for the next read and terminate*) (* now we wait for the next read and terminate*)
Vmm_lwt.read_wire s >|= fun _ -> () Vmm_lwt.read_wire s >|= fun _ -> ()
end end
@ -182,6 +177,6 @@ open Albatross_cli
let cmd = let cmd =
Term.(term_result (const jump $ setup_log $ influx)), Term.(term_result (const jump $ setup_log $ influx)),
Term.info "albatross_console" ~version:"%%VERSION_NUM%%" Term.info "albatross_console" ~version
let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1 let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1

View file

@ -156,8 +156,6 @@ module P = struct
vm ifd.bridge (String.concat ~sep:"," fields) vm ifd.bridge (String.concat ~sep:"," fields)
end end
let my_version = `AV4
let command = ref 1L let command = ref 1L
let str_of_e = function let str_of_e = function
@ -199,13 +197,6 @@ let rec read_sock_write_tcp drop c ?fd addr =
safe_close c >|= fun () -> safe_close c >|= fun () ->
true true
| Ok (hdr, `Data (`Stats_data (ru, mem, vmm, ifs))) -> | Ok (hdr, `Data (`Stats_data (ru, mem, vmm, ifs))) ->
begin
if not (Vmm_commands.version_eq hdr.Vmm_commands.version my_version) then begin
Logs.err (fun m -> m "unknown wire protocol version") ;
safe_close fd >>= fun () ->
safe_close c >|= fun () ->
false
end else
let name = let name =
let orig = hdr.Vmm_commands.name let orig = hdr.Vmm_commands.name
and f = if drop then Name.drop_front else (fun a -> a) and f = if drop then Name.drop_front else (fun a -> a)
@ -220,6 +211,7 @@ let rec read_sock_write_tcp drop c ?fd addr =
let taps = List.map (P.encode_if name) ifs in let taps = List.map (P.encode_if name) ifs in
let out = (String.concat ~sep:"\n" (ru :: mem @ vmm @ taps)) ^ "\n" in let out = (String.concat ~sep:"\n" (ru :: mem @ vmm @ taps)) ^ "\n" in
Logs.debug (fun m -> m "writing %d via tcp" (String.length out)) ; Logs.debug (fun m -> m "writing %d via tcp" (String.length out)) ;
begin
Vmm_lwt.write_raw fd (Bytes.unsafe_of_string out) >>= function Vmm_lwt.write_raw fd (Bytes.unsafe_of_string out) >>= function
| Ok () -> | Ok () ->
Logs.debug (fun m -> m "wrote successfully") ; Logs.debug (fun m -> m "wrote successfully") ;
@ -236,7 +228,7 @@ let rec read_sock_write_tcp drop c ?fd addr =
read_sock_write_tcp drop c ?fd addr read_sock_write_tcp drop c ?fd addr
let query_sock vm c = let query_sock vm c =
let header = Vmm_commands.{ version = my_version ; sequence = !command ; name = vm } in let header = Vmm_commands.header ~sequence:!command vm in
command := Int64.succ !command ; command := Int64.succ !command ;
Logs.debug (fun m -> m "%Lu requesting %a via socket" !command Name.pp vm) ; Logs.debug (fun m -> m "%Lu requesting %a via socket" !command Name.pp vm) ;
Vmm_lwt.write_wire c (header, `Command (`Stats_cmd `Stats_subscribe)) Vmm_lwt.write_wire c (header, `Command (`Stats_cmd `Stats_subscribe))
@ -306,7 +298,7 @@ let cmd =
`P "$(tname) connects to a albatross stats socket, pulls statistics and pushes them via TCP to influxdb" ] `P "$(tname) connects to a albatross stats socket, pulls statistics and pushes them via TCP to influxdb" ]
in in
Term.(term_result (const run_client $ setup_log $ influx $ opt_vm_name $ drop_label)), Term.(term_result (const run_client $ setup_log $ influx $ opt_vm_name $ drop_label)),
Term.info "albatross_influx" ~version:"%%VERSION_NUM%%" ~doc ~man Term.info "albatross_influx" ~version ~doc ~man
let () = let () =
match Term.eval cmd match Term.eval cmd

View file

@ -10,11 +10,10 @@
open Lwt.Infix open Lwt.Infix
let my_version = `AV4
let broadcast prefix wire t = let broadcast prefix wire t =
Lwt_list.fold_left_s (fun t (id, s) -> Lwt_list.fold_left_s (fun t (id, (version, s)) ->
Vmm_lwt.write_wire s wire >|= function let hdr = Vmm_commands.header ~version prefix in
Vmm_lwt.write_wire s (hdr, wire) >|= function
| Ok () -> t | Ok () -> t
| Error `Exception -> Vmm_trie.remove id t) | Error `Exception -> Vmm_trie.remove id t)
t (Vmm_trie.collect prefix t) t (Vmm_trie.collect prefix t)
@ -47,7 +46,7 @@ let write_to_file mvar file =
get_fd () >>= fun fd -> get_fd () >>= fun fd ->
loop ~log_entry:(Ptime_clock.now (), `Hup) fd loop ~log_entry:(Ptime_clock.now (), `Hup) fd
| `Entry log_entry -> Lwt.return log_entry) >>= fun log_entry -> | `Entry log_entry -> Lwt.return log_entry) >>= fun log_entry ->
let data = Vmm_asn.log_to_disk my_version log_entry in let data = Vmm_asn.log_to_disk log_entry in
Lwt.catch Lwt.catch
(fun () -> (fun () ->
write_complete fd data >>= fun () -> write_complete fd data >>= fun () ->
@ -67,7 +66,7 @@ let write_to_file mvar file =
loop fd >|= fun _ -> loop fd >|= fun _ ->
() ()
let send_history s ring id what = let send_history s version ring id what =
let tst event = let tst event =
let sub = Vmm_core.Log.name event in let sub = Vmm_core.Log.name event in
Vmm_core.Name.is_sub ~super:id ~sub Vmm_core.Name.is_sub ~super:id ~sub
@ -81,7 +80,7 @@ let send_history s ring id what =
Lwt_list.fold_left_s (fun r (ts, event) -> Lwt_list.fold_left_s (fun r (ts, event) ->
match r with match r with
| Ok () -> | Ok () ->
let header = Vmm_commands.{ version = my_version ; sequence = 0L ; name = id } in let header = Vmm_commands.header ~version id in
Vmm_lwt.write_wire s (header, `Data (`Log_data (ts, event))) Vmm_lwt.write_wire s (header, `Data (`Log_data (ts, event)))
| Error e -> Lwt.return (Error e)) | Error e -> Lwt.return (Error e))
(Ok ()) elements (Ok ()) elements
@ -89,17 +88,12 @@ let send_history s ring id what =
let tree = ref Vmm_trie.empty let tree = ref Vmm_trie.empty
let handle_data s mvar ring hdr entry = let handle_data s mvar ring hdr entry =
if not (Vmm_commands.version_eq hdr.Vmm_commands.version my_version) then begin
Logs.warn (fun m -> m "unsupported version") ;
Lwt.return_unit
end else begin
Vmm_lwt.write_wire s (hdr, `Success `Empty) >>= fun _ -> Vmm_lwt.write_wire s (hdr, `Success `Empty) >>= fun _ ->
Vmm_ring.write ring entry ; Vmm_ring.write ring entry ;
Lwt_mvar.put mvar (`Entry entry) >>= fun () -> Lwt_mvar.put mvar (`Entry entry) >>= fun () ->
let data' = (hdr, `Data (`Log_data entry)) in let data' = `Data (`Log_data entry) in
broadcast hdr.Vmm_commands.name data' !tree >|= fun tree' -> broadcast hdr.Vmm_commands.name data' !tree >|= fun tree' ->
tree := tree' tree := tree'
end
let read_data mvar ring s = let read_data mvar ring s =
let rec loop () = let rec loop () =
@ -122,27 +116,24 @@ let handle mvar ring s addr =
| Error _ -> | Error _ ->
Logs.err (fun m -> m "error while reading") ; Logs.err (fun m -> m "error while reading") ;
Lwt.return_unit Lwt.return_unit
| Ok (hdr, `Data (`Log_data entry)) -> | Ok (hdr, `Data `Log_data entry) ->
handle_data s mvar ring hdr entry >>= fun () -> handle_data s mvar ring hdr entry >>= fun () ->
read_data mvar ring s read_data mvar ring s
| Ok (hdr, `Command (`Log_cmd lc)) -> | Ok (hdr, `Command `Log_cmd `Log_subscribe ts) ->
if not (Vmm_commands.version_eq hdr.Vmm_commands.version my_version) then begin let tree', ret =
Logs.warn (fun m -> m "unsupported version") ; Vmm_trie.insert hdr.Vmm_commands.name (hdr.Vmm_commands.version, s) !tree
Lwt.return_unit in
end else begin
match lc with
| `Log_subscribe ts ->
let tree', ret = Vmm_trie.insert hdr.Vmm_commands.name s !tree in
tree := tree' ; tree := tree' ;
(match ret with (match ret with
| None -> Lwt.return_unit | None -> Lwt.return_unit
| Some s' -> Vmm_lwt.safe_close s') >>= fun () -> | Some (_, s') -> Vmm_lwt.safe_close s') >>= fun () ->
let out = `Success `Empty in let out = `Success `Empty in
begin
Vmm_lwt.write_wire s (hdr, out) >>= function Vmm_lwt.write_wire s (hdr, out) >>= function
| Error _ -> Logs.err (fun m -> m "error while sending reply for subscribe") ; | Error _ -> Logs.err (fun m -> m "error while sending reply for subscribe") ;
Lwt.return_unit Lwt.return_unit
| Ok () -> | Ok () ->
send_history s ring hdr.Vmm_commands.name ts >>= function send_history s hdr.Vmm_commands.version ring hdr.Vmm_commands.name ts >>= function
| Error _ -> Logs.err (fun m -> m "error while sending history") ; Lwt.return_unit | Error _ -> Logs.err (fun m -> m "error while sending history") ; Lwt.return_unit
| Ok () -> | Ok () ->
(* command processing is finished, but we leave the socket open (* command processing is finished, but we leave the socket open
@ -201,6 +192,6 @@ let read_only =
let cmd = let cmd =
Term.(const jump $ setup_log $ file $ read_only $ influx), Term.(const jump $ setup_log $ file $ read_only $ influx),
Term.info "albatross_log" ~version:"%%VERSION_NUM%%" Term.info "albatross_log" ~version
let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1 let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1

View file

@ -6,11 +6,8 @@ open Vmm_core
open Lwt.Infix open Lwt.Infix
let version = `AV4 let state = ref (Vmm_vmmd.init ())
let state = ref (Vmm_vmmd.init version)
let stub_hdr = Vmm_commands.{ version ; sequence = 0L ; name = Name.root }
let stub_data_out _ = Lwt.return_unit let stub_data_out _ = Lwt.return_unit
let create_lock = Lwt_mutex.create () let create_lock = Lwt_mutex.create ()
@ -18,11 +15,11 @@ let create_lock = Lwt_mutex.create ()
Vmm_vmmd.handle is getting called, and while communicating via log / Vmm_vmmd.handle is getting called, and while communicating via log /
console / stat socket communication. *) console / stat socket communication. *)
let rec create stat_out log_out cons_out data_out hdr name config = let rec create stat_out log_out cons_out data_out name config =
(match Vmm_vmmd.handle_create !state hdr name config with (match Vmm_vmmd.handle_create !state name config with
| Error `Msg msg -> | Error `Msg msg ->
Logs.err (fun m -> m "failed to create %a: %s" Name.pp name msg) ; Logs.err (fun m -> m "failed to create %a: %s" Name.pp name msg) ;
Lwt.return (None, (hdr, `Failure msg)) Lwt.return (None, `Failure msg)
| Ok (state', (cons, succ_cont, fail_cont)) -> | Ok (state', (cons, succ_cont, fail_cont)) ->
state := state'; state := state';
cons_out "create" cons >>= function cons_out "create" cons >>= function
@ -43,7 +40,7 @@ let rec create stat_out log_out cons_out data_out hdr name config =
if should_restart config name r then if should_restart config name r then
Lwt_mutex.with_lock create_lock (fun () -> Lwt_mutex.with_lock create_lock (fun () ->
create stat_out log_out cons_out stub_data_out create stat_out log_out cons_out stub_data_out
stub_hdr name vm.Unikernel.config) name vm.Unikernel.config)
else else
Lwt.return_unit)); Lwt.return_unit));
stat_out "setting up stat" stat >>= fun () -> stat_out "setting up stat" stat >>= fun () ->
@ -68,29 +65,28 @@ let rec create stat_out log_out cons_out data_out hdr name config =
let handle log_out cons_out stat_out fd addr = let handle log_out cons_out stat_out fd addr =
Logs.debug (fun m -> m "connection from %a" Vmm_lwt.pp_sockaddr addr) ; Logs.debug (fun m -> m "connection from %a" Vmm_lwt.pp_sockaddr addr) ;
let out wire =
(* TODO should we terminate the connection on write failure? *)
Vmm_lwt.write_wire fd wire >|= fun _ -> ()
in
let rec loop () = let rec loop () =
Logs.debug (fun m -> m "now reading") ; Logs.debug (fun m -> m "now reading") ;
Vmm_lwt.read_wire fd >>= function Vmm_lwt.read_wire fd >>= function
| Error _ -> | Error _ ->
Logs.err (fun m -> m "error while reading") ; Logs.err (fun m -> m "error while reading") ;
Lwt.return_unit Lwt.return_unit
| Ok wire -> | Ok (hdr, wire) ->
Logs.debug (fun m -> m "read %a" Vmm_commands.pp_wire wire) ; let out wire' =
(* TODO should we terminate the connection on write failure? *)
Vmm_lwt.write_wire fd (hdr, wire') >|= fun _ -> ()
in
Logs.debug (fun m -> m "read %a" Vmm_commands.pp_wire (hdr, wire));
Lwt_mutex.lock create_lock >>= fun () -> Lwt_mutex.lock create_lock >>= fun () ->
match Vmm_vmmd.handle_command !state wire with match Vmm_vmmd.handle_command !state (hdr, wire) with
| Error wire -> Lwt_mutex.unlock create_lock; out wire | Error wire' -> Lwt_mutex.unlock create_lock; out wire'
| Ok (state', next) -> | Ok (state', next) ->
state := state' ; state := state' ;
match next with match next with
| `Loop wire -> Lwt_mutex.unlock create_lock; out wire >>= loop | `Loop wire -> Lwt_mutex.unlock create_lock; out wire >>= loop
| `End wire -> Lwt_mutex.unlock create_lock; out wire | `End wire -> Lwt_mutex.unlock create_lock; out wire
| `Create (hdr, id, vm) -> | `Create (id, vm) ->
create stat_out log_out cons_out out hdr id vm >|= fun () -> create stat_out log_out cons_out out id vm >|= fun () ->
Lwt_mutex.unlock create_lock Lwt_mutex.unlock create_lock
| `Wait (who, data) -> | `Wait (who, data) ->
let state', task = Vmm_vmmd.register !state who Lwt.task in let state', task = Vmm_vmmd.register !state who Lwt.task in
@ -98,38 +94,32 @@ let handle log_out cons_out stat_out fd addr =
Lwt_mutex.unlock create_lock; Lwt_mutex.unlock create_lock;
task >>= fun r -> task >>= fun r ->
out (data r) out (data r)
| `Wait_and_create (who, (hdr, id, vm)) -> | `Wait_and_create (who, (id, vm)) ->
let state', task = Vmm_vmmd.register !state who Lwt.task in let state', task = Vmm_vmmd.register !state who Lwt.task in
state := state'; state := state';
Lwt_mutex.unlock create_lock; Lwt_mutex.unlock create_lock;
task >>= fun r -> task >>= fun r ->
Logs.info (fun m -> m "wait returned %a" pp_process_exit r); Logs.info (fun m -> m "wait returned %a" pp_process_exit r);
Lwt_mutex.with_lock create_lock (fun () -> Lwt_mutex.with_lock create_lock (fun () ->
create stat_out log_out cons_out out hdr id vm) create stat_out log_out cons_out out id vm)
in in
loop () >>= fun () -> loop () >>= fun () ->
Vmm_lwt.safe_close fd Vmm_lwt.safe_close fd
let write_reply name fd txt (header, cmd) = let write_reply name fd txt (hdr, cmd) =
Vmm_lwt.write_wire fd (header, cmd) >>= function Vmm_lwt.write_wire fd (hdr, cmd) >>= function
| Error `Exception -> invalid_arg ("exception during " ^ txt ^ " while writing to " ^ name) | Error `Exception ->
invalid_arg ("exception during " ^ txt ^ " while writing to " ^ name)
| Ok () -> | Ok () ->
Vmm_lwt.read_wire fd >|= function Vmm_lwt.read_wire fd >|= function
| Ok (header', reply) -> | Ok (hdr', reply) ->
if not Vmm_commands.(version_eq header.version header'.version) then begin if not Vmm_commands.(Int64.equal hdr.sequence hdr'.sequence) then begin
Logs.err (fun m -> m "%s: wrong version (got %a, expected %a) in reply from %s"
txt
Vmm_commands.pp_version header'.Vmm_commands.version
Vmm_commands.pp_version header.Vmm_commands.version
name) ;
invalid_arg "bad version received"
end else if not Vmm_commands.(Int64.equal header.sequence header'.sequence) then begin
Logs.err (fun m -> m "%s: wrong id %Lu (expected %Lu) in reply from %s" Logs.err (fun m -> m "%s: wrong id %Lu (expected %Lu) in reply from %s"
txt header'.Vmm_commands.sequence header.Vmm_commands.sequence name) ; txt hdr'.Vmm_commands.sequence hdr.Vmm_commands.sequence name) ;
invalid_arg "wrong sequence number received" invalid_arg "wrong sequence number received"
end else begin end else begin
Logs.debug (fun m -> m "%s: received valid reply from %s %a (request %a)" Logs.debug (fun m -> m "%s: received valid reply from %s %a (request %a)"
txt name Vmm_commands.pp_wire (header', reply) Vmm_commands.pp_wire (header,cmd)) ; txt name Vmm_commands.pp_wire (hdr', reply) Vmm_commands.pp_wire (hdr, cmd)) ;
match reply with match reply with
| `Success _ -> Ok () | `Success _ -> Ok ()
| `Failure msg -> | `Failure msg ->
@ -187,7 +177,7 @@ let jump _ influx =
in in
Lwt_list.iter_s (fun (name, config) -> Lwt_list.iter_s (fun (name, config) ->
create stat_out log_out cons_out stub_data_out stub_hdr name config) create stat_out log_out cons_out stub_data_out name config)
(Vmm_trie.all old_unikernels) >>= fun () -> (Vmm_trie.all old_unikernels) >>= fun () ->
Lwt.catch (fun () -> Lwt.catch (fun () ->
@ -209,6 +199,6 @@ open Cmdliner
let cmd = let cmd =
Term.(const jump $ setup_log $ influx), Term.(const jump $ setup_log $ influx),
Term.info "albatrossd" ~version:"%%VERSION_NUM%%" Term.info "albatrossd" ~version:Albatross_cli.version
let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1 let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1

View file

@ -1,7 +1,5 @@
(* (c) 2017 Hannes Mehnert, all rights reserved *) (* (c) 2017 Hannes Mehnert, all rights reserved *)
let asn_version = `AV4
let timestamps validity = let timestamps validity =
let now = Ptime_clock.now () in let now = Ptime_clock.now () in
match Ptime.add_span now (Ptime.Span.of_int_s (Duration.to_sec validity)) with match Ptime.add_span now (Ptime.Span.of_int_s (Duration.to_sec validity)) with

View file

@ -40,10 +40,10 @@ let sign_csr dbname cacert key csr days =
(* TODO: check delegation! verify whitelisted commands!? *) (* TODO: check delegation! verify whitelisted commands!? *)
match albatross_extension csr with match albatross_extension csr with
| Ok v -> | Ok v ->
Vmm_asn.cert_extension_of_cstruct v >>= fun (version, cmd) -> Vmm_asn.of_cert_extension v >>= fun (version, cmd) ->
if not (Vmm_commands.version_eq asn_version version) then if not Vmm_commands.(is_current version) then
Logs.warn (fun m -> m "version in request (%a) different from our version %a, using ours" Logs.warn (fun m -> m "version in request (%a) different from our version %a, using ours"
Vmm_commands.pp_version version Vmm_commands.pp_version asn_version); Vmm_commands.pp_version version Vmm_commands.pp_version Vmm_commands.current);
let exts, default_days = match cmd with let exts, default_days = match cmd with
| `Policy_cmd (`Policy_add _) -> d_exts (), 365 | `Policy_cmd (`Policy_add _) -> d_exts (), 365
| _ -> l_exts, 1 | _ -> l_exts, 1
@ -53,7 +53,7 @@ let sign_csr dbname cacert key csr days =
(* the "false" is here since X509 validation bails on exts marked as (* the "false" is here since X509 validation bails on exts marked as
critical (as required), but has no way to supply which extensions critical (as required), but has no way to supply which extensions
are actually handled by the application / caller *) are actually handled by the application / caller *)
let v' = Vmm_asn.cert_extension_to_cstruct (asn_version, cmd) in let v' = Vmm_asn.to_cert_extension cmd in
let extensions = Extension.(add (Unsupported Vmm_asn.oid) (false, v') exts) in let extensions = Extension.(add (Unsupported Vmm_asn.oid) (false, v') exts) in
sign ~dbname extensions issuer key csr (Duration.of_day days) sign ~dbname extensions issuer key csr (Duration.of_day days)
| Error e -> Error e | Error e -> Error e
@ -157,7 +157,7 @@ let default_cmd =
`P "$(tname) does CA operations (creation, sign, etc.)" ] `P "$(tname) does CA operations (creation, sign, etc.)" ]
in in
Term.(ret (const help $ setup_log $ Term.man_format $ Term.choice_names $ Term.pure None)), Term.(ret (const help $ setup_log $ Term.man_format $ Term.choice_names $ Term.pure None)),
Term.info "albatross_provision_ca" ~version:"%%VERSION_NUM%%" ~doc ~man Term.info "albatross_provision_ca" ~version ~doc ~man
let cmds = [ help_cmd ; sign_cmd ; generate_cmd ; (* TODO revoke_cmd *)] let cmds = [ help_cmd ; sign_cmd ; generate_cmd ; (* TODO revoke_cmd *)]

View file

@ -5,11 +5,9 @@ open Vmm_asn
open Rresult.R.Infix open Rresult.R.Infix
let version = `AV4
let csr priv name cmd = let csr priv name cmd =
let ext = let ext =
let v = cert_extension_to_cstruct (version, cmd) in let v = to_cert_extension cmd in
X509.Extension.(singleton (Unsupported oid) (false, v)) X509.Extension.(singleton (Unsupported oid) (false, v))
and name = and name =
[ X509.Distinguished_name.(Relative_distinguished_name.singleton (CN name)) ] [ X509.Distinguished_name.(Relative_distinguished_name.singleton (CN name)) ]
@ -199,7 +197,7 @@ let default_cmd =
`P "$(tname) creates a certificate signing request for Albatross" ] `P "$(tname) creates a certificate signing request for Albatross" ]
in in
Term.(ret (const help $ setup_log $ Term.man_format $ Term.choice_names $ Term.pure None)), Term.(ret (const help $ setup_log $ Term.man_format $ Term.choice_names $ Term.pure None)),
Term.info "albatross_provision_request" ~version:"%%VERSION_NUM%%" ~doc ~man Term.info "albatross_provision_request" ~version ~doc ~man
let cmds = [ help_cmd ; info_cmd ; let cmds = [ help_cmd ; info_cmd ;
policy_cmd ; remove_policy_cmd ; add_policy_cmd ; policy_cmd ; remove_policy_cmd ; add_policy_cmd ;

View file

@ -455,12 +455,10 @@ let version =
let f data = match data with let f data = match data with
| 4 -> `AV4 | 4 -> `AV4
| 3 -> `AV3 | 3 -> `AV3
| 2 -> `AV2
| x -> Asn.S.error (`Parse (Printf.sprintf "unknown version number 0x%X" x)) | x -> Asn.S.error (`Parse (Printf.sprintf "unknown version number 0x%X" x))
and g = function and g = function
| `AV4 -> 4 | `AV4 -> 4
| `AV3 -> 3 | `AV3 -> 3
| `AV2 -> 2
in in
Asn.S.map f g Asn.S.int Asn.S.map f g Asn.S.int
@ -602,8 +600,7 @@ let log_disk_of_cstruct, log_disk_to_cstruct =
let c = Asn.codec Asn.der log_disk in let c = Asn.codec Asn.der log_disk in
(Asn.decode c, Asn.encode c) (Asn.decode c, Asn.encode c)
let log_to_disk version entry = let log_to_disk entry = log_disk_to_cstruct (current, entry)
log_disk_to_cstruct (version, entry)
let logs_of_disk buf = let logs_of_disk buf =
let rec next acc buf = let rec next acc buf =
@ -655,12 +652,11 @@ let unikernels =
let unikernels_of_cstruct, unikernels_to_cstruct = projections_of unikernels let unikernels_of_cstruct, unikernels_to_cstruct = projections_of unikernels
type cert_extension = version * t
let cert_extension = let cert_extension =
Asn.S.(sequence2 Asn.S.(sequence2
(required ~label:"version" version) (required ~label:"version" version)
(required ~label:"command" wire_command)) (required ~label:"command" wire_command))
let cert_extension_of_cstruct, cert_extension_to_cstruct = let of_cert_extension, to_cert_extension =
projections_of cert_extension let a, b = projections_of cert_extension in
a, (fun d -> b (current, d))

View file

@ -17,14 +17,13 @@ val log_entry_to_cstruct : Log.t -> Cstruct.t
val log_entry_of_cstruct : Cstruct.t -> (Log.t, [> `Msg of string ]) result val log_entry_of_cstruct : Cstruct.t -> (Log.t, [> `Msg of string ]) result
val log_to_disk : Vmm_commands.version -> Log.t -> Cstruct.t val log_to_disk : Log.t -> Cstruct.t
val logs_of_disk : Cstruct.t -> Log.t list val logs_of_disk : Cstruct.t -> Log.t list
type cert_extension = Vmm_commands.version * Vmm_commands.t val of_cert_extension :
Cstruct.t -> (Vmm_commands.version * Vmm_commands.t, [> `Msg of string ]) result
val cert_extension_of_cstruct : Cstruct.t -> (cert_extension, [> `Msg of string ]) result val to_cert_extension : Vmm_commands.t -> Cstruct.t
val cert_extension_to_cstruct : cert_extension -> Cstruct.t
val unikernels_to_cstruct : Unikernel.config Vmm_trie.t -> Cstruct.t val unikernels_to_cstruct : Unikernel.config Vmm_trie.t -> Cstruct.t
val unikernels_of_cstruct : Cstruct.t -> (Unikernel.config Vmm_trie.t, [> `Msg of string ]) result val unikernels_of_cstruct : Cstruct.t -> (Unikernel.config Vmm_trie.t, [> `Msg of string ]) result

View file

@ -3,22 +3,24 @@
(* the wire protocol *) (* the wire protocol *)
open Vmm_core open Vmm_core
type version = [ `AV2 | `AV3 | `AV4 ] type version = [ `AV3 | `AV4 ]
let current = `AV4
let pp_version ppf v = let pp_version ppf v =
Fmt.int ppf Fmt.int ppf
(match v with (match v with
| `AV4 -> 4 | `AV4 -> 4
| `AV3 -> 3 | `AV3 -> 3)
| `AV2 -> 2)
let version_eq a b = let version_eq a b =
match a, b with match a, b with
| `AV4, `AV4 -> true | `AV4, `AV4 -> true
| `AV3, `AV3 -> true | `AV3, `AV3 -> true
| `AV2, `AV2 -> true
| _ -> false | _ -> false
let is_current = version_eq current
type since_count = [ `Since of Ptime.t | `Count of int ] type since_count = [ `Since of Ptime.t | `Count of int ]
let pp_since_count ppf = function let pp_since_count ppf = function
@ -124,6 +126,8 @@ type header = {
name : Name.t ; name : Name.t ;
} }
let header ?(version = current) ?(sequence = 0L) name = { version ; sequence ; name }
type success = [ type success = [
| `Empty | `Empty
| `String of string | `String of string
@ -142,11 +146,14 @@ let pp_success ppf = function
| `Unikernels vms -> Fmt.(list ~sep:(unit "@.") (pair ~sep:(unit ": ") Name.pp Unikernel.pp_config)) ppf vms | `Unikernels vms -> Fmt.(list ~sep:(unit "@.") (pair ~sep:(unit ": ") Name.pp Unikernel.pp_config)) ppf vms
| `Block_devices blocks -> Fmt.(list ~sep:(unit "@.") pp_block) ppf blocks | `Block_devices blocks -> Fmt.(list ~sep:(unit "@.") pp_block) ppf blocks
type wire = header * [ type res = [
| `Command of t | `Command of t
| `Success of success | `Success of success
| `Failure of string | `Failure of string
| `Data of data ] | `Data of data
]
type wire = header * res
let pp_wire ppf (header, data) = let pp_wire ppf (header, data) =
let name = header.name in let name = header.name in

View file

@ -3,10 +3,12 @@
open Vmm_core open Vmm_core
(** The type of versions of the grammar defined below. *) (** The type of versions of the grammar defined below. *)
type version = [ `AV2 | `AV3 | `AV4 ] type version = [ `AV3 | `AV4 ]
(** [version_eq a b] is true if [a] and [b] are equal. *) (** [current] is the current version. *)
val version_eq : version -> version -> bool val current : version
val is_current : version -> bool
(** [pp_version ppf version] pretty prints [version] onto [ppf]. *) (** [pp_version ppf version] pretty prints [version] onto [ppf]. *)
val pp_version : version Fmt.t val pp_version : version Fmt.t
@ -72,6 +74,8 @@ type header = {
name : Name.t ; name : Name.t ;
} }
val header : ?version:version -> ?sequence:int64 -> Name.t -> header
type success = [ type success = [
| `Empty | `Empty
| `String of string | `String of string
@ -80,11 +84,14 @@ type success = [
| `Block_devices of (Name.t * int * bool) list | `Block_devices of (Name.t * int * bool) list
] ]
type wire = header * [ type res = [
| `Command of t | `Command of t
| `Success of success | `Success of success
| `Failure of string | `Failure of string
| `Data of data ] | `Data of data
]
type wire = header * res
val pp_wire : wire Fmt.t val pp_wire : wire Fmt.t

View file

@ -104,10 +104,15 @@ let read_wire s =
Cstruct.hexdump_pp (Cstruct.of_bytes buf) Cstruct.hexdump_pp (Cstruct.of_bytes buf)
Cstruct.hexdump_pp (Cstruct.of_bytes b)) ; *) Cstruct.hexdump_pp (Cstruct.of_bytes b)) ; *)
match Vmm_asn.wire_of_cstruct (Cstruct.of_bytes b) with match Vmm_asn.wire_of_cstruct (Cstruct.of_bytes b) with
| Ok w -> Ok w
| Error (`Msg msg) -> | Error (`Msg msg) ->
Logs.err (fun m -> m "error %s while parsing data" msg) ; Logs.err (fun m -> m "error %s while parsing data" msg) ;
Error `Exception Error `Exception
| (Ok (hdr, _)) as w ->
if not Vmm_commands.(is_current hdr.version) then
Logs.warn (fun m -> m "version mismatch, received %a current %a"
Vmm_commands.pp_version hdr.Vmm_commands.version
Vmm_commands.pp_version Vmm_commands.current);
w
end else begin end else begin
Lwt.return (Error `Eof) Lwt.return (Error `Eof)
end end

View file

@ -8,7 +8,6 @@ open Rresult
open R.Infix open R.Infix
type 'a t = { type 'a t = {
wire_version : Vmm_commands.version ;
console_counter : int64 ; console_counter : int64 ;
stats_counter : int64 ; stats_counter : int64 ;
log_counter : int64 ; log_counter : int64 ;
@ -64,9 +63,8 @@ let register_restart t id create =
| Some _ -> Logs.err (fun m -> m "restart attempted to overwrite waiter"); None | Some _ -> Logs.err (fun m -> m "restart attempted to overwrite waiter"); None
| _ -> Some (register t id create) | _ -> Some (register t id create)
let init wire_version = let init () =
let t = { let t = {
wire_version ;
console_counter = 1L ; console_counter = 1L ;
stats_counter = 1L ; stats_counter = 1L ;
log_counter = 1L ; log_counter = 1L ;
@ -91,12 +89,12 @@ let init wire_version =
type 'a create = type 'a create =
Vmm_commands.wire * Vmm_commands.wire *
('a t -> ('a t * Vmm_commands.wire * Vmm_commands.wire * Vmm_commands.wire * Name.t * Unikernel.t, [ `Msg of string ]) result) * ('a t -> ('a t * Vmm_commands.wire * Vmm_commands.wire * Vmm_commands.res * Name.t * Unikernel.t, [ `Msg of string ]) result) *
(unit -> Vmm_commands.wire) (unit -> Vmm_commands.res)
let log t name event = let log t name event =
let data = (Ptime_clock.now (), event) in let data = (Ptime_clock.now (), event) in
let header = Vmm_commands.{ version = t.wire_version ; sequence = t.log_counter ; name } in let header = Vmm_commands.header ~sequence:t.log_counter name in
let log_counter = Int64.succ t.log_counter in let log_counter = Int64.succ t.log_counter in
Logs.debug (fun m -> m "log %a" Log.pp data) ; Logs.debug (fun m -> m "log %a" Log.pp data) ;
({ t with log_counter }, (header, `Data (`Log_data data))) ({ t with log_counter }, (header, `Data (`Log_data data)))
@ -123,16 +121,16 @@ let setup_stats t name vm =
in in
`Stats_add (name, vm.Unikernel.pid, ifs) `Stats_add (name, vm.Unikernel.pid, ifs)
in in
let header = Vmm_commands.{ version = t.wire_version ; sequence = t.stats_counter ; name } in let header = Vmm_commands.header ~sequence:t.stats_counter name in
let t = { t with stats_counter = Int64.succ t.stats_counter } in let t = { t with stats_counter = Int64.succ t.stats_counter } in
t, (header, `Command (`Stats_cmd stat_out)) t, (header, `Command (`Stats_cmd stat_out))
let remove_stats t name = let remove_stats t name =
let header = Vmm_commands.{ version = t.wire_version ; sequence = t.stats_counter ; name } in let header = Vmm_commands.header ~sequence:t.stats_counter name in
let t = { t with stats_counter = Int64.succ t.stats_counter } in let t = { t with stats_counter = Int64.succ t.stats_counter } in
(t, (header, `Command (`Stats_cmd `Stats_remove))) (t, (header, `Command (`Stats_cmd `Stats_remove)))
let handle_create t hdr name vm_config = let handle_create t name vm_config =
(match Vmm_resources.find_vm t.resources name with (match Vmm_resources.find_vm t.resources name with
| Some _ -> Error (`Msg "VM with same name is already running") | Some _ -> Error (`Msg "VM with same name is already running")
| None -> Ok ()) >>= fun () -> | None -> Ok ()) >>= fun () ->
@ -142,7 +140,7 @@ let handle_create t hdr name vm_config =
Vmm_unix.prepare name vm_config >>= fun taps -> Vmm_unix.prepare name vm_config >>= fun taps ->
Logs.debug (fun m -> m "prepared vm with taps %a" Fmt.(list ~sep:(unit ",@ ") string) taps) ; Logs.debug (fun m -> m "prepared vm with taps %a" Fmt.(list ~sep:(unit ",@ ") string) taps) ;
let cons_out = let cons_out =
let header = Vmm_commands.{ version = t.wire_version ; sequence = t.console_counter ; name } in let header = Vmm_commands.header ~sequence:t.console_counter name in
(header, `Command (`Console_cmd `Console_add)) (header, `Command (`Console_cmd `Console_add))
in in
let success t = let success t =
@ -170,13 +168,13 @@ let handle_create t hdr name vm_config =
log t name start log t name start
in in
let t, stat_out = setup_stats t name vm in let t, stat_out = setup_stats t name vm in
(t, stat_out, log_out, (hdr, `Success (`String "created VM")), name, vm) (t, stat_out, log_out, `Success (`String "created VM"), name, vm)
and fail () = and fail () =
match Vmm_unix.free_system_resources name taps with match Vmm_unix.free_system_resources name taps with
| Ok () -> (hdr, `Failure "could not create VM: console failed") | Ok () -> `Failure "could not create VM: console failed"
| Error (`Msg msg) -> | Error (`Msg msg) ->
let m = "could not create VM: console failed, and also " ^ msg ^ " while cleaning resources" in let m = "could not create VM: console failed, and also " ^ msg ^ " while cleaning resources" in
(hdr, `Failure m) `Failure m
in in
Ok ({ t with console_counter = Int64.succ t.console_counter }, Ok ({ t with console_counter = Int64.succ t.console_counter },
(cons_out, success, fail)) (cons_out, success, fail))
@ -190,11 +188,11 @@ let handle_shutdown t name vm r =
let t, stat_out = remove_stats t name in let t, stat_out = remove_stats t name in
(t, stat_out, log_out) (t, stat_out, log_out)
let handle_policy_cmd t reply id = function let handle_policy_cmd t id = function
| `Policy_remove -> | `Policy_remove ->
Logs.debug (fun m -> m "remove policy %a" Name.pp id) ; Logs.debug (fun m -> m "remove policy %a" Name.pp id) ;
Vmm_resources.remove_policy t.resources id >>= fun resources -> Vmm_resources.remove_policy t.resources id >>= fun resources ->
Ok ({ t with resources }, `End (reply (`String "removed policy"))) Ok ({ t with resources }, `End (`Success (`String "removed policy")))
| `Policy_add policy -> | `Policy_add policy ->
Logs.debug (fun m -> m "insert policy %a" Name.pp id) ; Logs.debug (fun m -> m "insert policy %a" Name.pp id) ;
let same_policy = match Vmm_resources.find_policy t.resources id with let same_policy = match Vmm_resources.find_policy t.resources id with
@ -202,10 +200,10 @@ let handle_policy_cmd t reply id = function
| Some p' -> Policy.equal policy p' | Some p' -> Policy.equal policy p'
in in
if same_policy then if same_policy then
Ok (t, `Loop (reply (`String "no modification of policy"))) Ok (t, `Loop (`Success (`String "no modification of policy")))
else else
Vmm_resources.insert_policy t.resources id policy >>= fun resources -> Vmm_resources.insert_policy t.resources id policy >>= fun resources ->
Ok ({ t with resources }, `Loop (reply (`String "added policy"))) Ok ({ t with resources }, `Loop (`Success (`String "added policy")))
| `Policy_info -> | `Policy_info ->
Logs.debug (fun m -> m "policy %a" Name.pp id) ; Logs.debug (fun m -> m "policy %a" Name.pp id) ;
let policies = let policies =
@ -218,9 +216,9 @@ let handle_policy_cmd t reply id = function
Logs.debug (fun m -> m "policies: couldn't find %a" Name.pp id) ; Logs.debug (fun m -> m "policies: couldn't find %a" Name.pp id) ;
Error (`Msg "policy: not found") Error (`Msg "policy: not found")
| _ -> | _ ->
Ok (t, `End (reply (`Policies policies))) Ok (t, `End (`Success (`Policies policies)))
let handle_unikernel_cmd t reply header id = function let handle_unikernel_cmd t id = function
| `Unikernel_info -> | `Unikernel_info ->
Logs.debug (fun m -> m "info %a" Name.pp id) ; Logs.debug (fun m -> m "info %a" Name.pp id) ;
let vms = let vms =
@ -233,9 +231,9 @@ let handle_unikernel_cmd t reply header id = function
Logs.debug (fun m -> m "info: couldn't find %a" Name.pp id) ; Logs.debug (fun m -> m "info: couldn't find %a" Name.pp id) ;
Error (`Msg "info: no unikernel found") Error (`Msg "info: no unikernel found")
| _ -> | _ ->
Ok (t, `End (reply (`Unikernels vms))) Ok (t, `End (`Success (`Unikernels vms)))
end end
| `Unikernel_create vm_config -> Ok (t, `Create (header, id, vm_config)) | `Unikernel_create vm_config -> Ok (t, `Create (id, vm_config))
| `Unikernel_force_create vm_config -> | `Unikernel_force_create vm_config ->
begin begin
let resources = let resources =
@ -244,12 +242,12 @@ let handle_unikernel_cmd t reply header id = function
in in
Vmm_resources.check_vm resources id vm_config >>= fun () -> Vmm_resources.check_vm resources id vm_config >>= fun () ->
match Vmm_resources.find_vm t.resources id with match Vmm_resources.find_vm t.resources id with
| None -> Ok (t, `Create (header, id, vm_config)) | None -> Ok (t, `Create (id, vm_config))
| Some vm -> | Some vm ->
(match Vmm_unix.destroy vm with (match Vmm_unix.destroy vm with
| exception Unix.Unix_error _ -> () | exception Unix.Unix_error _ -> ()
| () -> ()); | () -> ());
Ok (t, `Wait_and_create (id, (header, id, vm_config))) Ok (t, `Wait_and_create (id, (id, vm_config)))
end end
| `Unikernel_destroy -> | `Unikernel_destroy ->
match Vmm_resources.find_vm t.resources id with match Vmm_resources.find_vm t.resources id with
@ -263,11 +261,11 @@ let handle_unikernel_cmd t reply header id = function
in in
let s ex = let s ex =
let data = Fmt.strf "%a %s %a" Name.pp id answer pp_process_exit ex in let data = Fmt.strf "%a %s %a" Name.pp id answer pp_process_exit ex in
reply (`String data) `Success (`String data)
in in
Ok (t, `Wait (id, s)) Ok (t, `Wait (id, s))
let handle_block_cmd t reply id = function let handle_block_cmd t id = function
| `Block_remove -> | `Block_remove ->
Logs.debug (fun m -> m "removing block %a" Name.pp id) ; Logs.debug (fun m -> m "removing block %a" Name.pp id) ;
begin match Vmm_resources.find_block t.resources id with begin match Vmm_resources.find_block t.resources id with
@ -276,7 +274,7 @@ let handle_block_cmd t reply id = function
| Some (_, false) -> | Some (_, false) ->
Vmm_unix.destroy_block id >>= fun () -> Vmm_unix.destroy_block id >>= fun () ->
Vmm_resources.remove_block t.resources id >>= fun resources -> Vmm_resources.remove_block t.resources id >>= fun resources ->
Ok ({ t with resources }, `End (reply (`String "removed block"))) Ok ({ t with resources }, `End (`Success (`String "removed block")))
end end
| `Block_add size -> | `Block_add size ->
begin begin
@ -287,7 +285,7 @@ let handle_block_cmd t reply id = function
Vmm_resources.check_block t.resources id size >>= fun () -> Vmm_resources.check_block t.resources id size >>= fun () ->
Vmm_unix.create_block id size >>= fun () -> Vmm_unix.create_block id size >>= fun () ->
Vmm_resources.insert_block t.resources id size >>= fun resources -> Vmm_resources.insert_block t.resources id size >>= fun resources ->
Ok ({ t with resources }, `Loop (reply (`String "added block device"))) Ok ({ t with resources }, `Loop (`Success (`String "added block device")))
end end
| `Block_info -> | `Block_info ->
Logs.debug (fun m -> m "block %a" Name.pp id) ; Logs.debug (fun m -> m "block %a" Name.pp id) ;
@ -301,22 +299,21 @@ let handle_block_cmd t reply id = function
Logs.debug (fun m -> m "block: couldn't find %a" Name.pp id) ; Logs.debug (fun m -> m "block: couldn't find %a" Name.pp id) ;
Error (`Msg "block: not found") Error (`Msg "block: not found")
| _ -> | _ ->
Ok (t, `End (reply (`Block_devices blocks))) Ok (t, `End (`Success (`Block_devices blocks)))
let handle_command t (header, payload) = let handle_command t (header, payload) =
let msg_to_err = function let msg_to_err = function
| Ok x -> Ok x | Ok x -> Ok x
| Error (`Msg msg) -> | Error (`Msg msg) ->
Logs.err (fun m -> m "error while processing command: %s" msg) ; Logs.err (fun m -> m "error while processing command: %s" msg) ;
Error (header, `Failure msg) Error (`Failure msg)
and reply x = (header, `Success x)
and id = header.Vmm_commands.name and id = header.Vmm_commands.name
in in
msg_to_err ( msg_to_err (
match payload with match payload with
| `Command (`Policy_cmd pc) -> handle_policy_cmd t reply id pc | `Command (`Policy_cmd pc) -> handle_policy_cmd t id pc
| `Command (`Unikernel_cmd vc) -> handle_unikernel_cmd t reply header id vc | `Command (`Unikernel_cmd vc) -> handle_unikernel_cmd t id vc
| `Command (`Block_cmd bc) -> handle_block_cmd t reply id bc | `Command (`Block_cmd bc) -> handle_block_cmd t id bc
| _ -> | _ ->
Logs.err (fun m -> m "ignoring %a" Vmm_commands.pp_wire (header, payload)) ; Logs.err (fun m -> m "ignoring %a" Vmm_commands.pp_wire (header, payload)) ;
Error (`Msg "unknown command")) Error (`Msg "unknown command"))

View file

@ -4,7 +4,7 @@ open Vmm_core
type 'a t type 'a t
val init : Vmm_commands.version -> 'a t val init : unit -> 'a t
val waiter : 'a t -> Name.t -> 'a t * 'a option val waiter : 'a t -> Name.t -> 'a t * 'a option
@ -14,25 +14,23 @@ val register_restart : 'a t -> Name.t -> (unit -> 'b * 'a) -> ('a t * 'b) option
type 'a create = type 'a create =
Vmm_commands.wire * Vmm_commands.wire *
('a t -> ('a t * Vmm_commands.wire * Vmm_commands.wire * Vmm_commands.wire * ('a t -> ('a t * Vmm_commands.wire * Vmm_commands.wire * Vmm_commands.res * Name.t * Unikernel.t, [ `Msg of string ]) result) *
Name.t * Unikernel.t, [ `Msg of string ]) result) * (unit -> Vmm_commands.res)
(unit -> Vmm_commands.wire)
val handle_shutdown : 'a t -> Name.t -> Unikernel.t -> val handle_shutdown : 'a t -> Name.t -> Unikernel.t ->
[ `Exit of int | `Signal of int | `Stop of int ] -> 'a t * Vmm_commands.wire * Vmm_commands.wire [ `Exit of int | `Signal of int | `Stop of int ] -> 'a t * Vmm_commands.wire * Vmm_commands.wire
val handle_create : 'a t -> Vmm_commands.header -> val handle_create : 'a t -> Name.t -> Unikernel.config ->
Name.t -> Unikernel.config ->
('a t * 'a create, [> `Msg of string ]) result ('a t * 'a create, [> `Msg of string ]) result
val handle_command : 'a t -> Vmm_commands.wire -> val handle_command : 'a t -> Vmm_commands.wire ->
('a t * ('a t *
[ `Create of Vmm_commands.header * Name.t * Unikernel.config [ `Create of Name.t * Unikernel.config
| `Loop of Vmm_commands.wire | `Loop of Vmm_commands.res
| `End of Vmm_commands.wire | `End of Vmm_commands.res
| `Wait of Name.t * (process_exit -> Vmm_commands.wire) | `Wait of Name.t * (process_exit -> Vmm_commands.res)
| `Wait_and_create of Name.t * (Vmm_commands.header * Name.t * Unikernel.config) ], | `Wait_and_create of Name.t * (Name.t * Unikernel.config) ],
Vmm_commands.wire) result Vmm_commands.res) result
val killall : 'a t -> bool val killall : 'a t -> bool

View file

@ -69,6 +69,6 @@ let vmname =
let cmd = let cmd =
Term.(term_result (const jump $ setup_log $ pid $ vmname $ interval)), Term.(term_result (const jump $ setup_log $ pid $ vmname $ interval)),
Term.info "albatross_stat_client" ~version:"%%VERSION_NUM%%" Term.info "albatross_stat_client" ~version
let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1 let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1

View file

@ -40,7 +40,7 @@ let handle s addr =
Vmm_lwt.write_wire s (fst wire, `Success (`String out)) >>= function Vmm_lwt.write_wire s (fst wire, `Success (`String out)) >>= function
| Ok () -> | Ok () ->
(match close with (match close with
| Some s' -> | Some (_, s') ->
Vmm_lwt.safe_close s' >>= fun () -> Vmm_lwt.safe_close s' >>= fun () ->
(* read the next *) (* read the next *)
loop () loop ()
@ -90,6 +90,6 @@ let interval =
let cmd = let cmd =
Term.(term_result (const jump $ setup_log $ interval $ influx)), Term.(term_result (const jump $ setup_log $ interval $ influx)),
Term.info "albatross_stats" ~version:"%%VERSION_NUM%%" Term.info "albatross_stats" ~version
let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1 let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1

View file

@ -17,8 +17,6 @@ external vmmapi_close : vmctx -> unit = "vmmanage_vmmapi_close"
external vmmapi_statnames : vmctx -> string list = "vmmanage_vmmapi_statnames" external vmmapi_statnames : vmctx -> string list = "vmmanage_vmmapi_statnames"
external vmmapi_stats : vmctx -> int64 list = "vmmanage_vmmapi_stats" external vmmapi_stats : vmctx -> int64 list = "vmmanage_vmmapi_stats"
let my_version = `AV4
let descr = ref [] let descr = ref []
type 'a t = { type 'a t = {
@ -139,11 +137,11 @@ let tick t =
ru', mem, vmm', ifd ru', mem, vmm', ifd
in in
let outs = let outs =
List.fold_left (fun out (id, socket) -> List.fold_left (fun out (id, (version, socket)) ->
match Vmm_core.Name.drop_super ~super:id ~sub:vmid with match Vmm_core.Name.drop_super ~super:id ~sub:vmid with
| None -> Logs.err (fun m -> m "couldn't drop super %a from sub %a" Vmm_core.Name.pp id Vmm_core.Name.pp vmid) ; out | None -> Logs.err (fun m -> m "couldn't drop super %a from sub %a" Vmm_core.Name.pp id Vmm_core.Name.pp vmid) ; out
| Some real_id -> | Some real_id ->
let header = Vmm_commands.{ version = my_version ; sequence = 0L ; name = real_id } in let header = Vmm_commands.header ~version real_id in
((socket, id, (header, `Data (`Stats_data stats))) :: out)) ((socket, id, (header, `Data (`Stats_data stats))) :: out))
out xs out xs
in in
@ -178,17 +176,11 @@ let add_pid t vmid vmmdev pid nics =
assert (ret = None) ; assert (ret = None) ;
Ok { t with pid_nic ; vmid_pid } Ok { t with pid_nic ; vmid_pid }
let handle t socket (header, wire) = let handle t socket (hdr, wire) =
if not (Vmm_commands.version_eq my_version header.Vmm_commands.version) then begin
Logs.err (fun m -> m "invalid version %a (mine is %a)"
Vmm_commands.pp_version header.Vmm_commands.version
Vmm_commands.pp_version my_version) ;
Error (`Msg "cannot handle version")
end else
match wire with match wire with
| `Command (`Stats_cmd cmd) -> | `Command (`Stats_cmd cmd) ->
begin begin
let id = header.Vmm_commands.name in let id = hdr.Vmm_commands.name in
match cmd with match cmd with
| `Stats_add (vmmdev, pid, taps) -> | `Stats_add (vmmdev, pid, taps) ->
add_pid t id vmmdev pid taps >>= fun t -> add_pid t id vmmdev pid taps >>= fun t ->
@ -197,9 +189,11 @@ let handle t socket (header, wire) =
let t = remove_vmid t id in let t = remove_vmid t id in
Ok (t, None, "removed") Ok (t, None, "removed")
| `Stats_subscribe -> | `Stats_subscribe ->
let name_sockets, close = Vmm_trie.insert id socket t.name_sockets in let name_sockets, close =
Vmm_trie.insert id (hdr.Vmm_commands.version, socket) t.name_sockets
in
Ok ({ t with name_sockets }, close, "subscribed") Ok ({ t with name_sockets }, close, "subscribed")
end end
| _ -> | _ ->
Logs.err (fun m -> m "unexpected wire %a" Vmm_commands.pp_wire (header, wire)) ; Logs.err (fun m -> m "unexpected wire %a" Vmm_commands.pp_wire (hdr, wire)) ;
Error (`Msg "unexpected command") Error (`Msg "unexpected command")

View file

@ -2,8 +2,6 @@
open Lwt.Infix open Lwt.Infix
let my_version = `AV4
let command = ref 0L let command = ref 0L
let tls_config cacert cert priv_key = let tls_config cacert cert priv_key =
@ -32,42 +30,41 @@ let client_auth ca tls =
| `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain | `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain
| `Error -> Lwt.fail_with "error while getting epoch") | `Error -> Lwt.fail_with "error while getting epoch")
let read fd tls = let read version fd tls =
(* now we busy read and process output *) (* now we busy read and process output *)
let rec loop () = let rec loop () =
Vmm_lwt.read_wire fd >>= function Vmm_lwt.read_wire fd >>= function
| Error _ -> Lwt.return (Error (`Msg "exception while reading")) | Error _ -> Lwt.return (`Failure "exception while reading from fd")
| Ok wire -> | Ok (hdr, pay) ->
Logs.debug (fun m -> m "read proxying %a" Vmm_commands.pp_wire wire) ; Logs.debug (fun m -> m "read proxying %a" Vmm_commands.pp_wire (hdr, pay)) ;
let wire = { hdr with version }, pay in
Vmm_tls_lwt.write_tls tls wire >>= function Vmm_tls_lwt.write_tls tls wire >>= function
| Ok () -> loop () | Ok () -> loop ()
| Error `Exception -> Lwt.return (Error (`Msg "exception")) | Error `Exception -> Lwt.return (`Failure "exception")
in in
loop () loop ()
let process fd tls = let process fd =
Vmm_lwt.read_wire fd >>= function Vmm_lwt.read_wire fd >|= function
| Error _ -> Lwt.return (Error (`Msg "read error")) | Error _ -> `Failure "error reading from fd"
| Ok wire -> | Ok (hdr, pay) ->
(* TODO check version *) Logs.debug (fun m -> m "proxying %a" Vmm_commands.pp_wire (hdr, pay));
Logs.debug (fun m -> m "proxying %a" Vmm_commands.pp_wire wire) ; pay
Vmm_tls_lwt.write_tls tls wire >|= function
| Ok () -> Ok ()
| Error `Exception -> Error (`Msg "exception on write")
let handle ca tls = let handle ca tls =
client_auth ca tls >>= fun chain -> client_auth ca tls >>= fun chain ->
match Vmm_tls.handle my_version chain with match Vmm_tls.handle chain with
| Error (`Msg m) -> Lwt.fail_with m | Error `Msg msg ->
| Ok (name, policies, cmd) -> Logs.err (fun m -> m "failed to handle TLS connection %s" msg);
Lwt.return_unit
| Ok (name, policies, version, cmd) ->
begin
let sock, next = Vmm_commands.endpoint cmd in let sock, next = Vmm_commands.endpoint cmd in
let sockaddr = Lwt_unix.ADDR_UNIX (Vmm_core.socket_path sock) in let sockaddr = Lwt_unix.ADDR_UNIX (Vmm_core.socket_path sock) in
Vmm_lwt.connect Lwt_unix.PF_UNIX sockaddr >>= function Vmm_lwt.connect Lwt_unix.PF_UNIX sockaddr >>= function
| None -> | None ->
let err = Logs.warn (fun m -> m "failed to connect to %a" Vmm_lwt.pp_sockaddr sockaddr);
Rresult.R.error_msgf "failed to connect to %a" Vmm_lwt.pp_sockaddr sockaddr Lwt.return (`Failure "couldn't reach service")
in
Lwt.return err
| Some fd -> | Some fd ->
(match sock with (match sock with
| `Vmmd -> | `Vmmd ->
@ -76,13 +73,12 @@ let handle ca tls =
| Error (`Msg msg) -> Lwt.return (Error (`Msg msg)) | Error (`Msg msg) -> Lwt.return (Error (`Msg msg))
| Ok () -> | Ok () ->
Logs.debug (fun m -> m "adding policy for %a: %a" Vmm_core.Name.pp id Vmm_core.Policy.pp policy) ; Logs.debug (fun m -> m "adding policy for %a: %a" Vmm_core.Name.pp id Vmm_core.Policy.pp policy) ;
let header = Vmm_commands.{version = my_version ; sequence = !command ; name = id } in let header = Vmm_commands.header ~sequence:!command id in
command := Int64.succ !command ; command := Int64.succ !command ;
Vmm_lwt.write_wire fd (header, `Command (`Policy_cmd (`Policy_add policy))) >>= function Vmm_lwt.write_wire fd (header, `Command (`Policy_cmd (`Policy_add policy))) >>= function
| Error `Exception -> Lwt.return (Error (`Msg "failed to write policy")) | Error `Exception -> Lwt.return (Error (`Msg "failed to write policy"))
| Ok () -> | Ok () ->
Vmm_lwt.read_wire fd >|= function Vmm_lwt.read_wire fd >|= function
(* TODO check version *)
| Error _ -> Error (`Msg "read error after writing policy") | Error _ -> Error (`Msg "read error after writing policy")
| Ok (_, `Success _) -> Ok () | Ok (_, `Success _) -> Ok ()
| Ok wire -> | Ok wire ->
@ -92,32 +88,29 @@ let handle ca tls =
(Ok ()) policies (Ok ()) policies
| _ -> Lwt.return (Ok ())) >>= function | _ -> Lwt.return (Ok ())) >>= function
| Error (`Msg msg) -> | Error (`Msg msg) ->
begin Vmm_lwt.safe_close fd >|= fun () ->
Logs.warn (fun m -> m "error while applying policies %s" msg) ; Logs.warn (fun m -> m "error while applying policies %s" msg) ;
let wire = `Failure msg
let header = Vmm_commands.{version = my_version ; sequence = 0L ; name } in
header, `Failure msg
in
Vmm_tls_lwt.write_tls tls wire >>= fun _ ->
Vmm_lwt.safe_close fd >>= fun () ->
Lwt.fail_with msg
end
| Ok () -> | Ok () ->
let wire = let wire =
let header = Vmm_commands.{version = my_version ; sequence = !command ; name } in let header = Vmm_commands.header ~sequence:!command name in
command := Int64.succ !command ; command := Int64.succ !command ;
(header, `Command cmd) (header, `Command cmd)
in in
Vmm_lwt.write_wire fd wire >>= function Vmm_lwt.write_wire fd wire >>= function
| Error `Exception -> | Error `Exception ->
Vmm_lwt.safe_close fd >>= fun () -> Vmm_lwt.safe_close fd >|= fun () ->
Lwt.return (Error (`Msg "couldn't write")) `Failure "couldn't write unikernel to VMMD"
| Ok () -> | Ok () ->
(match next with (match next with
| `Read -> read fd tls | `Read -> read version fd tls
| `End -> process fd tls) >>= fun res -> | `End -> process fd) >>= fun res ->
Vmm_lwt.safe_close fd >|= fun () -> Vmm_lwt.safe_close fd >|= fun () ->
res res
end >>= fun reply ->
Vmm_tls_lwt.write_tls tls
(Vmm_commands.header ~version name, reply) >|= fun _ ->
()
open Cmdliner open Cmdliner

View file

@ -30,9 +30,7 @@ let jump _ cacert cert priv_key port =
Lwt.async (fun () -> Lwt.async (fun () ->
Lwt.catch Lwt.catch
(fun () -> (fun () ->
(handle ca t >|= function handle ca t >>= fun () ->
| Error (`Msg msg) -> Logs.err (fun m -> m "error in handle %s" msg)
| Ok () -> ()) >>= fun () ->
Vmm_tls_lwt.close t) Vmm_tls_lwt.close t)
(fun e -> (fun e ->
Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ; Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ;
@ -60,6 +58,6 @@ let port =
let cmd = let cmd =
Term.(const jump $ setup_log $ cacert $ cert $ key $ port), Term.(const jump $ setup_log $ cacert $ cert $ key $ port),
Term.info "albatross_tls_endpoint" ~version:"%%VERSION_NUM%%" Term.info "albatross_tls_endpoint" ~version
let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1 let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1

View file

@ -16,9 +16,7 @@ let jump cacert cert priv_key =
Lwt.fail exn) >>= fun t -> Lwt.fail exn) >>= fun t ->
Lwt.catch Lwt.catch
(fun () -> (fun () ->
(handle ca t >|= function handle ca t >>= fun () ->
| Error (`Msg msg) -> Logs.err (fun m -> m "error in handle %s" msg)
| Ok () -> ()) >>= fun () ->
Vmm_tls_lwt.close t) Vmm_tls_lwt.close t)
(fun e -> (fun e ->
Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ; Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ;
@ -28,6 +26,6 @@ open Cmdliner
let cmd = let cmd =
Term.(const jump $ cacert $ cert $ key), Term.(const jump $ cacert $ cert $ key),
Term.info "albatross_tls_inetd" ~version:"%%VERSION_NUM%%" Term.info "albatross_tls_inetd" ~version:Albatross_cli.version
let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1 let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1

View file

@ -11,7 +11,7 @@ let cert_name cert =
| Some (_, data) -> | Some (_, data) ->
match X509.(Distinguished_name.common_name (Certificate.subject cert)) with match X509.(Distinguished_name.common_name (Certificate.subject cert)) with
| Some name -> Ok (Some name) | Some name -> Ok (Some name)
| None -> match Vmm_asn.cert_extension_of_cstruct data with | None -> match Vmm_asn.of_cert_extension data with
| Error (`Msg _) -> Error (`Msg "couldn't parse albatross extension") | Error (`Msg _) -> Error (`Msg "couldn't parse albatross extension")
| Ok (_, `Policy_cmd pc) -> | Ok (_, `Policy_cmd pc) ->
begin match pc with begin match pc with
@ -44,36 +44,33 @@ let separate_chain = function
| [ leaf ] -> Ok (leaf, []) | [ leaf ] -> Ok (leaf, [])
| leaf :: xs -> Ok (leaf, List.rev xs) | leaf :: xs -> Ok (leaf, List.rev xs)
let wire_command_of_cert version cert = let wire_command_of_cert cert =
match Extension.(find (Unsupported Vmm_asn.oid) (Certificate.extensions cert)) with match Extension.(find (Unsupported Vmm_asn.oid) (Certificate.extensions cert)) with
| None -> Error `Not_present | None -> Error `Not_present
| Some (_, data) -> | Some (_, data) ->
Vmm_asn.cert_extension_of_cstruct data >>= fun (v, wire) -> Vmm_asn.of_cert_extension data >>= fun (v, wire) ->
if not (Vmm_commands.version_eq v version) then if not Vmm_commands.(is_current v) then
Error (`Version v) Logs.warn (fun m -> m "version mismatch, received %a current %a"
else Vmm_commands.pp_version v
Ok wire Vmm_commands.pp_version Vmm_commands.current);
Ok (v, wire)
let extract_policies version chain = let extract_policies chain =
List.fold_left (fun acc cert -> List.fold_left (fun acc cert ->
match acc, wire_command_of_cert version cert with match acc, wire_command_of_cert cert with
| Error e, _ -> Error e | Error e, _ -> Error e
| Ok acc, Error `Not_present -> Ok acc | Ok acc, Error `Not_present -> Ok acc
| Ok _, Error (`Msg msg) -> Error (`Msg msg) | Ok _, Error (`Msg msg) -> Error (`Msg msg)
| Ok _, Error (`Version received) -> | Ok (prefix, acc), Ok (_, `Policy_cmd `Policy_add p) ->
R.error_msgf "unexpected version %a (expected %a)"
Vmm_commands.pp_version received
Vmm_commands.pp_version version
| Ok (prefix, acc), Ok (`Policy_cmd (`Policy_add p)) ->
(cert_name cert >>= function (cert_name cert >>= function
| None -> Ok prefix | None -> Ok prefix
| Some x -> Vmm_core.Name.prepend x prefix) >>| fun name -> | Some x -> Vmm_core.Name.prepend x prefix) >>| fun name ->
(name, (name, p) :: acc) (name, (name, p) :: acc)
| _, Ok wire -> | _, Ok wire ->
R.error_msgf "unexpected wire %a" Vmm_commands.pp wire) R.error_msgf "unexpected wire %a" Vmm_commands.pp (snd wire))
(Ok (Vmm_core.Name.root, [])) chain (Ok (Vmm_core.Name.root, [])) chain
let handle version chain = let handle chain =
(if List.length chain < 10 then (if List.length chain < 10 then
Ok () Ok ()
else else
@ -90,22 +87,18 @@ let handle version chain =
Logs.debug (fun m -> m "name is %a leaf is %a, chain %a" Logs.debug (fun m -> m "name is %a leaf is %a, chain %a"
Vmm_core.Name.pp name Certificate.pp leaf Vmm_core.Name.pp name Certificate.pp leaf
Fmt.(list ~sep:(unit " -> ") Certificate.pp) rest); Fmt.(list ~sep:(unit " -> ") Certificate.pp) rest);
extract_policies version rest >>= fun (_, policies) -> extract_policies rest >>= fun (_, policies) ->
(* TODO: logging let login_hdr, login_ev = Log.hdr name, `Login addr in *) (* TODO: logging let login_hdr, login_ev = Log.hdr name, `Login addr in *)
match wire_command_of_cert version leaf with match wire_command_of_cert leaf with
| Error (`Msg p) -> Error (`Msg p) | Error `Msg p -> Error (`Msg p)
| Error (`Not_present) -> | Error `Not_present ->
Error (`Msg "leaf certificate does not contain an albatross extension") Error (`Msg "leaf certificate does not contain an albatross extension")
| Error (`Version received) -> | Ok (v, wire) ->
R.error_msgf "unexpected version %a (expected %a)"
Vmm_commands.pp_version received
Vmm_commands.pp_version version
| Ok wire ->
(* we only allow some commands via certificate *) (* we only allow some commands via certificate *)
match wire with match wire with
| `Console_cmd (`Console_subscribe _) | `Console_cmd (`Console_subscribe _)
| `Stats_cmd `Stats_subscribe | `Stats_cmd `Stats_subscribe
| `Log_cmd (`Log_subscribe _) | `Log_cmd (`Log_subscribe _)
| `Unikernel_cmd _ | `Unikernel_cmd _
| `Policy_cmd `Policy_info -> Ok (name, policies, wire) | `Policy_cmd `Policy_info -> Ok (name, policies, v, wire)
| _ -> Error (`Msg "unexpected command") | _ -> Error (`Msg "unexpected command")

View file

@ -1,10 +1,9 @@
(* (c) 2018 Hannes Mehnert, all rights reserved *) (* (c) 2018 Hannes Mehnert, all rights reserved *)
val wire_command_of_cert : Vmm_commands.version -> X509.Certificate.t -> val wire_command_of_cert : X509.Certificate.t ->
(Vmm_commands.t, [> `Msg of string | `Not_present | `Version of Vmm_commands.version ]) result (Vmm_commands.version * Vmm_commands.t, [> `Msg of string | `Not_present ]) result
val handle : val handle :
Vmm_commands.version ->
X509.Certificate.t list -> X509.Certificate.t list ->
(Vmm_core.Name.t * (Vmm_core.Name.t * Vmm_core.Policy.t) list * Vmm_commands.t, (Vmm_core.Name.t * (Vmm_core.Name.t * Vmm_core.Policy.t) list * Vmm_commands.version * Vmm_commands.t,
[> `Msg of string ]) Result.result [> `Msg of string ]) result

View file

@ -40,10 +40,15 @@ let read_tls t =
hdr.Vmm_wire.id Vmm_wire.pp_version hdr.Vmm_wire.version hdr.Vmm_wire.tag hdr.Vmm_wire.id Vmm_wire.pp_version hdr.Vmm_wire.version hdr.Vmm_wire.tag
Cstruct.hexdump_pp b) ; *) Cstruct.hexdump_pp b) ; *)
match Vmm_asn.wire_of_cstruct b with match Vmm_asn.wire_of_cstruct b with
| Ok w -> Ok w
| Error (`Msg msg) -> | Error (`Msg msg) ->
Logs.err (fun m -> m "error %s while parsing data" msg) ; Logs.err (fun m -> m "error %s while parsing data" msg) ;
Error `Exception Error `Exception
| (Ok (hdr, _)) as w ->
if not Vmm_commands.(is_current hdr.version) then
Logs.warn (fun m -> m "version mismatch, received %a current %a"
Vmm_commands.pp_version hdr.Vmm_commands.version
Vmm_commands.pp_version Vmm_commands.current);
w
else else
Lwt.return (Error `Eof) Lwt.return (Error `Eof)