From c3007734006fd84082872e5fb051e28a9fe12bd6 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 5 Dec 2022 23:33:33 +0000 Subject: [PATCH] Fix waiting for vm code, add keepalive --- main.py | 3 +-- sshim_patch.py | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/main.py b/main.py index be00937..3ae4d19 100644 --- a/main.py +++ b/main.py @@ -15,8 +15,7 @@ def connect_handler(script: sshim.Script): pass -server = sshim.Server(connect_handler, address='0.0.0.0' - '', port=3022) +server = sshim.Server(connect_handler, address='0.0.0.0', port=3022) try: server.run() except KeyboardInterrupt: diff --git a/sshim_patch.py b/sshim_patch.py index ccf3450..297c34e 100644 --- a/sshim_patch.py +++ b/sshim_patch.py @@ -82,23 +82,22 @@ class Runner(threading.Thread): logger.debug(e) vm_ip = lxd_interface.get_networking(self.instance_name)['address'] - if self.channel is None: - return None + if self.channel is not None: + self.channel.get_transport().set_keepalive(5) # TODO: make config option with paramiko.SSHClient() as ssh_client: - ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy) - try_limit = 10 # TODO: load from config - tries = 0 - while tries <= try_limit: + # wait for instance to start if it hasn't started yet + is_not_int = True + while is_not_int and self.channel is not None: try: - ssh_client.connect(vm_ip, username='root', password=self.instance_password) - break - except Exception as e: # TODO: narrow exception - if tries >= try_limit: - raise + int(lxd_interface.get_description(self.instance_name)) + is_not_int = False + except ValueError as e: logger.debug(e) - tries += 1 time.sleep(0.2) + + ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy) + ssh_client.connect(vm_ip, username='root', password=self.instance_password) if self.channel_type == "shell" or self.channel_type == "exec": client_channel = ssh_client.invoke_shell() elif self.channel_type == "sftp": @@ -113,7 +112,8 @@ class Runner(threading.Thread): lxd_interface.set_description(self.instance_name, str(current_time_rounded)) last_save_time = current_time_rounded - forward_channel_return = self.forward_channel(client_channel) + if "client_channel" in locals(): + forward_channel_return = self.forward_channel(client_channel) exit_time = round(time.time()) time.sleep(10)