Skip to content

Commit 22549ee

Browse files
authored
修复vertexai模型无法通过配置的http代理正确访问的问题 (#844)
1 parent a292eac commit 22549ee

File tree

1 file changed

+78
-13
lines changed

1 file changed

+78
-13
lines changed

providers/vertexai/base.go

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,20 @@ type VertexAIProviderFactory struct{}
3333

3434
// 创建 VertexAIProvider
3535
func (f VertexAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
36+
proxyAddr := ""
37+
if channel.Proxy != nil {
38+
proxyAddr = *channel.Proxy
39+
}
40+
3641
vertexAIProvider := &VertexAIProvider{
3742
BaseProvider: base.BaseProvider{
3843
Config: getConfig(),
3944
Channel: channel,
40-
Requester: requester.NewHTTPRequester(*channel.Proxy, nil),
45+
Requester: requester.NewHTTPRequester(proxyAddr, nil),
4146
},
4247
}
4348

4449
getKeyConfig(vertexAIProvider)
45-
4650
return vertexAIProvider
4751
}
4852

@@ -109,16 +113,19 @@ func (p *VertexAIProvider) GetToken() (string, error) {
109113
return "", fmt.Errorf("failed to unmarshal credentials: %w", err)
110114
}
111115

112-
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
116+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
113117
defer cancel()
114118

115119
proxyAddr := ""
116120
if p.Channel.Proxy != nil && *p.Channel.Proxy != "" {
117121
proxyAddr = *p.Channel.Proxy
122+
logger.SysLog(fmt.Sprintf("Vertex AI proxy: %s", proxyAddr))
118123
}
119124

125+
// 尝试使用gRPC客户端获取token
120126
client, err := credentials.NewIamCredentialsClient(ctx, option.WithCredentialsJSON([]byte(p.Channel.Key)), option.WithGRPCDialOption(grpc.WithContextDialer(customDialer(proxyAddr))))
121127
if err != nil {
128+
logger.SysError(fmt.Sprintf("Failed to create IAM credentials client: %v", err))
122129
return "", fmt.Errorf("failed to create IAM credentials client: %w", err)
123130
}
124131
defer client.Close()
@@ -172,8 +179,6 @@ func errorHandle(vertexaiError *VertexaiError) *types.OpenAIError {
172179
return nil
173180
}
174181

175-
logger.SysError(fmt.Sprintf("VertexAI error: %s", vertexaiError.Error.Message))
176-
177182
return &types.OpenAIError{
178183
Message: "VertexAI错误",
179184
Type: "gemini_error",
@@ -183,24 +188,84 @@ func errorHandle(vertexaiError *VertexaiError) *types.OpenAIError {
183188
}
184189

185190
func customDialer(proxyAddr string) func(context.Context, string) (net.Conn, error) {
186-
187191
return func(ctx context.Context, addr string) (net.Conn, error) {
192+
// 创建统一的dialer配置
193+
dialer := &net.Dialer{
194+
Timeout: 20 * time.Second,
195+
KeepAlive: 30 * time.Second,
196+
}
197+
198+
// 无代理直接连接
188199
if proxyAddr == "" {
189-
return net.Dial("tcp", addr)
200+
return dialer.DialContext(ctx, "tcp", addr)
190201
}
191202

192203
proxyURL, err := url.Parse(proxyAddr)
193204
if err != nil {
194205
return nil, fmt.Errorf("error parsing proxy address: %w", err)
195206
}
196207

197-
dialer := &net.Dialer{}
198-
199-
dialerProxy, err := proxy.FromURL(proxyURL, dialer)
200-
if err != nil {
201-
return nil, fmt.Errorf("failed to create HTTP dialer: %v", err)
208+
// 根据代理类型选择连接方式
209+
switch proxyURL.Scheme {
210+
case "http":
211+
return connectViaHTTPProxy(ctx, proxyURL, addr)
212+
case "https":
213+
logger.SysError("Warning: HTTPS proxy not compatible with gRPC, using direct connection")
214+
return dialer.DialContext(ctx, "tcp", addr)
215+
case "socks5", "socks5h":
216+
return connectViaSOCKS5Proxy(ctx, dialer, proxyURL, addr)
217+
default:
218+
return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
202219
}
220+
}
221+
}
222+
223+
// connectViaHTTPProxy 通过HTTP代理建立连接
224+
func connectViaHTTPProxy(ctx context.Context, proxyURL *url.URL, targetAddr string) (net.Conn, error) {
225+
dialer := &net.Dialer{
226+
Timeout: 20 * time.Second,
227+
KeepAlive: 30 * time.Second,
228+
}
229+
230+
proxyConn, err := dialer.DialContext(ctx, "tcp", proxyURL.Host)
231+
if err != nil {
232+
return nil, fmt.Errorf("failed to connect to HTTP proxy: %w", err)
233+
}
234+
235+
// 发送HTTP CONNECT请求
236+
connectReq := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", targetAddr, targetAddr)
237+
if _, err = proxyConn.Write([]byte(connectReq)); err != nil {
238+
proxyConn.Close()
239+
return nil, fmt.Errorf("failed to send CONNECT request: %w", err)
240+
}
203241

204-
return dialerProxy.Dial("tcp", addr)
242+
// 读取代理响应
243+
response := make([]byte, 1024)
244+
n, err := proxyConn.Read(response)
245+
if err != nil {
246+
proxyConn.Close()
247+
return nil, fmt.Errorf("failed to read proxy response: %w", err)
248+
}
249+
250+
responseStr := string(response[:n])
251+
if !strings.Contains(responseStr, "200 Connection established") && !strings.Contains(responseStr, "200 OK") {
252+
proxyConn.Close()
253+
return nil, fmt.Errorf("HTTP proxy CONNECT failed: %s", responseStr)
205254
}
255+
256+
return proxyConn, nil
257+
}
258+
259+
// connectViaSOCKS5Proxy 通过SOCKS5代理建立连接
260+
func connectViaSOCKS5Proxy(ctx context.Context, dialer *net.Dialer, proxyURL *url.URL, addr string) (net.Conn, error) {
261+
dialerProxy, err := proxy.FromURL(proxyURL, dialer)
262+
if err != nil {
263+
return nil, fmt.Errorf("failed to create SOCKS5 proxy dialer: %v", err)
264+
}
265+
266+
if contextDialer, ok := dialerProxy.(proxy.ContextDialer); ok {
267+
return contextDialer.DialContext(ctx, "tcp", addr)
268+
}
269+
270+
return dialerProxy.Dial("tcp", addr)
206271
}

0 commit comments

Comments
 (0)