Skip to content

Commit 4543f80

Browse files
authored
Make use of the default SSH user (#133)
Make use of the default SSH user
1 parent 2bf16a5 commit 4543f80

File tree

4 files changed

+22
-41
lines changed

4 files changed

+22
-41
lines changed

testgres/node.py

+7-19
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363

6464
from .defaults import \
6565
default_dbname, \
66-
default_username, \
6766
generate_app_name
6867

6968
from .exceptions import \
@@ -683,8 +682,6 @@ def slow_start(self, replica=False, dbname='template1', username=None, max_attem
683682
If False, waits for the instance to be in primary mode. Default is False.
684683
max_attempts:
685684
"""
686-
if not username:
687-
username = default_username()
688685
self.start()
689686

690687
if replica:
@@ -694,7 +691,7 @@ def slow_start(self, replica=False, dbname='template1', username=None, max_attem
694691
# Call poll_query_until until the expected value is returned
695692
self.poll_query_until(query=query,
696693
dbname=dbname,
697-
username=username,
694+
username=username or self.os_ops.username,
698695
suppress={InternalError,
699696
QueryException,
700697
ProgrammingError,
@@ -967,15 +964,13 @@ def psql(self,
967964
>>> psql(query='select 3', ON_ERROR_STOP=1)
968965
"""
969966

970-
# Set default arguments
971967
dbname = dbname or default_dbname()
972-
username = username or default_username()
973968

974969
psql_params = [
975970
self._get_bin_path("psql"),
976971
"-p", str(self.port),
977972
"-h", self.host,
978-
"-U", username,
973+
"-U", username or self.os_ops.username,
979974
"-X", # no .psqlrc
980975
"-A", # unaligned output
981976
"-t", # print rows only
@@ -1087,18 +1082,15 @@ def tmpfile():
10871082
fname = self.os_ops.mkstemp(prefix=TMP_DUMP)
10881083
return fname
10891084

1090-
# Set default arguments
1091-
dbname = dbname or default_dbname()
1092-
username = username or default_username()
10931085
filename = filename or tmpfile()
10941086

10951087
_params = [
10961088
self._get_bin_path("pg_dump"),
10971089
"-p", str(self.port),
10981090
"-h", self.host,
10991091
"-f", filename,
1100-
"-U", username,
1101-
"-d", dbname,
1092+
"-U", username or self.os_ops.username,
1093+
"-d", dbname or default_dbname(),
11021094
"-F", format.value
11031095
] # yapf: disable
11041096

@@ -1118,7 +1110,7 @@ def restore(self, filename, dbname=None, username=None):
11181110

11191111
# Set default arguments
11201112
dbname = dbname or default_dbname()
1121-
username = username or default_username()
1113+
username = username or self.os_ops.username
11221114

11231115
_params = [
11241116
self._get_bin_path("pg_restore"),
@@ -1388,15 +1380,13 @@ def pgbench(self,
13881380
if options is None:
13891381
options = []
13901382

1391-
# Set default arguments
13921383
dbname = dbname or default_dbname()
1393-
username = username or default_username()
13941384

13951385
_params = [
13961386
self._get_bin_path("pgbench"),
13971387
"-p", str(self.port),
13981388
"-h", self.host,
1399-
"-U", username,
1389+
"-U", username or self.os_ops.username
14001390
] + options # yapf: disable
14011391

14021392
# should be the last one
@@ -1463,15 +1453,13 @@ def pgbench_run(self, dbname=None, username=None, options=[], **kwargs):
14631453
>>> pgbench_run(time=10)
14641454
"""
14651455

1466-
# Set default arguments
14671456
dbname = dbname or default_dbname()
1468-
username = username or default_username()
14691457

14701458
_params = [
14711459
self._get_bin_path("pgbench"),
14721460
"-p", str(self.port),
14731461
"-h", self.host,
1474-
"-U", username,
1462+
"-U", username or self.os_ops.username
14751463
] + options # yapf: disable
14761464

14771465
for key, value in iteritems(kwargs):

testgres/operations/local_ops.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, conn_params=None):
3838
self.host = conn_params.host
3939
self.ssh_key = None
4040
self.remote = False
41-
self.username = conn_params.username or self.get_user()
41+
self.username = conn_params.username or getpass.getuser()
4242

4343
@staticmethod
4444
def _raise_exec_exception(message, command, exit_code, output):
@@ -130,10 +130,6 @@ def set_env(self, var_name, var_val):
130130
# Check if the directory is already in PATH
131131
os.environ[var_name] = var_val
132132

133-
# Get environment variables
134-
def get_user(self):
135-
return self.username or getpass.getuser()
136-
137133
def get_name(self):
138134
return os.name
139135

testgres/operations/os_ops.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,8 @@ def set_env(self, var_name, var_val):
4545
# Check if the directory is already in PATH
4646
raise NotImplementedError()
4747

48-
# Get environment variables
4948
def get_user(self):
50-
raise NotImplementedError()
49+
return self.username
5150

5251
def get_name(self):
5352
raise NotImplementedError()

testgres/operations/remote_ops.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import logging
1+
import getpass
22
import os
3+
import logging
4+
import platform
35
import subprocess
46
import tempfile
5-
import platform
67

78
# we support both pg8000 and psycopg2
89
try:
@@ -52,7 +53,8 @@ def __init__(self, conn_params: ConnectionParams):
5253
if self.port:
5354
self.ssh_args += ["-p", self.port]
5455
self.remote = True
55-
self.username = conn_params.username or self.get_user()
56+
self.username = conn_params.username or getpass.getuser()
57+
self.ssh_dest = f"{self.username}@{self.host}" if conn_params.username else self.host
5658
self.add_known_host(self.host)
5759
self.tunnel_process = None
5860

@@ -97,9 +99,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
9799
"""
98100
ssh_cmd = []
99101
if isinstance(cmd, str):
100-
ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_args + [cmd]
102+
ssh_cmd = ['ssh', self.ssh_dest] + self.ssh_args + [cmd]
101103
elif isinstance(cmd, list):
102-
ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_args + cmd
104+
ssh_cmd = ['ssh', self.ssh_dest] + self.ssh_args + cmd
103105
process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
104106
if get_process:
105107
return process
@@ -174,10 +176,6 @@ def set_env(self, var_name: str, var_val: str):
174176
"""
175177
return self.exec_command("export {}={}".format(var_name, var_val))
176178

177-
# Get environment variables
178-
def get_user(self):
179-
return self.exec_command("echo $USER", encoding=get_default_encoding()).strip()
180-
181179
def get_name(self):
182180
cmd = 'python3 -c "import os; print(os.name)"'
183181
return self.exec_command(cmd, encoding=get_default_encoding()).strip()
@@ -248,9 +246,9 @@ def mkdtemp(self, prefix=None):
248246
- prefix (str): The prefix of the temporary directory name.
249247
"""
250248
if prefix:
251-
command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
249+
command = ["ssh"] + self.ssh_args + [self.ssh_dest, f"mktemp -d {prefix}XXXXX"]
252250
else:
253-
command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", "mktemp -d"]
251+
command = ["ssh"] + self.ssh_args + [self.ssh_dest, "mktemp -d"]
254252

255253
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
256254

@@ -296,7 +294,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
296294
# For scp the port is specified by a "-P" option
297295
scp_args = ['-P' if x == '-p' else x for x in self.ssh_args]
298296
if not truncate:
299-
scp_cmd = ['scp'] + scp_args + [f"{self.username}@{self.host}:{filename}", tmp_file.name]
297+
scp_cmd = ['scp'] + scp_args + [f"{self.ssh_dest}:{filename}", tmp_file.name]
300298
subprocess.run(scp_cmd, check=False) # The file might not exist yet
301299
tmp_file.seek(0, os.SEEK_END)
302300

@@ -312,11 +310,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
312310
tmp_file.write(data)
313311

314312
tmp_file.flush()
315-
scp_cmd = ['scp'] + scp_args + [tmp_file.name, f"{self.username}@{self.host}:{filename}"]
313+
scp_cmd = ['scp'] + scp_args + [tmp_file.name, f"{self.ssh_dest}:{filename}"]
316314
subprocess.run(scp_cmd, check=True)
317315

318316
remote_directory = os.path.dirname(filename)
319-
mkdir_cmd = ['ssh'] + self.ssh_args + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"]
317+
mkdir_cmd = ['ssh'] + self.ssh_args + [self.ssh_dest, f"mkdir -p {remote_directory}"]
320318
subprocess.run(mkdir_cmd, check=True)
321319

322320
os.remove(tmp_file.name)
@@ -381,7 +379,7 @@ def get_pid(self):
381379
return int(self.exec_command("echo $$", encoding=get_default_encoding()))
382380

383381
def get_process_children(self, pid):
384-
command = ["ssh"] + self.ssh_args + [f"{self.username}@{self.host}", f"pgrep -P {pid}"]
382+
command = ["ssh"] + self.ssh_args + [self.ssh_dest, f"pgrep -P {pid}"]
385383

386384
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
387385

0 commit comments

Comments
 (0)