Skip to content

Commit db5feae

Browse files
committed
Refactor Capture classes
Moves {Passthrough,CaptureIO} to capture module.
1 parent 4c9b850 commit db5feae

File tree

3 files changed

+32
-37
lines changed

3 files changed

+32
-37
lines changed

src/_pytest/capture.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
from tempfile import TemporaryFile
1212
from typing import BinaryIO
1313
from typing import Generator
14+
from typing import IO
1415
from typing import Iterable
1516
from typing import Optional
1617

1718
import pytest
18-
from _pytest.compat import CaptureAndPassthroughIO
19-
from _pytest.compat import CaptureIO
2019
from _pytest.config import Config
2120
from _pytest.fixtures import FixtureRequest
2221

@@ -98,7 +97,8 @@ def _getcapture(self, method):
9897
return MultiCapture(out=False, err=False, in_=False)
9998
elif method == "tee-sys":
10099
return MultiCapture(out=True, err=True, in_=False, Capture=TeeSysCapture)
101-
raise ValueError("unknown capturing method: %r" % method) # pragma: no cover
100+
else:
101+
assert False, "unknown capturing method: {}".format(method)
102102

103103
def is_capturing(self):
104104
if self.is_globally_capturing():
@@ -323,6 +323,25 @@ def capfdbinary(request):
323323
yield fixture
324324

325325

326+
class CaptureIO(io.TextIOWrapper):
327+
def __init__(self) -> None:
328+
super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True)
329+
330+
def getvalue(self) -> str:
331+
assert isinstance(self.buffer, io.BytesIO)
332+
return self.buffer.getvalue().decode("UTF-8")
333+
334+
335+
class PassthroughCaptureIO(CaptureIO):
336+
def __init__(self, other: IO) -> None:
337+
self._other = other
338+
super().__init__()
339+
340+
def write(self, s) -> int:
341+
super().write(s)
342+
return self._other.write(s)
343+
344+
326345
class CaptureFixture:
327346
"""
328347
Object returned by :py:func:`capsys`, :py:func:`capsysbinary`, :py:func:`capfd` and :py:func:`capfdbinary`
@@ -686,16 +705,13 @@ def snap(self):
686705

687706

688707
class TeeSysCapture(SysCapture):
689-
def __init__(self, fd, tmpfile=None):
690-
name = patchsysdict[fd]
691-
self._old = getattr(sys, name)
692-
self.name = name
693-
if tmpfile is None:
694-
if name == "stdin":
695-
tmpfile = DontReadFromInput()
696-
else:
697-
tmpfile = CaptureAndPassthroughIO(self._old)
698-
self.tmpfile = tmpfile
708+
def __init__(self, fd: int) -> None:
709+
old = getattr(sys, patchsysdict[fd])
710+
if fd == 0:
711+
super().__init__(fd)
712+
else:
713+
super().__init__(fd, PassthroughCaptureIO(old))
714+
assert self._old == old, (self._old, old)
699715

700716

701717
map_fixname_class = {

src/_pytest/compat.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44
import functools
55
import inspect
6-
import io
76
import os
87
import re
98
import sys
@@ -13,7 +12,6 @@
1312
from typing import Any
1413
from typing import Callable
1514
from typing import Generic
16-
from typing import IO
1715
from typing import Optional
1816
from typing import overload
1917
from typing import Tuple
@@ -360,25 +358,6 @@ def _setup_collect_fakemodule() -> None:
360358
setattr(pytest.collect, attr_name, getattr(pytest, attr_name)) # type: ignore
361359

362360

363-
class CaptureIO(io.TextIOWrapper):
364-
def __init__(self) -> None:
365-
super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True)
366-
367-
def getvalue(self) -> str:
368-
assert isinstance(self.buffer, io.BytesIO)
369-
return self.buffer.getvalue().decode("UTF-8")
370-
371-
372-
class CaptureAndPassthroughIO(CaptureIO):
373-
def __init__(self, other: IO) -> None:
374-
self._other = other
375-
super().__init__()
376-
377-
def write(self, s) -> int:
378-
super().write(s)
379-
return self._other.write(s)
380-
381-
382361
if sys.version_info < (3, 5, 2):
383362

384363
def overload(f): # noqa: F811

testing/test_capture.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -822,10 +822,10 @@ def test_write_bytes_to_buffer(self):
822822
assert f.getvalue() == "foo\r\n"
823823

824824

825-
class TestCaptureAndPassthroughIO(TestCaptureIO):
825+
class TestPassthroughCaptureIO(TestCaptureIO):
826826
def test_text(self):
827827
sio = io.StringIO()
828-
f = capture.CaptureAndPassthroughIO(sio)
828+
f = capture.PassthroughCaptureIO(sio)
829829
f.write("hello")
830830
s1 = f.getvalue()
831831
assert s1 == "hello"
@@ -836,7 +836,7 @@ def test_text(self):
836836

837837
def test_unicode_and_str_mixture(self):
838838
sio = io.StringIO()
839-
f = capture.CaptureAndPassthroughIO(sio)
839+
f = capture.PassthroughCaptureIO(sio)
840840
f.write("\u00f6")
841841
pytest.raises(TypeError, f.write, b"hello")
842842

0 commit comments

Comments
 (0)