require tls 1.3, avoid renegotiation (client certificate is now already encrypted)

This commit is contained in:
Hannes Mehnert 2020-05-19 21:07:39 +02:00
parent ccf3cae68c
commit ceafacbd2a
4 changed files with 69 additions and 82 deletions

View file

@ -22,7 +22,7 @@ depends: [
"astring" "astring"
"jsonm" "jsonm"
"x509" {>= "0.11.0"} "x509" {>= "0.11.0"}
"tls" {>= "0.11.0"} "tls" {>= "0.12.0"}
"mirage-crypto-pk" "mirage-crypto-pk"
"mirage-crypto-rng" "mirage-crypto-rng"
"asn1-combinators" {>= "0.2.0"} "asn1-combinators" {>= "0.2.0"}

View file

@ -9,26 +9,11 @@ let tls_config cacert cert priv_key =
X509_lwt.certs_of_pem cacert >>= (function X509_lwt.certs_of_pem cacert >>= (function
| [ ca ] -> Lwt.return ca | [ ca ] -> Lwt.return ca
| _ -> Lwt.fail_with "expect single ca as cacert") >|= fun ca -> | _ -> Lwt.fail_with "expect single ca as cacert") >|= fun ca ->
(Tls.(Config.server ~version:(Core.TLS_1_2, Core.TLS_1_2) let time () = Some (Ptime_clock.now ()) in
~reneg:true ~certificates:(`Single cert) ()), Tls.Config.server
ca) ~version:(`TLS_1_3, `TLS_1_3)
~authenticator:(X509.Authenticator.chain_of_trust ~time [ca])
let client_auth ca tls = ~certificates:(`Single cert) ()
let authenticator =
let time () = Some (Ptime_clock.now ()) in
X509.Authenticator.chain_of_trust ~time (* ~crls:!state.Vmm_engine.crls *) [ca]
in
Lwt.catch
(fun () -> Tls_lwt.Unix.reneg ~authenticator tls)
(fun e ->
(match e with
| Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a))
| Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f))
| exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ;
Lwt.fail e) >>= fun () ->
(match Tls_lwt.Unix.epoch tls with
| `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain
| `Error -> Lwt.fail_with "error while getting epoch")
let read version fd tls = let read version fd tls =
(* now we busy read and process output *) (* now we busy read and process output *)
@ -51,66 +36,68 @@ let process fd =
Logs.debug (fun m -> m "proxying %a" Vmm_commands.pp_wire (hdr, pay)); Logs.debug (fun m -> m "proxying %a" Vmm_commands.pp_wire (hdr, pay));
pay pay
let handle ca tls = let handle tls =
client_auth ca tls >>= fun chain -> match Tls_lwt.Unix.epoch tls with
match Vmm_tls.handle chain with | `Error -> Lwt.fail_with "error while getting epoch"
| Error `Msg msg -> | `Ok epoch ->
Logs.err (fun m -> m "failed to handle TLS connection %s" msg); match Vmm_tls.handle epoch.Tls.Core.peer_certificate_chain with
Lwt.return_unit | Error `Msg msg ->
| Ok (name, policies, version, cmd) -> Logs.err (fun m -> m "failed to handle TLS connection %s" msg);
begin Lwt.return_unit
let sock, next = Vmm_commands.endpoint cmd in | Ok (name, policies, version, cmd) ->
let sockaddr = Lwt_unix.ADDR_UNIX (Vmm_core.socket_path sock) in begin
Vmm_lwt.connect Lwt_unix.PF_UNIX sockaddr >>= function let sock, next = Vmm_commands.endpoint cmd in
| None -> let sockaddr = Lwt_unix.ADDR_UNIX (Vmm_core.socket_path sock) in
Logs.warn (fun m -> m "failed to connect to %a" Vmm_lwt.pp_sockaddr sockaddr); Vmm_lwt.connect Lwt_unix.PF_UNIX sockaddr >>= function
Lwt.return (`Failure "couldn't reach service") | None ->
| Some fd -> Logs.warn (fun m -> m "failed to connect to %a" Vmm_lwt.pp_sockaddr sockaddr);
(match sock with Lwt.return (`Failure "couldn't reach service")
| `Vmmd -> | Some fd ->
Lwt_list.fold_left_s (fun r (id, policy) -> (match sock with
match r with | `Vmmd ->
| Error (`Msg msg) -> Lwt.return (Error (`Msg msg)) Lwt_list.fold_left_s (fun r (id, policy) ->
| Ok () -> match r with
Logs.debug (fun m -> m "adding policy for %a: %a" Vmm_core.Name.pp id Vmm_core.Policy.pp policy) ; | Error (`Msg msg) -> Lwt.return (Error (`Msg msg))
let header = Vmm_commands.header ~sequence:!command id in
command := Int64.succ !command ;
Vmm_lwt.write_wire fd (header, `Command (`Policy_cmd (`Policy_add policy))) >>= function
| Error `Exception -> Lwt.return (Error (`Msg "failed to write policy"))
| Ok () -> | Ok () ->
Vmm_lwt.read_wire fd >|= function Logs.debug (fun m -> m "adding policy for %a: %a" Vmm_core.Name.pp id Vmm_core.Policy.pp policy) ;
| Error _ -> Error (`Msg "read error after writing policy") let header = Vmm_commands.header ~sequence:!command id in
| Ok (_, `Success _) -> Ok () command := Int64.succ !command ;
| Ok wire -> Vmm_lwt.write_wire fd (header, `Command (`Policy_cmd (`Policy_add policy))) >>= function
Rresult.R.error_msgf | Error `Exception -> Lwt.return (Error (`Msg "failed to write policy"))
"expected success when adding policy, got: %a" | Ok () ->
Vmm_commands.pp_wire wire) Vmm_lwt.read_wire fd >|= function
(Ok ()) policies | Error _ -> Error (`Msg "read error after writing policy")
| _ -> Lwt.return (Ok ())) >>= function | Ok (_, `Success _) -> Ok ()
| Error (`Msg msg) -> | Ok wire ->
Vmm_lwt.safe_close fd >|= fun () -> Rresult.R.error_msgf
Logs.warn (fun m -> m "error while applying policies %s" msg) ; "expected success when adding policy, got: %a"
`Failure msg Vmm_commands.pp_wire wire)
| Ok () -> (Ok ()) policies
let wire = | _ -> Lwt.return (Ok ())) >>= function
let header = Vmm_commands.header ~sequence:!command name in | Error (`Msg msg) ->
command := Int64.succ !command ;
(header, `Command cmd)
in
Vmm_lwt.write_wire fd wire >>= function
| Error `Exception ->
Vmm_lwt.safe_close fd >|= fun () -> Vmm_lwt.safe_close fd >|= fun () ->
`Failure "couldn't write unikernel to VMMD" Logs.warn (fun m -> m "error while applying policies %s" msg) ;
`Failure msg
| Ok () -> | Ok () ->
(match next with let wire =
| `Read -> read version fd tls let header = Vmm_commands.header ~sequence:!command name in
| `End -> process fd) >>= fun res -> command := Int64.succ !command ;
Vmm_lwt.safe_close fd >|= fun () -> (header, `Command cmd)
res in
end >>= fun reply -> Vmm_lwt.write_wire fd wire >>= function
Vmm_tls_lwt.write_tls tls | Error `Exception ->
(Vmm_commands.header ~version name, reply) >|= fun _ -> Vmm_lwt.safe_close fd >|= fun () ->
() `Failure "couldn't write unikernel to VMMD"
| Ok () ->
(match next with
| `Read -> read version fd tls
| `End -> process fd) >>= fun res ->
Vmm_lwt.safe_close fd >|= fun () ->
res
end >>= fun reply ->
Vmm_tls_lwt.write_tls tls
(Vmm_commands.header ~version name, reply) >|= fun _ ->
()
let classify_tls_error = function let classify_tls_error = function
| Tls_lwt.Tls_alert | Tls_lwt.Tls_alert

View file

@ -19,7 +19,7 @@ let jump _ cacert cert priv_key port tmpdir =
Albatross_cli.set_tmpdir tmpdir; Albatross_cli.set_tmpdir tmpdir;
Lwt_main.run Lwt_main.run
(server_socket port >>= fun socket -> (server_socket port >>= fun socket ->
tls_config cacert cert priv_key >>= fun (config, ca) -> tls_config cacert cert priv_key >>= fun config ->
let rec loop () = let rec loop () =
Lwt.catch (fun () -> Lwt.catch (fun () ->
Lwt_unix.accept socket >>= fun (fd, _addr) -> Lwt_unix.accept socket >>= fun (fd, _addr) ->
@ -31,7 +31,7 @@ let jump _ cacert cert priv_key port tmpdir =
Lwt.async (fun () -> Lwt.async (fun () ->
Lwt.catch Lwt.catch
(fun () -> (fun () ->
handle ca t >>= fun () -> handle t >>= fun () ->
Vmm_tls_lwt.close t) Vmm_tls_lwt.close t)
(fun e -> (fun e ->
Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ; Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ;

View file

@ -8,7 +8,7 @@ let jump cacert cert priv_key tmpdir =
Mirage_crypto_rng_unix.initialize (); Mirage_crypto_rng_unix.initialize ();
Albatross_cli.set_tmpdir tmpdir; Albatross_cli.set_tmpdir tmpdir;
Lwt_main.run Lwt_main.run
(tls_config cacert cert priv_key >>= fun (config, ca) -> (tls_config cacert cert priv_key >>= fun config ->
let fd = Lwt_unix.of_unix_file_descr Unix.stdin in let fd = Lwt_unix.of_unix_file_descr Unix.stdin in
Lwt.catch Lwt.catch
(fun () -> Tls_lwt.Unix.server_of_fd config fd) (fun () -> Tls_lwt.Unix.server_of_fd config fd)
@ -17,7 +17,7 @@ let jump cacert cert priv_key tmpdir =
Lwt.fail exn) >>= fun t -> Lwt.fail exn) >>= fun t ->
Lwt.catch Lwt.catch
(fun () -> (fun () ->
handle ca t >>= fun () -> handle t >>= fun () ->
Vmm_tls_lwt.close t) Vmm_tls_lwt.close t)
(fun e -> (fun e ->
Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ; Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ;