Skip to content

Commit d17a49a

Browse files
committed
wip: openai-like providers
1 parent d5914fd commit d17a49a

File tree

9 files changed

+289
-12
lines changed

9 files changed

+289
-12
lines changed

.github/workflows/sphinx.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44
push:
55
branches: [master]
66
paths:
7-
- 'src/**'
7+
- 'agentle/**'
88
- 'docs/**'
99
workflow_dispatch:
1010

agentle/generations/providers/openai/__init__.py

Whitespace-only changes.

agentle/generations/providers/openai/adapters/__init__.py

Whitespace-only changes.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from rsb.adapters.adapter import Adapter
6+
7+
from agentle.generations.models.messages.assistant_message import AssistantMessage
8+
from agentle.generations.models.messages.developer_message import DeveloperMessage
9+
from agentle.generations.models.messages.user_message import UserMessage
10+
11+
if TYPE_CHECKING:
12+
from openai.types.chat.chat_completion_assistant_message_param import (
13+
ChatCompletionAssistantMessageParam,
14+
)
15+
from openai.types.chat.chat_completion_developer_message_param import (
16+
ChatCompletionDeveloperMessageParam,
17+
)
18+
from openai.types.chat.chat_completion_user_message_param import (
19+
ChatCompletionUserMessageParam,
20+
)
21+
22+
23+
class AgentleMessageToOpenaiMessageAdapter(
24+
Adapter[
25+
AssistantMessage | DeveloperMessage | UserMessage,
26+
"ChatCompletionDeveloperMessageParam | ChatCompletionUserMessageParam | ChatCompletionAssistantMessageParam",
27+
]
28+
):
29+
def adapt(
30+
self, _f: AssistantMessage | DeveloperMessage | UserMessage
31+
) -> (
32+
ChatCompletionAssistantMessageParam
33+
| ChatCompletionDeveloperMessageParam
34+
| ChatCompletionUserMessageParam
35+
):
36+
message = _f
37+
38+
match message:
39+
case AssistantMessage():
40+
...
41+
case DeveloperMessage():
42+
...
43+
case UserMessage():
44+
...
45+
46+
raise NotImplementedError
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
from typing import TYPE_CHECKING, override
5+
6+
from rsb.adapters.adapter import Adapter
7+
8+
from agentle.generations.models.message_parts.file import FilePart
9+
from agentle.generations.models.message_parts.text import TextPart
10+
11+
if TYPE_CHECKING:
12+
from openai.types.chat.chat_completion_content_part_image_param import (
13+
ChatCompletionContentPartImageParam,
14+
)
15+
from openai.types.chat.chat_completion_content_part_input_audio_param import (
16+
ChatCompletionContentPartInputAudioParam,
17+
)
18+
from openai.types.chat.chat_completion_content_part_text_param import (
19+
ChatCompletionContentPartTextParam,
20+
)
21+
22+
23+
class AgentlePartToOpenaiPartAdapter(
24+
Adapter[
25+
TextPart | FilePart,
26+
"ChatCompletionContentPartImageParam | ChatCompletionContentPartInputAudioParam | ChatCompletionContentPartTextParam",
27+
]
28+
):
29+
@override
30+
def adapt(
31+
self, _f: TextPart | FilePart
32+
) -> (
33+
ChatCompletionContentPartTextParam
34+
| ChatCompletionContentPartImageParam
35+
| ChatCompletionContentPartInputAudioParam
36+
):
37+
part = _f
38+
39+
match part:
40+
case TextPart():
41+
return ChatCompletionContentPartTextParam(text=part.text, type="text")
42+
case FilePart():
43+
mime_type = part.mime_type
44+
if mime_type.startswith("image/"):
45+
return ChatCompletionContentPartImageParam(
46+
image_url={
47+
"url": base64.b64encode(part.data).decode(),
48+
"detail": "auto",
49+
},
50+
type="image_url",
51+
)
52+
elif mime_type.startswith("audio/"):
53+
return ChatCompletionContentPartInputAudioParam(
54+
input_audio={
55+
"data": base64.b64encode(part.data).decode(),
56+
"format": "mp3",
57+
},
58+
type="input_audio",
59+
)
60+
else:
61+
raise ValueError(f"Unsupported file type: {mime_type}")
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Mapping
4+
from typing import TYPE_CHECKING, Any, Literal, Sequence, cast, override
5+
6+
import httpx
7+
8+
from agentle.generations.models.generation.generation import Generation
9+
from agentle.generations.models.generation.generation_config import GenerationConfig
10+
from agentle.generations.models.messages.assistant_message import AssistantMessage
11+
from agentle.generations.models.messages.developer_message import DeveloperMessage
12+
from agentle.generations.models.messages.user_message import UserMessage
13+
from agentle.generations.pricing.price_retrievable import PriceRetrievable
14+
from agentle.generations.providers.base.generation_provider import GenerationProvider
15+
from agentle.generations.providers.openai.adapters.agentle_message_to_openai_message_adapter import (
16+
AgentleMessageToOpenaiMessageAdapter,
17+
)
18+
from agentle.generations.tools.tool import Tool
19+
20+
type WithoutStructuredOutput = None
21+
22+
23+
if TYPE_CHECKING:
24+
from openai._types import NotGiven
25+
26+
27+
class NotGivenSentinel:
28+
def __bool__(self) -> Literal[False]:
29+
return False
30+
31+
32+
NOT_GIVEN = NotGivenSentinel()
33+
34+
35+
class OpenaiGenerationProvider(GenerationProvider, PriceRetrievable):
36+
"""
37+
OpenAI generation provider.
38+
"""
39+
40+
api_key: str | None
41+
organization_name: str | None
42+
project_name: str | None
43+
base_url: str | httpx.URL | None
44+
websocket_base_url: str | httpx.URL | None
45+
timeout: float | httpx.Timeout | None | NotGiven
46+
max_retries: int
47+
default_headers: Mapping[str, str] | None
48+
default_query: Mapping[str, object] | None
49+
http_client: httpx.AsyncClient | None
50+
51+
def __init__(
52+
self,
53+
api_key: str,
54+
*,
55+
organization_name: str | None = None,
56+
project_name: str | None = None,
57+
base_url: str | httpx.URL | None = None,
58+
websocket_base_url: str | httpx.URL | None = None,
59+
timeout: float | httpx.Timeout | None | NotGiven | NotGivenSentinel = NOT_GIVEN,
60+
max_retries: int = 2,
61+
default_headers: Mapping[str, str] | None = None,
62+
default_query: Mapping[str, object] | None = None,
63+
http_client: httpx.AsyncClient | None = None,
64+
) -> None:
65+
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
66+
67+
if timeout is NOT_GIVEN:
68+
timeout = OPENAI_NOT_GIVEN
69+
70+
self.api_key = api_key
71+
self.organization_name = organization_name
72+
self.project_name = project_name
73+
self.base_url = base_url
74+
self.websocket_base_url = websocket_base_url
75+
self.timeout = cast(float | httpx.Timeout | None | NotGiven, timeout)
76+
self.max_retries = max_retries
77+
self.default_headers = default_headers
78+
self.default_query = default_query
79+
self.http_client = http_client
80+
81+
@override
82+
async def create_generation_async[T = WithoutStructuredOutput](
83+
self,
84+
*,
85+
model: str | None = None,
86+
messages: Sequence[AssistantMessage | DeveloperMessage | UserMessage],
87+
response_schema: type[T] | None = None,
88+
generation_config: GenerationConfig | None = None,
89+
tools: Sequence[Tool[Any]] | None = None,
90+
) -> Generation[T]:
91+
from openai import AsyncOpenAI
92+
93+
client = AsyncOpenAI(
94+
api_key=self.api_key,
95+
base_url=self.base_url,
96+
websocket_base_url=self.websocket_base_url,
97+
timeout=self.timeout,
98+
max_retries=self.max_retries,
99+
default_headers=self.default_headers,
100+
default_query=self.default_query,
101+
http_client=self.http_client,
102+
organization=self.organization_name,
103+
project=self.project_name,
104+
)
105+
106+
message_adapter = AgentleMessageToOpenaiMessageAdapter()
107+
108+
c = client.chat.completions.create( # type: ignore
109+
messages=[message_adapter.adapt(message) for message in messages],
110+
model=model or self.default_model,
111+
)
112+
113+
raise NotImplementedError("Not implemented yet.")
114+
115+
@property
116+
@override
117+
def default_model(self) -> str:
118+
return "gpt-4o"
119+
120+
@override
121+
def price_per_million_tokens_input(
122+
self, model: str, estimate_tokens: int | None = None
123+
) -> float:
124+
"""
125+
Get the price per million tokens for input/prompt tokens.
126+
"""
127+
return 0.0
128+
129+
@override
130+
def price_per_million_tokens_output(
131+
self, model: str, estimate_tokens: int | None = None
132+
) -> float:
133+
"""
134+
Get the price per million tokens for output/completion tokens.
135+
"""
136+
return 0.0

examples/tool_calling_and_structured_outputs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
from pydantic import BaseModel
9-
from typing import Any, List, Optional
9+
from typing import Any
1010
from agentle.agents.agent import Agent
1111
from agentle.generations.providers.google.google_genai_generation_provider import (
1212
GoogleGenaiGenerationProvider,
@@ -61,10 +61,10 @@ class TravelRecommendation(BaseModel):
6161
country: str
6262
population: int
6363
local_time: str # Agent will need to calculate this based on timezone
64-
attractions: List[str]
64+
attractions: list[str]
6565
best_time_to_visit: str
6666
estimated_daily_budget: float
67-
safety_rating: Optional[int] = None # 1-10 scale
67+
safety_rating: int | None = None # 1-10 scale
6868

6969

7070
# Create an agent with both tools and a structured output schema

pyproject.toml

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ version = "v0.1.0"
44
description = "..."
55
readme = "README.md"
66
requires-python = ">=3.13"
7-
dependencies = ["mcp[cli]>=1.6.0", "rock-solid-base>=0.0.28"]
7+
dependencies = [
8+
"mcp[cli]>=1.6.0",
9+
"rock-solid-base>=0.0.28",
10+
]
811

912
[[project.authors]]
1013
name = "Arthur Brenno"
@@ -21,4 +24,35 @@ packages = ["agentle"]
2124
packages = ["agentle"]
2225

2326
[dependency-groups]
24-
dev = ["aiocache>=0.12.3", "aiofiles>=24.1.0", "aiohttp>=3.11.18", "blacksheep>=2.2.0", "cerebras-cloud-sdk>=1.29.0", "furo>=2024.8.6", "google-genai>=1.11.0", "ipykernel>=6.29.5", "langfuse>=2.60.3", "markdownify>=1.1.0", "mypy>=1.15.0", "openai>=1.75.0", "openai-agents>=0.0.11", "openpyxl>=3.1.5", "pandas>=2.2.3", "pillow>=11.2.1", "playwright>=1.52.0", "pydantic>=2.11.3", "pydantic-ai>=0.1.3", "pymupdf>=1.25.5", "pypdf>=5.4.0", "pytest>=8.3.5", "python-docx>=1.1.2", "python-dotenv>=1.1.0", "python-pptx>=1.0.2", "r2r>=3.5.14", "rarfile>=4.2", "sphinx>=8.2.3", "streamlit>=1.45.0", "uvicorn>=0.34.1"]
27+
dev = [
28+
"aiocache>=0.12.3",
29+
"aiofiles>=24.1.0",
30+
"aiohttp>=3.11.18",
31+
"blacksheep>=2.2.0",
32+
"cerebras-cloud-sdk>=1.29.0",
33+
"furo>=2024.8.6",
34+
"google-genai>=1.11.0",
35+
"ipykernel>=6.29.5",
36+
"langfuse>=2.60.3",
37+
"markdownify>=1.1.0",
38+
"mypy>=1.15.0",
39+
"openai>=1.77.0",
40+
"openai-agents>=0.0.11",
41+
"openpyxl>=3.1.5",
42+
"pandas>=2.2.3",
43+
"pillow>=11.2.1",
44+
"playwright>=1.52.0",
45+
"pydantic>=2.11.3",
46+
"pydantic-ai>=0.1.3",
47+
"pymupdf>=1.25.5",
48+
"pypdf>=5.4.0",
49+
"pytest>=8.3.5",
50+
"python-docx>=1.1.2",
51+
"python-dotenv>=1.1.0",
52+
"python-pptx>=1.0.2",
53+
"r2r>=3.5.14",
54+
"rarfile>=4.2",
55+
"sphinx>=8.2.3",
56+
"streamlit>=1.45.0",
57+
"uvicorn>=0.34.1",
58+
]

uv.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)