Skip to content

Commit efdbb72

Browse files
danieldksvlandeg
andauthored
Store activations in Docs when save_activations is enabled (#11002)
* Store activations in Doc when `store_activations` is enabled This change adds the new `activations` attribute to `Doc`. This attribute can be used by trainable pipes to store their activations, probabilities, and guesses for downstream users. As an example, this change modifies the `tagger` and `senter` pipes to add an `store_activations` option. When this option is enabled, the probabilities and guesses are stored in `set_annotations`. * Change type of `store_activations` to `Union[bool, List[str]]` When the value is: - A bool: all activations are stored when set to `True`. - A List[str]: the activations named in the list are stored * Formatting fixes in Tagger * Support store_activations in spancat and morphologizer * Make Doc.activations type visible to MyPy * textcat/textcat_multilabel: add store_activations option * trainable_lemmatizer/entity_linker: add store_activations option * parser/ner: do not currently support returning activations * Extend tagger and senter tests So that they, like the other tests, also check that we get no activations if no activations were requested. * Document `Doc.activations` and `store_activations` in the relevant pipes * Start errors/warnings at higher numbers to avoid merge conflicts Between the master and v4 branches. * Add `store_activations` to docstrings. * Replace store_activations setter by set_store_activations method Setters that take a different type than what the getter returns are still problematic for MyPy. Replace the setter by a method, so that type inference works everywhere. * Use dict comprehension suggested by @svlandeg * Revert "Use dict comprehension suggested by @svlandeg" This reverts commit 6e7b958. * EntityLinker: add type annotations to _add_activations * _store_activations: make kwarg-only, remove doc_scores_lens arg * set_annotations: add type annotations * Apply suggestions from code review Co-authored-by: Sofie Van Landeghem <[email protected]> * TextCat.predict: return dict * Make the `TrainablePipe.store_activations` property a bool This means that we can also bring back `store_activations` setter. * Remove `TrainablePipe.activations` We do not need to enumerate the activations anymore since `store_activations` is `bool`. * Add type annotations for activations in predict/set_annotations * Rename `TrainablePipe.store_activations` to `save_activations` * Error E1400 is not used anymore This error was used when activations were still `Union[bool, List[str]]`. * Change wording in API docs after store -> save change * docs: tag (save_)activations as new in spaCy 4.0 * Fix copied line in morphologizer activations test * Don't train in any test_save_activations test * Rename activations - "probs" -> "probabilities" - "guesses" -> "label_ids", except in the edit tree lemmatizer, where "guesses" -> "tree_ids". * Remove unused W400 warning. This warning was used when we still allowed the user to specify which activations to save. * Formatting fixes Co-authored-by: Sofie Van Landeghem <[email protected]> * Replace "kb_ids" by a constant * spancat: replace a cast by an assertion * Fix EOF spacing * Fix comments in test_save_activations tests * Do not set RNG seed in activation saving tests * Revert "spancat: replace a cast by an assertion" This reverts commit 0bd5730. Co-authored-by: Sofie Van Landeghem <[email protected]>
1 parent 60c050e commit efdbb72

28 files changed

+580
-130
lines changed

spacy/pipeline/edit_tree_lemmatizer.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import srsly
99
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
10-
from thinc.types import Floats2d, Ints1d, Ints2d
10+
from thinc.types import ArrayXd, Floats2d, Ints1d
1111

1212
from ._edit_tree_internals.edit_trees import EditTrees
1313
from ._edit_tree_internals.schemas import validate_edit_tree
@@ -21,6 +21,9 @@
2121
from .. import util
2222

2323

24+
ActivationsT = Dict[str, Union[List[Floats2d], List[Ints1d]]]
25+
26+
2427
default_model_config = """
2528
[model]
2629
@architectures = "spacy.Tagger.v2"
@@ -49,6 +52,7 @@
4952
"overwrite": False,
5053
"top_k": 1,
5154
"scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"},
55+
"save_activations": False,
5256
},
5357
default_score_weights={"lemma_acc": 1.0},
5458
)
@@ -61,6 +65,7 @@ def make_edit_tree_lemmatizer(
6165
overwrite: bool,
6266
top_k: int,
6367
scorer: Optional[Callable],
68+
save_activations: bool,
6469
):
6570
"""Construct an EditTreeLemmatizer component."""
6671
return EditTreeLemmatizer(
@@ -72,6 +77,7 @@ def make_edit_tree_lemmatizer(
7277
overwrite=overwrite,
7378
top_k=top_k,
7479
scorer=scorer,
80+
save_activations=save_activations,
7581
)
7682

7783

@@ -91,6 +97,7 @@ def __init__(
9197
overwrite: bool = False,
9298
top_k: int = 1,
9399
scorer: Optional[Callable] = lemmatizer_score,
100+
save_activations: bool = False,
94101
):
95102
"""
96103
Construct an edit tree lemmatizer.
@@ -102,6 +109,7 @@ def __init__(
102109
frequency in the training data.
103110
overwrite (bool): overwrite existing lemma annotations.
104111
top_k (int): try to apply at most the k most probable edit trees.
112+
save_activations (bool): save model activations in Doc when annotating.
105113
"""
106114
self.vocab = vocab
107115
self.model = model
@@ -116,6 +124,7 @@ def __init__(
116124

117125
self.cfg: Dict[str, Any] = {"labels": []}
118126
self.scorer = scorer
127+
self.save_activations = save_activations
119128

120129
def get_loss(
121130
self, examples: Iterable[Example], scores: List[Floats2d]
@@ -144,21 +153,24 @@ def get_loss(
144153

145154
return float(loss), d_scores
146155

147-
def predict(self, docs: Iterable[Doc]) -> List[Ints2d]:
156+
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
148157
n_docs = len(list(docs))
149158
if not any(len(doc) for doc in docs):
150159
# Handle cases where there are no tokens in any docs.
151160
n_labels = len(self.cfg["labels"])
152-
guesses: List[Ints2d] = [
161+
guesses: List[Ints1d] = [
162+
self.model.ops.alloc((0,), dtype="i") for doc in docs
163+
]
164+
scores: List[Floats2d] = [
153165
self.model.ops.alloc((0, n_labels), dtype="i") for doc in docs
154166
]
155167
assert len(guesses) == n_docs
156-
return guesses
168+
return {"probabilities": scores, "tree_ids": guesses}
157169
scores = self.model.predict(docs)
158170
assert len(scores) == n_docs
159171
guesses = self._scores2guesses(docs, scores)
160172
assert len(guesses) == n_docs
161-
return guesses
173+
return {"probabilities": scores, "tree_ids": guesses}
162174

163175
def _scores2guesses(self, docs, scores):
164176
guesses = []
@@ -186,8 +198,13 @@ def _scores2guesses(self, docs, scores):
186198

187199
return guesses
188200

189-
def set_annotations(self, docs: Iterable[Doc], batch_tree_ids):
201+
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
202+
batch_tree_ids = activations["tree_ids"]
190203
for i, doc in enumerate(docs):
204+
if self.save_activations:
205+
doc.activations[self.name] = {}
206+
for act_name, acts in activations.items():
207+
doc.activations[self.name][act_name] = acts[i]
191208
doc_tree_ids = batch_tree_ids[i]
192209
if hasattr(doc_tree_ids, "get"):
193210
doc_tree_ids = doc_tree_ids.get()

spacy/pipeline/entity_linker.py

Lines changed: 98 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
from typing import Optional, Iterable, Callable, Dict, Union, List, Any
2-
from thinc.types import Floats2d
1+
from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any
2+
from typing import cast
3+
from numpy import dtype
4+
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
35
from pathlib import Path
46
from itertools import islice
57
import srsly
@@ -21,6 +23,11 @@
2123
from .. import util
2224
from ..scorer import Scorer
2325

26+
27+
ActivationsT = Dict[str, Union[List[Ragged], List[str]]]
28+
29+
KNOWLEDGE_BASE_IDS = "kb_ids"
30+
2431
# See #9050
2532
BACKWARD_OVERWRITE = True
2633

@@ -57,6 +64,7 @@
5764
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
5865
"use_gold_ents": True,
5966
"threshold": None,
67+
"save_activations": False,
6068
},
6169
default_score_weights={
6270
"nel_micro_f": 1.0,
@@ -79,6 +87,7 @@ def make_entity_linker(
7987
scorer: Optional[Callable],
8088
use_gold_ents: bool,
8189
threshold: Optional[float] = None,
90+
save_activations: bool,
8291
):
8392
"""Construct an EntityLinker component.
8493
@@ -97,6 +106,7 @@ def make_entity_linker(
97106
component must provide entity annotations.
98107
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
99108
prediction is discarded. If None, predictions are not filtered by any threshold.
109+
save_activations (bool): save model activations in Doc when annotating.
100110
"""
101111

102112
if not model.attrs.get("include_span_maker", False):
@@ -128,6 +138,7 @@ def make_entity_linker(
128138
scorer=scorer,
129139
use_gold_ents=use_gold_ents,
130140
threshold=threshold,
141+
save_activations=save_activations,
131142
)
132143

133144

@@ -164,6 +175,7 @@ def __init__(
164175
scorer: Optional[Callable] = entity_linker_score,
165176
use_gold_ents: bool,
166177
threshold: Optional[float] = None,
178+
save_activations: bool = False,
167179
) -> None:
168180
"""Initialize an entity linker.
169181
@@ -212,6 +224,7 @@ def __init__(
212224
self.scorer = scorer
213225
self.use_gold_ents = use_gold_ents
214226
self.threshold = threshold
227+
self.save_activations = save_activations
215228

216229
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
217230
"""Define the KB of this pipe by providing a function that will
@@ -397,7 +410,7 @@ def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
397410
loss = loss / len(entity_encodings)
398411
return float(loss), out
399412

400-
def predict(self, docs: Iterable[Doc]) -> List[str]:
413+
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
401414
"""Apply the pipeline's model to a batch of docs, without modifying them.
402415
Returns the KB IDs for each entity in each doc, including NIL if there is
403416
no prediction.
@@ -410,13 +423,20 @@ def predict(self, docs: Iterable[Doc]) -> List[str]:
410423
self.validate_kb()
411424
entity_count = 0
412425
final_kb_ids: List[str] = []
413-
xp = self.model.ops.xp
426+
ops = self.model.ops
427+
xp = ops.xp
428+
docs_ents: List[Ragged] = []
429+
docs_scores: List[Ragged] = []
414430
if not docs:
415-
return final_kb_ids
431+
return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores}
416432
if isinstance(docs, Doc):
417433
docs = [docs]
418-
for i, doc in enumerate(docs):
434+
for doc in docs:
435+
doc_ents: List[Ints1d] = []
436+
doc_scores: List[Floats1d] = []
419437
if len(doc) == 0:
438+
docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
439+
docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
420440
continue
421441
sentences = [s for s in doc.sents]
422442
# Looping through each entity (TODO: rewrite)
@@ -439,14 +459,32 @@ def predict(self, docs: Iterable[Doc]) -> List[str]:
439459
if ent.label_ in self.labels_discard:
440460
# ignoring this entity - setting to NIL
441461
final_kb_ids.append(self.NIL)
462+
self._add_activations(
463+
doc_scores=doc_scores,
464+
doc_ents=doc_ents,
465+
scores=[0.0],
466+
ents=[0],
467+
)
442468
else:
443469
candidates = list(self.get_candidates(self.kb, ent))
444470
if not candidates:
445471
# no prediction possible for this entity - setting to NIL
446472
final_kb_ids.append(self.NIL)
473+
self._add_activations(
474+
doc_scores=doc_scores,
475+
doc_ents=doc_ents,
476+
scores=[0.0],
477+
ents=[0],
478+
)
447479
elif len(candidates) == 1 and self.threshold is None:
448480
# shortcut for efficiency reasons: take the 1 candidate
449481
final_kb_ids.append(candidates[0].entity_)
482+
self._add_activations(
483+
doc_scores=doc_scores,
484+
doc_ents=doc_ents,
485+
scores=[1.0],
486+
ents=[candidates[0].entity_],
487+
)
450488
else:
451489
random.shuffle(candidates)
452490
# set all prior probabilities to 0 if incl_prior=False
@@ -479,27 +517,48 @@ def predict(self, docs: Iterable[Doc]) -> List[str]:
479517
if self.threshold is None or scores.max() >= self.threshold
480518
else EntityLinker.NIL
481519
)
520+
self._add_activations(
521+
doc_scores=doc_scores,
522+
doc_ents=doc_ents,
523+
scores=scores,
524+
ents=[c.entity for c in candidates],
525+
)
526+
self._add_doc_activations(
527+
docs_scores=docs_scores,
528+
docs_ents=docs_ents,
529+
doc_scores=doc_scores,
530+
doc_ents=doc_ents,
531+
)
482532
if not (len(final_kb_ids) == entity_count):
483533
err = Errors.E147.format(
484534
method="predict", msg="result variables not of equal length"
485535
)
486536
raise RuntimeError(err)
487-
return final_kb_ids
537+
return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores}
488538

489-
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:
539+
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
490540
"""Modify a batch of documents, using pre-computed scores.
491541
492542
docs (Iterable[Doc]): The documents to modify.
493-
kb_ids (List[str]): The IDs to set, produced by EntityLinker.predict.
543+
activations (ActivationsT): The activations used for setting annotations, produced
544+
by EntityLinker.predict.
494545
495546
DOCS: https://spacy.io/api/entitylinker#set_annotations
496547
"""
548+
kb_ids = cast(List[str], activations[KNOWLEDGE_BASE_IDS])
497549
count_ents = len([ent for doc in docs for ent in doc.ents])
498550
if count_ents != len(kb_ids):
499551
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
500552
i = 0
501553
overwrite = self.cfg["overwrite"]
502-
for doc in docs:
554+
for j, doc in enumerate(docs):
555+
if self.save_activations:
556+
doc.activations[self.name] = {}
557+
for act_name, acts in activations.items():
558+
if act_name != KNOWLEDGE_BASE_IDS:
559+
# We only copy activations that are Ragged.
560+
doc.activations[self.name][act_name] = cast(Ragged, acts[j])
561+
503562
for ent in doc.ents:
504563
kb_id = kb_ids[i]
505564
i += 1
@@ -598,3 +657,32 @@ def rehearse(self, examples, *, sgd=None, losses=None, **config):
598657

599658
def add_label(self, label):
600659
raise NotImplementedError
660+
661+
def _add_doc_activations(
662+
self,
663+
*,
664+
docs_scores: List[Ragged],
665+
docs_ents: List[Ragged],
666+
doc_scores: List[Floats1d],
667+
doc_ents: List[Ints1d],
668+
):
669+
if not self.save_activations:
670+
return
671+
ops = self.model.ops
672+
lengths = ops.asarray1i([s.shape[0] for s in doc_scores])
673+
docs_scores.append(Ragged(ops.flatten(doc_scores), lengths))
674+
docs_ents.append(Ragged(ops.flatten(doc_ents), lengths))
675+
676+
def _add_activations(
677+
self,
678+
*,
679+
doc_scores: List[Floats1d],
680+
doc_ents: List[Ints1d],
681+
scores: Sequence[float],
682+
ents: Sequence[int],
683+
):
684+
if not self.save_activations:
685+
return
686+
ops = self.model.ops
687+
doc_scores.append(ops.asarray1f(scores))
688+
doc_ents.append(ops.asarray1i(ents, dtype="uint64"))

0 commit comments

Comments
 (0)