10
10
import sys
11
11
from io import UnsupportedOperation
12
12
from tempfile import TemporaryFile
13
+ from typing import Callable
14
+ from typing import List
13
15
14
16
import pytest
15
17
from _pytest .compat import CaptureIO
@@ -77,19 +79,28 @@ def __init__(self, method):
77
79
self ._method = method
78
80
self ._global_capturing = None
79
81
self ._current_item = None
82
+ self ._atexit_funcs : List [Callable ] = []
83
+ atexit .register (self ._atexit_run )
80
84
81
85
def __repr__ (self ):
82
86
return "<CaptureManager _method={!r} _global_capturing={!r} _current_item={!r}>" .format (
83
87
self ._method , self ._global_capturing , self ._current_item
84
88
)
85
89
90
+ def _atexit_register (self , func ):
91
+ self ._atexit_funcs .append (func )
92
+
93
+ def _atexit_run (self ):
94
+ for func in self ._atexit_funcs :
95
+ func ()
96
+
86
97
def _getcapture (self , method ):
87
98
if method == "fd" :
88
- return MultiCapture (out = True , err = True , Capture = FDCapture )
99
+ return MultiCapture (out = True , err = True , Capture = FDCapture , capman = self )
89
100
elif method == "sys" :
90
- return MultiCapture (out = True , err = True , Capture = SysCapture )
101
+ return MultiCapture (out = True , err = True , Capture = SysCapture , capman = self )
91
102
elif method == "no" :
92
- return MultiCapture (out = False , err = False , in_ = False )
103
+ return MultiCapture (out = False , err = False , in_ = False , capman = self )
93
104
raise ValueError ("unknown capturing method: %r" % method ) # pragma: no cover
94
105
95
106
def is_capturing (self ):
@@ -451,13 +462,13 @@ class MultiCapture:
451
462
out = err = in_ = None
452
463
_state = None
453
464
454
- def __init__ (self , out = True , err = True , in_ = True , Capture = None ):
465
+ def __init__ (self , out = True , err = True , in_ = True , Capture = None , capman : CaptureManager = None ):
455
466
if in_ :
456
- self .in_ = Capture (0 )
467
+ self .in_ = Capture (0 , capman = capman )
457
468
if out :
458
- self .out = Capture (1 )
469
+ self .out = Capture (1 , capman = capman )
459
470
if err :
460
- self .err = Capture (2 )
471
+ self .err = Capture (2 , capman = capman )
461
472
462
473
def __repr__ (self ):
463
474
return "<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>" .format (
@@ -540,8 +551,9 @@ class FDCaptureBinary:
540
551
EMPTY_BUFFER = b""
541
552
_state = None
542
553
543
- def __init__ (self , targetfd , tmpfile = None ):
554
+ def __init__ (self , targetfd , tmpfile = None , capman : CaptureManager = None ):
544
555
self .targetfd = targetfd
556
+ self ._capman = capman
545
557
try :
546
558
self .targetfd_save = os .dup (self .targetfd )
547
559
except OSError :
@@ -553,14 +565,14 @@ def __init__(self, targetfd, tmpfile=None):
553
565
if targetfd == 0 :
554
566
assert not tmpfile , "cannot set tmpfile with stdin"
555
567
tmpfile = open (os .devnull , "r" )
556
- self .syscapture = SysCapture (targetfd )
568
+ self .syscapture = SysCapture (targetfd , capman = self . _capman )
557
569
else :
558
570
if tmpfile is None :
559
571
f = TemporaryFile ()
560
572
with f :
561
573
tmpfile = safe_text_dupfile (f , mode = "wb+" )
562
574
if targetfd in patchsysdict :
563
- self .syscapture = SysCapture (targetfd , tmpfile )
575
+ self .syscapture = SysCapture (targetfd , tmpfile , capman )
564
576
else :
565
577
self .syscapture = NoCapture ()
566
578
self .tmpfile = tmpfile
@@ -595,9 +607,12 @@ def _done(self):
595
607
os .dup2 (targetfd_save , self .targetfd )
596
608
os .close (targetfd_save )
597
609
self .syscapture .done ()
598
- # Redirect any remaining output.
599
- os .dup2 (self .targetfd , self .tmpfile_fd )
600
- atexit .register (self .tmpfile .close )
610
+ if self ._capman :
611
+ # Redirect any remaining output.
612
+ os .dup2 (self .targetfd , self .tmpfile_fd )
613
+ self ._capman ._atexit_register (self .tmpfile .close )
614
+ else :
615
+ self .tmpfile .close ()
601
616
self ._state = "done"
602
617
603
618
def suspend (self ):
@@ -639,8 +654,9 @@ class SysCapture:
639
654
EMPTY_BUFFER = str ()
640
655
_state = None
641
656
642
- def __init__ (self , fd , tmpfile = None ):
657
+ def __init__ (self , fd , tmpfile = None , capman : CaptureManager = None ):
643
658
name = patchsysdict [fd ]
659
+ self ._capman = capman
644
660
self ._old = getattr (sys , name )
645
661
self .name = name
646
662
if tmpfile is None :
@@ -668,7 +684,10 @@ def snap(self):
668
684
def done (self ):
669
685
setattr (sys , self .name , self ._old )
670
686
del self ._old
671
- atexit .register (self .tmpfile .close )
687
+ if self ._capman :
688
+ self ._capman ._atexit_register (self .tmpfile .close )
689
+ else :
690
+ self .tmpfile .close ()
672
691
self ._state = "done"
673
692
674
693
def suspend (self ):
0 commit comments