Add persistent instances

master
root 2022-12-01 18:03:17 +00:00
parent b7e4e26eda
commit f8cbbe3703
2 changed files with 53 additions and 12 deletions

View File

@ -53,3 +53,23 @@ def setup_ssh(container_name: str, instance_password: str):
execute_command(container_name, ["passwd", "root"], stdin_payload=f"{instance_password}\n{instance_password}") execute_command(container_name, ["passwd", "root"], stdin_payload=f"{instance_password}\n{instance_password}")
return True return True
def get_networking(container_name: str):
instance = lxd_client.instances.get(container_name)
return instance.state().network['eth0']['addresses'][0]
def set_description(container_name: str, new_value: str):
instance = lxd_client.instances.get(container_name)
instance.description = new_value
instance.save()
return True
def get_description(container_name: str) -> str:
instance = lxd_client.instances.get(container_name)
return instance.description

View File

@ -1,4 +1,5 @@
from sshim import * from sshim import *
import pylxd
import paramiko import paramiko
import os import os
import uuid import uuid
@ -43,7 +44,7 @@ def check_auth_none(self, username):
def check_auth_password(self, username, password): def check_auth_password(self, username, password):
logger.debug(f"{username} just tried to connect") logger.debug(f"{username} just tried to connect")
if username == os.environ["SSH_USERNAME"] and password == os.environ["SSH_PASSWORD"]: if username == os.environ["SSH_USERNAME"] and password == os.environ["SSH_PASSWORD"]:
self.runner = Runner(self, self.transport) self.runner = Runner(self, os.environ["SSH_USERNAME"], self.transport)
self.runner.start() self.runner.start()
return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED return paramiko.AUTH_FAILED
@ -54,19 +55,24 @@ def check_auth_publickey(self, username, key):
class Runner(threading.Thread): class Runner(threading.Thread):
def __init__(self, client, transport: paramiko.Transport): def __init__(self, client, username: str, transport: paramiko.Transport, start_instance: bool = True):
self.instance_name = "instance-" + str(uuid.uuid4()) self.instance_name = "instance-" + username
threading.Thread.__init__(self, name=f'sshim.Runner {self.instance_name}') self.runner_identifier = "runner-" + str(uuid.uuid4())
self.instance_password = str(uuid.uuid4()) # TODO: secure password generation threading.Thread.__init__(self, name=f'sshim.Runner {self.instance_name} {self.runner_identifier}')
self.instance_password = self.instance_name # TODO: fix - VERY INSECURE!
self.daemon = True self.daemon = True
self.client = client self.client = client
self.transport = transport self.transport = transport
# self.transport.set_subsystem_handler('sftp', handler=paramiko.SFTPServer) self.start_instance = start_instance
self.shell_channel = None self.shell_channel = None
self.sftp_channel = None self.sftp_channel = None
def run(self) -> None: def run(self) -> None:
try:
vm_ip = lxd_interface.create_instance(self.instance_name, self.instance_password)['address'] vm_ip = lxd_interface.create_instance(self.instance_name, self.instance_password)['address']
except pylxd.exceptions.LXDAPIException as e:
logger.debug(e)
vm_ip = lxd_interface.get_networking(self.instance_name)['address']
with paramiko.SSHClient() as ssh_client: with paramiko.SSHClient() as ssh_client:
ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy) ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy)
@ -74,7 +80,13 @@ class Runner(threading.Thread):
client_shell_channel = ssh_client.invoke_shell() client_shell_channel = ssh_client.invoke_shell()
client_sftp_channel = ssh_client.open_sftp().get_channel() client_sftp_channel = ssh_client.open_sftp().get_channel()
last_save_time = round(time.time())
while True: while True:
current_time_rounded = round(time.time())
if current_time_rounded != last_save_time:
lxd_interface.set_description(self.instance_name, str(current_time_rounded))
last_save_time = current_time_rounded
if self.shell_channel is not None: if self.shell_channel is not None:
r, w, e = select.select([client_shell_channel, self.shell_channel], [], []) r, w, e = select.select([client_shell_channel, self.shell_channel], [], [])
if self.shell_channel in r: if self.shell_channel in r:
@ -82,14 +94,14 @@ class Runner(threading.Thread):
if len(x) == 0: if len(x) == 0:
self.shell_channel.close() self.shell_channel.close()
self.shell_channel = None self.shell_channel = None
continue break
client_shell_channel.send(x) client_shell_channel.send(x)
if client_shell_channel in r: if client_shell_channel in r:
x = client_shell_channel.recv(1024) x = client_shell_channel.recv(1024)
if len(x) == 0: if len(x) == 0:
self.shell_channel.close() self.shell_channel.close()
self.shell_channel = None self.shell_channel = None
continue break
self.shell_channel.send(x) self.shell_channel.send(x)
if self.sftp_channel is not None: # TODO: move this to function if self.sftp_channel is not None: # TODO: move this to function
@ -99,20 +111,29 @@ class Runner(threading.Thread):
if len(x) == 0: if len(x) == 0:
self.sftp_channel.close() self.sftp_channel.close()
self.sftp_channel = None self.sftp_channel = None
continue break
client_sftp_channel.send(x) client_sftp_channel.send(x)
if client_sftp_channel in r: if client_sftp_channel in r:
x = client_sftp_channel.recv(1024) x = client_sftp_channel.recv(1024)
if len(x) == 0: if len(x) == 0:
self.sftp_channel.close() self.sftp_channel.close()
self.sftp_channel = None self.sftp_channel = None
continue break
self.sftp_channel.send(x) self.sftp_channel.send(x)
if self.transport.is_active() is False: if self.transport.is_active() is False:
break break
exit_time = round(time.time())
time.sleep(10)
while True:
last_run_time = int(lxd_interface.get_description(self.instance_name))
if exit_time < last_run_time:
break
elif round(time.time()) > (last_run_time + 15):
lxd_interface.destroy_instance(self.instance_name) lxd_interface.destroy_instance(self.instance_name)
break
def set_shell_channel(self, channel): def set_shell_channel(self, channel):
self.shell_channel = channel self.shell_channel = channel