albatross/app/vmm_tls_endpoint.ml

168 lines
5.7 KiB
OCaml
Raw Normal View History

(* (c) 2017, 2018 Hannes Mehnert, all rights reserved *)
open Lwt.Infix
2018-10-22 22:12:06 +00:00
let my_version = `AV2
let command = ref 0L
2018-10-13 23:02:52 +00:00
let pp_sockaddr ppf = function
| Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str
| Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d"
(Unix.string_of_inet_addr addr) port
2018-10-13 23:02:52 +00:00
let connect socket_path =
let c = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
Lwt_unix.set_close_on_exec c ;
Lwt_unix.connect c (Lwt_unix.ADDR_UNIX socket_path) >|= fun () ->
c
let client_auth ca tls addr =
Logs.debug (fun m -> m "connection from %a" pp_sockaddr addr) ;
let authenticator =
let time = Ptime_clock.now () in
X509.Authenticator.chain_of_trust ~time (* ~crls:!state.Vmm_engine.crls *) [ca]
in
Lwt.catch
(fun () -> Tls_lwt.Unix.reneg ~authenticator tls)
(fun e ->
(match e with
| Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a))
| Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f))
| exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ;
Tls_lwt.Unix.close tls >>= fun () ->
Lwt.fail e) >>= fun () ->
(match Tls_lwt.Unix.epoch tls with
| `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain
| `Error ->
Tls_lwt.Unix.close tls >>= fun () ->
Lwt.fail_with "error while getting epoch")
2018-10-14 00:18:33 +00:00
let read fd tls =
(* now we busy read and process output *)
let rec loop () =
Vmm_lwt.read_wire fd >>= function
| Error _ -> Lwt.return (Error (`Msg "exception while reading"))
2018-10-22 22:12:06 +00:00
| Ok wire ->
Logs.debug (fun m -> m "read proxying %a" Vmm_asn.pp_wire wire) ;
Vmm_tls.write_tls tls wire >>= function
2018-10-14 00:18:33 +00:00
| Ok () -> loop ()
| Error `Exception -> Lwt.return (Error (`Msg "exception"))
in
loop ()
let process fd tls =
Vmm_lwt.read_wire fd >>= function
| Error _ -> Lwt.return (Error (`Msg "read error"))
2018-10-22 22:12:06 +00:00
| Ok wire ->
Logs.debug (fun m -> m "proxying %a" Vmm_asn.pp_wire wire) ;
Vmm_tls.write_tls tls wire >|= function
2018-10-14 00:18:33 +00:00
| Ok () -> Ok ()
| Error `Exception -> Error (`Msg "exception on write")
2018-10-13 23:02:52 +00:00
let handle ca (tls, addr) =
client_auth ca tls addr >>= fun chain ->
2018-10-14 00:18:33 +00:00
match Vmm_x509.handle addr chain with
| Error (`Msg m) -> Lwt.fail_with m
2018-10-22 22:12:06 +00:00
| Ok (name, cmd) ->
let sock, next = Vmm_commands.handle cmd in
2018-10-14 00:18:33 +00:00
connect (Vmm_core.socket_path sock) >>= fun fd ->
2018-10-22 22:12:06 +00:00
let wire =
let header = Vmm_asn.{version = my_version ; sequence = !command ; id = name } in
command := Int64.succ !command ;
(header, `Command cmd)
in
Vmm_lwt.write_wire fd wire >>= function
2018-10-14 00:18:33 +00:00
| Error `Exception -> Lwt.return (Error (`Msg "couldn't write"))
| Ok () ->
(match next with
| `Read -> read fd tls
| `End -> process fd tls) >>= fun res ->
Vmm_lwt.safe_close fd >|= fun () ->
res
let server_socket port =
let open Lwt_unix in
let s = socket PF_INET SOCK_STREAM 0 in
set_close_on_exec s ;
setsockopt s SO_REUSEADDR true ;
bind s (ADDR_INET (Unix.inet_addr_any, port)) >>= fun () ->
listen s 10 ;
Lwt.return s
let jump _ cacert cert priv_key port =
Sys.(set_signal sigpipe Signal_ignore) ;
Lwt_main.run
(Nocrypto_entropy_lwt.initialize () >>= fun () ->
server_socket port >>= fun socket ->
X509_lwt.private_of_pems ~cert ~priv_key >>= fun cert ->
X509_lwt.certs_of_pem cacert >>= (function
| [ ca ] -> Lwt.return ca
| _ -> Lwt.fail_with "expect single ca as cacert") >>= fun ca ->
let config =
Tls.(Config.server ~version:(Core.TLS_1_2, Core.TLS_1_2)
~reneg:true ~certificates:(`Single cert) ())
in
let rec loop () =
Lwt.catch (fun () ->
Lwt_unix.accept socket >>= fun (fd, addr) ->
Lwt_unix.set_close_on_exec fd ;
Lwt.catch
(fun () -> Tls_lwt.Unix.server_of_fd config fd >|= fun t -> (t, addr))
(fun exn ->
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () ->
Lwt.fail exn) >>= fun t ->
Lwt.async (fun () ->
Lwt.catch
2018-10-14 00:18:33 +00:00
(fun () -> handle ca t >|= function
| Error (`Msg msg) -> Logs.err (fun m -> m "error in handle %s" msg)
| Ok () -> ())
(fun e ->
Logs.err (fun m -> m "error while handle() %s"
(Printexc.to_string e)) ;
Lwt.return_unit)) ;
loop ())
(function
| Unix.Unix_error (e, f, _) ->
Logs.err (fun m -> m "Unix error %s in %s" (Unix.error_message e) f) ;
loop ()
| Tls_lwt.Tls_failure a ->
Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a)) ;
loop ()
| exn ->
Logs.err (fun m -> m "exception %s" (Printexc.to_string exn)) ;
loop ())
in
loop ())
let setup_log style_renderer level =
Fmt_tty.setup_std_outputs ?style_renderer ();
Logs.set_level level;
Logs.set_reporter (Logs_fmt.reporter ~dst:Format.std_formatter ())
open Cmdliner
let setup_log =
Term.(const setup_log
$ Fmt_cli.style_renderer ()
$ Logs_cli.level ())
(* TODO needs CRL as well, plus socket(s) *)
let cacert =
let doc = "CA certificate" in
Arg.(required & pos 0 (some file) None & info [] ~doc)
let cert =
let doc = "Certificate" in
Arg.(required & pos 1 (some file) None & info [] ~doc)
let key =
let doc = "Private key" in
Arg.(required & pos 2 (some file) None & info [] ~doc)
let port =
let doc = "TCP listen port" in
Arg.(value & opt int 1025 & info [ "port" ] ~doc)