deploy_nixos: improve host key check

This commit is contained in:
Jörg Thalheim 2022-02-04 09:28:09 +01:00
parent 6d32ac1baa
commit a585a0af44

View file

@ -2,12 +2,13 @@
import os
from contextlib import contextmanager
from typing import List, Dict, Tuple, IO, Iterator, Optional, Callable, Any, Union
from typing import List, Dict, Tuple, IO, Iterator, Optional, Callable, Any, Union, Text
from threading import Thread
import subprocess
import fcntl
import select
from contextlib import ExitStack
from enum import Enum
@contextmanager
@ -29,6 +30,13 @@ def pipe() -> Iterator[Tuple[IO[str], IO[str]]]:
FILE = Union[None, int]
class HostKeyCheck(Enum):
STRICT = 0
# trust-on-first-use
TOFU = 1
NONE = 2
class DeployHost:
def __init__(
self,
@ -37,6 +45,7 @@ class DeployHost:
port: int = 22,
forward_agent: bool = False,
command_prefix: Optional[str] = None,
host_key_check: HostKeyCheck = HostKeyCheck.STRICT,
meta: Dict[str, Any] = {},
) -> None:
self.host = host
@ -47,6 +56,7 @@ class DeployHost:
else:
self.command_prefix = host
self.forward_agent = forward_agent
self.host_key_check = host_key_check
self.meta = meta
def _prefix_output(
@ -91,8 +101,8 @@ class DeployHost:
return stdout_buf, stderr_buf
def _run(
self, cmd: str, shell: bool, stdout: FILE = None, stderr: FILE = None
) -> subprocess.CompletedProcess:
self, cmd: List[str], shell: bool, stdout: FILE = None, stderr: FILE = None
) -> subprocess.CompletedProcess[Text]:
with ExitStack() as stack:
if stdout is None or stderr is None:
read_fd, write_fd = stack.enter_context(pipe())
@ -117,32 +127,39 @@ class DeployHost:
stdout_write.close()
if stderr == subprocess.PIPE:
stderr_write.close()
stdout, stderr = self._prefix_output(read_fd, stdout_read, stderr_read)
stdout_data, stderr_data = self._prefix_output(read_fd, stdout_read, stderr_read)
ret = p.wait()
return subprocess.CompletedProcess(
cmd, ret, stdout=stdout, stderr=stderr
cmd, ret, stdout=stdout_data, stderr=stderr_data
)
raise RuntimeError("unreachable")
def run_local(
self, cmd: str, stdout: FILE = None, stderr: FILE = None
) -> subprocess.CompletedProcess:
print(f"[{self.command_prefix}] {cmd}")
return self._run(cmd, shell=True, stdout=stdout, stderr=stderr)
return self._run([cmd], shell=True, stdout=stdout, stderr=stderr)
def run(
self, cmd: str, stdout: FILE = None, stderr: FILE = None
) -> subprocess.CompletedProcess:
print(f"[{self.command_prefix}] {cmd}")
ssh_opts = ["-A"] if self.forward_agent else []
cmd = (
if self.host_key_check != HostKeyCheck.STRICT:
ssh_opts.extend(["-o", "StrictHostKeyChecking=no"])
if self.host_key_check == HostKeyCheck.NONE:
ssh_opts.extend(["-o", "UserKnownHostsFile=/dev/null"])
ssh_cmd = (
["ssh", f"{self.user}@{self.host}", "-p", str(self.port)]
+ ssh_opts
+ ["--", cmd]
)
return self._run(cmd, shell=False, stdout=stdout, stderr=stderr)
return self._run(ssh_cmd, shell=False, stdout=stdout, stderr=stderr)
DeployResults = List[Tuple[DeployHost, int]]
DeployResults = List[Tuple[DeployHost, subprocess.CompletedProcess[Text]]]
class DeployGroup:
@ -212,3 +229,25 @@ class DeployGroup:
for thread in threads:
thread.join()
def parse_hosts(
hosts: str,
host_key_check: HostKeyCheck = HostKeyCheck.STRICT,
forward_agent: bool = False,
) -> DeployGroup:
deploy_hosts = []
for h in hosts.split(","):
parts = h.split("@")
if len(parts) > 1:
user = parts[0]
hostname = parts[1]
else:
user = "root"
hostname = parts[0]
deploy_hosts.append(
DeployHost(
hostname, user=user, host_key_check=host_key_check, forward_agent=forward_agent
)
)
return DeployGroup(deploy_hosts)