diff --git a/app/vmmd_tls.ml b/app/vmmd_tls.ml index 49f8b8c..2d5cc65 100644 --- a/app/vmmd_tls.ml +++ b/app/vmmd_tls.ml @@ -55,6 +55,7 @@ let process fd tls = Vmm_lwt.read_wire fd >>= function | Error _ -> Lwt.return (Error (`Msg "read error")) | Ok wire -> + (* TODO check version *) Logs.debug (fun m -> m "proxying %a" Vmm_commands.pp_wire wire) ; Vmm_tls_lwt.write_tls tls wire >|= function | Ok () -> Ok () @@ -66,26 +67,59 @@ let handle ca (tls, addr) = | Error (`Msg m) -> Vmm_tls_lwt.close tls >>= fun () -> Lwt.fail_with m - | Ok (name, cmd) -> + | Ok (name, policies, cmd) -> let sock, next = Vmm_commands.endpoint cmd in connect (Vmm_core.socket_path sock) >>= fun fd -> - let wire = - let header = Vmm_commands.{version = my_version ; sequence = !command ; id = name } in - command := Int64.succ !command ; - (header, `Command cmd) - in - Vmm_lwt.write_wire fd wire >>= function - | Error `Exception -> - Vmm_tls_lwt.close tls >>= fun () -> - Vmm_lwt.safe_close fd >>= fun () -> - Lwt.return (Error (`Msg "couldn't write")) + (match sock with + | `Vmmd -> + Lwt_list.fold_left_s (fun r (id, policy) -> + match r with + | Error (`Msg msg) -> Lwt.return (Error (`Msg msg)) + | Ok () -> + Logs.debug (fun m -> m "adding policy for %a: %a" Vmm_core.pp_id id Vmm_core.pp_policy policy) ; + let header = Vmm_commands.{version = my_version ; sequence = !command ; id } in + command := Int64.succ !command ; + Vmm_lwt.write_wire fd (header, `Command (`Policy_cmd (`Policy_add policy))) >>= function + | Error `Exception -> Lwt.return (Error (`Msg "failed to write policy")) + | Ok () -> + Vmm_lwt.read_wire fd >|= function + | Error _ -> Error (`Msg "read error") + | Ok (_, `Success _) -> Ok () + | Ok _ -> + (* TODO check version *) + Error (`Msg ("expected success, received something else when adding policy"))) + (Ok ()) policies + | _ -> Lwt.return (Ok ())) >>= function + | Error (`Msg msg) -> + begin + Logs.debug (fun m -> m "error while applying policies %s" msg) ; + let wire = + let header = Vmm_commands.{version = my_version ; sequence = 0L ; id = name } in + header, `Failure msg + in + Vmm_tls_lwt.write_tls tls wire >>= fun _ -> + Vmm_tls_lwt.close tls >>= fun () -> + Vmm_lwt.safe_close fd >>= fun () -> + Lwt.fail_with msg + end | Ok () -> - (match next with - | `Read -> read fd tls - | `End -> process fd tls) >>= fun res -> - Vmm_tls_lwt.close tls >>= fun () -> - Vmm_lwt.safe_close fd >|= fun () -> - res + let wire = + let header = Vmm_commands.{version = my_version ; sequence = !command ; id = name } in + command := Int64.succ !command ; + (header, `Command cmd) + in + Vmm_lwt.write_wire fd wire >>= function + | Error `Exception -> + Vmm_tls_lwt.close tls >>= fun () -> + Vmm_lwt.safe_close fd >>= fun () -> + Lwt.return (Error (`Msg "couldn't write")) + | Ok () -> + (match next with + | `Read -> read fd tls + | `End -> process fd tls) >>= fun res -> + Vmm_tls_lwt.close tls >>= fun () -> + Vmm_lwt.safe_close fd >|= fun () -> + res let server_socket port = let open Lwt_unix in diff --git a/src/vmm_tls.ml b/src/vmm_tls.ml index df68224..864d010 100644 --- a/src/vmm_tls.ml +++ b/src/vmm_tls.ml @@ -4,17 +4,22 @@ open Rresult open Rresult.R.Infix (* we skip all non-albatross certificates *) +let cert_name cert = + match X509.Extension.unsupported cert Vmm_asn.oid with + | None -> None + | Some _ -> + let data = X509.common_name_to_string cert in + (* if the common name is empty, skip [useful for vmmc_bistro at least] + TODO: document properly and investigate potential security issue with + multi-tenant system (likely ca should ensure to never sign a delegation + with empty common name) *) + if data = "" then None else Some data + let name chain = List.fold_left (fun acc cert -> - match X509.Extension.unsupported cert Vmm_asn.oid with + match cert_name cert with | None -> acc - | Some _ -> - let data = X509.common_name_to_string cert in - (* if the common name is empty, skip [useful for vmmc_bistro at least] - TODO: document properly and investigate potential security issue with - multi-tenant system (likely ca should ensure to never sign a delegation - with empty common name) *) - if data = "" then acc else data :: acc) + | Some data -> data :: acc) [] chain (* this separates the leaf and top-level certificate from the chain, @@ -27,15 +32,15 @@ let separate_chain = function let wire_command_of_cert version cert = match X509.Extension.unsupported cert Vmm_asn.oid with - | None -> R.error_msgf "albatross OID is not present in certificate (%a)" Asn.OID.pp Vmm_asn.oid + | None -> Error `Not_present | Some (_, data) -> - Vmm_asn.cert_extension_of_cstruct data >>= fun (v, wire) -> - if not (Vmm_commands.version_eq v version) then - R.error_msgf "unexpected version %a (expected %a)" - Vmm_commands.pp_version v - Vmm_commands.pp_version version - else - Ok wire + match Vmm_asn.cert_extension_of_cstruct data with + | Error (`Msg p) -> Error (`Parse p) + | Ok (v, wire) -> + if not (Vmm_commands.version_eq v version) then + Error (`Version v) + else + Ok wire (* let check_policy = (* get names and static resources *) @@ -50,6 +55,26 @@ let wire_command_of_cert version cert = check_policies vm_config (List.map snd policies) >>= fun () -> *) +let extract_policies version chain = + List.fold_left (fun acc cert -> + match acc, wire_command_of_cert version cert with + | Error e, _ -> Error e + | Ok acc, Error `Not_present -> Ok acc + | Ok _, Error (`Parse msg) -> Error (`Msg msg) + | Ok _, Error (`Version received) -> + R.error_msgf "unexpected version %a (expected %a)" + Vmm_commands.pp_version received + Vmm_commands.pp_version version + | Ok (prefix, acc), Ok (`Policy_cmd (`Policy_add p)) -> + let name = match cert_name cert with + | None -> prefix + | Some x -> x :: prefix + in + Ok (name, (name, p) :: acc) + | _, Ok wire -> + R.error_msgf "unexpected wire %a" Vmm_commands.pp wire) + (Ok ([], [])) chain + let handle _addr version chain = separate_chain chain >>= fun (leaf, rest) -> let name = name chain in @@ -57,15 +82,22 @@ let handle _addr version chain = (X509.common_name_to_string leaf) Fmt.(list ~sep:(unit " -> ") string) (List.map X509.common_name_to_string rest)) ; - (* TODO: inspect top-level-cert of chain. *) + extract_policies version rest >>= fun (_, policies) -> (* TODO: logging let login_hdr, login_ev = Log.hdr name, `Login addr in *) - (* TODO: update policies (parse chain for policy, and apply them)! *) - wire_command_of_cert version leaf >>= fun wire -> - (* we only allow some commands via certificate *) - match wire with - | `Console_cmd (`Console_subscribe _) - | `Stats_cmd `Stats_subscribe - | `Log_cmd (`Log_subscribe _) - | `Vm_cmd _ - | `Policy_cmd _ -> Ok (name, wire) (* TODO policy_cmd is special (via delegation chain) *) - | _ -> Error (`Msg "unexpected command") + match wire_command_of_cert version leaf with + | Error (`Parse p) -> Error (`Msg p) + | Error (`Not_present) -> + Error (`Msg "leaf certificate does not contain an albatross extension") + | Error (`Version received) -> + R.error_msgf "unexpected version %a (expected %a)" + Vmm_commands.pp_version received + Vmm_commands.pp_version version + | Ok wire -> + (* we only allow some commands via certificate *) + match wire with + | `Console_cmd (`Console_subscribe _) + | `Stats_cmd `Stats_subscribe + | `Log_cmd (`Log_subscribe _) + | `Vm_cmd _ + | `Policy_cmd `Policy_info -> Ok (name, policies, wire) + | _ -> Error (`Msg "unexpected command") diff --git a/src/vmm_tls.mli b/src/vmm_tls.mli index 6505d41..61b5674 100644 --- a/src/vmm_tls.mli +++ b/src/vmm_tls.mli @@ -1,9 +1,10 @@ (* (c) 2018 Hannes Mehnert, all rights reserved *) val wire_command_of_cert : Vmm_commands.version -> X509.t -> - (Vmm_commands.t, [> `Msg of string ]) result + (Vmm_commands.t, [> `Parse of string | `Not_present | `Version of Vmm_commands.version ]) result val handle : 'a -> Vmm_commands.version -> X509.t list -> - (string list * Vmm_commands.t, [> `Msg of string ]) Result.result + (string list * (Vmm_core.id * Vmm_core.policy) list * Vmm_commands.t, + [> `Msg of string ]) Result.result