Skip to content

Commit a988fda

Browse files
Fix model.copy() bug where layer used more than once (#659)
* Fix model.copy() bug where layer used more than once * Expand functionality to include shims * Corrections after review * Added default for Model._copy()
1 parent 6d84d00 commit a988fda

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

thinc/model.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,19 +462,42 @@ def copy(self: SelfT) -> SelfT:
462462
layers will also be deep-copied. The copy will receive a distinct `model.id`
463463
value.
464464
"""
465+
return self._copy()
466+
467+
def _copy(self: SelfT, seen: Optional[Dict[int, Union["Model", Shim]]] = None) -> SelfT:
468+
if seen is None:
469+
seen = {}
465470
params = {}
466471
for name in self.param_names:
467472
params[name] = self.get_param(name) if self.has_param(name) else None
468473

474+
copied_layers: List[Model] = []
475+
for layer in self.layers:
476+
if id(layer) in seen:
477+
copied_layers.append(cast(Model, seen[id(layer)]))
478+
else:
479+
copied_layer = layer._copy(seen)
480+
seen[id(layer)] = copied_layer
481+
copied_layers.append(copied_layer)
482+
483+
copied_shims = []
484+
for shim in self.shims:
485+
if id(shim) in seen:
486+
copied_shims.append(cast(Shim, seen[id(shim)]))
487+
else:
488+
copied_shim = shim.copy()
489+
seen[id(shim)] = copied_shim
490+
copied_shims.append(copied_shim)
491+
469492
copied: Model[InT, OutT] = Model(
470493
self.name,
471494
self._func,
472495
init=self.init,
473496
params=copy.deepcopy(params),
474497
dims=copy.deepcopy(self._dims),
475498
attrs=copy.deepcopy(self._attrs),
476-
layers=[layer.copy() for layer in self.layers],
477-
shims=[shim.copy() for shim in self.shims],
499+
layers=copied_layers,
500+
shims=copied_shims,
478501
)
479502
for name in self.grad_names:
480503
copied.set_grad(name, self.get_grad(name).copy())

thinc/tests/model/test_model.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,3 +615,44 @@ def test_walk_bfs_post_order_fails():
615615
relu = Relu(5)
616616
with pytest.raises(ValueError, match="Invalid order"):
617617
relu.walk(order="dfs_post_order")
618+
619+
620+
def test_model_copy_with_loop():
621+
class MyShim(Shim):
622+
name = "testshim"
623+
624+
def to_bytes(self):
625+
return test_replace_node_with_indirect_node_ref
626+
627+
def from_bytes(self, bytes):
628+
pass
629+
630+
model_a = create_model("a")
631+
working_shim = MyShim(None)
632+
layer = Model(
633+
"test",
634+
lambda X: (X, lambda dY: dY),
635+
dims={"nI": 5, "nO": 5},
636+
params={"W": numpy.zeros((10,)), "b": None},
637+
refs={"a": model_a, "b": None},
638+
attrs={"foo": "bar"},
639+
shims=[working_shim],
640+
layers=[model_a, model_a],
641+
)
642+
layer2 = Model(
643+
"test2",
644+
lambda X: (X, lambda dY: dY),
645+
dims={"nI": 5, "nO": 5},
646+
params={"W": numpy.zeros((10,)), "b": None},
647+
refs={"a": model_a, "b": None},
648+
attrs={"foo": "bar"},
649+
shims=[working_shim],
650+
layers=[model_a, model_a],
651+
)
652+
relu = Relu(5)
653+
model = chain(layer, relu, layer, layer2)
654+
model2 = model.copy()
655+
model.from_dict(model2.to_dict())
656+
assert model2.name == "test>>relu>>test>>test2"
657+
assert model2.layers[0] == model2.layers[2]
658+
assert id(model2.layers[0].shims[0]) == id(model2.layers[3].shims[0])

0 commit comments

Comments
 (0)