Fix waiting for vm code, add keepalive

master
root 2022-12-05 23:33:33 +00:00
parent d5a2326d2a
commit c300773400
2 changed files with 14 additions and 15 deletions

View File

@ -15,8 +15,7 @@ def connect_handler(script: sshim.Script):
pass pass
server = sshim.Server(connect_handler, address='0.0.0.0' server = sshim.Server(connect_handler, address='0.0.0.0', port=3022)
'', port=3022)
try: try:
server.run() server.run()
except KeyboardInterrupt: except KeyboardInterrupt:

View File

@ -82,23 +82,22 @@ class Runner(threading.Thread):
logger.debug(e) logger.debug(e)
vm_ip = lxd_interface.get_networking(self.instance_name)['address'] vm_ip = lxd_interface.get_networking(self.instance_name)['address']
if self.channel is None: if self.channel is not None:
return None self.channel.get_transport().set_keepalive(5) # TODO: make config option
with paramiko.SSHClient() as ssh_client: with paramiko.SSHClient() as ssh_client:
ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy) # wait for instance to start if it hasn't started yet
try_limit = 10 # TODO: load from config is_not_int = True
tries = 0 while is_not_int and self.channel is not None:
while tries <= try_limit:
try: try:
ssh_client.connect(vm_ip, username='root', password=self.instance_password) int(lxd_interface.get_description(self.instance_name))
break is_not_int = False
except Exception as e: # TODO: narrow exception except ValueError as e:
if tries >= try_limit:
raise
logger.debug(e) logger.debug(e)
tries += 1
time.sleep(0.2) time.sleep(0.2)
ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy)
ssh_client.connect(vm_ip, username='root', password=self.instance_password)
if self.channel_type == "shell" or self.channel_type == "exec": if self.channel_type == "shell" or self.channel_type == "exec":
client_channel = ssh_client.invoke_shell() client_channel = ssh_client.invoke_shell()
elif self.channel_type == "sftp": elif self.channel_type == "sftp":
@ -113,6 +112,7 @@ class Runner(threading.Thread):
lxd_interface.set_description(self.instance_name, str(current_time_rounded)) lxd_interface.set_description(self.instance_name, str(current_time_rounded))
last_save_time = current_time_rounded last_save_time = current_time_rounded
if "client_channel" in locals():
forward_channel_return = self.forward_channel(client_channel) forward_channel_return = self.forward_channel(client_channel)
exit_time = round(time.time()) exit_time = round(time.time())