-
Notifications
You must be signed in to change notification settings - Fork 4
Text-to-SQL: Help agents turn natural language into SQL queries #762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,26 +1,9 @@ | ||
| import logging | ||
|
|
||
| import click | ||
| from click_aliases import ClickAliasedGroup | ||
|
|
||
| from ..util.cli import boot_click | ||
| from ..util.app import make_cli | ||
| from .convert.cli import convert_query | ||
| from .llm.cli import llm_cli | ||
| from .mcp.cli import cli as mcp_cli | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @click.group(cls=ClickAliasedGroup) | ||
| @click.option("--verbose", is_flag=True, required=False, help="Turn on logging") | ||
| @click.option("--debug", is_flag=True, required=False, help="Turn on logging with debug level") | ||
| @click.version_option() | ||
| @click.pass_context | ||
| def cli(ctx: click.Context, verbose: bool, debug: bool): | ||
| """ | ||
| Query utilities. | ||
| """ | ||
| return boot_click(ctx, verbose, debug) | ||
|
|
||
|
|
||
| cli = make_cli() | ||
| cli.add_command(convert_query, name="convert") | ||
| cli.add_command(llm_cli, name="llm") | ||
| cli.add_command(mcp_cli, name="mcp") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| """ | ||
| Use an LLM to query a database in human language via NLSQLTableQueryEngine. | ||
| Example code using LlamaIndex with vanilla Open AI, Azure Open AI, or Ollama. | ||
| """ | ||
|
|
||
| import dataclasses | ||
| import logging | ||
| import os | ||
| from typing import Optional | ||
|
|
||
| import sqlalchemy as sa | ||
|
|
||
| from cratedb_toolkit.query.llm.model import DatabaseInfo, ModelInfo | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| try: | ||
| from llama_index.core.base.embeddings.base import BaseEmbedding | ||
| from llama_index.core.base.response.schema import RESPONSE_TYPE | ||
| from llama_index.core.llms import LLM | ||
| from llama_index.core.query_engine import NLSQLTableQueryEngine | ||
| from llama_index.core.utilities.sql_wrapper import SQLDatabase | ||
| except ImportError: | ||
| pass | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class DataQuery: | ||
| """ | ||
| DataQuery helps agents turn natural language into SQL queries. | ||
| It's the little sister of Google's QueryData product. [1] | ||
|
|
||
| We recommend evaluating the Text-to-SQL interface using the Gemma models if you are | ||
| looking at non-frontier variants that need less resources for inference. However, | ||
| depending on the complexity of your problem, you may also want to use cutting-edge | ||
| models with your provider of choice at the cost of higher resource usage. | ||
|
|
||
| Attention: Any natural language SQL table query engine and Text-to-SQL application | ||
| should be aware that executing arbitrary SQL queries can be a security risk. | ||
| It is recommended to take precautions as needed, such as using restricted roles, | ||
| read-only databases, sandboxing, etc. | ||
|
|
||
| [1] https://cloud.google.com/blog/products/databases/introducing-querydata-for-near-100-percent-accurate-data-agents | ||
| [2] https://github.com/kupp0/multi-db-property-search-data-agents | ||
| """ | ||
|
|
||
| db: DatabaseInfo | ||
| model: ModelInfo | ||
| query_engine: Optional["NLSQLTableQueryEngine"] = None | ||
|
|
||
| def __post_init__(self): | ||
| self.setup() | ||
|
|
||
| def setup(self): | ||
| # Configure database connection and query engine. | ||
| logger.info("Connecting to CrateDB") | ||
| engine_crate = sa.create_engine(os.getenv("CRATEDB_SQLALCHEMY_URL", "crate://")) | ||
| engine_crate.connect() | ||
|
|
||
| # Configure model. | ||
| logger.info("Configuring LLM model") | ||
| llm: LLM | ||
| embed_model: BaseEmbedding | ||
| from cratedb_toolkit.query.llm.util import configure_llm | ||
|
|
||
| llm, embed_model = configure_llm(self.model) | ||
|
|
||
| # Configure query engine. | ||
| logger.info("Creating query engine") | ||
| sql_database = SQLDatabase( | ||
| engine_crate, | ||
| ignore_tables=self.db.ignore_tables, | ||
| include_tables=self.db.include_tables, | ||
| ) | ||
|
amotl marked this conversation as resolved.
|
||
| self.query_engine = NLSQLTableQueryEngine( | ||
| sql_database=sql_database, | ||
| llm=llm, | ||
| embed_model=embed_model, | ||
| ) | ||
|
|
||
| def ask(self, question: str) -> "RESPONSE_TYPE": | ||
| """Invoke an inquiry to the LLM.""" | ||
| if not self.query_engine: | ||
| raise ValueError("Query engine not configured") | ||
| logger.debug("Running query: %s", question) | ||
| return self.query_engine.query(question) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| import logging | ||
| import os | ||
| from typing import Optional | ||
|
|
||
| import click | ||
| from dotenv import load_dotenv | ||
|
|
||
| from cratedb_toolkit import DatabaseCluster | ||
| from cratedb_toolkit.query.llm.api import DataQuery | ||
| from cratedb_toolkit.query.llm.model import DatabaseInfo, ModelInfo, ModelProvider | ||
| from cratedb_toolkit.util.common import setup_logging | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def help_llm(): | ||
| """ | ||
| Use an LLM to query the database in human language. | ||
|
|
||
| Synopsis | ||
| ======== | ||
|
|
||
| export CRATEDB_CLUSTER_URL=crate://localhost/ | ||
| ctk query llm "What is the average value for sensor 1?" | ||
|
|
||
| """ # noqa: E501 | ||
|
|
||
|
|
||
| @click.command() | ||
| @click.argument("question") | ||
| @click.option("--schema", envvar="CRATEDB_SCHEMA", type=str, required=False, help="Schema where to operate on") | ||
| @click.option("--llm-provider", envvar="LLM_PROVIDER", type=str, required=True, help="LLM provider name") | ||
| @click.option("--llm-name", envvar="LLM_NAME", type=str, required=False, help="LLM model name for completions") | ||
| @click.option( | ||
| "--llm-embedding-name", envvar="LLM_EMBEDDING_NAME", type=str, required=False, help="LLM model name for embeddings" | ||
| ) | ||
| @click.option("--llm-api-key", envvar="LLM_API_KEY", type=str, required=False, help="LLM API key") | ||
| @click.pass_context | ||
| def llm_cli( | ||
| ctx: click.Context, | ||
| question: str, | ||
| schema: Optional[str], | ||
| llm_provider: str, | ||
| llm_name: Optional[str], | ||
| llm_embedding_name: Optional[str], | ||
| llm_api_key: Optional[str], | ||
| ): | ||
| """ | ||
| Use an LLM to query a database in human language. | ||
| """ | ||
| setup_logging() | ||
| load_dotenv() | ||
|
|
||
| # Connect to database. | ||
| dc = DatabaseCluster.from_options(ctx.meta["address"]) | ||
| engine = dc.adapter.engine | ||
| schema = os.getenv("CRATEDB_SCHEMA", "doc") | ||
|
|
||
| provider = ModelProvider(llm_provider) | ||
|
|
||
| # Parameter sanity checks and heuristics. | ||
| if not llm_name: | ||
| if provider in [ModelProvider.OPENAI, ModelProvider.AZURE]: | ||
| llm_name = "gpt-4.1" | ||
| elif provider in [ModelProvider.OLLAMA]: | ||
| llm_name = "gemma3:1b" | ||
| else: | ||
| raise ValueError("LLM completion model not selected") | ||
| if not llm_embedding_name: | ||
| if provider in [ModelProvider.OPENAI, ModelProvider.AZURE]: | ||
| llm_embedding_name = "text-embedding-3-large" | ||
| elif provider in [ModelProvider.OLLAMA]: | ||
| llm_embedding_name = "local" | ||
| else: | ||
| raise ValueError("LLM embedding model not selected") | ||
| if not llm_api_key: | ||
| if provider in [ModelProvider.OPENAI, ModelProvider.AZURE]: | ||
| llm_api_key = os.getenv("OPENAI_API_KEY") | ||
|
|
||
| logger.info("Selected LLM: completion=%s, embedding=%s", llm_name, llm_embedding_name) | ||
|
|
||
| # Submit query. | ||
| dq = DataQuery( | ||
| db=DatabaseInfo( | ||
| engine=engine, | ||
| schema=schema, | ||
| ), | ||
| model=ModelInfo( | ||
| provider=provider, | ||
| completion=llm_name, | ||
| embedding=llm_embedding_name, | ||
| api_key=llm_api_key, | ||
| ), | ||
| ) | ||
| response = dq.ask(question) | ||
|
|
||
| logger.info("Query was: %s", question) | ||
| logger.info("Answer was: %s", response) | ||
| logger.info("More (metadata, formatted sources):") | ||
| logger.info(response.get_formatted_sources()) | ||
| logger.info(response.metadata) | ||
| return response | ||
|
|
||
| # assert "Answer was: The average value for sensor 1 is approximately 17.03." in out # noqa: ERA001 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| import dataclasses | ||
| from enum import Enum | ||
| from typing import List, Optional | ||
|
|
||
| import sqlalchemy as sa | ||
|
|
||
|
|
||
| class ModelProvider(Enum): | ||
| """Model provider choices.""" | ||
|
|
||
| OPENAI = "openai" | ||
| AZURE = "azure" | ||
| OLLAMA = "ollama" | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class ModelInfo: | ||
| """Information about the model.""" | ||
|
|
||
| provider: ModelProvider | ||
| completion: str | ||
| embedding: str | ||
| endpoint: Optional[str] = None | ||
| instance: Optional[str] = None | ||
| api_key: Optional[str] = None | ||
| api_version: Optional[str] = None | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class DatabaseInfo: | ||
| """Information about the database.""" | ||
|
|
||
| engine: sa.engine.Engine | ||
| schema: Optional[str] = None | ||
| ignore_tables: Optional[List[str]] = None | ||
| include_tables: Optional[List[str]] = None |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| # ty: ignore[unresolved-import] | ||
| from typing import Tuple | ||
|
|
||
| import llama_index.core | ||
| from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings | ||
| from llama_index.core.base.embeddings.base import BaseEmbedding | ||
| from llama_index.core.llms import LLM | ||
| from llama_index.embeddings.langchain import LangchainEmbedding | ||
| from llama_index.llms.azure_openai import AzureOpenAI | ||
| from llama_index.llms.ollama import Ollama | ||
| from llama_index.llms.openai import OpenAI | ||
|
|
||
| from cratedb_toolkit.query.llm.model import ModelInfo, ModelProvider | ||
|
|
||
|
|
||
| def configure_llm(info: ModelInfo, debug: bool = False) -> Tuple[LLM, BaseEmbedding]: | ||
| """ | ||
| Configure LLM access and model types. Use either vanilla Open AI, Azure Open AI, or Ollama. | ||
|
|
||
| TODO: What about Hugging Face, Runpod, vLLM, and others? | ||
|
|
||
| Notes about text embedding models: | ||
|
|
||
| > The new model, `text-embedding-ada-002`, replaces five separate models for text search, | ||
| > text similarity, and code search, and outperforms our previous most capable model, | ||
| > Davinci, at most tasks, while being priced 99.8% lower. | ||
|
|
||
| - https://openai.com/index/new-and-improved-embedding-model/ | ||
| - https://community.openai.com/t/models-embedding-vs-similarity-vs-search-models/291265 | ||
| """ | ||
|
|
||
| completion_model = info.completion | ||
| embedding_model = info.embedding or "text-embedding-3-large" | ||
|
|
||
| if not info.provider: | ||
| raise ValueError("LLM model type not defined") | ||
| if not completion_model: | ||
| raise ValueError("LLM model name not defined") | ||
|
|
||
| # https://docs.llamaindex.ai/en/stable/understanding/tracing_and_debugging/tracing_and_debugging/ | ||
| if debug: | ||
| llama_index.core.set_global_handler("simple") | ||
|
|
||
| if info.provider is ModelProvider.OPENAI: | ||
| llm = OpenAI( | ||
| model=completion_model, | ||
| temperature=0.0, | ||
| api_key=info.api_key, | ||
| api_version=info.api_version, | ||
| ) | ||
| elif info.provider is ModelProvider.AZURE: | ||
| llm = AzureOpenAI( | ||
| model=completion_model, | ||
| temperature=0.0, | ||
| engine=info.instance, | ||
| azure_endpoint=info.endpoint, | ||
| api_key=info.api_key, | ||
| api_version=info.api_version, | ||
| ) | ||
|
Comment on lines
+45
to
+59
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: cat -n cratedb_toolkit/query/llm/util.py | head -100Repository: crate/cratedb-toolkit Length of output: 3759 🏁 Script executed: cat -n cratedb_toolkit/query/llm/model.pyRepository: crate/cratedb-toolkit Length of output: 906 Honor
Proposed fix if info.type is ModelType.OPENAI:
llm = OpenAI(
model=completion_model,
temperature=0.0,
- api_key=os.getenv("OPENAI_API_KEY"),
+ api_key=info.api_key or os.getenv("OPENAI_API_KEY"),
)
@@
llm = AzureOpenAI(
model=completion_model,
temperature=0.0,
- engine=os.getenv("LLM_INSTANCE"),
- azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT"),
- api_key=os.getenv("OPENAI_API_KEY"),
- api_version=os.getenv("OPENAI_AZURE_API_VERSION"),
+ engine=info.instance or os.getenv("LLM_INSTANCE"),
+ azure_endpoint=info.endpoint or os.getenv("OPENAI_AZURE_ENDPOINT"),
+ api_key=info.api_key or os.getenv("OPENAI_API_KEY"),
+ api_version=info.api_version or os.getenv("OPENAI_AZURE_API_VERSION"),
)
@@
AzureOpenAIEmbeddings(
- azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT"),
+ azure_endpoint=info.endpoint or os.getenv("OPENAI_AZURE_ENDPOINT"),
model=embedding_model,
)🤖 Prompt for AI Agents |
||
| elif info.provider is ModelProvider.OLLAMA: | ||
| # https://docs.llamaindex.ai/en/stable/api_reference/llms/ollama/ | ||
| llm = Ollama( | ||
| base_url=info.endpoint or "http://localhost:11434", | ||
| model=completion_model, | ||
| temperature=0.0, | ||
| request_timeout=120.0, | ||
| keep_alive=-1, | ||
| ) | ||
| else: | ||
| raise ValueError("LLM model type invalid: %s", info.provider) | ||
|
|
||
| if info.provider is ModelProvider.OPENAI: | ||
| embed_model = LangchainEmbedding(OpenAIEmbeddings(model=embedding_model)) | ||
| elif info.provider is ModelProvider.AZURE: | ||
| embed_model = LangchainEmbedding( | ||
| AzureOpenAIEmbeddings( | ||
| azure_endpoint=info.endpoint, | ||
| model=embedding_model, | ||
| ) | ||
| ) | ||
| else: | ||
| embed_model = "local" | ||
|
|
||
| return llm, embed_model # ty: ignore[invalid-return-type] | ||
Uh oh!
There was an error while loading. Please reload this page.