Skip to content

Commit b9d7d91

Browse files
committed
wip: atexit
1 parent 60c365a commit b9d7d91

File tree

1 file changed

+34
-15
lines changed

1 file changed

+34
-15
lines changed

src/_pytest/capture.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import sys
1111
from io import UnsupportedOperation
1212
from tempfile import TemporaryFile
13+
from typing import Callable
14+
from typing import List
1315

1416
import pytest
1517
from _pytest.compat import CaptureIO
@@ -77,19 +79,28 @@ def __init__(self, method):
7779
self._method = method
7880
self._global_capturing = None
7981
self._current_item = None
82+
self._atexit_funcs: List[Callable] = []
83+
atexit.register(self._atexit_run)
8084

8185
def __repr__(self):
8286
return "<CaptureManager _method={!r} _global_capturing={!r} _current_item={!r}>".format(
8387
self._method, self._global_capturing, self._current_item
8488
)
8589

90+
def _atexit_register(self, func):
91+
self._atexit_funcs.append(func)
92+
93+
def _atexit_run(self):
94+
for func in self._atexit_funcs:
95+
func()
96+
8697
def _getcapture(self, method):
8798
if method == "fd":
88-
return MultiCapture(out=True, err=True, Capture=FDCapture)
99+
return MultiCapture(out=True, err=True, Capture=FDCapture, capman=self)
89100
elif method == "sys":
90-
return MultiCapture(out=True, err=True, Capture=SysCapture)
101+
return MultiCapture(out=True, err=True, Capture=SysCapture, capman=self)
91102
elif method == "no":
92-
return MultiCapture(out=False, err=False, in_=False)
103+
return MultiCapture(out=False, err=False, in_=False, capman=self)
93104
raise ValueError("unknown capturing method: %r" % method) # pragma: no cover
94105

95106
def is_capturing(self):
@@ -451,13 +462,13 @@ class MultiCapture:
451462
out = err = in_ = None
452463
_state = None
453464

454-
def __init__(self, out=True, err=True, in_=True, Capture=None):
465+
def __init__(self, out=True, err=True, in_=True, Capture=None, capman: CaptureManager = None):
455466
if in_:
456-
self.in_ = Capture(0)
467+
self.in_ = Capture(0, capman=capman)
457468
if out:
458-
self.out = Capture(1)
469+
self.out = Capture(1, capman=capman)
459470
if err:
460-
self.err = Capture(2)
471+
self.err = Capture(2, capman=capman)
461472

462473
def __repr__(self):
463474
return "<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>".format(
@@ -540,8 +551,9 @@ class FDCaptureBinary:
540551
EMPTY_BUFFER = b""
541552
_state = None
542553

543-
def __init__(self, targetfd, tmpfile=None):
554+
def __init__(self, targetfd, tmpfile=None, capman: CaptureManager = None):
544555
self.targetfd = targetfd
556+
self._capman = capman
545557
try:
546558
self.targetfd_save = os.dup(self.targetfd)
547559
except OSError:
@@ -553,14 +565,14 @@ def __init__(self, targetfd, tmpfile=None):
553565
if targetfd == 0:
554566
assert not tmpfile, "cannot set tmpfile with stdin"
555567
tmpfile = open(os.devnull, "r")
556-
self.syscapture = SysCapture(targetfd)
568+
self.syscapture = SysCapture(targetfd, capman=self._capman)
557569
else:
558570
if tmpfile is None:
559571
f = TemporaryFile()
560572
with f:
561573
tmpfile = safe_text_dupfile(f, mode="wb+")
562574
if targetfd in patchsysdict:
563-
self.syscapture = SysCapture(targetfd, tmpfile)
575+
self.syscapture = SysCapture(targetfd, tmpfile, capman)
564576
else:
565577
self.syscapture = NoCapture()
566578
self.tmpfile = tmpfile
@@ -595,9 +607,12 @@ def _done(self):
595607
os.dup2(targetfd_save, self.targetfd)
596608
os.close(targetfd_save)
597609
self.syscapture.done()
598-
# Redirect any remaining output.
599-
os.dup2(self.targetfd, self.tmpfile_fd)
600-
atexit.register(self.tmpfile.close)
610+
if self._capman:
611+
# Redirect any remaining output.
612+
os.dup2(self.targetfd, self.tmpfile_fd)
613+
self._capman._atexit_register(self.tmpfile.close)
614+
else:
615+
self.tmpfile.close()
601616
self._state = "done"
602617

603618
def suspend(self):
@@ -639,8 +654,9 @@ class SysCapture:
639654
EMPTY_BUFFER = str()
640655
_state = None
641656

642-
def __init__(self, fd, tmpfile=None):
657+
def __init__(self, fd, tmpfile=None, capman: CaptureManager = None):
643658
name = patchsysdict[fd]
659+
self._capman = capman
644660
self._old = getattr(sys, name)
645661
self.name = name
646662
if tmpfile is None:
@@ -668,7 +684,10 @@ def snap(self):
668684
def done(self):
669685
setattr(sys, self.name, self._old)
670686
del self._old
671-
atexit.register(self.tmpfile.close)
687+
if self._capman:
688+
self._capman._atexit_register(self.tmpfile.close)
689+
else:
690+
self.tmpfile.close()
672691
self._state = "done"
673692

674693
def suspend(self):

0 commit comments

Comments
 (0)