diff --git a/pandas/io/formats/format.py b/pandas/io/formats/format.py index 3dc4290953360..bfe8ed8ddafd0 100644 --- a/pandas/io/formats/format.py +++ b/pandas/io/formats/format.py @@ -938,17 +938,18 @@ def to_latex( """ from pandas.io.formats.latex import LatexFormatter - return LatexFormatter( + latex_formatter = LatexFormatter( self, - column_format=column_format, longtable=longtable, + column_format=column_format, multicolumn=multicolumn, multicolumn_format=multicolumn_format, multirow=multirow, caption=caption, label=label, position=position, - ).get_result(buf=buf, encoding=encoding) + ) + return latex_formatter.get_result(buf=buf, encoding=encoding) def _format_col(self, i: int) -> List[str]: frame = self.tr_frame diff --git a/pandas/io/formats/latex.py b/pandas/io/formats/latex.py index 715b8bbdf5672..8080d953da308 100644 --- a/pandas/io/formats/latex.py +++ b/pandas/io/formats/latex.py @@ -1,7 +1,8 @@ """ Module for formatting output data in Latex. """ -from typing import IO, List, Optional, Tuple +from abc import ABC, abstractmethod +from typing import IO, Iterator, List, Optional, Type import numpy as np @@ -10,56 +11,95 @@ from pandas.io.formats.format import DataFrameFormatter, TableFormatter -class LatexFormatter(TableFormatter): - """ - Used to render a DataFrame to a LaTeX tabular/longtable environment output. +class RowStringConverter(ABC): + r"""Converter for dataframe rows into LaTeX strings. Parameters ---------- formatter : `DataFrameFormatter` - column_format : str, default None - The columns format as specified in `LaTeX table format - `__ e.g 'rcl' for 3 columns - longtable : boolean, default False - Use a longtable environment instead of tabular. + Instance of `DataFrameFormatter`. + multicolumn: bool, optional + Whether to use \multicolumn macro. + multicolumn_format: str, optional + Multicolumn format. + multirow: bool, optional + Whether to use \multirow macro. - See Also - -------- - HTMLFormatter """ def __init__( self, formatter: DataFrameFormatter, - column_format: Optional[str] = None, - longtable: bool = False, multicolumn: bool = False, multicolumn_format: Optional[str] = None, multirow: bool = False, - caption: Optional[str] = None, - label: Optional[str] = None, - position: Optional[str] = None, ): self.fmt = formatter self.frame = self.fmt.frame - self.bold_rows = self.fmt.bold_rows - self.column_format = column_format - self.longtable = longtable self.multicolumn = multicolumn self.multicolumn_format = multicolumn_format self.multirow = multirow - self.caption = caption - self.label = label - self.escape = self.fmt.escape - self.position = position - self._table_float = any(p is not None for p in (caption, label, position)) + self.clinebuf: List[List[int]] = [] + self.strcols = self._get_strcols() + self.strrows: List[List[str]] = ( + list(zip(*self.strcols)) # type: ignore[arg-type] + ) + + def get_strrow(self, row_num: int) -> str: + """Get string representation of the row.""" + row = self.strrows[row_num] + + is_multicol = ( + row_num < self.column_levels and self.fmt.header and self.multicolumn + ) + + is_multirow = ( + row_num >= self.header_levels + and self.fmt.index + and self.multirow + and self.index_levels > 1 + ) + + is_cline_maybe_required = is_multirow and row_num < len(self.strrows) - 1 + + crow = self._preprocess_row(row) + + if is_multicol: + crow = self._format_multicolumn(crow) + if is_multirow: + crow = self._format_multirow(crow, row_num) + + lst = [] + lst.append(" & ".join(crow)) + lst.append(" \\\\") + if is_cline_maybe_required: + cline = self._compose_cline(row_num, len(self.strcols)) + lst.append(cline) + return "".join(lst) + + @property + def _header_row_num(self) -> int: + """Number of rows in header.""" + return self.header_levels if self.fmt.header else 0 + + @property + def index_levels(self) -> int: + """Integer number of levels in index.""" + return self.frame.index.nlevels + + @property + def column_levels(self) -> int: + return self.frame.columns.nlevels + + @property + def header_levels(self) -> int: + nlevels = self.column_levels + if self.fmt.has_index_names and self.fmt.show_index_names: + nlevels += 1 + return nlevels - def write_result(self, buf: IO[str]) -> None: - """ - Render a DataFrame to a LaTeX tabular, longtable, or table/tabular - environment output. - """ - # string representation of the columns + def _get_strcols(self) -> List[List[str]]: + """String representation of the columns.""" if len(self.frame.columns) == 0 or len(self.frame.index) == 0: info_line = ( f"Empty {type(self.frame).__name__}\n" @@ -70,12 +110,6 @@ def write_result(self, buf: IO[str]) -> None: else: strcols = self.fmt._to_str_columns() - def get_col_type(dtype): - if issubclass(dtype.type, np.number): - return "r" - else: - return "l" - # reestablish the MultiIndex that has been joined by _to_str_column if self.fmt.index and isinstance(self.frame.index, ABCMultiIndex): out = self.frame.index.format( @@ -107,89 +141,19 @@ def pad_empties(x): # Get rid of old multiindex column and add new ones strcols = out + strcols[1:] + return strcols - if self.column_format is None: - dtypes = self.frame.dtypes._values - column_format = "".join(map(get_col_type, dtypes)) - if self.fmt.index: - index_format = "l" * self.frame.index.nlevels - column_format = index_format + column_format - elif not isinstance(self.column_format, str): # pragma: no cover - raise AssertionError( - f"column_format must be str or unicode, not {type(column_format)}" - ) + def _preprocess_row(self, row: List[str]) -> List[str]: + """Preprocess elements of the row.""" + if self.fmt.escape: + crow = _escape_symbols(row) else: - column_format = self.column_format - - self._write_tabular_begin(buf, column_format) - - buf.write("\\toprule\n") + crow = [x if x else "{}" for x in row] + if self.fmt.bold_rows and self.fmt.index: + crow = _convert_to_bold(crow, self.index_levels) + return crow - ilevels = self.frame.index.nlevels - clevels = self.frame.columns.nlevels - nlevels = clevels - if self.fmt.has_index_names and self.fmt.show_index_names: - nlevels += 1 - strrows = list(zip(*strcols)) - self.clinebuf: List[List[int]] = [] - - for i, row in enumerate(strrows): - if i == nlevels and self.fmt.header: - buf.write("\\midrule\n") # End of header - if self.longtable: - buf.write("\\endhead\n") - buf.write("\\midrule\n") - buf.write( - f"\\multicolumn{{{len(row)}}}{{r}}" - "{{Continued on next page}} \\\\\n" - ) - buf.write("\\midrule\n") - buf.write("\\endfoot\n\n") - buf.write("\\bottomrule\n") - buf.write("\\endlastfoot\n") - if self.escape: - # escape backslashes first - crow = [ - ( - x.replace("\\", "\\textbackslash ") - .replace("_", "\\_") - .replace("%", "\\%") - .replace("$", "\\$") - .replace("#", "\\#") - .replace("{", "\\{") - .replace("}", "\\}") - .replace("~", "\\textasciitilde ") - .replace("^", "\\textasciicircum ") - .replace("&", "\\&") - if (x and x != "{}") - else "{}" - ) - for x in row - ] - else: - crow = [x if x else "{}" for x in row] - if self.bold_rows and self.fmt.index: - # bold row labels - crow = [ - f"\\textbf{{{x}}}" - if j < ilevels and x.strip() not in ["", "{}"] - else x - for j, x in enumerate(crow) - ] - if i < clevels and self.fmt.header and self.multicolumn: - # sum up columns to multicolumns - crow = self._format_multicolumn(crow, ilevels) - if i >= nlevels and self.fmt.index and self.multirow and ilevels > 1: - # sum up rows to multirows - crow = self._format_multirow(crow, ilevels, i, strrows) - buf.write(" & ".join(crow)) - buf.write(" \\\\\n") - if self.multirow and i < len(strrows) - 1: - self._print_cline(buf, i, len(strcols)) - - self._write_tabular_end(buf) - - def _format_multicolumn(self, row: List[str], ilevels: int) -> List[str]: + def _format_multicolumn(self, row: List[str]) -> List[str]: r""" Combine columns belonging to a group to a single multicolumn entry according to self.multicolumn_format @@ -199,7 +163,7 @@ def _format_multicolumn(self, row: List[str], ilevels: int) -> List[str]: will become \multicolumn{3}{l}{a} & b & \multicolumn{2}{l}{c} """ - row2 = list(row[:ilevels]) + row2 = row[: self.index_levels] ncol = 1 coltext = "" @@ -214,7 +178,7 @@ def append_col(): else: row2.append(coltext) - for c in row[ilevels:]: + for c in row[self.index_levels :]: # if next col has text, write the previous if c.strip(): if coltext: @@ -229,9 +193,7 @@ def append_col(): append_col() return row2 - def _format_multirow( - self, row: List[str], ilevels: int, i: int, rows: List[Tuple[str, ...]] - ) -> List[str]: + def _format_multirow(self, row: List[str], i: int) -> List[str]: r""" Check following rows, whether row should be a multirow @@ -241,10 +203,10 @@ def _format_multirow( b & 0 & \cline{1-2} b & 0 & """ - for j in range(ilevels): + for j in range(self.index_levels): if row[j].strip(): nrow = 1 - for r in rows[i + 1 :]: + for r in self.strrows[i + 1 :]: if not r[j].strip(): nrow += 1 else: @@ -256,88 +218,524 @@ def _format_multirow( self.clinebuf.append([i + nrow - 1, j + 1]) return row - def _print_cline(self, buf: IO[str], i: int, icol: int) -> None: + def _compose_cline(self, i: int, icol: int) -> str: """ - Print clines after multirow-blocks are finished. + Create clines after multirow-blocks are finished. """ + lst = [] for cl in self.clinebuf: if cl[0] == i: - buf.write(f"\\cline{{{cl[1]:d}-{icol:d}}}\n") - # remove entries that have been written to buffer - self.clinebuf = [x for x in self.clinebuf if x[0] != i] + lst.append(f"\n\\cline{{{cl[1]:d}-{icol:d}}}") + # remove entries that have been written to buffer + self.clinebuf = [x for x in self.clinebuf if x[0] != i] + return "".join(lst) + + +class RowStringIterator(RowStringConverter): + """Iterator over rows of the header or the body of the table.""" + + @abstractmethod + def __iter__(self) -> Iterator[str]: + """Iterate over LaTeX string representations of rows.""" + + +class RowHeaderIterator(RowStringIterator): + """Iterator for the table header rows.""" + + def __iter__(self) -> Iterator[str]: + for row_num in range(len(self.strrows)): + if row_num < self._header_row_num: + yield self.get_strrow(row_num) + + +class RowBodyIterator(RowStringIterator): + """Iterator for the table body rows.""" + + def __iter__(self) -> Iterator[str]: + for row_num in range(len(self.strrows)): + if row_num >= self._header_row_num: + yield self.get_strrow(row_num) - def _write_tabular_begin(self, buf, column_format: str): - """ - Write the beginning of a tabular environment or - nested table/tabular environments including caption and label. + +class TableBuilderAbstract(ABC): + """ + Abstract table builder producing string representation of LaTeX table. + + Parameters + ---------- + formatter : `DataFrameFormatter` + Instance of `DataFrameFormatter`. + column_format: str, optional + Column format, for example, 'rcl' for three columns. + multicolumn: bool, optional + Use multicolumn to enhance MultiIndex columns. + multicolumn_format: str, optional + The alignment for multicolumns, similar to column_format. + multirow: bool, optional + Use multirow to enhance MultiIndex rows. + caption: str, optional + Table caption. + label: str, optional + LaTeX label. + position: str, optional + Float placement specifier, for example, 'htb'. + """ + + def __init__( + self, + formatter: DataFrameFormatter, + column_format: Optional[str] = None, + multicolumn: bool = False, + multicolumn_format: Optional[str] = None, + multirow: bool = False, + caption: Optional[str] = None, + label: Optional[str] = None, + position: Optional[str] = None, + ): + self.fmt = formatter + self.column_format = column_format + self.multicolumn = multicolumn + self.multicolumn_format = multicolumn_format + self.multirow = multirow + self.caption = caption + self.label = label + self.position = position + + def get_result(self) -> str: + """String representation of LaTeX table.""" + elements = [ + self.env_begin, + self.top_separator, + self.header, + self.middle_separator, + self.env_body, + self.bottom_separator, + self.env_end, + ] + result = "\n".join([item for item in elements if item]) + trailing_newline = "\n" + result += trailing_newline + return result + + @property + @abstractmethod + def env_begin(self) -> str: + """Beginning of the environment.""" + + @property + @abstractmethod + def top_separator(self) -> str: + """Top level separator.""" + + @property + @abstractmethod + def header(self) -> str: + """Header lines.""" + + @property + @abstractmethod + def middle_separator(self) -> str: + """Middle level separator.""" + + @property + @abstractmethod + def env_body(self) -> str: + """Environment body.""" + + @property + @abstractmethod + def bottom_separator(self) -> str: + """Bottom level separator.""" + + @property + @abstractmethod + def env_end(self) -> str: + """End of the environment.""" + + +class GenericTableBuilder(TableBuilderAbstract): + """Table builder producing string representation of LaTeX table.""" + + @property + def header(self) -> str: + iterator = self._create_row_iterator(over="header") + return "\n".join(list(iterator)) + + @property + def top_separator(self) -> str: + return "\\toprule" + + @property + def middle_separator(self) -> str: + return "\\midrule" if self._is_separator_required() else "" + + @property + def env_body(self) -> str: + iterator = self._create_row_iterator(over="body") + return "\n".join(list(iterator)) + + def _is_separator_required(self) -> bool: + return bool(self.header and self.env_body) + + @property + def _position_macro(self) -> str: + r"""Position macro, extracted from self.position, like [h].""" + return f"[{self.position}]" if self.position else "" + + @property + def _caption_macro(self) -> str: + r"""Caption macro, extracted from self.caption, like \caption{cap}.""" + return f"\\caption{{{self.caption}}}" if self.caption else "" + + @property + def _label_macro(self) -> str: + r"""Label macro, extracted from self.label, like \label{ref}.""" + return f"\\label{{{self.label}}}" if self.label else "" + + def _create_row_iterator(self, over: str) -> RowStringIterator: + """Create iterator over header or body of the table. Parameters ---------- - buf : string or file handle - File path or object. If not specified, the result is returned as - a string. - column_format : str - The columns format as specified in `LaTeX table format - `__ e.g 'rcl' - for 3 columns + over : {'body', 'header'} + Over what to iterate. + + Returns + ------- + RowStringIterator + Iterator over body or header. """ - if self._table_float: - # then write output in a nested table/tabular or longtable environment - if self.caption is None: - caption_ = "" - else: - caption_ = f"\n\\caption{{{self.caption}}}" + iterator_kind = self._select_iterator(over) + return iterator_kind( + formatter=self.fmt, + multicolumn=self.multicolumn, + multicolumn_format=self.multicolumn_format, + multirow=self.multirow, + ) + + def _select_iterator(self, over: str) -> Type[RowStringIterator]: + """Select proper iterator over table rows.""" + if over == "header": + return RowHeaderIterator + elif over == "body": + return RowBodyIterator + else: + msg = f"'over' must be either 'header' or 'body', but {over} was provided" + raise ValueError(msg) + + +class LongTableBuilder(GenericTableBuilder): + """Concrete table builder for longtable. + + >>> from pandas import DataFrame + >>> from pandas.io.formats import format as fmt + >>> df = DataFrame({"a": [1, 2], "b": ["b1", "b2"]}) + >>> formatter = fmt.DataFrameFormatter(df) + >>> builder = LongTableBuilder(formatter, caption='caption', label='lab', + ... column_format='lrl') + >>> table = builder.get_result() + >>> print(table) + \\begin{longtable}{lrl} + \\caption{caption} + \\label{lab}\\\\ + \\toprule + {} & a & b \\\\ + \\midrule + \\endhead + \\midrule + \\multicolumn{3}{r}{{Continued on next page}} \\\\ + \\midrule + \\endfoot + + \\bottomrule + \\endlastfoot + 0 & 1 & b1 \\\\ + 1 & 2 & b2 \\\\ + \\end{longtable} + + """ - if self.label is None: - label_ = "" - else: - label_ = f"\n\\label{{{self.label}}}" + @property + def env_begin(self) -> str: + first_row = ( + f"\\begin{{longtable}}{self._position_macro}{{{self.column_format}}}" + ) + elements = [first_row, f"{self._caption_and_label()}"] + return "\n".join([item for item in elements if item]) + + def _caption_and_label(self) -> str: + if self.caption or self.label: + double_backslash = "\\\\" + elements = [f"{self._caption_macro}", f"{self._label_macro}"] + caption_and_label = "\n".join([item for item in elements if item]) + caption_and_label += double_backslash + return caption_and_label + else: + return "" + + @property + def middle_separator(self) -> str: + iterator = self._create_row_iterator(over="header") + elements = [ + "\\midrule", + "\\endhead", + "\\midrule", + f"\\multicolumn{{{len(iterator.strcols)}}}{{r}}" + "{{Continued on next page}} \\\\", + "\\midrule", + "\\endfoot\n", + "\\bottomrule", + "\\endlastfoot", + ] + if self._is_separator_required(): + return "\n".join(elements) + return "" + + @property + def bottom_separator(self) -> str: + return "" + + @property + def env_end(self) -> str: + return "\\end{longtable}" + + +class RegularTableBuilder(GenericTableBuilder): + """Concrete table builder for regular table. + + >>> from pandas import DataFrame + >>> from pandas.io.formats import format as fmt + >>> df = DataFrame({"a": [1, 2], "b": ["b1", "b2"]}) + >>> formatter = fmt.DataFrameFormatter(df) + >>> builder = RegularTableBuilder(formatter, caption='caption', label='lab', + ... column_format='lrc') + >>> table = builder.get_result() + >>> print(table) + \\begin{table} + \\centering + \\caption{caption} + \\label{lab} + \\begin{tabular}{lrc} + \\toprule + {} & a & b \\\\ + \\midrule + 0 & 1 & b1 \\\\ + 1 & 2 & b2 \\\\ + \\bottomrule + \\end{tabular} + \\end{table} + + """ - if self.position is None: - position_ = "" - else: - position_ = f"[{self.position}]" + @property + def env_begin(self) -> str: + elements = [ + f"\\begin{{table}}{self._position_macro}", + "\\centering", + f"{self._caption_macro}", + f"{self._label_macro}", + f"\\begin{{tabular}}{{{self.column_format}}}", + ] + return "\n".join([item for item in elements if item]) + + @property + def bottom_separator(self) -> str: + return "\\bottomrule" + + @property + def env_end(self) -> str: + return "\n".join(["\\end{tabular}", "\\end{table}"]) + + +class TabularBuilder(GenericTableBuilder): + """Concrete table builder for tabular environment. + + >>> from pandas import DataFrame + >>> from pandas.io.formats import format as fmt + >>> df = DataFrame({"a": [1, 2], "b": ["b1", "b2"]}) + >>> formatter = fmt.DataFrameFormatter(df) + >>> builder = TabularBuilder(formatter, column_format='lrc') + >>> table = builder.get_result() + >>> print(table) + \\begin{tabular}{lrc} + \\toprule + {} & a & b \\\\ + \\midrule + 0 & 1 & b1 \\\\ + 1 & 2 & b2 \\\\ + \\bottomrule + \\end{tabular} + + """ - if self.longtable: - table_ = f"\\begin{{longtable}}{position_}{{{column_format}}}" - tabular_ = "\n" - else: - table_ = f"\\begin{{table}}{position_}\n\\centering" - tabular_ = f"\n\\begin{{tabular}}{{{column_format}}}\n" - - if self.longtable and (self.caption is not None or self.label is not None): - # a double-backslash is required at the end of the line - # as discussed here: - # https://tex.stackexchange.com/questions/219138 - backlash_ = "\\\\" - else: - backlash_ = "" - buf.write(f"{table_}{caption_}{label_}{backlash_}{tabular_}") - else: - if self.longtable: - tabletype_ = "longtable" - else: - tabletype_ = "tabular" - buf.write(f"\\begin{{{tabletype_}}}{{{column_format}}}\n") + @property + def env_begin(self) -> str: + return f"\\begin{{tabular}}{{{self.column_format}}}" + + @property + def bottom_separator(self) -> str: + return "\\bottomrule" + + @property + def env_end(self) -> str: + return "\\end{tabular}" + + +class LatexFormatter(TableFormatter): + """ + Used to render a DataFrame to a LaTeX tabular/longtable environment output. + + Parameters + ---------- + formatter : `DataFrameFormatter` + column_format : str, default None + The columns format as specified in `LaTeX table format + `__ e.g 'rcl' for 3 columns + + See Also + -------- + HTMLFormatter + """ + + def __init__( + self, + formatter: DataFrameFormatter, + longtable: bool = False, + column_format: Optional[str] = None, + multicolumn: bool = False, + multicolumn_format: Optional[str] = None, + multirow: bool = False, + caption: Optional[str] = None, + label: Optional[str] = None, + position: Optional[str] = None, + ): + self.fmt = formatter + self.frame = self.fmt.frame + self.longtable = longtable + self.column_format = column_format # type: ignore[assignment] + self.multicolumn = multicolumn + self.multicolumn_format = multicolumn_format + self.multirow = multirow + self.caption = caption + self.label = label + self.position = position - def _write_tabular_end(self, buf): + def write_result(self, buf: IO[str]) -> None: """ - Write the end of a tabular environment or nested table/tabular - environment. + Render a DataFrame to a LaTeX tabular, longtable, or table/tabular + environment output. + """ + table_string = self.builder.get_result() + buf.write(table_string) - Parameters - ---------- - buf : string or file handle - File path or object. If not specified, the result is returned as - a string. + @property + def builder(self) -> TableBuilderAbstract: + """Concrete table builder. + Returns + ------- + TableBuilder """ + builder = self._select_builder() + return builder( + formatter=self.fmt, + column_format=self.column_format, + multicolumn=self.multicolumn, + multicolumn_format=self.multicolumn_format, + multirow=self.multirow, + caption=self.caption, + label=self.label, + position=self.position, + ) + + def _select_builder(self) -> Type[TableBuilderAbstract]: + """Select proper table builder.""" if self.longtable: - buf.write("\\end{longtable}\n") + return LongTableBuilder + if any([self.caption, self.label, self.position]): + return RegularTableBuilder + return TabularBuilder + + @property + def column_format(self) -> str: + """Column format.""" + return self._column_format + + @column_format.setter + def column_format(self, input_column_format: Optional[str]) -> None: + """Setter for column format.""" + if input_column_format is None: + self._column_format = ( + self._get_index_format() + self._get_column_format_based_on_dtypes() + ) + elif not isinstance(input_column_format, str): + raise ValueError( + f"column_format must be str or unicode, " + f"not {type(input_column_format)}" + ) else: - buf.write("\\bottomrule\n") - buf.write("\\end{tabular}\n") - if self._table_float: - buf.write("\\end{table}\n") - else: - pass + self._column_format = input_column_format + + def _get_column_format_based_on_dtypes(self) -> str: + """Get column format based on data type. + + Right alignment for numbers and left - for strings. + """ + + def get_col_type(dtype): + if issubclass(dtype.type, np.number): + return "r" + return "l" + + dtypes = self.frame.dtypes._values + return "".join(map(get_col_type, dtypes)) + + def _get_index_format(self) -> str: + """Get index column format.""" + return "l" * self.frame.index.nlevels if self.fmt.index else "" + + +def _escape_symbols(row: List[str]) -> List[str]: + """Carry out string replacements for special symbols. + + Parameters + ---------- + row : list + List of string, that may contain special symbols. + + Returns + ------- + list + list of strings with the special symbols replaced. + """ + return [ + ( + x.replace("\\", "\\textbackslash ") + .replace("_", "\\_") + .replace("%", "\\%") + .replace("$", "\\$") + .replace("#", "\\#") + .replace("{", "\\{") + .replace("}", "\\}") + .replace("~", "\\textasciitilde ") + .replace("^", "\\textasciicircum ") + .replace("&", "\\&") + if (x and x != "{}") + else "{}" + ) + for x in row + ] + + +def _convert_to_bold(crow: List[str], ilevels: int) -> List[str]: + """Convert elements in ``crow`` to bold.""" + return [ + f"\\textbf{{{x}}}" if j < ilevels and x.strip() not in ["", "{}"] else x + for j, x in enumerate(crow) + ] + + +if __name__ == "__main__": + import doctest + + doctest.testmod() diff --git a/pandas/tests/io/formats/test_to_latex.py b/pandas/tests/io/formats/test_to_latex.py index 96a9ed2b86cf4..a2cb8f52dfd5b 100644 --- a/pandas/tests/io/formats/test_to_latex.py +++ b/pandas/tests/io/formats/test_to_latex.py @@ -7,6 +7,14 @@ from pandas import DataFrame, Series import pandas._testing as tm +from pandas.io.formats.format import DataFrameFormatter +from pandas.io.formats.latex import ( + RegularTableBuilder, + RowBodyIterator, + RowHeaderIterator, + RowStringConverter, +) + class TestToLatex: def test_to_latex_filename(self, float_frame): @@ -60,6 +68,16 @@ def test_to_latex(self, float_frame): assert withoutindex_result == withoutindex_expected + @pytest.mark.parametrize( + "bad_column_format", + [5, 1.2, ["l", "r"], ("r", "c"), {"r", "c", "l"}, dict(a="r", b="l")], + ) + def test_to_latex_bad_column_format(self, bad_column_format): + df = DataFrame({"a": [1, 2], "b": ["b1", "b2"]}) + msg = r"column_format must be str or unicode" + with pytest.raises(ValueError, match=msg): + df.to_latex(column_format=bad_column_format) + def test_to_latex_format(self, float_frame): # GH Bug #9402 float_frame.to_latex(column_format="ccc") @@ -930,3 +948,87 @@ def test_to_latex_multindex_header(self): \end{tabular} """ assert observed == expected + + +class TestTableBuilder: + @pytest.fixture + def dataframe(self): + return DataFrame({"a": [1, 2], "b": ["b1", "b2"]}) + + @pytest.fixture + def table_builder(self, dataframe): + return RegularTableBuilder(formatter=DataFrameFormatter(dataframe)) + + def test_create_row_iterator(self, table_builder): + iterator = table_builder._create_row_iterator(over="header") + assert isinstance(iterator, RowHeaderIterator) + + def test_create_body_iterator(self, table_builder): + iterator = table_builder._create_row_iterator(over="body") + assert isinstance(iterator, RowBodyIterator) + + def test_create_body_wrong_kwarg_raises(self, table_builder): + with pytest.raises(ValueError, match="must be either 'header' or 'body'"): + table_builder._create_row_iterator(over="SOMETHING BAD") + + +class TestRowStringConverter: + @pytest.mark.parametrize( + "row_num, expected", + [ + (0, r"{} & Design & ratio & xy \\"), + (1, r"0 & 1 & 4 & 10 \\"), + (2, r"1 & 2 & 5 & 11 \\"), + ], + ) + def test_get_strrow_normal_without_escape(self, row_num, expected): + df = DataFrame({r"Design": [1, 2, 3], r"ratio": [4, 5, 6], r"xy": [10, 11, 12]}) + row_string_converter = RowStringConverter( + formatter=DataFrameFormatter(df, escape=True), + ) + assert row_string_converter.get_strrow(row_num=row_num) == expected + + @pytest.mark.parametrize( + "row_num, expected", + [ + (0, r"{} & Design \# & ratio, \% & x\&y \\"), + (1, r"0 & 1 & 4 & 10 \\"), + (2, r"1 & 2 & 5 & 11 \\"), + ], + ) + def test_get_strrow_normal_with_escape(self, row_num, expected): + df = DataFrame( + {r"Design #": [1, 2, 3], r"ratio, %": [4, 5, 6], r"x&y": [10, 11, 12]} + ) + row_string_converter = RowStringConverter( + formatter=DataFrameFormatter(df, escape=True), + ) + assert row_string_converter.get_strrow(row_num=row_num) == expected + + @pytest.mark.parametrize( + "row_num, expected", + [ + (0, r"{} & \multicolumn{2}{r}{c1} & \multicolumn{2}{r}{c2} & c3 \\"), + (1, r"{} & 0 & 1 & 0 & 1 & 0 \\"), + (2, r"0 & 0 & 5 & 0 & 5 & 0 \\"), + ], + ) + def test_get_strrow_multindex_multicolumn(self, row_num, expected): + df = DataFrame( + { + ("c1", 0): {x: x for x in range(5)}, + ("c1", 1): {x: x + 5 for x in range(5)}, + ("c2", 0): {x: x for x in range(5)}, + ("c2", 1): {x: x + 5 for x in range(5)}, + ("c3", 0): {x: x for x in range(5)}, + } + ) + + row_string_converter = RowStringConverter( + formatter=DataFrameFormatter(df), + multicolumn=True, + multicolumn_format="r", + multirow=True, + ) + + assert row_string_converter.get_strrow(row_num=row_num) == expected