albatross/app/vmmd.ml

245 lines
9.0 KiB
OCaml

(* (c) 2017 Hannes Mehnert, all rights reserved *)
open Lwt.Infix
let write_tls state t data =
Lwt.catch (fun () -> Vmm_tls.write_tls (fst t) data)
(fun e ->
let state', out = Vmm_engine.handle_disconnect !state t in
state := state' ;
Lwt_list.iter_s (fun (s, data) -> Vmm_lwt.write_raw s data) out >>= fun () ->
raise e)
let to_ipaddr (_, sa) = match sa with
| Lwt_unix.ADDR_UNIX _ -> invalid_arg "cannot convert unix address"
| Lwt_unix.ADDR_INET (addr, port) -> Ipaddr_unix.V4.of_inet_addr_exn addr, port
let pp_sockaddr ppf (_, sa) = match sa with
| Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str
| Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d"
(Unix.string_of_inet_addr addr) port
let process state xs =
Lwt_list.iter_s (function
| `Raw (s, str) -> Vmm_lwt.write_raw s str
| `Tls (s, str) -> write_tls state s str)
xs
let handle ca state t =
Logs.debug (fun m -> m "connection from %a" pp_sockaddr t) ;
let authenticator =
let time = Unix.gettimeofday () in
X509.Authenticator.chain_of_trust ~time ~crls:!state.Vmm_engine.crls [ca]
in
Lwt.catch (fun () ->
Tls_lwt.Unix.reneg ~authenticator (fst t))
(fun e ->
(match e with
| Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a))
| Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f))
| exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ;
Lwt.fail e) >>= fun () ->
(match Tls_lwt.Unix.epoch (fst t) with
| `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain
| `Error -> Lwt.fail_with "error while getting epoch") >>= fun chain ->
match Vmm_engine.handle_initial !state t (to_ipaddr t) chain ca with
| Ok (state', outs, next) ->
state := state' ;
process state outs >>= fun () ->
(match next with
| `Create cont ->
(match cont !state t with
| Ok (state', outs, vm) ->
state := state' ;
process state outs >>= fun () ->
Lwt.async (fun () ->
Vmm_lwt.wait_and_clear vm.Vmm_core.pid vm.Vmm_core.stdout >>= fun r ->
let state', outs = Vmm_engine.handle_shutdown !state vm r in
state := state' ;
process state outs) ;
Lwt.return_unit
| Error (`Msg e) ->
Logs.err (fun m -> m "error while cont %s" e) ;
let err = Vmm_wire.fail ~msg:e 0 !state.Vmm_engine.client_version in
process state [ `Tls (t, err) ]) >>= fun () ->
Tls_lwt.Unix.close (fst t)
| `Loop (prefix, perms) ->
let rec loop () =
Vmm_tls.read_tls (fst t) >>= function
| Error (`Msg msg) ->
Logs.err (fun m -> m "reading client %a error: %s" pp_sockaddr t msg) ;
loop ()
| Ok (hdr, buf) ->
let state', out = Vmm_engine.handle_command !state t prefix perms hdr buf in
state := state' ;
process state out >>= fun () ->
loop ()
in
Lwt.catch loop (fun e ->
let state', cons = Vmm_engine.handle_disconnect !state t in
state := state' ;
Lwt_list.iter_s (fun (s, data) -> Vmm_lwt.write_raw s data) cons >>= fun () ->
raise e)
| `Close socks ->
Logs.debug (fun m -> m "closing session with %d active ones" (List.length socks)) ;
Lwt_list.iter_s (fun (t, _) -> Tls_lwt.Unix.close t) socks >>= fun () ->
Tls_lwt.Unix.close (fst t))
| Error (`Msg e) ->
Logs.err (fun m -> m "VMM %a %s" pp_sockaddr t e) ;
let err = Vmm_wire.fail ~msg:e 0 !state.Vmm_engine.client_version in
process state [`Tls (t, err)] >>= fun () ->
Tls_lwt.Unix.close (fst t)
let server_socket port =
let open Lwt_unix in
let s = socket PF_INET SOCK_STREAM 0 in
set_close_on_exec s ;
setsockopt s SO_REUSEADDR true ;
Versioned.bind_2 s (ADDR_INET (Unix.inet_addr_any, port)) >>= fun () ->
listen s 10 ;
Lwt.return s
let init_exception () =
Lwt.async_exception_hook := (function
| Tls_lwt.Tls_failure a ->
Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a))
| exn ->
Logs.err (fun m -> m "exception: %s" (Printexc.to_string exn)))
let rec read_log state s =
Vmm_lwt.read_exactly s >>= function
| Error (`Msg msg) ->
Logs.err (fun m -> m "reading log error %s" msg) ;
read_log state s
| Ok (hdr, data) ->
let state', outs = Vmm_engine.handle_log !state hdr data in
state := state' ;
process state outs >>= fun () ->
read_log state s
let rec read_cons state s =
Vmm_lwt.read_exactly s >>= function
| Error (`Msg msg) ->
Logs.err (fun m -> m "reading console error %s" msg) ;
read_cons state s
| Ok (hdr, data) ->
let state', outs = Vmm_engine.handle_cons !state hdr data in
state := state' ;
process state outs >>= fun () ->
read_cons state s
let rec read_stats state s =
Vmm_lwt.read_exactly s >>= function
| Error (`Msg msg) ->
Logs.err (fun m -> m "reading stats error %s" msg) ;
read_stats state s
| Ok (hdr, data) ->
let state', outs = Vmm_engine.handle_stat !state hdr data in
state := state' ;
process state outs >>= fun () ->
read_stats state s
let cmp_s (_, a) (_, b) =
let open Lwt_unix in
match a, b with
| ADDR_UNIX str, ADDR_UNIX str' -> String.compare str str' = 0
| ADDR_INET (addr, port), ADDR_INET (addr', port') ->
port = port' &&
String.compare (Unix.string_of_inet_addr addr) (Unix.string_of_inet_addr addr') = 0
| _ -> false
let jump _ dir cacert cert priv_key =
Sys.(set_signal sigpipe Signal_ignore) ;
Lwt_main.run
(init_exception () ;
let d = Fpath.v dir in
let c = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
Lwt_unix.set_close_on_exec c ;
Lwt_unix.(connect c (ADDR_UNIX Fpath.(to_string (d / "cons" + "sock")))) >>= fun () ->
Lwt.catch (fun () ->
let s = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
Lwt_unix.set_close_on_exec s ;
Lwt_unix.(connect s (ADDR_UNIX Fpath.(to_string (d / "stat" + "sock")))) >|= fun () ->
Some s)
(function
| Unix.Unix_error (Unix.ENOENT, _, _) -> Lwt.return None
| e -> Lwt.fail e) >>= fun s ->
let l = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
Lwt_unix.set_close_on_exec l ;
Lwt_unix.(connect l (ADDR_UNIX Fpath.(to_string (d / "log" + "sock")))) >>= fun () ->
server_socket 1025 >>= fun socket ->
X509_lwt.private_of_pems ~cert ~priv_key >>= fun cert ->
X509_lwt.certs_of_pem cacert >>= (function
| [ ca ] -> Lwt.return ca
| _ -> Lwt.fail_with "expect single ca as cacert") >>= fun ca ->
let config =
Tls.(Config.server ~version:(Core.TLS_1_2, Core.TLS_1_2)
~reneg:true ~certificates:(`Single cert) ())
in
(match Vmm_engine.init d cmp_s c s l with
| Ok s -> Lwt.return s
| Error (`Msg m) -> Lwt.fail_with m) >>= fun t ->
let state = ref t in
Lwt.async (fun () -> read_cons state c) ;
(match s with
| None -> ()
| Some s -> Lwt.async (fun () -> read_stats state s)) ;
Lwt.async (fun () -> read_log state l) ;
let rec loop () =
Lwt.catch (fun () ->
Lwt_unix.accept socket >>= fun (fd, addr) ->
Lwt_unix.set_close_on_exec fd ;
Lwt.catch
(fun () -> Tls_lwt.Unix.server_of_fd config fd >|= fun t -> (t, addr))
(fun exn ->
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () ->
Lwt.fail exn) >>= fun t ->
Lwt.async (fun () -> handle ca state t) ;
loop ())
(function
| Unix.Unix_error (e, f, _) ->
Logs.err (fun m -> m "Unix error %s in %s" (Unix.error_message e) f) ;
loop ()
| Tls_lwt.Tls_failure a ->
Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a)) ;
loop ()
| exn ->
Logs.err (fun m -> m "exception %s" (Printexc.to_string exn)) ;
loop ())
in
loop ())
let setup_log style_renderer level =
Fmt_tty.setup_std_outputs ?style_renderer ();
Logs.set_level level;
Logs.set_reporter (Logs_fmt.reporter ~dst:Format.std_formatter ())
open Cmdliner
let setup_log =
Term.(const setup_log
$ Fmt_cli.style_renderer ()
$ Logs_cli.level ())
let wdir =
let doc = "Working directory (unix domain sockets, etc.)" in
Arg.(required & pos 0 (some dir) None & info [] ~doc)
let cacert =
let doc = "CA certificate" in
Arg.(required & pos 1 (some file) None & info [] ~doc)
let cert =
let doc = "Certificate" in
Arg.(required & pos 2 (some file) None & info [] ~doc)
let key =
let doc = "Private key" in
Arg.(required & pos 3 (some file) None & info [] ~doc)
let cmd =
Term.(ret (const jump $ setup_log $ wdir $ cacert $ cert $ key)),
Term.info "vmmd" ~version:"%%VERSION_NUM%%"
let () = match Term.eval cmd with `Ok () -> exit 0 | _ -> exit 1