Skip to content

Commit 8b5a951

Browse files
committed
Support for azure openai via the standard openai provider
closes #112 Signed-off-by: Christopher Petito <[email protected]>
1 parent 0f23870 commit 8b5a951

File tree

4 files changed

+175
-4
lines changed

4 files changed

+175
-4
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,18 @@ all theses keys are optional, you will likely need at least one of these, though
123123
# For OpenAI models
124124
export OPENAI_API_KEY=your_api_key_here
125125

126+
# For Azure OpenAI (via OpenAI provider)
127+
export AZURE_OPENAI_API_KEY=your_api_key_here
128+
126129
# For Anthropic models
127130
export ANTHROPIC_API_KEY=your_api_key_here
128131

129132
# For Gemini models
130133
export GOOGLE_API_KEY=your_api_key_here
131134
```
132135

136+
> Azure OpenAI configuration details are in the [USAGE](./docs/USAGE.md#azure-openai-via-the-openai-provider) docs
137+
133138
### Run Agents!
134139

135140
```bash

docs/USAGE.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,29 @@ models:
162162
model: ai/qwen3
163163
```
164164

165+
#### Azure OpenAI (via the OpenAI provider)
166+
167+
When using Azure OpenAI:
168+
- **base_url**: use your resource root only, without the `/openai/...` path. Example: `https://YOUR_RESOURCE.openai.azure.com/`.
169+
- **model**: set to your desired model name. If `provider_opts.azure_deployment_name` is not specified, this will also be used as your Azure deployment name.
170+
- **auth**: by default we use `AZURE_OPENAI_API_KEY` when the `base_url` points to Azure. You can also set `token_key: SOME_CUSTOM_NAME` explicitly.
171+
- **API version**: set `provider_opts.azure_api_version` to the model's API version you are using (e.g., `2024-10-21`).
172+
- **deployment_name**: If you deployment name differs from the model name, set `provider_opts.azure_deployment_name`.
173+
174+
```yaml
175+
models:
176+
azure-model:
177+
provider: openai
178+
model: gpt-4.1-mini
179+
base_url: https://YOUR_RESOURCE.openai.azure.com/
180+
token_key: CUSTOM_AZURE_OPENAI_API_KEY # only needed if you store the api key in an env var that's not AZURE_OPENAI_API_KEY
181+
provider_opts:
182+
azure_api_version: 2024-12-01-preview # required
183+
# Optional: if your deployment name differs from the model name, set the deployment name here instead
184+
azure_deployment_name: my-gpt-4-1-mini-deployment
185+
186+
```
187+
165188
#### DMR (Docker Model Runner) provider usage
166189

167190
If `base_url` is omitted, cagent will use `http://localhost:12434/engines/llama.cpp/v1` by default

pkg/model/provider/openai/client.go

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"errors"
7+
"fmt"
78
"log/slog"
89
"strings"
910

@@ -49,16 +50,53 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
4950
if gateway := globalOptions.Gateway(); gateway == "" {
5051
key := cfg.TokenKey
5152
if key == "" {
52-
key = "OPENAI_API_KEY"
53+
if isAzureBaseURL(cfg.BaseURL) {
54+
key = "AZURE_OPENAI_API_KEY"
55+
} else {
56+
key = "OPENAI_API_KEY"
57+
}
5358
}
5459
authToken := env.Get(ctx, key)
5560
if authToken == "" {
61+
if key == cfg.TokenKey {
62+
errMsg := fmt.Sprintf("%s key configured in model config is required", key)
63+
return nil, errors.New(errMsg)
64+
}
65+
if isAzureBaseURL(cfg.BaseURL) {
66+
return nil, errors.New("AZURE_OPENAI_API_KEY environment variable is required")
67+
}
5668
return nil, errors.New("OPENAI_API_KEY environment variable is required")
5769
}
5870

59-
openaiConfig = openai.DefaultConfig(authToken)
60-
if cfg.BaseURL != "" {
61-
openaiConfig.BaseURL = cfg.BaseURL
71+
// Configure Azure vs standard OpenAI
72+
if cfg.BaseURL != "" && isAzureBaseURL(cfg.BaseURL) {
73+
azureBase := strings.ToLower(strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/"))
74+
openaiConfig = openai.DefaultAzureConfig(authToken, azureBase)
75+
// api-version is required by Azure; enforce presence via provider_opts.azure_api_version
76+
if cfg.ProviderOpts == nil {
77+
return nil, errors.New("provider_opts.azure_api_version is required for Azure OpenAI")
78+
}
79+
if v, ok := cfg.ProviderOpts["azure_api_version"]; !ok {
80+
return nil, errors.New("provider_opts.azure_api_version is required for Azure OpenAI")
81+
} else {
82+
if s, ok := v.(string); !ok || s == "" {
83+
return nil, errors.New("provider_opts.azure_api_version is required for Azure OpenAI")
84+
} else {
85+
openaiConfig.APIVersion = s
86+
}
87+
}
88+
// allow overriding deployment name (Azure uses deployment name in URL path)
89+
if v, ok := cfg.ProviderOpts["azure_deployment_name"]; ok {
90+
if s, ok := v.(string); ok && s != "" {
91+
name := s
92+
openaiConfig.AzureModelMapperFunc = func(string) string { return name }
93+
}
94+
}
95+
} else {
96+
openaiConfig = openai.DefaultConfig(authToken)
97+
if cfg.BaseURL != "" {
98+
openaiConfig.BaseURL = cfg.BaseURL
99+
}
62100
}
63101
} else {
64102
authToken := desktop.GetToken(ctx)
@@ -102,6 +140,14 @@ func (c *Client) newGatewayClient(ctx context.Context) *openai.Client {
102140
return openai.NewClientWithConfig(cfg)
103141
}
104142

143+
// isAzureBaseURL returns true if the base URL points to Azure OpenAI
144+
func isAzureBaseURL(u string) bool {
145+
if u == "" {
146+
return false
147+
}
148+
return strings.Contains(strings.ToLower(u), ".openai.azure.com")
149+
}
150+
105151
func convertMultiContent(multiContent []chat.MessagePart) []openai.ChatMessagePart {
106152
openaiMultiContent := make([]openai.ChatMessagePart, len(multiContent))
107153
for i, part := range multiContent {
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package openai
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
latest "github.com/docker/cagent/pkg/config/v2"
8+
"github.com/docker/cagent/pkg/environment"
9+
)
10+
11+
type mapEnv struct{ m map[string]string }
12+
13+
func (e mapEnv) Get(ctx context.Context, name string) string { return e.m[name] }
14+
15+
func TestIsAzureBaseURL(t *testing.T) {
16+
cases := []struct {
17+
in string
18+
want bool
19+
}{
20+
{"", false},
21+
{"https://api.openai.com/v1", false},
22+
{"https://myres.openai.azure.com/", true},
23+
{"HTTPs://MYRES.OpenAI.AZURE.COM/openai/v1/", true},
24+
{"http://localhost:8080/v1", false},
25+
}
26+
for _, tc := range cases {
27+
if got := isAzureBaseURL(tc.in); got != tc.want {
28+
t.Fatalf("isAzureBaseURL(%q)=%v, want %v", tc.in, got, tc.want)
29+
}
30+
}
31+
}
32+
33+
func TestNewClient_AzureRequiresAPIVersion(t *testing.T) {
34+
ctx := context.Background()
35+
cfg := &latest.ModelConfig{
36+
Provider: "openai",
37+
Model: "my-deployment", // deployment name for Azure
38+
BaseURL: "https://myres.openai.azure.com/",
39+
}
40+
env := mapEnv{m: map[string]string{
41+
"AZURE_OPENAI_API_KEY": "azure-key",
42+
// OPENAI_API_KEY intentionally missing to exercise fallback
43+
}}
44+
c, err := NewClient(ctx, cfg, environment.Provider(env))
45+
if err == nil || c != nil {
46+
t.Fatalf("expected error due to missing azure_api_version, got client=%v, err=%v", c, err)
47+
}
48+
}
49+
50+
func TestNewClient_AzureWithAPIVersionSucceeds(t *testing.T) {
51+
ctx := context.Background()
52+
cfg := &latest.ModelConfig{
53+
Provider: "openai",
54+
Model: "my-deployment",
55+
BaseURL: "https://myres.openai.azure.com/",
56+
ProviderOpts: map[string]any{
57+
"azure_api_version": "2024-10-21",
58+
},
59+
}
60+
env := mapEnv{m: map[string]string{
61+
"AZURE_OPENAI_API_KEY": "azure-key",
62+
}}
63+
c, err := NewClient(ctx, cfg, environment.Provider(env))
64+
if err != nil || c == nil {
65+
t.Fatalf("expected client without error, got client=%v, err=%v", c, err)
66+
}
67+
}
68+
69+
func TestNewClient_AzureMissingKeysReturnsError(t *testing.T) {
70+
ctx := context.Background()
71+
cfg := &latest.ModelConfig{
72+
Provider: "openai",
73+
Model: "my-deployment",
74+
BaseURL: "https://myres.openai.azure.com/",
75+
}
76+
env := mapEnv{m: map[string]string{}}
77+
c, err := NewClient(ctx, cfg, environment.Provider(env))
78+
if err == nil || c != nil {
79+
t.Fatalf("expected error due to missing keys, got client=%v, err=%v", c, err)
80+
}
81+
}
82+
83+
func TestNewClient_NonAzureUsesOpenAIKey(t *testing.T) {
84+
ctx := context.Background()
85+
cfg := &latest.ModelConfig{
86+
Provider: "openai",
87+
Model: "gpt-4o",
88+
// BaseURL left empty to use default OpenAI
89+
}
90+
env := mapEnv{m: map[string]string{
91+
"OPENAI_API_KEY": "openai-key",
92+
}}
93+
c, err := NewClient(ctx, cfg, environment.Provider(env))
94+
if err != nil || c == nil {
95+
t.Fatalf("expected client without error, got client=%v, err=%v", c, err)
96+
}
97+
}

0 commit comments

Comments
 (0)