diff --git a/app/vmmd_tls.ml b/app/vmmd_tls.ml index 8de112f..49f8b8c 100644 --- a/app/vmmd_tls.ml +++ b/app/vmmd_tls.ml @@ -30,12 +30,12 @@ let client_auth ca tls addr = | Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a)) | Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f)) | exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ; - Tls_lwt.Unix.close tls >>= fun () -> + Vmm_tls_lwt.close tls >>= fun () -> Lwt.fail e) >>= fun () -> (match Tls_lwt.Unix.epoch tls with | `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain | `Error -> - Tls_lwt.Unix.close tls >>= fun () -> + Vmm_tls_lwt.close tls >>= fun () -> Lwt.fail_with "error while getting epoch") let read fd tls = @@ -63,7 +63,9 @@ let process fd tls = let handle ca (tls, addr) = client_auth ca tls addr >>= fun chain -> match Vmm_tls.handle addr my_version chain with - | Error (`Msg m) -> Lwt.fail_with m + | Error (`Msg m) -> + Vmm_tls_lwt.close tls >>= fun () -> + Lwt.fail_with m | Ok (name, cmd) -> let sock, next = Vmm_commands.endpoint cmd in connect (Vmm_core.socket_path sock) >>= fun fd -> @@ -73,11 +75,15 @@ let handle ca (tls, addr) = (header, `Command cmd) in Vmm_lwt.write_wire fd wire >>= function - | Error `Exception -> Lwt.return (Error (`Msg "couldn't write")) + | Error `Exception -> + Vmm_tls_lwt.close tls >>= fun () -> + Vmm_lwt.safe_close fd >>= fun () -> + Lwt.return (Error (`Msg "couldn't write")) | Ok () -> (match next with | `Read -> read fd tls | `End -> process fd tls) >>= fun res -> + Vmm_tls_lwt.close tls >>= fun () -> Vmm_lwt.safe_close fd >|= fun () -> res diff --git a/src/vmm_tls_lwt.ml b/src/vmm_tls_lwt.ml index 4bd3daf..51c74d9 100644 --- a/src/vmm_tls_lwt.ml +++ b/src/vmm_tls_lwt.ml @@ -62,3 +62,8 @@ let write_tls s wire = | e -> Logs.err (fun m -> m "TLS write exception %s" (Printexc.to_string e)) ; Lwt.return (Error `Exception)) + +let close tls = + Lwt.catch + (fun () -> Tls_lwt.Unix.close tls) + (fun _ -> Lwt.return_unit) diff --git a/src/vmm_tls_lwt.mli b/src/vmm_tls_lwt.mli index 39886d6..bf6762d 100644 --- a/src/vmm_tls_lwt.mli +++ b/src/vmm_tls_lwt.mli @@ -5,3 +5,5 @@ val read_tls : Tls_lwt.Unix.t -> val write_tls : Tls_lwt.Unix.t -> Vmm_commands.wire -> (unit, [> `Exception ]) result Lwt.t + +val close : Tls_lwt.Unix.t -> unit Lwt.t