Refactor code, increase security
parent
51304e3426
commit
d5a2326d2a
|
@ -12,11 +12,10 @@ def create_instance(container_name: str, instance_password: str):
|
||||||
|
|
||||||
instance = lxd_client.instances.create(config, wait=True)
|
instance = lxd_client.instances.create(config, wait=True)
|
||||||
instance.start(wait=True)
|
instance.start(wait=True)
|
||||||
|
setup_ssh(container_name, instance_password)
|
||||||
while type(ipaddress.ip_address(instance.state().network['eth0']['addresses'][0]['address'])) != ipaddress.IPv4Address:
|
while type(ipaddress.ip_address(instance.state().network['eth0']['addresses'][0]['address'])) != ipaddress.IPv4Address:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
setup_ssh(container_name, instance_password)
|
|
||||||
|
|
||||||
return instance.state().network['eth0']['addresses'][0]
|
return instance.state().network['eth0']['addresses'][0]
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,6 +57,10 @@ def setup_ssh(container_name: str, instance_password: str):
|
||||||
def get_networking(container_name: str):
|
def get_networking(container_name: str):
|
||||||
instance = lxd_client.instances.get(container_name)
|
instance = lxd_client.instances.get(container_name)
|
||||||
|
|
||||||
|
while instance.state().network is None or \
|
||||||
|
type(ipaddress.ip_address(instance.state().network['eth0']['addresses'][0]['address'])) != ipaddress.IPv4Address:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
return instance.state().network['eth0']['addresses'][0]
|
return instance.state().network['eth0']['addresses'][0]
|
||||||
|
|
||||||
|
|
||||||
|
|
3
main.py
3
main.py
|
@ -15,7 +15,8 @@ def connect_handler(script: sshim.Script):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
server = sshim.Server(connect_handler, address='127.0.0.1', port=3022)
|
server = sshim.Server(connect_handler, address='0.0.0.0'
|
||||||
|
'', port=3022)
|
||||||
try:
|
try:
|
||||||
server.run()
|
server.run()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
|
114
sshim_patch.py
114
sshim_patch.py
|
@ -8,6 +8,8 @@ import threading
|
||||||
import logging
|
import logging
|
||||||
import select
|
import select
|
||||||
import time
|
import time
|
||||||
|
import ipaddress
|
||||||
|
import secrets
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -22,37 +24,38 @@ def check_channel_request(self, kind, channel_id):
|
||||||
|
|
||||||
def check_channel_shell_request(self, channel):
|
def check_channel_shell_request(self, channel):
|
||||||
logger.debug("Check channel shell request: %s" % channel.get_id())
|
logger.debug("Check channel shell request: %s" % channel.get_id())
|
||||||
self.runner.set_shell_channel(channel)
|
Runner(self, self.username, 'shell', channel).start()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def check_channel_exec_request(self, channel):
|
def check_channel_exec_request(self, channel):
|
||||||
logger.debug("Check channel exec request: %s" % channel.get_id())
|
logger.debug("Check channel exec request: %s" % channel.get_id())
|
||||||
self.runner.set_shell_channel(channel)
|
Runner(self, self.username, 'exec', channel).start()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def check_channel_subsystem_request(self, channel, name):
|
def check_channel_subsystem_request(self, channel, name):
|
||||||
if name == 'sftp':
|
if name == 'sftp':
|
||||||
self.runner.set_sftp_channel(channel)
|
Runner(self, self.username, 'sftp', channel).start()
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def check_auth_none(self, username):
|
def check_auth_none(self, username):
|
||||||
if username == os.environ["SSH_USERNAME"]:
|
|
||||||
return paramiko.AUTH_PARTIALLY_SUCCESSFUL
|
|
||||||
return paramiko.AUTH_FAILED
|
return paramiko.AUTH_FAILED
|
||||||
|
|
||||||
|
|
||||||
def check_auth_password(self, username, password):
|
def check_auth_password(self, username, password):
|
||||||
logger.debug(f"{username} just tried to connect")
|
logger.debug(f"{username} just tried to connect")
|
||||||
if username == os.environ["SSH_USERNAME"] and password == os.environ["SSH_PASSWORD"]:
|
# ensure that the connection is made from a local ip
|
||||||
self.runner = Runner(self, os.environ["SSH_USERNAME"], self.transport)
|
if ipaddress.ip_address(self.address).is_private is not True:
|
||||||
self.runner.start()
|
return paramiko.AUTH_FAILED
|
||||||
|
if secrets.compare_digest(password, os.environ["SSH_PASSWORD"]):
|
||||||
|
self.username = username
|
||||||
|
Runner(self, self.username).start()
|
||||||
return paramiko.AUTH_SUCCESSFUL
|
return paramiko.AUTH_SUCCESSFUL
|
||||||
return paramiko.AUTH_FAILED
|
return paramiko.AUTH_FAILED
|
||||||
|
|
||||||
|
@ -62,17 +65,15 @@ def check_auth_publickey(self, username, key):
|
||||||
|
|
||||||
|
|
||||||
class Runner(threading.Thread):
|
class Runner(threading.Thread):
|
||||||
def __init__(self, client, username: str, transport: paramiko.Transport, start_instance: bool = True):
|
def __init__(self, client, username: str, channel_type: str = None, channel: paramiko.Channel = None):
|
||||||
self.instance_name = "instance-" + username
|
self.instance_name = "instance-" + username
|
||||||
self.runner_identifier = "runner-" + str(uuid.uuid4())
|
self.runner_identifier = "runner-" + str(uuid.uuid4())
|
||||||
threading.Thread.__init__(self, name=f'sshim.Runner {self.instance_name} {self.runner_identifier}')
|
threading.Thread.__init__(self, name=f'sshim.Runner {self.instance_name} {self.runner_identifier}')
|
||||||
self.instance_password = self.instance_name # TODO: fix - VERY INSECURE!
|
self.instance_password = self.instance_name # TODO: fix - VERY INSECURE!
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
self.client = client
|
self.client = client
|
||||||
self.transport = transport
|
self.channel_type = channel_type
|
||||||
self.start_instance = start_instance
|
self.channel = channel
|
||||||
self.shell_channel = None
|
|
||||||
self.sftp_channel = None
|
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
try:
|
try:
|
||||||
|
@ -81,56 +82,38 @@ class Runner(threading.Thread):
|
||||||
logger.debug(e)
|
logger.debug(e)
|
||||||
vm_ip = lxd_interface.get_networking(self.instance_name)['address']
|
vm_ip = lxd_interface.get_networking(self.instance_name)['address']
|
||||||
|
|
||||||
|
if self.channel is None:
|
||||||
|
return None
|
||||||
|
|
||||||
with paramiko.SSHClient() as ssh_client:
|
with paramiko.SSHClient() as ssh_client:
|
||||||
ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy)
|
ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy)
|
||||||
ssh_client.connect(vm_ip, username='root', password=self.instance_password)
|
try_limit = 10 # TODO: load from config
|
||||||
client_shell_channel = ssh_client.invoke_shell()
|
tries = 0
|
||||||
client_sftp_channel = ssh_client.open_sftp().get_channel()
|
while tries <= try_limit:
|
||||||
|
try:
|
||||||
|
ssh_client.connect(vm_ip, username='root', password=self.instance_password)
|
||||||
|
break
|
||||||
|
except Exception as e: # TODO: narrow exception
|
||||||
|
if tries >= try_limit:
|
||||||
|
raise
|
||||||
|
logger.debug(e)
|
||||||
|
tries += 1
|
||||||
|
time.sleep(0.2)
|
||||||
|
if self.channel_type == "shell" or self.channel_type == "exec":
|
||||||
|
client_channel = ssh_client.invoke_shell()
|
||||||
|
elif self.channel_type == "sftp":
|
||||||
|
client_channel = ssh_client.open_sftp().get_channel()
|
||||||
|
|
||||||
last_save_time = round(time.time())
|
last_save_time = round(time.time())
|
||||||
lxd_interface.set_description(self.instance_name, str(last_save_time))
|
lxd_interface.set_description(self.instance_name, str(last_save_time))
|
||||||
while True:
|
forward_channel_return = True
|
||||||
|
while forward_channel_return is True:
|
||||||
current_time_rounded = round(time.time())
|
current_time_rounded = round(time.time())
|
||||||
if current_time_rounded != last_save_time:
|
if current_time_rounded != last_save_time:
|
||||||
lxd_interface.set_description(self.instance_name, str(current_time_rounded))
|
lxd_interface.set_description(self.instance_name, str(current_time_rounded))
|
||||||
last_save_time = current_time_rounded
|
last_save_time = current_time_rounded
|
||||||
|
|
||||||
if self.shell_channel is not None:
|
forward_channel_return = self.forward_channel(client_channel)
|
||||||
r, w, e = select.select([client_shell_channel, self.shell_channel], [], [])
|
|
||||||
if self.shell_channel in r:
|
|
||||||
x = self.shell_channel.recv(1024)
|
|
||||||
if len(x) == 0:
|
|
||||||
self.shell_channel.close()
|
|
||||||
self.shell_channel = None
|
|
||||||
break
|
|
||||||
client_shell_channel.send(x)
|
|
||||||
if client_shell_channel in r:
|
|
||||||
x = client_shell_channel.recv(1024)
|
|
||||||
if len(x) == 0:
|
|
||||||
self.shell_channel.close()
|
|
||||||
self.shell_channel = None
|
|
||||||
break
|
|
||||||
self.shell_channel.send(x)
|
|
||||||
|
|
||||||
if self.sftp_channel is not None: # TODO: move this to function
|
|
||||||
r, w, e = select.select([client_sftp_channel, self.sftp_channel], [], [])
|
|
||||||
if self.sftp_channel in r:
|
|
||||||
x = self.sftp_channel.recv(1024)
|
|
||||||
if len(x) == 0:
|
|
||||||
self.sftp_channel.close()
|
|
||||||
self.sftp_channel = None
|
|
||||||
break
|
|
||||||
client_sftp_channel.send(x)
|
|
||||||
if client_sftp_channel in r:
|
|
||||||
x = client_sftp_channel.recv(1024)
|
|
||||||
if len(x) == 0:
|
|
||||||
self.sftp_channel.close()
|
|
||||||
self.sftp_channel = None
|
|
||||||
break
|
|
||||||
self.sftp_channel.send(x)
|
|
||||||
|
|
||||||
if self.transport.is_active() is False:
|
|
||||||
break
|
|
||||||
|
|
||||||
exit_time = round(time.time())
|
exit_time = round(time.time())
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
@ -143,13 +126,24 @@ class Runner(threading.Thread):
|
||||||
lxd_interface.destroy_instance(self.instance_name)
|
lxd_interface.destroy_instance(self.instance_name)
|
||||||
break
|
break
|
||||||
|
|
||||||
def set_shell_channel(self, channel):
|
def forward_channel(self, client_channel) -> bool:
|
||||||
self.shell_channel = channel
|
if self.channel is None:
|
||||||
self.shell_channel.settimeout(None)
|
return False
|
||||||
|
else:
|
||||||
def set_sftp_channel(self, channel):
|
r, w, e = select.select([client_channel, self.channel], [], [])
|
||||||
self.sftp_channel = channel
|
if self.channel in r:
|
||||||
self.sftp_channel.settimeout(None)
|
x = self.channel.recv(1024)
|
||||||
|
if len(x) == 0:
|
||||||
|
self.channel.close()
|
||||||
|
return False
|
||||||
|
client_channel.send(x)
|
||||||
|
if client_channel in r:
|
||||||
|
x = client_channel.recv(1024)
|
||||||
|
if len(x) == 0:
|
||||||
|
self.channel.close()
|
||||||
|
return False
|
||||||
|
self.channel.send(x)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
Handler.check_channel_request = check_channel_request
|
Handler.check_channel_request = check_channel_request
|
||||||
|
|
Loading…
Reference in New Issue