Skip to content

Commit b7054cd

Browse files
committed
include parent classes in process name/class reverse lookup
1 parent 3e2b8b0 commit b7054cd

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

xsimlab/model.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, processes_cls):
4343
self._processes_cls = processes_cls
4444
self._processes_obj = {k: cls() for k, cls in processes_cls.items()}
4545

46-
self._reverse_lookup = {cls: k for k, cls in processes_cls.items()}
46+
self._reverse_lookup = self._get_reverse_lookup(processes_cls)
4747

4848
self._input_vars = None
4949

@@ -53,6 +53,24 @@ def __init__(self, processes_cls):
5353
# a cache for group keys
5454
self._group_keys = {}
5555

56+
def _get_reverse_lookup(self, processes_cls):
57+
"""Return a dictionary with process classes as keys and process names
58+
as values.
59+
60+
Additionally, the returned dictionary maps all parent classes
61+
to one (str) or several (list) process names.
62+
63+
"""
64+
reverse_lookup = defaultdict(list)
65+
66+
for p_name, p_cls in processes_cls.items():
67+
# exclude `object` base class from lookup
68+
for cls in p_cls.mro()[:-1]:
69+
reverse_lookup[cls].append(p_name)
70+
71+
return {k: v[0] if len(v) == 1 else v
72+
for k, v in reverse_lookup.items()}
73+
5674
def bind_processes(self, model_obj):
5775
for p_name, p_obj in self._processes_obj.items():
5876
p_obj.__xsimlab_model__ = model_obj
@@ -92,6 +110,17 @@ def _get_var_key(self, p_name, var):
92110
.format(target_p_cls.__name__, var.name, p_name)
93111
)
94112

113+
elif isinstance(target_p_name, list):
114+
raise ValueError(
115+
"Process class {!r} required by foreign variable '{}.{}' "
116+
"is used (possibly via one its child classes) by multiple "
117+
"processes: {}"
118+
.format(
119+
target_p_cls.__name__, p_name, var.name,
120+
', '.join(['{!r}'.format(n) for n in target_p_name])
121+
)
122+
)
123+
95124
store_key, od_key = self._get_var_key(target_p_name, target_var)
96125

97126
elif var_type == VarType.GROUP:

0 commit comments

Comments
 (0)