|
14 | 14 |
|
15 | 15 | from fairseq import utils |
16 | 16 | from fairseq.models import ( |
17 | | - FairseqDecoder, |
18 | | - FairseqLanguageModel, |
| 17 | + FairseqEncoder, |
| 18 | + FairseqEncoderModel, |
19 | 19 | register_model, |
20 | 20 | register_model_architecture, |
21 | 21 | ) |
|
33 | 33 |
|
34 | 34 |
|
35 | 35 | @register_model('roberta') |
36 | | -class RobertaModel(FairseqLanguageModel): |
| 36 | +class RobertaModel(FairseqEncoderModel): |
37 | 37 |
|
38 | 38 | @classmethod |
39 | 39 | def hub_models(cls): |
@@ -116,12 +116,20 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, cla |
116 | 116 | if classification_head_name is not None: |
117 | 117 | features_only = True |
118 | 118 |
|
119 | | - x, extra = self.decoder(src_tokens, features_only, return_all_hiddens, **kwargs) |
| 119 | + x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs) |
120 | 120 |
|
121 | 121 | if classification_head_name is not None: |
122 | 122 | x = self.classification_heads[classification_head_name](x) |
123 | 123 | return x, extra |
124 | 124 |
|
| 125 | + def get_normalized_probs(self, net_output, log_probs, sample=None): |
| 126 | + """Get normalized probabilities (or log probs) from a net's output.""" |
| 127 | + logits = net_output[0].float() |
| 128 | + if log_probs: |
| 129 | + return F.log_softmax(logits, dim=-1) |
| 130 | + else: |
| 131 | + return F.softmax(logits, dim=-1) |
| 132 | + |
125 | 133 | def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): |
126 | 134 | """Register a classification head.""" |
127 | 135 | if name in self.classification_heads: |
@@ -163,13 +171,23 @@ def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_na |
163 | 171 | return RobertaHubInterface(x['args'], x['task'], x['models'][0]) |
164 | 172 |
|
165 | 173 | def upgrade_state_dict_named(self, state_dict, name): |
166 | | - super().upgrade_state_dict_named(state_dict, name) |
167 | | - |
168 | 174 | prefix = name + '.' if name != '' else '' |
169 | | - current_head_names = [] if not hasattr(self, 'classification_heads') else \ |
170 | | - self.classification_heads.keys() |
| 175 | + |
| 176 | + # rename decoder -> encoder before upgrading children modules |
| 177 | + for k in list(state_dict.keys()): |
| 178 | + if k.startswith(prefix + 'decoder'): |
| 179 | + new_k = prefix + 'encoder' + k[len(prefix + 'decoder'):] |
| 180 | + state_dict[new_k] = state_dict[k] |
| 181 | + del state_dict[k] |
| 182 | + |
| 183 | + # upgrade children modules |
| 184 | + super().upgrade_state_dict_named(state_dict, name) |
171 | 185 |
|
172 | 186 | # Handle new classification heads present in the state dict. |
| 187 | + current_head_names = ( |
| 188 | + [] if not hasattr(self, 'classification_heads') |
| 189 | + else self.classification_heads.keys() |
| 190 | + ) |
173 | 191 | keys_to_delete = [] |
174 | 192 | for k in state_dict.keys(): |
175 | 193 | if not k.startswith(prefix + 'classification_heads.'): |
@@ -261,24 +279,15 @@ def forward(self, features, **kwargs): |
261 | 279 | return x |
262 | 280 |
|
263 | 281 |
|
264 | | -class RobertaEncoder(FairseqDecoder): |
265 | | - """RoBERTa encoder. |
266 | | -
|
267 | | - Implements the :class:`~fairseq.models.FairseqDecoder` interface required |
268 | | - by :class:`~fairseq.models.FairseqLanguageModel`. |
269 | | - """ |
| 282 | +class RobertaEncoder(FairseqEncoder): |
| 283 | + """RoBERTa encoder.""" |
270 | 284 |
|
271 | 285 | def __init__(self, args, dictionary): |
272 | 286 | super().__init__(dictionary) |
273 | 287 | self.args = args |
274 | 288 |
|
275 | | - # RoBERTa is a sentence encoder model, so users will intuitively trim |
276 | | - # encoder layers. However, the implementation uses the fairseq decoder, |
277 | | - # so we fix here. |
278 | 289 | if args.encoder_layers_to_keep: |
279 | 290 | args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) |
280 | | - args.decoder_layers_to_keep = args.encoder_layers_to_keep |
281 | | - args.encoder_layers_to_keep = None |
282 | 291 |
|
283 | 292 | self.sentence_encoder = TransformerSentenceEncoder( |
284 | 293 | padding_idx=dictionary.pad(), |
|
0 commit comments