1
- import logging
1
+ import getpass
2
2
import os
3
+ import logging
4
+ import platform
3
5
import subprocess
4
6
import tempfile
5
- import platform
6
7
7
8
# we support both pg8000 and psycopg2
8
9
try :
@@ -52,7 +53,8 @@ def __init__(self, conn_params: ConnectionParams):
52
53
if self .port :
53
54
self .ssh_args += ["-p" , self .port ]
54
55
self .remote = True
55
- self .username = conn_params .username or self .get_user ()
56
+ self .username = conn_params .username or getpass .getuser ()
57
+ self .ssh_dest = f"{ self .username } @{ self .host } " if conn_params .username else self .host
56
58
self .add_known_host (self .host )
57
59
self .tunnel_process = None
58
60
@@ -97,9 +99,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
97
99
"""
98
100
ssh_cmd = []
99
101
if isinstance (cmd , str ):
100
- ssh_cmd = ['ssh' , f" { self .username } @ { self . host } " ] + self .ssh_args + [cmd ]
102
+ ssh_cmd = ['ssh' , self .ssh_dest ] + self .ssh_args + [cmd ]
101
103
elif isinstance (cmd , list ):
102
- ssh_cmd = ['ssh' , f" { self .username } @ { self . host } " ] + self .ssh_args + cmd
104
+ ssh_cmd = ['ssh' , self .ssh_dest ] + self .ssh_args + cmd
103
105
process = subprocess .Popen (ssh_cmd , stdin = subprocess .PIPE , stdout = subprocess .PIPE , stderr = subprocess .PIPE )
104
106
if get_process :
105
107
return process
@@ -174,10 +176,6 @@ def set_env(self, var_name: str, var_val: str):
174
176
"""
175
177
return self .exec_command ("export {}={}" .format (var_name , var_val ))
176
178
177
- # Get environment variables
178
- def get_user (self ):
179
- return self .exec_command ("echo $USER" , encoding = get_default_encoding ()).strip ()
180
-
181
179
def get_name (self ):
182
180
cmd = 'python3 -c "import os; print(os.name)"'
183
181
return self .exec_command (cmd , encoding = get_default_encoding ()).strip ()
@@ -248,9 +246,9 @@ def mkdtemp(self, prefix=None):
248
246
- prefix (str): The prefix of the temporary directory name.
249
247
"""
250
248
if prefix :
251
- command = ["ssh" ] + self .ssh_args + [f" { self .username } @ { self . host } " , f"mktemp -d { prefix } XXXXX" ]
249
+ command = ["ssh" ] + self .ssh_args + [self .ssh_dest , f"mktemp -d { prefix } XXXXX" ]
252
250
else :
253
- command = ["ssh" ] + self .ssh_args + [f" { self .username } @ { self . host } " , "mktemp -d" ]
251
+ command = ["ssh" ] + self .ssh_args + [self .ssh_dest , "mktemp -d" ]
254
252
255
253
result = subprocess .run (command , stdout = subprocess .PIPE , stderr = subprocess .PIPE , text = True )
256
254
@@ -296,7 +294,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
296
294
# For scp the port is specified by a "-P" option
297
295
scp_args = ['-P' if x == '-p' else x for x in self .ssh_args ]
298
296
if not truncate :
299
- scp_cmd = ['scp' ] + scp_args + [f"{ self .username } @ { self . host } :{ filename } " , tmp_file .name ]
297
+ scp_cmd = ['scp' ] + scp_args + [f"{ self .ssh_dest } :{ filename } " , tmp_file .name ]
300
298
subprocess .run (scp_cmd , check = False ) # The file might not exist yet
301
299
tmp_file .seek (0 , os .SEEK_END )
302
300
@@ -312,11 +310,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
312
310
tmp_file .write (data )
313
311
314
312
tmp_file .flush ()
315
- scp_cmd = ['scp' ] + scp_args + [tmp_file .name , f"{ self .username } @ { self . host } :{ filename } " ]
313
+ scp_cmd = ['scp' ] + scp_args + [tmp_file .name , f"{ self .ssh_dest } :{ filename } " ]
316
314
subprocess .run (scp_cmd , check = True )
317
315
318
316
remote_directory = os .path .dirname (filename )
319
- mkdir_cmd = ['ssh' ] + self .ssh_args + [f" { self .username } @ { self . host } " , f"mkdir -p { remote_directory } " ]
317
+ mkdir_cmd = ['ssh' ] + self .ssh_args + [self .ssh_dest , f"mkdir -p { remote_directory } " ]
320
318
subprocess .run (mkdir_cmd , check = True )
321
319
322
320
os .remove (tmp_file .name )
@@ -381,7 +379,7 @@ def get_pid(self):
381
379
return int (self .exec_command ("echo $$" , encoding = get_default_encoding ()))
382
380
383
381
def get_process_children (self , pid ):
384
- command = ["ssh" ] + self .ssh_args + [f" { self .username } @ { self . host } " , f"pgrep -P { pid } " ]
382
+ command = ["ssh" ] + self .ssh_args + [self .ssh_dest , f"pgrep -P { pid } " ]
385
383
386
384
result = subprocess .run (command , stdout = subprocess .PIPE , stderr = subprocess .PIPE , text = True )
387
385
0 commit comments