Skip to content

Fix model.copy() bug where layer used more than once #659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions thinc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,19 +462,42 @@ def copy(self: SelfT) -> SelfT:
layers will also be deep-copied. The copy will receive a distinct `model.id`
value.
"""
return self._copy()

def _copy(self: SelfT, seen: Optional[Dict[int, Union["Model", Shim]]] = None) -> SelfT:
if seen is None:
seen = {}
params = {}
for name in self.param_names:
params[name] = self.get_param(name) if self.has_param(name) else None

copied_layers: List[Model] = []
for layer in self.layers:
if id(layer) in seen:
copied_layers.append(cast(Model, seen[id(layer)]))
else:
copied_layer = layer._copy(seen)
seen[id(layer)] = copied_layer
copied_layers.append(copied_layer)

copied_shims = []
for shim in self.shims:
if id(shim) in seen:
copied_shims.append(cast(Shim, seen[id(shim)]))
else:
copied_shim = shim.copy()
seen[id(shim)] = copied_shim
copied_shims.append(copied_shim)

copied: Model[InT, OutT] = Model(
self.name,
self._func,
init=self.init,
params=copy.deepcopy(params),
dims=copy.deepcopy(self._dims),
attrs=copy.deepcopy(self._attrs),
layers=[layer.copy() for layer in self.layers],
shims=[shim.copy() for shim in self.shims],
layers=copied_layers,
shims=copied_shims,
)
for name in self.grad_names:
copied.set_grad(name, self.get_grad(name).copy())
Expand Down
45 changes: 43 additions & 2 deletions thinc/tests/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,10 @@ def test_all_operators(op):
with pytest.raises(TypeError):
value = m1 % m2
if op == "**":
value = m1 ** m2
value = m1**m2
else:
with pytest.raises(TypeError):
value = m1 ** m2
value = m1**m2
if op == "<<":
value = m1 << m2
else:
Expand Down Expand Up @@ -614,3 +614,44 @@ def test_walk_bfs_post_order_fails():
relu = Relu(5)
with pytest.raises(ValueError, match="Invalid order"):
relu.walk(order="dfs_post_order")


def test_model_copy_with_loop():
class MyShim(Shim):
name = "testshim"

def to_bytes(self):
return test_replace_node_with_indirect_node_ref

def from_bytes(self, bytes):
pass

model_a = create_model("a")
working_shim = MyShim(None)
layer = Model(
"test",
lambda X: (X, lambda dY: dY),
dims={"nI": 5, "nO": 5},
params={"W": numpy.zeros((10,)), "b": None},
refs={"a": model_a, "b": None},
attrs={"foo": "bar"},
shims=[working_shim],
layers=[model_a, model_a],
)
layer2 = Model(
"test2",
lambda X: (X, lambda dY: dY),
dims={"nI": 5, "nO": 5},
params={"W": numpy.zeros((10,)), "b": None},
refs={"a": model_a, "b": None},
attrs={"foo": "bar"},
shims=[working_shim],
layers=[model_a, model_a],
)
relu = Relu(5)
model = chain(layer, relu, layer, layer2)
model2 = model.copy()
model.from_dict(model2.to_dict())
assert model2.name == "test>>relu>>test>>test2"
assert model2.layers[0] == model2.layers[2]
assert id(model2.layers[0].shims[0]) == id(model2.layers[3].shims[0])