safer and clearer error semantics for all processes, fixes #5

This commit is contained in:
Hannes Mehnert 2018-03-18 17:30:43 +00:00
parent 88012094f8
commit cfa7ccd1e0
8 changed files with 304 additions and 186 deletions

View file

@ -52,18 +52,14 @@ let process db hdr data =
| Error (`Msg msg) -> Logs.err (fun m -> m "error while processing: %s" msg) | Error (`Msg msg) -> Logs.err (fun m -> m "error while processing: %s" msg)
let rec read_tls_write_cons db t = let rec read_tls_write_cons db t =
Lwt.catch (fun () ->
Vmm_tls.read_tls t >>= function Vmm_tls.read_tls t >>= function
| Error (`Msg msg) -> | Error (`Msg msg) ->
Logs.err (fun m -> m "error while reading %s" msg) ; Logs.err (fun m -> m "error while reading %s" msg) ;
read_tls_write_cons db t read_tls_write_cons db t
| Error _ -> Logs.err (fun m -> m "exception while reading") ; Lwt.return_unit
| Ok (hdr, data) -> | Ok (hdr, data) ->
process db hdr data ; process db hdr data ;
read_tls_write_cons db t) read_tls_write_cons db t
(fun e ->
Logs.err (fun m -> m "exception reading TLS stream %s"
(Printexc.to_string e)) ;
Tls_lwt.Unix.close t)
let rec read_cons_write_tls db t = let rec read_cons_write_tls db t =
Lwt.catch (fun () -> Lwt.catch (fun () ->
@ -77,10 +73,14 @@ let rec read_cons_write_tls db t =
| Some cmd -> | Some cmd ->
let out = Vmm_wire.Client.cmd ?arg cmd !command my_version in let out = Vmm_wire.Client.cmd ?arg cmd !command my_version in
command := succ !command ; command := succ !command ;
Vmm_tls.write_tls t out >>= fun () -> Vmm_tls.write_tls t out >>= function
| Error _ -> Logs.err (fun m -> m "exception while writing") ; Lwt.return_unit
| Ok () ->
Logs.debug (fun m -> m "wrote %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ; Logs.debug (fun m -> m "wrote %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ;
read_cons_write_tls db t) read_cons_write_tls db t)
(fun _ -> Lwt.return_unit) (fun e ->
Logs.err (fun m -> m "exception %s in read_cons_write_tls" (Printexc.to_string e)) ;
Lwt.return_unit)
let client cas host port cert priv_key db = let client cas host port cert priv_key db =
Nocrypto_entropy_lwt.initialize () >>= fun () -> Nocrypto_entropy_lwt.initialize () >>= fun () ->

View file

@ -37,8 +37,11 @@ let read_console s name ring channel () =
(if String.Set.mem name !active then (if String.Set.mem name !active then
Vmm_lwt.write_raw s (data my_version name t line) Vmm_lwt.write_raw s (data my_version name t line)
else else
Lwt.return_unit) >>= fun () -> Lwt.return (Ok ())) >>= function
loop () | Ok () -> loop ()
| Error _ ->
Logs.err (fun m -> m "error reading console") ;
Lwt_io.close channel
in in
loop ()) loop ())
(fun e -> (fun e ->
@ -102,14 +105,20 @@ let history s name since =
let entries = Vmm_ring.read_history r since in let entries = Vmm_ring.read_history r since in
Logs.debug (fun m -> m "found %d history" (List.length entries)) ; Logs.debug (fun m -> m "found %d history" (List.length entries)) ;
Lwt_list.iter_s (fun (i, v) -> Lwt_list.iter_s (fun (i, v) ->
Vmm_lwt.write_raw s (data my_version name i v)) entries >|= fun () -> Vmm_lwt.write_raw s (data my_version name i v) >|= fun _ -> ())
entries >|= fun () ->
Ok "success" Ok "success"
let handle s addr () = let handle s addr () =
Logs.info (fun m -> m "handling connection %a" pp_sockaddr addr) ; Logs.info (fun m -> m "handling connection %a" pp_sockaddr addr) ;
let rec loop () = let rec loop () =
Vmm_lwt.read_exactly s >>= function Vmm_lwt.read_exactly s >>= function
| Error (`Msg msg) -> Logs.err (fun m -> m "error while reading %s" msg) ; loop () | Error (`Msg msg) ->
Logs.err (fun m -> m "error while reading %s" msg) ;
loop ()
| Error _ ->
Logs.err (fun m -> m "exception while reading") ;
Lwt.return_unit
| Ok (hdr, data) -> | Ok (hdr, data) ->
(if not (version_eq hdr.version my_version) then (if not (version_eq hdr.version my_version) then
Lwt.return (Error (`Msg "ignoring data with bad version")) Lwt.return (Error (`Msg "ignoring data with bad version"))
@ -138,10 +147,14 @@ let handle s addr () =
| Ok msg -> Vmm_lwt.write_raw s (success ~msg hdr.id my_version) | Ok msg -> Vmm_lwt.write_raw s (success ~msg hdr.id my_version)
| Error (`Msg msg) -> | Error (`Msg msg) ->
Logs.err (fun m -> m "error while processing command: %s" msg) ; Logs.err (fun m -> m "error while processing command: %s" msg) ;
Vmm_lwt.write_raw s (fail ~msg hdr.id my_version)) >>= fun () -> Vmm_lwt.write_raw s (fail ~msg hdr.id my_version)) >>= function
loop () | Ok () -> loop ()
| Error _ ->
Logs.err (fun m -> m "exception while writing to socket") ;
Lwt.return_unit
in in
loop () loop () >>= fun () ->
Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit)
let jump _ file = let jump _ file =
Sys.(set_signal sigpipe Signal_ignore) ; Sys.(set_signal sigpipe Signal_ignore) ;

View file

@ -29,8 +29,13 @@ let write_complete s str =
in in
w 0 w 0
let pp_sockaddr ppf = function
| Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str
| Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d"
(Unix.string_of_inet_addr addr) port
let handle fd ring s addr () = let handle fd ring s addr () =
Logs.info (fun m -> m "handling connection") ; Logs.info (fun m -> m "handling connection from %a" pp_sockaddr addr) ;
let str = Fmt.strf "%a: CONNECT\n" (Ptime.pp_human ~tz_offset_s:0 ()) (Ptime_clock.now ()) in let str = Fmt.strf "%a: CONNECT\n" (Ptime.pp_human ~tz_offset_s:0 ()) (Ptime_clock.now ()) in
write_complete fd str >>= fun () -> write_complete fd str >>= fun () ->
let rec loop () = let rec loop () =
@ -38,21 +43,25 @@ let handle fd ring s addr () =
| Error (`Msg e) -> | Error (`Msg e) ->
Logs.err (fun m -> m "error while reading %s" e) ; Logs.err (fun m -> m "error while reading %s" e) ;
loop () loop ()
| Error _ ->
Logs.err (fun m -> m "exception while reading") ;
Lwt.return_unit
| Ok (hdr, data) -> | Ok (hdr, data) ->
let out =
(if not (version_eq hdr.version my_version) then (if not (version_eq hdr.version my_version) then
Lwt.return (Error (`Msg "unknown version")) Error (`Msg "unknown version")
else match int_to_op hdr.tag with else match int_to_op hdr.tag with
| Some Data -> | Some Data ->
(match decode_ts data with (match decode_ts data with
| Ok ts -> Vmm_ring.write ring (ts, data) | Ok ts -> Vmm_ring.write ring (ts, data)
| Error _ -> ()) ; | Error _ ->
write_complete fd data >>= fun () -> Logs.warn (fun m -> m "ignoring error while decoding timestamp %s" data)) ;
Lwt.return (Ok None) Ok (`Data data)
| Some History -> | Some History ->
begin match decode_str data with begin match decode_str data with
| Error e -> Lwt.return (Error e) | Error e -> Error e
| Ok (str, off) -> match decode_ts ~off data with | Ok (str, off) -> match decode_ts ~off data with
| Error e -> Lwt.return (Error e) | Error e -> Error e
| Ok ts -> | Ok ts ->
let elements = Vmm_ring.read_history ring ts in let elements = Vmm_ring.read_history ring ts in
let res = List.fold_left (fun acc (_, x) -> let res = List.fold_left (fun acc (_, x) ->
@ -67,23 +76,46 @@ let handle fd ring s addr () =
[] elements [] elements
in in
(* just need a wrapper in tag = Log.Data, id = reqid *) (* just need a wrapper in tag = Log.Data, id = reqid *)
Lwt_list.iter_s (fun x -> let out =
List.fold_left (fun acc x ->
let length = String.length x in let length = String.length x in
let hdr = Vmm_wire.create_header { length ; id = hdr.id ; tag = op_to_int Data ; version = my_version } in let hdr = Vmm_wire.create_header { length ; id = hdr.id ; tag = op_to_int Data ; version = my_version } in
Vmm_lwt.write_raw s (Cstruct.to_string hdr ^ x)) (Cstruct.to_string hdr ^ x) :: acc)
(List.rev res) >>= fun () -> [] (List.rev res)
Lwt.return (Ok None) in
Ok (`Out out)
end end
| _ -> | _ ->
Logs.err (fun m -> m "didn't understand log command %d" hdr.tag) ; Error (`Msg "unknown command"))
Lwt.return (Error (`Msg "unknown command"))) >>= (function
| Ok msg -> Vmm_lwt.write_raw s (success ?msg hdr.id my_version)
| Error (`Msg msg) ->
Logs.err (fun m -> m "error while processing: %s" msg) ;
Vmm_lwt.write_raw s (fail ~msg hdr.id my_version)) >>= fun () ->
loop ()
in in
Lwt.catch loop (fun e -> Lwt.return_unit) match out with
| Error (`Msg msg) ->
begin
Logs.err (fun m -> m "error while processing: %s" msg) ;
Vmm_lwt.write_raw s (fail ~msg hdr.id my_version) >>= function
| Error _ -> Logs.err (fun m -> m "error0 while writing") ; Lwt.return_unit
| Ok () -> loop ()
end
| Ok (`Data data) ->
begin
write_complete fd data >>= fun () ->
Vmm_lwt.write_raw s (success hdr.id my_version) >>= function
| Error _ -> Logs.err (fun m -> m "error1 while writing") ; Lwt.return_unit
| Ok () -> loop ()
end
| Ok (`Out datas) ->
Lwt_list.fold_left_s (fun r x -> match r with
| Error e -> Lwt.return (Error e)
| Ok () -> Vmm_lwt.write_raw s x)
(Ok ()) datas >>= function
| Error _ -> Logs.err (fun m -> m "error2 while writing") ; Lwt.return_unit
| Ok () ->
Vmm_lwt.write_raw s (success hdr.id my_version) >>= function
| Error _ -> Logs.err (fun m -> m "error3 while writing") ; Lwt.return_unit
| Ok () -> loop ()
in
loop () >>= fun () ->
Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit)
let jump _ file sock = let jump _ file sock =
Sys.(set_signal sigpipe Signal_ignore) ; Sys.(set_signal sigpipe Signal_ignore) ;

View file

@ -128,7 +128,9 @@ let process db tls hdr data =
let out = Vmm_wire.Client.cmd `Info !command my_version in let out = Vmm_wire.Client.cmd `Info !command my_version in
command := succ !command ; command := succ !command ;
Logs.debug (fun m -> m "writing %a over TLS" Cstruct.hexdump_pp (Cstruct.of_string out)) ; Logs.debug (fun m -> m "writing %a over TLS" Cstruct.hexdump_pp (Cstruct.of_string out)) ;
Vmm_tls.write_tls tls out (Vmm_tls.write_tls tls out >|= function
| Ok () -> ()
| Error _ -> Logs.err (fun m -> m "error while writing") ; ())
| _ -> | _ ->
let r = let r =
match hdr.tag with match hdr.tag with
@ -176,21 +178,23 @@ let process db tls hdr data =
match r with match r with
| Ok `None -> Lwt.return_unit | Ok `None -> Lwt.return_unit
| Ok (`Sockaddr s) -> d s | Ok (`Sockaddr s) -> d s
| Ok (`Stat (fd, s, out)) -> Vmm_lwt.write_raw fd out >>= fun () -> d (fd, s) | Ok (`Stat (fd, s, out)) ->
(Vmm_lwt.write_raw fd out >>= function
| Ok () -> d (fd, s)
| Error _ -> Logs.err (fun m -> m "exception while writing") ; Lwt.return_unit)
| Error (`Msg msg) -> Logs.err (fun m -> m "error while processing: %s" msg) ; Lwt.return_unit | Error (`Msg msg) -> Logs.err (fun m -> m "error while processing: %s" msg) ; Lwt.return_unit
let rec tls_listener db tls = let rec tls_listener db tls =
Lwt.catch (fun () -> (Vmm_tls.read_tls tls >>= function
Vmm_tls.read_tls tls >>= function
| Error (`Msg msg) -> | Error (`Msg msg) ->
Logs.err (fun m -> m "error while reading %s" msg) ; Logs.err (fun m -> m "error while reading %s" msg) ;
Lwt.return (Ok ()) Lwt.return (Ok ())
| Error _ ->
Logs.err (fun m -> m "received exception in read_tls") ;
Lwt.return (Error ())
| Ok (hdr, data) -> | Ok (hdr, data) ->
process db tls hdr data >>= fun () -> process db tls hdr data >>= fun () ->
Lwt.return (Ok ())) Lwt.return (Ok ())) >>= function
(fun e ->
Logs.err (fun m -> m "received exception in read_tls: %s" (Printexc.to_string e)) ;
Lwt.return (Error ())) >>= function
| Ok () -> tls_listener db tls | Ok () -> tls_listener db tls
| Error () -> Lwt.return_unit | Error () -> Lwt.return_unit
@ -203,24 +207,32 @@ let hdr =
(* wait for TCP connection, once received request stats from vmmd, and loop *) (* wait for TCP connection, once received request stats from vmmd, and loop *)
let rec tcp_listener db tcp tls = let rec tcp_listener db tcp tls =
Lwt_unix.accept tcp >>= fun (cs, sockaddr) -> Lwt_unix.accept tcp >>= fun (cs, sockaddr) ->
Vmm_lwt.write_raw cs hdr >>= fun () -> Vmm_lwt.write_raw cs hdr >>= function
| Error _ -> Logs.err (fun m -> m "exception while accepting") ; Lwt.return_unit
| Ok () ->
let l = List.length !known_vms in let l = List.length !known_vms in
let ip, port = match sockaddr with Lwt_unix.ADDR_INET (ip, port) -> ip, port | _ -> invalid_arg "unexpected" in let ip, port = match sockaddr with Lwt_unix.ADDR_INET (ip, port) -> ip, port | _ -> invalid_arg "unexpected" in
Logs.info (fun m -> m "connection from %s:%d with %d known" (Unix.string_of_inet_addr ip) port l) ; Logs.info (fun m -> m "connection from %s:%d with %d known" (Unix.string_of_inet_addr ip) port l) ;
(if l = 0 then (if l = 0 then
Lwt_unix.close cs Lwt_unix.close cs >|= fun () -> Error ()
else begin else begin
count := SM.add sockaddr (List.length !known_vms) !count ; count := SM.add sockaddr (List.length !known_vms) !count ;
Lwt_list.iter_s Lwt_list.fold_left_s
(fun vm -> (fun r vm ->
match r with
| Error () -> Lwt.return (Error ())
| Ok () ->
let vm_id = translate_name db vm in let vm_id = translate_name db vm in
let out = Vmm_wire.Client.cmd `Statistics ~arg:vm_id !command my_version in let out = Vmm_wire.Client.cmd `Statistics ~arg:vm_id !command my_version in
t := IM.add !command (cs, sockaddr, vm) !t ; t := IM.add !command (cs, sockaddr, vm) !t ;
command := succ !command ; command := succ !command ;
Vmm_tls.write_tls tls out) Vmm_tls.write_tls tls out >|= function
!known_vms | Ok () -> Ok ()
end) >>= fun () -> | Error _ -> Logs.err (fun m -> m "exception while writing") ; Error ())
tcp_listener db tcp tls (Ok ()) !known_vms
end) >>= function
| Ok () -> tcp_listener db tcp tls
| Error () -> Lwt.return_unit
let client cas host port cert priv_key db listen_ip listen_port = let client cas host port cert priv_key db listen_ip listen_port =
Nocrypto_entropy_lwt.initialize () >>= fun () -> Nocrypto_entropy_lwt.initialize () >>= fun () ->

View file

@ -2,14 +2,17 @@
open Lwt.Infix open Lwt.Infix
let write_raw s data =
Vmm_lwt.write_raw s data >|= fun _ -> ()
let write_tls state t data = let write_tls state t data =
Lwt.catch (fun () -> Vmm_tls.write_tls (fst t) data) Vmm_tls.write_tls (fst t) data >>= function
(fun e -> | Ok () -> Lwt.return_unit
| Error `Exception ->
let state', out = Vmm_engine.handle_disconnect !state t in let state', out = Vmm_engine.handle_disconnect !state t in
state := state' ; state := state' ;
Lwt_list.iter_s (fun (s, data) -> Vmm_lwt.write_raw s data) out >>= fun () -> Lwt_list.iter_s (fun (s, data) -> write_raw s data) out >>= fun () ->
Tls_lwt.Unix.close (fst t) >>= fun () -> Tls_lwt.Unix.close (fst t)
raise e)
let to_ipaddr (_, sa) = match sa with let to_ipaddr (_, sa) = match sa with
| Lwt_unix.ADDR_UNIX _ -> invalid_arg "cannot convert unix address" | Lwt_unix.ADDR_UNIX _ -> invalid_arg "cannot convert unix address"
@ -22,7 +25,7 @@ let pp_sockaddr ppf (_, sa) = match sa with
let process state xs = let process state xs =
Lwt_list.iter_s (function Lwt_list.iter_s (function
| `Raw (s, str) -> Vmm_lwt.write_raw s str | `Raw (s, str) -> write_raw s str
| `Tls (s, str) -> write_tls state s str) | `Tls (s, str) -> write_tls state s str)
xs xs
@ -73,19 +76,19 @@ let handle ca state t =
| Error (`Msg msg) -> | Error (`Msg msg) ->
Logs.err (fun m -> m "reading client %a error: %s" pp_sockaddr t msg) ; Logs.err (fun m -> m "reading client %a error: %s" pp_sockaddr t msg) ;
loop () loop ()
| Error _ ->
Logs.err (fun m -> m "disconnect from %a" pp_sockaddr t) ;
let state', cons = Vmm_engine.handle_disconnect !state t in
state := state' ;
Lwt_list.iter_s (fun (s, data) -> write_raw s data) cons >>= fun () ->
Tls_lwt.Unix.close (fst t)
| Ok (hdr, buf) -> | Ok (hdr, buf) ->
let state', out = Vmm_engine.handle_command !state t prefix perms hdr buf in let state', out = Vmm_engine.handle_command !state t prefix perms hdr buf in
state := state' ; state := state' ;
process state out >>= fun () -> process state out >>= fun () ->
loop () loop ()
in in
Lwt.catch loop loop ()
(fun e ->
let state', cons = Vmm_engine.handle_disconnect !state t in
state := state' ;
Lwt_list.iter_s (fun (s, data) -> Vmm_lwt.write_raw s data) cons >>= fun () ->
Tls_lwt.Unix.close (fst t) >>= fun () ->
raise e)
| `Close socks -> | `Close socks ->
Logs.debug (fun m -> m "closing session with %d active ones" (List.length socks)) ; Logs.debug (fun m -> m "closing session with %d active ones" (List.length socks)) ;
Lwt_list.iter_s (fun (t, _) -> Tls_lwt.Unix.close t) socks >>= fun () -> Lwt_list.iter_s (fun (t, _) -> Tls_lwt.Unix.close t) socks >>= fun () ->
@ -105,18 +108,26 @@ let server_socket port =
listen s 10 ; listen s 10 ;
Lwt.return s Lwt.return s
let init_exception () = let init_sock dir name =
Lwt.async_exception_hook := (function let c = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
| Tls_lwt.Tls_failure a -> Lwt_unix.set_close_on_exec c ;
Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a)) let addr = Fpath.(dir / name + "sock") in
| exn -> Lwt.catch (fun () ->
Logs.err (fun m -> m "exception: %s" (Printexc.to_string exn))) Lwt_unix.(connect c (ADDR_UNIX (Fpath.to_string addr))) >|= fun () -> Some c)
(fun e ->
Logs.warn (fun m -> m "error %s connecting to socket %a"
(Printexc.to_string e) Fpath.pp addr) ;
(Lwt.catch (fun () -> Lwt_unix.close c) (fun _ -> Lwt.return_unit)) >|= fun () ->
None)
let rec read_log state s = let rec read_log state s =
Vmm_lwt.read_exactly s >>= function Vmm_lwt.read_exactly s >>= function
| Error (`Msg msg) -> | Error (`Msg msg) ->
Logs.err (fun m -> m "reading log error %s" msg) ; Logs.err (fun m -> m "reading log error %s" msg) ;
read_log state s read_log state s
| Error _ ->
Logs.err (fun m -> m "exception while reading log") ;
invalid_arg "log socket communication issue"
| Ok (hdr, data) -> | Ok (hdr, data) ->
let state', outs = Vmm_engine.handle_log !state hdr data in let state', outs = Vmm_engine.handle_log !state hdr data in
state := state' ; state := state' ;
@ -128,6 +139,9 @@ let rec read_cons state s =
| Error (`Msg msg) -> | Error (`Msg msg) ->
Logs.err (fun m -> m "reading console error %s" msg) ; Logs.err (fun m -> m "reading console error %s" msg) ;
read_cons state s read_cons state s
| Error _ ->
Logs.err (fun m -> m "exception while reading console socket") ;
invalid_arg "console socket communication issue"
| Ok (hdr, data) -> | Ok (hdr, data) ->
let state', outs = Vmm_engine.handle_cons !state hdr data in let state', outs = Vmm_engine.handle_cons !state hdr data in
state := state' ; state := state' ;
@ -139,6 +153,10 @@ let rec read_stats state s =
| Error (`Msg msg) -> | Error (`Msg msg) ->
Logs.err (fun m -> m "reading stats error %s" msg) ; Logs.err (fun m -> m "reading stats error %s" msg) ;
read_stats state s read_stats state s
| Error _ ->
Logs.err (fun m -> m "exception while reading stats") ;
Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit) >|= fun () ->
state := { !state with Vmm_engine.stats_socket = None }
| Ok (hdr, data) -> | Ok (hdr, data) ->
let state', outs = Vmm_engine.handle_stat !state hdr data in let state', outs = Vmm_engine.handle_stat !state hdr data in
state := state' ; state := state' ;
@ -156,23 +174,15 @@ let cmp_s (_, a) (_, b) =
let jump _ dir cacert cert priv_key = let jump _ dir cacert cert priv_key =
Sys.(set_signal sigpipe Signal_ignore) ; Sys.(set_signal sigpipe Signal_ignore) ;
let dir = Fpath.v dir in
Lwt_main.run Lwt_main.run
(init_exception () ; ((init_sock dir "cons" >|= function
let d = Fpath.v dir in | None -> invalid_arg "cannot connect to console socket"
let c = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in | Some c -> c) >>= fun c ->
Lwt_unix.set_close_on_exec c ; init_sock dir "stat" >>= fun s ->
Lwt_unix.(connect c (ADDR_UNIX Fpath.(to_string (d / "cons" + "sock")))) >>= fun () -> (init_sock dir "log" >|= function
Lwt.catch (fun () -> | None -> invalid_arg "cannot connect to log socket"
let s = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in | Some l -> l) >>= fun l ->
Lwt_unix.set_close_on_exec s ;
Lwt_unix.(connect s (ADDR_UNIX Fpath.(to_string (d / "stat" + "sock")))) >|= fun () ->
Some s)
(function
| Unix.Unix_error (Unix.ENOENT, _, _) -> Lwt.return None
| e -> Lwt.fail e) >>= fun s ->
let l = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
Lwt_unix.set_close_on_exec l ;
Lwt_unix.(connect l (ADDR_UNIX Fpath.(to_string (d / "log" + "sock")))) >>= fun () ->
server_socket 1025 >>= fun socket -> server_socket 1025 >>= fun socket ->
X509_lwt.private_of_pems ~cert ~priv_key >>= fun cert -> X509_lwt.private_of_pems ~cert ~priv_key >>= fun cert ->
X509_lwt.certs_of_pem cacert >>= (function X509_lwt.certs_of_pem cacert >>= (function
@ -182,7 +192,7 @@ let jump _ dir cacert cert priv_key =
Tls.(Config.server ~version:(Core.TLS_1_2, Core.TLS_1_2) Tls.(Config.server ~version:(Core.TLS_1_2, Core.TLS_1_2)
~reneg:true ~certificates:(`Single cert) ()) ~reneg:true ~certificates:(`Single cert) ())
in in
(match Vmm_engine.init d cmp_s c s l with (match Vmm_engine.init dir cmp_s c s l with
| Ok s -> Lwt.return s | Ok s -> Lwt.return s
| Error (`Msg m) -> Lwt.fail_with m) >>= fun t -> | Error (`Msg m) -> Lwt.fail_with m) >>= fun t ->
let state = ref t in let state = ref t in
@ -200,7 +210,13 @@ let jump _ dir cacert cert priv_key =
(fun exn -> (fun exn ->
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () -> Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () ->
Lwt.fail exn) >>= fun t -> Lwt.fail exn) >>= fun t ->
Lwt.async (fun () -> handle ca state t) ; Lwt.async (fun () ->
Lwt.catch
(fun () -> handle ca state t)
(fun e ->
Logs.err (fun m -> m "error while handle() %s"
(Printexc.to_string e)) ;
Lwt.return_unit)) ;
loop ()) loop ())
(function (function
| Unix.Unix_error (e, f, _) -> | Unix.Unix_error (e, f, _) ->

View file

@ -39,20 +39,34 @@ let wait_and_clear pid stdout =
let read_exactly s = let read_exactly s =
let buf = Bytes.create 8 in let buf = Bytes.create 8 in
let rec r b i l = let rec r b i l =
Lwt.catch (fun () ->
Lwt_unix.read s b i l >>= function Lwt_unix.read s b i l >>= function
| 0 -> Lwt.fail_with "end of file" | 0 ->
| n when n == l -> Lwt.return_unit Logs.err (fun m -> m "end of file while reading") ;
Lwt.return (Error `Eof)
| n when n == l -> Lwt.return (Ok ())
| n when n < l -> r b (i + n) (l - n) | n when n < l -> r b (i + n) (l - n)
| _ -> Lwt.fail_with "read too much" | _ ->
Logs.err (fun m -> m "read too much, shouldn't happen)") ;
Lwt.return (Error `Toomuch))
(fun e ->
let err = Printexc.to_string e in
Logs.err (fun m -> m "exception %s while reading" err) ;
Lwt.return (Error `Exception))
in in
r buf 0 8 >>= fun () -> r buf 0 8 >>= function
| Error e -> Lwt.return (Error e)
| Ok () ->
match Vmm_wire.parse_header (Bytes.to_string buf) with match Vmm_wire.parse_header (Bytes.to_string buf) with
| Error (`Msg m) -> Lwt.return (Error (`Msg m)) | Error (`Msg m) -> Lwt.return (Error (`Msg m))
| Ok hdr -> | Ok hdr ->
let l = hdr.Vmm_wire.length in let l = hdr.Vmm_wire.length in
if l > 0 then if l > 0 then
let b = Bytes.create l in let b = Bytes.create l in
r b 0 l >|= fun () -> r b 0 l >|= function
| Error e -> Error e
| Ok () ->
Logs.debug (fun m -> m "read hdr %a, body %a" Logs.debug (fun m -> m "read hdr %a, body %a"
Cstruct.hexdump_pp (Cstruct.of_bytes buf) Cstruct.hexdump_pp (Cstruct.of_bytes buf)
Cstruct.hexdump_pp (Cstruct.of_bytes b)) ; Cstruct.hexdump_pp (Cstruct.of_bytes b)) ;
@ -63,11 +77,15 @@ let read_exactly s =
let write_raw s buf = let write_raw s buf =
let buf = Bytes.unsafe_of_string buf in let buf = Bytes.unsafe_of_string buf in
let rec w off l = let rec w off l =
Lwt.catch (fun () ->
Lwt_unix.send s buf off l [] >>= fun n -> Lwt_unix.send s buf off l [] >>= fun n ->
if n = l then if n = l then
Lwt.return_unit Lwt.return (Ok ())
else else
w (off + n) (l - n) w (off + n) (l - n))
(fun e ->
Logs.err (fun m -> m "exception %s while writing" (Printexc.to_string e)) ;
Lwt.return (Error `Exception))
in in
Logs.debug (fun m -> m "writing %a" Cstruct.hexdump_pp (Cstruct.of_bytes buf)) ; Logs.debug (fun m -> m "writing %a" Cstruct.hexdump_pp (Cstruct.of_bytes buf)) ;
w 0 (Bytes.length buf) w 0 (Bytes.length buf)

View file

@ -6,23 +6,39 @@ let read_tls t =
let rec r_n buf off tot = let rec r_n buf off tot =
let l = tot - off in let l = tot - off in
if l = 0 then if l = 0 then
Lwt.return_unit Lwt.return (Ok ())
else else
Lwt.catch (fun () ->
Tls_lwt.Unix.read t (Cstruct.shift buf off) >>= function Tls_lwt.Unix.read t (Cstruct.shift buf off) >>= function
| 0 -> Lwt.fail_with "read 0 bytes" | 0 ->
| x when x == l -> Lwt.return_unit Logs.err (fun m -> m "TLS: end of file") ;
Lwt.return (Error `Eof)
| x when x == l -> Lwt.return (Ok ())
| x when x < l -> r_n buf (off + x) tot | x when x < l -> r_n buf (off + x) tot
| _ -> Lwt.fail_with "overread, will never happen" | _ ->
Logs.err (fun m -> m "TLS: read too much, shouldn't happen") ;
Lwt.return (Error `Toomuch))
(function
| Tls_lwt.Tls_failure a ->
Logs.err (fun m -> m "TLS read failure: %s" (Tls.Engine.string_of_failure a)) ;
Lwt.return (Error `Exception)
| e ->
Logs.err (fun m -> m "TLS read exception %s" (Printexc.to_string e)) ;
Lwt.return (Error `Exception))
in in
let buf = Cstruct.create 8 in let buf = Cstruct.create 8 in
r_n buf 0 8 >>= fun () -> r_n buf 0 8 >>= function
| Error e -> Lwt.return (Error e)
| Ok () ->
match Vmm_wire.parse_header (Cstruct.to_string buf) with match Vmm_wire.parse_header (Cstruct.to_string buf) with
| Error (`Msg m) -> Lwt.return (Error (`Msg m)) | Error (`Msg m) -> Lwt.return (Error (`Msg m))
| Ok hdr -> | Ok hdr ->
let l = hdr.Vmm_wire.length in let l = hdr.Vmm_wire.length in
if l > 0 then if l > 0 then
let b = Cstruct.create l in let b = Cstruct.create l in
r_n b 0 l >|= fun () -> r_n b 0 l >|= function
| Error e -> Error e
| Ok () ->
Logs.debug (fun m -> m "TLS read id %d %a tag %d data %a" Logs.debug (fun m -> m "TLS read id %d %a tag %d data %a"
hdr.Vmm_wire.id Vmm_wire.pp_version hdr.Vmm_wire.version hdr.Vmm_wire.tag hdr.Vmm_wire.id Vmm_wire.pp_version hdr.Vmm_wire.version hdr.Vmm_wire.tag
Cstruct.hexdump_pp b) ; Cstruct.hexdump_pp b) ;
@ -32,4 +48,12 @@ let read_tls t =
let write_tls s buf = let write_tls s buf =
Logs.debug (fun m -> m "TLS write %a" Cstruct.hexdump_pp (Cstruct.of_string buf)) ; Logs.debug (fun m -> m "TLS write %a" Cstruct.hexdump_pp (Cstruct.of_string buf)) ;
Tls_lwt.Unix.write s (Cstruct.of_string buf) Lwt.catch
(fun () -> Tls_lwt.Unix.write s (Cstruct.of_string buf) >|= fun () -> Ok ())
(function
| Tls_lwt.Tls_failure a ->
Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a)) ;
Lwt.return (Error `Exception)
| e ->
Logs.err (fun m -> m "TLS write exception %s" (Printexc.to_string e)) ;
Lwt.return (Error `Exception))

View file

@ -26,15 +26,18 @@ let handle s addr () =
let rec loop () = let rec loop () =
Vmm_lwt.read_exactly s >>= function Vmm_lwt.read_exactly s >>= function
| Error (`Msg msg) -> Logs.err (fun m -> m "error while reading %s" msg) ; loop () | Error (`Msg msg) -> Logs.err (fun m -> m "error while reading %s" msg) ; loop ()
| Error _ -> Logs.err (fun m -> m "exception while reading") ; Lwt.return_unit
| Ok (hdr, data) -> | Ok (hdr, data) ->
Logs.debug (fun m -> m "received %a" Cstruct.hexdump_pp (Cstruct.of_string data)) ; Logs.debug (fun m -> m "received %a" Cstruct.hexdump_pp (Cstruct.of_string data)) ;
let t', out = Vmm_stats.handle !t hdr data in let t', out = Vmm_stats.handle !t hdr data in
t := t' ; t := t' ;
Logs.debug (fun m -> m "sent %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ; Logs.debug (fun m -> m "sent %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ;
Vmm_lwt.write_raw s out >>= fun () -> Vmm_lwt.write_raw s out >>= function
loop () | Ok () -> loop ()
| Error _ -> Logs.err (fun m -> m "exception while writing") ; Lwt.return_unit
in in
loop () loop () >>= fun () ->
Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit)
let rec timer () = let rec timer () =
t := Vmm_stats.tick !t ; t := Vmm_stats.tick !t ;