Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/core/langchain_core/document_loaders/langsmith.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
inline_s3_urls: bool = True,
offset: int = 0,
limit: int | None = None,
metadata: dict | None = None,
metadata: dict[str, Any] | None = None,
filter: str | None = None, # noqa: A002
content_key: str = "",
format_content: Callable[..., str] | None = None,
Expand Down
34 changes: 19 additions & 15 deletions libs/core/langchain_core/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from __future__ import annotations

import builtins
import contextlib
import json
import typing
from abc import ABC, abstractmethod
from collections.abc import Mapping # noqa: TC003
from collections.abc import Mapping
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
Expand All @@ -17,14 +17,14 @@

from langchain_core.exceptions import ErrorCode, create_message
from langchain_core.load import dumpd
from langchain_core.output_parsers.base import BaseOutputParser # noqa: TC001
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
PromptValue,
StringPromptValue,
)
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import ensure_config
from langchain_core.runnables.base import RunnableSerializable
from langchain_core.runnables.config import RunnableConfig, ensure_config
from langchain_core.utils.pydantic import create_model_v2

if TYPE_CHECKING:
Expand All @@ -37,7 +37,7 @@


class BasePromptTemplate(
RunnableSerializable[dict, PromptValue], ABC, Generic[FormatOutputType]
RunnableSerializable[dict[str, Any], PromptValue], ABC, Generic[FormatOutputType]
):
"""Base class for all prompt templates, returning a prompt."""

Expand All @@ -51,7 +51,7 @@ class BasePromptTemplate(

These variables are auto inferred from the prompt and user need not provide them.
"""
input_types: typing.Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006
input_types: builtins.dict[str, Any] = Field(default_factory=dict, exclude=True)
"""A dictionary of the types of the variables the prompt template expects.

If not provided, all variables are assumed to be strings.
Expand All @@ -64,7 +64,7 @@ class BasePromptTemplate(
Partial variables populate the template so that you don't need to pass them in every
time you call the prompt.
"""
metadata: typing.Dict[str, Any] | None = None # noqa: UP006
metadata: builtins.dict[str, Any] | None = None
"""Metadata to be used for tracing."""
tags: list[str] | None = None
"""Tags to be used for tracing."""
Expand Down Expand Up @@ -150,7 +150,7 @@ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseMod
field_definitions={**required_input_variables, **optional_input_variables},
)

def _validate_input(self, inner_input: Any) -> dict:
def _validate_input(self, inner_input: Any) -> dict[str, Any]:
if not isinstance(inner_input, dict):
if len(self.input_variables) == 1:
var_name = self.input_variables[0]
Expand Down Expand Up @@ -186,19 +186,21 @@ def _validate_input(self, inner_input: Any) -> dict:
)
return inner_input_

def _format_prompt_with_error_handling(self, inner_input: dict) -> PromptValue:
def _format_prompt_with_error_handling(
self, inner_input: dict[str, Any]
) -> PromptValue:
inner_input_ = self._validate_input(inner_input)
return self.format_prompt(**inner_input_)

async def _aformat_prompt_with_error_handling(
self, inner_input: dict
self, inner_input: dict[str, Any]
) -> PromptValue:
inner_input_ = self._validate_input(inner_input)
return await self.aformat_prompt(**inner_input_)

@override
def invoke(
self, input: dict, config: RunnableConfig | None = None, **kwargs: Any
self, input: dict[str, Any], config: RunnableConfig | None = None, **kwargs: Any
) -> PromptValue:
"""Invoke the prompt.

Expand All @@ -224,7 +226,7 @@ def invoke(

@override
async def ainvoke(
self, input: dict, config: RunnableConfig | None = None, **kwargs: Any
self, input: dict[str, Any], config: RunnableConfig | None = None, **kwargs: Any
) -> PromptValue:
"""Async invoke the prompt.

Expand Down Expand Up @@ -330,7 +332,7 @@ def _prompt_type(self) -> str:
"""Return the prompt type key."""
raise NotImplementedError

def dict(self, **kwargs: Any) -> dict:
def dict(self, **kwargs: Any) -> builtins.dict[str, Any]:
"""Return dictionary representation of prompt.

Args:
Expand Down Expand Up @@ -387,7 +389,9 @@ def save(self, file_path: Path | str) -> None:
raise ValueError(msg)


def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> dict:
def _get_document_info(
doc: Document, prompt: BasePromptTemplate[str]
) -> dict[str, Any]:
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
Expand Down
46 changes: 39 additions & 7 deletions libs/core/langchain_core/prompts/structured.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
"""Structured prompt template for a language model."""

from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
from collections.abc import (
AsyncIterator,
Awaitable,
Callable,
Iterator,
Mapping,
Sequence,
)
from typing import (
Any,
overload,
)

from pydantic import BaseModel, Field
from typing_extensions import override

from langchain_core._api.beta_decorator import beta
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts.chat import (
ChatPromptTemplate,
MessageLikeRepresentation,
Expand Down Expand Up @@ -133,15 +142,38 @@ class OutputSchema(BaseModel):
"""
return cls(messages, schema, **kwargs)

@overload
def __or__(
self, other: Mapping[str, Any]
) -> RunnableSerializable[dict[str, Any], dict[str, Any]]: ...

@overload
def __or__(
self,
other: Callable[[PromptValue], Runnable[PromptValue, Other]]
| Callable[[PromptValue], Awaitable[Runnable[PromptValue, Other]]],
) -> RunnableSerializable[dict[str, Any], Other]: ...

@overload
def __or__(
self,
other: Runnable[PromptValue, Other]
| Callable[[Iterator[PromptValue]], Iterator[Other]]
| Callable[[AsyncIterator[PromptValue]], AsyncIterator[Other]]
| Callable[[PromptValue], Runnable[PromptValue, Other]]
| Callable[[PromptValue], Awaitable[Runnable[PromptValue, Other]]]
| Callable[[PromptValue], Other],
) -> RunnableSerializable[dict[str, Any], Other]: ...

@override
def __or__(
self,
other: Runnable[Any, Other]
| Callable[[Iterator[Any]], Iterator[Other]]
| Callable[[AsyncIterator[Any]], AsyncIterator[Other]]
| Callable[[Any], Other]
| Mapping[str, Runnable[Any, Other] | Callable[[Any], Other] | Any],
) -> RunnableSerializable[dict, Other]:
other: Runnable[PromptValue, Other]
| Callable[[Iterator[PromptValue]], Iterator[Other]]
| Callable[[AsyncIterator[PromptValue]], AsyncIterator[Other]]
| Callable[[PromptValue], Other]
| Mapping[str, Runnable[PromptValue, Any] | Callable[[PromptValue], Any] | Any],
) -> RunnableSerializable[dict[str, Any], Any]:
return self.pipe(other)

def pipe(
Expand Down
Loading