Skip to content

Commit c78569b

Browse files
author
ibriquem
committed
Fix reload with assertion rewrite import hook
1 parent cc793a8 commit c78569b

File tree

2 files changed

+52
-10
lines changed

2 files changed

+52
-10
lines changed

src/_pytest/assertion/rewrite.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,17 +197,17 @@ def _warn_already_imported(self, name):
197197
)
198198

199199
def load_module(self, name):
200-
# If there is an existing module object named 'fullname' in
201-
# sys.modules, the loader must use that existing module. (Otherwise,
202-
# the reload() builtin will not work correctly.)
203-
if name in sys.modules:
204-
return sys.modules[name]
205-
206200
co, pyc = self.modules.pop(name)
207-
# I wish I could just call imp.load_compiled here, but __file__ has to
208-
# be set properly. In Python 3.2+, this all would be handled correctly
209-
# by load_compiled.
210-
mod = sys.modules[name] = imp.new_module(name)
201+
if name in sys.modules:
202+
# If there is an existing module object named 'fullname' in
203+
# sys.modules, the loader must use that existing module. (Otherwise,
204+
# the reload() builtin will not work correctly.)
205+
mod = sys.modules[name]
206+
else:
207+
# I wish I could just call imp.load_compiled here, but __file__ has to
208+
# be set properly. In Python 3.2+, this all would be handled correctly
209+
# by load_compiled.
210+
mod = sys.modules[name] = imp.new_module(name)
211211
try:
212212
mod.__file__ = co.co_filename
213213
# Normally, this attribute is 3.2+.

testing/test_assertrewrite.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,48 @@ def test_loader():
10501050
result = testdir.runpytest("-s")
10511051
result.stdout.fnmatch_lines(["* 1 passed*"])
10521052

1053+
def test_reload_reloads(self, testdir):
1054+
"""Reloading a module after change picks up the change."""
1055+
testdir.tmpdir.join("file.py").write(
1056+
textwrap.dedent(
1057+
"""
1058+
def reloaded():
1059+
return False
1060+
1061+
def rewrite_self():
1062+
with open(__file__, 'w') as self:
1063+
self.write('def reloaded(): return True')
1064+
"""
1065+
)
1066+
)
1067+
testdir.tmpdir.join("pytest.ini").write(
1068+
textwrap.dedent(
1069+
"""
1070+
[pytest]
1071+
python_files = *.py
1072+
"""
1073+
)
1074+
)
1075+
1076+
testdir.makepyfile(
1077+
test_fun="""
1078+
import sys
1079+
try:
1080+
from imp import reload
1081+
except ImportError:
1082+
pass
1083+
1084+
def test_loader():
1085+
import file
1086+
assert not file.reloaded()
1087+
file.rewrite_self()
1088+
reload(file)
1089+
assert file.reloaded()
1090+
"""
1091+
)
1092+
result = testdir.runpytest("-s")
1093+
result.stdout.fnmatch_lines(["* 1 passed*"])
1094+
10531095
def test_get_data_support(self, testdir):
10541096
"""Implement optional PEP302 api (#808).
10551097
"""

0 commit comments

Comments
 (0)