diff --git a/sshim_patch.py b/sshim_patch.py index e9ed9a9..3613a44 100644 --- a/sshim_patch.py +++ b/sshim_patch.py @@ -12,9 +12,16 @@ import inspect logger = logging.getLogger(__name__) +def check_channel_request(self, kind, channel_id): + logger.debug(f"Client requested {kind}") + if kind in ('session', 'sftp'): + return paramiko.OPEN_SUCCEEDED + return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + + def check_channel_shell_request(self, channel): logger.debug("Check channel shell request: %s" % channel.get_id()) - Runner(self, channel).start() + self.runner.set_shell_channel(channel) return True @@ -28,6 +35,8 @@ def check_auth_none(self, username): def check_auth_password(self, username, password): logger.debug(f"{username} just tried to connect") if username == os.environ["SSH_USERNAME"] and password == os.environ["SSH_PASSWORD"]: + self.runner = Runner(self, self.transport) + self.runner.start() return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -37,15 +46,15 @@ def check_auth_publickey(self, username, key): class Runner(threading.Thread): - def __init__(self, client, channel: paramiko.Channel): - threading.Thread.__init__(self, name='sshim.Runner(%s)' % channel.chanid) + def __init__(self, client, transport: paramiko.Transport): + threading.Thread.__init__(self, name='sshim.Runner') self.instance_name = "instance-" + str(uuid.uuid4()) self.instance_password = str(uuid.uuid4()) # TODO: secure password generation self.daemon = True self.client = client - self.channel = channel - self.channel.settimeout(None) - self.transport = None + self.transport = transport + self.shell_channel = None + self.sftp_channel = None def run(self) -> None: vm_ip = lxd_interface.create_instance(self.instance_name, self.instance_password)['address'] @@ -53,27 +62,51 @@ class Runner(threading.Thread): with paramiko.SSHClient() as ssh_client: ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy) ssh_client.connect(vm_ip, username='root', password=self.instance_password) - self.transport = ssh_client.get_transport() - client_channel = ssh_client.invoke_shell() + client_shell_channel = ssh_client.invoke_shell() + client_sftp_channel = ssh_client.open_sftp().get_channel() while True: - r, w, e = select.select([client_channel, self.channel], [], []) - if self.channel in r: - x = self.channel.recv(1024) - if len(x) == 0: - break - client_channel.send(x) + if self.shell_channel is not None: + r, w, e = select.select([client_shell_channel, self.shell_channel], [], []) + if self.shell_channel in r: + x = self.shell_channel.recv(1024) + if len(x) == 0: + break + client_shell_channel.send(x) + if client_shell_channel in r: + x = client_shell_channel.recv(1024) + if len(x) == 0: + break + self.shell_channel.send(x) - if client_channel in r: - x = client_channel.recv(1024) - if len(x) == 0: - break - self.channel.send(x) + if self.sftp_channel is not None: # TODO: move this to function + r, w, e = select.select([client_sftp_channel, self.sftp_channel], [], []) + if self.sftp_channel in r: + x = self.sftp_channel.recv(1024) + if len(x) == 0: + break + client_sftp_channel.send(x) + if client_sftp_channel in r: + x = client_sftp_channel.recv(1024) + if len(x) == 0: + break + self.sftp_channel.send(x) - client_channel.close() - self.channel.close() + if self.transport.is_active() is False: + break + + client_shell_channel.close() + self.shell_channel.close() lxd_interface.destroy_instance(self.instance_name) + def set_shell_channel(self, channel): + self.shell_channel = channel + self.shell_channel.settimeout(None) + + def set_scp_channel(self, channel): + self.sftp_channel = channel + self.sftp_channel.settimeout(None) + Handler.check_channel_shell_request = check_channel_shell_request Handler.check_auth_none = check_auth_none