Skip to content

Commit 4ccbe2e

Browse files
committed
Add support for array-like string concat in patching
1 parent 3290211 commit 4ccbe2e

File tree

8 files changed

+1184
-734
lines changed

8 files changed

+1184
-734
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tests:
44
uv run python -m pytest --disable-socket --allow-unix-socket -vv --durations=10 tests/unit_tests
55

66
tests_watch:
7-
uv run ptw --now . -- -vv -x tests/unit_tests
7+
uv run --with-editable . ptw . -- -vv -x --ff tests/unit_tests
88

99
integration_tests:
1010
uv run python -m pytest -v --durations=10 --cov=trustcall --cov-report=term-missing --cov-report=html --cov-config=.coveragerc tests/integration_tests

pyproject.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@ requires-python = "<4.0,>=3.10"
77
dependencies = [
88
"langgraph>=0.2.25",
99
"dydantic<1.0.0,>=0.0.8",
10-
"jsonpatch<2.0,>=1.33",
11-
"langgraph-prebuilt>=0.1.2",
10+
"jsonpatch<2.0,>=1.33"
1211
]
1312
name = "trustcall"
14-
version = "0.0.38"
13+
version = "0.0.39"
1514
description = "Tenacious & trustworthy tool calling built on LangGraph."
1615
readme = "README.md"
1716

@@ -20,7 +19,7 @@ readme = "README.md"
2019
dev = [
2120
"ruff<1.0.0,>=0.4.10",
2221
"mypy<2.0.0,>=1.10.1",
23-
"pytest<9.0.0,>=8.2.2",
22+
"pytest>=8.2.2,<9.0.0",
2423
"pytest-socket<1.0.0,>=0.7.0",
2524
"langchain<1.0,>=0.3",
2625
"langchain-openai<1.0,>=0.2",

tests/evals/test_evals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def score_run(run: Run, example: Example) -> dict: # type: ignore
101101
]
102102
)
103103
return {"results": results}
104-
schema = create_model_from_schema(example.inputs["tool_def"]["parameters"])
104+
schema = create_model_from_schema(example.inputs["tool_def"]["parameters"]) # type: ignore
105105
try:
106106
schema.model_validate(predicted)
107107
results.append(

tests/unit_tests/test_extraction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ def foo() -> None:
716716
"""bar"""
717717
...
718718

719-
with pytest.raises(ValueError, match="At least one of"):
719+
with pytest.raises(Exception):
720720
create_extractor(
721721
llm="openai:foo",
722722
tools=[foo],

tests/unit_tests/test_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pydantic import ValidationError
66
from typing_extensions import Annotated, TypedDict
77

8-
from trustcall._base import _convert_any_typed_dicts_to_pydantic
8+
from trustcall._base import _apply_patch, _convert_any_typed_dicts_to_pydantic
99

1010

1111
def test_convert_any_typed_dicts_to_pydantic():
@@ -89,3 +89,28 @@ class RecursiveType(TypedDict):
8989
cyclic["next"] = cyclic
9090
with pytest.raises(ValueError): # or RecursionError, depending on implementation
9191
model(**cyclic)
92+
93+
94+
@pytest.mark.parametrize(
95+
"doc,patches,expected",
96+
[
97+
({"a": 1}, [{"op": "add", "path": "/b", "value": 2}], {"a": 1, "b": 2}),
98+
({"a": 1}, [{"op": "remove", "path": "/a"}], {}),
99+
({"a": 1}, [{"op": "replace", "path": "/a", "value": 2}], {"a": 2}),
100+
# Expanded syntax of concatenation for strings (similar to arrays)
101+
(
102+
{"a": {"b": "hello"}},
103+
[{"op": "add", "path": "/a/b/-", "value": " world"}],
104+
{"a": {"b": "hello world"}},
105+
),
106+
# Similar, but within nested list/dict/etc.
107+
(
108+
{"a": [{"b": "foo"}, {"b": {"c": "hello"}}]},
109+
[{"op": "add", "path": "/a/1/b/c/-", "value": " world"}],
110+
{"a": [{"b": "foo"}, {"b": {"c": "hello world"}}]},
111+
),
112+
],
113+
)
114+
def test_apply_patch_concat(doc, patches, expected):
115+
result = _apply_patch(doc, patches)
116+
assert result == expected

trustcall/_base.py

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626

2727
import jsonpatch # type: ignore[import-untyped]
28+
import jsonpointer # type: ignore[import-untyped]
2829
import langsmith as ls
2930
from dydantic import create_model_from_schema
3031
from langchain_core.language_models import BaseChatModel
@@ -40,10 +41,10 @@
4041
)
4142
from langchain_core.prompt_values import PromptValue
4243
from langchain_core.runnables import Runnable, RunnableConfig
44+
from langchain_core.runnables.config import get_executor_for_config
4345
from langchain_core.tools import BaseTool, InjectedToolArg, create_schema_from_function
4446
from langgraph.constants import Send
4547
from langgraph.graph import StateGraph, add_messages
46-
from langgraph.prebuilt.tool_validator import ValidationNode, get_executor_for_config
4748
from langgraph.types import Command
4849
from langgraph.utils.runnable import RunnableCallable
4950
from pydantic import (
@@ -58,6 +59,8 @@
5859
)
5960
from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict
6061

62+
from trustcall._validation_node import ValidationNode
63+
6164
logger = logging.getLogger("extraction")
6265

6366

@@ -812,7 +815,7 @@ def _teardown(
812815
ToolCall(
813816
id=tc["id"],
814817
name=tool_name,
815-
args=jsonpatch.apply_patch(target, patches),
818+
args=_apply_patch(target, patches),
816819
)
817820
)
818821
updated_docs[tc["id"]] = str(json_doc_id)
@@ -1221,11 +1224,12 @@ class PatchFunctionErrors(BaseModel):
12211224

12221225
json_doc_id: str = Field(
12231226
...,
1224-
description="The ID of the function you are patching.",
1227+
description="First, identify the json_doc_id of the function you are patching.",
12251228
)
12261229
planned_edits: str = Field(
12271230
...,
1228-
description="Write a bullet-point list of each ValidationError you encountered"
1231+
description="Second, write a bullet-point list of each ValidationError "
1232+
"you encountered"
12291233
" and the corresponding JSONPatch operation needed to heal it."
12301234
" For each operation, write why your initial guess was incorrect, "
12311235
" citing the corresponding types(s) from the JSONSchema"
@@ -1234,8 +1238,8 @@ class PatchFunctionErrors(BaseModel):
12341238
)
12351239
patches: list[JsonPatch] = Field(
12361240
...,
1237-
description="A list of JSONPatch operations to be applied to the"
1238-
" previous tool call's response arguments. If none are required, return"
1241+
description="Finally, provide a list of JSONPatch operations to be applied to"
1242+
" the previous tool call's response arguments. If none are required, return"
12391243
" an empty list. This field is REQUIRED."
12401244
" Multiple patches in the list are applied sequentially in the order provided,"
12411245
" with each patch building upon the result of the previous one.",
@@ -1254,17 +1258,19 @@ class PatchFunctionName(BaseModel):
12541258

12551259
json_doc_id: str = Field(
12561260
...,
1257-
description="The ID of the function you are patching.",
1261+
description="First, identify the json_doc_id of the function"
1262+
" you are patching.",
12581263
)
12591264
reasoning: list[str] = Field(
12601265
...,
1261-
description="At least 2 logical reasons why this action ought to be taken."
1266+
description="Seconds, provide at least 2 logical reasons why this"
1267+
" action ought to be taken."
12621268
"Cite the specific error(s) mentioned to motivate the fix.",
12631269
)
12641270
fixed_name: Optional[str] = Field(
12651271
...,
1266-
description="If you need to change the name of the function (e.g., "
1267-
f'from an "Unrecognized tool name" error), do so here.{vname}',
1272+
description="Finally, if you need to change the name of the function (e.g.,"
1273+
f' from an "Unrecognized tool name" error), do so here.{vname}',
12681274
)
12691275

12701276
return PatchFunctionName
@@ -1276,11 +1282,11 @@ class PatchDoc(BaseModel):
12761282

12771283
json_doc_id: str = Field(
12781284
...,
1279-
description="The json_doc_id of the document you are patching.",
1285+
description="First, identify the json_doc_id of the document you are patching.",
12801286
)
12811287
planned_edits: str = Field(
12821288
...,
1283-
description="Think step-by-step, reasoning over each required"
1289+
description="Seconds, think step-by-step, reasoning over each required"
12841290
" update and the corresponding JSONPatch operation to accomplish it."
12851291
" Cite the fields in the JSONSchema you referenced in developing this plan."
12861292
" Address each path as a group; don't switch between paths.\n"
@@ -1294,8 +1300,8 @@ class PatchDoc(BaseModel):
12941300
)
12951301
patches: list[JsonPatch] = Field(
12961302
...,
1297-
description="A list of JSONPatch operations to be applied to the"
1298-
" previous tool call's response arguments. If none are required, return"
1303+
description="Finally, provide a list of JSONPatch operations to be applied to"
1304+
" the previous tool call's response arguments. If none are required, return"
12991305
" an empty list. This field is REQUIRED."
13001306
" Multiple patches in the list are applied sequentially in the order provided,"
13011307
" with each patch building upon the result of the previous one."
@@ -1453,9 +1459,7 @@ def _get_message_op(
14531459
try:
14541460
patches = _ensure_patches(tool_call)
14551461
if patches:
1456-
patched_args = jsonpatch.apply_patch(
1457-
tc["args"], patches
1458-
)
1462+
patched_args = _apply_patch(tc["args"], patches)
14591463
msg_ops.append(
14601464
{
14611465
"op": "update_tool_call",
@@ -1620,7 +1624,7 @@ def _strip_injected(fn: Callable) -> Callable:
16201624
return _curry(fn, **{k: None for k in injected})
16211625

16221626

1623-
def _ensure_patches(args: dict) -> list[JsonPatch]:
1627+
def _ensure_patches(args: dict) -> list[jsonpatch.JsonPatch]:
16241628
patches = args.get("patches")
16251629
if isinstance(patches, list):
16261630
return patches
@@ -1656,6 +1660,44 @@ def _ensure_patches(args: dict) -> list[JsonPatch]:
16561660
return []
16571661

16581662

1663+
def _fix_string_concat(
1664+
doc: dict, patch: list[jsonpatch.JsonPatch]
1665+
) -> Optional[list[jsonpatch.JsonPatch]] | None:
1666+
fixed = False
1667+
result = []
1668+
for p in patch:
1669+
if p["path"] and p["path"].endswith("/-"):
1670+
new_path = p["path"][:-2]
1671+
pointer = jsonpointer.JsonPointer(new_path)
1672+
existing = pointer.resolve(doc)
1673+
if existing is not None and isinstance(existing, str):
1674+
fixed = True
1675+
result.append(
1676+
{
1677+
"path": new_path,
1678+
"op": "replace",
1679+
"value": existing + p["value"],
1680+
}
1681+
)
1682+
else:
1683+
result.append(p)
1684+
else:
1685+
result.append(p)
1686+
if not fixed:
1687+
return None
1688+
return result
1689+
1690+
1691+
def _apply_patch(doc: dict, patches: list[jsonpatch.JsonPatch]) -> dict:
1692+
try:
1693+
return jsonpatch.apply_patch(doc, patches)
1694+
except jsonpatch.JsonPatchConflict:
1695+
fixed = _fix_string_concat(doc, patches)
1696+
if fixed is not None:
1697+
return jsonpatch.apply_patch(doc, fixed)
1698+
raise
1699+
1700+
16591701
__all__ = [
16601702
"create_extractor",
16611703
"ensure_tools",

0 commit comments

Comments
 (0)