Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,18 @@ all theses keys are optional, you will likely need at least one of these, though
# For OpenAI models
export OPENAI_API_KEY=your_api_key_here

# For Azure OpenAI (via OpenAI provider)
export AZURE_OPENAI_API_KEY=your_api_key_here

# For Anthropic models
export ANTHROPIC_API_KEY=your_api_key_here

# For Gemini models
export GOOGLE_API_KEY=your_api_key_here
```

> Azure OpenAI configuration details are in the [USAGE](./docs/USAGE.md#azure-openai-via-the-openai-provider) docs

### Run Agents!

```bash
Expand Down
23 changes: 23 additions & 0 deletions docs/USAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,29 @@ models:
model: ai/qwen3
```

#### Azure OpenAI (via the OpenAI provider)

When using Azure OpenAI:
- **base_url**: use your resource root only, without the `/openai/...` path. Example: `https://YOUR_RESOURCE.openai.azure.com/`.
- **model**: set to your desired model name. If `provider_opts.azure_deployment_name` is not specified, this will also be used as your Azure deployment name.
- **auth**: by default we use `AZURE_OPENAI_API_KEY` when the `base_url` points to Azure. You can also set `token_key: SOME_CUSTOM_NAME` explicitly.
- **API version**: set `provider_opts.azure_api_version` to the model's API version you are using (e.g., `2024-10-21`).
- **deployment_name**: If you deployment name differs from the model name, set `provider_opts.azure_deployment_name`.

```yaml
models:
azure-model:
provider: openai
model: gpt-4.1-mini
base_url: https://YOUR_RESOURCE.openai.azure.com/
token_key: CUSTOM_AZURE_OPENAI_API_KEY # only needed if you store the api key in an env var that's not AZURE_OPENAI_API_KEY
provider_opts:
azure_api_version: 2024-12-01-preview # required
# Optional: if your deployment name differs from the model name, set the deployment name here instead
azure_deployment_name: my-gpt-4-1-mini-deployment

```

#### DMR (Docker Model Runner) provider usage

If `base_url` is omitted, cagent will use `http://localhost:12434/engines/llama.cpp/v1` by default
Expand Down
54 changes: 50 additions & 4 deletions pkg/model/provider/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"strings"

Expand Down Expand Up @@ -49,16 +50,53 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
if gateway := globalOptions.Gateway(); gateway == "" {
key := cfg.TokenKey
if key == "" {
key = "OPENAI_API_KEY"
if isAzureBaseURL(cfg.BaseURL) {
key = "AZURE_OPENAI_API_KEY"
} else {
key = "OPENAI_API_KEY"
}
}
authToken := env.Get(ctx, key)
if authToken == "" {
if key == cfg.TokenKey {
errMsg := fmt.Sprintf("%s key configured in model config is required", key)
return nil, errors.New(errMsg)
}
if isAzureBaseURL(cfg.BaseURL) {
return nil, errors.New("AZURE_OPENAI_API_KEY environment variable is required")
}
return nil, errors.New("OPENAI_API_KEY environment variable is required")
}

openaiConfig = openai.DefaultConfig(authToken)
if cfg.BaseURL != "" {
openaiConfig.BaseURL = cfg.BaseURL
// Configure Azure vs standard OpenAI
if cfg.BaseURL != "" && isAzureBaseURL(cfg.BaseURL) {
azureBase := strings.ToLower(strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/"))
openaiConfig = openai.DefaultAzureConfig(authToken, azureBase)
// api-version is required by Azure; enforce presence via provider_opts.azure_api_version
if cfg.ProviderOpts == nil {
return nil, errors.New("provider_opts.azure_api_version is required for Azure OpenAI")
}
if v, ok := cfg.ProviderOpts["azure_api_version"]; !ok {
return nil, errors.New("provider_opts.azure_api_version is required for Azure OpenAI")
} else {
if s, ok := v.(string); !ok || s == "" {
return nil, errors.New("provider_opts.azure_api_version is required for Azure OpenAI")
} else {
openaiConfig.APIVersion = s
}
}
// allow overriding deployment name (Azure uses deployment name in URL path)
if v, ok := cfg.ProviderOpts["azure_deployment_name"]; ok {
if s, ok := v.(string); ok && s != "" {
name := s
openaiConfig.AzureModelMapperFunc = func(string) string { return name }
}
}
} else {
openaiConfig = openai.DefaultConfig(authToken)
if cfg.BaseURL != "" {
openaiConfig.BaseURL = cfg.BaseURL
}
}
} else {
authToken := desktop.GetToken(ctx)
Expand Down Expand Up @@ -102,6 +140,14 @@ func (c *Client) newGatewayClient(ctx context.Context) *openai.Client {
return openai.NewClientWithConfig(cfg)
}

// isAzureBaseURL returns true if the base URL points to Azure OpenAI
func isAzureBaseURL(u string) bool {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively to this, we could maybe use another provider_opts key such as azure: true

models:
  azure:
    provider: openai
    ...
    provider_opts:
        azure: true
        ...

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could imagine that at some point Azure OpenAI would support custom domains. Then it would be useful to have the option key. Or maybe react on the azure_api_version? If that is set, it has to be Azure OpenAI and vice versa

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing a bit more, I now indeed got an endpoint that ends with .cognitiveservices.azure.com. That doesn't work with the currently proposed implementatoin

if u == "" {
return false
}
return strings.Contains(strings.ToLower(u), ".openai.azure.com")
}

func convertMultiContent(multiContent []chat.MessagePart) []openai.ChatMessagePart {
openaiMultiContent := make([]openai.ChatMessagePart, len(multiContent))
for i, part := range multiContent {
Expand Down
97 changes: 97 additions & 0 deletions pkg/model/provider/openai/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package openai

import (
"context"
"testing"

latest "github.com/docker/cagent/pkg/config/v2"
"github.com/docker/cagent/pkg/environment"
)

type mapEnv struct{ m map[string]string }

func (e mapEnv) Get(ctx context.Context, name string) string { return e.m[name] }

func TestIsAzureBaseURL(t *testing.T) {
cases := []struct {
in string
want bool
}{
{"", false},
{"https://api.openai.com/v1", false},
{"https://myres.openai.azure.com/", true},
{"HTTPs://MYRES.OpenAI.AZURE.COM/openai/v1/", true},
{"http://localhost:8080/v1", false},
}
for _, tc := range cases {
if got := isAzureBaseURL(tc.in); got != tc.want {
t.Fatalf("isAzureBaseURL(%q)=%v, want %v", tc.in, got, tc.want)
}
}
}

func TestNewClient_AzureRequiresAPIVersion(t *testing.T) {
ctx := context.Background()
cfg := &latest.ModelConfig{
Provider: "openai",
Model: "my-deployment", // deployment name for Azure
BaseURL: "https://myres.openai.azure.com/",
}
env := mapEnv{m: map[string]string{
"AZURE_OPENAI_API_KEY": "azure-key",
// OPENAI_API_KEY intentionally missing to exercise fallback
}}
c, err := NewClient(ctx, cfg, environment.Provider(env))
if err == nil || c != nil {
t.Fatalf("expected error due to missing azure_api_version, got client=%v, err=%v", c, err)
}
}

func TestNewClient_AzureWithAPIVersionSucceeds(t *testing.T) {
ctx := context.Background()
cfg := &latest.ModelConfig{
Provider: "openai",
Model: "my-deployment",
BaseURL: "https://myres.openai.azure.com/",
ProviderOpts: map[string]any{
"azure_api_version": "2024-10-21",
},
}
env := mapEnv{m: map[string]string{
"AZURE_OPENAI_API_KEY": "azure-key",
}}
c, err := NewClient(ctx, cfg, environment.Provider(env))
if err != nil || c == nil {
t.Fatalf("expected client without error, got client=%v, err=%v", c, err)
}
}

func TestNewClient_AzureMissingKeysReturnsError(t *testing.T) {
ctx := context.Background()
cfg := &latest.ModelConfig{
Provider: "openai",
Model: "my-deployment",
BaseURL: "https://myres.openai.azure.com/",
}
env := mapEnv{m: map[string]string{}}
c, err := NewClient(ctx, cfg, environment.Provider(env))
if err == nil || c != nil {
t.Fatalf("expected error due to missing keys, got client=%v, err=%v", c, err)
}
}

func TestNewClient_NonAzureUsesOpenAIKey(t *testing.T) {
ctx := context.Background()
cfg := &latest.ModelConfig{
Provider: "openai",
Model: "gpt-4o",
// BaseURL left empty to use default OpenAI
}
env := mapEnv{m: map[string]string{
"OPENAI_API_KEY": "openai-key",
}}
c, err := NewClient(ctx, cfg, environment.Provider(env))
if err != nil || c == nil {
t.Fatalf("expected client without error, got client=%v, err=%v", c, err)
}
}