Skip to content

Commit aaff012

Browse files
maffoopavoljuhas
andauthored
Add cached_method decorator for per-instance method caches (quantumlib#5570)
Co-authored-by: Pavol Juhas <juhas@google.com>
1 parent bab8299 commit aaff012

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

cirq/_compat.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import traceback
2424
import warnings
2525
from types import ModuleType
26-
from typing import Any, Callable, Optional, Dict, Tuple, Type, Set
26+
from typing import Any, Callable, Dict, Optional, overload, Set, Tuple, Type, TypeVar
2727

2828
import numpy as np
2929
import pandas as pd
@@ -39,6 +39,54 @@
3939
from backports.cached_property import cached_property # type: ignore[no-redef]
4040

4141

42+
TFunc = TypeVar('TFunc', bound=Callable)
43+
44+
45+
@overload
46+
def cached_method(__func: TFunc) -> TFunc:
47+
...
48+
49+
50+
@overload
51+
def cached_method(*, maxsize: int = 128) -> Callable[[TFunc], TFunc]:
52+
...
53+
54+
55+
def cached_method(method: Optional[TFunc] = None, *, maxsize: int = 128) -> Any:
56+
"""Decorator that adds a per-instance LRU cache for a method.
57+
58+
Can be applied with or without parameters to customize the underlying cache:
59+
60+
@cached_method
61+
def foo(self, name: str) -> int:
62+
...
63+
64+
@cached_method(maxsize=1000)
65+
def bar(self, name: str) -> int:
66+
...
67+
"""
68+
69+
def decorator(func):
70+
cache_name = f'_{func.__name__}_cache'
71+
72+
@functools.wraps(func)
73+
def wrapped(self, *args, **kwargs):
74+
cached = getattr(self, cache_name, None)
75+
if cached is None:
76+
77+
@functools.lru_cache(maxsize=maxsize)
78+
def cached_func(*args, **kwargs):
79+
return func(self, *args, **kwargs)
80+
81+
object.__setattr__(self, cache_name, cached_func)
82+
cached = cached_func
83+
return cached(*args, **kwargs)
84+
85+
return wrapped
86+
87+
return decorator if method is None else decorator(method)
88+
89+
4290
def proper_repr(value: Any) -> str:
4391
"""Overrides sympy and numpy returning repr strings that don't parse."""
4492

cirq/_compat_test.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import collections
1415
import dataclasses
1516
import importlib
1617
import logging
@@ -21,7 +22,7 @@
2122
import types
2223
import warnings
2324
from types import ModuleType
24-
from typing import Any, Callable, Optional
25+
from typing import Any, Callable, Dict, Optional, Tuple
2526
from importlib.machinery import ModuleSpec
2627
from unittest import mock
2728

@@ -35,6 +36,7 @@
3536
import cirq.testing
3637
from cirq._compat import (
3738
block_overlapping_deprecation,
39+
cached_method,
3840
cached_property,
3941
proper_repr,
4042
dataclass_repr,
@@ -985,3 +987,33 @@ def bar(self):
985987
bar2 = foo.bar
986988
assert bar2 is bar
987989
assert foo.bar_calls == 1
990+
991+
992+
class Bar:
993+
def __init__(self):
994+
self.foo_calls: Dict[int, int] = collections.Counter()
995+
self.bar_calls: Dict[int, int] = collections.Counter()
996+
997+
@cached_method
998+
def foo(self, n: int) -> Tuple[int, int]:
999+
self.foo_calls[n] += 1
1000+
return (id(self), n)
1001+
1002+
@cached_method(maxsize=1)
1003+
def bar(self, n: int) -> Tuple[int, int]:
1004+
self.bar_calls[n] += 1
1005+
return (id(self), 2 * n)
1006+
1007+
1008+
def test_cached_method():
1009+
b = Bar()
1010+
assert b.foo(123) == b.foo(123) == b.foo(123) == (id(b), 123)
1011+
assert b.foo(234) == b.foo(234) == b.foo(234) == (id(b), 234)
1012+
assert b.foo_calls == {123: 1, 234: 1}
1013+
1014+
assert b.bar(123) == b.bar(123) == (id(b), 123 * 2)
1015+
assert b.bar_calls == {123: 1}
1016+
assert b.bar(234) == b.bar(234) == (id(b), 234 * 2)
1017+
assert b.bar_calls == {123: 1, 234: 1}
1018+
assert b.bar(123) == b.bar(123) == (id(b), 123 * 2)
1019+
assert b.bar_calls == {123: 2, 234: 1}

0 commit comments

Comments
 (0)