diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index e61f50afea838..ac1f74a5fceb6 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -55,6 +55,7 @@ Other enhancements - :meth:`DataFrame.plot.scatter` can now accept a categorical column as the argument to ``c`` (:issue:`12380`, :issue:`31357`) - :meth:`.Styler.set_tooltips` allows on hover tooltips to be added to styled HTML dataframes. - :meth:`Series.loc.__getitem__` and :meth:`Series.loc.__setitem__` with :class:`MultiIndex` now raising helpful error message when indexer has too many dimensions (:issue:`35349`) +- :meth:`pandas.read_stata` and :class:`StataReader` support reading data from compressed files. .. --------------------------------------------------------------------------- diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 9a5c9e4a2e2b2..8f8c435fae4f3 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -98,6 +98,19 @@ Return StataReader object for iterations, returns chunks with given number of lines.""" +_compression_params = f"""\ +compression : str or dict, default None + If string, specifies compression mode. If dict, value at key 'method' + specifies compression mode. Compression mode must be one of {{'infer', + 'gzip', 'bz2', 'zip', 'xz', None}}. If compression mode is 'infer' + and `filepath_or_buffer` is path-like, then detect compression from + the following extensions: '.gz', '.bz2', '.zip', or '.xz' (otherwise + no compression). If dict and compression mode is one of + {{'zip', 'gzip', 'bz2'}}, or inferred as one of the above, + other entries passed as additional compression options. +{generic._shared_docs["storage_options"]}""" + + _iterator_params = """\ iterator : bool, default False Return StataReader object.""" @@ -129,6 +142,7 @@ {_statafile_processing_params2} {_chunksize_params} {_iterator_params} +{_compression_params} Returns ------- @@ -180,6 +194,7 @@ {_statafile_processing_params1} {_statafile_processing_params2} {_chunksize_params} +{_compression_params} {_reader_notes} """ @@ -1026,6 +1041,7 @@ def __init__( columns: Optional[Sequence[str]] = None, order_categoricals: bool = True, chunksize: Optional[int] = None, + compression: CompressionOptions = "infer", storage_options: StorageOptions = None, ): super().__init__() @@ -1064,10 +1080,10 @@ def __init__( "rb", storage_options=storage_options, is_text=False, + compression=compression, ) as handles: # Copy to BytesIO, and ensure no encoding - contents = handles.handle.read() - self.path_or_buf = BytesIO(contents) # type: ignore[arg-type] + self.path_or_buf = BytesIO(handles.handle.read()) # type: ignore[arg-type] self._read_header() self._setup_dtype() @@ -1898,6 +1914,7 @@ def read_stata( order_categoricals: bool = True, chunksize: Optional[int] = None, iterator: bool = False, + compression: CompressionOptions = "infer", storage_options: StorageOptions = None, ) -> Union[DataFrame, StataReader]: @@ -1912,6 +1929,7 @@ def read_stata( order_categoricals=order_categoricals, chunksize=chunksize, storage_options=storage_options, + compression=compression, ) if iterator or chunksize: diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py index 5897b91a5fa70..058dc7659fc95 100644 --- a/pandas/tests/io/test_stata.py +++ b/pandas/tests/io/test_stata.py @@ -2003,3 +2003,48 @@ def test_precision_loss(): tm.assert_series_equal(reread.dtypes, expected_dt) assert reread.loc[0, "little"] == df.loc[0, "little"] assert reread.loc[0, "big"] == float(df.loc[0, "big"]) + + +def test_compression_roundtrip(compression): + df = DataFrame( + [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + index=["A", "B"], + columns=["X", "Y", "Z"], + ) + df.index.name = "index" + + with tm.ensure_clean() as path: + + df.to_stata(path, compression=compression) + reread = read_stata(path, compression=compression, index_col="index") + tm.assert_frame_equal(df, reread) + + # explicitly ensure file was compressed. + with tm.decompress_file(path, compression) as fh: + contents = io.BytesIO(fh.read()) + reread = pd.read_stata(contents, index_col="index") + tm.assert_frame_equal(df, reread) + + +@pytest.mark.parametrize("to_infer", [True, False]) +@pytest.mark.parametrize("read_infer", [True, False]) +def test_stata_compression(compression_only, read_infer, to_infer): + compression = compression_only + + ext = "gz" if compression == "gzip" else compression + filename = f"test.{ext}" + + df = DataFrame( + [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + index=["A", "B"], + columns=["X", "Y", "Z"], + ) + df.index.name = "index" + + to_compression = "infer" if to_infer else compression + read_compression = "infer" if read_infer else compression + + with tm.ensure_clean(filename) as path: + df.to_stata(path, compression=to_compression) + result = pd.read_stata(path, compression=read_compression, index_col="index") + tm.assert_frame_equal(result, df)