safer and clearer error semantics for all processes, fixes #5
This commit is contained in:
parent
88012094f8
commit
cfa7ccd1e0
|
@ -52,18 +52,14 @@ let process db hdr data =
|
|||
| Error (`Msg msg) -> Logs.err (fun m -> m "error while processing: %s" msg)
|
||||
|
||||
let rec read_tls_write_cons db t =
|
||||
Lwt.catch (fun () ->
|
||||
Vmm_tls.read_tls t >>= function
|
||||
| Error (`Msg msg) ->
|
||||
Logs.err (fun m -> m "error while reading %s" msg) ;
|
||||
read_tls_write_cons db t
|
||||
| Ok (hdr, data) ->
|
||||
process db hdr data ;
|
||||
read_tls_write_cons db t)
|
||||
(fun e ->
|
||||
Logs.err (fun m -> m "exception reading TLS stream %s"
|
||||
(Printexc.to_string e)) ;
|
||||
Tls_lwt.Unix.close t)
|
||||
Vmm_tls.read_tls t >>= function
|
||||
| Error (`Msg msg) ->
|
||||
Logs.err (fun m -> m "error while reading %s" msg) ;
|
||||
read_tls_write_cons db t
|
||||
| Error _ -> Logs.err (fun m -> m "exception while reading") ; Lwt.return_unit
|
||||
| Ok (hdr, data) ->
|
||||
process db hdr data ;
|
||||
read_tls_write_cons db t
|
||||
|
||||
let rec read_cons_write_tls db t =
|
||||
Lwt.catch (fun () ->
|
||||
|
@ -77,10 +73,14 @@ let rec read_cons_write_tls db t =
|
|||
| Some cmd ->
|
||||
let out = Vmm_wire.Client.cmd ?arg cmd !command my_version in
|
||||
command := succ !command ;
|
||||
Vmm_tls.write_tls t out >>= fun () ->
|
||||
Logs.debug (fun m -> m "wrote %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ;
|
||||
read_cons_write_tls db t)
|
||||
(fun _ -> Lwt.return_unit)
|
||||
Vmm_tls.write_tls t out >>= function
|
||||
| Error _ -> Logs.err (fun m -> m "exception while writing") ; Lwt.return_unit
|
||||
| Ok () ->
|
||||
Logs.debug (fun m -> m "wrote %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ;
|
||||
read_cons_write_tls db t)
|
||||
(fun e ->
|
||||
Logs.err (fun m -> m "exception %s in read_cons_write_tls" (Printexc.to_string e)) ;
|
||||
Lwt.return_unit)
|
||||
|
||||
let client cas host port cert priv_key db =
|
||||
Nocrypto_entropy_lwt.initialize () >>= fun () ->
|
||||
|
|
|
@ -37,8 +37,11 @@ let read_console s name ring channel () =
|
|||
(if String.Set.mem name !active then
|
||||
Vmm_lwt.write_raw s (data my_version name t line)
|
||||
else
|
||||
Lwt.return_unit) >>= fun () ->
|
||||
loop ()
|
||||
Lwt.return (Ok ())) >>= function
|
||||
| Ok () -> loop ()
|
||||
| Error _ ->
|
||||
Logs.err (fun m -> m "error reading console") ;
|
||||
Lwt_io.close channel
|
||||
in
|
||||
loop ())
|
||||
(fun e ->
|
||||
|
@ -102,14 +105,20 @@ let history s name since =
|
|||
let entries = Vmm_ring.read_history r since in
|
||||
Logs.debug (fun m -> m "found %d history" (List.length entries)) ;
|
||||
Lwt_list.iter_s (fun (i, v) ->
|
||||
Vmm_lwt.write_raw s (data my_version name i v)) entries >|= fun () ->
|
||||
Vmm_lwt.write_raw s (data my_version name i v) >|= fun _ -> ())
|
||||
entries >|= fun () ->
|
||||
Ok "success"
|
||||
|
||||
let handle s addr () =
|
||||
Logs.info (fun m -> m "handling connection %a" pp_sockaddr addr) ;
|
||||
let rec loop () =
|
||||
Vmm_lwt.read_exactly s >>= function
|
||||
| Error (`Msg msg) -> Logs.err (fun m -> m "error while reading %s" msg) ; loop ()
|
||||
| Error (`Msg msg) ->
|
||||
Logs.err (fun m -> m "error while reading %s" msg) ;
|
||||
loop ()
|
||||
| Error _ ->
|
||||
Logs.err (fun m -> m "exception while reading") ;
|
||||
Lwt.return_unit
|
||||
| Ok (hdr, data) ->
|
||||
(if not (version_eq hdr.version my_version) then
|
||||
Lwt.return (Error (`Msg "ignoring data with bad version"))
|
||||
|
@ -138,10 +147,14 @@ let handle s addr () =
|
|||
| Ok msg -> Vmm_lwt.write_raw s (success ~msg hdr.id my_version)
|
||||
| Error (`Msg msg) ->
|
||||
Logs.err (fun m -> m "error while processing command: %s" msg) ;
|
||||
Vmm_lwt.write_raw s (fail ~msg hdr.id my_version)) >>= fun () ->
|
||||
loop ()
|
||||
Vmm_lwt.write_raw s (fail ~msg hdr.id my_version)) >>= function
|
||||
| Ok () -> loop ()
|
||||
| Error _ ->
|
||||
Logs.err (fun m -> m "exception while writing to socket") ;
|
||||
Lwt.return_unit
|
||||
in
|
||||
loop ()
|
||||
loop () >>= fun () ->
|
||||
Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit)
|
||||
|
||||
let jump _ file =
|
||||
Sys.(set_signal sigpipe Signal_ignore) ;
|
||||
|
|
122
app/vmm_log.ml
122
app/vmm_log.ml
|
@ -29,8 +29,13 @@ let write_complete s str =
|
|||
in
|
||||
w 0
|
||||
|
||||
let pp_sockaddr ppf = function
|
||||
| Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str
|
||||
| Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d"
|
||||
(Unix.string_of_inet_addr addr) port
|
||||
|
||||
let handle fd ring s addr () =
|
||||
Logs.info (fun m -> m "handling connection") ;
|
||||
Logs.info (fun m -> m "handling connection from %a" pp_sockaddr addr) ;
|
||||
let str = Fmt.strf "%a: CONNECT\n" (Ptime.pp_human ~tz_offset_s:0 ()) (Ptime_clock.now ()) in
|
||||
write_complete fd str >>= fun () ->
|
||||
let rec loop () =
|
||||
|
@ -38,52 +43,79 @@ let handle fd ring s addr () =
|
|||
| Error (`Msg e) ->
|
||||
Logs.err (fun m -> m "error while reading %s" e) ;
|
||||
loop ()
|
||||
| Error _ ->
|
||||
Logs.err (fun m -> m "exception while reading") ;
|
||||
Lwt.return_unit
|
||||
| Ok (hdr, data) ->
|
||||
(if not (version_eq hdr.version my_version) then
|
||||
Lwt.return (Error (`Msg "unknown version"))
|
||||
else match int_to_op hdr.tag with
|
||||
| Some Data ->
|
||||
( match decode_ts data with
|
||||
| Ok ts -> Vmm_ring.write ring (ts, data)
|
||||
| Error _ -> ()) ;
|
||||
write_complete fd data >>= fun () ->
|
||||
Lwt.return (Ok None)
|
||||
| Some History ->
|
||||
begin match decode_str data with
|
||||
| Error e -> Lwt.return (Error e)
|
||||
| Ok (str, off) -> match decode_ts ~off data with
|
||||
| Error e -> Lwt.return (Error e)
|
||||
| Ok ts ->
|
||||
let elements = Vmm_ring.read_history ring ts in
|
||||
let res = List.fold_left (fun acc (_, x) ->
|
||||
match Vmm_wire.Log.decode_log_hdr (Cstruct.of_string x) with
|
||||
| Ok (hdr, _) ->
|
||||
Logs.debug (fun m -> m "found an entry: %a" (Vmm_core.Log.pp_hdr []) hdr) ;
|
||||
if String.equal str (Vmm_core.string_of_id hdr.Vmm_core.Log.context) then
|
||||
x :: acc
|
||||
else
|
||||
acc
|
||||
| _ -> acc)
|
||||
[] elements
|
||||
in
|
||||
(* just need a wrapper in tag = Log.Data, id = reqid *)
|
||||
Lwt_list.iter_s (fun x ->
|
||||
let length = String.length x in
|
||||
let hdr = Vmm_wire.create_header { length ; id = hdr.id ; tag = op_to_int Data ; version = my_version } in
|
||||
Vmm_lwt.write_raw s (Cstruct.to_string hdr ^ x))
|
||||
(List.rev res) >>= fun () ->
|
||||
Lwt.return (Ok None)
|
||||
end
|
||||
| _ ->
|
||||
Logs.err (fun m -> m "didn't understand log command %d" hdr.tag) ;
|
||||
Lwt.return (Error (`Msg "unknown command"))) >>= (function
|
||||
| Ok msg -> Vmm_lwt.write_raw s (success ?msg hdr.id my_version)
|
||||
| Error (`Msg msg) ->
|
||||
Logs.err (fun m -> m "error while processing: %s" msg) ;
|
||||
Vmm_lwt.write_raw s (fail ~msg hdr.id my_version)) >>= fun () ->
|
||||
loop ()
|
||||
let out =
|
||||
(if not (version_eq hdr.version my_version) then
|
||||
Error (`Msg "unknown version")
|
||||
else match int_to_op hdr.tag with
|
||||
| Some Data ->
|
||||
(match decode_ts data with
|
||||
| Ok ts -> Vmm_ring.write ring (ts, data)
|
||||
| Error _ ->
|
||||
Logs.warn (fun m -> m "ignoring error while decoding timestamp %s" data)) ;
|
||||
Ok (`Data data)
|
||||
| Some History ->
|
||||
begin match decode_str data with
|
||||
| Error e -> Error e
|
||||
| Ok (str, off) -> match decode_ts ~off data with
|
||||
| Error e -> Error e
|
||||
| Ok ts ->
|
||||
let elements = Vmm_ring.read_history ring ts in
|
||||
let res = List.fold_left (fun acc (_, x) ->
|
||||
match Vmm_wire.Log.decode_log_hdr (Cstruct.of_string x) with
|
||||
| Ok (hdr, _) ->
|
||||
Logs.debug (fun m -> m "found an entry: %a" (Vmm_core.Log.pp_hdr []) hdr) ;
|
||||
if String.equal str (Vmm_core.string_of_id hdr.Vmm_core.Log.context) then
|
||||
x :: acc
|
||||
else
|
||||
acc
|
||||
| _ -> acc)
|
||||
[] elements
|
||||
in
|
||||
(* just need a wrapper in tag = Log.Data, id = reqid *)
|
||||
let out =
|
||||
List.fold_left (fun acc x ->
|
||||
let length = String.length x in
|
||||
let hdr = Vmm_wire.create_header { length ; id = hdr.id ; tag = op_to_int Data ; version = my_version } in
|
||||
(Cstruct.to_string hdr ^ x) :: acc)
|
||||
[] (List.rev res)
|
||||
in
|
||||
Ok (`Out out)
|
||||
end
|
||||
| _ ->
|
||||
Error (`Msg "unknown command"))
|
||||
in
|
||||
match out with
|
||||
| Error (`Msg msg) ->
|
||||
begin
|
||||
Logs.err (fun m -> m "error while processing: %s" msg) ;
|
||||
Vmm_lwt.write_raw s (fail ~msg hdr.id my_version) >>= function
|
||||
| Error _ -> Logs.err (fun m -> m "error0 while writing") ; Lwt.return_unit
|
||||
| Ok () -> loop ()
|
||||
end
|
||||
| Ok (`Data data) ->
|
||||
begin
|
||||
write_complete fd data >>= fun () ->
|
||||
Vmm_lwt.write_raw s (success hdr.id my_version) >>= function
|
||||
| Error _ -> Logs.err (fun m -> m "error1 while writing") ; Lwt.return_unit
|
||||
| Ok () -> loop ()
|
||||
end
|
||||
| Ok (`Out datas) ->
|
||||
Lwt_list.fold_left_s (fun r x -> match r with
|
||||
| Error e -> Lwt.return (Error e)
|
||||
| Ok () -> Vmm_lwt.write_raw s x)
|
||||
(Ok ()) datas >>= function
|
||||
| Error _ -> Logs.err (fun m -> m "error2 while writing") ; Lwt.return_unit
|
||||
| Ok () ->
|
||||
Vmm_lwt.write_raw s (success hdr.id my_version) >>= function
|
||||
| Error _ -> Logs.err (fun m -> m "error3 while writing") ; Lwt.return_unit
|
||||
| Ok () -> loop ()
|
||||
in
|
||||
Lwt.catch loop (fun e -> Lwt.return_unit)
|
||||
loop () >>= fun () ->
|
||||
Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit)
|
||||
|
||||
let jump _ file sock =
|
||||
Sys.(set_signal sigpipe Signal_ignore) ;
|
||||
|
|
|
@ -128,7 +128,9 @@ let process db tls hdr data =
|
|||
let out = Vmm_wire.Client.cmd `Info !command my_version in
|
||||
command := succ !command ;
|
||||
Logs.debug (fun m -> m "writing %a over TLS" Cstruct.hexdump_pp (Cstruct.of_string out)) ;
|
||||
Vmm_tls.write_tls tls out
|
||||
(Vmm_tls.write_tls tls out >|= function
|
||||
| Ok () -> ()
|
||||
| Error _ -> Logs.err (fun m -> m "error while writing") ; ())
|
||||
| _ ->
|
||||
let r =
|
||||
match hdr.tag with
|
||||
|
@ -176,21 +178,23 @@ let process db tls hdr data =
|
|||
match r with
|
||||
| Ok `None -> Lwt.return_unit
|
||||
| Ok (`Sockaddr s) -> d s
|
||||
| Ok (`Stat (fd, s, out)) -> Vmm_lwt.write_raw fd out >>= fun () -> d (fd, s)
|
||||
| Ok (`Stat (fd, s, out)) ->
|
||||
(Vmm_lwt.write_raw fd out >>= function
|
||||
| Ok () -> d (fd, s)
|
||||
| Error _ -> Logs.err (fun m -> m "exception while writing") ; Lwt.return_unit)
|
||||
| Error (`Msg msg) -> Logs.err (fun m -> m "error while processing: %s" msg) ; Lwt.return_unit
|
||||
|
||||
let rec tls_listener db tls =
|
||||
Lwt.catch (fun () ->
|
||||
Vmm_tls.read_tls tls >>= function
|
||||
| Error (`Msg msg) ->
|
||||
Logs.err (fun m -> m "error while reading %s" msg) ;
|
||||
Lwt.return (Ok ())
|
||||
| Ok (hdr, data) ->
|
||||
process db tls hdr data >>= fun () ->
|
||||
Lwt.return (Ok ()))
|
||||
(fun e ->
|
||||
Logs.err (fun m -> m "received exception in read_tls: %s" (Printexc.to_string e)) ;
|
||||
Lwt.return (Error ())) >>= function
|
||||
(Vmm_tls.read_tls tls >>= function
|
||||
| Error (`Msg msg) ->
|
||||
Logs.err (fun m -> m "error while reading %s" msg) ;
|
||||
Lwt.return (Ok ())
|
||||
| Error _ ->
|
||||
Logs.err (fun m -> m "received exception in read_tls") ;
|
||||
Lwt.return (Error ())
|
||||
| Ok (hdr, data) ->
|
||||
process db tls hdr data >>= fun () ->
|
||||
Lwt.return (Ok ())) >>= function
|
||||
| Ok () -> tls_listener db tls
|
||||
| Error () -> Lwt.return_unit
|
||||
|
||||
|
@ -203,24 +207,32 @@ let hdr =
|
|||
(* wait for TCP connection, once received request stats from vmmd, and loop *)
|
||||
let rec tcp_listener db tcp tls =
|
||||
Lwt_unix.accept tcp >>= fun (cs, sockaddr) ->
|
||||
Vmm_lwt.write_raw cs hdr >>= fun () ->
|
||||
let l = List.length !known_vms in
|
||||
let ip, port = match sockaddr with Lwt_unix.ADDR_INET (ip, port) -> ip, port | _ -> invalid_arg "unexpected" in
|
||||
Logs.info (fun m -> m "connection from %s:%d with %d known" (Unix.string_of_inet_addr ip) port l) ;
|
||||
(if l = 0 then
|
||||
Lwt_unix.close cs
|
||||
else begin
|
||||
count := SM.add sockaddr (List.length !known_vms) !count ;
|
||||
Lwt_list.iter_s
|
||||
(fun vm ->
|
||||
let vm_id = translate_name db vm in
|
||||
let out = Vmm_wire.Client.cmd `Statistics ~arg:vm_id !command my_version in
|
||||
t := IM.add !command (cs, sockaddr, vm) !t ;
|
||||
command := succ !command ;
|
||||
Vmm_tls.write_tls tls out)
|
||||
!known_vms
|
||||
end) >>= fun () ->
|
||||
tcp_listener db tcp tls
|
||||
Vmm_lwt.write_raw cs hdr >>= function
|
||||
| Error _ -> Logs.err (fun m -> m "exception while accepting") ; Lwt.return_unit
|
||||
| Ok () ->
|
||||
let l = List.length !known_vms in
|
||||
let ip, port = match sockaddr with Lwt_unix.ADDR_INET (ip, port) -> ip, port | _ -> invalid_arg "unexpected" in
|
||||
Logs.info (fun m -> m "connection from %s:%d with %d known" (Unix.string_of_inet_addr ip) port l) ;
|
||||
(if l = 0 then
|
||||
Lwt_unix.close cs >|= fun () -> Error ()
|
||||
else begin
|
||||
count := SM.add sockaddr (List.length !known_vms) !count ;
|
||||
Lwt_list.fold_left_s
|
||||
(fun r vm ->
|
||||
match r with
|
||||
| Error () -> Lwt.return (Error ())
|
||||
| Ok () ->
|
||||
let vm_id = translate_name db vm in
|
||||
let out = Vmm_wire.Client.cmd `Statistics ~arg:vm_id !command my_version in
|
||||
t := IM.add !command (cs, sockaddr, vm) !t ;
|
||||
command := succ !command ;
|
||||
Vmm_tls.write_tls tls out >|= function
|
||||
| Ok () -> Ok ()
|
||||
| Error _ -> Logs.err (fun m -> m "exception while writing") ; Error ())
|
||||
(Ok ()) !known_vms
|
||||
end) >>= function
|
||||
| Ok () -> tcp_listener db tcp tls
|
||||
| Error () -> Lwt.return_unit
|
||||
|
||||
let client cas host port cert priv_key db listen_ip listen_port =
|
||||
Nocrypto_entropy_lwt.initialize () >>= fun () ->
|
||||
|
|
94
app/vmmd.ml
94
app/vmmd.ml
|
@ -2,14 +2,17 @@
|
|||
|
||||
open Lwt.Infix
|
||||
|
||||
let write_raw s data =
|
||||
Vmm_lwt.write_raw s data >|= fun _ -> ()
|
||||
|
||||
let write_tls state t data =
|
||||
Lwt.catch (fun () -> Vmm_tls.write_tls (fst t) data)
|
||||
(fun e ->
|
||||
let state', out = Vmm_engine.handle_disconnect !state t in
|
||||
state := state' ;
|
||||
Lwt_list.iter_s (fun (s, data) -> Vmm_lwt.write_raw s data) out >>= fun () ->
|
||||
Tls_lwt.Unix.close (fst t) >>= fun () ->
|
||||
raise e)
|
||||
Vmm_tls.write_tls (fst t) data >>= function
|
||||
| Ok () -> Lwt.return_unit
|
||||
| Error `Exception ->
|
||||
let state', out = Vmm_engine.handle_disconnect !state t in
|
||||
state := state' ;
|
||||
Lwt_list.iter_s (fun (s, data) -> write_raw s data) out >>= fun () ->
|
||||
Tls_lwt.Unix.close (fst t)
|
||||
|
||||
let to_ipaddr (_, sa) = match sa with
|
||||
| Lwt_unix.ADDR_UNIX _ -> invalid_arg "cannot convert unix address"
|
||||
|
@ -22,7 +25,7 @@ let pp_sockaddr ppf (_, sa) = match sa with
|
|||
|
||||
let process state xs =
|
||||
Lwt_list.iter_s (function
|
||||
| `Raw (s, str) -> Vmm_lwt.write_raw s str
|
||||
| `Raw (s, str) -> write_raw s str
|
||||
| `Tls (s, str) -> write_tls state s str)
|
||||
xs
|
||||
|
||||
|
@ -73,19 +76,19 @@ let handle ca state t =
|
|||
| Error (`Msg msg) ->
|
||||
Logs.err (fun m -> m "reading client %a error: %s" pp_sockaddr t msg) ;
|
||||
loop ()
|
||||
| Error _ ->
|
||||
Logs.err (fun m -> m "disconnect from %a" pp_sockaddr t) ;
|
||||
let state', cons = Vmm_engine.handle_disconnect !state t in
|
||||
state := state' ;
|
||||
Lwt_list.iter_s (fun (s, data) -> write_raw s data) cons >>= fun () ->
|
||||
Tls_lwt.Unix.close (fst t)
|
||||
| Ok (hdr, buf) ->
|
||||
let state', out = Vmm_engine.handle_command !state t prefix perms hdr buf in
|
||||
state := state' ;
|
||||
process state out >>= fun () ->
|
||||
loop ()
|
||||
in
|
||||
Lwt.catch loop
|
||||
(fun e ->
|
||||
let state', cons = Vmm_engine.handle_disconnect !state t in
|
||||
state := state' ;
|
||||
Lwt_list.iter_s (fun (s, data) -> Vmm_lwt.write_raw s data) cons >>= fun () ->
|
||||
Tls_lwt.Unix.close (fst t) >>= fun () ->
|
||||
raise e)
|
||||
loop ()
|
||||
| `Close socks ->
|
||||
Logs.debug (fun m -> m "closing session with %d active ones" (List.length socks)) ;
|
||||
Lwt_list.iter_s (fun (t, _) -> Tls_lwt.Unix.close t) socks >>= fun () ->
|
||||
|
@ -105,18 +108,26 @@ let server_socket port =
|
|||
listen s 10 ;
|
||||
Lwt.return s
|
||||
|
||||
let init_exception () =
|
||||
Lwt.async_exception_hook := (function
|
||||
| Tls_lwt.Tls_failure a ->
|
||||
Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a))
|
||||
| exn ->
|
||||
Logs.err (fun m -> m "exception: %s" (Printexc.to_string exn)))
|
||||
let init_sock dir name =
|
||||
let c = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
|
||||
Lwt_unix.set_close_on_exec c ;
|
||||
let addr = Fpath.(dir / name + "sock") in
|
||||
Lwt.catch (fun () ->
|
||||
Lwt_unix.(connect c (ADDR_UNIX (Fpath.to_string addr))) >|= fun () -> Some c)
|
||||
(fun e ->
|
||||
Logs.warn (fun m -> m "error %s connecting to socket %a"
|
||||
(Printexc.to_string e) Fpath.pp addr) ;
|
||||
(Lwt.catch (fun () -> Lwt_unix.close c) (fun _ -> Lwt.return_unit)) >|= fun () ->
|
||||
None)
|
||||
|
||||
let rec read_log state s =
|
||||
Vmm_lwt.read_exactly s >>= function
|
||||
| Error (`Msg msg) ->
|
||||
Logs.err (fun m -> m "reading log error %s" msg) ;
|
||||
read_log state s
|
||||
| Error _ ->
|
||||
Logs.err (fun m -> m "exception while reading log") ;
|
||||
invalid_arg "log socket communication issue"
|
||||
| Ok (hdr, data) ->
|
||||
let state', outs = Vmm_engine.handle_log !state hdr data in
|
||||
state := state' ;
|
||||
|
@ -128,6 +139,9 @@ let rec read_cons state s =
|
|||
| Error (`Msg msg) ->
|
||||
Logs.err (fun m -> m "reading console error %s" msg) ;
|
||||
read_cons state s
|
||||
| Error _ ->
|
||||
Logs.err (fun m -> m "exception while reading console socket") ;
|
||||
invalid_arg "console socket communication issue"
|
||||
| Ok (hdr, data) ->
|
||||
let state', outs = Vmm_engine.handle_cons !state hdr data in
|
||||
state := state' ;
|
||||
|
@ -139,6 +153,10 @@ let rec read_stats state s =
|
|||
| Error (`Msg msg) ->
|
||||
Logs.err (fun m -> m "reading stats error %s" msg) ;
|
||||
read_stats state s
|
||||
| Error _ ->
|
||||
Logs.err (fun m -> m "exception while reading stats") ;
|
||||
Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit) >|= fun () ->
|
||||
state := { !state with Vmm_engine.stats_socket = None }
|
||||
| Ok (hdr, data) ->
|
||||
let state', outs = Vmm_engine.handle_stat !state hdr data in
|
||||
state := state' ;
|
||||
|
@ -156,23 +174,15 @@ let cmp_s (_, a) (_, b) =
|
|||
|
||||
let jump _ dir cacert cert priv_key =
|
||||
Sys.(set_signal sigpipe Signal_ignore) ;
|
||||
let dir = Fpath.v dir in
|
||||
Lwt_main.run
|
||||
(init_exception () ;
|
||||
let d = Fpath.v dir in
|
||||
let c = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
|
||||
Lwt_unix.set_close_on_exec c ;
|
||||
Lwt_unix.(connect c (ADDR_UNIX Fpath.(to_string (d / "cons" + "sock")))) >>= fun () ->
|
||||
Lwt.catch (fun () ->
|
||||
let s = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
|
||||
Lwt_unix.set_close_on_exec s ;
|
||||
Lwt_unix.(connect s (ADDR_UNIX Fpath.(to_string (d / "stat" + "sock")))) >|= fun () ->
|
||||
Some s)
|
||||
(function
|
||||
| Unix.Unix_error (Unix.ENOENT, _, _) -> Lwt.return None
|
||||
| e -> Lwt.fail e) >>= fun s ->
|
||||
let l = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
|
||||
Lwt_unix.set_close_on_exec l ;
|
||||
Lwt_unix.(connect l (ADDR_UNIX Fpath.(to_string (d / "log" + "sock")))) >>= fun () ->
|
||||
((init_sock dir "cons" >|= function
|
||||
| None -> invalid_arg "cannot connect to console socket"
|
||||
| Some c -> c) >>= fun c ->
|
||||
init_sock dir "stat" >>= fun s ->
|
||||
(init_sock dir "log" >|= function
|
||||
| None -> invalid_arg "cannot connect to log socket"
|
||||
| Some l -> l) >>= fun l ->
|
||||
server_socket 1025 >>= fun socket ->
|
||||
X509_lwt.private_of_pems ~cert ~priv_key >>= fun cert ->
|
||||
X509_lwt.certs_of_pem cacert >>= (function
|
||||
|
@ -182,7 +192,7 @@ let jump _ dir cacert cert priv_key =
|
|||
Tls.(Config.server ~version:(Core.TLS_1_2, Core.TLS_1_2)
|
||||
~reneg:true ~certificates:(`Single cert) ())
|
||||
in
|
||||
(match Vmm_engine.init d cmp_s c s l with
|
||||
(match Vmm_engine.init dir cmp_s c s l with
|
||||
| Ok s -> Lwt.return s
|
||||
| Error (`Msg m) -> Lwt.fail_with m) >>= fun t ->
|
||||
let state = ref t in
|
||||
|
@ -200,7 +210,13 @@ let jump _ dir cacert cert priv_key =
|
|||
(fun exn ->
|
||||
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) >>= fun () ->
|
||||
Lwt.fail exn) >>= fun t ->
|
||||
Lwt.async (fun () -> handle ca state t) ;
|
||||
Lwt.async (fun () ->
|
||||
Lwt.catch
|
||||
(fun () -> handle ca state t)
|
||||
(fun e ->
|
||||
Logs.err (fun m -> m "error while handle() %s"
|
||||
(Printexc.to_string e)) ;
|
||||
Lwt.return_unit)) ;
|
||||
loop ())
|
||||
(function
|
||||
| Unix.Unix_error (e, f, _) ->
|
||||
|
|
|
@ -39,35 +39,53 @@ let wait_and_clear pid stdout =
|
|||
let read_exactly s =
|
||||
let buf = Bytes.create 8 in
|
||||
let rec r b i l =
|
||||
Lwt_unix.read s b i l >>= function
|
||||
| 0 -> Lwt.fail_with "end of file"
|
||||
| n when n == l -> Lwt.return_unit
|
||||
| n when n < l -> r b (i + n) (l - n)
|
||||
| _ -> Lwt.fail_with "read too much"
|
||||
Lwt.catch (fun () ->
|
||||
Lwt_unix.read s b i l >>= function
|
||||
| 0 ->
|
||||
Logs.err (fun m -> m "end of file while reading") ;
|
||||
Lwt.return (Error `Eof)
|
||||
| n when n == l -> Lwt.return (Ok ())
|
||||
| n when n < l -> r b (i + n) (l - n)
|
||||
| _ ->
|
||||
Logs.err (fun m -> m "read too much, shouldn't happen)") ;
|
||||
Lwt.return (Error `Toomuch))
|
||||
(fun e ->
|
||||
let err = Printexc.to_string e in
|
||||
Logs.err (fun m -> m "exception %s while reading" err) ;
|
||||
Lwt.return (Error `Exception))
|
||||
|
||||
in
|
||||
r buf 0 8 >>= fun () ->
|
||||
match Vmm_wire.parse_header (Bytes.to_string buf) with
|
||||
| Error (`Msg m) -> Lwt.return (Error (`Msg m))
|
||||
| Ok hdr ->
|
||||
let l = hdr.Vmm_wire.length in
|
||||
if l > 0 then
|
||||
let b = Bytes.create l in
|
||||
r b 0 l >|= fun () ->
|
||||
Logs.debug (fun m -> m "read hdr %a, body %a"
|
||||
Cstruct.hexdump_pp (Cstruct.of_bytes buf)
|
||||
Cstruct.hexdump_pp (Cstruct.of_bytes b)) ;
|
||||
Ok (hdr, Bytes.to_string b)
|
||||
else
|
||||
Lwt.return (Ok (hdr, ""))
|
||||
r buf 0 8 >>= function
|
||||
| Error e -> Lwt.return (Error e)
|
||||
| Ok () ->
|
||||
match Vmm_wire.parse_header (Bytes.to_string buf) with
|
||||
| Error (`Msg m) -> Lwt.return (Error (`Msg m))
|
||||
| Ok hdr ->
|
||||
let l = hdr.Vmm_wire.length in
|
||||
if l > 0 then
|
||||
let b = Bytes.create l in
|
||||
r b 0 l >|= function
|
||||
| Error e -> Error e
|
||||
| Ok () ->
|
||||
Logs.debug (fun m -> m "read hdr %a, body %a"
|
||||
Cstruct.hexdump_pp (Cstruct.of_bytes buf)
|
||||
Cstruct.hexdump_pp (Cstruct.of_bytes b)) ;
|
||||
Ok (hdr, Bytes.to_string b)
|
||||
else
|
||||
Lwt.return (Ok (hdr, ""))
|
||||
|
||||
let write_raw s buf =
|
||||
let buf = Bytes.unsafe_of_string buf in
|
||||
let rec w off l =
|
||||
Lwt_unix.send s buf off l [] >>= fun n ->
|
||||
if n = l then
|
||||
Lwt.return_unit
|
||||
else
|
||||
w (off + n) (l - n)
|
||||
Lwt.catch (fun () ->
|
||||
Lwt_unix.send s buf off l [] >>= fun n ->
|
||||
if n = l then
|
||||
Lwt.return (Ok ())
|
||||
else
|
||||
w (off + n) (l - n))
|
||||
(fun e ->
|
||||
Logs.err (fun m -> m "exception %s while writing" (Printexc.to_string e)) ;
|
||||
Lwt.return (Error `Exception))
|
||||
in
|
||||
Logs.debug (fun m -> m "writing %a" Cstruct.hexdump_pp (Cstruct.of_bytes buf)) ;
|
||||
w 0 (Bytes.length buf)
|
||||
|
|
|
@ -6,30 +6,54 @@ let read_tls t =
|
|||
let rec r_n buf off tot =
|
||||
let l = tot - off in
|
||||
if l = 0 then
|
||||
Lwt.return_unit
|
||||
Lwt.return (Ok ())
|
||||
else
|
||||
Tls_lwt.Unix.read t (Cstruct.shift buf off) >>= function
|
||||
| 0 -> Lwt.fail_with "read 0 bytes"
|
||||
| x when x == l -> Lwt.return_unit
|
||||
| x when x < l -> r_n buf (off + x) tot
|
||||
| _ -> Lwt.fail_with "overread, will never happen"
|
||||
Lwt.catch (fun () ->
|
||||
Tls_lwt.Unix.read t (Cstruct.shift buf off) >>= function
|
||||
| 0 ->
|
||||
Logs.err (fun m -> m "TLS: end of file") ;
|
||||
Lwt.return (Error `Eof)
|
||||
| x when x == l -> Lwt.return (Ok ())
|
||||
| x when x < l -> r_n buf (off + x) tot
|
||||
| _ ->
|
||||
Logs.err (fun m -> m "TLS: read too much, shouldn't happen") ;
|
||||
Lwt.return (Error `Toomuch))
|
||||
(function
|
||||
| Tls_lwt.Tls_failure a ->
|
||||
Logs.err (fun m -> m "TLS read failure: %s" (Tls.Engine.string_of_failure a)) ;
|
||||
Lwt.return (Error `Exception)
|
||||
| e ->
|
||||
Logs.err (fun m -> m "TLS read exception %s" (Printexc.to_string e)) ;
|
||||
Lwt.return (Error `Exception))
|
||||
in
|
||||
let buf = Cstruct.create 8 in
|
||||
r_n buf 0 8 >>= fun () ->
|
||||
match Vmm_wire.parse_header (Cstruct.to_string buf) with
|
||||
| Error (`Msg m) -> Lwt.return (Error (`Msg m))
|
||||
| Ok hdr ->
|
||||
let l = hdr.Vmm_wire.length in
|
||||
if l > 0 then
|
||||
let b = Cstruct.create l in
|
||||
r_n b 0 l >|= fun () ->
|
||||
Logs.debug (fun m -> m "TLS read id %d %a tag %d data %a"
|
||||
hdr.Vmm_wire.id Vmm_wire.pp_version hdr.Vmm_wire.version hdr.Vmm_wire.tag
|
||||
Cstruct.hexdump_pp b) ;
|
||||
Ok (hdr, Cstruct.to_string b)
|
||||
else
|
||||
Lwt.return (Ok (hdr, ""))
|
||||
r_n buf 0 8 >>= function
|
||||
| Error e -> Lwt.return (Error e)
|
||||
| Ok () ->
|
||||
match Vmm_wire.parse_header (Cstruct.to_string buf) with
|
||||
| Error (`Msg m) -> Lwt.return (Error (`Msg m))
|
||||
| Ok hdr ->
|
||||
let l = hdr.Vmm_wire.length in
|
||||
if l > 0 then
|
||||
let b = Cstruct.create l in
|
||||
r_n b 0 l >|= function
|
||||
| Error e -> Error e
|
||||
| Ok () ->
|
||||
Logs.debug (fun m -> m "TLS read id %d %a tag %d data %a"
|
||||
hdr.Vmm_wire.id Vmm_wire.pp_version hdr.Vmm_wire.version hdr.Vmm_wire.tag
|
||||
Cstruct.hexdump_pp b) ;
|
||||
Ok (hdr, Cstruct.to_string b)
|
||||
else
|
||||
Lwt.return (Ok (hdr, ""))
|
||||
|
||||
let write_tls s buf =
|
||||
Logs.debug (fun m -> m "TLS write %a" Cstruct.hexdump_pp (Cstruct.of_string buf)) ;
|
||||
Tls_lwt.Unix.write s (Cstruct.of_string buf)
|
||||
Lwt.catch
|
||||
(fun () -> Tls_lwt.Unix.write s (Cstruct.of_string buf) >|= fun () -> Ok ())
|
||||
(function
|
||||
| Tls_lwt.Tls_failure a ->
|
||||
Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a)) ;
|
||||
Lwt.return (Error `Exception)
|
||||
| e ->
|
||||
Logs.err (fun m -> m "TLS write exception %s" (Printexc.to_string e)) ;
|
||||
Lwt.return (Error `Exception))
|
||||
|
|
|
@ -26,15 +26,18 @@ let handle s addr () =
|
|||
let rec loop () =
|
||||
Vmm_lwt.read_exactly s >>= function
|
||||
| Error (`Msg msg) -> Logs.err (fun m -> m "error while reading %s" msg) ; loop ()
|
||||
| Error _ -> Logs.err (fun m -> m "exception while reading") ; Lwt.return_unit
|
||||
| Ok (hdr, data) ->
|
||||
Logs.debug (fun m -> m "received %a" Cstruct.hexdump_pp (Cstruct.of_string data)) ;
|
||||
let t', out = Vmm_stats.handle !t hdr data in
|
||||
t := t' ;
|
||||
Logs.debug (fun m -> m "sent %a" Cstruct.hexdump_pp (Cstruct.of_string out)) ;
|
||||
Vmm_lwt.write_raw s out >>= fun () ->
|
||||
loop ()
|
||||
Vmm_lwt.write_raw s out >>= function
|
||||
| Ok () -> loop ()
|
||||
| Error _ -> Logs.err (fun m -> m "exception while writing") ; Lwt.return_unit
|
||||
in
|
||||
loop ()
|
||||
loop () >>= fun () ->
|
||||
Lwt.catch (fun () -> Lwt_unix.close s) (fun _ -> Lwt.return_unit)
|
||||
|
||||
let rec timer () =
|
||||
t := Vmm_stats.tick !t ;
|
||||
|
|
Loading…
Reference in a new issue