revise vmmd_tls file descriptor handling
This commit is contained in:
parent
c669be8e02
commit
824f5f3418
|
@ -30,13 +30,10 @@ let client_auth ca tls addr =
|
|||
| Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a))
|
||||
| Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f))
|
||||
| exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ;
|
||||
Vmm_tls_lwt.close tls >>= fun () ->
|
||||
Lwt.fail e) >>= fun () ->
|
||||
(match Tls_lwt.Unix.epoch tls with
|
||||
| `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain
|
||||
| `Error ->
|
||||
Vmm_tls_lwt.close tls >>= fun () ->
|
||||
Lwt.fail_with "error while getting epoch")
|
||||
| `Error -> Lwt.fail_with "error while getting epoch")
|
||||
|
||||
let read fd tls =
|
||||
(* now we busy read and process output *)
|
||||
|
@ -64,9 +61,7 @@ let process fd tls =
|
|||
let handle ca (tls, addr) =
|
||||
client_auth ca tls addr >>= fun chain ->
|
||||
match Vmm_tls.handle addr my_version chain with
|
||||
| Error (`Msg m) ->
|
||||
Vmm_tls_lwt.close tls >>= fun () ->
|
||||
Lwt.fail_with m
|
||||
| Error (`Msg m) -> Lwt.fail_with m
|
||||
| Ok (name, policies, cmd) ->
|
||||
let sock, next = Vmm_commands.endpoint cmd in
|
||||
connect (Vmm_core.socket_path sock) >>= fun fd ->
|
||||
|
@ -98,7 +93,6 @@ let handle ca (tls, addr) =
|
|||
header, `Failure msg
|
||||
in
|
||||
Vmm_tls_lwt.write_tls tls wire >>= fun _ ->
|
||||
Vmm_tls_lwt.close tls >>= fun () ->
|
||||
Vmm_lwt.safe_close fd >>= fun () ->
|
||||
Lwt.fail_with msg
|
||||
end
|
||||
|
@ -110,14 +104,12 @@ let handle ca (tls, addr) =
|
|||
in
|
||||
Vmm_lwt.write_wire fd wire >>= function
|
||||
| Error `Exception ->
|
||||
Vmm_tls_lwt.close tls >>= fun () ->
|
||||
Vmm_lwt.safe_close fd >>= fun () ->
|
||||
Lwt.return (Error (`Msg "couldn't write"))
|
||||
| Ok () ->
|
||||
(match next with
|
||||
| `Read -> read fd tls
|
||||
| `End -> process fd tls) >>= fun res ->
|
||||
Vmm_tls_lwt.close tls >>= fun () ->
|
||||
Vmm_lwt.safe_close fd >|= fun () ->
|
||||
res
|
||||
|
||||
|
@ -150,17 +142,18 @@ let jump _ cacert cert priv_key port =
|
|||
Lwt.catch
|
||||
(fun () -> Tls_lwt.Unix.server_of_fd config fd >|= fun t -> (t, addr))
|
||||
(fun exn ->
|
||||
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () ->
|
||||
Vmm_lwt.safe_close fd >>= fun () ->
|
||||
Lwt.fail exn) >>= fun t ->
|
||||
Lwt.async (fun () ->
|
||||
Lwt.catch
|
||||
(fun () -> handle ca t >|= function
|
||||
(fun () ->
|
||||
(handle ca t >|= function
|
||||
| Error (`Msg msg) -> Logs.err (fun m -> m "error in handle %s" msg)
|
||||
| Ok () -> ())
|
||||
| Ok () -> ()) >>= fun () ->
|
||||
Vmm_tls_lwt.close (fst t))
|
||||
(fun e ->
|
||||
Logs.err (fun m -> m "error while handle() %s"
|
||||
(Printexc.to_string e)) ;
|
||||
Lwt.return_unit)) ;
|
||||
Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ;
|
||||
Vmm_tls_lwt.close (fst t))) ;
|
||||
loop ())
|
||||
(function
|
||||
| Unix.Unix_error (e, f, _) ->
|
||||
|
|
Loading…
Reference in a new issue