diff --git a/lxd_interface.py b/lxd_interface.py index ad16d61..6f9c264 100644 --- a/lxd_interface.py +++ b/lxd_interface.py @@ -12,11 +12,10 @@ def create_instance(container_name: str, instance_password: str): instance = lxd_client.instances.create(config, wait=True) instance.start(wait=True) + setup_ssh(container_name, instance_password) while type(ipaddress.ip_address(instance.state().network['eth0']['addresses'][0]['address'])) != ipaddress.IPv4Address: time.sleep(0.1) - setup_ssh(container_name, instance_password) - return instance.state().network['eth0']['addresses'][0] @@ -58,6 +57,10 @@ def setup_ssh(container_name: str, instance_password: str): def get_networking(container_name: str): instance = lxd_client.instances.get(container_name) + while instance.state().network is None or \ + type(ipaddress.ip_address(instance.state().network['eth0']['addresses'][0]['address'])) != ipaddress.IPv4Address: + time.sleep(0.1) + return instance.state().network['eth0']['addresses'][0] diff --git a/main.py b/main.py index a1e612a..be00937 100644 --- a/main.py +++ b/main.py @@ -15,7 +15,8 @@ def connect_handler(script: sshim.Script): pass -server = sshim.Server(connect_handler, address='127.0.0.1', port=3022) +server = sshim.Server(connect_handler, address='0.0.0.0' + '', port=3022) try: server.run() except KeyboardInterrupt: diff --git a/sshim_patch.py b/sshim_patch.py index 98bbc08..ccf3450 100644 --- a/sshim_patch.py +++ b/sshim_patch.py @@ -8,6 +8,8 @@ import threading import logging import select import time +import ipaddress +import secrets import inspect logger = logging.getLogger(__name__) @@ -22,37 +24,38 @@ def check_channel_request(self, kind, channel_id): def check_channel_shell_request(self, channel): logger.debug("Check channel shell request: %s" % channel.get_id()) - self.runner.set_shell_channel(channel) + Runner(self, self.username, 'shell', channel).start() return True def check_channel_exec_request(self, channel): logger.debug("Check channel exec request: %s" % channel.get_id()) - self.runner.set_shell_channel(channel) + Runner(self, self.username, 'exec', channel).start() return True def check_channel_subsystem_request(self, channel, name): if name == 'sftp': - self.runner.set_sftp_channel(channel) + Runner(self, self.username, 'sftp', channel).start() return True else: return False def check_auth_none(self, username): - if username == os.environ["SSH_USERNAME"]: - return paramiko.AUTH_PARTIALLY_SUCCESSFUL return paramiko.AUTH_FAILED def check_auth_password(self, username, password): logger.debug(f"{username} just tried to connect") - if username == os.environ["SSH_USERNAME"] and password == os.environ["SSH_PASSWORD"]: - self.runner = Runner(self, os.environ["SSH_USERNAME"], self.transport) - self.runner.start() + # ensure that the connection is made from a local ip + if ipaddress.ip_address(self.address).is_private is not True: + return paramiko.AUTH_FAILED + if secrets.compare_digest(password, os.environ["SSH_PASSWORD"]): + self.username = username + Runner(self, self.username).start() return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -62,17 +65,15 @@ def check_auth_publickey(self, username, key): class Runner(threading.Thread): - def __init__(self, client, username: str, transport: paramiko.Transport, start_instance: bool = True): + def __init__(self, client, username: str, channel_type: str = None, channel: paramiko.Channel = None): self.instance_name = "instance-" + username self.runner_identifier = "runner-" + str(uuid.uuid4()) threading.Thread.__init__(self, name=f'sshim.Runner {self.instance_name} {self.runner_identifier}') self.instance_password = self.instance_name # TODO: fix - VERY INSECURE! self.daemon = True self.client = client - self.transport = transport - self.start_instance = start_instance - self.shell_channel = None - self.sftp_channel = None + self.channel_type = channel_type + self.channel = channel def run(self) -> None: try: @@ -81,56 +82,38 @@ class Runner(threading.Thread): logger.debug(e) vm_ip = lxd_interface.get_networking(self.instance_name)['address'] + if self.channel is None: + return None + with paramiko.SSHClient() as ssh_client: ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy) - ssh_client.connect(vm_ip, username='root', password=self.instance_password) - client_shell_channel = ssh_client.invoke_shell() - client_sftp_channel = ssh_client.open_sftp().get_channel() + try_limit = 10 # TODO: load from config + tries = 0 + while tries <= try_limit: + try: + ssh_client.connect(vm_ip, username='root', password=self.instance_password) + break + except Exception as e: # TODO: narrow exception + if tries >= try_limit: + raise + logger.debug(e) + tries += 1 + time.sleep(0.2) + if self.channel_type == "shell" or self.channel_type == "exec": + client_channel = ssh_client.invoke_shell() + elif self.channel_type == "sftp": + client_channel = ssh_client.open_sftp().get_channel() last_save_time = round(time.time()) lxd_interface.set_description(self.instance_name, str(last_save_time)) - while True: + forward_channel_return = True + while forward_channel_return is True: current_time_rounded = round(time.time()) if current_time_rounded != last_save_time: lxd_interface.set_description(self.instance_name, str(current_time_rounded)) last_save_time = current_time_rounded - if self.shell_channel is not None: - r, w, e = select.select([client_shell_channel, self.shell_channel], [], []) - if self.shell_channel in r: - x = self.shell_channel.recv(1024) - if len(x) == 0: - self.shell_channel.close() - self.shell_channel = None - break - client_shell_channel.send(x) - if client_shell_channel in r: - x = client_shell_channel.recv(1024) - if len(x) == 0: - self.shell_channel.close() - self.shell_channel = None - break - self.shell_channel.send(x) - - if self.sftp_channel is not None: # TODO: move this to function - r, w, e = select.select([client_sftp_channel, self.sftp_channel], [], []) - if self.sftp_channel in r: - x = self.sftp_channel.recv(1024) - if len(x) == 0: - self.sftp_channel.close() - self.sftp_channel = None - break - client_sftp_channel.send(x) - if client_sftp_channel in r: - x = client_sftp_channel.recv(1024) - if len(x) == 0: - self.sftp_channel.close() - self.sftp_channel = None - break - self.sftp_channel.send(x) - - if self.transport.is_active() is False: - break + forward_channel_return = self.forward_channel(client_channel) exit_time = round(time.time()) time.sleep(10) @@ -143,13 +126,24 @@ class Runner(threading.Thread): lxd_interface.destroy_instance(self.instance_name) break - def set_shell_channel(self, channel): - self.shell_channel = channel - self.shell_channel.settimeout(None) - - def set_sftp_channel(self, channel): - self.sftp_channel = channel - self.sftp_channel.settimeout(None) + def forward_channel(self, client_channel) -> bool: + if self.channel is None: + return False + else: + r, w, e = select.select([client_channel, self.channel], [], []) + if self.channel in r: + x = self.channel.recv(1024) + if len(x) == 0: + self.channel.close() + return False + client_channel.send(x) + if client_channel in r: + x = client_channel.recv(1024) + if len(x) == 0: + self.channel.close() + return False + self.channel.send(x) + return True Handler.check_channel_request = check_channel_request