|
1 | 1 | import getpass
|
2 | 2 | import os
|
3 |
| -import logging |
4 | 3 | import platform
|
5 | 4 | import subprocess
|
6 | 5 | import tempfile
|
@@ -55,40 +54,10 @@ def __init__(self, conn_params: ConnectionParams):
|
55 | 54 | self.remote = True
|
56 | 55 | self.username = conn_params.username or getpass.getuser()
|
57 | 56 | self.ssh_dest = f"{self.username}@{self.host}" if conn_params.username else self.host
|
58 |
| - self.add_known_host(self.host) |
59 |
| - self.tunnel_process = None |
60 | 57 |
|
61 | 58 | def __enter__(self):
|
62 | 59 | return self
|
63 | 60 |
|
64 |
| - def __exit__(self, exc_type, exc_val, exc_tb): |
65 |
| - self.close_ssh_tunnel() |
66 |
| - |
67 |
| - def establish_ssh_tunnel(self, local_port, remote_port): |
68 |
| - """ |
69 |
| - Establish an SSH tunnel from a local port to a remote PostgreSQL port. |
70 |
| - """ |
71 |
| - ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"] |
72 |
| - self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300) |
73 |
| - |
74 |
| - def close_ssh_tunnel(self): |
75 |
| - if hasattr(self, 'tunnel_process'): |
76 |
| - self.tunnel_process.terminate() |
77 |
| - self.tunnel_process.wait() |
78 |
| - del self.tunnel_process |
79 |
| - else: |
80 |
| - print("No active tunnel to close.") |
81 |
| - |
82 |
| - def add_known_host(self, host): |
83 |
| - known_hosts_path = os.path.expanduser("~/.ssh/known_hosts") |
84 |
| - cmd = 'ssh-keyscan -H %s >> %s' % (host, known_hosts_path) |
85 |
| - |
86 |
| - try: |
87 |
| - subprocess.check_call(cmd, shell=True) |
88 |
| - logging.info("Successfully added %s to known_hosts." % host) |
89 |
| - except subprocess.CalledProcessError as e: |
90 |
| - raise Exception("Failed to add %s to known_hosts. Error: %s" % (host, str(e))) |
91 |
| - |
92 | 61 | def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
|
93 | 62 | encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None,
|
94 | 63 | stderr=None, get_process=None, timeout=None):
|
@@ -293,6 +262,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
|
293 | 262 | with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file:
|
294 | 263 | # For scp the port is specified by a "-P" option
|
295 | 264 | scp_args = ['-P' if x == '-p' else x for x in self.ssh_args]
|
| 265 | + |
296 | 266 | if not truncate:
|
297 | 267 | scp_cmd = ['scp'] + scp_args + [f"{self.ssh_dest}:{filename}", tmp_file.name]
|
298 | 268 | subprocess.run(scp_cmd, check=False) # The file might not exist yet
|
@@ -391,18 +361,11 @@ def get_process_children(self, pid):
|
391 | 361 |
|
392 | 362 | # Database control
|
393 | 363 | def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
|
394 |
| - """ |
395 |
| - Established SSH tunnel and Connects to a PostgreSQL |
396 |
| - """ |
397 |
| - self.establish_ssh_tunnel(local_port=port, remote_port=5432) |
398 |
| - try: |
399 |
| - conn = pglib.connect( |
400 |
| - host=host, |
401 |
| - port=port, |
402 |
| - database=dbname, |
403 |
| - user=user, |
404 |
| - password=password, |
405 |
| - ) |
406 |
| - return conn |
407 |
| - except Exception as e: |
408 |
| - raise Exception(f"Could not connect to the database. Error: {e}") |
| 364 | + conn = pglib.connect( |
| 365 | + host=host, |
| 366 | + port=port, |
| 367 | + database=dbname, |
| 368 | + user=user, |
| 369 | + password=password, |
| 370 | + ) |
| 371 | + return conn |
0 commit comments