diff --git a/setup.py b/setup.py old mode 100644 new mode 100755 diff --git a/testgres/enums.py b/testgres/enums.py index 1f6869de..fb68f2bb 100644 --- a/testgres/enums.py +++ b/testgres/enums.py @@ -85,3 +85,14 @@ def from_process(process): # default return ProcessType.Unknown + + +class DumpFormat(Enum): + """ + Available dump formats + """ + + Plain = 'plain' + Custom = 'custom' + Directory = 'directory' + Tar = 'tar' diff --git a/testgres/node.py b/testgres/node.py index f93f8787..d8ce1f03 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -10,7 +10,10 @@ from six import raise_from, iteritems from tempfile import mkstemp, mkdtemp -from .enums import NodeStatus, ProcessType +from .enums import \ + NodeStatus, \ + ProcessType, \ + DumpFormat from .cache import cached_initdb @@ -54,7 +57,8 @@ QueryException, \ StartNodeException, \ TimeoutException, \ - TestgresException + TestgresException, \ + BackupException from .logger import TestgresLogger @@ -803,7 +807,11 @@ def safe_psql(self, query=None, **kwargs): return out - def dump(self, filename=None, dbname=None, username=None): + def dump(self, + filename=None, + dbname=None, + username=None, + format=DumpFormat.Plain): """ Dump database into a file using pg_dump. NOTE: the file is not removed automatically. @@ -812,14 +820,27 @@ def dump(self, filename=None, dbname=None, username=None): filename: database dump taken by pg_dump. dbname: database name to connect to. username: database user name. + format: format argument plain/custom/directory/tar. Returns: Path to a file containing dump. """ + # Check arguments + if not isinstance(format, DumpFormat): + try: + format = DumpFormat(format) + except ValueError: + msg = 'Invalid format "{}"'.format(format) + raise BackupException(msg) + + # Generate tmpfile or tmpdir def tmpfile(): - fd, fname = mkstemp(prefix=TMP_DUMP) - os.close(fd) + if format == DumpFormat.Directory: + fname = mkdtemp(prefix=TMP_DUMP) + else: + fd, fname = mkstemp(prefix=TMP_DUMP) + os.close(fd) return fname # Set default arguments @@ -833,7 +854,8 @@ def tmpfile(): "-h", self.host, "-f", filename, "-U", username, - "-d", dbname + "-d", dbname, + "-F", format.value ] # yapf: disable execute_utility(_params, self.utils_log_file) @@ -845,12 +867,29 @@ def restore(self, filename, dbname=None, username=None): Restore database from pg_dump's file. Args: - filename: database dump taken by pg_dump. + filename: database dump taken by pg_dump in custom/directory/tar formats. dbname: database name to connect to. username: database user name. """ - self.psql(filename=filename, dbname=dbname, username=username) + # Set default arguments + dbname = dbname or default_dbname() + username = username or default_username() + + _params = [ + get_bin_path("pg_restore"), + "-p", str(self.port), + "-h", self.host, + "-U", username, + "-d", dbname, + filename + ] # yapf: disable + + # try pg_restore if dump is binary formate, and psql if not + try: + execute_utility(_params, self.utils_log_name) + except ExecUtilException: + self.psql(filename=filename, dbname=dbname, username=username) @method_decorator(positional_args_hack(['dbname', 'query'])) def poll_query_until(self, diff --git a/tests/test_simple.py b/tests/test_simple.py index afa142ac..33defb12 100755 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -67,6 +67,8 @@ def removing(f): finally: if os.path.isfile(f): os.remove(f) + elif os.path.isdir(f): + rmtree(f, ignore_errors=True) class TestgresTests(unittest.TestCase): @@ -426,16 +428,17 @@ def test_dump(self): with get_new_node().init().start() as node1: node1.execute(query_create) - - # take a new dump - with removing(node1.dump()) as dump: - with get_new_node().init().start() as node2: - # restore dump - self.assertTrue(os.path.isfile(dump)) - node2.restore(filename=dump) - - res = node2.execute(query_select) - self.assertListEqual(res, [(1, ), (2, )]) + for format in ['plain', 'custom', 'directory', 'tar']: + with removing(node1.dump(format=format)) as dump: + with get_new_node().init().start() as node3: + if format == 'directory': + self.assertTrue(os.path.isdir(dump)) + else: + self.assertTrue(os.path.isfile(dump)) + # restore dump + node3.restore(filename=dump) + res = node3.execute(query_select) + self.assertListEqual(res, [(1, ), (2, )]) def test_users(self): with get_new_node().init().start() as node: