vmmd_tls: close sockets appropriately
This commit is contained in:
parent
5e921d7345
commit
296b7a9b01
|
@ -30,12 +30,12 @@ let client_auth ca tls addr =
|
|||
| Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a))
|
||||
| Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f))
|
||||
| exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ;
|
||||
Tls_lwt.Unix.close tls >>= fun () ->
|
||||
Vmm_tls_lwt.close tls >>= fun () ->
|
||||
Lwt.fail e) >>= fun () ->
|
||||
(match Tls_lwt.Unix.epoch tls with
|
||||
| `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain
|
||||
| `Error ->
|
||||
Tls_lwt.Unix.close tls >>= fun () ->
|
||||
Vmm_tls_lwt.close tls >>= fun () ->
|
||||
Lwt.fail_with "error while getting epoch")
|
||||
|
||||
let read fd tls =
|
||||
|
@ -63,7 +63,9 @@ let process fd tls =
|
|||
let handle ca (tls, addr) =
|
||||
client_auth ca tls addr >>= fun chain ->
|
||||
match Vmm_tls.handle addr my_version chain with
|
||||
| Error (`Msg m) -> Lwt.fail_with m
|
||||
| Error (`Msg m) ->
|
||||
Vmm_tls_lwt.close tls >>= fun () ->
|
||||
Lwt.fail_with m
|
||||
| Ok (name, cmd) ->
|
||||
let sock, next = Vmm_commands.endpoint cmd in
|
||||
connect (Vmm_core.socket_path sock) >>= fun fd ->
|
||||
|
@ -73,11 +75,15 @@ let handle ca (tls, addr) =
|
|||
(header, `Command cmd)
|
||||
in
|
||||
Vmm_lwt.write_wire fd wire >>= function
|
||||
| Error `Exception -> Lwt.return (Error (`Msg "couldn't write"))
|
||||
| Error `Exception ->
|
||||
Vmm_tls_lwt.close tls >>= fun () ->
|
||||
Vmm_lwt.safe_close fd >>= fun () ->
|
||||
Lwt.return (Error (`Msg "couldn't write"))
|
||||
| Ok () ->
|
||||
(match next with
|
||||
| `Read -> read fd tls
|
||||
| `End -> process fd tls) >>= fun res ->
|
||||
Vmm_tls_lwt.close tls >>= fun () ->
|
||||
Vmm_lwt.safe_close fd >|= fun () ->
|
||||
res
|
||||
|
||||
|
|
|
@ -62,3 +62,8 @@ let write_tls s wire =
|
|||
| e ->
|
||||
Logs.err (fun m -> m "TLS write exception %s" (Printexc.to_string e)) ;
|
||||
Lwt.return (Error `Exception))
|
||||
|
||||
let close tls =
|
||||
Lwt.catch
|
||||
(fun () -> Tls_lwt.Unix.close tls)
|
||||
(fun _ -> Lwt.return_unit)
|
||||
|
|
|
@ -5,3 +5,5 @@ val read_tls : Tls_lwt.Unix.t ->
|
|||
|
||||
val write_tls :
|
||||
Tls_lwt.Unix.t -> Vmm_commands.wire -> (unit, [> `Exception ]) result Lwt.t
|
||||
|
||||
val close : Tls_lwt.Unix.t -> unit Lwt.t
|
||||
|
|
Loading…
Reference in a new issue