Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ tests:
uv run python -m pytest --disable-socket --allow-unix-socket -vv --durations=10 tests/unit_tests

tests_watch:
uv run ptw --now . -- -vv -x tests/unit_tests
uv run --with-editable . ptw . -- -vv -x --ff tests/unit_tests

integration_tests:
uv run python -m pytest -v --durations=10 --cov=trustcall --cov-report=term-missing --cov-report=html --cov-config=.coveragerc tests/integration_tests
Expand Down
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ requires-python = "<4.0,>=3.10"
dependencies = [
"langgraph>=0.2.25",
"dydantic<1.0.0,>=0.0.8",
"jsonpatch<2.0,>=1.33",
"langgraph-prebuilt>=0.1.2",
"jsonpatch<2.0,>=1.33"
]
name = "trustcall"
version = "0.0.38"
version = "0.0.39"
description = "Tenacious & trustworthy tool calling built on LangGraph."
readme = "README.md"

Expand All @@ -20,7 +19,7 @@ readme = "README.md"
dev = [
"ruff<1.0.0,>=0.4.10",
"mypy<2.0.0,>=1.10.1",
"pytest<9.0.0,>=8.2.2",
"pytest>=8.2.2,<9.0.0",
"pytest-socket<1.0.0,>=0.7.0",
"langchain<1.0,>=0.3",
"langchain-openai<1.0,>=0.2",
Expand Down
2 changes: 1 addition & 1 deletion tests/evals/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def score_run(run: Run, example: Example) -> dict: # type: ignore
]
)
return {"results": results}
schema = create_model_from_schema(example.inputs["tool_def"]["parameters"])
schema = create_model_from_schema(example.inputs["tool_def"]["parameters"]) # type: ignore
try:
schema.model_validate(predicted)
results.append(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def foo() -> None:
"""bar"""
...

with pytest.raises(ValueError, match="At least one of"):
with pytest.raises(Exception):
create_extractor(
llm="openai:foo",
tools=[foo],
Expand Down
27 changes: 26 additions & 1 deletion tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import ValidationError
from typing_extensions import Annotated, TypedDict

from trustcall._base import _convert_any_typed_dicts_to_pydantic
from trustcall._base import _apply_patch, _convert_any_typed_dicts_to_pydantic


def test_convert_any_typed_dicts_to_pydantic():
Expand Down Expand Up @@ -89,3 +89,28 @@ class RecursiveType(TypedDict):
cyclic["next"] = cyclic
with pytest.raises(ValueError): # or RecursionError, depending on implementation
model(**cyclic)


@pytest.mark.parametrize(
"doc,patches,expected",
[
({"a": 1}, [{"op": "add", "path": "/b", "value": 2}], {"a": 1, "b": 2}),
({"a": 1}, [{"op": "remove", "path": "/a"}], {}),
({"a": 1}, [{"op": "replace", "path": "/a", "value": 2}], {"a": 2}),
# Expanded syntax of concatenation for strings (similar to arrays)
(
{"a": {"b": "hello"}},
[{"op": "add", "path": "/a/b/-", "value": " world"}],
{"a": {"b": "hello world"}},
),
# Similar, but within nested list/dict/etc.
(
{"a": [{"b": "foo"}, {"b": {"c": "hello"}}]},
[{"op": "add", "path": "/a/1/b/c/-", "value": " world"}],
{"a": [{"b": "foo"}, {"b": {"c": "hello world"}}]},
),
],
)
def test_apply_patch_concat(doc, patches, expected):
result = _apply_patch(doc, patches)
assert result == expected
78 changes: 60 additions & 18 deletions trustcall/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

import jsonpatch # type: ignore[import-untyped]
import jsonpointer # type: ignore[import-untyped]
import langsmith as ls
from dydantic import create_model_from_schema
from langchain_core.language_models import BaseChatModel
Expand All @@ -40,10 +41,10 @@
)
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.config import get_executor_for_config
from langchain_core.tools import BaseTool, InjectedToolArg, create_schema_from_function
from langgraph.constants import Send
from langgraph.graph import StateGraph, add_messages
from langgraph.prebuilt.tool_validator import ValidationNode, get_executor_for_config
from langgraph.types import Command
from langgraph.utils.runnable import RunnableCallable
from pydantic import (
Expand All @@ -58,6 +59,8 @@
)
from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict

from trustcall._validation_node import ValidationNode

logger = logging.getLogger("extraction")


Expand Down Expand Up @@ -812,7 +815,7 @@ def _teardown(
ToolCall(
id=tc["id"],
name=tool_name,
args=jsonpatch.apply_patch(target, patches),
args=_apply_patch(target, patches),
)
)
updated_docs[tc["id"]] = str(json_doc_id)
Expand Down Expand Up @@ -1221,11 +1224,12 @@ class PatchFunctionErrors(BaseModel):

json_doc_id: str = Field(
...,
description="The ID of the function you are patching.",
description="First, identify the json_doc_id of the function you are patching.",
)
planned_edits: str = Field(
...,
description="Write a bullet-point list of each ValidationError you encountered"
description="Second, write a bullet-point list of each ValidationError "
"you encountered"
" and the corresponding JSONPatch operation needed to heal it."
" For each operation, write why your initial guess was incorrect, "
" citing the corresponding types(s) from the JSONSchema"
Expand All @@ -1234,8 +1238,8 @@ class PatchFunctionErrors(BaseModel):
)
patches: list[JsonPatch] = Field(
...,
description="A list of JSONPatch operations to be applied to the"
" previous tool call's response arguments. If none are required, return"
description="Finally, provide a list of JSONPatch operations to be applied to"
" the previous tool call's response arguments. If none are required, return"
" an empty list. This field is REQUIRED."
" Multiple patches in the list are applied sequentially in the order provided,"
" with each patch building upon the result of the previous one.",
Expand All @@ -1254,17 +1258,19 @@ class PatchFunctionName(BaseModel):

json_doc_id: str = Field(
...,
description="The ID of the function you are patching.",
description="First, identify the json_doc_id of the function"
" you are patching.",
)
reasoning: list[str] = Field(
...,
description="At least 2 logical reasons why this action ought to be taken."
description="Seconds, provide at least 2 logical reasons why this"
" action ought to be taken."
"Cite the specific error(s) mentioned to motivate the fix.",
)
fixed_name: Optional[str] = Field(
...,
description="If you need to change the name of the function (e.g., "
f'from an "Unrecognized tool name" error), do so here.{vname}',
description="Finally, if you need to change the name of the function (e.g.,"
f' from an "Unrecognized tool name" error), do so here.{vname}',
)

return PatchFunctionName
Expand All @@ -1276,11 +1282,11 @@ class PatchDoc(BaseModel):

json_doc_id: str = Field(
...,
description="The json_doc_id of the document you are patching.",
description="First, identify the json_doc_id of the document you are patching.",
)
planned_edits: str = Field(
...,
description="Think step-by-step, reasoning over each required"
description="Seconds, think step-by-step, reasoning over each required"
" update and the corresponding JSONPatch operation to accomplish it."
" Cite the fields in the JSONSchema you referenced in developing this plan."
" Address each path as a group; don't switch between paths.\n"
Expand All @@ -1294,8 +1300,8 @@ class PatchDoc(BaseModel):
)
patches: list[JsonPatch] = Field(
...,
description="A list of JSONPatch operations to be applied to the"
" previous tool call's response arguments. If none are required, return"
description="Finally, provide a list of JSONPatch operations to be applied to"
" the previous tool call's response arguments. If none are required, return"
" an empty list. This field is REQUIRED."
" Multiple patches in the list are applied sequentially in the order provided,"
" with each patch building upon the result of the previous one."
Expand Down Expand Up @@ -1453,9 +1459,7 @@ def _get_message_op(
try:
patches = _ensure_patches(tool_call)
if patches:
patched_args = jsonpatch.apply_patch(
tc["args"], patches
)
patched_args = _apply_patch(tc["args"], patches)
msg_ops.append(
{
"op": "update_tool_call",
Expand Down Expand Up @@ -1620,7 +1624,7 @@ def _strip_injected(fn: Callable) -> Callable:
return _curry(fn, **{k: None for k in injected})


def _ensure_patches(args: dict) -> list[JsonPatch]:
def _ensure_patches(args: dict) -> list[jsonpatch.JsonPatch]:
patches = args.get("patches")
if isinstance(patches, list):
return patches
Expand Down Expand Up @@ -1656,6 +1660,44 @@ def _ensure_patches(args: dict) -> list[JsonPatch]:
return []


def _fix_string_concat(
doc: dict, patch: list[jsonpatch.JsonPatch]
) -> Optional[list[jsonpatch.JsonPatch]] | None:
fixed = False
result = []
for p in patch:
if p["path"] and p["path"].endswith("/-"):
new_path = p["path"][:-2]
pointer = jsonpointer.JsonPointer(new_path)
existing = pointer.resolve(doc)
if existing is not None and isinstance(existing, str):
fixed = True
result.append(
{
"path": new_path,
"op": "replace",
"value": existing + p["value"],
}
)
else:
result.append(p)
else:
result.append(p)
if not fixed:
return None
return result


def _apply_patch(doc: dict, patches: list[jsonpatch.JsonPatch]) -> dict:
try:
return jsonpatch.apply_patch(doc, patches)
except jsonpatch.JsonPatchConflict:
fixed = _fix_string_concat(doc, patches)
if fixed is not None:
return jsonpatch.apply_patch(doc, fixed)
raise


__all__ = [
"create_extractor",
"ensure_tools",
Expand Down
Loading