From ceafacbd2a1b800e97df96710583338de1cca53a Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Tue, 19 May 2020 21:07:39 +0200 Subject: [PATCH] require tls 1.3, avoid renegotiation (client certificate is now already encrypted) --- albatross.opam | 2 +- tls/albatross_tls_common.ml | 141 +++++++++++++++------------------- tls/albatross_tls_endpoint.ml | 4 +- tls/albatross_tls_inetd.ml | 4 +- 4 files changed, 69 insertions(+), 82 deletions(-) diff --git a/albatross.opam b/albatross.opam index 8cd9ab1..0007f3f 100644 --- a/albatross.opam +++ b/albatross.opam @@ -22,7 +22,7 @@ depends: [ "astring" "jsonm" "x509" {>= "0.11.0"} - "tls" {>= "0.11.0"} + "tls" {>= "0.12.0"} "mirage-crypto-pk" "mirage-crypto-rng" "asn1-combinators" {>= "0.2.0"} diff --git a/tls/albatross_tls_common.ml b/tls/albatross_tls_common.ml index b4a06f8..14f62c2 100644 --- a/tls/albatross_tls_common.ml +++ b/tls/albatross_tls_common.ml @@ -9,26 +9,11 @@ let tls_config cacert cert priv_key = X509_lwt.certs_of_pem cacert >>= (function | [ ca ] -> Lwt.return ca | _ -> Lwt.fail_with "expect single ca as cacert") >|= fun ca -> - (Tls.(Config.server ~version:(Core.TLS_1_2, Core.TLS_1_2) - ~reneg:true ~certificates:(`Single cert) ()), - ca) - -let client_auth ca tls = - let authenticator = - let time () = Some (Ptime_clock.now ()) in - X509.Authenticator.chain_of_trust ~time (* ~crls:!state.Vmm_engine.crls *) [ca] - in - Lwt.catch - (fun () -> Tls_lwt.Unix.reneg ~authenticator tls) - (fun e -> - (match e with - | Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a)) - | Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f)) - | exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ; - Lwt.fail e) >>= fun () -> - (match Tls_lwt.Unix.epoch tls with - | `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain - | `Error -> Lwt.fail_with "error while getting epoch") + let time () = Some (Ptime_clock.now ()) in + Tls.Config.server + ~version:(`TLS_1_3, `TLS_1_3) + ~authenticator:(X509.Authenticator.chain_of_trust ~time [ca]) + ~certificates:(`Single cert) () let read version fd tls = (* now we busy read and process output *) @@ -51,66 +36,68 @@ let process fd = Logs.debug (fun m -> m "proxying %a" Vmm_commands.pp_wire (hdr, pay)); pay -let handle ca tls = - client_auth ca tls >>= fun chain -> - match Vmm_tls.handle chain with - | Error `Msg msg -> - Logs.err (fun m -> m "failed to handle TLS connection %s" msg); - Lwt.return_unit - | Ok (name, policies, version, cmd) -> - begin - let sock, next = Vmm_commands.endpoint cmd in - let sockaddr = Lwt_unix.ADDR_UNIX (Vmm_core.socket_path sock) in - Vmm_lwt.connect Lwt_unix.PF_UNIX sockaddr >>= function - | None -> - Logs.warn (fun m -> m "failed to connect to %a" Vmm_lwt.pp_sockaddr sockaddr); - Lwt.return (`Failure "couldn't reach service") - | Some fd -> - (match sock with - | `Vmmd -> - Lwt_list.fold_left_s (fun r (id, policy) -> - match r with - | Error (`Msg msg) -> Lwt.return (Error (`Msg msg)) - | Ok () -> - Logs.debug (fun m -> m "adding policy for %a: %a" Vmm_core.Name.pp id Vmm_core.Policy.pp policy) ; - let header = Vmm_commands.header ~sequence:!command id in - command := Int64.succ !command ; - Vmm_lwt.write_wire fd (header, `Command (`Policy_cmd (`Policy_add policy))) >>= function - | Error `Exception -> Lwt.return (Error (`Msg "failed to write policy")) +let handle tls = + match Tls_lwt.Unix.epoch tls with + | `Error -> Lwt.fail_with "error while getting epoch" + | `Ok epoch -> + match Vmm_tls.handle epoch.Tls.Core.peer_certificate_chain with + | Error `Msg msg -> + Logs.err (fun m -> m "failed to handle TLS connection %s" msg); + Lwt.return_unit + | Ok (name, policies, version, cmd) -> + begin + let sock, next = Vmm_commands.endpoint cmd in + let sockaddr = Lwt_unix.ADDR_UNIX (Vmm_core.socket_path sock) in + Vmm_lwt.connect Lwt_unix.PF_UNIX sockaddr >>= function + | None -> + Logs.warn (fun m -> m "failed to connect to %a" Vmm_lwt.pp_sockaddr sockaddr); + Lwt.return (`Failure "couldn't reach service") + | Some fd -> + (match sock with + | `Vmmd -> + Lwt_list.fold_left_s (fun r (id, policy) -> + match r with + | Error (`Msg msg) -> Lwt.return (Error (`Msg msg)) | Ok () -> - Vmm_lwt.read_wire fd >|= function - | Error _ -> Error (`Msg "read error after writing policy") - | Ok (_, `Success _) -> Ok () - | Ok wire -> - Rresult.R.error_msgf - "expected success when adding policy, got: %a" - Vmm_commands.pp_wire wire) - (Ok ()) policies - | _ -> Lwt.return (Ok ())) >>= function - | Error (`Msg msg) -> - Vmm_lwt.safe_close fd >|= fun () -> - Logs.warn (fun m -> m "error while applying policies %s" msg) ; - `Failure msg - | Ok () -> - let wire = - let header = Vmm_commands.header ~sequence:!command name in - command := Int64.succ !command ; - (header, `Command cmd) - in - Vmm_lwt.write_wire fd wire >>= function - | Error `Exception -> + Logs.debug (fun m -> m "adding policy for %a: %a" Vmm_core.Name.pp id Vmm_core.Policy.pp policy) ; + let header = Vmm_commands.header ~sequence:!command id in + command := Int64.succ !command ; + Vmm_lwt.write_wire fd (header, `Command (`Policy_cmd (`Policy_add policy))) >>= function + | Error `Exception -> Lwt.return (Error (`Msg "failed to write policy")) + | Ok () -> + Vmm_lwt.read_wire fd >|= function + | Error _ -> Error (`Msg "read error after writing policy") + | Ok (_, `Success _) -> Ok () + | Ok wire -> + Rresult.R.error_msgf + "expected success when adding policy, got: %a" + Vmm_commands.pp_wire wire) + (Ok ()) policies + | _ -> Lwt.return (Ok ())) >>= function + | Error (`Msg msg) -> Vmm_lwt.safe_close fd >|= fun () -> - `Failure "couldn't write unikernel to VMMD" + Logs.warn (fun m -> m "error while applying policies %s" msg) ; + `Failure msg | Ok () -> - (match next with - | `Read -> read version fd tls - | `End -> process fd) >>= fun res -> - Vmm_lwt.safe_close fd >|= fun () -> - res - end >>= fun reply -> - Vmm_tls_lwt.write_tls tls - (Vmm_commands.header ~version name, reply) >|= fun _ -> - () + let wire = + let header = Vmm_commands.header ~sequence:!command name in + command := Int64.succ !command ; + (header, `Command cmd) + in + Vmm_lwt.write_wire fd wire >>= function + | Error `Exception -> + Vmm_lwt.safe_close fd >|= fun () -> + `Failure "couldn't write unikernel to VMMD" + | Ok () -> + (match next with + | `Read -> read version fd tls + | `End -> process fd) >>= fun res -> + Vmm_lwt.safe_close fd >|= fun () -> + res + end >>= fun reply -> + Vmm_tls_lwt.write_tls tls + (Vmm_commands.header ~version name, reply) >|= fun _ -> + () let classify_tls_error = function | Tls_lwt.Tls_alert diff --git a/tls/albatross_tls_endpoint.ml b/tls/albatross_tls_endpoint.ml index 2c1f668..177e4af 100644 --- a/tls/albatross_tls_endpoint.ml +++ b/tls/albatross_tls_endpoint.ml @@ -19,7 +19,7 @@ let jump _ cacert cert priv_key port tmpdir = Albatross_cli.set_tmpdir tmpdir; Lwt_main.run (server_socket port >>= fun socket -> - tls_config cacert cert priv_key >>= fun (config, ca) -> + tls_config cacert cert priv_key >>= fun config -> let rec loop () = Lwt.catch (fun () -> Lwt_unix.accept socket >>= fun (fd, _addr) -> @@ -31,7 +31,7 @@ let jump _ cacert cert priv_key port tmpdir = Lwt.async (fun () -> Lwt.catch (fun () -> - handle ca t >>= fun () -> + handle t >>= fun () -> Vmm_tls_lwt.close t) (fun e -> Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ; diff --git a/tls/albatross_tls_inetd.ml b/tls/albatross_tls_inetd.ml index 9b9096d..5aee82a 100644 --- a/tls/albatross_tls_inetd.ml +++ b/tls/albatross_tls_inetd.ml @@ -8,7 +8,7 @@ let jump cacert cert priv_key tmpdir = Mirage_crypto_rng_unix.initialize (); Albatross_cli.set_tmpdir tmpdir; Lwt_main.run - (tls_config cacert cert priv_key >>= fun (config, ca) -> + (tls_config cacert cert priv_key >>= fun config -> let fd = Lwt_unix.of_unix_file_descr Unix.stdin in Lwt.catch (fun () -> Tls_lwt.Unix.server_of_fd config fd) @@ -17,7 +17,7 @@ let jump cacert cert priv_key tmpdir = Lwt.fail exn) >>= fun t -> Lwt.catch (fun () -> - handle ca t >>= fun () -> + handle t >>= fun () -> Vmm_tls_lwt.close t) (fun e -> Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ;