diff --git a/pandas/tests/io/test_parquet.py b/pandas/tests/io/test_parquet.py index d472a5ed23c75..8a6a22abe23fa 100644 --- a/pandas/tests/io/test_parquet.py +++ b/pandas/tests/io/test_parquet.py @@ -110,48 +110,79 @@ def df_full(): pd.Timestamp('20130103')]}) -def test_invalid_engine(df_compat): +def check_round_trip(df, engine=None, path=None, + write_kwargs=None, read_kwargs=None, + expected=None, check_names=True, + repeat=2): + """Verify parquet serializer and deserializer produce the same results. + + Performs a pandas to disk and disk to pandas round trip, + then compares the 2 resulting DataFrames to verify equality. + + Parameters + ---------- + df: Dataframe + engine: str, optional + 'pyarrow' or 'fastparquet' + path: str, optional + write_kwargs: dict of str:str, optional + read_kwargs: dict of str:str, optional + expected: DataFrame, optional + Expected deserialization result, otherwise will be equal to `df` + check_names: list of str, optional + Closed set of column names to be compared + repeat: int, optional + How many times to repeat the test + """ + + write_kwargs = write_kwargs or {'compression': None} + read_kwargs = read_kwargs or {} + + if expected is None: + expected = df + + if engine: + write_kwargs['engine'] = engine + read_kwargs['engine'] = engine + + def compare(repeat): + for _ in range(repeat): + df.to_parquet(path, **write_kwargs) + actual = read_parquet(path, **read_kwargs) + tm.assert_frame_equal(expected, actual, + check_names=check_names) + + if path is None: + with tm.ensure_clean() as path: + compare(repeat) + else: + compare(repeat) + +def test_invalid_engine(df_compat): with pytest.raises(ValueError): - df_compat.to_parquet('foo', 'bar') + check_round_trip(df_compat, 'foo', 'bar') def test_options_py(df_compat, pa): # use the set option - df = df_compat - with tm.ensure_clean() as path: - - with pd.option_context('io.parquet.engine', 'pyarrow'): - df.to_parquet(path) - - result = read_parquet(path) - tm.assert_frame_equal(result, df) + with pd.option_context('io.parquet.engine', 'pyarrow'): + check_round_trip(df_compat) def test_options_fp(df_compat, fp): # use the set option - df = df_compat - with tm.ensure_clean() as path: - - with pd.option_context('io.parquet.engine', 'fastparquet'): - df.to_parquet(path, compression=None) - - result = read_parquet(path) - tm.assert_frame_equal(result, df) + with pd.option_context('io.parquet.engine', 'fastparquet'): + check_round_trip(df_compat) def test_options_auto(df_compat, fp, pa): + # use the set option - df = df_compat - with tm.ensure_clean() as path: - - with pd.option_context('io.parquet.engine', 'auto'): - df.to_parquet(path) - - result = read_parquet(path) - tm.assert_frame_equal(result, df) + with pd.option_context('io.parquet.engine', 'auto'): + check_round_trip(df_compat) def test_options_get_engine(fp, pa): @@ -228,53 +259,23 @@ def check_error_on_write(self, df, engine, exc): with tm.ensure_clean() as path: to_parquet(df, path, engine, compression=None) - def check_round_trip(self, df, engine, expected=None, path=None, - write_kwargs=None, read_kwargs=None, - check_names=True): - - if write_kwargs is None: - write_kwargs = {'compression': None} - - if read_kwargs is None: - read_kwargs = {} - - if expected is None: - expected = df - - if path is None: - with tm.ensure_clean() as path: - check_round_trip_equals(df, path, engine, - write_kwargs=write_kwargs, - read_kwargs=read_kwargs, - expected=expected, - check_names=check_names) - else: - check_round_trip_equals(df, path, engine, - write_kwargs=write_kwargs, - read_kwargs=read_kwargs, - expected=expected, - check_names=check_names) - class TestBasic(Base): def test_error(self, engine): - for obj in [pd.Series([1, 2, 3]), 1, 'foo', pd.Timestamp('20130101'), np.array([1, 2, 3])]: self.check_error_on_write(obj, engine, ValueError) def test_columns_dtypes(self, engine): - df = pd.DataFrame({'string': list('abc'), 'int': list(range(1, 4))}) # unicode df.columns = [u'foo', u'bar'] - self.check_round_trip(df, engine) + check_round_trip(df, engine) def test_columns_dtypes_invalid(self, engine): - df = pd.DataFrame({'string': list('abc'), 'int': list(range(1, 4))}) @@ -302,8 +303,7 @@ def test_compression(self, engine, compression): pytest.importorskip('brotli') df = pd.DataFrame({'A': [1, 2, 3]}) - self.check_round_trip(df, engine, - write_kwargs={'compression': compression}) + check_round_trip(df, engine, write_kwargs={'compression': compression}) def test_read_columns(self, engine): # GH18154 @@ -311,8 +311,8 @@ def test_read_columns(self, engine): 'int': list(range(1, 4))}) expected = pd.DataFrame({'string': list('abc')}) - self.check_round_trip(df, engine, expected=expected, - read_kwargs={'columns': ['string']}) + check_round_trip(df, engine, expected=expected, + read_kwargs={'columns': ['string']}) def test_write_index(self, engine): check_names = engine != 'fastparquet' @@ -323,7 +323,7 @@ def test_write_index(self, engine): pytest.skip("pyarrow is < 0.7.0") df = pd.DataFrame({'A': [1, 2, 3]}) - self.check_round_trip(df, engine) + check_round_trip(df, engine) indexes = [ [2, 3, 4], @@ -334,12 +334,12 @@ def test_write_index(self, engine): # non-default index for index in indexes: df.index = index - self.check_round_trip(df, engine, check_names=check_names) + check_round_trip(df, engine, check_names=check_names) # index with meta-data df.index = [0, 1, 2] df.index.name = 'foo' - self.check_round_trip(df, engine) + check_round_trip(df, engine) def test_write_multiindex(self, pa_ge_070): # Not suppoprted in fastparquet as of 0.1.3 or older pyarrow version @@ -348,7 +348,7 @@ def test_write_multiindex(self, pa_ge_070): df = pd.DataFrame({'A': [1, 2, 3]}) index = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)]) df.index = index - self.check_round_trip(df, engine) + check_round_trip(df, engine) def test_write_column_multiindex(self, engine): # column multi-index @@ -357,7 +357,6 @@ def test_write_column_multiindex(self, engine): self.check_error_on_write(df, engine, ValueError) def test_multiindex_with_columns(self, pa_ge_070): - engine = pa_ge_070 dates = pd.date_range('01-Jan-2018', '01-Dec-2018', freq='MS') df = pd.DataFrame(np.random.randn(2 * len(dates), 3), @@ -368,14 +367,10 @@ def test_multiindex_with_columns(self, pa_ge_070): index2 = index1.copy(names=None) for index in [index1, index2]: df.index = index - with tm.ensure_clean() as path: - df.to_parquet(path, engine) - result = read_parquet(path, engine) - expected = df - tm.assert_frame_equal(result, expected) - result = read_parquet(path, engine, columns=['A', 'B']) - expected = df[['A', 'B']] - tm.assert_frame_equal(result, expected) + + check_round_trip(df, engine) + check_round_trip(df, engine, read_kwargs={'columns': ['A', 'B']}, + expected=df[['A', 'B']]) class TestParquetPyArrow(Base): @@ -391,7 +386,7 @@ def test_basic(self, pa, df_full): tz='Europe/Brussels') df['bool_with_none'] = [True, None, True] - self.check_round_trip(df, pa) + check_round_trip(df, pa) @pytest.mark.xfail(reason="pyarrow fails on this (ARROW-1883)") def test_basic_subset_columns(self, pa, df_full): @@ -402,8 +397,8 @@ def test_basic_subset_columns(self, pa, df_full): df['datetime_tz'] = pd.date_range('20130101', periods=3, tz='Europe/Brussels') - self.check_round_trip(df, pa, expected=df[['string', 'int']], - read_kwargs={'columns': ['string', 'int']}) + check_round_trip(df, pa, expected=df[['string', 'int']], + read_kwargs={'columns': ['string', 'int']}) def test_duplicate_columns(self, pa): # not currently able to handle duplicate columns @@ -433,7 +428,7 @@ def test_categorical(self, pa_ge_070): # de-serialized as object expected = df.assign(a=df.a.astype(object)) - self.check_round_trip(df, pa, expected) + check_round_trip(df, pa, expected=expected) def test_categorical_unsupported(self, pa_lt_070): pa = pa_lt_070 @@ -444,20 +439,19 @@ def test_categorical_unsupported(self, pa_lt_070): def test_s3_roundtrip(self, df_compat, s3_resource, pa): # GH #19134 - self.check_round_trip(df_compat, pa, - path='s3://pandas-test/pyarrow.parquet') + check_round_trip(df_compat, pa, + path='s3://pandas-test/pyarrow.parquet') class TestParquetFastParquet(Base): def test_basic(self, fp, df_full): - df = df_full # additional supported types for fastparquet df['timedelta'] = pd.timedelta_range('1 day', periods=3) - self.check_round_trip(df, fp) + check_round_trip(df, fp) @pytest.mark.skip(reason="not supported") def test_duplicate_columns(self, fp): @@ -470,7 +464,7 @@ def test_duplicate_columns(self, fp): def test_bool_with_none(self, fp): df = pd.DataFrame({'a': [True, None, False]}) expected = pd.DataFrame({'a': [1.0, np.nan, 0.0]}, dtype='float16') - self.check_round_trip(df, fp, expected=expected) + check_round_trip(df, fp, expected=expected) def test_unsupported(self, fp): @@ -486,7 +480,7 @@ def test_categorical(self, fp): if LooseVersion(fastparquet.__version__) < LooseVersion("0.1.3"): pytest.skip("CategoricalDtype not supported for older fp") df = pd.DataFrame({'a': pd.Categorical(list('abc'))}) - self.check_round_trip(df, fp) + check_round_trip(df, fp) def test_datetime_tz(self, fp): # doesn't preserve tz @@ -495,7 +489,7 @@ def test_datetime_tz(self, fp): # warns on the coercion with catch_warnings(record=True): - self.check_round_trip(df, fp, df.astype('datetime64[ns]')) + check_round_trip(df, fp, expected=df.astype('datetime64[ns]')) def test_filter_row_groups(self, fp): d = {'a': list(range(0, 3))} @@ -508,5 +502,5 @@ def test_filter_row_groups(self, fp): def test_s3_roundtrip(self, df_compat, s3_resource, fp): # GH #19134 - self.check_round_trip(df_compat, fp, - path='s3://pandas-test/fastparquet.parquet') + check_round_trip(df_compat, fp, + path='s3://pandas-test/fastparquet.parquet')