deploy_nixos: improve host key check
This commit is contained in:
parent
6d32ac1baa
commit
a585a0af44
1 changed files with 48 additions and 9 deletions
|
@ -2,12 +2,13 @@
|
|||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Dict, Tuple, IO, Iterator, Optional, Callable, Any, Union
|
||||
from typing import List, Dict, Tuple, IO, Iterator, Optional, Callable, Any, Union, Text
|
||||
from threading import Thread
|
||||
import subprocess
|
||||
import fcntl
|
||||
import select
|
||||
from contextlib import ExitStack
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -29,6 +30,13 @@ def pipe() -> Iterator[Tuple[IO[str], IO[str]]]:
|
|||
FILE = Union[None, int]
|
||||
|
||||
|
||||
class HostKeyCheck(Enum):
|
||||
STRICT = 0
|
||||
# trust-on-first-use
|
||||
TOFU = 1
|
||||
NONE = 2
|
||||
|
||||
|
||||
class DeployHost:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -37,6 +45,7 @@ class DeployHost:
|
|||
port: int = 22,
|
||||
forward_agent: bool = False,
|
||||
command_prefix: Optional[str] = None,
|
||||
host_key_check: HostKeyCheck = HostKeyCheck.STRICT,
|
||||
meta: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
self.host = host
|
||||
|
@ -47,6 +56,7 @@ class DeployHost:
|
|||
else:
|
||||
self.command_prefix = host
|
||||
self.forward_agent = forward_agent
|
||||
self.host_key_check = host_key_check
|
||||
self.meta = meta
|
||||
|
||||
def _prefix_output(
|
||||
|
@ -91,8 +101,8 @@ class DeployHost:
|
|||
return stdout_buf, stderr_buf
|
||||
|
||||
def _run(
|
||||
self, cmd: str, shell: bool, stdout: FILE = None, stderr: FILE = None
|
||||
) -> subprocess.CompletedProcess:
|
||||
self, cmd: List[str], shell: bool, stdout: FILE = None, stderr: FILE = None
|
||||
) -> subprocess.CompletedProcess[Text]:
|
||||
with ExitStack() as stack:
|
||||
if stdout is None or stderr is None:
|
||||
read_fd, write_fd = stack.enter_context(pipe())
|
||||
|
@ -117,32 +127,39 @@ class DeployHost:
|
|||
stdout_write.close()
|
||||
if stderr == subprocess.PIPE:
|
||||
stderr_write.close()
|
||||
stdout, stderr = self._prefix_output(read_fd, stdout_read, stderr_read)
|
||||
stdout_data, stderr_data = self._prefix_output(read_fd, stdout_read, stderr_read)
|
||||
ret = p.wait()
|
||||
return subprocess.CompletedProcess(
|
||||
cmd, ret, stdout=stdout, stderr=stderr
|
||||
cmd, ret, stdout=stdout_data, stderr=stderr_data
|
||||
)
|
||||
raise RuntimeError("unreachable")
|
||||
|
||||
def run_local(
|
||||
self, cmd: str, stdout: FILE = None, stderr: FILE = None
|
||||
) -> subprocess.CompletedProcess:
|
||||
print(f"[{self.command_prefix}] {cmd}")
|
||||
return self._run(cmd, shell=True, stdout=stdout, stderr=stderr)
|
||||
return self._run([cmd], shell=True, stdout=stdout, stderr=stderr)
|
||||
|
||||
def run(
|
||||
self, cmd: str, stdout: FILE = None, stderr: FILE = None
|
||||
) -> subprocess.CompletedProcess:
|
||||
print(f"[{self.command_prefix}] {cmd}")
|
||||
ssh_opts = ["-A"] if self.forward_agent else []
|
||||
cmd = (
|
||||
|
||||
if self.host_key_check != HostKeyCheck.STRICT:
|
||||
ssh_opts.extend(["-o", "StrictHostKeyChecking=no"])
|
||||
if self.host_key_check == HostKeyCheck.NONE:
|
||||
ssh_opts.extend(["-o", "UserKnownHostsFile=/dev/null"])
|
||||
|
||||
ssh_cmd = (
|
||||
["ssh", f"{self.user}@{self.host}", "-p", str(self.port)]
|
||||
+ ssh_opts
|
||||
+ ["--", cmd]
|
||||
)
|
||||
return self._run(cmd, shell=False, stdout=stdout, stderr=stderr)
|
||||
return self._run(ssh_cmd, shell=False, stdout=stdout, stderr=stderr)
|
||||
|
||||
|
||||
DeployResults = List[Tuple[DeployHost, int]]
|
||||
DeployResults = List[Tuple[DeployHost, subprocess.CompletedProcess[Text]]]
|
||||
|
||||
|
||||
class DeployGroup:
|
||||
|
@ -212,3 +229,25 @@ class DeployGroup:
|
|||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
def parse_hosts(
|
||||
hosts: str,
|
||||
host_key_check: HostKeyCheck = HostKeyCheck.STRICT,
|
||||
forward_agent: bool = False,
|
||||
) -> DeployGroup:
|
||||
deploy_hosts = []
|
||||
for h in hosts.split(","):
|
||||
parts = h.split("@")
|
||||
if len(parts) > 1:
|
||||
user = parts[0]
|
||||
hostname = parts[1]
|
||||
else:
|
||||
user = "root"
|
||||
hostname = parts[0]
|
||||
deploy_hosts.append(
|
||||
DeployHost(
|
||||
hostname, user=user, host_key_check=host_key_check, forward_agent=forward_agent
|
||||
)
|
||||
)
|
||||
return DeployGroup(deploy_hosts)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue