revise vmmd_tls file descriptor handling

This commit is contained in:
Hannes Mehnert 2018-10-29 19:00:13 +01:00
parent c669be8e02
commit 824f5f3418

View file

@ -30,13 +30,10 @@ let client_auth ca tls addr =
| Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a))
| Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f))
| exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ;
Vmm_tls_lwt.close tls >>= fun () ->
Lwt.fail e) >>= fun () ->
(match Tls_lwt.Unix.epoch tls with
| `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain
| `Error ->
Vmm_tls_lwt.close tls >>= fun () ->
Lwt.fail_with "error while getting epoch")
| `Error -> Lwt.fail_with "error while getting epoch")
let read fd tls =
(* now we busy read and process output *)
@ -64,9 +61,7 @@ let process fd tls =
let handle ca (tls, addr) =
client_auth ca tls addr >>= fun chain ->
match Vmm_tls.handle addr my_version chain with
| Error (`Msg m) ->
Vmm_tls_lwt.close tls >>= fun () ->
Lwt.fail_with m
| Error (`Msg m) -> Lwt.fail_with m
| Ok (name, policies, cmd) ->
let sock, next = Vmm_commands.endpoint cmd in
connect (Vmm_core.socket_path sock) >>= fun fd ->
@ -98,7 +93,6 @@ let handle ca (tls, addr) =
header, `Failure msg
in
Vmm_tls_lwt.write_tls tls wire >>= fun _ ->
Vmm_tls_lwt.close tls >>= fun () ->
Vmm_lwt.safe_close fd >>= fun () ->
Lwt.fail_with msg
end
@ -110,14 +104,12 @@ let handle ca (tls, addr) =
in
Vmm_lwt.write_wire fd wire >>= function
| Error `Exception ->
Vmm_tls_lwt.close tls >>= fun () ->
Vmm_lwt.safe_close fd >>= fun () ->
Lwt.return (Error (`Msg "couldn't write"))
| Ok () ->
(match next with
| `Read -> read fd tls
| `End -> process fd tls) >>= fun res ->
Vmm_tls_lwt.close tls >>= fun () ->
Vmm_lwt.safe_close fd >|= fun () ->
res
@ -150,17 +142,18 @@ let jump _ cacert cert priv_key port =
Lwt.catch
(fun () -> Tls_lwt.Unix.server_of_fd config fd >|= fun t -> (t, addr))
(fun exn ->
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () ->
Vmm_lwt.safe_close fd >>= fun () ->
Lwt.fail exn) >>= fun t ->
Lwt.async (fun () ->
Lwt.catch
(fun () -> handle ca t >|= function
(fun () ->
(handle ca t >|= function
| Error (`Msg msg) -> Logs.err (fun m -> m "error in handle %s" msg)
| Ok () -> ())
| Ok () -> ()) >>= fun () ->
Vmm_tls_lwt.close (fst t))
(fun e ->
Logs.err (fun m -> m "error while handle() %s"
(Printexc.to_string e)) ;
Lwt.return_unit)) ;
Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ;
Vmm_tls_lwt.close (fst t))) ;
loop ())
(function
| Unix.Unix_error (e, f, _) ->