1
1
# Copyright (C) 2024 Intel Corporation
2
2
# SPDX-License-Identifier: Apache-2.0
3
3
4
+ import asyncio
4
5
import os
5
- from typing import List , Union
6
6
7
- import torch
8
- import torch .nn as nn
9
- from einops import rearrange
10
- from transformers import AutoProcessor , AutoTokenizer , CLIPModel
7
+ import requests
11
8
12
9
from comps import CustomLogger , OpeaComponent , OpeaComponentRegistry , ServiceType
13
10
from comps .cores .proto .api_protocol import EmbeddingRequest , EmbeddingResponse , EmbeddingResponseData
16
13
logflag = os .getenv ("LOGFLAG" , False )
17
14
18
15
19
- model_name = "openai/clip-vit-base-patch32"
20
-
21
- clip = CLIPModel .from_pretrained (model_name )
22
- processor = AutoProcessor .from_pretrained (model_name )
23
- tokenizer = AutoTokenizer .from_pretrained (model_name )
24
-
25
-
26
- class vCLIP (nn .Module ):
27
- def __init__ (self , cfg ):
28
- super ().__init__ ()
29
-
30
- self .num_frm = cfg ["num_frm" ]
31
- self .model_name = cfg ["model_name" ]
32
-
33
- def embed_query (self , texts ):
34
- """Input is list of texts."""
35
- text_inputs = tokenizer (texts , padding = True , return_tensors = "pt" )
36
- text_features = clip .get_text_features (** text_inputs )
37
- return text_features
38
-
39
- def get_embedding_length (self ):
40
- text_features = self .embed_query ("sample_text" )
41
- return text_features .shape [1 ]
42
-
43
- def get_image_embeddings (self , images ):
44
- """Input is list of images."""
45
- image_inputs = processor (images = images , return_tensors = "pt" )
46
- image_features = clip .get_image_features (** image_inputs )
47
- return image_features
48
-
49
- def get_video_embeddings (self , frames_batch ):
50
- """Input is list of list of frames in video."""
51
- self .batch_size = len (frames_batch )
52
- vid_embs = []
53
- for frames in frames_batch :
54
- frame_embeddings = self .get_image_embeddings (frames )
55
- frame_embeddings = rearrange (frame_embeddings , "(b n) d -> b n d" , b = len (frames_batch ))
56
- # Normalize, mean aggregate and return normalized video_embeddings
57
- frame_embeddings = frame_embeddings / frame_embeddings .norm (dim = - 1 , keepdim = True )
58
- video_embeddings = frame_embeddings .mean (dim = 1 )
59
- video_embeddings = video_embeddings / video_embeddings .norm (dim = - 1 , keepdim = True )
60
- vid_embs .append (video_embeddings )
61
- return torch .cat (vid_embs , dim = 0 )
62
-
63
-
64
16
@OpeaComponentRegistry .register ("OPEA_CLIP_EMBEDDING" )
65
17
class OpeaClipEmbedding (OpeaComponent ):
66
18
"""A specialized embedding component derived from OpeaComponent for CLIP embedding services.
@@ -74,7 +26,7 @@ class OpeaClipEmbedding(OpeaComponent):
74
26
75
27
def __init__ (self , name : str , description : str , config : dict = None ):
76
28
super ().__init__ (name , ServiceType .EMBEDDING .name .lower (), description , config )
77
- self .embeddings = vCLIP ({ "model_name" : "openai/clip-vit-base-patch32" , "num_frm" : 4 } )
29
+ self .base_url = os . getenv ( "CLIP_EMBEDDING_ENDPOINT" , "http://localhost:6990" )
78
30
79
31
health_status = self .check_health ()
80
32
if not health_status :
@@ -89,46 +41,38 @@ async def invoke(self, input: EmbeddingRequest) -> EmbeddingResponse:
89
41
Returns:
90
42
EmbeddingResponse: The response in OpenAI embedding format, including embeddings, model, and usage information.
91
43
"""
92
- # Parse input according to the EmbeddingRequest format
93
- if isinstance (input .input , str ):
94
- texts = [input .input .replace ("\n " , " " )]
95
- elif isinstance (input .input , list ):
96
- if all (isinstance (item , str ) for item in input .input ):
97
- texts = [text .replace ("\n " , " " ) for text in input .input ]
98
- else :
99
- raise ValueError ("Invalid input format: Only string or list of strings are supported." )
100
- else :
101
- raise TypeError ("Unsupported input type: input must be a string or list of strings." )
102
- embed_vector = self .get_embeddings (texts )
103
- if input .dimensions is not None :
104
- embed_vector = [embed_vector [i ][: input .dimensions ] for i in range (len (embed_vector ))]
105
-
106
- # for standard openai embedding format
107
- res = EmbeddingResponse (
108
- data = [EmbeddingResponseData (index = i , embedding = embed_vector [i ]) for i in range (len (embed_vector ))]
109
- )
110
- return res
44
+ json_payload = input .model_dump ()
45
+ try :
46
+ response = await asyncio .to_thread (
47
+ requests .post ,
48
+ f"{ self .base_url } /v1/embeddings" ,
49
+ headers = {"Content-Type" : "application/json" },
50
+ json = json_payload ,
51
+ )
52
+ response .raise_for_status ()
53
+ response_json = response .json ()
54
+
55
+ return EmbeddingResponse (
56
+ data = [EmbeddingResponseData (** item ) for item in response_json .get ("data" , [])],
57
+ model = response_json .get ("model" , input .model ),
58
+ usage = response_json .get ("usage" , {}),
59
+ )
60
+ except requests .RequestException as e :
61
+ raise RuntimeError (f"Failed to invoke embedding service: { str (e )} " )
111
62
112
63
def check_health (self ) -> bool :
113
64
"""Checks if the embedding model is healthy.
114
65
115
66
Returns:
116
67
bool: True if the embedding model is initialized, False otherwise.
117
68
"""
118
- if self .embeddings :
69
+ try :
70
+ _ = requests .post (
71
+ f"{ self .base_url } /v1/embeddings" ,
72
+ headers = {"Content-Type" : "application/json" },
73
+ json = {"input" : "health check" },
74
+ )
75
+
119
76
return True
120
- else :
77
+ except requests . RequestException as e :
121
78
return False
122
-
123
- def get_embeddings (self , text : Union [str , List [str ]]) -> List [List [float ]]:
124
- """Generates embeddings for input text.
125
-
126
- Args:
127
- text (Union[str, List[str]]): Input text or list of texts.
128
-
129
- Returns:
130
- List[List[float]]: List of embedding vectors.
131
- """
132
- texts = [text ] if isinstance (text , str ) else text
133
- embed_vector = self .embeddings .embed_query (texts ).tolist ()
134
- return embed_vector
0 commit comments