revise vmmd_tls file descriptor handling

This commit is contained in:
Hannes Mehnert 2018-10-29 19:00:13 +01:00
parent c669be8e02
commit 824f5f3418

View file

@ -30,13 +30,10 @@ let client_auth ca tls addr =
| Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a)) | Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a))
| Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f)) | Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f))
| exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ; | exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ;
Vmm_tls_lwt.close tls >>= fun () ->
Lwt.fail e) >>= fun () -> Lwt.fail e) >>= fun () ->
(match Tls_lwt.Unix.epoch tls with (match Tls_lwt.Unix.epoch tls with
| `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain | `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain
| `Error -> | `Error -> Lwt.fail_with "error while getting epoch")
Vmm_tls_lwt.close tls >>= fun () ->
Lwt.fail_with "error while getting epoch")
let read fd tls = let read fd tls =
(* now we busy read and process output *) (* now we busy read and process output *)
@ -64,9 +61,7 @@ let process fd tls =
let handle ca (tls, addr) = let handle ca (tls, addr) =
client_auth ca tls addr >>= fun chain -> client_auth ca tls addr >>= fun chain ->
match Vmm_tls.handle addr my_version chain with match Vmm_tls.handle addr my_version chain with
| Error (`Msg m) -> | Error (`Msg m) -> Lwt.fail_with m
Vmm_tls_lwt.close tls >>= fun () ->
Lwt.fail_with m
| Ok (name, policies, cmd) -> | Ok (name, policies, cmd) ->
let sock, next = Vmm_commands.endpoint cmd in let sock, next = Vmm_commands.endpoint cmd in
connect (Vmm_core.socket_path sock) >>= fun fd -> connect (Vmm_core.socket_path sock) >>= fun fd ->
@ -98,7 +93,6 @@ let handle ca (tls, addr) =
header, `Failure msg header, `Failure msg
in in
Vmm_tls_lwt.write_tls tls wire >>= fun _ -> Vmm_tls_lwt.write_tls tls wire >>= fun _ ->
Vmm_tls_lwt.close tls >>= fun () ->
Vmm_lwt.safe_close fd >>= fun () -> Vmm_lwt.safe_close fd >>= fun () ->
Lwt.fail_with msg Lwt.fail_with msg
end end
@ -110,14 +104,12 @@ let handle ca (tls, addr) =
in in
Vmm_lwt.write_wire fd wire >>= function Vmm_lwt.write_wire fd wire >>= function
| Error `Exception -> | Error `Exception ->
Vmm_tls_lwt.close tls >>= fun () ->
Vmm_lwt.safe_close fd >>= fun () -> Vmm_lwt.safe_close fd >>= fun () ->
Lwt.return (Error (`Msg "couldn't write")) Lwt.return (Error (`Msg "couldn't write"))
| Ok () -> | Ok () ->
(match next with (match next with
| `Read -> read fd tls | `Read -> read fd tls
| `End -> process fd tls) >>= fun res -> | `End -> process fd tls) >>= fun res ->
Vmm_tls_lwt.close tls >>= fun () ->
Vmm_lwt.safe_close fd >|= fun () -> Vmm_lwt.safe_close fd >|= fun () ->
res res
@ -150,17 +142,18 @@ let jump _ cacert cert priv_key port =
Lwt.catch Lwt.catch
(fun () -> Tls_lwt.Unix.server_of_fd config fd >|= fun t -> (t, addr)) (fun () -> Tls_lwt.Unix.server_of_fd config fd >|= fun t -> (t, addr))
(fun exn -> (fun exn ->
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () -> Vmm_lwt.safe_close fd >>= fun () ->
Lwt.fail exn) >>= fun t -> Lwt.fail exn) >>= fun t ->
Lwt.async (fun () -> Lwt.async (fun () ->
Lwt.catch Lwt.catch
(fun () -> handle ca t >|= function (fun () ->
| Error (`Msg msg) -> Logs.err (fun m -> m "error in handle %s" msg) (handle ca t >|= function
| Ok () -> ()) | Error (`Msg msg) -> Logs.err (fun m -> m "error in handle %s" msg)
| Ok () -> ()) >>= fun () ->
Vmm_tls_lwt.close (fst t))
(fun e -> (fun e ->
Logs.err (fun m -> m "error while handle() %s" Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ;
(Printexc.to_string e)) ; Vmm_tls_lwt.close (fst t))) ;
Lwt.return_unit)) ;
loop ()) loop ())
(function (function
| Unix.Unix_error (e, f, _) -> | Unix.Unix_error (e, f, _) ->