File tree Expand file tree Collapse file tree 1 file changed +10
-4
lines changed Expand file tree Collapse file tree 1 file changed +10
-4
lines changed Original file line number Diff line number Diff line change
1
+ from typing import Dict , List , Tuple
2
+
1
3
import torch
2
4
import torch .nn as nn
3
- from transformers import BertModel , BertTokenizer , BertConfig
4
5
import torch .nn .functional as F
5
- from typing import Tuple , List , Dict
6
+ from transformers import BertConfig , BertModel , BertTokenizer
6
7
7
8
8
9
# Sample Pool Model (for testing plugin serialization)
@@ -182,8 +183,13 @@ def BertModule():
182
183
intermediate_size = 3072 ,
183
184
torchscript = True ,
184
185
)
185
- model = BertModel (config )
186
+ model_kwargs = {
187
+ "use_cache" : False ,
188
+ "output_attentions" : False ,
189
+ "output_hidden_states" : False ,
190
+ "torchscript" : True ,
191
+ }
192
+ model = BertModel .from_pretrained (model_name , config = config , ** model_kwargs )
186
193
model .eval ()
187
- model = BertModel .from_pretrained (model_name , torchscript = True )
188
194
traced_model = torch .jit .trace (model , [tokens_tensor , segments_tensors ])
189
195
return traced_model
You can’t perform that action at this time.
0 commit comments