Skip to content

Commit 87312f0

Browse files
committed
fix: Torch nightly version 2
1 parent 45cbcd9 commit 87312f0

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

tests/modules/custom_models.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from typing import Dict, List, Tuple
2+
13
import torch
24
import torch.nn as nn
3-
from transformers import BertModel, BertTokenizer, BertConfig
45
import torch.nn.functional as F
5-
from typing import Tuple, List, Dict
6+
from transformers import BertConfig, BertModel, BertTokenizer
67

78

89
# Sample Pool Model (for testing plugin serialization)
@@ -182,8 +183,13 @@ def BertModule():
182183
intermediate_size=3072,
183184
torchscript=True,
184185
)
185-
model = BertModel(config)
186+
model_kwargs = {
187+
"use_cache": False,
188+
"output_attentions": False,
189+
"output_hidden_states": False,
190+
"torchscript": True,
191+
}
192+
model = BertModel.from_pretrained(model_name, config=config, **model_kwargs)
186193
model.eval()
187-
model = BertModel.from_pretrained(model_name, torchscript=True)
188194
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
189195
return traced_model

0 commit comments

Comments
 (0)