albatross/src/vmm_lwt.ml

172 lines
5.9 KiB
OCaml

(* (c) 2017 Hannes Mehnert, all rights reserved *)
open Lwt.Infix
let pp_sockaddr ppf = function
| Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str
| Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d"
(Unix.string_of_inet_addr addr) port
let safe_close fd =
Lwt.catch
(fun () -> Lwt_unix.close fd)
(fun _ -> Lwt.return_unit)
let server_socket systemd sock =
if systemd
then match Vmm_unix.sd_listen_fds () with
| Some [fd] -> Lwt.return (Lwt_unix.of_unix_file_descr fd)
| _ -> failwith "Systemd socket activation error" (* FIXME *)
else
let name = Vmm_core.socket_path sock in
(Lwt_unix.file_exists name >>= function
| true -> Lwt_unix.unlink name
| false -> Lwt.return_unit) >>= fun () ->
let s = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
Lwt_unix.set_close_on_exec s ;
let old_umask = Unix.umask 0 in
let _ = Unix.umask (old_umask land 0o707) in
Lwt_unix.(bind s (ADDR_UNIX name)) >|= fun () ->
Logs.app (fun m -> m "listening on %s" name);
let _ = Unix.umask old_umask in
Lwt_unix.listen s 1 ;
s
let connect addrtype sockaddr =
let c = Lwt_unix.(socket addrtype SOCK_STREAM 0) in
Lwt_unix.set_close_on_exec c ;
Lwt.catch (fun () ->
Lwt_unix.(connect c sockaddr) >|= fun () ->
Some c)
(fun e ->
Logs.warn (fun m -> m "error %s connecting to socket %a"
(Printexc.to_string e) pp_sockaddr sockaddr);
safe_close c >|= fun () ->
None)
let pp_process_status ppf = function
| Unix.WEXITED c -> Fmt.pf ppf "exited with %d" c
| Unix.WSIGNALED s -> Fmt.pf ppf "killed by signal %a" Fmt.Dump.signal s
| Unix.WSTOPPED s -> Fmt.pf ppf "stopped by signal %a" Fmt.Dump.signal s
let ret = function
| Unix.WEXITED c -> `Exit c
| Unix.WSIGNALED s -> `Signal s
| Unix.WSTOPPED s -> `Stop s
let rec waitpid pid =
Lwt.catch
(fun () -> Lwt_unix.waitpid [] pid >|= fun r -> Ok r)
(function
| Unix.(Unix_error (EINTR, _, _)) ->
Logs.debug (fun m -> m "EINTR in waitpid(), %d retrying" pid) ;
waitpid pid
| e ->
Logs.err (fun m -> m "error %s in waitpid() %d"
(Printexc.to_string e) pid) ;
Lwt.return (Error ()))
let wait_and_clear pid =
Logs.debug (fun m -> m "waitpid() for pid %d" pid) ;
waitpid pid >|= fun r ->
match r with
| Error () ->
Logs.err (fun m -> m "waitpid() for %d returned error" pid) ;
`Exit 23
| Ok (_, s) ->
Logs.debug (fun m -> m "pid %d exited: %a" pid pp_process_status s) ;
ret s
let read_wire s =
let buf = Bytes.create 4 in
let rec r b i l =
Lwt.catch (fun () ->
Lwt_unix.read s b i l >>= function
| 0 ->
Logs.debug (fun m -> m "end of file while reading") ;
Lwt.return (Error `Eof)
| n when n == l -> Lwt.return (Ok ())
| n when n < l -> r b (i + n) (l - n)
| _ ->
Logs.err (fun m -> m "read too much, shouldn't happen)") ;
Lwt.return (Error `Toomuch))
(fun e ->
let err = Printexc.to_string e in
Logs.err (fun m -> m "exception %s while reading" err) ;
safe_close s >|= fun () ->
Error `Exception)
in
r buf 0 4 >>= function
| Error e -> Lwt.return (Error e)
| Ok () ->
let len = Cstruct.BE.get_uint32 (Cstruct.of_bytes buf) 0 in
if len > 0l then begin
let b = Bytes.create (Int32.to_int len) in
r b 0 (Int32.to_int len) >|= function
| Error e -> Error e
| Ok () ->
(* Logs.debug (fun m -> m "read hdr %a, body %a"
Cstruct.hexdump_pp (Cstruct.of_bytes buf)
Cstruct.hexdump_pp (Cstruct.of_bytes b)) ; *)
match Vmm_asn.wire_of_cstruct (Cstruct.of_bytes b) with
| Error (`Msg msg) ->
Logs.err (fun m -> m "error %s while parsing data" msg) ;
Error `Exception
| (Ok (hdr, _)) as w ->
if not Vmm_commands.(is_current hdr.version) then
Logs.warn (fun m -> m "version mismatch, received %a current %a"
Vmm_commands.pp_version hdr.Vmm_commands.version
Vmm_commands.pp_version Vmm_commands.current);
w
end else begin
Lwt.return (Error `Eof)
end
let write_raw s buf =
let rec w off l =
Lwt.catch (fun () ->
Lwt_unix.send s buf off l [] >>= fun n ->
if n = l then
Lwt.return (Ok ())
else
w (off + n) (l - n))
(fun e ->
Logs.err (fun m -> m "exception %s while writing" (Printexc.to_string e)) ;
safe_close s >|= fun () ->
Error `Exception)
in
(* Logs.debug (fun m -> m "writing %a" Cstruct.hexdump_pp (Cstruct.of_bytes buf)) ; *)
w 0 (Bytes.length buf)
let write_wire s wire =
let data = Vmm_asn.wire_to_cstruct wire in
let dlen = Cstruct.create 4 in
Cstruct.BE.set_uint32 dlen 0 (Int32.of_int (Cstruct.len data)) ;
let buf = Cstruct.(to_bytes (append dlen data)) in
write_raw s buf
let read_from_file file =
Lwt.catch (fun () ->
Lwt_unix.stat file >>= fun stat ->
let size = stat.Lwt_unix.st_size in
Lwt_unix.openfile file Lwt_unix.[O_RDONLY] 0 >>= fun fd ->
Lwt.catch (fun () ->
let buf = Bytes.create size in
let rec read off =
Lwt_unix.read fd buf off (size - off) >>= fun bytes ->
if bytes + off = size then
Lwt.return_unit
else
read (bytes + off)
in
read 0 >>= fun () ->
safe_close fd >|= fun () ->
Cstruct.of_bytes buf)
(fun e ->
Logs.err (fun m -> m "exception %s while reading %s" (Printexc.to_string e) file) ;
safe_close fd >|= fun () ->
Cstruct.empty))
(fun e ->
Logs.err (fun m -> m "exception %s while reading %s" (Printexc.to_string e) file) ;
Lwt.return Cstruct.empty)