safer and clearer error semantics for all processes, fixes #5

This commit is contained in:
Hannes Mehnert 2018-03-18 17:30:43 +00:00
parent 88012094f8
commit cfa7ccd1e0
8 changed files with 304 additions and 186 deletions

View File

@ -52,18 +52,14 @@ let process db hdr data =
| Error (`Msg msg) -> Logs.err (fun m -> m "error while processing: %s" msg)
let rec read_tls_write_cons db t =
Lwt.catch (fun () ->
Vmm_tls.read_tls t >>= function
| Error (`Msg msg) ->
Logs.err (fun m -> m "error while reading %s" msg) ;
read_tls_write_cons db t
| Ok (hdr, data) ->
process db hdr data ;
read_tls_write_cons db t)
(fun e ->
Logs.err (fun m -> m "exception reading TLS stream %s"
(Printexc.to_string e)) ;
Tls_lwt.Unix.close t)
Vmm_tls.read_tls t >>= function
| Error (`Msg msg) ->
Logs.err (fun m -> m "error while reading %s" msg) ;
read_tls_write_cons db t
| Error _ -> Logs.err (fun m -> m "exception while reading") ; Lwt.return_unit
| Ok (hdr, data) ->
process db hdr data ;
read_tls_write_cons db t
let rec read_cons_write_tls db t =
Lwt.catch (fun () ->
@ -77,10 +73,14 @@ let rec read_cons_write_tls db t =
| Some cmd ->
let out = Vmm_wire.Client.cmd ?arg cmd !command my_version in
command := succ !command ;
Vmm_tls.write_tls t out >>= fun () ->
Logs.debug (fun m -> m "wrote %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ;
read_cons_write_tls db t)
(fun _ -> Lwt.return_unit)
Vmm_tls.write_tls t out >>= function
| Error _ -> Logs.err (fun m -> m "exception while writing") ; Lwt.return_unit
| Ok () ->
Logs.debug (fun m -> m "wrote %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ;
read_cons_write_tls db t)
(fun e ->
Logs.err (fun m -> m "exception %s in read_cons_write_tls" (Printexc.to_string e)) ;
Lwt.return_unit)
let client cas host port cert priv_key db =
Nocrypto_entropy_lwt.initialize () >>= fun () ->

View File

@ -37,8 +37,11 @@ let read_console s name ring channel () =
(if String.Set.mem name !active then
Vmm_lwt.write_raw s (data my_version name t line)
else
Lwt.return_unit) >>= fun () ->
loop ()
Lwt.return (Ok ())) >>= function
| Ok () -> loop ()
| Error _ ->
Logs.err (fun m -> m "error reading console") ;
Lwt_io.close channel
in
loop ())
(fun e ->
@ -102,14 +105,20 @@ let history s name since =
let entries = Vmm_ring.read_history r since in
Logs.debug (fun m -> m "found %d history" (List.length entries)) ;
Lwt_list.iter_s (fun (i, v) ->
Vmm_lwt.write_raw s (data my_version name i v)) entries >|= fun () ->
Vmm_lwt.write_raw s (data my_version name i v) >|= fun _ -> ())
entries >|= fun () ->
Ok "success"
let handle s addr () =
Logs.info (fun m -> m "handling connection %a" pp_sockaddr addr) ;
let rec loop () =
Vmm_lwt.read_exactly s >>= function
| Error (`Msg msg) -> Logs.err (fun m -> m "error while reading %s" msg) ; loop ()
| Error (`Msg msg) ->
Logs.err (fun m -> m "error while reading %s" msg) ;
loop ()
| Error _ ->
Logs.err (fun m -> m "exception while reading") ;
Lwt.return_unit
| Ok (hdr, data) ->
(if not (version_eq hdr.version my_version) then
Lwt.return (Error (`Msg "ignoring data with bad version"))
@ -138,10 +147,14 @@ let handle s addr () =
| Ok msg -> Vmm_lwt.write_raw s (success ~msg hdr.id my_version)
| Error (`Msg msg) ->
Logs.err (fun m -> m "error while processing command: %s" msg) ;
Vmm_lwt.write_raw s (fail ~msg hdr.id my_version)) >>= fun () ->
loop ()
Vmm_lwt.write_raw s (fail ~msg hdr.id my_version)) >>= function
| Ok () -> loop ()
| Error _ ->
Logs.err (fun m -> m "exception while writing to socket") ;
Lwt.return_unit
in
loop ()
loop () >>= fun () ->
Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit)
let jump _ file =
Sys.(set_signal sigpipe Signal_ignore) ;

View File

@ -29,8 +29,13 @@ let write_complete s str =
in
w 0
let pp_sockaddr ppf = function
| Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str
| Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d"
(Unix.string_of_inet_addr addr) port
let handle fd ring s addr () =
Logs.info (fun m -> m "handling connection") ;
Logs.info (fun m -> m "handling connection from %a" pp_sockaddr addr) ;
let str = Fmt.strf "%a: CONNECT\n" (Ptime.pp_human ~tz_offset_s:0 ()) (Ptime_clock.now ()) in
write_complete fd str >>= fun () ->
let rec loop () =
@ -38,52 +43,79 @@ let handle fd ring s addr () =
| Error (`Msg e) ->
Logs.err (fun m -> m "error while reading %s" e) ;
loop ()
| Error _ ->
Logs.err (fun m -> m "exception while reading") ;
Lwt.return_unit
| Ok (hdr, data) ->
(if not (version_eq hdr.version my_version) then
Lwt.return (Error (`Msg "unknown version"))
else match int_to_op hdr.tag with
| Some Data ->
( match decode_ts data with
| Ok ts -> Vmm_ring.write ring (ts, data)
| Error _ -> ()) ;
write_complete fd data >>= fun () ->
Lwt.return (Ok None)
| Some History ->
begin match decode_str data with
| Error e -> Lwt.return (Error e)
| Ok (str, off) -> match decode_ts ~off data with
| Error e -> Lwt.return (Error e)
| Ok ts ->
let elements = Vmm_ring.read_history ring ts in
let res = List.fold_left (fun acc (_, x) ->
match Vmm_wire.Log.decode_log_hdr (Cstruct.of_string x) with
| Ok (hdr, _) ->
Logs.debug (fun m -> m "found an entry: %a" (Vmm_core.Log.pp_hdr []) hdr) ;
if String.equal str (Vmm_core.string_of_id hdr.Vmm_core.Log.context) then
x :: acc
else
acc
| _ -> acc)
[] elements
in
(* just need a wrapper in tag = Log.Data, id = reqid *)
Lwt_list.iter_s (fun x ->
let length = String.length x in
let hdr = Vmm_wire.create_header { length ; id = hdr.id ; tag = op_to_int Data ; version = my_version } in
Vmm_lwt.write_raw s (Cstruct.to_string hdr ^ x))
(List.rev res) >>= fun () ->
Lwt.return (Ok None)
end
| _ ->
Logs.err (fun m -> m "didn't understand log command %d" hdr.tag) ;
Lwt.return (Error (`Msg "unknown command"))) >>= (function
| Ok msg -> Vmm_lwt.write_raw s (success ?msg hdr.id my_version)
| Error (`Msg msg) ->
Logs.err (fun m -> m "error while processing: %s" msg) ;
Vmm_lwt.write_raw s (fail ~msg hdr.id my_version)) >>= fun () ->
loop ()
let out =
(if not (version_eq hdr.version my_version) then
Error (`Msg "unknown version")
else match int_to_op hdr.tag with
| Some Data ->
(match decode_ts data with
| Ok ts -> Vmm_ring.write ring (ts, data)
| Error _ ->
Logs.warn (fun m -> m "ignoring error while decoding timestamp %s" data)) ;
Ok (`Data data)
| Some History ->
begin match decode_str data with
| Error e -> Error e
| Ok (str, off) -> match decode_ts ~off data with
| Error e -> Error e
| Ok ts ->
let elements = Vmm_ring.read_history ring ts in
let res = List.fold_left (fun acc (_, x) ->
match Vmm_wire.Log.decode_log_hdr (Cstruct.of_string x) with
| Ok (hdr, _) ->
Logs.debug (fun m -> m "found an entry: %a" (Vmm_core.Log.pp_hdr []) hdr) ;
if String.equal str (Vmm_core.string_of_id hdr.Vmm_core.Log.context) then
x :: acc
else
acc
| _ -> acc)
[] elements
in
(* just need a wrapper in tag = Log.Data, id = reqid *)
let out =
List.fold_left (fun acc x ->
let length = String.length x in
let hdr = Vmm_wire.create_header { length ; id = hdr.id ; tag = op_to_int Data ; version = my_version } in
(Cstruct.to_string hdr ^ x) :: acc)
[] (List.rev res)
in
Ok (`Out out)
end
| _ ->
Error (`Msg "unknown command"))
in
match out with
| Error (`Msg msg) ->
begin
Logs.err (fun m -> m "error while processing: %s" msg) ;
Vmm_lwt.write_raw s (fail ~msg hdr.id my_version) >>= function
| Error _ -> Logs.err (fun m -> m "error0 while writing") ; Lwt.return_unit
| Ok () -> loop ()
end
| Ok (`Data data) ->
begin
write_complete fd data >>= fun () ->
Vmm_lwt.write_raw s (success hdr.id my_version) >>= function
| Error _ -> Logs.err (fun m -> m "error1 while writing") ; Lwt.return_unit
| Ok () -> loop ()
end
| Ok (`Out datas) ->
Lwt_list.fold_left_s (fun r x -> match r with
| Error e -> Lwt.return (Error e)
| Ok () -> Vmm_lwt.write_raw s x)
(Ok ()) datas >>= function
| Error _ -> Logs.err (fun m -> m "error2 while writing") ; Lwt.return_unit
| Ok () ->
Vmm_lwt.write_raw s (success hdr.id my_version) >>= function
| Error _ -> Logs.err (fun m -> m "error3 while writing") ; Lwt.return_unit
| Ok () -> loop ()
in
Lwt.catch loop (fun e -> Lwt.return_unit)
loop () >>= fun () ->
Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit)
let jump _ file sock =
Sys.(set_signal sigpipe Signal_ignore) ;

View File

@ -128,7 +128,9 @@ let process db tls hdr data =
let out = Vmm_wire.Client.cmd `Info !command my_version in
command := succ !command ;
Logs.debug (fun m -> m "writing %a over TLS" Cstruct.hexdump_pp (Cstruct.of_string out)) ;
Vmm_tls.write_tls tls out
(Vmm_tls.write_tls tls out >|= function
| Ok () -> ()
| Error _ -> Logs.err (fun m -> m "error while writing") ; ())
| _ ->
let r =
match hdr.tag with
@ -176,21 +178,23 @@ let process db tls hdr data =
match r with
| Ok `None -> Lwt.return_unit
| Ok (`Sockaddr s) -> d s
| Ok (`Stat (fd, s, out)) -> Vmm_lwt.write_raw fd out >>= fun () -> d (fd, s)
| Ok (`Stat (fd, s, out)) ->
(Vmm_lwt.write_raw fd out >>= function
| Ok () -> d (fd, s)
| Error _ -> Logs.err (fun m -> m "exception while writing") ; Lwt.return_unit)
| Error (`Msg msg) -> Logs.err (fun m -> m "error while processing: %s" msg) ; Lwt.return_unit
let rec tls_listener db tls =
Lwt.catch (fun () ->
Vmm_tls.read_tls tls >>= function
| Error (`Msg msg) ->
Logs.err (fun m -> m "error while reading %s" msg) ;
Lwt.return (Ok ())
| Ok (hdr, data) ->
process db tls hdr data >>= fun () ->
Lwt.return (Ok ()))
(fun e ->
Logs.err (fun m -> m "received exception in read_tls: %s" (Printexc.to_string e)) ;
Lwt.return (Error ())) >>= function
(Vmm_tls.read_tls tls >>= function
| Error (`Msg msg) ->
Logs.err (fun m -> m "error while reading %s" msg) ;
Lwt.return (Ok ())
| Error _ ->
Logs.err (fun m -> m "received exception in read_tls") ;
Lwt.return (Error ())
| Ok (hdr, data) ->
process db tls hdr data >>= fun () ->
Lwt.return (Ok ())) >>= function
| Ok () -> tls_listener db tls
| Error () -> Lwt.return_unit
@ -203,24 +207,32 @@ let hdr =
(* wait for TCP connection, once received request stats from vmmd, and loop *)
let rec tcp_listener db tcp tls =
Lwt_unix.accept tcp >>= fun (cs, sockaddr) ->
Vmm_lwt.write_raw cs hdr >>= fun () ->
let l = List.length !known_vms in
let ip, port = match sockaddr with Lwt_unix.ADDR_INET (ip, port) -> ip, port | _ -> invalid_arg "unexpected" in
Logs.info (fun m -> m "connection from %s:%d with %d known" (Unix.string_of_inet_addr ip) port l) ;
(if l = 0 then
Lwt_unix.close cs
else begin
count := SM.add sockaddr (List.length !known_vms) !count ;
Lwt_list.iter_s
(fun vm ->
let vm_id = translate_name db vm in
let out = Vmm_wire.Client.cmd `Statistics ~arg:vm_id !command my_version in
t := IM.add !command (cs, sockaddr, vm) !t ;
command := succ !command ;
Vmm_tls.write_tls tls out)
!known_vms
end) >>= fun () ->
tcp_listener db tcp tls
Vmm_lwt.write_raw cs hdr >>= function
| Error _ -> Logs.err (fun m -> m "exception while accepting") ; Lwt.return_unit
| Ok () ->
let l = List.length !known_vms in
let ip, port = match sockaddr with Lwt_unix.ADDR_INET (ip, port) -> ip, port | _ -> invalid_arg "unexpected" in
Logs.info (fun m -> m "connection from %s:%d with %d known" (Unix.string_of_inet_addr ip) port l) ;
(if l = 0 then
Lwt_unix.close cs >|= fun () -> Error ()
else begin
count := SM.add sockaddr (List.length !known_vms) !count ;
Lwt_list.fold_left_s
(fun r vm ->
match r with
| Error () -> Lwt.return (Error ())
| Ok () ->
let vm_id = translate_name db vm in
let out = Vmm_wire.Client.cmd `Statistics ~arg:vm_id !command my_version in
t := IM.add !command (cs, sockaddr, vm) !t ;
command := succ !command ;
Vmm_tls.write_tls tls out >|= function
| Ok () -> Ok ()
| Error _ -> Logs.err (fun m -> m "exception while writing") ; Error ())
(Ok ()) !known_vms
end) >>= function
| Ok () -> tcp_listener db tcp tls
| Error () -> Lwt.return_unit
let client cas host port cert priv_key db listen_ip listen_port =
Nocrypto_entropy_lwt.initialize () >>= fun () ->

View File

@ -2,14 +2,17 @@
open Lwt.Infix
let write_raw s data =
Vmm_lwt.write_raw s data >|= fun _ -> ()
let write_tls state t data =
Lwt.catch (fun () -> Vmm_tls.write_tls (fst t) data)
(fun e ->
let state', out = Vmm_engine.handle_disconnect !state t in
state := state' ;
Lwt_list.iter_s (fun (s, data) -> Vmm_lwt.write_raw s data) out >>= fun () ->
Tls_lwt.Unix.close (fst t) >>= fun () ->
raise e)
Vmm_tls.write_tls (fst t) data >>= function
| Ok () -> Lwt.return_unit
| Error `Exception ->
let state', out = Vmm_engine.handle_disconnect !state t in
state := state' ;
Lwt_list.iter_s (fun (s, data) -> write_raw s data) out >>= fun () ->
Tls_lwt.Unix.close (fst t)
let to_ipaddr (_, sa) = match sa with
| Lwt_unix.ADDR_UNIX _ -> invalid_arg "cannot convert unix address"
@ -22,7 +25,7 @@ let pp_sockaddr ppf (_, sa) = match sa with
let process state xs =
Lwt_list.iter_s (function
| `Raw (s, str) -> Vmm_lwt.write_raw s str
| `Raw (s, str) -> write_raw s str
| `Tls (s, str) -> write_tls state s str)
xs
@ -73,19 +76,19 @@ let handle ca state t =
| Error (`Msg msg) ->
Logs.err (fun m -> m "reading client %a error: %s" pp_sockaddr t msg) ;
loop ()
| Error _ ->
Logs.err (fun m -> m "disconnect from %a" pp_sockaddr t) ;
let state', cons = Vmm_engine.handle_disconnect !state t in
state := state' ;
Lwt_list.iter_s (fun (s, data) -> write_raw s data) cons >>= fun () ->
Tls_lwt.Unix.close (fst t)
| Ok (hdr, buf) ->
let state', out = Vmm_engine.handle_command !state t prefix perms hdr buf in
state := state' ;
process state out >>= fun () ->
loop ()
in
Lwt.catch loop
(fun e ->
let state', cons = Vmm_engine.handle_disconnect !state t in
state := state' ;
Lwt_list.iter_s (fun (s, data) -> Vmm_lwt.write_raw s data) cons >>= fun () ->
Tls_lwt.Unix.close (fst t) >>= fun () ->
raise e)
loop ()
| `Close socks ->
Logs.debug (fun m -> m "closing session with %d active ones" (List.length socks)) ;
Lwt_list.iter_s (fun (t, _) -> Tls_lwt.Unix.close t) socks >>= fun () ->
@ -105,18 +108,26 @@ let server_socket port =
listen s 10 ;
Lwt.return s
let init_exception () =
Lwt.async_exception_hook := (function
| Tls_lwt.Tls_failure a ->
Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a))
| exn ->
Logs.err (fun m -> m "exception: %s" (Printexc.to_string exn)))
let init_sock dir name =
let c = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
Lwt_unix.set_close_on_exec c ;
let addr = Fpath.(dir / name + "sock") in
Lwt.catch (fun () ->
Lwt_unix.(connect c (ADDR_UNIX (Fpath.to_string addr))) >|= fun () -> Some c)
(fun e ->
Logs.warn (fun m -> m "error %s connecting to socket %a"
(Printexc.to_string e) Fpath.pp addr) ;
(Lwt.catch (fun () -> Lwt_unix.close c) (fun _ -> Lwt.return_unit)) >|= fun () ->
None)
let rec read_log state s =
Vmm_lwt.read_exactly s >>= function
| Error (`Msg msg) ->
Logs.err (fun m -> m "reading log error %s" msg) ;
read_log state s
| Error _ ->
Logs.err (fun m -> m "exception while reading log") ;
invalid_arg "log socket communication issue"
| Ok (hdr, data) ->
let state', outs = Vmm_engine.handle_log !state hdr data in
state := state' ;
@ -128,6 +139,9 @@ let rec read_cons state s =
| Error (`Msg msg) ->
Logs.err (fun m -> m "reading console error %s" msg) ;
read_cons state s
| Error _ ->
Logs.err (fun m -> m "exception while reading console socket") ;
invalid_arg "console socket communication issue"
| Ok (hdr, data) ->
let state', outs = Vmm_engine.handle_cons !state hdr data in
state := state' ;
@ -139,6 +153,10 @@ let rec read_stats state s =
| Error (`Msg msg) ->
Logs.err (fun m -> m "reading stats error %s" msg) ;
read_stats state s
| Error _ ->
Logs.err (fun m -> m "exception while reading stats") ;
Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit) >|= fun () ->
state := { !state with Vmm_engine.stats_socket = None }
| Ok (hdr, data) ->
let state', outs = Vmm_engine.handle_stat !state hdr data in
state := state' ;
@ -156,23 +174,15 @@ let cmp_s (_, a) (_, b) =
let jump _ dir cacert cert priv_key =
Sys.(set_signal sigpipe Signal_ignore) ;
let dir = Fpath.v dir in
Lwt_main.run
(init_exception () ;
let d = Fpath.v dir in
let c = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
Lwt_unix.set_close_on_exec c ;
Lwt_unix.(connect c (ADDR_UNIX Fpath.(to_string (d / "cons" + "sock")))) >>= fun () ->
Lwt.catch (fun () ->
let s = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
Lwt_unix.set_close_on_exec s ;
Lwt_unix.(connect s (ADDR_UNIX Fpath.(to_string (d / "stat" + "sock")))) >|= fun () ->
Some s)
(function
| Unix.Unix_error (Unix.ENOENT, _, _) -> Lwt.return None
| e -> Lwt.fail e) >>= fun s ->
let l = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
Lwt_unix.set_close_on_exec l ;
Lwt_unix.(connect l (ADDR_UNIX Fpath.(to_string (d / "log" + "sock")))) >>= fun () ->
((init_sock dir "cons" >|= function
| None -> invalid_arg "cannot connect to console socket"
| Some c -> c) >>= fun c ->
init_sock dir "stat" >>= fun s ->
(init_sock dir "log" >|= function
| None -> invalid_arg "cannot connect to log socket"
| Some l -> l) >>= fun l ->
server_socket 1025 >>= fun socket ->
X509_lwt.private_of_pems ~cert ~priv_key >>= fun cert ->
X509_lwt.certs_of_pem cacert >>= (function
@ -182,7 +192,7 @@ let jump _ dir cacert cert priv_key =
Tls.(Config.server ~version:(Core.TLS_1_2, Core.TLS_1_2)
~reneg:true ~certificates:(`Single cert) ())
in
(match Vmm_engine.init d cmp_s c s l with
(match Vmm_engine.init dir cmp_s c s l with
| Ok s -> Lwt.return s
| Error (`Msg m) -> Lwt.fail_with m) >>= fun t ->
let state = ref t in
@ -200,7 +210,13 @@ let jump _ dir cacert cert priv_key =
(fun exn ->
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () ->
Lwt.fail exn) >>= fun t ->
Lwt.async (fun () -> handle ca state t) ;
Lwt.async (fun () ->
Lwt.catch
(fun () -> handle ca state t)
(fun e ->
Logs.err (fun m -> m "error while handle() %s"
(Printexc.to_string e)) ;
Lwt.return_unit)) ;
loop ())
(function
| Unix.Unix_error (e, f, _) ->

View File

@ -39,35 +39,53 @@ let wait_and_clear pid stdout =
let read_exactly s =
let buf = Bytes.create 8 in
let rec r b i l =
Lwt_unix.read s b i l >>= function
| 0 -> Lwt.fail_with "end of file"
| n when n == l -> Lwt.return_unit
| n when n < l -> r b (i + n) (l - n)
| _ -> Lwt.fail_with "read too much"
Lwt.catch (fun () ->
Lwt_unix.read s b i l >>= function
| 0 ->
Logs.err (fun m -> m "end of file while reading") ;
Lwt.return (Error `Eof)
| n when n == l -> Lwt.return (Ok ())
| n when n < l -> r b (i + n) (l - n)
| _ ->
Logs.err (fun m -> m "read too much, shouldn't happen)") ;
Lwt.return (Error `Toomuch))
(fun e ->
let err = Printexc.to_string e in
Logs.err (fun m -> m "exception %s while reading" err) ;
Lwt.return (Error `Exception))
in
r buf 0 8 >>= fun () ->
match Vmm_wire.parse_header (Bytes.to_string buf) with
| Error (`Msg m) -> Lwt.return (Error (`Msg m))
| Ok hdr ->
let l = hdr.Vmm_wire.length in
if l > 0 then
let b = Bytes.create l in
r b 0 l >|= fun () ->
Logs.debug (fun m -> m "read hdr %a, body %a"
Cstruct.hexdump_pp (Cstruct.of_bytes buf)
Cstruct.hexdump_pp (Cstruct.of_bytes b)) ;
Ok (hdr, Bytes.to_string b)
else
Lwt.return (Ok (hdr, ""))
r buf 0 8 >>= function
| Error e -> Lwt.return (Error e)
| Ok () ->
match Vmm_wire.parse_header (Bytes.to_string buf) with
| Error (`Msg m) -> Lwt.return (Error (`Msg m))
| Ok hdr ->
let l = hdr.Vmm_wire.length in
if l > 0 then
let b = Bytes.create l in
r b 0 l >|= function
| Error e -> Error e
| Ok () ->
Logs.debug (fun m -> m "read hdr %a, body %a"
Cstruct.hexdump_pp (Cstruct.of_bytes buf)
Cstruct.hexdump_pp (Cstruct.of_bytes b)) ;
Ok (hdr, Bytes.to_string b)
else
Lwt.return (Ok (hdr, ""))
let write_raw s buf =
let buf = Bytes.unsafe_of_string buf in
let rec w off l =
Lwt_unix.send s buf off l [] >>= fun n ->
if n = l then
Lwt.return_unit
else
w (off + n) (l - n)
Lwt.catch (fun () ->
Lwt_unix.send s buf off l [] >>= fun n ->
if n = l then
Lwt.return (Ok ())
else
w (off + n) (l - n))
(fun e ->
Logs.err (fun m -> m "exception %s while writing" (Printexc.to_string e)) ;
Lwt.return (Error `Exception))
in
Logs.debug (fun m -> m "writing %a" Cstruct.hexdump_pp (Cstruct.of_bytes buf)) ;
w 0 (Bytes.length buf)

View File

@ -6,30 +6,54 @@ let read_tls t =
let rec r_n buf off tot =
let l = tot - off in
if l = 0 then
Lwt.return_unit
Lwt.return (Ok ())
else
Tls_lwt.Unix.read t (Cstruct.shift buf off) >>= function
| 0 -> Lwt.fail_with "read 0 bytes"
| x when x == l -> Lwt.return_unit
| x when x < l -> r_n buf (off + x) tot
| _ -> Lwt.fail_with "overread, will never happen"
Lwt.catch (fun () ->
Tls_lwt.Unix.read t (Cstruct.shift buf off) >>= function
| 0 ->
Logs.err (fun m -> m "TLS: end of file") ;
Lwt.return (Error `Eof)
| x when x == l -> Lwt.return (Ok ())
| x when x < l -> r_n buf (off + x) tot
| _ ->
Logs.err (fun m -> m "TLS: read too much, shouldn't happen") ;
Lwt.return (Error `Toomuch))
(function
| Tls_lwt.Tls_failure a ->
Logs.err (fun m -> m "TLS read failure: %s" (Tls.Engine.string_of_failure a)) ;
Lwt.return (Error `Exception)
| e ->
Logs.err (fun m -> m "TLS read exception %s" (Printexc.to_string e)) ;
Lwt.return (Error `Exception))
in
let buf = Cstruct.create 8 in
r_n buf 0 8 >>= fun () ->
match Vmm_wire.parse_header (Cstruct.to_string buf) with
| Error (`Msg m) -> Lwt.return (Error (`Msg m))
| Ok hdr ->
let l = hdr.Vmm_wire.length in
if l > 0 then
let b = Cstruct.create l in
r_n b 0 l >|= fun () ->
Logs.debug (fun m -> m "TLS read id %d %a tag %d data %a"
hdr.Vmm_wire.id Vmm_wire.pp_version hdr.Vmm_wire.version hdr.Vmm_wire.tag
Cstruct.hexdump_pp b) ;
Ok (hdr, Cstruct.to_string b)
else
Lwt.return (Ok (hdr, ""))
r_n buf 0 8 >>= function
| Error e -> Lwt.return (Error e)
| Ok () ->
match Vmm_wire.parse_header (Cstruct.to_string buf) with
| Error (`Msg m) -> Lwt.return (Error (`Msg m))
| Ok hdr ->
let l = hdr.Vmm_wire.length in
if l > 0 then
let b = Cstruct.create l in
r_n b 0 l >|= function
| Error e -> Error e
| Ok () ->
Logs.debug (fun m -> m "TLS read id %d %a tag %d data %a"
hdr.Vmm_wire.id Vmm_wire.pp_version hdr.Vmm_wire.version hdr.Vmm_wire.tag
Cstruct.hexdump_pp b) ;
Ok (hdr, Cstruct.to_string b)
else
Lwt.return (Ok (hdr, ""))
let write_tls s buf =
Logs.debug (fun m -> m "TLS write %a" Cstruct.hexdump_pp (Cstruct.of_string buf)) ;
Tls_lwt.Unix.write s (Cstruct.of_string buf)
Lwt.catch
(fun () -> Tls_lwt.Unix.write s (Cstruct.of_string buf) >|= fun () -> Ok ())
(function
| Tls_lwt.Tls_failure a ->
Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a)) ;
Lwt.return (Error `Exception)
| e ->
Logs.err (fun m -> m "TLS write exception %s" (Printexc.to_string e)) ;
Lwt.return (Error `Exception))

View File

@ -26,15 +26,18 @@ let handle s addr () =
let rec loop () =
Vmm_lwt.read_exactly s >>= function
| Error (`Msg msg) -> Logs.err (fun m -> m "error while reading %s" msg) ; loop ()
| Error _ -> Logs.err (fun m -> m "exception while reading") ; Lwt.return_unit
| Ok (hdr, data) ->
Logs.debug (fun m -> m "received %a" Cstruct.hexdump_pp (Cstruct.of_string data)) ;
let t', out = Vmm_stats.handle !t hdr data in
t := t' ;
Logs.debug (fun m -> m "sent %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ;
Vmm_lwt.write_raw s out >>= fun () ->
loop ()
Vmm_lwt.write_raw s out >>= function
| Ok () -> loop ()
| Error _ -> Logs.err (fun m -> m "exception while writing") ; Lwt.return_unit
in
loop ()
loop () >>= fun () ->
Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit)
let rec timer () =
t := Vmm_stats.tick !t ;