Skip to content

Commit 7847380

Browse files
authored
Add an SSH port parameter (#131)
1 parent 529b4df commit 7847380

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

testgres/operations/os_ops.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010

1111

1212
class ConnectionParams:
13-
def __init__(self, host='127.0.0.1', ssh_key=None, username=None):
13+
def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None):
1414
self.host = host
15+
self.port = port
1516
self.ssh_key = ssh_key
1617
self.username = username
1718

testgres/operations/remote_ops.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,13 @@ def __init__(self, conn_params: ConnectionParams):
4444
super().__init__(conn_params.username)
4545
self.conn_params = conn_params
4646
self.host = conn_params.host
47+
self.port = conn_params.port
4748
self.ssh_key = conn_params.ssh_key
49+
self.ssh_args = []
4850
if self.ssh_key:
49-
self.ssh_cmd = ["-i", self.ssh_key]
50-
else:
51-
self.ssh_cmd = []
51+
self.ssh_args += ["-i", self.ssh_key]
52+
if self.port:
53+
self.ssh_args += ["-p", self.port]
5254
self.remote = True
5355
self.username = conn_params.username or self.get_user()
5456
self.add_known_host(self.host)
@@ -95,9 +97,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
9597
"""
9698
ssh_cmd = []
9799
if isinstance(cmd, str):
98-
ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + [cmd]
100+
ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_args + [cmd]
99101
elif isinstance(cmd, list):
100-
ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + cmd
102+
ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_args + cmd
101103
process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
102104
if get_process:
103105
return process
@@ -246,9 +248,9 @@ def mkdtemp(self, prefix=None):
246248
- prefix (str): The prefix of the temporary directory name.
247249
"""
248250
if prefix:
249-
command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
251+
command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
250252
else:
251-
command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", "mktemp -d"]
253+
command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", "mktemp -d"]
252254

253255
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
254256

@@ -291,8 +293,10 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
291293
mode = "r+b" if binary else "r+"
292294

293295
with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file:
296+
# For scp the port is specified by a "-P" option
297+
scp_args = ['-P' if x == '-p' else x for x in self.ssh_args]
294298
if not truncate:
295-
scp_cmd = ['scp'] + self.ssh_cmd + [f"{self.username}@{self.host}:{filename}", tmp_file.name]
299+
scp_cmd = ['scp'] + scp_args + [f"{self.username}@{self.host}:{filename}", tmp_file.name]
296300
subprocess.run(scp_cmd, check=False) # The file might not exist yet
297301
tmp_file.seek(0, os.SEEK_END)
298302

@@ -308,11 +312,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
308312
tmp_file.write(data)
309313

310314
tmp_file.flush()
311-
scp_cmd = ['scp'] + self.ssh_cmd + [tmp_file.name, f"{self.username}@{self.host}:{filename}"]
315+
scp_cmd = ['scp'] + scp_args + [tmp_file.name, f"{self.username}@{self.host}:{filename}"]
312316
subprocess.run(scp_cmd, check=True)
313317

314318
remote_directory = os.path.dirname(filename)
315-
mkdir_cmd = ['ssh'] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"]
319+
mkdir_cmd = ['ssh'] + self.ssh_args + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"]
316320
subprocess.run(mkdir_cmd, check=True)
317321

318322
os.remove(tmp_file.name)
@@ -377,7 +381,7 @@ def get_pid(self):
377381
return int(self.exec_command("echo $$", encoding=get_default_encoding()))
378382

379383
def get_process_children(self, pid):
380-
command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"pgrep -P {pid}"]
384+
command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", f"pgrep -P {pid}"]
381385

382386
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
383387

0 commit comments

Comments
 (0)