diff --git a/app/vmmd_tls.ml b/app/vmmd_tls.ml index b074887..d32cbd7 100644 --- a/app/vmmd_tls.ml +++ b/app/vmmd_tls.ml @@ -30,13 +30,10 @@ let client_auth ca tls addr = | Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a)) | Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f)) | exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ; - Vmm_tls_lwt.close tls >>= fun () -> Lwt.fail e) >>= fun () -> (match Tls_lwt.Unix.epoch tls with | `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain - | `Error -> - Vmm_tls_lwt.close tls >>= fun () -> - Lwt.fail_with "error while getting epoch") + | `Error -> Lwt.fail_with "error while getting epoch") let read fd tls = (* now we busy read and process output *) @@ -64,9 +61,7 @@ let process fd tls = let handle ca (tls, addr) = client_auth ca tls addr >>= fun chain -> match Vmm_tls.handle addr my_version chain with - | Error (`Msg m) -> - Vmm_tls_lwt.close tls >>= fun () -> - Lwt.fail_with m + | Error (`Msg m) -> Lwt.fail_with m | Ok (name, policies, cmd) -> let sock, next = Vmm_commands.endpoint cmd in connect (Vmm_core.socket_path sock) >>= fun fd -> @@ -98,7 +93,6 @@ let handle ca (tls, addr) = header, `Failure msg in Vmm_tls_lwt.write_tls tls wire >>= fun _ -> - Vmm_tls_lwt.close tls >>= fun () -> Vmm_lwt.safe_close fd >>= fun () -> Lwt.fail_with msg end @@ -110,14 +104,12 @@ let handle ca (tls, addr) = in Vmm_lwt.write_wire fd wire >>= function | Error `Exception -> - Vmm_tls_lwt.close tls >>= fun () -> Vmm_lwt.safe_close fd >>= fun () -> Lwt.return (Error (`Msg "couldn't write")) | Ok () -> (match next with | `Read -> read fd tls | `End -> process fd tls) >>= fun res -> - Vmm_tls_lwt.close tls >>= fun () -> Vmm_lwt.safe_close fd >|= fun () -> res @@ -150,17 +142,18 @@ let jump _ cacert cert priv_key port = Lwt.catch (fun () -> Tls_lwt.Unix.server_of_fd config fd >|= fun t -> (t, addr)) (fun exn -> - Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () -> + Vmm_lwt.safe_close fd >>= fun () -> Lwt.fail exn) >>= fun t -> Lwt.async (fun () -> Lwt.catch - (fun () -> handle ca t >|= function - | Error (`Msg msg) -> Logs.err (fun m -> m "error in handle %s" msg) - | Ok () -> ()) + (fun () -> + (handle ca t >|= function + | Error (`Msg msg) -> Logs.err (fun m -> m "error in handle %s" msg) + | Ok () -> ()) >>= fun () -> + Vmm_tls_lwt.close (fst t)) (fun e -> - Logs.err (fun m -> m "error while handle() %s" - (Printexc.to_string e)) ; - Lwt.return_unit)) ; + Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ; + Vmm_tls_lwt.close (fst t))) ; loop ()) (function | Unix.Unix_error (e, f, _) ->