diff --git a/lxd_interface.py b/lxd_interface.py index 3f848fc..ad16d61 100644 --- a/lxd_interface.py +++ b/lxd_interface.py @@ -53,3 +53,23 @@ def setup_ssh(container_name: str, instance_password: str): execute_command(container_name, ["passwd", "root"], stdin_payload=f"{instance_password}\n{instance_password}") return True + + +def get_networking(container_name: str): + instance = lxd_client.instances.get(container_name) + + return instance.state().network['eth0']['addresses'][0] + + +def set_description(container_name: str, new_value: str): + instance = lxd_client.instances.get(container_name) + instance.description = new_value + instance.save() + + return True + + +def get_description(container_name: str) -> str: + instance = lxd_client.instances.get(container_name) + + return instance.description diff --git a/sshim_patch.py b/sshim_patch.py index 879468e..f97bd53 100644 --- a/sshim_patch.py +++ b/sshim_patch.py @@ -1,4 +1,5 @@ from sshim import * +import pylxd import paramiko import os import uuid @@ -43,7 +44,7 @@ def check_auth_none(self, username): def check_auth_password(self, username, password): logger.debug(f"{username} just tried to connect") if username == os.environ["SSH_USERNAME"] and password == os.environ["SSH_PASSWORD"]: - self.runner = Runner(self, self.transport) + self.runner = Runner(self, os.environ["SSH_USERNAME"], self.transport) self.runner.start() return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -54,19 +55,24 @@ def check_auth_publickey(self, username, key): class Runner(threading.Thread): - def __init__(self, client, transport: paramiko.Transport): - self.instance_name = "instance-" + str(uuid.uuid4()) - threading.Thread.__init__(self, name=f'sshim.Runner {self.instance_name}') - self.instance_password = str(uuid.uuid4()) # TODO: secure password generation + def __init__(self, client, username: str, transport: paramiko.Transport, start_instance: bool = True): + self.instance_name = "instance-" + username + self.runner_identifier = "runner-" + str(uuid.uuid4()) + threading.Thread.__init__(self, name=f'sshim.Runner {self.instance_name} {self.runner_identifier}') + self.instance_password = self.instance_name # TODO: fix - VERY INSECURE! self.daemon = True self.client = client self.transport = transport - # self.transport.set_subsystem_handler('sftp', handler=paramiko.SFTPServer) + self.start_instance = start_instance self.shell_channel = None self.sftp_channel = None def run(self) -> None: - vm_ip = lxd_interface.create_instance(self.instance_name, self.instance_password)['address'] + try: + vm_ip = lxd_interface.create_instance(self.instance_name, self.instance_password)['address'] + except pylxd.exceptions.LXDAPIException as e: + logger.debug(e) + vm_ip = lxd_interface.get_networking(self.instance_name)['address'] with paramiko.SSHClient() as ssh_client: ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy) @@ -74,7 +80,13 @@ class Runner(threading.Thread): client_shell_channel = ssh_client.invoke_shell() client_sftp_channel = ssh_client.open_sftp().get_channel() + last_save_time = round(time.time()) while True: + current_time_rounded = round(time.time()) + if current_time_rounded != last_save_time: + lxd_interface.set_description(self.instance_name, str(current_time_rounded)) + last_save_time = current_time_rounded + if self.shell_channel is not None: r, w, e = select.select([client_shell_channel, self.shell_channel], [], []) if self.shell_channel in r: @@ -82,14 +94,14 @@ class Runner(threading.Thread): if len(x) == 0: self.shell_channel.close() self.shell_channel = None - continue + break client_shell_channel.send(x) if client_shell_channel in r: x = client_shell_channel.recv(1024) if len(x) == 0: self.shell_channel.close() self.shell_channel = None - continue + break self.shell_channel.send(x) if self.sftp_channel is not None: # TODO: move this to function @@ -99,20 +111,29 @@ class Runner(threading.Thread): if len(x) == 0: self.sftp_channel.close() self.sftp_channel = None - continue + break client_sftp_channel.send(x) if client_sftp_channel in r: x = client_sftp_channel.recv(1024) if len(x) == 0: self.sftp_channel.close() self.sftp_channel = None - continue + break self.sftp_channel.send(x) if self.transport.is_active() is False: break - lxd_interface.destroy_instance(self.instance_name) + exit_time = round(time.time()) + time.sleep(10) + while True: + last_run_time = int(lxd_interface.get_description(self.instance_name)) + + if exit_time < last_run_time: + break + elif round(time.time()) > (last_run_time + 15): + lxd_interface.destroy_instance(self.instance_name) + break def set_shell_channel(self, channel): self.shell_channel = channel