Skip to content

Refactor test_parquet.py to use check_round_trip at module level #19332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 23, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 81 additions & 87 deletions pandas/tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,48 +110,79 @@ def df_full():
pd.Timestamp('20130103')]})


def test_invalid_engine(df_compat):
def check_round_trip(df, engine=None, path=None,
write_kwargs=None, read_kwargs=None,
expected=None, check_names=True,
repeat=2):
"""Verify parquet serializer and deserializer produce the same results.

Performs a pandas to disk and disk to pandas round trip,
then compares the 2 resulting DataFrames to verify equality.

Parameters
----------
df: Dataframe
engine: str, optional
'pyarrow' or 'fastparquet'
path: str, optional
write_kwargs: dict of str:str, optional
read_kwargs: dict of str:str, optional
expected: DataFrame, optional
Expected deserialization result, otherwise will be equal to `df`
check_names: list of str, optional
Closed set of column names to be compared
repeat: int, optional
How many times to repeat the test
"""

write_kwargs = write_kwargs or {'compression': None}
read_kwargs = read_kwargs or {}

if expected is None:
expected = df

if engine:
write_kwargs['engine'] = engine
read_kwargs['engine'] = engine

def compare(repeat):
for _ in range(repeat):
df.to_parquet(path, **write_kwargs)
actual = read_parquet(path, **read_kwargs)
tm.assert_frame_equal(expected, actual,
check_names=check_names)

if path is None:
with tm.ensure_clean() as path:
compare(repeat)
else:
compare(repeat)


def test_invalid_engine(df_compat):
with pytest.raises(ValueError):
df_compat.to_parquet('foo', 'bar')
check_round_trip(df_compat, 'foo', 'bar')


def test_options_py(df_compat, pa):
# use the set option

df = df_compat
with tm.ensure_clean() as path:

with pd.option_context('io.parquet.engine', 'pyarrow'):
df.to_parquet(path)

result = read_parquet(path)
tm.assert_frame_equal(result, df)
with pd.option_context('io.parquet.engine', 'pyarrow'):
check_round_trip(df_compat)


def test_options_fp(df_compat, fp):
# use the set option

df = df_compat
with tm.ensure_clean() as path:

with pd.option_context('io.parquet.engine', 'fastparquet'):
df.to_parquet(path, compression=None)

result = read_parquet(path)
tm.assert_frame_equal(result, df)
with pd.option_context('io.parquet.engine', 'fastparquet'):
check_round_trip(df_compat)


def test_options_auto(df_compat, fp, pa):
# use the set option

df = df_compat
with tm.ensure_clean() as path:

with pd.option_context('io.parquet.engine', 'auto'):
df.to_parquet(path)

result = read_parquet(path)
tm.assert_frame_equal(result, df)
with pd.option_context('io.parquet.engine', 'auto'):
check_round_trip(df_compat)


def test_options_get_engine(fp, pa):
Expand Down Expand Up @@ -228,53 +259,23 @@ def check_error_on_write(self, df, engine, exc):
with tm.ensure_clean() as path:
to_parquet(df, path, engine, compression=None)

def check_round_trip(self, df, engine, expected=None, path=None,
write_kwargs=None, read_kwargs=None,
check_names=True):

if write_kwargs is None:
write_kwargs = {'compression': None}

if read_kwargs is None:
read_kwargs = {}

if expected is None:
expected = df

if path is None:
with tm.ensure_clean() as path:
check_round_trip_equals(df, path, engine,
write_kwargs=write_kwargs,
read_kwargs=read_kwargs,
expected=expected,
check_names=check_names)
else:
check_round_trip_equals(df, path, engine,
write_kwargs=write_kwargs,
read_kwargs=read_kwargs,
expected=expected,
check_names=check_names)


class TestBasic(Base):

def test_error(self, engine):

for obj in [pd.Series([1, 2, 3]), 1, 'foo', pd.Timestamp('20130101'),
np.array([1, 2, 3])]:
self.check_error_on_write(obj, engine, ValueError)

def test_columns_dtypes(self, engine):

df = pd.DataFrame({'string': list('abc'),
'int': list(range(1, 4))})

# unicode
df.columns = [u'foo', u'bar']
self.check_round_trip(df, engine)
check_round_trip(df, engine)

def test_columns_dtypes_invalid(self, engine):

df = pd.DataFrame({'string': list('abc'),
'int': list(range(1, 4))})

Expand Down Expand Up @@ -302,17 +303,16 @@ def test_compression(self, engine, compression):
pytest.importorskip('brotli')

df = pd.DataFrame({'A': [1, 2, 3]})
self.check_round_trip(df, engine,
write_kwargs={'compression': compression})
check_round_trip(df, engine, write_kwargs={'compression': compression})

def test_read_columns(self, engine):
# GH18154
df = pd.DataFrame({'string': list('abc'),
'int': list(range(1, 4))})

expected = pd.DataFrame({'string': list('abc')})
self.check_round_trip(df, engine, expected=expected,
read_kwargs={'columns': ['string']})
check_round_trip(df, engine, expected=expected,
read_kwargs={'columns': ['string']})

def test_write_index(self, engine):
check_names = engine != 'fastparquet'
Expand All @@ -323,7 +323,7 @@ def test_write_index(self, engine):
pytest.skip("pyarrow is < 0.7.0")

df = pd.DataFrame({'A': [1, 2, 3]})
self.check_round_trip(df, engine)
check_round_trip(df, engine)

indexes = [
[2, 3, 4],
Expand All @@ -334,12 +334,12 @@ def test_write_index(self, engine):
# non-default index
for index in indexes:
df.index = index
self.check_round_trip(df, engine, check_names=check_names)
check_round_trip(df, engine, check_names=check_names)

# index with meta-data
df.index = [0, 1, 2]
df.index.name = 'foo'
self.check_round_trip(df, engine)
check_round_trip(df, engine)

def test_write_multiindex(self, pa_ge_070):
# Not suppoprted in fastparquet as of 0.1.3 or older pyarrow version
Expand All @@ -348,7 +348,7 @@ def test_write_multiindex(self, pa_ge_070):
df = pd.DataFrame({'A': [1, 2, 3]})
index = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)])
df.index = index
self.check_round_trip(df, engine)
check_round_trip(df, engine)

def test_write_column_multiindex(self, engine):
# column multi-index
Expand All @@ -357,7 +357,6 @@ def test_write_column_multiindex(self, engine):
self.check_error_on_write(df, engine, ValueError)

def test_multiindex_with_columns(self, pa_ge_070):

engine = pa_ge_070
dates = pd.date_range('01-Jan-2018', '01-Dec-2018', freq='MS')
df = pd.DataFrame(np.random.randn(2 * len(dates), 3),
Expand All @@ -368,14 +367,10 @@ def test_multiindex_with_columns(self, pa_ge_070):
index2 = index1.copy(names=None)
for index in [index1, index2]:
df.index = index
with tm.ensure_clean() as path:
df.to_parquet(path, engine)
result = read_parquet(path, engine)
expected = df
tm.assert_frame_equal(result, expected)
result = read_parquet(path, engine, columns=['A', 'B'])
expected = df[['A', 'B']]
tm.assert_frame_equal(result, expected)

check_round_trip(df, engine)
check_round_trip(df, engine, read_kwargs={'columns': ['A', 'B']},
expected=df[['A', 'B']])


class TestParquetPyArrow(Base):
Expand All @@ -391,7 +386,7 @@ def test_basic(self, pa, df_full):
tz='Europe/Brussels')
df['bool_with_none'] = [True, None, True]

self.check_round_trip(df, pa)
check_round_trip(df, pa)

@pytest.mark.xfail(reason="pyarrow fails on this (ARROW-1883)")
def test_basic_subset_columns(self, pa, df_full):
Expand All @@ -402,8 +397,8 @@ def test_basic_subset_columns(self, pa, df_full):
df['datetime_tz'] = pd.date_range('20130101', periods=3,
tz='Europe/Brussels')

self.check_round_trip(df, pa, expected=df[['string', 'int']],
read_kwargs={'columns': ['string', 'int']})
check_round_trip(df, pa, expected=df[['string', 'int']],
read_kwargs={'columns': ['string', 'int']})

def test_duplicate_columns(self, pa):
# not currently able to handle duplicate columns
Expand Down Expand Up @@ -433,7 +428,7 @@ def test_categorical(self, pa_ge_070):

# de-serialized as object
expected = df.assign(a=df.a.astype(object))
self.check_round_trip(df, pa, expected)
check_round_trip(df, pa, expected=expected)

def test_categorical_unsupported(self, pa_lt_070):
pa = pa_lt_070
Expand All @@ -444,20 +439,19 @@ def test_categorical_unsupported(self, pa_lt_070):

def test_s3_roundtrip(self, df_compat, s3_resource, pa):
# GH #19134
self.check_round_trip(df_compat, pa,
path='s3://pandas-test/pyarrow.parquet')
check_round_trip(df_compat, pa,
path='s3://pandas-test/pyarrow.parquet')


class TestParquetFastParquet(Base):

def test_basic(self, fp, df_full):

df = df_full

# additional supported types for fastparquet
df['timedelta'] = pd.timedelta_range('1 day', periods=3)

self.check_round_trip(df, fp)
check_round_trip(df, fp)

@pytest.mark.skip(reason="not supported")
def test_duplicate_columns(self, fp):
Expand All @@ -470,7 +464,7 @@ def test_duplicate_columns(self, fp):
def test_bool_with_none(self, fp):
df = pd.DataFrame({'a': [True, None, False]})
expected = pd.DataFrame({'a': [1.0, np.nan, 0.0]}, dtype='float16')
self.check_round_trip(df, fp, expected=expected)
check_round_trip(df, fp, expected=expected)

def test_unsupported(self, fp):

Expand All @@ -486,7 +480,7 @@ def test_categorical(self, fp):
if LooseVersion(fastparquet.__version__) < LooseVersion("0.1.3"):
pytest.skip("CategoricalDtype not supported for older fp")
df = pd.DataFrame({'a': pd.Categorical(list('abc'))})
self.check_round_trip(df, fp)
check_round_trip(df, fp)

def test_datetime_tz(self, fp):
# doesn't preserve tz
Expand All @@ -495,7 +489,7 @@ def test_datetime_tz(self, fp):

# warns on the coercion
with catch_warnings(record=True):
self.check_round_trip(df, fp, df.astype('datetime64[ns]'))
check_round_trip(df, fp, expected=df.astype('datetime64[ns]'))

def test_filter_row_groups(self, fp):
d = {'a': list(range(0, 3))}
Expand All @@ -508,5 +502,5 @@ def test_filter_row_groups(self, fp):

def test_s3_roundtrip(self, df_compat, s3_resource, fp):
# GH #19134
self.check_round_trip(df_compat, fp,
path='s3://pandas-test/fastparquet.parquet')
check_round_trip(df_compat, fp,
path='s3://pandas-test/fastparquet.parquet')