|
8 | 8 | from pydantic import BaseModel |
9 | 9 |
|
10 | 10 | from crewai.utilities.agent_utils import is_context_length_exceeded |
| 11 | +from crewai.utilities.converter import generate_model_description |
11 | 12 | from crewai.utilities.exceptions.context_window_exceeding_exception import ( |
12 | 13 | LLMContextLengthExceededError, |
13 | 14 | ) |
|
26 | 27 | from azure.ai.inference.models import ( |
27 | 28 | ChatCompletions, |
28 | 29 | ChatCompletionsToolCall, |
29 | | - StreamingChatCompletionsUpdate, |
30 | 30 | JsonSchemaFormat, |
| 31 | + StreamingChatCompletionsUpdate, |
31 | 32 | ) |
32 | 33 | from azure.core.credentials import ( |
33 | 34 | AzureKeyCredential, |
@@ -279,14 +280,15 @@ def _prepare_completion_params( |
279 | 280 | } |
280 | 281 |
|
281 | 282 | if response_model and self.is_openai_model: |
282 | | - response_model_json_schema = response_model.model_json_schema() |
283 | | - response_model_json_schema['additionalProperties'] = False |
284 | | - |
| 283 | + model_description = generate_model_description(response_model) |
| 284 | + json_schema_info = model_description["json_schema"] |
| 285 | + json_schema_name = json_schema_info["name"] |
| 286 | + |
285 | 287 | params["response_format"] = JsonSchemaFormat( |
286 | | - name="Tasks_Response", |
287 | | - schema=response_model_json_schema, |
288 | | - description="Describes the task expected response", |
289 | | - strict=True, |
| 288 | + name=json_schema_name, |
| 289 | + schema=json_schema_info["schema"], |
| 290 | + description=f"Schema for {json_schema_name}", |
| 291 | + strict=json_schema_info["strict"], |
290 | 292 | ) |
291 | 293 |
|
292 | 294 | # Only include model parameter for non-Azure OpenAI endpoints |
@@ -314,8 +316,8 @@ def _prepare_completion_params( |
314 | 316 | params["tool_choice"] = "auto" |
315 | 317 |
|
316 | 318 | additional_params = self.additional_params |
317 | | - additional_drop_params = additional_params.get('additional_drop_params') |
318 | | - drop_params = additional_params.get('drop_params') |
| 319 | + additional_drop_params = additional_params.get("additional_drop_params") |
| 320 | + drop_params = additional_params.get("drop_params") |
319 | 321 |
|
320 | 322 | if drop_params and isinstance(additional_drop_params, list): |
321 | 323 | for drop_param in additional_drop_params: |
|
0 commit comments