Skip to content

Commit a128b12

Browse files
authored
Remove SSH tunnel (#136)
1 parent 4543f80 commit a128b12

File tree

4 files changed

+37
-54
lines changed

4 files changed

+37
-54
lines changed

testgres/node.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,9 @@ def get_auth_method(t):
528528
u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host),
529529
u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host),
530530
u"host\treplication\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host),
531-
u"host\tall\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host)
531+
u"host\tall\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host),
532+
u"host\tall\tall\tall\t{}\n".format(auth_host),
533+
u"host\treplication\tall\tall\t{}\n".format(auth_host)
532534
] # yapf: disable
533535

534536
# write missing lines
@@ -1671,9 +1673,15 @@ def _get_bin_path(self, filename):
16711673

16721674
class NodeApp:
16731675

1674-
def __init__(self, test_path, nodes_to_cleanup, os_ops=LocalOperations()):
1675-
self.test_path = test_path
1676-
self.nodes_to_cleanup = nodes_to_cleanup
1676+
def __init__(self, test_path=None, nodes_to_cleanup=None, os_ops=LocalOperations()):
1677+
if test_path:
1678+
if os.path.isabs(test_path):
1679+
self.test_path = test_path
1680+
else:
1681+
self.test_path = os.path.join(os_ops.cwd(), test_path)
1682+
else:
1683+
self.test_path = os_ops.cwd()
1684+
self.nodes_to_cleanup = nodes_to_cleanup if nodes_to_cleanup else []
16771685
self.os_ops = os_ops
16781686

16791687
def make_empty(

testgres/operations/os_ops.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import getpass
12
import locale
3+
import sys
24

35
try:
46
import psycopg2 as pglib # noqa: F401
@@ -24,7 +26,7 @@ def get_default_encoding():
2426
class OsOperations:
2527
def __init__(self, username=None):
2628
self.ssh_key = None
27-
self.username = username
29+
self.username = username or getpass.getuser()
2830

2931
# Command execution
3032
def exec_command(self, cmd, **kwargs):
@@ -34,6 +36,13 @@ def exec_command(self, cmd, **kwargs):
3436
def environ(self, var_name):
3537
raise NotImplementedError()
3638

39+
def cwd(self):
40+
if sys.platform == 'linux':
41+
cmd = 'pwd'
42+
elif sys.platform == 'win32':
43+
cmd = 'cd'
44+
return self.exec_command(cmd).decode().rstrip()
45+
3746
def find_executable(self, executable):
3847
raise NotImplementedError()
3948

testgres/operations/remote_ops.py

+9-46
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import getpass
22
import os
3-
import logging
43
import platform
54
import subprocess
65
import tempfile
@@ -55,40 +54,10 @@ def __init__(self, conn_params: ConnectionParams):
5554
self.remote = True
5655
self.username = conn_params.username or getpass.getuser()
5756
self.ssh_dest = f"{self.username}@{self.host}" if conn_params.username else self.host
58-
self.add_known_host(self.host)
59-
self.tunnel_process = None
6057

6158
def __enter__(self):
6259
return self
6360

64-
def __exit__(self, exc_type, exc_val, exc_tb):
65-
self.close_ssh_tunnel()
66-
67-
def establish_ssh_tunnel(self, local_port, remote_port):
68-
"""
69-
Establish an SSH tunnel from a local port to a remote PostgreSQL port.
70-
"""
71-
ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"]
72-
self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300)
73-
74-
def close_ssh_tunnel(self):
75-
if hasattr(self, 'tunnel_process'):
76-
self.tunnel_process.terminate()
77-
self.tunnel_process.wait()
78-
del self.tunnel_process
79-
else:
80-
print("No active tunnel to close.")
81-
82-
def add_known_host(self, host):
83-
known_hosts_path = os.path.expanduser("~/.ssh/known_hosts")
84-
cmd = 'ssh-keyscan -H %s >> %s' % (host, known_hosts_path)
85-
86-
try:
87-
subprocess.check_call(cmd, shell=True)
88-
logging.info("Successfully added %s to known_hosts." % host)
89-
except subprocess.CalledProcessError as e:
90-
raise Exception("Failed to add %s to known_hosts. Error: %s" % (host, str(e)))
91-
9261
def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
9362
encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None,
9463
stderr=None, get_process=None, timeout=None):
@@ -293,6 +262,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
293262
with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file:
294263
# For scp the port is specified by a "-P" option
295264
scp_args = ['-P' if x == '-p' else x for x in self.ssh_args]
265+
296266
if not truncate:
297267
scp_cmd = ['scp'] + scp_args + [f"{self.ssh_dest}:{filename}", tmp_file.name]
298268
subprocess.run(scp_cmd, check=False) # The file might not exist yet
@@ -391,18 +361,11 @@ def get_process_children(self, pid):
391361

392362
# Database control
393363
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
394-
"""
395-
Established SSH tunnel and Connects to a PostgreSQL
396-
"""
397-
self.establish_ssh_tunnel(local_port=port, remote_port=5432)
398-
try:
399-
conn = pglib.connect(
400-
host=host,
401-
port=port,
402-
database=dbname,
403-
user=user,
404-
password=password,
405-
)
406-
return conn
407-
except Exception as e:
408-
raise Exception(f"Could not connect to the database. Error: {e}")
364+
conn = pglib.connect(
365+
host=host,
366+
port=port,
367+
database=dbname,
368+
user=user,
369+
password=password,
370+
)
371+
return conn

testgres/plugins/pg_probackup2/pg_probackup2/app.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ def __str__(self):
4343
class ProbackupApp:
4444

4545
def __init__(self, test_class: unittest.TestCase,
46-
pg_node, pb_log_path, test_env, auto_compress_alg, backup_dir):
46+
pg_node, pb_log_path, test_env, auto_compress_alg, backup_dir, probackup_path=None):
4747
self.test_class = test_class
4848
self.pg_node = pg_node
4949
self.pb_log_path = pb_log_path
5050
self.test_env = test_env
5151
self.auto_compress_alg = auto_compress_alg
5252
self.backup_dir = backup_dir
53-
self.probackup_path = init_params.probackup_path
53+
self.probackup_path = probackup_path or init_params.probackup_path
5454
self.probackup_old_path = init_params.probackup_old_path
5555
self.remote = init_params.remote
5656
self.verbose = init_params.verbose
@@ -388,6 +388,7 @@ def catchup_node(
388388
backup_mode, source_pgdata, destination_node,
389389
options=None,
390390
remote_host='localhost',
391+
remote_port=None,
391392
expect_error=False,
392393
gdb=False
393394
):
@@ -401,7 +402,9 @@ def catchup_node(
401402
'--destination-pgdata={0}'.format(destination_node.data_dir)
402403
]
403404
if self.remote:
404-
cmd_list += ['--remote-proto=ssh', '--remote-host=%s' % remote_host]
405+
cmd_list += ['--remote-proto=ssh', f'--remote-host={remote_host}']
406+
if remote_port:
407+
cmd_list.append(f'--remote-port={remote_port}')
405408
if self.verbose:
406409
cmd_list += [
407410
'--log-level-file=VERBOSE',

0 commit comments

Comments
 (0)