vmmd_tls: close sockets appropriately

This commit is contained in:
Hannes Mehnert 2018-10-28 19:19:38 +01:00
parent 5e921d7345
commit 296b7a9b01
3 changed files with 17 additions and 4 deletions

View file

@ -30,12 +30,12 @@ let client_auth ca tls addr =
| Tls_lwt.Tls_alert a -> Logs.err (fun m -> m "TLS ALERT %s" (Tls.Packet.alert_type_to_string a))
| Tls_lwt.Tls_failure f -> Logs.err (fun m -> m "TLS FAILURE %s" (Tls.Engine.string_of_failure f))
| exn -> Logs.err (fun m -> m "%s" (Printexc.to_string exn))) ;
Tls_lwt.Unix.close tls >>= fun () ->
Vmm_tls_lwt.close tls >>= fun () ->
Lwt.fail e) >>= fun () ->
(match Tls_lwt.Unix.epoch tls with
| `Ok epoch -> Lwt.return epoch.Tls.Core.peer_certificate_chain
| `Error ->
Tls_lwt.Unix.close tls >>= fun () ->
Vmm_tls_lwt.close tls >>= fun () ->
Lwt.fail_with "error while getting epoch")
let read fd tls =
@ -63,7 +63,9 @@ let process fd tls =
let handle ca (tls, addr) =
client_auth ca tls addr >>= fun chain ->
match Vmm_tls.handle addr my_version chain with
| Error (`Msg m) -> Lwt.fail_with m
| Error (`Msg m) ->
Vmm_tls_lwt.close tls >>= fun () ->
Lwt.fail_with m
| Ok (name, cmd) ->
let sock, next = Vmm_commands.endpoint cmd in
connect (Vmm_core.socket_path sock) >>= fun fd ->
@ -73,11 +75,15 @@ let handle ca (tls, addr) =
(header, `Command cmd)
in
Vmm_lwt.write_wire fd wire >>= function
| Error `Exception -> Lwt.return (Error (`Msg "couldn't write"))
| Error `Exception ->
Vmm_tls_lwt.close tls >>= fun () ->
Vmm_lwt.safe_close fd >>= fun () ->
Lwt.return (Error (`Msg "couldn't write"))
| Ok () ->
(match next with
| `Read -> read fd tls
| `End -> process fd tls) >>= fun res ->
Vmm_tls_lwt.close tls >>= fun () ->
Vmm_lwt.safe_close fd >|= fun () ->
res

View file

@ -62,3 +62,8 @@ let write_tls s wire =
| e ->
Logs.err (fun m -> m "TLS write exception %s" (Printexc.to_string e)) ;
Lwt.return (Error `Exception))
let close tls =
Lwt.catch
(fun () -> Tls_lwt.Unix.close tls)
(fun _ -> Lwt.return_unit)

View file

@ -5,3 +5,5 @@ val read_tls : Tls_lwt.Unix.t ->
val write_tls :
Tls_lwt.Unix.t -> Vmm_commands.wire -> (unit, [> `Exception ]) result Lwt.t
val close : Tls_lwt.Unix.t -> unit Lwt.t