-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
We could use the following registry that stores our f->jet_f
@dataclass
class JetRegistry:
_map: dict[Callable, Callable] = field(default_factory=dict)
_frozen: bool = False
def register(self, *funcs: Callable, allow_override: bool = False):
def deco(jet_f: Callable):
for f in funcs:
if not allow_override and f in self._map:
raise RuntimeError(
f"Jet for {getattr(f,'__name__',f)} already exists"
)
self._map[f] = jet_f
return jet_f
return deco
def get(self, f: Callable) -> Callable:
return self._map[f]
def mapping(self) -> dict[Callable, Callable]:
return dict(self._map) # copy
"""
maybe also the possibility to freeze the registry
def freeze(self):
self._frozen = True
"""If JET = JetRegistry(), users could extend the registry using
@JET.register(torch.sin, math.sin)
def jet_sin(
s: PrimalAndCoefficients, K: int, is_taylor: tuple[bool, ...]
) -> ValueAndCoefficients:
passWhat do you think?
Metadata
Metadata
Assignees
Labels
No labels