diff --git a/src/_pytest/assertion/__init__.py b/src/_pytest/assertion/__init__.py index 22f3ca8e258..49ebced478e 100644 --- a/src/_pytest/assertion/__init__.py +++ b/src/_pytest/assertion/__init__.py @@ -3,15 +3,25 @@ from __future__ import annotations + +__all__ = ["_diff_text", "assertrepr_compare", "format_explanation"] + +from collections.abc import Callable from collections.abc import Generator +import os import sys from typing import Any +from typing import Literal from typing import Protocol from typing import TYPE_CHECKING from _pytest.assertion import rewrite from _pytest.assertion import truncate from _pytest.assertion import util +from _pytest.assertion._compare_eq import _diff_text +import _pytest.assertion._typing +from _pytest.assertion.assertrepr_compare import assertrepr_compare +from _pytest.assertion.format_explanation import format_explanation from _pytest.assertion.rewrite import assertstate_key from _pytest.config import Config from _pytest.config import hookimpl @@ -177,22 +187,28 @@ def callbinrepr(op, left: object, right: object) -> str | None: return res return None - saved_assert_hooks = util._reprcompare, util._assertion_pass - util._reprcompare = callbinrepr - util._config = item.config + saved_assert_hooks = ( + _pytest.assertion._typing._reprcompare, + _pytest.assertion._typing._assertion_pass, + ) + _pytest.assertion._typing._reprcompare = callbinrepr + _pytest.assertion._typing._config = item.config if ihook.pytest_assertion_pass.get_hookimpls(): def call_assertion_pass_hook(lineno: int, orig: str, expl: str) -> None: ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl) - util._assertion_pass = call_assertion_pass_hook + _pytest.assertion._typing._assertion_pass = call_assertion_pass_hook try: return (yield) finally: - util._reprcompare, util._assertion_pass = saved_assert_hooks - util._config = None + ( + _pytest.assertion._typing._reprcompare, + _pytest.assertion._typing._assertion_pass, + ) = saved_assert_hooks + _pytest.assertion._typing._config = None def pytest_sessionfinish(session: Session) -> None: diff --git a/src/_pytest/assertion/_compare_eq.py b/src/_pytest/assertion/_compare_eq.py new file mode 100644 index 00000000000..a942e615273 --- /dev/null +++ b/src/_pytest/assertion/_compare_eq.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +from collections.abc import Iterable +from collections.abc import Mapping +from collections.abc import Sequence +import pprint +from typing import Any + +from _pytest._io.pprint import PrettyPrinter +from _pytest._io.saferepr import saferepr +from _pytest.assertion._isx import isattrs +from _pytest.assertion._typing import _HighlightFunc +from _pytest.assertion.util import running_on_ci + + +def _compare_eq_sequence( + left: Sequence[Any], + right: Sequence[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes) + explanation: list[str] = [] + len_left = len(left) + len_right = len(right) + for i in range(min(len_left, len_right)): + if left[i] != right[i]: + if comparing_bytes: + # when comparing bytes, we want to see their ascii representation + # instead of their numeric values (#5260) + # using a slice gives us the ascii representation: + # >>> s = b'foo' + # >>> s[0] + # 102 + # >>> s[0:1] + # b'f' + left_value = left[i : i + 1] + right_value = right[i : i + 1] + else: + left_value = left[i] + right_value = right[i] + + explanation.append( + f"At index {i} diff:" + f" {highlighter(repr(left_value))} != {highlighter(repr(right_value))}" + ) + break + + if comparing_bytes: + # when comparing bytes, it doesn't help to show the "sides contain one or more + # items" longer explanation, so skip it + + return explanation + + len_diff = len_left - len_right + if len_diff: + if len_diff > 0: + dir_with_more = "Left" + extra = saferepr(left[len_right]) + else: + len_diff = 0 - len_diff + dir_with_more = "Right" + extra = saferepr(right[len_left]) + + if len_diff == 1: + explanation += [ + f"{dir_with_more} contains one more item: {highlighter(extra)}" + ] + else: + explanation += [ + f"{dir_with_more} contains {len_diff} more items, first extra item: {highlighter(extra)}" + ] + return explanation + + +def _compare_eq_dict( + left: Mapping[Any, Any], + right: Mapping[Any, Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + explanation: list[str] = [] + set_left = set(left) + set_right = set(right) + common = set_left.intersection(set_right) + same = {k: left[k] for k in common if left[k] == right[k]} + if same and verbose < 2: + explanation += [f"Omitting {len(same)} identical items, use -vv to show"] + elif same: + explanation += ["Common items:"] + explanation += highlighter(pprint.pformat(same)).splitlines() + diff = {k for k in common if left[k] != right[k]} + if diff: + explanation += ["Differing items:"] + for k in diff: + explanation += [ + highlighter(saferepr({k: left[k]})) + + " != " + + highlighter(saferepr({k: right[k]})) + ] + extra_left = set_left - set_right + len_extra_left = len(extra_left) + if len_extra_left: + explanation.append( + f"Left contains {len_extra_left} more item{'' if len_extra_left == 1 else 's'}:" + ) + explanation.extend( + highlighter(pprint.pformat({k: left[k] for k in extra_left})).splitlines() + ) + extra_right = set_right - set_left + len_extra_right = len(extra_right) + if len_extra_right: + explanation.append( + f"Right contains {len_extra_right} more item{'' if len_extra_right == 1 else 's'}:" + ) + explanation.extend( + highlighter(pprint.pformat({k: right[k] for k in extra_right})).splitlines() + ) + return explanation + + +def _diff_text( + left: str, right: str, highlighter: _HighlightFunc, verbose: int = 0 +) -> list[str]: + """Return the explanation for the diff between text. + + Unless --verbose is used this will skip leading and trailing + characters which are identical to keep the diff minimal. + """ + from difflib import ndiff + + explanation: list[str] = [] + + if verbose < 1: + i = 0 # just in case left or right has zero length + for i in range(min(len(left), len(right))): + if left[i] != right[i]: + break + if i > 42: + i -= 10 # Provide some context + explanation = [ + f"Skipping {i} identical leading characters in diff, use -v to show" + ] + left = left[i:] + right = right[i:] + if len(left) == len(right): + for i in range(len(left)): + if left[-i] != right[-i]: + break + if i > 42: + i -= 10 # Provide some context + explanation += [ + f"Skipping {i} identical trailing " + "characters in diff, use -v to show" + ] + left = left[:-i] + right = right[:-i] + keepends = True + if left.isspace() or right.isspace(): + left = repr(str(left)) + right = repr(str(right)) + explanation += ["Strings contain only whitespace, escaping them using repr()"] + # "right" is the expected base against which we compare "left", + # see https://github.com/pytest-dev/pytest/issues/3333 + explanation.extend( + highlighter( + "\n".join( + line.strip("\n") + for line in ndiff(right.splitlines(keepends), left.splitlines(keepends)) + ), + lexer="diff", + ).splitlines() + ) + return explanation + + +def _compare_eq_iterable( + left: Iterable[Any], + right: Iterable[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + if verbose <= 0 and not running_on_ci(): + return ["Use -v to get more diff"] + # dynamic import to speedup pytest + import difflib + + left_formatting = PrettyPrinter().pformat(left).splitlines() + right_formatting = PrettyPrinter().pformat(right).splitlines() + + explanation = ["", "Full diff:"] + # "right" is the expected base against which we compare "left", + # see https://github.com/pytest-dev/pytest/issues/3333 + explanation.extend( + highlighter( + "\n".join( + line.rstrip() + for line in difflib.ndiff(right_formatting, left_formatting) + ), + lexer="diff", + ).splitlines() + ) + return explanation + + +def has_default_eq( + obj: object, +) -> bool: + """Check if an instance of an object contains the default eq + + First, we check if the object's __eq__ attribute has __code__, + if so, we check the equally of the method code filename (__code__.co_filename) + to the default one generated by the dataclass and attr module + for dataclasses the default co_filename is , for attrs class, the __eq__ should contain "attrs eq generated" + """ + # inspired from https://github.com/willmcgugan/rich/blob/07d51ffc1aee6f16bd2e5a25b4e82850fb9ed778/rich/pretty.py#L68 + if hasattr(obj.__eq__, "__code__") and hasattr(obj.__eq__.__code__, "co_filename"): + code_filename = obj.__eq__.__code__.co_filename + + if isattrs(obj): + return "attrs generated " in code_filename + + return code_filename == "" # data class + return True diff --git a/src/_pytest/assertion/_compare_set.py b/src/_pytest/assertion/_compare_set.py new file mode 100644 index 00000000000..480f172c4b3 --- /dev/null +++ b/src/_pytest/assertion/_compare_set.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from collections.abc import Callable +from collections.abc import Set as AbstractSet +from typing import Any + +from _pytest._io.saferepr import saferepr +from _pytest.assertion._typing import _HighlightFunc + + +def _set_one_sided_diff( + posn: str, + set1: AbstractSet[Any], + set2: AbstractSet[Any], + highlighter: _HighlightFunc, +) -> list[str]: + explanation = [] + diff = set1 - set2 + if diff: + explanation.append(f"Extra items in the {posn} set:") + for item in diff: + explanation.append(highlighter(saferepr(item))) + return explanation + + +def _compare_eq_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + explanation = [] + explanation.extend(_set_one_sided_diff("left", left, right, highlighter)) + explanation.extend(_set_one_sided_diff("right", right, left, highlighter)) + return explanation + + +def _compare_gt_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + explanation = _compare_gte_set(left, right, highlighter) + if not explanation: + return ["Both sets are equal"] + return explanation + + +def _compare_lt_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + explanation = _compare_lte_set(left, right, highlighter) + if not explanation: + return ["Both sets are equal"] + return explanation + + +def _compare_gte_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + return _set_one_sided_diff("right", right, left, highlighter) + + +def _compare_lte_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + return _set_one_sided_diff("left", left, right, highlighter) + + +SetComparisonFunction = dict[ + str, + Callable[ + [AbstractSet[Any], AbstractSet[Any], _HighlightFunc, int], + list[str], + ], +] + +SET_COMPARISON_FUNCTIONS: SetComparisonFunction = { + "==": _compare_eq_set, + "!=": lambda *a, **kw: ["Both sets are equal"], + ">=": _compare_gte_set, + "<=": _compare_lte_set, + ">": _compare_gt_set, + "<": _compare_lt_set, +} diff --git a/src/_pytest/assertion/_isx.py b/src/_pytest/assertion/_isx.py new file mode 100644 index 00000000000..ee6c4db43d3 --- /dev/null +++ b/src/_pytest/assertion/_isx.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import collections.abc +from typing import Any + + +def issequence(x: Any) -> bool: + return isinstance(x, collections.abc.Sequence) and not isinstance(x, str) + + +def istext(x: Any) -> bool: + return isinstance(x, str) + + +def isdict(x: Any) -> bool: + return isinstance(x, dict) + + +def isnamedtuple(obj: Any) -> bool: + return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None + + +def isdatacls(obj: Any) -> bool: + return getattr(obj, "__dataclass_fields__", None) is not None + + +def isattrs(obj: Any) -> bool: + return getattr(obj, "__attrs_attrs__", None) is not None + + +def isiterable(obj: Any) -> bool: + try: + iter(obj) + return not istext(obj) + except Exception: + return False diff --git a/src/_pytest/assertion/_notin_text.py b/src/_pytest/assertion/_notin_text.py new file mode 100644 index 00000000000..305142c4553 --- /dev/null +++ b/src/_pytest/assertion/_notin_text.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Literal + +from _pytest._io.saferepr import saferepr +from _pytest.assertion._compare_eq import _diff_text + + +def _notin_text(term: str, text: str, verbose: int = 0) -> list[str]: + index = text.find(term) + head = text[:index] + tail = text[index + len(term) :] + correct_text = head + tail + diff = _diff_text(text, correct_text, dummy_highlighter, verbose) + newdiff = [f"{saferepr(term, maxsize=42)} is contained here:"] + for line in diff: + if line.startswith("Skipping"): + continue + if line.startswith("- "): + continue + if line.startswith("+ "): + newdiff.append(" " + line[2:]) + else: + newdiff.append(line) + return newdiff + + +def dummy_highlighter(source: str, lexer: Literal["diff", "python"] = "python") -> str: + """Dummy highlighter that returns the text unprocessed. + + Needed for _notin_text, as the diff gets post-processed to only show the "+" part. + """ + return source diff --git a/src/_pytest/assertion/_typing.py b/src/_pytest/assertion/_typing.py new file mode 100644 index 00000000000..780996e7e79 --- /dev/null +++ b/src/_pytest/assertion/_typing.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Literal +from typing import Protocol + +from _pytest.config import Config + + +class _HighlightFunc(Protocol): # noqa: PYI046 + def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str: + """Apply highlighting to the given source.""" + + +# The _reprcompare attribute on the util module is used by the new assertion +# interpretation code and assertion rewriter to detect this plugin was +# loaded and in turn call the hooks defined here as part of the +# DebugInterpreter. +_reprcompare: Callable[[str, object, object], str | None] | None = None + +# Works similarly as _reprcompare attribute. Is populated with the hook call +# when pytest_runtest_setup is called. +_assertion_pass: Callable[[int, str, str], None] | None = None + +# Config object which is assigned during pytest_runtest_protocol. +_config: Config | None = None diff --git a/src/_pytest/assertion/assertrepr_compare.py b/src/_pytest/assertion/assertrepr_compare.py new file mode 100644 index 00000000000..a83616c4b39 --- /dev/null +++ b/src/_pytest/assertion/assertrepr_compare.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from collections.abc import Sequence +import pprint +from typing import Any +from unicodedata import normalize + +from _pytest import outcomes +import _pytest._code +from _pytest._io.saferepr import saferepr +from _pytest._io.saferepr import saferepr_unlimited +from _pytest.assertion._compare_eq import _compare_eq_dict +from _pytest.assertion._compare_eq import _compare_eq_iterable +from _pytest.assertion._compare_eq import _compare_eq_sequence +from _pytest.assertion._compare_eq import _diff_text +from _pytest.assertion._compare_eq import has_default_eq +from _pytest.assertion._compare_set import SET_COMPARISON_FUNCTIONS +from _pytest.assertion._isx import isattrs +from _pytest.assertion._isx import isdatacls +from _pytest.assertion._isx import isiterable +from _pytest.assertion._isx import isnamedtuple +from _pytest.assertion._notin_text import _notin_text +from _pytest.assertion._typing import _HighlightFunc +from _pytest.config import Config + + +def assertrepr_compare( + config: Config, op: str, left: Any, right: Any, use_ascii: bool = False +) -> list[str] | None: + """Return specialised explanations for some operators/operands.""" + verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS) + + # Strings which normalize equal are often hard to distinguish when printed; use ascii() to make this easier. + # See issue #3246. + use_ascii = ( + isinstance(left, str) + and isinstance(right, str) + and normalize("NFD", left) == normalize("NFD", right) + ) + + if verbose > 1: + left_repr = saferepr_unlimited(left, use_ascii=use_ascii) + right_repr = saferepr_unlimited(right, use_ascii=use_ascii) + else: + # XXX: "15 chars indentation" is wrong + # ("E AssertionError: assert "); should use term width. + maxsize = ( + 80 - 15 - len(op) - 2 + ) // 2 # 15 chars indentation, 1 space around op + + left_repr = saferepr(left, maxsize=maxsize, use_ascii=use_ascii) + right_repr = saferepr(right, maxsize=maxsize, use_ascii=use_ascii) + + summary = f"{left_repr} {op} {right_repr}" + highlighter = config.get_terminal_writer()._highlight + explanation: list[str] | None + try: + match (left, op, right): + case ( + set() | frozenset(), + "==" | "!=" | ">=" | "<=" | ">" | "<", + set() | frozenset(), + ): + explanation = SET_COMPARISON_FUNCTIONS[op]( + left, right, highlighter, verbose + ) + case (_, "==", _): + explanation = _compare_eq_any(left, right, highlighter, verbose) + case (str(), "not in", str()): + explanation = _notin_text(left, right, verbose) + case _: + explanation = None + except outcomes.Exit: + raise + except Exception: + repr_crash = _pytest._code.ExceptionInfo.from_current()._getreprcrash() + explanation = [ + f"(pytest_assertion plugin: representation of details failed: {repr_crash}.", + " Probably an object has a faulty __repr__.)", + ] + + if not explanation: + return None + + if explanation[0] != "": + explanation = ["", *explanation] + return [summary, *explanation] + + +def _compare_eq_any( + left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0 +) -> list[str]: + from _pytest.python_api import ApproxBase + + explanation: list[str] = [] + match (left, right): + case (str(), str()): + return _diff_text(left, right, highlighter, verbose) + case (_, ApproxBase() as approx_side): + explanation = approx_side._repr_compare(left) + case (ApproxBase() as approx_side, _): + explanation = approx_side._repr_compare(right) + case (tuple(), _) if getattr(left, "_fields", None) is not None: + explanation = _compare_eq_cls(left, right, highlighter, verbose) + case (Sequence(), Sequence()): + explanation = _compare_eq_sequence(left, right, highlighter, verbose) + case (dict(), dict()): + explanation = _compare_eq_dict(left, right, highlighter, verbose) + case _ if type(left) is type(right) and ( + getattr(left, "__dataclass_fields__", None) is not None + or getattr(left, "__attrs_attrs__", None) is not None + ): + explanation = _compare_eq_cls(left, right, highlighter, verbose) + case _: + explanation = [] + + if isiterable(left) and isiterable(right): + expl = _compare_eq_iterable(left, right, highlighter, verbose) + explanation.extend(expl) + return explanation + + +def _compare_eq_cls( + left: Any, right: Any, highlighter: _HighlightFunc, verbose: int +) -> list[str]: + if not has_default_eq(left): + return [] + if isdatacls(left): + import dataclasses + + all_fields = dataclasses.fields(left) + fields_to_check = [info.name for info in all_fields if info.compare] + elif isattrs(left): + all_fields = left.__attrs_attrs__ + fields_to_check = [field.name for field in all_fields if getattr(field, "eq")] + elif isnamedtuple(left): + fields_to_check = left._fields + else: + assert False + + indent = " " + same = [] + diff = [] + for field in fields_to_check: + if getattr(left, field) == getattr(right, field): + same.append(field) + else: + diff.append(field) + + explanation = [] + if same or diff: + explanation += [""] + if same and verbose < 2: + explanation.append(f"Omitting {len(same)} identical items, use -vv to show") + elif same: + explanation += ["Matching attributes:"] + explanation += highlighter(pprint.pformat(same)).splitlines() + if diff: + explanation += ["Differing attributes:"] + explanation += highlighter(pprint.pformat(diff)).splitlines() + for field in diff: + field_left = getattr(left, field) + field_right = getattr(right, field) + explanation += [ + "", + f"Drill down into differing attribute {field}:", + f"{indent}{field}: {highlighter(repr(field_left))} != {highlighter(repr(field_right))}", + ] + explanation += [ + indent + line + for line in _compare_eq_any( + field_left, field_right, highlighter, verbose + ) + ] + return explanation diff --git a/src/_pytest/assertion/format_explanation.py b/src/_pytest/assertion/format_explanation.py new file mode 100644 index 00000000000..430b852666e --- /dev/null +++ b/src/_pytest/assertion/format_explanation.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from collections.abc import Sequence + + +def format_explanation(explanation: str) -> str: + r"""Format an explanation. + + Normally all embedded newlines are escaped, however there are + three exceptions: \n{, \n} and \n~. The first two are intended + cover nested explanations, see function and attribute explanations + for examples (.visit_Call(), visit_Attribute()). The last one is + for when one explanation needs to span multiple lines, e.g. when + displaying diffs. + """ + lines = _split_explanation(explanation) + result = _format_lines(lines) + return "\n".join(result) + + +def _split_explanation(explanation: str) -> list[str]: + r"""Return a list of individual lines in the explanation. + + This will return a list of lines split on '\n{', '\n}' and '\n~'. + Any other newlines will be escaped and appear in the line as the + literal '\n' characters. + """ + raw_lines = (explanation or "").split("\n") + lines = [raw_lines[0]] + for values in raw_lines[1:]: + if values and values[0] in ["{", "}", "~", ">"]: + lines.append(values) + else: + lines[-1] += "\\n" + values + return lines + + +def _format_lines(lines: Sequence[str]) -> list[str]: + """Format the individual lines. + + This will replace the '{', '}' and '~' characters of our mini formatting + language with the proper 'where ...', 'and ...' and ' + ...' text, taking + care of indentation along the way. + + Return a list of formatted lines. + """ + result = list(lines[:1]) + stack = [0] + stackcnt = [0] + for line in lines[1:]: + if line.startswith("{"): + if stackcnt[-1]: + s = "and " + else: + s = "where " + stack.append(len(result)) + stackcnt[-1] += 1 + stackcnt.append(0) + result.append(" +" + " " * (len(stack) - 1) + s + line[1:]) + elif line.startswith("}"): + stack.pop() + stackcnt.pop() + result[stack[-1]] += line[1:] + else: + assert line[0] in ["~", ">"] + stack[-1] += 1 + indent = len(stack) if line.startswith("~") else len(stack) - 1 + result.append(" " * indent + line[1:]) + assert len(stack) == 1 + return result diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 566549d66f2..e3036278dff 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -26,6 +26,8 @@ from typing import IO from typing import TYPE_CHECKING +import _pytest.assertion._typing + if sys.version_info >= (3, 12): from importlib.resources.abc import TraversableResources @@ -41,7 +43,6 @@ from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr_unlimited from _pytest._version import version -from _pytest.assertion import util from _pytest.config import Config from _pytest.fixtures import FixtureFunctionDefinition from _pytest.main import Session @@ -433,7 +434,7 @@ def _saferepr(obj: object) -> str: # for bound methods, skip redundant information return obj.__name__ - maxsize = _get_maxsize_for_saferepr(util._config) + maxsize = _get_maxsize_for_saferepr(_pytest.assertion._typing._config) if not maxsize: return saferepr_unlimited(obj).replace("\n", "\\n") return saferepr(obj, maxsize=maxsize).replace("\n", "\\n") @@ -465,7 +466,9 @@ def _format_assertmsg(obj: object) -> str: # However in either case we want to preserve the newline. replaces = [("\n", "\n~"), ("%", "%%")] if not isinstance(obj, str): - obj = saferepr(obj, _get_maxsize_for_saferepr(util._config)) + obj = saferepr( + obj, _get_maxsize_for_saferepr(_pytest.assertion._typing._config) + ) replaces.append(("\\n", "\n~")) for r1, r2 in replaces: @@ -503,22 +506,24 @@ def _call_reprcompare( done = True if done: break - if util._reprcompare is not None: - custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) + if _pytest.assertion._typing._reprcompare is not None: + custom = _pytest.assertion._typing._reprcompare( + ops[i], each_obj[i], each_obj[i + 1] + ) if custom is not None: return custom return expl def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None: - if util._assertion_pass is not None: - util._assertion_pass(lineno, orig, expl) + if _pytest.assertion._typing._assertion_pass is not None: + _pytest.assertion._typing._assertion_pass(lineno, orig, expl) def _check_if_assertion_pass_impl() -> bool: """Check if any plugins implement the pytest_assertion_pass hook in order not to generate explanation unnecessarily (might be expensive).""" - return True if util._assertion_pass else False + return True if _pytest.assertion._typing._assertion_pass else False UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index cc499f7186f..52438c63404 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -3,616 +3,7 @@ from __future__ import annotations -import collections.abc -from collections.abc import Callable -from collections.abc import Iterable -from collections.abc import Mapping -from collections.abc import Sequence -from collections.abc import Set as AbstractSet import os -import pprint -from typing import Any -from typing import Literal -from typing import Protocol -from unicodedata import normalize - -from _pytest import outcomes -import _pytest._code -from _pytest._io.pprint import PrettyPrinter -from _pytest._io.saferepr import saferepr -from _pytest._io.saferepr import saferepr_unlimited -from _pytest.config import Config - - -# The _reprcompare attribute on the util module is used by the new assertion -# interpretation code and assertion rewriter to detect this plugin was -# loaded and in turn call the hooks defined here as part of the -# DebugInterpreter. -_reprcompare: Callable[[str, object, object], str | None] | None = None - -# Works similarly as _reprcompare attribute. Is populated with the hook call -# when pytest_runtest_setup is called. -_assertion_pass: Callable[[int, str, str], None] | None = None - -# Config object which is assigned during pytest_runtest_protocol. -_config: Config | None = None - - -class _HighlightFunc(Protocol): - def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str: - """Apply highlighting to the given source.""" - - -def dummy_highlighter(source: str, lexer: Literal["diff", "python"] = "python") -> str: - """Dummy highlighter that returns the text unprocessed. - - Needed for _notin_text, as the diff gets post-processed to only show the "+" part. - """ - return source - - -def format_explanation(explanation: str) -> str: - r"""Format an explanation. - - Normally all embedded newlines are escaped, however there are - three exceptions: \n{, \n} and \n~. The first two are intended - cover nested explanations, see function and attribute explanations - for examples (.visit_Call(), visit_Attribute()). The last one is - for when one explanation needs to span multiple lines, e.g. when - displaying diffs. - """ - lines = _split_explanation(explanation) - result = _format_lines(lines) - return "\n".join(result) - - -def _split_explanation(explanation: str) -> list[str]: - r"""Return a list of individual lines in the explanation. - - This will return a list of lines split on '\n{', '\n}' and '\n~'. - Any other newlines will be escaped and appear in the line as the - literal '\n' characters. - """ - raw_lines = (explanation or "").split("\n") - lines = [raw_lines[0]] - for values in raw_lines[1:]: - if values and values[0] in ["{", "}", "~", ">"]: - lines.append(values) - else: - lines[-1] += "\\n" + values - return lines - - -def _format_lines(lines: Sequence[str]) -> list[str]: - """Format the individual lines. - - This will replace the '{', '}' and '~' characters of our mini formatting - language with the proper 'where ...', 'and ...' and ' + ...' text, taking - care of indentation along the way. - - Return a list of formatted lines. - """ - result = list(lines[:1]) - stack = [0] - stackcnt = [0] - for line in lines[1:]: - if line.startswith("{"): - if stackcnt[-1]: - s = "and " - else: - s = "where " - stack.append(len(result)) - stackcnt[-1] += 1 - stackcnt.append(0) - result.append(" +" + " " * (len(stack) - 1) + s + line[1:]) - elif line.startswith("}"): - stack.pop() - stackcnt.pop() - result[stack[-1]] += line[1:] - else: - assert line[0] in ["~", ">"] - stack[-1] += 1 - indent = len(stack) if line.startswith("~") else len(stack) - 1 - result.append(" " * indent + line[1:]) - assert len(stack) == 1 - return result - - -def issequence(x: Any) -> bool: - return isinstance(x, collections.abc.Sequence) and not isinstance(x, str) - - -def istext(x: Any) -> bool: - return isinstance(x, str) - - -def isdict(x: Any) -> bool: - return isinstance(x, dict) - - -def isset(x: Any) -> bool: - return isinstance(x, set | frozenset) - - -def isnamedtuple(obj: Any) -> bool: - return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None - - -def isdatacls(obj: Any) -> bool: - return getattr(obj, "__dataclass_fields__", None) is not None - - -def isattrs(obj: Any) -> bool: - return getattr(obj, "__attrs_attrs__", None) is not None - - -def isiterable(obj: Any) -> bool: - try: - iter(obj) - return not istext(obj) - except Exception: - return False - - -def has_default_eq( - obj: object, -) -> bool: - """Check if an instance of an object contains the default eq - - First, we check if the object's __eq__ attribute has __code__, - if so, we check the equally of the method code filename (__code__.co_filename) - to the default one generated by the dataclass and attr module - for dataclasses the default co_filename is , for attrs class, the __eq__ should contain "attrs eq generated" - """ - # inspired from https://github.com/willmcgugan/rich/blob/07d51ffc1aee6f16bd2e5a25b4e82850fb9ed778/rich/pretty.py#L68 - if hasattr(obj.__eq__, "__code__") and hasattr(obj.__eq__.__code__, "co_filename"): - code_filename = obj.__eq__.__code__.co_filename - - if isattrs(obj): - return "attrs generated " in code_filename - - return code_filename == "" # data class - return True - - -def assertrepr_compare( - config, op: str, left: Any, right: Any, use_ascii: bool = False -) -> list[str] | None: - """Return specialised explanations for some operators/operands.""" - verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS) - - # Strings which normalize equal are often hard to distinguish when printed; use ascii() to make this easier. - # See issue #3246. - use_ascii = ( - isinstance(left, str) - and isinstance(right, str) - and normalize("NFD", left) == normalize("NFD", right) - ) - - if verbose > 1: - left_repr = saferepr_unlimited(left, use_ascii=use_ascii) - right_repr = saferepr_unlimited(right, use_ascii=use_ascii) - else: - # XXX: "15 chars indentation" is wrong - # ("E AssertionError: assert "); should use term width. - maxsize = ( - 80 - 15 - len(op) - 2 - ) // 2 # 15 chars indentation, 1 space around op - - left_repr = saferepr(left, maxsize=maxsize, use_ascii=use_ascii) - right_repr = saferepr(right, maxsize=maxsize, use_ascii=use_ascii) - - summary = f"{left_repr} {op} {right_repr}" - highlighter = config.get_terminal_writer()._highlight - - explanation = None - try: - if op == "==": - explanation = _compare_eq_any(left, right, highlighter, verbose) - elif op == "not in": - if istext(left) and istext(right): - explanation = _notin_text(left, right, verbose) - elif op == "!=": - if isset(left) and isset(right): - explanation = ["Both sets are equal"] - elif op == ">=": - if isset(left) and isset(right): - explanation = _compare_gte_set(left, right, highlighter, verbose) - elif op == "<=": - if isset(left) and isset(right): - explanation = _compare_lte_set(left, right, highlighter, verbose) - elif op == ">": - if isset(left) and isset(right): - explanation = _compare_gt_set(left, right, highlighter, verbose) - elif op == "<": - if isset(left) and isset(right): - explanation = _compare_lt_set(left, right, highlighter, verbose) - - except outcomes.Exit: - raise - except Exception: - repr_crash = _pytest._code.ExceptionInfo.from_current()._getreprcrash() - explanation = [ - f"(pytest_assertion plugin: representation of details failed: {repr_crash}.", - " Probably an object has a faulty __repr__.)", - ] - - if not explanation: - return None - - if explanation[0] != "": - explanation = ["", *explanation] - return [summary, *explanation] - - -def _compare_eq_any( - left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0 -) -> list[str]: - explanation = [] - if istext(left) and istext(right): - explanation = _diff_text(left, right, highlighter, verbose) - else: - from _pytest.python_api import ApproxBase - - if isinstance(left, ApproxBase) or isinstance(right, ApproxBase): - # Although the common order should be obtained == expected, this ensures both ways - approx_side = left if isinstance(left, ApproxBase) else right - other_side = right if isinstance(left, ApproxBase) else left - - explanation = approx_side._repr_compare(other_side) - elif type(left) is type(right) and ( - isdatacls(left) or isattrs(left) or isnamedtuple(left) - ): - # Note: unlike dataclasses/attrs, namedtuples compare only the - # field values, not the type or field names. But this branch - # intentionally only handles the same-type case, which was often - # used in older code bases before dataclasses/attrs were available. - explanation = _compare_eq_cls(left, right, highlighter, verbose) - elif issequence(left) and issequence(right): - explanation = _compare_eq_sequence(left, right, highlighter, verbose) - elif isset(left) and isset(right): - explanation = _compare_eq_set(left, right, highlighter, verbose) - elif isdict(left) and isdict(right): - explanation = _compare_eq_dict(left, right, highlighter, verbose) - - if isiterable(left) and isiterable(right): - expl = _compare_eq_iterable(left, right, highlighter, verbose) - explanation.extend(expl) - - return explanation - - -def _diff_text( - left: str, right: str, highlighter: _HighlightFunc, verbose: int = 0 -) -> list[str]: - """Return the explanation for the diff between text. - - Unless --verbose is used this will skip leading and trailing - characters which are identical to keep the diff minimal. - """ - from difflib import ndiff - - explanation: list[str] = [] - - if verbose < 1: - i = 0 # just in case left or right has zero length - for i in range(min(len(left), len(right))): - if left[i] != right[i]: - break - if i > 42: - i -= 10 # Provide some context - explanation = [ - f"Skipping {i} identical leading characters in diff, use -v to show" - ] - left = left[i:] - right = right[i:] - if len(left) == len(right): - for i in range(len(left)): - if left[-i] != right[-i]: - break - if i > 42: - i -= 10 # Provide some context - explanation += [ - f"Skipping {i} identical trailing " - "characters in diff, use -v to show" - ] - left = left[:-i] - right = right[:-i] - keepends = True - if left.isspace() or right.isspace(): - left = repr(str(left)) - right = repr(str(right)) - explanation += ["Strings contain only whitespace, escaping them using repr()"] - # "right" is the expected base against which we compare "left", - # see https://github.com/pytest-dev/pytest/issues/3333 - explanation.extend( - highlighter( - "\n".join( - line.strip("\n") - for line in ndiff(right.splitlines(keepends), left.splitlines(keepends)) - ), - lexer="diff", - ).splitlines() - ) - return explanation - - -def _compare_eq_iterable( - left: Iterable[Any], - right: Iterable[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - if verbose <= 0 and not running_on_ci(): - return ["Use -v to get more diff"] - # dynamic import to speedup pytest - import difflib - - left_formatting = PrettyPrinter().pformat(left).splitlines() - right_formatting = PrettyPrinter().pformat(right).splitlines() - - explanation = ["", "Full diff:"] - # "right" is the expected base against which we compare "left", - # see https://github.com/pytest-dev/pytest/issues/3333 - explanation.extend( - highlighter( - "\n".join( - line.rstrip() - for line in difflib.ndiff(right_formatting, left_formatting) - ), - lexer="diff", - ).splitlines() - ) - return explanation - - -def _compare_eq_sequence( - left: Sequence[Any], - right: Sequence[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes) - explanation: list[str] = [] - len_left = len(left) - len_right = len(right) - for i in range(min(len_left, len_right)): - if left[i] != right[i]: - if comparing_bytes: - # when comparing bytes, we want to see their ascii representation - # instead of their numeric values (#5260) - # using a slice gives us the ascii representation: - # >>> s = b'foo' - # >>> s[0] - # 102 - # >>> s[0:1] - # b'f' - left_value = left[i : i + 1] - right_value = right[i : i + 1] - else: - left_value = left[i] - right_value = right[i] - - explanation.append( - f"At index {i} diff:" - f" {highlighter(repr(left_value))} != {highlighter(repr(right_value))}" - ) - break - - if comparing_bytes: - # when comparing bytes, it doesn't help to show the "sides contain one or more - # items" longer explanation, so skip it - - return explanation - - len_diff = len_left - len_right - if len_diff: - if len_diff > 0: - dir_with_more = "Left" - extra = saferepr(left[len_right]) - else: - len_diff = 0 - len_diff - dir_with_more = "Right" - extra = saferepr(right[len_left]) - - if len_diff == 1: - explanation += [ - f"{dir_with_more} contains one more item: {highlighter(extra)}" - ] - else: - explanation += [ - f"{dir_with_more} contains {len_diff} more items, first extra item: {highlighter(extra)}" - ] - return explanation - - -def _compare_eq_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - explanation = [] - explanation.extend(_set_one_sided_diff("left", left, right, highlighter)) - explanation.extend(_set_one_sided_diff("right", right, left, highlighter)) - return explanation - - -def _compare_gt_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - explanation = _compare_gte_set(left, right, highlighter) - if not explanation: - return ["Both sets are equal"] - return explanation - - -def _compare_lt_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - explanation = _compare_lte_set(left, right, highlighter) - if not explanation: - return ["Both sets are equal"] - return explanation - - -def _compare_gte_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - return _set_one_sided_diff("right", right, left, highlighter) - - -def _compare_lte_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - return _set_one_sided_diff("left", left, right, highlighter) - - -def _set_one_sided_diff( - posn: str, - set1: AbstractSet[Any], - set2: AbstractSet[Any], - highlighter: _HighlightFunc, -) -> list[str]: - explanation = [] - diff = set1 - set2 - if diff: - explanation.append(f"Extra items in the {posn} set:") - for item in diff: - explanation.append(highlighter(saferepr(item))) - return explanation - - -def _compare_eq_dict( - left: Mapping[Any, Any], - right: Mapping[Any, Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - explanation: list[str] = [] - set_left = set(left) - set_right = set(right) - common = set_left.intersection(set_right) - same = {k: left[k] for k in common if left[k] == right[k]} - if same and verbose < 2: - explanation += [f"Omitting {len(same)} identical items, use -vv to show"] - elif same: - explanation += ["Common items:"] - explanation += highlighter(pprint.pformat(same)).splitlines() - diff = {k for k in common if left[k] != right[k]} - if diff: - explanation += ["Differing items:"] - for k in diff: - explanation += [ - highlighter(saferepr({k: left[k]})) - + " != " - + highlighter(saferepr({k: right[k]})) - ] - extra_left = set_left - set_right - len_extra_left = len(extra_left) - if len_extra_left: - explanation.append( - f"Left contains {len_extra_left} more item{'' if len_extra_left == 1 else 's'}:" - ) - explanation.extend( - highlighter(pprint.pformat({k: left[k] for k in extra_left})).splitlines() - ) - extra_right = set_right - set_left - len_extra_right = len(extra_right) - if len_extra_right: - explanation.append( - f"Right contains {len_extra_right} more item{'' if len_extra_right == 1 else 's'}:" - ) - explanation.extend( - highlighter(pprint.pformat({k: right[k] for k in extra_right})).splitlines() - ) - return explanation - - -def _compare_eq_cls( - left: Any, right: Any, highlighter: _HighlightFunc, verbose: int -) -> list[str]: - if not has_default_eq(left): - return [] - if isdatacls(left): - import dataclasses - - all_fields = dataclasses.fields(left) - fields_to_check = [info.name for info in all_fields if info.compare] - elif isattrs(left): - all_fields = left.__attrs_attrs__ - fields_to_check = [field.name for field in all_fields if getattr(field, "eq")] - elif isnamedtuple(left): - fields_to_check = left._fields - else: - assert False - - indent = " " - same = [] - diff = [] - for field in fields_to_check: - if getattr(left, field) == getattr(right, field): - same.append(field) - else: - diff.append(field) - - explanation = [] - if same or diff: - explanation += [""] - if same and verbose < 2: - explanation.append(f"Omitting {len(same)} identical items, use -vv to show") - elif same: - explanation += ["Matching attributes:"] - explanation += highlighter(pprint.pformat(same)).splitlines() - if diff: - explanation += ["Differing attributes:"] - explanation += highlighter(pprint.pformat(diff)).splitlines() - for field in diff: - field_left = getattr(left, field) - field_right = getattr(right, field) - explanation += [ - "", - f"Drill down into differing attribute {field}:", - f"{indent}{field}: {highlighter(repr(field_left))} != {highlighter(repr(field_right))}", - ] - explanation += [ - indent + line - for line in _compare_eq_any( - field_left, field_right, highlighter, verbose - ) - ] - return explanation - - -def _notin_text(term: str, text: str, verbose: int = 0) -> list[str]: - index = text.find(term) - head = text[:index] - tail = text[index + len(term) :] - correct_text = head + tail - diff = _diff_text(text, correct_text, dummy_highlighter, verbose) - newdiff = [f"{saferepr(term, maxsize=42)} is contained here:"] - for line in diff: - if line.startswith("Skipping"): - continue - if line.startswith("- "): - continue - if line.startswith("+ "): - newdiff.append(" " + line[2:]) - else: - newdiff.append(line) - return newdiff def running_on_ci() -> bool: diff --git a/testing/test_assertion.py b/testing/test_assertion.py index 2c2830eb929..1e5c6804360 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -1976,10 +1976,10 @@ def f(): def test_exit_from_assertrepr_compare(monkeypatch) -> None: - def raise_exit(obj): + def raise_exit(*args, **kwargs): outcomes.exit("Quitting debugger") - monkeypatch.setattr(util, "istext", raise_exit) + monkeypatch.setattr(util, "_compare_eq_any", raise_exit) with pytest.raises(outcomes.Exit, match="Quitting debugger"): callequal(1, 1)