revise vmmd_tls file descriptor handling
This commit is contained in:
parent
c669be8e02
commit
824f5f3418
|
@ -30,13 +30,10 @@ let client_auth ca tls addr =
|
||||||
| Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a))
|
| Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a))
|
||||||
| Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f))
|
| Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f))
|
||||||
| exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ;
|
| exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ;
|
||||||
Vmm_tls_lwt.close tls >>= fun () ->
|
|
||||||
Lwt.fail e) >>= fun () ->
|
Lwt.fail e) >>= fun () ->
|
||||||
(match Tls_lwt.Unix.epoch tls with
|
(match Tls_lwt.Unix.epoch tls with
|
||||||
| `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain
|
| `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain
|
||||||
| `Error ->
|
| `Error -> Lwt.fail_with "error while getting epoch")
|
||||||
Vmm_tls_lwt.close tls >>= fun () ->
|
|
||||||
Lwt.fail_with "error while getting epoch")
|
|
||||||
|
|
||||||
let read fd tls =
|
let read fd tls =
|
||||||
(* now we busy read and process output *)
|
(* now we busy read and process output *)
|
||||||
|
@ -64,9 +61,7 @@ let process fd tls =
|
||||||
let handle ca (tls, addr) =
|
let handle ca (tls, addr) =
|
||||||
client_auth ca tls addr >>= fun chain ->
|
client_auth ca tls addr >>= fun chain ->
|
||||||
match Vmm_tls.handle addr my_version chain with
|
match Vmm_tls.handle addr my_version chain with
|
||||||
| Error (`Msg m) ->
|
| Error (`Msg m) -> Lwt.fail_with m
|
||||||
Vmm_tls_lwt.close tls >>= fun () ->
|
|
||||||
Lwt.fail_with m
|
|
||||||
| Ok (name, policies, cmd) ->
|
| Ok (name, policies, cmd) ->
|
||||||
let sock, next = Vmm_commands.endpoint cmd in
|
let sock, next = Vmm_commands.endpoint cmd in
|
||||||
connect (Vmm_core.socket_path sock) >>= fun fd ->
|
connect (Vmm_core.socket_path sock) >>= fun fd ->
|
||||||
|
@ -98,7 +93,6 @@ let handle ca (tls, addr) =
|
||||||
header, `Failure msg
|
header, `Failure msg
|
||||||
in
|
in
|
||||||
Vmm_tls_lwt.write_tls tls wire >>= fun _ ->
|
Vmm_tls_lwt.write_tls tls wire >>= fun _ ->
|
||||||
Vmm_tls_lwt.close tls >>= fun () ->
|
|
||||||
Vmm_lwt.safe_close fd >>= fun () ->
|
Vmm_lwt.safe_close fd >>= fun () ->
|
||||||
Lwt.fail_with msg
|
Lwt.fail_with msg
|
||||||
end
|
end
|
||||||
|
@ -110,14 +104,12 @@ let handle ca (tls, addr) =
|
||||||
in
|
in
|
||||||
Vmm_lwt.write_wire fd wire >>= function
|
Vmm_lwt.write_wire fd wire >>= function
|
||||||
| Error `Exception ->
|
| Error `Exception ->
|
||||||
Vmm_tls_lwt.close tls >>= fun () ->
|
|
||||||
Vmm_lwt.safe_close fd >>= fun () ->
|
Vmm_lwt.safe_close fd >>= fun () ->
|
||||||
Lwt.return (Error (`Msg "couldn't write"))
|
Lwt.return (Error (`Msg "couldn't write"))
|
||||||
| Ok () ->
|
| Ok () ->
|
||||||
(match next with
|
(match next with
|
||||||
| `Read -> read fd tls
|
| `Read -> read fd tls
|
||||||
| `End -> process fd tls) >>= fun res ->
|
| `End -> process fd tls) >>= fun res ->
|
||||||
Vmm_tls_lwt.close tls >>= fun () ->
|
|
||||||
Vmm_lwt.safe_close fd >|= fun () ->
|
Vmm_lwt.safe_close fd >|= fun () ->
|
||||||
res
|
res
|
||||||
|
|
||||||
|
@ -150,17 +142,18 @@ let jump _ cacert cert priv_key port =
|
||||||
Lwt.catch
|
Lwt.catch
|
||||||
(fun () -> Tls_lwt.Unix.server_of_fd config fd >|= fun t -> (t, addr))
|
(fun () -> Tls_lwt.Unix.server_of_fd config fd >|= fun t -> (t, addr))
|
||||||
(fun exn ->
|
(fun exn ->
|
||||||
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () ->
|
Vmm_lwt.safe_close fd >>= fun () ->
|
||||||
Lwt.fail exn) >>= fun t ->
|
Lwt.fail exn) >>= fun t ->
|
||||||
Lwt.async (fun () ->
|
Lwt.async (fun () ->
|
||||||
Lwt.catch
|
Lwt.catch
|
||||||
(fun () -> handle ca t >|= function
|
(fun () ->
|
||||||
|
(handle ca t >|= function
|
||||||
| Error (`Msg msg) -> Logs.err (fun m -> m "error in handle %s" msg)
|
| Error (`Msg msg) -> Logs.err (fun m -> m "error in handle %s" msg)
|
||||||
| Ok () -> ())
|
| Ok () -> ()) >>= fun () ->
|
||||||
|
Vmm_tls_lwt.close (fst t))
|
||||||
(fun e ->
|
(fun e ->
|
||||||
Logs.err (fun m -> m "error while handle() %s"
|
Logs.err (fun m -> m "error while handle() %s" (Printexc.to_string e)) ;
|
||||||
(Printexc.to_string e)) ;
|
Vmm_tls_lwt.close (fst t))) ;
|
||||||
Lwt.return_unit)) ;
|
|
||||||
loop ())
|
loop ())
|
||||||
(function
|
(function
|
||||||
| Unix.Unix_error (e, f, _) ->
|
| Unix.Unix_error (e, f, _) ->
|
||||||
|
|
Loading…
Reference in a new issue