Skip to content

Commit 35ce37d

Browse files
committed
[Feature] Add test for recompiles of ReplayBuffer.extend
ghstack-source-id: f50d4ec Pull Request resolved: #2504
1 parent a27514c commit 35ce37d

File tree

3 files changed

+116
-2
lines changed

3 files changed

+116
-2
lines changed

test/_utils_internal.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from __future__ import annotations
66

77
import contextlib
8+
import logging
89
import os
910

1011
import os.path
1112
import time
13+
import unittest
1214
from functools import wraps
1315

1416
# Get relative file path
@@ -204,6 +206,31 @@ def f_retry(*args, **kwargs):
204206
return deco_retry
205207

206208

209+
# After calling this function, any log record whose name contains 'record_name'
210+
# and is emitted from the logger that has qualified name 'logger_qname' is
211+
# appended to the 'records' list.
212+
# NOTE: This function is based on testing utilities for 'torch._logging'
213+
def capture_log_records(records, logger_qname, record_name):
214+
assert isinstance(records, list)
215+
logger = logging.getLogger(logger_qname)
216+
217+
class EmitWrapper:
218+
def __init__(self, old_emit):
219+
self.old_emit = old_emit
220+
221+
def __call__(self, record):
222+
nonlocal records
223+
self.old_emit(record)
224+
if record_name in record.name:
225+
records.append(record)
226+
227+
for handler in logger.handlers:
228+
new_emit = EmitWrapper(handler.emit)
229+
contextlib.ExitStack().enter_context(
230+
unittest.mock.patch.object(handler, "emit", new_emit)
231+
)
232+
233+
207234
@pytest.fixture
208235
def dtype_fixture():
209236
dtype = torch.get_default_dtype()

test/test_rb.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
import pytest
1818
import torch
1919

20-
from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc
20+
from _utils_internal import (
21+
capture_log_records,
22+
CARTPOLE_VERSIONED,
23+
get_default_devices,
24+
make_tc,
25+
)
2126

2227
from mocking_classes import CountingEnv
2328
from packaging import version
@@ -399,6 +404,63 @@ def data_iter():
399404
) if cond else contextlib.nullcontext():
400405
rb.extend(data2)
401406

407+
def test_extend_recompile(self, rb_type, sampler, writer, storage, size, datatype):
408+
if rb_type is not ReplayBuffer:
409+
pytest.skip(
410+
"Only replay buffer of type 'ReplayBuffer' is currently supported."
411+
)
412+
if sampler in (PrioritizedSampler,):
413+
pytest.skip(f"Sampler of type '{sampler.__name__}' is not yet supported.")
414+
if storage is not LazyTensorStorage:
415+
pytest.skip(
416+
"Only storage of type 'LazyTensorStorage' is currently supported."
417+
)
418+
if writer is not RoundRobinWriter:
419+
pytest.skip(
420+
"Only writer of type 'RoundRobinWriter' is currently supported."
421+
)
422+
423+
torch.compiler.reset()
424+
425+
storage_size = 10 * size
426+
rb = self._get_rb(
427+
rb_type=rb_type,
428+
sampler=sampler,
429+
writer=writer,
430+
storage=storage,
431+
size=storage_size,
432+
)
433+
data_size = size
434+
data = self._get_data(datatype, size=data_size)
435+
436+
@torch.compile
437+
def extend(data):
438+
rb.extend(data)
439+
440+
# Number of times to extend the replay buffer
441+
num_extend = 30
442+
443+
# NOTE: The first two calls to 'extend' currently cause recompilations,
444+
# so avoid capturing those for now.
445+
num_extend_before_capture = 2
446+
447+
for _ in range(num_extend_before_capture):
448+
extend(data)
449+
450+
try:
451+
torch._logging.set_logs(recompiles=True)
452+
records = []
453+
capture_log_records(records, "torch._dynamo", "recompiles")
454+
455+
for _ in range(num_extend - num_extend_before_capture):
456+
extend(data)
457+
458+
assert len(records) == 0
459+
assert len(rb) == storage_size
460+
461+
finally:
462+
torch._logging.set_logs()
463+
402464
def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
403465
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
404466
pytest.skip(

test/test_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import torch
1616

17-
from _utils_internal import get_default_devices
17+
from _utils_internal import capture_log_records, get_default_devices
1818
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for
1919

2020
from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend
@@ -380,6 +380,31 @@ def test_rng_decorator(device):
380380
torch.testing.assert_close(s0b, s1b)
381381

382382

383+
# Check that 'capture_log_records' captures records emitted when torch
384+
# recompiles a function.
385+
def test_capture_log_records_recompile():
386+
torch.compiler.reset()
387+
388+
# This function recompiles each time it is called with a different string
389+
# input.
390+
@torch.compile
391+
def str_to_tensor(s):
392+
return bytes(s, "utf8")
393+
394+
str_to_tensor("a")
395+
396+
try:
397+
torch._logging.set_logs(recompiles=True)
398+
records = []
399+
capture_log_records(records, "torch._dynamo", "recompiles")
400+
str_to_tensor("b")
401+
402+
finally:
403+
torch._logging.set_logs()
404+
405+
assert len(records) == 1
406+
407+
383408
if __name__ == "__main__":
384409
args, unknown = argparse.ArgumentParser().parse_known_args()
385410
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)