""" Simple, robust SSH client(s) for basic I/O with remote hosts
"""
import asyncio
import logging
import os
import pty
import stat
import subprocess
import tempfile
import typing
from contextlib import contextmanager
import retrying
from dcos_test_utils import helpers
log = logging.getLogger(__name__)
SHARED_SSH_OPTS = [
'-oConnectTimeout=10',
'-oStrictHostKeyChecking=no',
'-oUserKnownHostsFile=/dev/null',
'-oLogLevel=ERROR',
'-oBatchMode=yes',
'-oPasswordAuthentication=no']
[docs]class Tunnelled():
""" Abstraction of an already instantiated SSH-tunnel
Args:
opt_list: list of SSH options strings. E.G. '-oControlPath=foo'
target: string in the form user@host
port: port number to be used for SSH or SCP
"""
def __init__(self, opt_list: list, target: str, port: int):
self.opt_list = opt_list
self.target = target
self.port = port
[docs] def command(self, cmd: list, **kwargs) -> bytes:
""" Run a command at the tunnel target
Args:
cmd: list of strings that will be sent as a command to the target
**kwargs: any keywork args that can be passed into
subprocess.check_output. For more information, see:
https://docs.python.org/3/library/subprocess.html#subprocess.check_output
"""
run_cmd = ['ssh', '-p', str(self.port)] + self.opt_list + [self.target] + cmd
log.debug('Running socket cmd: ' + ' '.join(run_cmd))
if 'stdout' in kwargs:
return subprocess.run(run_cmd, **kwargs, check=True, env={"PATH": os.environ["PATH"]})
else:
return subprocess.run(run_cmd, **kwargs, check=True, env={"PATH": os.environ["PATH"]},
stdout=subprocess.PIPE).stdout
[docs] def copy_file(self, src: str, dst: str, to_remote=True) -> None:
""" Copy a path from localhost to target. If path is a local directory, then
recursive copy will be used.
Args:
src: local or remote representing source data
dst: local or remote destination path
to_remote: Whether copying from remote->local or local->remote
"""
copy_command = []
if to_remote:
if os.path.isdir(src):
copy_command.append('-r')
remote_full_path = '{}:{}'.format(self.target, dst)
copy_command += [src, remote_full_path]
else:
remote_full_path = '{}:{}'.format(self.target, src)
copy_command += [remote_full_path, dst]
cmd = ['scp'] + self.opt_list + ['-P', str(self.port)] + copy_command
log.debug('Copying {} to {}'.format(*copy_command[-2:]))
log.debug('scp command: {}'.format(cmd))
subprocess.run(cmd, check=True, env={"PATH": os.environ["PATH"]})
[docs]def temp_ssh_key(key: str) -> str:
""" Dumps an SSH key string to a temp file that will be deleted at session close and returns the path
"""
key_path = helpers.session_tempfile(key)
os.chmod(str(key_path), stat.S_IREAD | stat.S_IWRITE)
return key_path
[docs]@contextmanager
def open_tunnel(
user: str,
host: str,
port: int,
control_path: str,
key_path: str) -> Tunnelled:
""" Provides clean setup/tear down for an SSH tunnel
Args:
user: SSH user
key_path: path to a private SSH key
host: string containing target host
port: target's SSH port
"""
target = user + '@' + host
opt_list = SHARED_SSH_OPTS + [
'-oControlPath=' + control_path,
'-oControlMaster=auto']
base_cmd = ['ssh', '-p', str(port)] + opt_list
start_tunnel = base_cmd + ['-fnN', '-i', key_path, target]
log.debug('Starting SSH tunnel: ' + ' '.join(start_tunnel))
subprocess.run(start_tunnel, check=True, env={"PATH": os.environ["PATH"]})
log.debug('SSH Tunnel established!')
yield Tunnelled(opt_list, target, port)
close_tunnel = base_cmd + ['-O', 'exit', target]
log.debug('Closing SSH Tunnel: ' + ' '.join(close_tunnel))
# after we are done using the tunnel, we do not care about its output
subprocess.run(close_tunnel, check=True, env={"PATH": os.environ["PATH"]}, stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL)
[docs]class SshClient:
""" class for binding SSH user and key to tunnel
:param user: SSH user to connect with
:type user: str
:param key: SSH private key for user to connect with
:type key: str
"""
def __init__(self, user: str, key: str):
self.user = user
self.key = key
self.key_path = temp_ssh_key(key)
[docs] def tunnel(self, host: str, port: int=22) -> typing.Generator[Tunnelled, None, None]:
""" wrapper for the :func:`open_tunnel` context manager
:param host: host IP to open the tunnel to
:type host: str
:param port: SSH port of the host (defaults to 22)
:type port: int
"""
with tempfile.NamedTemporaryFile() as f:
return open_tunnel(self.user, host, port, f.name, self.key_path)
[docs] def command(self, host: str, cmd: list, port: int=22, **kwargs) -> bytes:
""" Opens a tunnel and runs a single command
:param host: host IP to open the tunnel to
:type host: str
:param cmd: list of shell args to run on the host
:type cmd: list
:param port: SSH port of the host (defaults to 22)
:type port: int
:param kwargs: see args used in :func:`Tunnelled.command`
"""
with self.tunnel(host, port) as t:
return t.command(cmd, **kwargs)
[docs] def get_home_dir(self, host: str, port: int=22) -> str:
""" Returns the SSH home dir
:param host: host IP to get the home directory from
:type host: str
:param port: SSH port of the host (defaults to 22)
:type port: int
"""
return self.command(host, ['pwd'], port=port).decode().strip()
[docs] @retrying.retry(wait_fixed=1000, stop_max_attempt_number=600)
def wait_for_ssh_connection(self, host: str, port: int=22) -> None:
""" Blocks until SSH connection can be established
:param host: host IP to wait for connection to
:type host: str
:param port: SSH port of the host (defaults to 22)
:type port: int
"""
self.get_home_dir(host, port)
[docs] def add_ssh_user_to_docker_users(self, host: str, port: int=22):
""" Runs user mod on remote host to add this user to docker users
:param host: host to add usergroup memership too
:type host: str
:param port: SSH port of the host (defaults to 22)
:type port: int
"""
self.command(host, ['sudo', 'usermod', '-aG', 'docker', self.user], port=port)
@contextmanager
def _make_slave_pty():
master_pty, slave_pty = pty.openpty()
yield slave_pty
os.close(slave_pty)
os.close(master_pty)
[docs]def parse_ip(ip: str) -> (str, int):
""" takes an IP string and either a hostname and either the given port or
the default ssh port of 22
"""
tmp = ip.split(':')
if len(tmp) == 2:
return tmp[0], int(tmp[1])
elif len(tmp) == 1:
# no port, assume default SSH
return ip, 22
else:
raise ValueError(
"Expected a string of form <ip> or <ip>:<port> but found a string with more than one " +
"colon in it. NOTE: IPv6 is not supported at this time. Got: {}".format(ip))
[docs]class AsyncSshClient(SshClient):
""" SshClient for running against a set of hosts in parallel
Args:
user: ssh user name
key: ssh private key contents
targets: list of host strings for SSH use (hostname:optional_port)
process_timeout (optional): how many seconds any given process can run for
parallelism (optional): how many processes to run at the same time. Rarely is
a SSH command CPU bound, so this number can be greater than CPU concurrency
"""
def __init__(
self,
user: str,
key: str,
targets: list,
process_timeout=120,
parallelism=10):
super().__init__(user, key)
self.process_timeout = process_timeout
self.__targets = targets
self.__parallelism = parallelism
async def _run_cmd_return_dict_async(self, cmd: list) -> dict:
""" Runs an arbitrary command as an asynchronous subprocess
Args:
cmd: list or argument to initialize the process
Returns:
dict of the command args, output, returncode, and pid
"""
log.debug('Starting command: {}'.format(str(cmd)))
with _make_slave_pty() as slave_pty:
process = await asyncio.create_subprocess_exec(
*cmd, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
stdin=slave_pty,
env={'TERM': 'linux'})
stdout = b''
stderr = b''
try:
stdout, stderr = await asyncio.wait_for(process.communicate(), self.process_timeout)
except asyncio.TimeoutError:
try:
process.terminate()
except ProcessLookupError:
log.info('process with pid {} not found'.format(process.pid))
log.error('timeout of {} sec reached. PID {} killed'.format(self.process_timeout, process.pid))
return {
"cmd": cmd,
"stdout": stdout,
"stderr": stderr,
"returncode": process.returncode,
"pid": process.pid
}
[docs] async def run(self, sem: asyncio.Semaphore, host: str, cmd: list) -> dict:
""" Uses SSH tunnel to run a command against a host
Args:
sem: semaphore for concurrency control
host: host string to run copy to
cmd: argument list to be executed on the remote host
Returns:
command result dict (see _run_cmd_return_dict_async)
"""
hostname, port = parse_ip(host)
async with sem:
log.debug('Starting run command on {}'.format(host))
with self.tunnel(hostname, port) as t:
full_cmd = ['ssh', '-p', str(t.port)] + t.opt_list + [t.target] + cmd
result = await self._run_cmd_return_dict_async(full_cmd)
result['host'] = host
return result
[docs] async def copy(
self,
sem: asyncio.Semaphore,
host: str,
local_path: str,
remote_path: str,
recursive: bool) -> dict:
""" uses SCP to copy files to remote host
Args:
sem: semaphore for concurrency control
host: host string to run copy to
local_path: path that will be copied
remote_path: where the data will be copied to
recursive: if True, recursive SCP the local_path to remote_path
Returns:
command result dict (see _run_cmd_return_dict_async)
"""
async with sem:
log.debug('Starting copy command on {}'.format(host))
hostname, port = parse_ip(host)
copy_command = []
if recursive:
copy_command.append('-r')
remote_full_path = '{}@{}:{}'.format(self.user, hostname, remote_path)
copy_command += [local_path, remote_full_path]
full_cmd = ['scp'] + SHARED_SSH_OPTS + ['-P', str(port), '-i', self.key_path] + copy_command
log.debug('copy with command {}'.format(full_cmd))
result = await self._run_cmd_return_dict_async(full_cmd)
result['host'] = host
return result
[docs] async def run_command_on_hosts(self, coroutine_name: str, *args, sem: asyncio.Semaphore=None) -> list:
""" Starts and waits upon tasks running across all hosts
Args:
coroutine_name: either 'copy' or 'run'
*args: arg list to be passed to copy or run
sem (optional): semaphore for controlling concurrency. If not supplied, a semaphore
of the default parallelism will be created
Returns:
list of result dicts from _run_cmd_return_dict_async
"""
if not sem:
sem = asyncio.Semaphore(self.__parallelism)
tasks = self.start_command_on_hosts(sem, coroutine_name, *args)
log.debug('Waiting for asynchonrous processes to finish')
await asyncio.wait(tasks)
return [task.result() for task in tasks]
[docs] def start_command_on_hosts(self, sem: asyncio.Semaphore, coroutine_name: str, *args) -> list:
""" Starts coroutines against all hosts and returns futures
Args:
sem: semaphore for blocking job creation to control concurrency
coroutine_name: either 'copy' or 'run'
*args: args to be passed to copy or run
Returns:
list of futures of the commands that were started
"""
log.debug('Starting {} with {} to execute on all hosts'.format(coroutine_name, str(args)))
tasks = []
for host in self.__targets:
log.debug('Starting {} on {}'.format(coroutine_name, host))
tasks.append(asyncio.ensure_future(getattr(self, coroutine_name)(sem, host, *args)))
return tasks
[docs] def run_command(self, coroutine_name: str, *args) -> list:
""" Runs a _run_command_on_hosts in an async loop
Args:
coroutine_name: either 'copy' or 'run'
*args: args to pass to copy or run
Returns:
list of result dicts
"""
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
results = loop.run_until_complete(
self.run_command_on_hosts(coroutine_name, *args))
finally:
loop.close()
return results