Skip to content

Commit 6e208a8

Browse files
kaghatimBV-Venky
andauthored
feat(bedrock): add strict_tools config with auto-inject of additional… (#2213)
Co-authored-by: Venkatesh Bhukya <venkateshcjjc@gmail.com>
1 parent 771a86a commit 6e208a8

5 files changed

Lines changed: 665 additions & 1 deletion

File tree

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""Strict JSON schema transformation for tool definitions.
2+
3+
When model providers require `strict: true` on tool definitions, they also require
4+
`"additionalProperties": false` on every `object` type in the input schema. This module
5+
provides a utility to recursively apply that constraint.
6+
7+
Modeled after OpenAI's `_ensure_strict_json_schema`:
8+
https://github.com/openai/openai-python/blob/main/src/openai/lib/_pydantic.py
9+
"""
10+
11+
import copy
12+
import logging
13+
from typing import Any
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
def ensure_strict_json_schema(
19+
schema: dict[str, Any],
20+
*,
21+
require_all_properties: bool = False,
22+
) -> dict[str, Any]:
23+
"""Ensure a JSON schema conforms to strict tool use requirements.
24+
25+
Creates a deep copy of the schema and recursively:
26+
1. Adds ``"additionalProperties": false`` to all ``object`` types that do not already define it
27+
2. Optionally adds all properties to the ``required`` array (needed for OpenAI)
28+
3. Handles ``$defs``, ``definitions``, ``anyOf``, ``allOf``, ``items``, and ``$ref``
29+
30+
Args:
31+
schema: The JSON schema to process. A deep copy is made internally so the original is not mutated.
32+
require_all_properties: If True, set ``required`` to include all property keys. OpenAI strict mode
33+
requires this; Bedrock and Anthropic do not.
34+
35+
Returns:
36+
A new schema dict with strict-mode constraints applied.
37+
"""
38+
schema_copy = copy.deepcopy(schema)
39+
_apply_strict(schema_copy, root=schema_copy, require_all_properties=require_all_properties)
40+
return schema_copy
41+
42+
43+
def _apply_strict(
44+
schema: dict[str, Any],
45+
*,
46+
root: dict[str, Any],
47+
require_all_properties: bool,
48+
) -> None:
49+
"""Recursively apply strict-mode constraints to a JSON schema in place.
50+
51+
Args:
52+
schema: The schema node to process (modified in place).
53+
root: The root schema, used for resolving ``$ref`` pointers.
54+
require_all_properties: If True, add all properties to ``required``.
55+
"""
56+
# Process $defs / definitions blocks
57+
for defs_key in ("$defs", "definitions"):
58+
defs = schema.get(defs_key)
59+
if isinstance(defs, dict):
60+
for def_schema in defs.values():
61+
if isinstance(def_schema, dict):
62+
_apply_strict(def_schema, root=root, require_all_properties=require_all_properties)
63+
64+
# Add additionalProperties: false to object types that lack it
65+
if schema.get("type") == "object" and "additionalProperties" not in schema:
66+
schema["additionalProperties"] = False
67+
68+
# Process properties and optionally enforce required
69+
properties = schema.get("properties")
70+
if isinstance(properties, dict):
71+
if require_all_properties:
72+
schema["required"] = list(properties.keys())
73+
74+
for prop_schema in properties.values():
75+
if isinstance(prop_schema, dict):
76+
_apply_strict(prop_schema, root=root, require_all_properties=require_all_properties)
77+
78+
# Process array items
79+
items = schema.get("items")
80+
if isinstance(items, dict):
81+
_apply_strict(items, root=root, require_all_properties=require_all_properties)
82+
83+
# Process anyOf variants
84+
any_of = schema.get("anyOf")
85+
if isinstance(any_of, list):
86+
for variant in any_of:
87+
if isinstance(variant, dict):
88+
_apply_strict(variant, root=root, require_all_properties=require_all_properties)
89+
90+
# Process allOf variants
91+
all_of = schema.get("allOf")
92+
if isinstance(all_of, list):
93+
for entry in all_of:
94+
if isinstance(entry, dict):
95+
_apply_strict(entry, root=root, require_all_properties=require_all_properties)
96+
97+
# Process oneOf variants
98+
one_of = schema.get("oneOf")
99+
if isinstance(one_of, list):
100+
for variant in one_of:
101+
if isinstance(variant, dict):
102+
_apply_strict(variant, root=root, require_all_properties=require_all_properties)
103+
104+
# Resolve $ref combined with other keys by inlining the referenced schema
105+
ref = schema.get("$ref")
106+
if isinstance(ref, str) and len(schema) > 1:
107+
resolved = _resolve_ref(root, ref)
108+
if isinstance(resolved, dict):
109+
# Inline the resolved schema, giving priority to existing keys
110+
merged = {**copy.deepcopy(resolved), **schema}
111+
merged.pop("$ref", None)
112+
schema.clear()
113+
schema.update(merged)
114+
# Re-apply strict to the inlined schema
115+
_apply_strict(schema, root=root, require_all_properties=require_all_properties)
116+
117+
118+
def _resolve_ref(root: dict[str, Any], ref: str) -> dict[str, Any] | None:
119+
"""Resolve a JSON Schema ``$ref`` pointer against the root schema.
120+
121+
Args:
122+
root: The root schema containing definitions.
123+
ref: A JSON pointer string (e.g., ``#/$defs/MyModel``).
124+
125+
Returns:
126+
The resolved schema dict, or None if resolution fails.
127+
"""
128+
if not ref.startswith("#/"):
129+
logger.warning("ref=<%s> | unexpected $ref format, skipping resolution", ref)
130+
return None
131+
132+
path = ref[2:].split("/")
133+
current: Any = root
134+
for key in path:
135+
if not isinstance(current, dict) or key not in current:
136+
logger.warning("ref=<%s> | failed to resolve $ref path", ref)
137+
return None
138+
current = current[key]
139+
140+
if not isinstance(current, dict):
141+
logger.warning("ref=<%s> | resolved to non-dict value", ref)
142+
return None
143+
144+
return current

src/strands/models/bedrock.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from ..types.streaming import CitationsDelta, StreamEvent
3333
from ..types.tools import ToolChoice, ToolSpec
34+
from ._strict_schema import ensure_strict_json_schema
3435
from ._validation import validate_config_keys
3536
from .model import BaseModelConfig, CacheConfig, Model
3637

@@ -100,6 +101,10 @@ class BedrockConfig(BaseModelConfig, total=False):
100101
supported service tiers, models, and regions
101102
stop_sequences: List of sequences that will stop generation when encountered
102103
streaming: Flag to enable/disable streaming. Defaults to True.
104+
strict_tools: Flag to enable structured output enforcement on tool definitions.
105+
When True, adds strict: true to each tool spec and automatically injects
106+
"additionalProperties": false into all object types in tool input schemas.
107+
See https://docs.aws.amazon.com/bedrock/latest/userguide/structured-output.html
103108
temperature: Controls randomness in generation (higher = more random)
104109
top_p: Controls diversity via nucleus sampling (alternative to temperature)
105110
"""
@@ -125,6 +130,7 @@ class BedrockConfig(BaseModelConfig, total=False):
125130
service_tier: str | None
126131
stop_sequences: list[str] | None
127132
streaming: bool | None
133+
strict_tools: bool | None
128134
temperature: float | None
129135
top_p: float | None
130136

@@ -240,6 +246,7 @@ def _format_request(
240246

241247
# Use system_prompt_content directly (copy for mutability)
242248
system_blocks: list[SystemContentBlock] = system_prompt_content.copy() if system_prompt_content else []
249+
243250
# Add cache point if configured (backwards compatibility)
244251
if cache_prompt := self.config.get("cache_prompt"):
245252
warnings.warn(
@@ -261,7 +268,12 @@ def _format_request(
261268
"toolSpec": {
262269
"name": tool_spec["name"],
263270
"description": tool_spec["description"],
264-
"inputSchema": tool_spec["inputSchema"],
271+
"inputSchema": (
272+
{"json": ensure_strict_json_schema(tool_spec["inputSchema"]["json"])}
273+
if self.config.get("strict_tools")
274+
else tool_spec["inputSchema"]
275+
),
276+
**({"strict": True} if self.config.get("strict_tools") else {}),
265277
}
266278
}
267279
for tool_spec in tool_specs

tests/strands/models/test_bedrock.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,188 @@ def test_format_request_tool_specs(model, messages, model_id, tool_spec):
493493
assert tru_request == exp_request
494494

495495

496+
def test_format_request_strict_tools_injects_strict_and_closes_schema(bedrock_client, model_id, messages):
497+
tool_specs = [
498+
{
499+
"name": "my_tool",
500+
"description": "A tool",
501+
"inputSchema": {
502+
"json": {
503+
"type": "object",
504+
"properties": {"param": {"type": "string"}},
505+
"required": ["param"],
506+
}
507+
},
508+
}
509+
]
510+
model = BedrockModel(model_id=model_id, strict_tools=True)
511+
request = model._format_request(messages, tool_specs=tool_specs)
512+
tool_spec_result = request["toolConfig"]["tools"][0]["toolSpec"]
513+
514+
assert tool_spec_result == {
515+
"name": "my_tool",
516+
"description": "A tool",
517+
"inputSchema": {
518+
"json": {
519+
"type": "object",
520+
"properties": {"param": {"type": "string"}},
521+
"required": ["param"],
522+
"additionalProperties": False,
523+
}
524+
},
525+
"strict": True,
526+
}
527+
528+
529+
def test_format_request_strict_tools_does_not_mutate_original(bedrock_client, model_id, messages):
530+
tool_specs = [
531+
{
532+
"name": "my_tool",
533+
"description": "A tool",
534+
"inputSchema": {
535+
"json": {
536+
"type": "object",
537+
"properties": {"param": {"type": "string"}},
538+
"required": ["param"],
539+
}
540+
},
541+
}
542+
]
543+
model = BedrockModel(model_id=model_id, strict_tools=True)
544+
model._format_request(messages, tool_specs=tool_specs)
545+
546+
assert "additionalProperties" not in tool_specs[0]["inputSchema"]["json"]
547+
548+
549+
def test_format_request_strict_tools_preserves_additional_properties_true(bedrock_client, model_id, messages):
550+
tool_specs = [
551+
{
552+
"name": "my_tool",
553+
"description": "A tool",
554+
"inputSchema": {
555+
"json": {
556+
"type": "object",
557+
"properties": {"param": {"type": "string"}},
558+
"required": ["param"],
559+
"additionalProperties": True,
560+
}
561+
},
562+
}
563+
]
564+
model = BedrockModel(model_id=model_id, strict_tools=True)
565+
request = model._format_request(messages, tool_specs=tool_specs)
566+
schema = request["toolConfig"]["tools"][0]["toolSpec"]["inputSchema"]["json"]
567+
568+
assert schema["additionalProperties"] is True
569+
570+
571+
def test_format_request_strict_tools_nested_objects(bedrock_client, model_id, messages):
572+
tool_specs = [
573+
{
574+
"name": "my_tool",
575+
"description": "A tool",
576+
"inputSchema": {
577+
"json": {
578+
"type": "object",
579+
"properties": {
580+
"config": {
581+
"type": "object",
582+
"properties": {"value": {"type": "integer"}},
583+
}
584+
},
585+
"required": ["config"],
586+
}
587+
},
588+
}
589+
]
590+
model = BedrockModel(model_id=model_id, strict_tools=True)
591+
request = model._format_request(messages, tool_specs=tool_specs)
592+
schema = request["toolConfig"]["tools"][0]["toolSpec"]["inputSchema"]["json"]
593+
594+
assert schema == {
595+
"type": "object",
596+
"properties": {
597+
"config": {
598+
"type": "object",
599+
"properties": {"value": {"type": "integer"}},
600+
"additionalProperties": False,
601+
}
602+
},
603+
"required": ["config"],
604+
"additionalProperties": False,
605+
}
606+
607+
608+
def test_format_request_strict_tools_default_no_strict(bedrock_client, model_id, messages):
609+
tool_specs = [
610+
{
611+
"name": "my_tool",
612+
"description": "A tool",
613+
"inputSchema": {
614+
"json": {
615+
"type": "object",
616+
"properties": {"param": {"type": "string"}},
617+
"required": ["param"],
618+
}
619+
},
620+
}
621+
]
622+
model = BedrockModel(model_id=model_id)
623+
request = model._format_request(messages, tool_specs=tool_specs)
624+
tool_spec_result = request["toolConfig"]["tools"][0]["toolSpec"]
625+
626+
assert "strict" not in tool_spec_result
627+
assert tool_spec_result["inputSchema"]["json"] == {
628+
"type": "object",
629+
"properties": {"param": {"type": "string"}},
630+
"required": ["param"],
631+
}
632+
633+
634+
def test_format_request_strict_tools_false_no_strict(bedrock_client, model_id, messages):
635+
tool_specs = [
636+
{
637+
"name": "my_tool",
638+
"description": "A tool",
639+
"inputSchema": {"json": {"type": "object", "properties": {"x": {"type": "string"}}}},
640+
}
641+
]
642+
model = BedrockModel(model_id=model_id, strict_tools=False)
643+
request = model._format_request(messages, tool_specs=tool_specs)
644+
tool_spec_result = request["toolConfig"]["tools"][0]["toolSpec"]
645+
646+
assert "strict" not in tool_spec_result
647+
648+
649+
def test_format_request_strict_tools_none_no_strict(bedrock_client, model_id, messages):
650+
tool_specs = [
651+
{
652+
"name": "my_tool",
653+
"description": "A tool",
654+
"inputSchema": {"json": {"type": "object", "properties": {"x": {"type": "string"}}}},
655+
}
656+
]
657+
model = BedrockModel(model_id=model_id, strict_tools=None)
658+
request = model._format_request(messages, tool_specs=tool_specs)
659+
tool_spec_result = request["toolConfig"]["tools"][0]["toolSpec"]
660+
661+
assert "strict" not in tool_spec_result
662+
663+
664+
def test_format_request_strict_tools_applies_to_all_tools(bedrock_client, model_id, messages):
665+
tool_specs = [
666+
{"name": "tool_a", "description": "Tool A", "inputSchema": {"json": {"type": "object", "properties": {}}}},
667+
{"name": "tool_b", "description": "Tool B", "inputSchema": {"json": {"type": "object", "properties": {}}}},
668+
]
669+
model = BedrockModel(model_id=model_id, strict_tools=True)
670+
request = model._format_request(messages, tool_specs=tool_specs)
671+
672+
for tool in request["toolConfig"]["tools"]:
673+
if "toolSpec" in tool:
674+
assert tool["toolSpec"]["strict"] is True
675+
assert tool["toolSpec"]["inputSchema"]["json"]["additionalProperties"] is False
676+
677+
496678
def test_format_request_tool_choice_auto(model, messages, model_id, tool_spec):
497679
tool_choice = {"auto": {}}
498680
tru_request = model._format_request(messages, [tool_spec], tool_choice=tool_choice)

0 commit comments

Comments
 (0)