Skip to content

Commit f164f0d

Browse files
embeddings: adaptive detect embedding model arguments in mosec (#296)
* embeddings: adaptive detect embedding model arguments in mosec Signed-off-by: Jincheng Miao <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jincheng Miao <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6b091c6 commit f164f0d

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

comps/embeddings/langchain-mosec/mosec-docker/server-ipex.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,14 @@ def __init__(self):
3434
d = torch.randint(vocab_size, size=[batch_size, seq_length])
3535
t = torch.randint(0, 1, size=[batch_size, seq_length])
3636
m = torch.randint(1, 2, size=[batch_size, seq_length])
37-
self.model = torch.jit.trace(self.model, [d, t, m], check_trace=False, strict=False)
37+
model_inputs = [d]
38+
if "token_type_ids" in self.tokenizer.model_input_names:
39+
model_inputs.append(t)
40+
if "attention_mask" in self.tokenizer.model_input_names:
41+
model_inputs.append(m)
42+
self.model = torch.jit.trace(self.model, model_inputs, check_trace=False, strict=False)
3843
self.model = torch.jit.freeze(self.model)
39-
self.model(d, t, m)
44+
self.model(*model_inputs)
4045

4146
def get_embedding_with_token_count(self, sentences: Union[str, List[Union[str, List[int]]]]):
4247
# Mean Pooling - Take attention mask into account for correct averaging

0 commit comments

Comments
 (0)