diff --git a/app/vmm_client.ml b/app/vmm_client.ml index a6bc641..acd7cc7 100644 --- a/app/vmm_client.ml +++ b/app/vmm_client.ml @@ -52,18 +52,14 @@ let process db hdr data = | Error (`Msg msg) -> Logs.err (fun m -> m "error while processing: %s" msg) let rec read_tls_write_cons db t = - Lwt.catch (fun () -> - Vmm_tls.read_tls t >>= function - | Error (`Msg msg) -> - Logs.err (fun m -> m "error while reading %s" msg) ; - read_tls_write_cons db t - | Ok (hdr, data) -> - process db hdr data ; - read_tls_write_cons db t) - (fun e -> - Logs.err (fun m -> m "exception reading TLS stream %s" - (Printexc.to_string e)) ; - Tls_lwt.Unix.close t) + Vmm_tls.read_tls t >>= function + | Error (`Msg msg) -> + Logs.err (fun m -> m "error while reading %s" msg) ; + read_tls_write_cons db t + | Error _ -> Logs.err (fun m -> m "exception while reading") ; Lwt.return_unit + | Ok (hdr, data) -> + process db hdr data ; + read_tls_write_cons db t let rec read_cons_write_tls db t = Lwt.catch (fun () -> @@ -77,10 +73,14 @@ let rec read_cons_write_tls db t = | Some cmd -> let out = Vmm_wire.Client.cmd ?arg cmd !command my_version in command := succ !command ; - Vmm_tls.write_tls t out >>= fun () -> - Logs.debug (fun m -> m "wrote %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ; - read_cons_write_tls db t) - (fun _ -> Lwt.return_unit) + Vmm_tls.write_tls t out >>= function + | Error _ -> Logs.err (fun m -> m "exception while writing") ; Lwt.return_unit + | Ok () -> + Logs.debug (fun m -> m "wrote %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ; + read_cons_write_tls db t) + (fun e -> + Logs.err (fun m -> m "exception %s in read_cons_write_tls" (Printexc.to_string e)) ; + Lwt.return_unit) let client cas host port cert priv_key db = Nocrypto_entropy_lwt.initialize () >>= fun () -> diff --git a/app/vmm_console.ml b/app/vmm_console.ml index a0c3ff2..5f2baa4 100644 --- a/app/vmm_console.ml +++ b/app/vmm_console.ml @@ -37,8 +37,11 @@ let read_console s name ring channel () = (if String.Set.mem name !active then Vmm_lwt.write_raw s (data my_version name t line) else - Lwt.return_unit) >>= fun () -> - loop () + Lwt.return (Ok ())) >>= function + | Ok () -> loop () + | Error _ -> + Logs.err (fun m -> m "error reading console") ; + Lwt_io.close channel in loop ()) (fun e -> @@ -102,14 +105,20 @@ let history s name since = let entries = Vmm_ring.read_history r since in Logs.debug (fun m -> m "found %d history" (List.length entries)) ; Lwt_list.iter_s (fun (i, v) -> - Vmm_lwt.write_raw s (data my_version name i v)) entries >|= fun () -> + Vmm_lwt.write_raw s (data my_version name i v) >|= fun _ -> ()) + entries >|= fun () -> Ok "success" let handle s addr () = Logs.info (fun m -> m "handling connection %a" pp_sockaddr addr) ; let rec loop () = Vmm_lwt.read_exactly s >>= function - | Error (`Msg msg) -> Logs.err (fun m -> m "error while reading %s" msg) ; loop () + | Error (`Msg msg) -> + Logs.err (fun m -> m "error while reading %s" msg) ; + loop () + | Error _ -> + Logs.err (fun m -> m "exception while reading") ; + Lwt.return_unit | Ok (hdr, data) -> (if not (version_eq hdr.version my_version) then Lwt.return (Error (`Msg "ignoring data with bad version")) @@ -138,10 +147,14 @@ let handle s addr () = | Ok msg -> Vmm_lwt.write_raw s (success ~msg hdr.id my_version) | Error (`Msg msg) -> Logs.err (fun m -> m "error while processing command: %s" msg) ; - Vmm_lwt.write_raw s (fail ~msg hdr.id my_version)) >>= fun () -> - loop () + Vmm_lwt.write_raw s (fail ~msg hdr.id my_version)) >>= function + | Ok () -> loop () + | Error _ -> + Logs.err (fun m -> m "exception while writing to socket") ; + Lwt.return_unit in - loop () + loop () >>= fun () -> + Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit) let jump _ file = Sys.(set_signal sigpipe Signal_ignore) ; diff --git a/app/vmm_log.ml b/app/vmm_log.ml index 46bd879..48583b1 100644 --- a/app/vmm_log.ml +++ b/app/vmm_log.ml @@ -29,8 +29,13 @@ let write_complete s str = in w 0 +let pp_sockaddr ppf = function + | Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str + | Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d" + (Unix.string_of_inet_addr addr) port + let handle fd ring s addr () = - Logs.info (fun m -> m "handling connection") ; + Logs.info (fun m -> m "handling connection from %a" pp_sockaddr addr) ; let str = Fmt.strf "%a: CONNECT\n" (Ptime.pp_human ~tz_offset_s:0 ()) (Ptime_clock.now ()) in write_complete fd str >>= fun () -> let rec loop () = @@ -38,52 +43,79 @@ let handle fd ring s addr () = | Error (`Msg e) -> Logs.err (fun m -> m "error while reading %s" e) ; loop () + | Error _ -> + Logs.err (fun m -> m "exception while reading") ; + Lwt.return_unit | Ok (hdr, data) -> - (if not (version_eq hdr.version my_version) then - Lwt.return (Error (`Msg "unknown version")) - else match int_to_op hdr.tag with - | Some Data -> - ( match decode_ts data with - | Ok ts -> Vmm_ring.write ring (ts, data) - | Error _ -> ()) ; - write_complete fd data >>= fun () -> - Lwt.return (Ok None) - | Some History -> - begin match decode_str data with - | Error e -> Lwt.return (Error e) - | Ok (str, off) -> match decode_ts ~off data with - | Error e -> Lwt.return (Error e) - | Ok ts -> - let elements = Vmm_ring.read_history ring ts in - let res = List.fold_left (fun acc (_, x) -> - match Vmm_wire.Log.decode_log_hdr (Cstruct.of_string x) with - | Ok (hdr, _) -> - Logs.debug (fun m -> m "found an entry: %a" (Vmm_core.Log.pp_hdr []) hdr) ; - if String.equal str (Vmm_core.string_of_id hdr.Vmm_core.Log.context) then - x :: acc - else - acc - | _ -> acc) - [] elements - in - (* just need a wrapper in tag = Log.Data, id = reqid *) - Lwt_list.iter_s (fun x -> - let length = String.length x in - let hdr = Vmm_wire.create_header { length ; id = hdr.id ; tag = op_to_int Data ; version = my_version } in - Vmm_lwt.write_raw s (Cstruct.to_string hdr ^ x)) - (List.rev res) >>= fun () -> - Lwt.return (Ok None) - end - | _ -> - Logs.err (fun m -> m "didn't understand log command %d" hdr.tag) ; - Lwt.return (Error (`Msg "unknown command"))) >>= (function - | Ok msg -> Vmm_lwt.write_raw s (success ?msg hdr.id my_version) - | Error (`Msg msg) -> - Logs.err (fun m -> m "error while processing: %s" msg) ; - Vmm_lwt.write_raw s (fail ~msg hdr.id my_version)) >>= fun () -> - loop () + let out = + (if not (version_eq hdr.version my_version) then + Error (`Msg "unknown version") + else match int_to_op hdr.tag with + | Some Data -> + (match decode_ts data with + | Ok ts -> Vmm_ring.write ring (ts, data) + | Error _ -> + Logs.warn (fun m -> m "ignoring error while decoding timestamp %s" data)) ; + Ok (`Data data) + | Some History -> + begin match decode_str data with + | Error e -> Error e + | Ok (str, off) -> match decode_ts ~off data with + | Error e -> Error e + | Ok ts -> + let elements = Vmm_ring.read_history ring ts in + let res = List.fold_left (fun acc (_, x) -> + match Vmm_wire.Log.decode_log_hdr (Cstruct.of_string x) with + | Ok (hdr, _) -> + Logs.debug (fun m -> m "found an entry: %a" (Vmm_core.Log.pp_hdr []) hdr) ; + if String.equal str (Vmm_core.string_of_id hdr.Vmm_core.Log.context) then + x :: acc + else + acc + | _ -> acc) + [] elements + in + (* just need a wrapper in tag = Log.Data, id = reqid *) + let out = + List.fold_left (fun acc x -> + let length = String.length x in + let hdr = Vmm_wire.create_header { length ; id = hdr.id ; tag = op_to_int Data ; version = my_version } in + (Cstruct.to_string hdr ^ x) :: acc) + [] (List.rev res) + in + Ok (`Out out) + end + | _ -> + Error (`Msg "unknown command")) + in + match out with + | Error (`Msg msg) -> + begin + Logs.err (fun m -> m "error while processing: %s" msg) ; + Vmm_lwt.write_raw s (fail ~msg hdr.id my_version) >>= function + | Error _ -> Logs.err (fun m -> m "error0 while writing") ; Lwt.return_unit + | Ok () -> loop () + end + | Ok (`Data data) -> + begin + write_complete fd data >>= fun () -> + Vmm_lwt.write_raw s (success hdr.id my_version) >>= function + | Error _ -> Logs.err (fun m -> m "error1 while writing") ; Lwt.return_unit + | Ok () -> loop () + end + | Ok (`Out datas) -> + Lwt_list.fold_left_s (fun r x -> match r with + | Error e -> Lwt.return (Error e) + | Ok () -> Vmm_lwt.write_raw s x) + (Ok ()) datas >>= function + | Error _ -> Logs.err (fun m -> m "error2 while writing") ; Lwt.return_unit + | Ok () -> + Vmm_lwt.write_raw s (success hdr.id my_version) >>= function + | Error _ -> Logs.err (fun m -> m "error3 while writing") ; Lwt.return_unit + | Ok () -> loop () in - Lwt.catch loop (fun e -> Lwt.return_unit) + loop () >>= fun () -> + Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit) let jump _ file sock = Sys.(set_signal sigpipe Signal_ignore) ; diff --git a/app/vmm_prometheus_stats.ml b/app/vmm_prometheus_stats.ml index a95f01d..113af0f 100644 --- a/app/vmm_prometheus_stats.ml +++ b/app/vmm_prometheus_stats.ml @@ -128,7 +128,9 @@ let process db tls hdr data = let out = Vmm_wire.Client.cmd `Info !command my_version in command := succ !command ; Logs.debug (fun m -> m "writing %a over TLS" Cstruct.hexdump_pp (Cstruct.of_string out)) ; - Vmm_tls.write_tls tls out + (Vmm_tls.write_tls tls out >|= function + | Ok () -> () + | Error _ -> Logs.err (fun m -> m "error while writing") ; ()) | _ -> let r = match hdr.tag with @@ -176,21 +178,23 @@ let process db tls hdr data = match r with | Ok `None -> Lwt.return_unit | Ok (`Sockaddr s) -> d s - | Ok (`Stat (fd, s, out)) -> Vmm_lwt.write_raw fd out >>= fun () -> d (fd, s) + | Ok (`Stat (fd, s, out)) -> + (Vmm_lwt.write_raw fd out >>= function + | Ok () -> d (fd, s) + | Error _ -> Logs.err (fun m -> m "exception while writing") ; Lwt.return_unit) | Error (`Msg msg) -> Logs.err (fun m -> m "error while processing: %s" msg) ; Lwt.return_unit let rec tls_listener db tls = - Lwt.catch (fun () -> - Vmm_tls.read_tls tls >>= function - | Error (`Msg msg) -> - Logs.err (fun m -> m "error while reading %s" msg) ; - Lwt.return (Ok ()) - | Ok (hdr, data) -> - process db tls hdr data >>= fun () -> - Lwt.return (Ok ())) - (fun e -> - Logs.err (fun m -> m "received exception in read_tls: %s" (Printexc.to_string e)) ; - Lwt.return (Error ())) >>= function + (Vmm_tls.read_tls tls >>= function + | Error (`Msg msg) -> + Logs.err (fun m -> m "error while reading %s" msg) ; + Lwt.return (Ok ()) + | Error _ -> + Logs.err (fun m -> m "received exception in read_tls") ; + Lwt.return (Error ()) + | Ok (hdr, data) -> + process db tls hdr data >>= fun () -> + Lwt.return (Ok ())) >>= function | Ok () -> tls_listener db tls | Error () -> Lwt.return_unit @@ -203,24 +207,32 @@ let hdr = (* wait for TCP connection, once received request stats from vmmd, and loop *) let rec tcp_listener db tcp tls = Lwt_unix.accept tcp >>= fun (cs, sockaddr) -> - Vmm_lwt.write_raw cs hdr >>= fun () -> - let l = List.length !known_vms in - let ip, port = match sockaddr with Lwt_unix.ADDR_INET (ip, port) -> ip, port | _ -> invalid_arg "unexpected" in - Logs.info (fun m -> m "connection from %s:%d with %d known" (Unix.string_of_inet_addr ip) port l) ; - (if l = 0 then - Lwt_unix.close cs - else begin - count := SM.add sockaddr (List.length !known_vms) !count ; - Lwt_list.iter_s - (fun vm -> - let vm_id = translate_name db vm in - let out = Vmm_wire.Client.cmd `Statistics ~arg:vm_id !command my_version in - t := IM.add !command (cs, sockaddr, vm) !t ; - command := succ !command ; - Vmm_tls.write_tls tls out) - !known_vms - end) >>= fun () -> - tcp_listener db tcp tls + Vmm_lwt.write_raw cs hdr >>= function + | Error _ -> Logs.err (fun m -> m "exception while accepting") ; Lwt.return_unit + | Ok () -> + let l = List.length !known_vms in + let ip, port = match sockaddr with Lwt_unix.ADDR_INET (ip, port) -> ip, port | _ -> invalid_arg "unexpected" in + Logs.info (fun m -> m "connection from %s:%d with %d known" (Unix.string_of_inet_addr ip) port l) ; + (if l = 0 then + Lwt_unix.close cs >|= fun () -> Error () + else begin + count := SM.add sockaddr (List.length !known_vms) !count ; + Lwt_list.fold_left_s + (fun r vm -> + match r with + | Error () -> Lwt.return (Error ()) + | Ok () -> + let vm_id = translate_name db vm in + let out = Vmm_wire.Client.cmd `Statistics ~arg:vm_id !command my_version in + t := IM.add !command (cs, sockaddr, vm) !t ; + command := succ !command ; + Vmm_tls.write_tls tls out >|= function + | Ok () -> Ok () + | Error _ -> Logs.err (fun m -> m "exception while writing") ; Error ()) + (Ok ()) !known_vms + end) >>= function + | Ok () -> tcp_listener db tcp tls + | Error () -> Lwt.return_unit let client cas host port cert priv_key db listen_ip listen_port = Nocrypto_entropy_lwt.initialize () >>= fun () -> diff --git a/app/vmmd.ml b/app/vmmd.ml index f048071..3c07261 100644 --- a/app/vmmd.ml +++ b/app/vmmd.ml @@ -2,14 +2,17 @@ open Lwt.Infix +let write_raw s data = + Vmm_lwt.write_raw s data >|= fun _ -> () + let write_tls state t data = - Lwt.catch (fun () -> Vmm_tls.write_tls (fst t) data) - (fun e -> - let state', out = Vmm_engine.handle_disconnect !state t in - state := state' ; - Lwt_list.iter_s (fun (s, data) -> Vmm_lwt.write_raw s data) out >>= fun () -> - Tls_lwt.Unix.close (fst t) >>= fun () -> - raise e) + Vmm_tls.write_tls (fst t) data >>= function + | Ok () -> Lwt.return_unit + | Error `Exception -> + let state', out = Vmm_engine.handle_disconnect !state t in + state := state' ; + Lwt_list.iter_s (fun (s, data) -> write_raw s data) out >>= fun () -> + Tls_lwt.Unix.close (fst t) let to_ipaddr (_, sa) = match sa with | Lwt_unix.ADDR_UNIX _ -> invalid_arg "cannot convert unix address" @@ -22,7 +25,7 @@ let pp_sockaddr ppf (_, sa) = match sa with let process state xs = Lwt_list.iter_s (function - | `Raw (s, str) -> Vmm_lwt.write_raw s str + | `Raw (s, str) -> write_raw s str | `Tls (s, str) -> write_tls state s str) xs @@ -73,19 +76,19 @@ let handle ca state t = | Error (`Msg msg) -> Logs.err (fun m -> m "reading client %a error: %s" pp_sockaddr t msg) ; loop () + | Error _ -> + Logs.err (fun m -> m "disconnect from %a" pp_sockaddr t) ; + let state', cons = Vmm_engine.handle_disconnect !state t in + state := state' ; + Lwt_list.iter_s (fun (s, data) -> write_raw s data) cons >>= fun () -> + Tls_lwt.Unix.close (fst t) | Ok (hdr, buf) -> let state', out = Vmm_engine.handle_command !state t prefix perms hdr buf in state := state' ; process state out >>= fun () -> loop () in - Lwt.catch loop - (fun e -> - let state', cons = Vmm_engine.handle_disconnect !state t in - state := state' ; - Lwt_list.iter_s (fun (s, data) -> Vmm_lwt.write_raw s data) cons >>= fun () -> - Tls_lwt.Unix.close (fst t) >>= fun () -> - raise e) + loop () | `Close socks -> Logs.debug (fun m -> m "closing session with %d active ones" (List.length socks)) ; Lwt_list.iter_s (fun (t, _) -> Tls_lwt.Unix.close t) socks >>= fun () -> @@ -105,18 +108,26 @@ let server_socket port = listen s 10 ; Lwt.return s -let init_exception () = - Lwt.async_exception_hook := (function - | Tls_lwt.Tls_failure a -> - Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a)) - | exn -> - Logs.err (fun m -> m "exception: %s" (Printexc.to_string exn))) +let init_sock dir name = + let c = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in + Lwt_unix.set_close_on_exec c ; + let addr = Fpath.(dir / name + "sock") in + Lwt.catch (fun () -> + Lwt_unix.(connect c (ADDR_UNIX (Fpath.to_string addr))) >|= fun () -> Some c) + (fun e -> + Logs.warn (fun m -> m "error %s connecting to socket %a" + (Printexc.to_string e) Fpath.pp addr) ; + (Lwt.catch (fun () -> Lwt_unix.close c) (fun _ -> Lwt.return_unit)) >|= fun () -> + None) let rec read_log state s = Vmm_lwt.read_exactly s >>= function | Error (`Msg msg) -> Logs.err (fun m -> m "reading log error %s" msg) ; read_log state s + | Error _ -> + Logs.err (fun m -> m "exception while reading log") ; + invalid_arg "log socket communication issue" | Ok (hdr, data) -> let state', outs = Vmm_engine.handle_log !state hdr data in state := state' ; @@ -128,6 +139,9 @@ let rec read_cons state s = | Error (`Msg msg) -> Logs.err (fun m -> m "reading console error %s" msg) ; read_cons state s + | Error _ -> + Logs.err (fun m -> m "exception while reading console socket") ; + invalid_arg "console socket communication issue" | Ok (hdr, data) -> let state', outs = Vmm_engine.handle_cons !state hdr data in state := state' ; @@ -139,6 +153,10 @@ let rec read_stats state s = | Error (`Msg msg) -> Logs.err (fun m -> m "reading stats error %s" msg) ; read_stats state s + | Error _ -> + Logs.err (fun m -> m "exception while reading stats") ; + Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit) >|= fun () -> + state := { !state with Vmm_engine.stats_socket = None } | Ok (hdr, data) -> let state', outs = Vmm_engine.handle_stat !state hdr data in state := state' ; @@ -156,23 +174,15 @@ let cmp_s (_, a) (_, b) = let jump _ dir cacert cert priv_key = Sys.(set_signal sigpipe Signal_ignore) ; + let dir = Fpath.v dir in Lwt_main.run - (init_exception () ; - let d = Fpath.v dir in - let c = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in - Lwt_unix.set_close_on_exec c ; - Lwt_unix.(connect c (ADDR_UNIX Fpath.(to_string (d / "cons" + "sock")))) >>= fun () -> - Lwt.catch (fun () -> - let s = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in - Lwt_unix.set_close_on_exec s ; - Lwt_unix.(connect s (ADDR_UNIX Fpath.(to_string (d / "stat" + "sock")))) >|= fun () -> - Some s) - (function - | Unix.Unix_error (Unix.ENOENT, _, _) -> Lwt.return None - | e -> Lwt.fail e) >>= fun s -> - let l = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in - Lwt_unix.set_close_on_exec l ; - Lwt_unix.(connect l (ADDR_UNIX Fpath.(to_string (d / "log" + "sock")))) >>= fun () -> + ((init_sock dir "cons" >|= function + | None -> invalid_arg "cannot connect to console socket" + | Some c -> c) >>= fun c -> + init_sock dir "stat" >>= fun s -> + (init_sock dir "log" >|= function + | None -> invalid_arg "cannot connect to log socket" + | Some l -> l) >>= fun l -> server_socket 1025 >>= fun socket -> X509_lwt.private_of_pems ~cert ~priv_key >>= fun cert -> X509_lwt.certs_of_pem cacert >>= (function @@ -182,7 +192,7 @@ let jump _ dir cacert cert priv_key = Tls.(Config.server ~version:(Core.TLS_1_2, Core.TLS_1_2) ~reneg:true ~certificates:(`Single cert) ()) in - (match Vmm_engine.init d cmp_s c s l with + (match Vmm_engine.init dir cmp_s c s l with | Ok s -> Lwt.return s | Error (`Msg m) -> Lwt.fail_with m) >>= fun t -> let state = ref t in @@ -200,7 +210,13 @@ let jump _ dir cacert cert priv_key = (fun exn -> Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () -> Lwt.fail exn) >>= fun t -> - Lwt.async (fun () -> handle ca state t) ; + Lwt.async (fun () -> + Lwt.catch + (fun () -> handle ca state t) + (fun e -> + Logs.err (fun m -> m "error while handle() %s" + (Printexc.to_string e)) ; + Lwt.return_unit)) ; loop ()) (function | Unix.Unix_error (e, f, _) -> diff --git a/src/vmm_lwt.ml b/src/vmm_lwt.ml index 65156d6..870f2a6 100644 --- a/src/vmm_lwt.ml +++ b/src/vmm_lwt.ml @@ -39,35 +39,53 @@ let wait_and_clear pid stdout = let read_exactly s = let buf = Bytes.create 8 in let rec r b i l = - Lwt_unix.read s b i l >>= function - | 0 -> Lwt.fail_with "end of file" - | n when n == l -> Lwt.return_unit - | n when n < l -> r b (i + n) (l - n) - | _ -> Lwt.fail_with "read too much" + Lwt.catch (fun () -> + Lwt_unix.read s b i l >>= function + | 0 -> + Logs.err (fun m -> m "end of file while reading") ; + Lwt.return (Error `Eof) + | n when n == l -> Lwt.return (Ok ()) + | n when n < l -> r b (i + n) (l - n) + | _ -> + Logs.err (fun m -> m "read too much, shouldn't happen)") ; + Lwt.return (Error `Toomuch)) + (fun e -> + let err = Printexc.to_string e in + Logs.err (fun m -> m "exception %s while reading" err) ; + Lwt.return (Error `Exception)) + in - r buf 0 8 >>= fun () -> - match Vmm_wire.parse_header (Bytes.to_string buf) with - | Error (`Msg m) -> Lwt.return (Error (`Msg m)) - | Ok hdr -> - let l = hdr.Vmm_wire.length in - if l > 0 then - let b = Bytes.create l in - r b 0 l >|= fun () -> - Logs.debug (fun m -> m "read hdr %a, body %a" - Cstruct.hexdump_pp (Cstruct.of_bytes buf) - Cstruct.hexdump_pp (Cstruct.of_bytes b)) ; - Ok (hdr, Bytes.to_string b) - else - Lwt.return (Ok (hdr, "")) + r buf 0 8 >>= function + | Error e -> Lwt.return (Error e) + | Ok () -> + match Vmm_wire.parse_header (Bytes.to_string buf) with + | Error (`Msg m) -> Lwt.return (Error (`Msg m)) + | Ok hdr -> + let l = hdr.Vmm_wire.length in + if l > 0 then + let b = Bytes.create l in + r b 0 l >|= function + | Error e -> Error e + | Ok () -> + Logs.debug (fun m -> m "read hdr %a, body %a" + Cstruct.hexdump_pp (Cstruct.of_bytes buf) + Cstruct.hexdump_pp (Cstruct.of_bytes b)) ; + Ok (hdr, Bytes.to_string b) + else + Lwt.return (Ok (hdr, "")) let write_raw s buf = let buf = Bytes.unsafe_of_string buf in let rec w off l = - Lwt_unix.send s buf off l [] >>= fun n -> - if n = l then - Lwt.return_unit - else - w (off + n) (l - n) + Lwt.catch (fun () -> + Lwt_unix.send s buf off l [] >>= fun n -> + if n = l then + Lwt.return (Ok ()) + else + w (off + n) (l - n)) + (fun e -> + Logs.err (fun m -> m "exception %s while writing" (Printexc.to_string e)) ; + Lwt.return (Error `Exception)) in Logs.debug (fun m -> m "writing %a" Cstruct.hexdump_pp (Cstruct.of_bytes buf)) ; w 0 (Bytes.length buf) diff --git a/src/vmm_tls.ml b/src/vmm_tls.ml index 656d45f..f8f8989 100644 --- a/src/vmm_tls.ml +++ b/src/vmm_tls.ml @@ -6,30 +6,54 @@ let read_tls t = let rec r_n buf off tot = let l = tot - off in if l = 0 then - Lwt.return_unit + Lwt.return (Ok ()) else - Tls_lwt.Unix.read t (Cstruct.shift buf off) >>= function - | 0 -> Lwt.fail_with "read 0 bytes" - | x when x == l -> Lwt.return_unit - | x when x < l -> r_n buf (off + x) tot - | _ -> Lwt.fail_with "overread, will never happen" + Lwt.catch (fun () -> + Tls_lwt.Unix.read t (Cstruct.shift buf off) >>= function + | 0 -> + Logs.err (fun m -> m "TLS: end of file") ; + Lwt.return (Error `Eof) + | x when x == l -> Lwt.return (Ok ()) + | x when x < l -> r_n buf (off + x) tot + | _ -> + Logs.err (fun m -> m "TLS: read too much, shouldn't happen") ; + Lwt.return (Error `Toomuch)) + (function + | Tls_lwt.Tls_failure a -> + Logs.err (fun m -> m "TLS read failure: %s" (Tls.Engine.string_of_failure a)) ; + Lwt.return (Error `Exception) + | e -> + Logs.err (fun m -> m "TLS read exception %s" (Printexc.to_string e)) ; + Lwt.return (Error `Exception)) in let buf = Cstruct.create 8 in - r_n buf 0 8 >>= fun () -> - match Vmm_wire.parse_header (Cstruct.to_string buf) with - | Error (`Msg m) -> Lwt.return (Error (`Msg m)) - | Ok hdr -> - let l = hdr.Vmm_wire.length in - if l > 0 then - let b = Cstruct.create l in - r_n b 0 l >|= fun () -> - Logs.debug (fun m -> m "TLS read id %d %a tag %d data %a" - hdr.Vmm_wire.id Vmm_wire.pp_version hdr.Vmm_wire.version hdr.Vmm_wire.tag - Cstruct.hexdump_pp b) ; - Ok (hdr, Cstruct.to_string b) - else - Lwt.return (Ok (hdr, "")) + r_n buf 0 8 >>= function + | Error e -> Lwt.return (Error e) + | Ok () -> + match Vmm_wire.parse_header (Cstruct.to_string buf) with + | Error (`Msg m) -> Lwt.return (Error (`Msg m)) + | Ok hdr -> + let l = hdr.Vmm_wire.length in + if l > 0 then + let b = Cstruct.create l in + r_n b 0 l >|= function + | Error e -> Error e + | Ok () -> + Logs.debug (fun m -> m "TLS read id %d %a tag %d data %a" + hdr.Vmm_wire.id Vmm_wire.pp_version hdr.Vmm_wire.version hdr.Vmm_wire.tag + Cstruct.hexdump_pp b) ; + Ok (hdr, Cstruct.to_string b) + else + Lwt.return (Ok (hdr, "")) let write_tls s buf = Logs.debug (fun m -> m "TLS write %a" Cstruct.hexdump_pp (Cstruct.of_string buf)) ; - Tls_lwt.Unix.write s (Cstruct.of_string buf) + Lwt.catch + (fun () -> Tls_lwt.Unix.write s (Cstruct.of_string buf) >|= fun () -> Ok ()) + (function + | Tls_lwt.Tls_failure a -> + Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a)) ; + Lwt.return (Error `Exception) + | e -> + Logs.err (fun m -> m "TLS write exception %s" (Printexc.to_string e)) ; + Lwt.return (Error `Exception)) diff --git a/stats/vmm_stats_lwt.ml b/stats/vmm_stats_lwt.ml index a29e041..f79f8a4 100644 --- a/stats/vmm_stats_lwt.ml +++ b/stats/vmm_stats_lwt.ml @@ -26,15 +26,18 @@ let handle s addr () = let rec loop () = Vmm_lwt.read_exactly s >>= function | Error (`Msg msg) -> Logs.err (fun m -> m "error while reading %s" msg) ; loop () + | Error _ -> Logs.err (fun m -> m "exception while reading") ; Lwt.return_unit | Ok (hdr, data) -> Logs.debug (fun m -> m "received %a" Cstruct.hexdump_pp (Cstruct.of_string data)) ; let t', out = Vmm_stats.handle !t hdr data in t := t' ; Logs.debug (fun m -> m "sent %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ; - Vmm_lwt.write_raw s out >>= fun () -> - loop () + Vmm_lwt.write_raw s out >>= function + | Ok () -> loop () + | Error _ -> Logs.err (fun m -> m "exception while writing") ; Lwt.return_unit in - loop () + loop () >>= fun () -> + Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit) let rec timer () = t := Vmm_stats.tick !t ;