diff --git a/docker/pythonpath_dev/superset_config_docker_light.py b/docker/pythonpath_dev/superset_config_docker_light.py index 9a5ae0ae67ab..1f053c2ce363 100644 --- a/docker/pythonpath_dev/superset_config_docker_light.py +++ b/docker/pythonpath_dev/superset_config_docker_light.py @@ -19,6 +19,7 @@ # Import all settings from the main config first from flask_caching.backends.filesystemcache import FileSystemCache + from superset_config import * # noqa: F403 # Override caching to use simple in-memory cache instead of Redis diff --git a/pyproject.toml b/pyproject.toml index 0f1aa6faaf1b..85b78489fbae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ + "anthropic>=0.6.1, <1.0", # no bounds for apache-superset-core until we have a stable version "apache-superset-core", "backoff>=1.8.0", @@ -57,6 +58,7 @@ dependencies = [ "flask-session>=0.4.0, <1.0", "flask-wtf>=1.1.0, <2.0", "geopy", + "google-genai>=1.28.0, <2.0", "greenlet>=3.0.3, <=3.1.1", "gunicorn>=22.0.0; sys_platform != 'win32'", "hashids>=1.3.1, <2", @@ -73,6 +75,7 @@ dependencies = [ "msgpack>=1.0.0, <1.1", "nh3>=0.2.11, <0.3", "numpy>1.23.5, <2.3", + "openai>=1.99.1, <2", "packaging", # -------------------------- # pandas and related (wanting pandas[performance] without numba as it's 100+MB and not needed) @@ -103,6 +106,8 @@ dependencies = [ "sqlglot>=27.3.0, <28", # newer pandas needs 0.9+ "tabulate>=0.9.0, <1.0", + "sqlparse>=0.5.3, <1.0", + "tiktoken>=0.10.0, <1", "typing-extensions>=4, <5", "waitress; sys_platform == 'win32'", "watchdog>=6.0.0", @@ -407,6 +412,7 @@ authorized_licenses = [ "apache software, bsd", "bsd", "bsd-3-clause", + "cnri-python", "isc license (iscl)", "isc license", "mit", @@ -428,6 +434,8 @@ authorized_licenses = [ polyline = "2" # Apache 2.0 https://github.com/hkwi/python-geohash python-geohash = "0" +# Apache-2.0 AND CNRI-Python (both licenses individually approved above) +regex = "2025.7.34" # -------------------------------------------------------------- # TODO REMOVE THESE DEPS FROM CODEBASE diff --git a/requirements/base.txt b/requirements/base.txt index 3ff7c38950de..51b3558352eb 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -8,6 +8,14 @@ amqp==5.3.1 # via kombu annotated-types==0.7.0 # via pydantic +anthropic==0.61.0 + # via apache-superset (pyproject.toml) +anyio==4.10.0 + # via + # anthropic + # google-genai + # httpx + # openai apispec==6.6.1 # via # -r requirements/base.in @@ -15,7 +23,9 @@ apispec==6.6.1 apsw==3.50.1.0 # via shillelagh async-timeout==4.0.3 - # via -r requirements/base.in + # via + # -r requirements/base.in + # redis attrs==25.3.0 # via # cattrs @@ -50,6 +60,8 @@ celery==5.5.2 # via apache-superset (pyproject.toml) certifi==2025.6.15 # via + # httpcore + # httpx # requests # selenium cffi==1.17.1 @@ -95,12 +107,22 @@ deprecated==1.2.18 # via limits deprecation==2.1.0 # via apache-superset (pyproject.toml) +distro==1.9.0 + # via + # anthropic + # openai dnspython==2.7.0 # via email-validator email-validator==2.2.0 # via flask-appbuilder et-xmlfile==2.0.0 # via openpyxl +exceptiongroup==1.3.0 + # via + # anyio + # cattrs + # trio + # trio-websocket flask==2.3.3 # via # apache-superset (pyproject.toml) @@ -155,25 +177,39 @@ geographiclib==2.0 geopy==2.4.1 # via apache-superset (pyproject.toml) google-auth==2.40.3 - # via shillelagh + # via + # google-genai + # shillelagh +google-genai==1.28.0 + # via apache-superset (pyproject.toml) greenlet==3.1.1 # via # apache-superset (pyproject.toml) # shillelagh - # sqlalchemy gunicorn==23.0.0 # via apache-superset (pyproject.toml) h11==0.16.0 - # via wsproto + # via + # httpcore + # wsproto hashids==1.3.1 # via apache-superset (pyproject.toml) holidays==0.25 # via apache-superset (pyproject.toml) +httpcore==1.0.9 + # via httpx +httpx==0.28.1 + # via + # anthropic + # google-genai + # openai humanize==4.12.3 # via apache-superset (pyproject.toml) idna==3.10 # via + # anyio # email-validator + # httpx # requests # trio # url-normalize @@ -187,6 +223,10 @@ jinja2==3.1.6 # via # flask # flask-babel +jiter==0.10.0 + # via + # anthropic + # openai jsonpath-ng==1.7.0 # via apache-superset (pyproject.toml) jsonschema==4.23.0 @@ -248,6 +288,8 @@ numpy==1.26.4 # pyarrow odfpy==1.4.1 # via pandas +openai==1.99.1 + # via apache-superset (pyproject.toml) openapi-schema-validator==0.6.3 # via -r requirements/base.in openpyxl==3.1.5 @@ -300,7 +342,11 @@ pyasn1-modules==0.4.2 pycparser==2.22 # via cffi pydantic==2.11.7 - # via apache-superset (pyproject.toml) + # via + # apache-superset (pyproject.toml) + # anthropic + # google-genai + # openai pydantic-core==2.33.2 # via pydantic pygments==2.19.1 @@ -348,10 +394,14 @@ referencing==0.36.2 # via # jsonschema # jsonschema-specifications +regex==2025.7.34 + # via tiktoken requests==2.32.4 # via + # google-genai # requests-cache # shillelagh + # tiktoken requests-cache==1.2.1 # via shillelagh rfc3339-validator==0.1.4 @@ -379,7 +429,11 @@ six==1.17.0 slack-sdk==3.35.0 # via apache-superset (pyproject.toml) sniffio==1.3.1 - # via trio + # via + # anthropic + # anyio + # openai + # trio sortedcontainers==2.4.0 # via trio sqlalchemy==1.4.54 @@ -397,10 +451,18 @@ sqlalchemy-utils==0.38.3 # flask-appbuilder sqlglot==27.3.0 # via apache-superset (pyproject.toml) +sqlparse==0.5.3 + # via apache-superset (pyproject.toml) sshtunnel==0.4.0 # via apache-superset (pyproject.toml) tabulate==0.9.0 # via apache-superset (pyproject.toml) +tenacity==8.5.0 + # via google-genai +tiktoken==0.10.0 + # via apache-superset (pyproject.toml) +tqdm==4.67.1 + # via openai trio==0.30.0 # via # selenium @@ -411,12 +473,18 @@ typing-extensions==4.14.0 # via # apache-superset (pyproject.toml) # alembic + # anthropic + # anyio # cattrs + # exceptiongroup + # google-genai # limits + # openai # pydantic # pydantic-core # pyopenssl # referencing + # rich # selenium # shillelagh # typing-inspection @@ -445,6 +513,8 @@ wcwidth==0.2.13 # via prompt-toolkit websocket-client==1.8.0 # via selenium +websockets==15.0.1 + # via google-genai werkzeug==3.1.3 # via # -r requirements/base.in diff --git a/requirements/development.txt b/requirements/development.txt index 317ef4e119a8..e321e97024d5 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -22,6 +22,17 @@ annotated-types==0.7.0 # via # -c requirements/base-constraint.txt # pydantic +anthropic==0.61.0 + # via + # -c requirements/base-constraint.txt + # apache-superset +anyio==4.10.0 + # via + # -c requirements/base-constraint.txt + # anthropic + # google-genai + # httpx + # openai apispec==6.6.1 # via # -c requirements/base-constraint.txt @@ -32,6 +43,10 @@ apsw==3.50.1.0 # shillelagh astroid==3.3.10 # via pylint +async-timeout==4.0.3 + # via + # -c requirements/base-constraint.txt + # redis attrs==25.3.0 # via # -c requirements/base-constraint.txt @@ -89,6 +104,8 @@ celery==5.5.2 certifi==2025.6.15 # via # -c requirements/base-constraint.txt + # httpcore + # httpx # requests # selenium cffi==1.17.1 @@ -175,6 +192,11 @@ dill==0.4.0 # via pylint distlib==0.3.8 # via virtualenv +distro==1.9.0 + # via + # -c requirements/base-constraint.txt + # anthropic + # openai dnspython==2.7.0 # via # -c requirements/base-constraint.txt @@ -193,6 +215,14 @@ et-xmlfile==2.0.0 # via # -c requirements/base-constraint.txt # openpyxl +exceptiongroup==1.3.0 + # via + # -c requirements/base-constraint.txt + # anyio + # cattrs + # pytest + # trio + # trio-websocket filelock==3.12.2 # via virtualenv flask==2.3.3 @@ -300,6 +330,7 @@ google-auth==2.40.3 # google-auth-oauthlib # google-cloud-bigquery # google-cloud-core + # google-genai # pandas-gbq # pydata-google-auth # shillelagh @@ -319,6 +350,10 @@ google-cloud-core==2.4.1 # via google-cloud-bigquery google-crc32c==1.6.0 # via google-resumable-media +google-genai==1.28.0 + # via + # -c requirements/base-constraint.txt + # apache-superset google-resumable-media==2.7.2 # via google-cloud-bigquery googleapis-common-protos==1.66.0 @@ -331,7 +366,6 @@ greenlet==3.1.1 # apache-superset # gevent # shillelagh - # sqlalchemy grpcio==1.71.0 # via # apache-superset @@ -346,6 +380,7 @@ gunicorn==23.0.0 h11==0.16.0 # via # -c requirements/base-constraint.txt + # httpcore # wsproto hashids==1.3.1 # via @@ -356,6 +391,16 @@ holidays==0.25 # -c requirements/base-constraint.txt # apache-superset # prophet +httpcore==1.0.9 + # via + # -c requirements/base-constraint.txt + # httpx +httpx==0.28.1 + # via + # -c requirements/base-constraint.txt + # anthropic + # google-genai + # openai humanize==4.12.3 # via # -c requirements/base-constraint.txt @@ -365,7 +410,9 @@ identify==2.5.36 idna==3.10 # via # -c requirements/base-constraint.txt + # anyio # email-validator + # httpx # requests # trio # url-normalize @@ -390,6 +437,11 @@ jinja2==3.1.6 # apache-superset-extensions-cli # flask # flask-babel +jiter==0.10.0 + # via + # -c requirements/base-constraint.txt + # anthropic + # openai jsonpath-ng==1.7.0 # via # -c requirements/base-constraint.txt @@ -501,6 +553,10 @@ odfpy==1.4.1 # via # -c requirements/base-constraint.txt # pandas +openai==1.99.1 + # via + # -c requirements/base-constraint.txt + # apache-superset openapi-schema-validator==0.6.3 # via # -c requirements/base-constraint.txt @@ -640,7 +696,10 @@ pycparser==2.22 pydantic==2.11.7 # via # -c requirements/base-constraint.txt + # anthropic # apache-superset + # google-genai + # openai pydantic-core==2.33.2 # via # -c requirements/base-constraint.txt @@ -751,18 +810,24 @@ referencing==0.36.2 # jsonschema # jsonschema-path # jsonschema-specifications +regex==2025.7.34 + # via + # -c requirements/base-constraint.txt + # tiktoken requests==2.32.4 # via # -c requirements/base-constraint.txt # docker # google-api-core # google-cloud-bigquery + # google-genai # jsonschema-path # pydruid # pyhive # requests-cache # requests-oauthlib # shillelagh + # tiktoken # trino requests-cache==1.2.1 # via @@ -824,6 +889,9 @@ slack-sdk==3.35.0 sniffio==1.3.1 # via # -c requirements/base-constraint.txt + # anthropic + # anyio + # openai # trio sortedcontainers==2.4.0 # via @@ -854,6 +922,10 @@ sqlglot==27.3.0 # apache-superset sqloxide==0.1.51 # via apache-superset +sqlparse==0.5.3 + # via + # -c requirements/base-constraint.txt + # apache-superset sshtunnel==0.4.0 # via # -c requirements/base-constraint.txt @@ -864,11 +936,27 @@ tabulate==0.9.0 # via # -c requirements/base-constraint.txt # apache-superset +tenacity==8.5.0 + # via + # -c requirements/base-constraint.txt + # google-genai +tiktoken==0.10.0 + # via + # -c requirements/base-constraint.txt + # apache-superset +tomli==2.2.1 + # via + # apache-superset-extensions-cli + # coverage + # pylint + # pytest tomlkit==0.13.3 # via pylint tqdm==4.67.1 # via + # -c requirements/base-constraint.txt # cmdstanpy + # openai # prophet trino==0.330.0 # via apache-superset @@ -885,13 +973,20 @@ typing-extensions==4.14.0 # via # -c requirements/base-constraint.txt # alembic + # anthropic + # anyio # apache-superset + # astroid # cattrs + # exceptiongroup + # google-genai # limits + # openai # pydantic # pydantic-core # pyopenssl # referencing + # rich # selenium # shillelagh # typing-inspection @@ -938,6 +1033,10 @@ websocket-client==1.8.0 # via # -c requirements/base-constraint.txt # selenium +websockets==15.0.1 + # via + # -c requirements/base-constraint.txt + # google-genai werkzeug==3.1.3 # via # -c requirements/base-constraint.txt diff --git a/superset-frontend/packages/superset-ui-core/src/components/Icons/AntdEnhanced.tsx b/superset-frontend/packages/superset-ui-core/src/components/Icons/AntdEnhanced.tsx index fea1eec7700c..bd8fe30f0e99 100644 --- a/superset-frontend/packages/superset-ui-core/src/components/Icons/AntdEnhanced.tsx +++ b/superset-frontend/packages/superset-ui-core/src/components/Icons/AntdEnhanced.tsx @@ -37,6 +37,7 @@ import { CalculatorOutlined, CaretUpOutlined, CaretDownOutlined, + CaretDownFilled, CaretLeftOutlined, CaretRightOutlined, CaretRightFilled, @@ -173,6 +174,7 @@ const AntdIcons = { CalculatorOutlined, CaretUpOutlined, CaretDownOutlined, + CaretDownFilled, CaretLeftOutlined, CaretRightOutlined, CaretRightFilled, diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js index 019a1fe3ef98..515b26a3301e 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.js @@ -104,6 +104,11 @@ export const SET_EDITOR_TAB_LAST_UPDATE = 'SET_EDITOR_TAB_LAST_UPDATE'; export const SET_LAST_UPDATED_ACTIVE_TAB = 'SET_LAST_UPDATED_ACTIVE_TAB'; export const CLEAR_DESTROYED_QUERY_EDITOR = 'CLEAR_DESTROYED_QUERY_EDITOR'; +export const GENERATE_SQL = 'GENERATE_SQL'; +export const START_GENERATE_SQL = 'START_GENERATE_SQL'; +export const GENERATE_SQL_DONE = 'GENERATE_SQL_DONE'; +export const GENERATE_SQL_SET_PROMPT = 'GENERATE_SQL_SET_PROMPT'; + export const addInfoToast = addInfoToastAction; export const addSuccessToast = addSuccessToastAction; export const addDangerToast = addDangerToastAction; @@ -332,6 +337,111 @@ export function fetchQueryResults(query, displayLimit, timeoutInMs) { }; } +export function queryEditorSetSql(queryEditor, sql, queryId) { + return { type: QUERY_EDITOR_SET_SQL, queryEditor, sql, queryId }; +} + +function convertSqlToComment(sql) { + const oldSql = sql.split('\n'); + let commentedSql = ''; + + if (sql.trim() !== '') { + const contextBuilder = []; + for (let i = 0; i < oldSql.length; i += 1) { + if (oldSql[i].startsWith('--')) { + contextBuilder.push(oldSql[i]); + } else if (i === oldSql.length - 1 && oldSql[i].trim() === '') { + continue; + } else { + contextBuilder.push(`-- ${oldSql[i]}`); + } + } + commentedSql = `${contextBuilder.join('\n')}\n`; + } + return commentedSql; +} + +export function queryEditorSetAndSaveSql(targetQueryEditor, sql, queryId) { + return function (dispatch, getState) { + const queryEditor = getUpToDateQuery(getState(), targetQueryEditor); + // saved query and set tab state use this action + dispatch(queryEditorSetSql(queryEditor, sql, queryId)); + if (isFeatureEnabled(FeatureFlag.SqllabBackendPersistence)) { + return SupersetClient.put({ + endpoint: encodeURI(`/tabstateview/${queryEditor.id}`), + postPayload: { sql, latest_query_id: queryId }, + }).catch(() => + dispatch( + addDangerToast( + t( + 'An error occurred while storing your query in the backend. To ' + + 'avoid losing your changes, please save your query using the ' + + '"Save Query" button.', + ), + ), + ), + ); + } + return Promise.resolve(); + }; +} + +export function generateSql(databaseId, queryEditor, prompt) { + return function (dispatch, getState) { + dispatch({ + type: START_GENERATE_SQL, + queryEditorId: queryEditor.id, + prompt, + }); + const { sql } = getUpToDateQuery(getState(), queryEditor); + return SupersetClient.post({ + endpoint: '/api/v1/sqllab/generate_sql/', + body: JSON.stringify({ + database_id: databaseId, + user_prompt: prompt, + prior_context: sql, + schemas: [queryEditor.schema], + }), + headers: { 'Content-Type': 'application/json' }, + }) + .then(({ json }) => { + const oldContext = convertSqlToComment(sql); + const newQuestion = `-- ${prompt}\n\n`; + const newSql = [ + oldContext === '' ? '' : `${oldContext}\n`, + newQuestion, + json.sql, + ].join(''); + + // TODO(AW): Is it better to dispatch two events here, or have the DONE event dispatch its own event? + dispatch(queryEditorSetAndSaveSql(queryEditor, newSql)); + dispatch({ + type: GENERATE_SQL_DONE, + queryEditorId: queryEditor.id, + prompt: '', + }); + // TODO(AW): Formatting the query makes the response from the LLM easier to read + // but messes up the formatting of the question and previous query. + // dispatch(formatQuery(queryEditor)); + }) + .catch(() => { + // TODO(AW): Same question as above - should we try to combine these two events? + dispatch( + addDangerToast(t('An error occurred while generating the SQL')), + ); + dispatch({ + type: GENERATE_SQL_DONE, + queryEditorId: queryEditor.id, + prompt, + }); + }); + }; +} + +export function setGenerateSqlPrompt(queryEditorId, prompt) { + return { type: GENERATE_SQL_SET_PROMPT, queryEditorId, prompt }; +} + export function runQuery(query, runPreviewOnly) { return function (dispatch) { dispatch(startQuery(query, runPreviewOnly)); @@ -341,7 +451,11 @@ export function runQuery(query, runPreviewOnly) { json: true, runAsync: query.runAsync, catalog: query.catalog, - schema: query.schema, + schema: Array.isArray(query.schema) + ? query.schema.length > 0 + ? query.schema[0] + : '' + : query.schema, sql: query.sql, sql_editor_id: query.sqlEditorId, tab: query.tab, @@ -398,7 +512,11 @@ export function runQueryFromSqlEditor( immutableId: qe.immutableId, tab: qe.name, catalog: qe.catalog, - schema: qe.schema, + schema: Array.isArray(qe.schema) + ? qe.schema.length > 0 + ? qe.schema[0] + : '' + : qe.schema, tempTable, templateParams: qe.templateParams, queryLimit: qe.queryLimit || defaultQueryLimit, @@ -878,41 +996,10 @@ export function updateSavedQuery(query, clientId) { }) .then(() => dispatch(updateQueryEditor(query))); } - -export function queryEditorSetSql(queryEditor, sql, queryId) { - return { type: QUERY_EDITOR_SET_SQL, queryEditor, sql, queryId }; -} - export function queryEditorSetCursorPosition(queryEditor, position) { return { type: QUERY_EDITOR_SET_CURSOR_POSITION, queryEditor, position }; } -export function queryEditorSetAndSaveSql(targetQueryEditor, sql, queryId) { - return function (dispatch, getState) { - const queryEditor = getUpToDateQuery(getState(), targetQueryEditor); - // saved query and set tab state use this action - dispatch(queryEditorSetSql(queryEditor, sql, queryId)); - const queryEditorId = queryEditor.tabViewId ?? queryEditor.id; - if (isFeatureEnabled(FeatureFlag.SqllabBackendPersistence)) { - return SupersetClient.put({ - endpoint: encodeURI(`/tabstateview/${queryEditorId}`), - postPayload: { sql, latest_query_id: queryId }, - }).catch(() => - dispatch( - addDangerToast( - t( - 'An error occurred while storing your query in the backend. To ' + - 'avoid losing your changes, please save your query using the ' + - '"Save Query" button.', - ), - ), - ), - ); - } - return Promise.resolve(); - }; -} - export function formatQuery(queryEditor) { return function (dispatch, getState) { const { sql } = getUpToDateQuery(getState(), queryEditor); diff --git a/superset-frontend/src/SqlLab/components/AIAssistantEditor/index.tsx b/superset-frontend/src/SqlLab/components/AIAssistantEditor/index.tsx new file mode 100644 index 000000000000..d790b064ff18 --- /dev/null +++ b/superset-frontend/src/SqlLab/components/AIAssistantEditor/index.tsx @@ -0,0 +1,176 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { ChangeEvent } from 'react'; + +import { useDispatch, useSelector } from 'react-redux'; + +import { css, t, styled } from '@superset-ui/core'; + +import { Button, Icons, Input } from '@superset-ui/core/components'; +import { LOG_ACTIONS_AI_ASSISTANT_OPENED } from 'src/logger/LogUtils'; +import useLogAction from 'src/logger/useLogAction'; + +import { setGenerateSqlPrompt } from 'src/SqlLab/actions/sqlLab'; +import { SqlLabRootState } from 'src/SqlLab/types'; + +export interface AiAssistantEditorProps { + queryEditorId: string; + onGenerateSql: (prompt: string) => void; + isGeneratingSql: boolean; + schema?: string | string[]; + disabledMessage?: string; +} + +const StyledButton = styled.span` + button { + line-height: 20px; + } +`; + +const StyledToolbar = styled.div` + ${({ theme }) => css` + padding: ${theme.sizeUnit * 2}px; + background: ${theme.colors.grayscale.light5}; + border: 1px solid ${theme.colors.grayscale.light2}; + border-bottom: 0; + margin-bottom: 0; + + .assist-input { + display: flex; + justify-content: space-between; + align-items: center; + column-gap: ${theme.sizeUnit}px; + } + + form { + margin-block-end: 0; + } + + .label { + width: ${theme.sizeUnit * 25}px; + height: 100%; + color: ${theme.colors.grayscale.base}; + font-size: ${theme.fontSize}px; + } + `} +`; + +const DisabledMessage = styled.div` + ${({ theme }) => css` + color: ${theme.colors.error.base}; + margin-top: ${theme.sizeUnit * 2}px; + margin-left: ${theme.sizeUnit * 2}px; + font-size: ${theme.fontSizeSM}px; + padding: ${theme.sizeUnit * 2}px; + `} +`; + +const SelectedSchemaMessage = styled.div` + ${({ theme }) => css` + color: ${theme.colors.grayscale.base}; + margin-top: ${theme.sizeUnit * 2}px; + margin-left: ${theme.sizeUnit * 2}px; + font-size: ${theme.fontSizeSM}px; + padding: ${theme.sizeUnit * 2}px; + display: flex; + align-items: center; + column-gap: ${theme.sizeUnit}px; + `} +`; + +const onClick = ( + logAction: (name: string, payload: Record) => void, +): void => { + logAction(LOG_ACTIONS_AI_ASSISTANT_OPENED, { shortcut: false }); +}; + +const AiAssistantEditor = ({ + queryEditorId, + onGenerateSql, + isGeneratingSql = false, + schema = [], + disabledMessage, +}: AiAssistantEditorProps) => { + const dispatch = useDispatch(); + const logAction = useLogAction({ queryEditorId }); + + const changePrompt = (event: ChangeEvent) => { + dispatch(setGenerateSqlPrompt(queryEditorId, event.target.value)); + }; + + const prompt = useSelector((state: SqlLabRootState) => { + const queryEditor = state.sqlLab.queryEditors.find( + qe => qe.id === queryEditorId, + ); + if (queryEditor) { + return queryEditor.queryGenerator?.prompt || ''; + } + return ''; + }); + + const isDisabled = isGeneratingSql || !!disabledMessage; + + return ( + +
+ AI Assist + { + if (!isDisabled && e.key === 'Enter') { + e.preventDefault(); + onClick(logAction); + onGenerateSql(prompt); + } + }} + /> + + + +
+ {disabledMessage ? ( + {disabledMessage} + ) : schema && schema.length > 0 ? ( + + + {`Selecting schema will restrict the AI to generate SQL for only the selected schema. This will increase costs due to skipping the AI cache. Currently selected: ${Array.isArray(schema) ? schema.join(', ') : schema}`} + + ) : null} +
+ ); +}; + +export default AiAssistantEditor; diff --git a/superset-frontend/src/SqlLab/components/AceEditorWrapper/index.tsx b/superset-frontend/src/SqlLab/components/AceEditorWrapper/index.tsx index bfa615765b60..948e796da07d 100644 --- a/superset-frontend/src/SqlLab/components/AceEditorWrapper/index.tsx +++ b/superset-frontend/src/SqlLab/components/AceEditorWrapper/index.tsx @@ -161,7 +161,7 @@ const AceEditorWrapper = ({ const { data: annotations } = useAnnotations({ dbId: queryEditor.dbId, catalog: queryEditor.catalog, - schema: queryEditor.schema, + // schema: queryEditor.schema, sql: currentSql, templateParams: queryEditor.templateParams, }); diff --git a/superset-frontend/src/SqlLab/components/AceEditorWrapper/useAnnotations.ts b/superset-frontend/src/SqlLab/components/AceEditorWrapper/useAnnotations.ts index c64605c39583..eb5dc9220434 100644 --- a/superset-frontend/src/SqlLab/components/AceEditorWrapper/useAnnotations.ts +++ b/superset-frontend/src/SqlLab/components/AceEditorWrapper/useAnnotations.ts @@ -55,7 +55,7 @@ export function useAnnotations(params: FetchValidationQueryParams) { const errorObj = (error ?? {}) as ClientErrorObject; let message = errorObj?.error || errorObj?.statusText || t('Unknown error'); - if (message.includes('CSRF token')) { + if (typeof message === 'string' && message.includes('CSRF token')) { message = t(COMMON_ERR_MESSAGES.SESSION_TIMED_OUT); } return { diff --git a/superset-frontend/src/SqlLab/components/AceEditorWrapper/useKeywords.test.ts b/superset-frontend/src/SqlLab/components/AceEditorWrapper/useKeywords.test.ts index e0f49b70e316..563903d5d90f 100644 --- a/superset-frontend/src/SqlLab/components/AceEditorWrapper/useKeywords.test.ts +++ b/superset-frontend/src/SqlLab/components/AceEditorWrapper/useKeywords.test.ts @@ -47,12 +47,14 @@ const fakeTableApiResult = { value: 'fake api result1', label: 'fake api label1', type: 'table', + schema: 'main', }, { id: 2, value: 'fake api result2', label: 'fake api label2', type: 'table', + schema: 'main', }, ], }; diff --git a/superset-frontend/src/SqlLab/components/ColumnElement/index.tsx b/superset-frontend/src/SqlLab/components/ColumnElement/index.tsx index 6150c6e3a792..865dd569d8fc 100644 --- a/superset-frontend/src/SqlLab/components/ColumnElement/index.tsx +++ b/superset-frontend/src/SqlLab/components/ColumnElement/index.tsx @@ -40,6 +40,10 @@ const StyledTooltip = (props: any) => { color: ${theme.colorBgLayout}; font-size: ${theme.fontSizeXS}px; } + + p { + text-align: left; + } } `} {...props} @@ -57,12 +61,14 @@ const iconMap = { pk: 'fa-key', fk: 'fa-link', index: 'fa-bookmark', + comment: 'fa-comment', }; const tooltipTitleMap = { pk: t('Primary key'), fk: t('Foreign key'), index: t('Index'), + comment: t('Comment'), }; export type ColumnKeyTypeType = keyof typeof tooltipTitleMap; @@ -72,6 +78,7 @@ interface ColumnElementProps { name: string; keys?: { type: ColumnKeyTypeType }[]; type: string; + comment?: string; }; } @@ -81,7 +88,7 @@ const NowrapDiv = styled.div` const ColumnElement = ({ column }: ColumnElementProps) => { let columnName: ReactNode = column.name; - let icons; + let icons: ReactNode[] = []; if (column.keys && column.keys.length > 0) { columnName = {column.name}; icons = column.keys.map((key, i) => ( @@ -104,6 +111,23 @@ const ColumnElement = ({ column }: ColumnElementProps) => { )); } + if (column.comment) { + icons.push( + + {tooltipTitleMap.comment} +
+

{column.comment}

+ + } + > + +
, + ); + } return (
diff --git a/superset-frontend/src/SqlLab/components/SqlEditor/index.tsx b/superset-frontend/src/SqlLab/components/SqlEditor/index.tsx index 1a11eb90c56d..3e64dcb6a6d4 100644 --- a/superset-frontend/src/SqlLab/components/SqlEditor/index.tsx +++ b/superset-frontend/src/SqlLab/components/SqlEditor/index.tsx @@ -66,6 +66,10 @@ import { Skeleton } from '@superset-ui/core/components/Skeleton'; import { Switch } from '@superset-ui/core/components/Switch'; import { Menu, MenuItemType } from '@superset-ui/core/components/Menu'; import { Icons } from '@superset-ui/core/components/Icons'; +import { + SavedContextStatus, + useLlmContextStatus, +} from 'src/hooks/apiResources'; import { detectOS } from 'src/utils/common'; import { addNewQueryEditor, @@ -79,6 +83,7 @@ import { queryEditorSetAndSaveSql, queryEditorSetTemplateParams, runQueryFromSqlEditor, + generateSql, saveQuery, addSavedQueryToTabState, scheduleQuery, @@ -96,6 +101,7 @@ import { INITIAL_NORTH_PERCENT, SET_QUERY_EDITOR_SQL_DEBOUNCE_MS, } from 'src/SqlLab/constants'; +import useQueryEditor from 'src/SqlLab/hooks/useQueryEditor'; import { getItem, LocalStorageKeys, @@ -126,6 +132,7 @@ import SqlEditorLeftBar from '../SqlEditorLeftBar'; import AceEditorWrapper from '../AceEditorWrapper'; import RunQueryActionButton from '../RunQueryActionButton'; import QueryLimitSelect from '../QueryLimitSelect'; +import AIAssistantEditor from '../AIAssistantEditor'; import KeyboardShortcutButton, { KEY_MAP, KeyboardShortcut, @@ -226,6 +233,23 @@ const SqlEditor: FC = ({ }) => { const theme = useTheme(); const dispatch = useDispatch(); + const storedQueryEditor = useQueryEditor(queryEditor.id, [ + 'dbId', + 'catalog', + 'schema', + ]); + const [savedLlmContext, setSavedLlmContext] = + useState(null); + const [contextError, setContextError] = useState(null); + useLlmContextStatus({ + dbId: storedQueryEditor.dbId, + onSuccess: result => { + if (result.context) { + setSavedLlmContext(result.context); + } + setContextError(result.error ? result.error.build_time : null); + }, + }); const { database, @@ -330,6 +354,15 @@ const SqlEditor: FC = ({ } }; + const runAiAssistant = useCallback( + (prompt: string) => { + if (database) { + dispatch(generateSql(database.id, storedQueryEditor, prompt)); + } + }, + [database, storedQueryEditor], + ); + useEffect(() => { if (autorun) { setAutorun(false); @@ -877,6 +910,33 @@ const SqlEditor: FC = ({ ); }; + const renderAiAssistantEditor = () => { + const disabledMessage = + savedLlmContext && contextError + ? t('Context build error; falling back to an older context') + : !savedLlmContext && contextError + ? t('AI Assistant is unavailable due to a context build error') + : !savedLlmContext && !contextError + ? t( + 'AI Assistant is unavailable - please try again in a few minutes', + ) + : undefined; + + return ( + database?.llm_connection?.enabled && ( + + ) + ); + }; + const handleCursorPositionChange = (newPosition: CursorPosition) => { dispatch(queryEditorSetCursorPosition(queryEditor, newPosition)); }; @@ -972,6 +1032,7 @@ const SqlEditor: FC = ({ /> )} {queryEditor.isDataset && renderDatasetWarning()} + {renderAiAssistantEditor()}
{({ height }) => diff --git a/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/index.tsx b/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/index.tsx index a51f312b6085..c1e17bf41f64 100644 --- a/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/index.tsx +++ b/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/index.tsx @@ -35,7 +35,10 @@ import { import { Button, EmptyState, Icons } from '@superset-ui/core/components'; import { type DatabaseObject } from 'src/components'; import { t, styled, css } from '@superset-ui/core'; -import { TableSelectorMultiple } from 'src/components/TableSelector'; +import { + TableSelectorMultiple, + TableValue, +} from 'src/components/TableSelector'; import useQueryEditor from 'src/SqlLab/hooks/useQueryEditor'; import { getItem, @@ -94,7 +97,7 @@ const SqlEditorLeftBar = ({ const tables = useMemo( () => allSelectedTables.filter( - table => table.dbId === dbId && table.schema === schema, + table => table.dbId === dbId, // && table.schema === schema, ), [allSelectedTables, dbId, schema], ); @@ -130,7 +133,7 @@ const SqlEditorLeftBar = ({ ); const onTablesChange = ( - tableNames: string[], + tableValues: TableValue[], catalogName: string | null, schemaName: string, ) => { @@ -139,8 +142,8 @@ const SqlEditorLeftBar = ({ } const currentTables = [...tables]; - const tablesToAdd = tableNames.filter(name => { - const index = currentTables.findIndex(table => table.name === name); + const tablesToAdd = tableValues.filter(tv => { + const index = currentTables.findIndex(table => table.name === tv.value); if (index >= 0) { currentTables.splice(index, 1); return false; @@ -149,8 +152,10 @@ const SqlEditorLeftBar = ({ return true; }); - tablesToAdd.forEach(tableName => { - dispatch(addTable(queryEditor, tableName, catalogName, schemaName)); + tablesToAdd.forEach(tableValue => { + dispatch( + addTable(queryEditor, tableValue.value, catalogName, tableValue.schema), + ); }); dispatch(removeTables(currentTables)); diff --git a/superset-frontend/src/SqlLab/components/TableElement/index.tsx b/superset-frontend/src/SqlLab/components/TableElement/index.tsx index 1bef2037d981..8b5cf1232428 100644 --- a/superset-frontend/src/SqlLab/components/TableElement/index.tsx +++ b/superset-frontend/src/SqlLab/components/TableElement/index.tsx @@ -344,7 +344,9 @@ const TableElement = ({ table, ...props }: TableElementProps) => { ref={tableNameRef} className="table-name" > - {name} + + {schema}.{name} +
diff --git a/superset-frontend/src/SqlLab/reducers/getInitialState.ts b/superset-frontend/src/SqlLab/reducers/getInitialState.ts index 7d9f0fddacad..fef8a62424dc 100644 --- a/superset-frontend/src/SqlLab/reducers/getInitialState.ts +++ b/superset-frontend/src/SqlLab/reducers/getInitialState.ts @@ -278,6 +278,10 @@ export default function getInitialState({ unsavedQueryEditor, lastUpdatedActiveTab, destroyedQueryEditors, + queryGenerator: { + isGeneratingQuery: false, + prompt: '', + }, }, localStorageUsageInKilobytes: 0, common, diff --git a/superset-frontend/src/SqlLab/reducers/sqlLab.js b/superset-frontend/src/SqlLab/reducers/sqlLab.js index a1e0210d7900..9c24e109fbe6 100644 --- a/superset-frontend/src/SqlLab/reducers/sqlLab.js +++ b/superset-frontend/src/SqlLab/reducers/sqlLab.js @@ -420,6 +420,61 @@ export default function sqlLabReducer(state = {}, action) { }; return alterInObject(state, 'queries', action.query, alts); }, + [actions.START_GENERATE_SQL]() { + const queryEditors = state.queryEditors.map(qe => { + if (qe.id === action.queryEditorId) { + return { + ...qe, + queryGenerator: { + ...qe.queryGenerator, + isGeneratingQuery: true, + }, + }; + } + return qe; + }); + return { + ...state, + queryEditors, + }; + }, + [actions.GENERATE_SQL_DONE]() { + const queryEditors = state.queryEditors.map(qe => { + if (qe.id === action.queryEditorId) { + return { + ...qe, + queryGenerator: { + ...qe.queryGenerator, + isGeneratingQuery: false, + prompt: action.prompt, + }, + }; + } + return qe; + }); + return { + ...state, + queryEditors, + }; + }, + [actions.GENERATE_SQL_SET_PROMPT]() { + const queryEditors = state.queryEditors.map(qe => { + if (qe.id === action.queryEditorId) { + return { + ...qe, + queryGenerator: { + ...qe.queryGenerator, + prompt: action.prompt, + }, + }; + } + return qe; + }); + return { + ...state, + queryEditors, + }; + }, [actions.SET_ACTIVE_QUERY_EDITOR]() { const qeIds = state.queryEditors.map(qe => qe.id); if ( diff --git a/superset-frontend/src/SqlLab/types.ts b/superset-frontend/src/SqlLab/types.ts index 5532f155b746..886ebef48d28 100644 --- a/superset-frontend/src/SqlLab/types.ts +++ b/superset-frontend/src/SqlLab/types.ts @@ -71,6 +71,7 @@ export interface QueryEditor { updatedAt?: number; cursorPosition?: CursorPosition; isDataset?: boolean; + queryGenerator?: QueryGenerator; tabViewId?: string; } @@ -98,6 +99,11 @@ export interface Table { persistData?: TableMetaData; } +export interface QueryGenerator { + isGeneratingQuery: boolean; + prompt: string; +} + export type SqlLabRootState = { sqlLab: { activeSouthPaneTab: string | number; // default is string; action.newQuery.id is number diff --git a/superset-frontend/src/SqlLab/utils/reduxStateToLocalStorageHelper.ts b/superset-frontend/src/SqlLab/utils/reduxStateToLocalStorageHelper.ts index 683b082b38cb..4737147b5f49 100644 --- a/superset-frontend/src/SqlLab/utils/reduxStateToLocalStorageHelper.ts +++ b/superset-frontend/src/SqlLab/utils/reduxStateToLocalStorageHelper.ts @@ -22,12 +22,7 @@ import type { QueryResponse, QueryResults, } from '@superset-ui/core'; -import type { - CursorPosition, - QueryEditor, - SqlLabRootState, - Table, -} from 'src/SqlLab/types'; +import type { QueryEditor, SqlLabRootState, Table } from 'src/SqlLab/types'; import type { ThunkDispatch } from 'redux-thunk'; import { pick } from 'lodash'; import { tableApiUtil } from 'src/hooks/apiResources/tables'; @@ -119,17 +114,12 @@ export function clearQueryEditors(queryEditors: QueryEditor[]) { // only return selected keys Object.keys(editor) .filter(key => PERSISTENT_QUERY_EDITOR_KEYS.has(key)) - .reduce< - Record< - string, - string | number | boolean | CursorPosition | null | undefined - > - >( + .reduce>( (accumulator, key) => ({ ...accumulator, [key]: editor[key as keyof QueryEditor], }), - {}, + {} as Pick, ), ); } diff --git a/superset-frontend/src/components/DatabaseSelector/DatabaseSelector.test.tsx b/superset-frontend/src/components/DatabaseSelector/DatabaseSelector.test.tsx index f8d0425c5e68..54cf5b09f1e4 100644 --- a/superset-frontend/src/components/DatabaseSelector/DatabaseSelector.test.tsx +++ b/superset-frontend/src/components/DatabaseSelector/DatabaseSelector.test.tsx @@ -16,369 +16,452 @@ * specific language governing permissions and limitations * under the License. */ - -import fetchMock from 'fetch-mock'; import { - act, - defaultStore as store, - render, - screen, - userEvent, - waitFor, -} from 'spec/helpers/testing-library'; -import { api } from 'src/hooks/apiResources/queryApi'; -import { EmptyState } from '@superset-ui/core/components'; -import { DatabaseSelector } from '.'; -import type { DatabaseSelectorProps } from './types'; - -const createProps = (): DatabaseSelectorProps => ({ - db: { - id: 1, - database_name: 'test', - backend: 'test-postgresql', - }, - formMode: false, - isDatabaseSelectEnabled: true, - readOnly: false, - catalog: null, - schema: 'public', - sqlLabMode: true, - getDbList: jest.fn(), - handleError: jest.fn(), - onDbChange: jest.fn(), - onSchemaChange: jest.fn(), -}); - -const fakeDatabaseApiResult = { - count: 2, - description_columns: {}, - ids: [1, 2], - label_columns: { - allow_file_upload: 'Allow Csv Upload', - allow_ctas: 'Allow Ctas', - allow_cvas: 'Allow Cvas', - allow_dml: 'Allow DDL and DML', - allow_run_async: 'Allow Run Async', - allows_cost_estimate: 'Allows Cost Estimate', - allows_subquery: 'Allows Subquery', - allows_virtual_table_explore: 'Allows Virtual Table Explore', - disable_data_preview: 'Disables SQL Lab Data Preview', - disable_drill_to_detail: 'Disable Drill To Detail', - backend: 'Backend', - changed_on: 'Changed On', - changed_on_delta_humanized: 'Changed On Delta Humanized', - 'created_by.first_name': 'Created By First Name', - 'created_by.last_name': 'Created By Last Name', - database_name: 'Database Name', - explore_database_id: 'Explore Database Id', - expose_in_sqllab: 'Expose In Sqllab', - force_ctas_schema: 'Force Ctas Schema', - id: 'Id', - }, - list_columns: [ - 'allow_file_upload', - 'allow_ctas', - 'allow_cvas', - 'allow_dml', - 'allow_run_async', - 'allows_cost_estimate', - 'allows_subquery', - 'allows_virtual_table_explore', - 'disable_data_preview', - 'disable_drill_to_detail', - 'backend', - 'changed_on', - 'changed_on_delta_humanized', - 'created_by.first_name', - 'created_by.last_name', - 'database_name', - 'explore_database_id', - 'expose_in_sqllab', - 'force_ctas_schema', - 'id', - ], - list_title: 'List Database', - order_columns: [ - 'allow_file_upload', - 'allow_dml', - 'allow_run_async', - 'changed_on', - 'changed_on_delta_humanized', - 'created_by.first_name', - 'database_name', - 'expose_in_sqllab', - ], - result: [ - { - allow_file_upload: false, - allow_ctas: false, - allow_cvas: false, - allow_dml: false, - allow_run_async: false, - allows_cost_estimate: null, - allows_subquery: true, - allows_virtual_table_explore: true, - disable_data_preview: false, - disable_drill_to_detail: false, - backend: 'postgresql', - changed_on: '2021-03-09T19:02:07.141095', - changed_on_delta_humanized: 'a day ago', - created_by: null, - database_name: 'test-postgres', - explore_database_id: 1, - expose_in_sqllab: true, - force_ctas_schema: null, - id: 1, - }, - { - allow_csv_upload: false, - allow_ctas: false, - allow_cvas: false, - allow_dml: false, - allow_run_async: false, - allows_cost_estimate: null, - allows_subquery: true, - allows_virtual_table_explore: true, - disable_data_preview: false, - disable_drill_to_detail: false, - backend: 'mysql', - changed_on: '2021-03-09T19:02:07.141095', - changed_on_delta_humanized: 'a day ago', - created_by: null, - database_name: 'test-mysql', - explore_database_id: 1, - expose_in_sqllab: true, - force_ctas_schema: null, - id: 2, - }, - ], -}; - -const fakeDatabaseApiResultInReverseOrder = { - ...fakeDatabaseApiResult, - ids: [2, 1], - result: [...fakeDatabaseApiResult.result].reverse(), -}; - -const fakeSchemaApiResult = { - count: 2, - result: ['information_schema', 'public'], -}; - -const fakeCatalogApiResult = { - count: 0, - result: [], -}; - -const fakeFunctionNamesApiResult = { - function_names: [], -}; - -const databaseApiRoute = - 'glob:*/api/v1/database/?*order_column:database_name,order_direction*'; -const catalogApiRoute = 'glob:*/api/v1/database/*/catalogs/?*'; -const schemaApiRoute = 'glob:*/api/v1/database/*/schemas/?*'; -const tablesApiRoute = 'glob:*/api/v1/database/*/tables/*'; - -function setupFetchMock() { - fetchMock.get(databaseApiRoute, fakeDatabaseApiResult); - fetchMock.get(catalogApiRoute, fakeCatalogApiResult); - fetchMock.get(schemaApiRoute, fakeSchemaApiResult); - fetchMock.get(tablesApiRoute, fakeFunctionNamesApiResult); -} + ReactNode, + useState, + useMemo, + useEffect, + useRef, + useCallback, +} from 'react'; +import { styled, SupersetClient, SupersetError, t } from '@superset-ui/core'; +import rison from 'rison'; +import RefreshLabel from '@superset-ui/core/components/RefreshLabel'; +import { useToasts } from 'src/components/MessageToasts/withToasts'; +import { + useCatalogs, + CatalogOption, + useSchemas, + SchemaOption, +} from 'src/hooks/apiResources'; +import { + Select, + AsyncSelect, + Label, + FormLabel, + LabeledValue as AntdLabeledValue, +} from '@superset-ui/core/components'; -beforeEach(() => { - setupFetchMock(); -}); +import { ErrorMessageWithStackTrace } from 'src/components'; +import type { + DatabaseSelectorProps, + DatabaseValue, + DatabaseObject, +} from './types'; -afterEach(() => { - fetchMock.reset(); - act(() => { - store.dispatch(api.util.resetApiState()); - }); -}); +const DatabaseSelectorWrapper = styled.div` + ${({ theme }) => ` + .refresh { + display: flex; + align-items: center; + width: 30px; + margin-left: ${theme.sizeUnit}px; + margin-top: ${theme.sizeUnit * 5}px; + } -test('Should render', async () => { - const props = createProps(); - render(, { useRedux: true, store }); - expect(await screen.findByTestId('DatabaseSelector')).toBeInTheDocument(); -}); + .section { + display: flex; + flex-direction: row; + align-items: center; + } -test('Refresh should work', async () => { - const props = createProps(); + .select { + width: calc(100% - 30px - ${theme.sizeUnit}px); + flex: 1; + } - render(, { useRedux: true, store }); + & > div { + margin-bottom: ${theme.sizeUnit * 4}px; + } + `} +`; - expect(fetchMock.calls(schemaApiRoute).length).toBe(0); +const LabelStyle = styled.div` + display: flex; + flex-direction: row; + align-items: center; + margin-left: ${({ theme }) => theme.sizeUnit - 2}px; - const select = screen.getByRole('combobox', { - name: 'Select schema or type to search schemas: public', - }); - - await userEvent.click(select); + .backend { + overflow: visible; + } - await waitFor(() => { - expect(fetchMock.calls(databaseApiRoute).length).toBe(1); - expect(fetchMock.calls(schemaApiRoute).length).toBe(1); - expect(props.handleError).toHaveBeenCalledTimes(0); - expect(props.onDbChange).toHaveBeenCalledTimes(0); - expect(props.onSchemaChange).toHaveBeenCalledTimes(0); - }); + .name { + overflow: hidden; + text-overflow: ellipsis; + } +`; - // click schema reload - await userEvent.click(screen.getByRole('button', { name: 'sync' })); +const SelectLabel = ({ + backend, + databaseName, +}: { + backend?: string; + databaseName: string; +}) => ( + + + + {databaseName} + + +); - await waitFor(() => { - expect(fetchMock.calls(databaseApiRoute).length).toBe(1); - expect(fetchMock.calls(schemaApiRoute).length).toBe(2); - expect(props.handleError).toHaveBeenCalledTimes(0); - expect(props.onDbChange).toHaveBeenCalledTimes(0); - expect(props.onSchemaChange).toHaveBeenCalledTimes(0); - }); -}); +const EMPTY_CATALOG_OPTIONS: CatalogOption[] = []; +const EMPTY_SCHEMA_OPTIONS: SchemaOption[] = []; -test('Should database select display options', async () => { - const props = createProps(); - render(, { useRedux: true, store }); - const select = screen.getByRole('combobox', { - name: 'Select database or type to search databases', - }); - expect(select).toBeInTheDocument(); - await userEvent.click(select); - expect(await screen.findByText('test-mysql')).toBeInTheDocument(); -}); - -test('should display options in order of the api response', async () => { - fetchMock.get(databaseApiRoute, fakeDatabaseApiResultInReverseOrder, { - overwriteRoutes: true, - }); - const props = createProps(); - render(, { - useRedux: true, - store, - }); - const select = screen.getByRole('combobox', { - name: 'Select database or type to search databases', - }); - expect(select).toBeInTheDocument(); - await userEvent.click(select); - const options = await screen.findAllByRole('option'); +interface AntdLabeledValueWithOrder extends AntdLabeledValue { + order: number; +} - expect(options[0]).toHaveTextContent( - `${fakeDatabaseApiResultInReverseOrder.result[0].id}`, - ); - expect(options[1]).toHaveTextContent( - `${fakeDatabaseApiResultInReverseOrder.result[1].id}`, - ); -}); - -test('Should fetch the search keyword when total count exceeds initial options', async () => { - fetchMock.get( - databaseApiRoute, - { - ...fakeDatabaseApiResult, - count: fakeDatabaseApiResult.result.length + 1, - }, - { overwriteRoutes: true }, +export function DatabaseSelector({ + db, + formMode = false, + emptyState, + getDbList, + handleError, + isDatabaseSelectEnabled = true, + onDbChange, + onEmptyResults, + onCatalogChange, + catalog, + onSchemaChange, + schema, + readOnly = false, + sqlLabMode = false, + schemaSelectMode = 'single', +}: DatabaseSelectorProps) { + const showCatalogSelector = !!db?.allow_multi_catalog; + const [currentDb, setCurrentDb] = useState(); + const [errorPayload, setErrorPayload] = useState(); + const [currentCatalog, setCurrentCatalog] = useState< + CatalogOption | null | undefined + >(catalog ? { label: catalog, value: catalog, title: catalog } : undefined); + const catalogRef = useRef(catalog); + catalogRef.current = catalog; + const [currentSchema, setCurrentSchema] = useState< + SchemaOption | SchemaOption[] | undefined + >(undefined); + const schemaRef = useRef(schema); + schemaRef.current = schema; + const { addSuccessToast } = useToasts(); + const sortComparator = useCallback( + (itemA: AntdLabeledValueWithOrder, itemB: AntdLabeledValueWithOrder) => + itemA.order - itemB.order, + [], ); - const props = createProps(); - render(, { useRedux: true, store }); - const select = screen.getByRole('combobox', { - name: 'Select database or type to search databases', - }); - await waitFor(() => - expect(fetchMock.calls(databaseApiRoute)).toHaveLength(1), - ); - expect(select).toBeInTheDocument(); - await userEvent.type(select, 'keywordtest'); - await waitFor(() => - expect(fetchMock.calls(databaseApiRoute)).toHaveLength(2), - ); - expect(fetchMock.calls(databaseApiRoute)[1][0]).toContain('keywordtest'); -}); - -test('should show empty state if there are no options', async () => { - fetchMock.reset(); - fetchMock.get(databaseApiRoute, { result: [] }); - fetchMock.get(schemaApiRoute, { result: [] }); - fetchMock.get(tablesApiRoute, { result: [] }); - const props = createProps(); - render( - } - />, - { useRedux: true, store }, - ); - const select = screen.getByRole('combobox', { - name: 'Select database or type to search databases', - }); - await userEvent.click(select); - const emptystate = await screen.findByText('empty'); - expect(emptystate).toBeInTheDocument(); - expect(screen.queryByText('test-mysql')).not.toBeInTheDocument(); -}); - -test('Should schema select display options', async () => { - const props = createProps(); - render(, { useRedux: true, store }); - const select = screen.getByRole('combobox', { - name: 'Select schema or type to search schemas: public', - }); - expect(select).toBeInTheDocument(); - await userEvent.click(select); - await waitFor(() => { - expect(screen.queryByText('Loading...')).not.toBeInTheDocument(); - }); - const publicOption = await screen.findByRole('option', { name: 'public' }); - expect(publicOption).toBeInTheDocument(); + useEffect(() => { + if (schemaSelectMode === 'single') { + setCurrentSchema( + schema && !Array.isArray(schema) + ? { label: schema, value: schema, title: schema } + : undefined, + ); + } else { + setCurrentSchema( + Array.isArray(schema) + ? schema.map(schema => ({ + label: schema, + value: schema, + title: schema, + })) + : typeof schema === 'string' && schema + ? [{ label: schema, value: schema, title: schema }] + : [], + ); + } + }, [schema]); - const infoSchemaOption = await screen.findByRole('option', { - name: 'information_schema', - }); - expect(infoSchemaOption).toBeInTheDocument(); -}); - -test('Sends the correct db when changing the database', async () => { - const props = createProps(); - render(, { useRedux: true, store }); - const select = screen.getByRole('combobox', { - name: 'Select database or type to search databases', - }); - expect(select).toBeInTheDocument(); - await userEvent.click(select); - await userEvent.click(await screen.findByText('test-mysql')); - await waitFor(() => - expect(props.onDbChange).toHaveBeenCalledWith( - expect.objectContaining({ - value: 2, - database_name: 'test-mysql', - backend: 'mysql', - }), - ), + const loadDatabases = useMemo( + () => + async ( + search: string, + page: number, + pageSize: number, + ): Promise<{ + data: DatabaseValue[]; + totalCount: number; + }> => { + const queryParams = rison.encode({ + order_column: 'database_name', + order_direction: 'asc', + page, + page_size: pageSize, + ...(formMode || !sqlLabMode + ? { filters: [{ col: 'database_name', opr: 'ct', value: search }] } + : { + filters: [ + { col: 'database_name', opr: 'ct', value: search }, + { + col: 'expose_in_sqllab', + opr: 'eq', + value: true, + }, + ], + }), + }); + const endpoint = `/api/v1/database/?q=${queryParams}`; + return SupersetClient.get({ endpoint }).then(({ json }) => { + const { result, count } = json; + if (getDbList) { + getDbList(result); + } + if (result.length === 0) { + if (onEmptyResults) onEmptyResults(search); + } + + const options = result.map((row: DatabaseObject, order: number) => ({ + label: ( + + ), + value: row.id, + id: `${row.backend}-${row.database_name}-${row.id}`, + database_name: row.database_name, + backend: row.backend, + allow_multi_catalog: row.allow_multi_catalog, + order, + })); + + return { + data: options, + totalCount: count ?? options.length, + }; + }); + }, + [formMode, getDbList, sqlLabMode, onEmptyResults], ); -}); -test('Sends the correct schema when changing the schema', async () => { - const props = createProps(); - const { rerender } = render(, { - useRedux: true, - store, + useEffect(() => { + setCurrentDb(current => + current?.id !== db?.id + ? db + ? { + label: ( + + ), + value: db.id, + ...db, + } + : undefined + : current, + ); + }, [db]); + + function changeSchema(schema?: SchemaOption | SchemaOption[]) { + setCurrentSchema(schema); + if (Array.isArray(schema)) { + const schema_values = schema.map(schema => schema.value); + if (onSchemaChange && schema_values !== schemaRef.current) { + onSchemaChange(schema_values); + } + } else if (onSchemaChange) { + onSchemaChange(schema?.value); + } + } + + const { + currentData: schemaData, + isFetching: loadingSchemas, + refetch: refetchSchemas, + } = useSchemas({ + dbId: currentDb?.value, + catalog: currentCatalog?.value, + onSuccess: (schemas, isFetched) => { + setErrorPayload(null); + if (schemas.length === 1) { + changeSchema(schemas[0]); + } else if ( + !schemas.find(schemaOption => schemaRef.current === schemaOption.value) + ) { + changeSchema(undefined); + } + + if (isFetched) { + addSuccessToast('List refreshed'); + } + }, + onError: error => { + if (error?.errors) { + setErrorPayload(error?.errors?.[0]); + } else { + handleError(t('There was an error loading the schemas')); + } + }, }); - await waitFor(() => expect(fetchMock.calls(databaseApiRoute).length).toBe(1)); - rerender(); - expect(props.onSchemaChange).toHaveBeenCalledTimes(0); - const select = screen.getByRole('combobox', { - name: 'Select schema or type to search schemas: public', + + const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS; + + function changeCatalog(catalog: CatalogOption | null | undefined) { + setCurrentCatalog(catalog); + setCurrentSchema(schemaSelectMode === 'single' ? undefined : []); + if (onCatalogChange && catalog?.value !== catalogRef.current) { + onCatalogChange(catalog?.value); + } + } + + const { + data: catalogData, + isFetching: loadingCatalogs, + refetch: refetchCatalogs, + } = useCatalogs({ + dbId: showCatalogSelector ? currentDb?.value : undefined, + onSuccess: (catalogs, isFetched) => { + setErrorPayload(null); + if (!showCatalogSelector) { + changeCatalog(null); + } else if (catalogs.length === 1) { + changeCatalog(catalogs[0]); + } else if ( + !catalogs.find( + catalogOption => catalogRef.current === catalogOption.value, + ) + ) { + changeCatalog(undefined); + } + + if (showCatalogSelector && isFetched) { + addSuccessToast('List refreshed'); + } + }, + onError: error => { + if (showCatalogSelector) { + if (error?.errors) { + setErrorPayload(error?.errors?.[0]); + } else { + handleError(t('There was an error loading the catalogs')); + } + } + }, }); - expect(select).toBeInTheDocument(); - await userEvent.click(select); - const schemaOption = await screen.findByText('information_schema'); - await userEvent.click(schemaOption); - await waitFor(() => - expect(props.onSchemaChange).toHaveBeenCalledWith('information_schema'), + + const catalogOptions = catalogData || EMPTY_CATALOG_OPTIONS; + + function changeDatabase( + value: { label: string; value: number }, + database: DatabaseValue, + ) { + // the database id is actually stored in the value property; the ID is used + // for the DOM, so it can't be an integer + const databaseWithId = { ...database, id: database.value }; + setCurrentDb(databaseWithId); + setCurrentCatalog(undefined); + setCurrentSchema(schemaSelectMode === 'single' ? undefined : []); + if (onDbChange) { + onDbChange(databaseWithId); + } + if (onCatalogChange) { + onCatalogChange(undefined); + } + if (onSchemaChange) { + onSchemaChange(schemaSelectMode === 'single' ? undefined : []); + } + } + + function renderSelectRow(select: ReactNode, refreshBtn: ReactNode) { + return ( +
+ {select} + {refreshBtn} +
+ ); + } + + function renderDatabaseSelect() { + return renderSelectRow( + {t('Database')}} + lazyLoading={false} + notFoundContent={emptyState} + onChange={changeDatabase} + value={currentDb} + placeholder={t('Select database or type to search databases')} + disabled={!isDatabaseSelectEnabled || readOnly} + options={loadDatabases} + sortComparator={sortComparator} + />, + null, + ); + } + + function renderCatalogSelect() { + const refreshIcon = !readOnly && ( + + ); + return renderSelectRow( + {t('Schema')}} + labelInValue + loading={loadingSchemas} + name="select-schema" + notFoundContent={t('No compatible schema found')} + placeholder={t('Select schema or type to search schemas')} + onChange={items => changeSchema(items as SchemaOption[])} + options={schemaOptions} + showSearch + value={currentSchema} + allowClear + mode={schemaSelectMode} + />, + refreshIcon, + ); + } + + function renderError() { + return errorPayload ? ( + + ) : null; + } + + return ( + + {renderDatabaseSelect()} + {renderError()} + {showCatalogSelector && renderCatalogSelect()} + {renderSchemaSelect()} + ); - expect(props.onSchemaChange).toHaveBeenCalledTimes(1); -}); +} + +export type { DatabaseObject }; diff --git a/superset-frontend/src/components/Datasource/components/DatasourceEditor/DatasourceEditor.jsx b/superset-frontend/src/components/Datasource/components/DatasourceEditor/DatasourceEditor.jsx index 30b499bed5c0..50f7cc765e85 100644 --- a/superset-frontend/src/components/Datasource/components/DatasourceEditor/DatasourceEditor.jsx +++ b/superset-frontend/src/components/Datasource/components/DatasourceEditor/DatasourceEditor.jsx @@ -1299,14 +1299,14 @@ class DatasourceEditor extends PureComponent { this.state.isEditMode && this.onDatasourcePropChange('catalog', catalog) } - onSchemaChange={schema => + onSchemaChange={schemas => this.state.isEditMode && - this.onDatasourcePropChange('schema', schema) + this.onDatasourcePropChange('schemas', schemas) } onDbChange={database => this.state.isEditMode && diff --git a/superset-frontend/src/components/TableSelector/TableSelector.test.tsx b/superset-frontend/src/components/TableSelector/TableSelector.test.tsx index a47c40731aaf..54e632159578 100644 --- a/superset-frontend/src/components/TableSelector/TableSelector.test.tsx +++ b/superset-frontend/src/components/TableSelector/TableSelector.test.tsx @@ -46,10 +46,10 @@ const getTableMockFunction = () => ({ count: 4, result: [ - { label: 'table_a', value: 'table_a' }, - { label: 'table_b', value: 'table_b' }, - { label: 'table_c', value: 'table_c' }, - { label: 'table_d', value: 'table_d' }, + { label: 'table_a', value: 'table_a', schema: 'test_schema' }, + { label: 'table_b', value: 'table_b', schema: 'test_schema' }, + { label: 'table_c', value: 'table_c', schema: 'test_schema' }, + { label: 'table_d', value: 'table_d', schema: 'test_schema' }, ], }) as any; @@ -118,8 +118,10 @@ test('skips select all options', async () => { const tableSelect = screen.getByRole('combobox', { name: 'Select table or type to search tables', }); - await userEvent.click(tableSelect); - expect(await screen.findByText('table_a')).toBeInTheDocument(); + userEvent.click(tableSelect); + expect( + await screen.findByRole('option', { name: 'test_schema.table_a' }), + ).toBeInTheDocument(); expect( screen.queryByRole('option', { name: /Select All/i }), ).not.toBeInTheDocument(); @@ -143,8 +145,8 @@ test('renders table options without Select All option', async () => { await waitFor( () => { - expect(screen.getByText('table_a')).toBeInTheDocument(); - expect(screen.getByText('table_b')).toBeInTheDocument(); + expect(screen.getByText('test_schema.table_a')).toBeInTheDocument(); + expect(screen.getByText('test_schema.table_b')).toBeInTheDocument(); }, { timeout: 10000 }, ); @@ -170,8 +172,16 @@ test('table select retain value if not in SQL Lab mode', async () => { expect(screen.queryByText('table_a')).not.toBeInTheDocument(); expect(getSelectItemContainer(tableSelect)).toHaveLength(0); - await act(async () => { - await userEvent.click(tableSelect); + userEvent.click(tableSelect); + + expect( + await screen.findByRole('option', { name: 'test_schema.table_a' }), + ).toBeInTheDocument(); + + await waitFor(() => { + const item = screen.getAllByText('table_a'); + userEvent.click(item[item.length - 1]); + // userEvent.click(screen.getAllByText('table_a')[1]); }); await waitFor( @@ -254,10 +264,15 @@ test('table multi select retain all the values selected', async () => { await userEvent.click(item[item.length - 1]); }); - const selections = await screen.findAllByRole('option', { selected: true }); - expect(selections).toHaveLength(2); - expect(selections[0]).toHaveTextContent('table_b'); - expect(selections[1]).toHaveTextContent('table_c'); + const selection1 = await screen.findByRole('option', { + name: 'test_schema.table_b', + }); + expect(selection1).toHaveAttribute('aria-selected', 'true'); + + const selection2 = await screen.findByRole('option', { + name: 'test_schema.table_c', + }); + expect(selection2).toHaveAttribute('aria-selected', 'true'); }); test('TableOption renders correct icons for different table types', () => { diff --git a/superset-frontend/src/components/TableSelector/index.tsx b/superset-frontend/src/components/TableSelector/index.tsx index 370eb3bc533d..34247db05cff 100644 --- a/superset-frontend/src/components/TableSelector/index.tsx +++ b/superset-frontend/src/components/TableSelector/index.tsx @@ -99,7 +99,7 @@ interface TableSelectorProps { isDatabaseSelectEnabled?: boolean; onDbChange?: (db: DatabaseObject) => void; onCatalogChange?: (catalog?: string | null) => void; - onSchemaChange?: (schema?: string) => void; + onSchemaChange?: (schema: string | string[]) => void; readOnly?: boolean; catalog?: string | null; schema?: string; @@ -107,14 +107,18 @@ interface TableSelectorProps { sqlLabMode?: boolean; tableValue?: string | string[]; onTableSelectChange?: ( - value?: string | string[], + value?: TableValue | TableValue[], catalog?: string | null, - schema?: string, ) => void; tableSelectMode?: 'single' | 'multiple'; customTableOptionLabelRenderer?: (table: Table) => JSX.Element; } +export interface TableValue { + value: string; + schema: string; +} + export interface TableOption { label: JSX.Element; text: string; @@ -122,9 +126,9 @@ export interface TableOption { } export const TableOption = ({ table }: { table: Table }) => { - const { value, type, extra } = table; + const { value, type, extra, schema } = table; return ( - + {type === 'view' ? ( ) : type === 'materialized_view' ? ( @@ -184,9 +188,9 @@ const TableSelector: FunctionComponent = ({ const [currentCatalog, setCurrentCatalog] = useState< string | null | undefined >(catalog); - const [currentSchema, setCurrentSchema] = useState( - schema, - ); + const [currentSchema, setCurrentSchema] = useState< + string | string[] | undefined + >(schema); const [tableSelectValue, setTableSelectValue] = useState< SelectValue | undefined >(undefined); @@ -219,7 +223,7 @@ const TableSelector: FunctionComponent = ({ () => data ? data.options.map(table => ({ - value: table.value, + value: `${table.schema}.${table.value}`, label: customTableOptionLabelRenderer ? ( customTableOptionLabelRenderer(table) ) : ( @@ -257,13 +261,18 @@ const TableSelector: FunctionComponent = ({ const internalTableChange = ( selectedOptions: TableOption | TableOption[] | undefined, ) => { + const parseOption = (option: TableOption): TableValue => { + const nameParts = option.value.split('.'); + return { value: nameParts[1], schema: nameParts[0] }; + }; if (currentSchema) { onTableSelectChange?.( Array.isArray(selectedOptions) - ? selectedOptions.map(option => option?.value) - : selectedOptions?.value, + ? selectedOptions.map(option => parseOption(option)) + : selectedOptions !== undefined + ? parseOption(selectedOptions) + : undefined, currentCatalog, - currentSchema, ); } else { setTableSelectValue(selectedOptions); @@ -292,7 +301,7 @@ const TableSelector: FunctionComponent = ({ setTableSelectValue(value); }; - const internalSchemaChange = (schema?: string) => { + const internalSchemaChange = (schema: string | string[]) => { setCurrentSchema(schema); if (onSchemaChange) { onSchemaChange(schema); @@ -369,6 +378,7 @@ const TableSelector: FunctionComponent = ({ sqlLabMode={sqlLabMode} isDatabaseSelectEnabled={isDatabaseSelectEnabled && !readOnly} readOnly={readOnly} + schemaSelectMode="multiple" /> {sqlLabMode && !formMode &&
} {renderTableSelect()} diff --git a/superset-frontend/src/explore/exploreUtils/index.js b/superset-frontend/src/explore/exploreUtils/index.js index d3bd62ef412a..45c7a63ac470 100644 --- a/superset-frontend/src/explore/exploreUtils/index.js +++ b/superset-frontend/src/explore/exploreUtils/index.js @@ -271,6 +271,7 @@ export const exportChart = async ({ parseMethod, }); } + console.log('exportChart payload', payload, ownState); SupersetClient.postForm(url, { form_data: safeStringify(payload) }); }; diff --git a/superset-frontend/src/features/databases/DatabaseModal/AIAssistantOptions.tsx b/superset-frontend/src/features/databases/DatabaseModal/AIAssistantOptions.tsx new file mode 100644 index 000000000000..8830f973d03d --- /dev/null +++ b/superset-frontend/src/features/databases/DatabaseModal/AIAssistantOptions.tsx @@ -0,0 +1,748 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { useEffect, useState, useMemo, useCallback, useRef } from 'react'; +import { t, SupersetClient, SupersetTheme } from '@superset-ui/core'; +import { + InfoTooltip, + Button, + Checkbox, + Collapse, + Select, + Switch, + CollapseLabelInModal, + Input, +} from '@superset-ui/core/components'; +import { + useDatabaseTables, + LlmDefaults, + SavedContextStatus, + useLlmContextStatus, + useLlmDefaults, + useCreateCustomLlmProviderMutation, + useTestCustomLlmProviderMutation, + CustomLlmProviderForm, +} from 'src/hooks/apiResources'; +import { + StyledContextError, + StyledContextWrapper, + StyledInputContainer, + StyledLlmSwitch, + StyledTokenEstimate, + StyledTopKForm, + wideButton, +} from './styles'; +import { DatabaseObject } from '../types'; +import SchemaSelector from './SchemaSelector'; + +const AIAssistantOptions = ({ + db, + onLlmConnectionChange, + onLlmContextOptionsChange, +}: { + db: DatabaseObject | null; + onLlmConnectionChange: Function; + onLlmContextOptionsChange: Function; +}) => { + const dbIdRef = useRef(null); + + // Update ref only when db exists and has an id + if (db?.id) { + dbIdRef.current = db.id; + } + + const [selectedProvider, setSelectedProvider] = useState( + db?.llm_connection?.provider || null, + ); + const [regenerating, setRegenerating] = useState(false); + const [savedContext, setSavedContext] = useState( + null, + ); + const [contextError, setContextError] = useState(null); + const [llmDefaults, setLlmDefaults] = useState(null); + const [selectedModelTokenLimit, setSelectedModelTokenLimit] = useState< + number | null + >(null); + const contextSettings = db?.llm_context_options; + const [activeKey, setActiveKey] = useState( + undefined, + ); + const [customProviderForm, setCustomProviderForm] = + useState({ + name: '', + endpoint_url: '', + request_template: '{"model": "{model}", "messages": {messages}}', + response_path: 'choices[0].message.content', + headers: + '{"Content-Type": "application/json", "Authorization": "Bearer {api_key}"}', + models: + '{"default": {"name": "Default Model", "input_token_limit": 100000}}', + system_instructions: '', + timeout: 30, + enabled: true, + }); + const [createCustomProvider] = useCreateCustomLlmProviderMutation(); + const [testCustomProvider] = useTestCustomLlmProviderMutation(); + + const tables = useDatabaseTables(dbIdRef.current || 0); + + const contextStatusOnSuccess = useCallback((result: any) => { + if (!result) return; + setRegenerating(result.status === 'building'); + if (result.context) { + setSavedContext(result.context); + } + setContextError(result.error ? result.error.build_time : null); + }, []); + + const contextStatus = useLlmContextStatus({ + dbId: dbIdRef.current, + onSuccess: contextStatusOnSuccess, + skip: !dbIdRef.current, + }); + + const llmDefaultsOnSuccess = useCallback((result: LlmDefaults) => { + if (!result) return; + setLlmDefaults(result); + }, []); + + useLlmDefaults({ + dbId: dbIdRef.current, + onSuccess: llmDefaultsOnSuccess, + skip: !dbIdRef.current, + }); + + useEffect(() => { + if (llmDefaults && selectedProvider && llmDefaults[selectedProvider]) { + const model = + db?.llm_connection?.model || + Object.keys(llmDefaults[selectedProvider].models)[0]; + setSelectedModelTokenLimit( + llmDefaults[selectedProvider].models[model]?.input_token_limit || null, + ); + } else { + setSelectedModelTokenLimit(null); + } + }, [llmDefaults, selectedProvider, db?.llm_connection?.model]); + + const handleProviderChange = useCallback( + (value: string) => { + setSelectedProvider(value); + if (value === 'add_custom') { + // Don't update the connection for the custom form + return; + } + + onLlmConnectionChange({ + ...db?.llm_connection, + provider: value, + model: llmDefaults?.[value]?.models + ? Object.keys(llmDefaults[value].models)[0] + : '', + }); + }, + [db?.llm_connection, llmDefaults, onLlmConnectionChange], + ); + + const handleLlmConnectionChange = useCallback( + (name: string, value: any) => { + onLlmConnectionChange({ ...db?.llm_connection, [name]: value }); + }, + [db?.llm_connection, onLlmConnectionChange], + ); + + const handleContextOptionsChange = useCallback( + (name: string, value: any) => { + onLlmContextOptionsChange({ ...db?.llm_context_options, [name]: value }); + }, + [db?.llm_context_options, onLlmContextOptionsChange], + ); + + const onSchemasChange = useCallback( + (value: string[]) => { + handleContextOptionsChange('schemas', JSON.stringify(value)); + }, + [handleContextOptionsChange], + ); + + const providerOptions = useMemo(() => { + if (!llmDefaults) return []; + + const options = Object.keys(llmDefaults).map(provider => { + const providerData = llmDefaults[provider]; + // For custom providers, use the display name if available + const label = providerData.name || provider; + return { + value: provider, + label, + }; + }); + + // Add "Add Custom Provider" option + options.push({ + value: 'add_custom', + label: '+ Add Custom Provider', + }); + + return options; + }, [llmDefaults]); + + const modelOptions = useMemo( + () => + llmDefaults && selectedProvider && selectedProvider in llmDefaults + ? Object.entries(llmDefaults[selectedProvider].models).map( + ([model, data]) => ({ + value: model, + label: data.name, + }), + ) + : [], + [llmDefaults, selectedProvider], + ); + + // Early return after all hooks if no database + if (!db) { + return null; + } + + return ( + <> + +
+ {t('Enable large language model support in SQL Lab')} +
+
+ + handleLlmConnectionChange('enabled', checked) + } + disabled={!db} + /> +
+
+ setActiveKey(key)} + items={[ + { + key: 'language_models', + label: ( + + ), + children: ( +
+ +
+ {t('Language model provider')} +
+
+ + setCustomProviderForm({ + ...customProviderForm, + name: e.target.value, + }) + } + /> +
+
+ + +
{t('Endpoint URL')}
+
+ + setCustomProviderForm({ + ...customProviderForm, + endpoint_url: e.target.value, + }) + } + /> +
+
+ + +
+ {t('Request Template (JSON)')} +
+
+ + setCustomProviderForm({ + ...customProviderForm, + request_template: e.target.value, + }) + } + rows={3} + /> +
+
+ {t( + 'Use {model}, {messages}, {api_key} as placeholders', + )} +
+
+ + +
{t('Response Path')}
+
+ + setCustomProviderForm({ + ...customProviderForm, + response_path: e.target.value, + }) + } + /> +
+
+ {t( + 'JSONPath to extract the generated text from the response', + )} +
+
+ + +
{t('Headers (JSON)')}
+
+ + setCustomProviderForm({ + ...customProviderForm, + headers: e.target.value, + }) + } + rows={2} + /> +
+
+ + +
{t('Models (JSON)')}
+
+ + setCustomProviderForm({ + ...customProviderForm, + models: e.target.value, + }) + } + rows={3} + /> +
+
+ +
+ + +
+ + ) : ( + selectedProvider && ( + <> + +
+ {t('Provider API key')} +
+
+ + handleLlmConnectionChange( + 'api_key', + e.target.value, + ) + } + disabled={!db} + /> +
+
+ + +
{t('Model')}
+
+ + handleContextOptionsChange( + 'refresh_interval', + e.target.value, + ) + } + disabled={!db} + /> +
+
+ {t( + 'Frequently updating the database context will consume more system resources' + + ' but will make changes to the database schema available to the AI Assistant sooner.', + )} +
+
+ +
+ {t('Select tables to include in the context')} +
+
+ {db && ( + + )} +
+
+ {t( + "Tables that aren't included will not be available for the AI Assistant to query.", + )} +
+
+ +
+ + handleContextOptionsChange( + 'include_indexes', + (e.target as HTMLInputElement).checked, + ) + } + disabled={!db} + > + {t('Include indexes in the database context')} + + +
+
+ +
+ {t( + 'Include up to k most common results from the first n rows', + )} +
+ +
+
{t('Results (k)')}
+ + handleContextOptionsChange('top_k', e.target.value) + } + disabled={!db} + /> +
+
+
{t('Row limit (n)')}
+ + handleContextOptionsChange( + 'top_k_limit', + e.target.value, + ) + } + disabled={!db} + /> +
+
+ {t( + 'The "top k" most common values on text columns are included to ' + + "increase the model's ability to perform text matching. Row " + + 'limit is the number of rows we scan to calculate the most common values', + )} +
+
+
+ +
+
+ {t('LLM Instructions')} +
+ +
+
+ + handleContextOptionsChange( + 'instructions', + e.target.value, + ) + } + style={{ flex: 1 }} + disabled={!db} + /> +
+
+ + + ), + }, + ]} + /> + + ); +}; + +export default AIAssistantOptions; diff --git a/superset-frontend/src/features/databases/DatabaseModal/SchemaSelector.tsx b/superset-frontend/src/features/databases/DatabaseModal/SchemaSelector.tsx new file mode 100644 index 000000000000..7da0d219981f --- /dev/null +++ b/superset-frontend/src/features/databases/DatabaseModal/SchemaSelector.tsx @@ -0,0 +1,400 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { t, useTheme, css, styled } from '@superset-ui/core'; +import { useEffect, useState } from 'react'; +import { Checkbox, Icons, Input, Button } from '@superset-ui/core/components'; + +const Container = styled.div` + ${({ theme }) => css` + width: 100%; + margin: 0 auto; + padding: ${theme.sizeUnit * 1.5}px ${theme.sizeUnit * 2}px; + border-style: none; + border: 1px solid ${theme.colors.grayscale.light2}; + border-radius: ${theme.sizeUnit}px; + + input { + flex-grow: 0; + } + `} +`; + +const SchemaList = styled.div` + ${({ theme }) => css` + display: flex; + flex-direction: column; + gap: ${theme.sizeUnit * 2}px; + `} +`; + +const SchemaItem = styled.div` + ${({ theme }) => css` + border-bottom: 1px solid ${theme.colors.grayscale.light2}; + padding-bottom: ${theme.sizeUnit * 1.5}px; + `} +`; + +const SchemaHeader = styled.div` + display: flex; + align-items: center; +`; + +const CaretButton = styled.button` + ${({ theme }) => css` + margin-right: ${theme.sizeUnit * 2}px; + background: none; + border: none; + cursor: pointer; + padding: 0; + display: flex; + align-items: center; + justify-content: center; + + &:focus { + outline: none; + } + `} +`; + +const EmptyCaret = styled.div` + ${({ theme }) => css` + width: ${theme.sizeUnit * 3.5}px; + height: ${theme.sizeUnit * 3.5}px; + margin-right: ${theme.sizeUnit * 2}px; + padding: 0; + `} +`; + +const CheckboxContainer = styled.div` + display: flex; + align-items: center; + + label { + font-weight: 500; + cursor: pointer; + } + + label.disabled { + font-weight: 400; + color: ${props => props.theme.colors.grayscale.light1}; + } +`; + +const TablesList = styled.div` + ${({ theme }) => css` + margin-left: ${theme.sizeUnit * 8}px; + margin-top: ${theme.sizeUnit}px; + display: flex; + flex-direction: column; + gap: ${theme.sizeUnit}px; + `} +`; + +const TableItem = styled.div` + display: flex; + align-items: center; +`; + +const StatusBar = styled.div` + ${({ theme }) => css` + margin-top: ${theme.sizeUnit * 2}px; + font-size: 0.875rem; + color: ${theme.colors.grayscale.dark1}; + `} +`; + +const Header = styled.div` + ${({ theme }) => css` + margin-bottom: ${theme.sizeUnit * 2}px; + font-size: ${theme.fontSizeSM}px; + font-weight: 600; + color: ${theme.colors.grayscale.dark1}; + display: flex; + + div { + border-right: 1px solid ${theme.colors.grayscale.light2}; + padding-right: ${theme.sizeUnit * 2}px; + padding-left: ${theme.sizeUnit * 2}px; + + &:first-child { + padding-left: 0; + } + + &:last-child { + border-right: none; + } + } + `} +`; + +const LoadingContainer = styled.div` + ${({ theme }) => css` + display: flex; + align-items: center; + gap: ${theme.sizeUnit}px; + + span { + margin-left: ${theme.sizeUnit * 2}px; + } + `} +`; + +const SchemaSelector = ({ + value, + options, + loading, + error, + onSchemasChange, + maxContentHeight = null, +}: { + value: string[]; + options: Record; + loading: boolean; + error: Error | null; + onSchemasChange: Function; + maxContentHeight?: number | null; +}) => { + const theme = useTheme(); + const [expandedSchema, setExpandedSchema] = useState(null); + const [selectedItems, setSelectedItems] = useState<{ + [key: string]: boolean; + }>({}); + const [filterText, setFilterText] = useState(''); + const [filteredOptions, setFilteredOptions] = useState< + Record + >({}); + useEffect(() => { + const filtered = filterText + ? Object.keys(options).reduce( + (acc, schema) => { + const filteredTables = options[schema].filter(table => + table.toLowerCase().includes(filterText.toLowerCase()), + ); + if (filteredTables.length > 0) { + acc[schema] = filteredTables; + } + return acc; + }, + {} as Record, + ) + : options; + setFilteredOptions(filtered); + }, [options, filterText]); + + useEffect(() => { + const initialSelections: { [key: string]: boolean } = {}; + Object.keys(options).forEach(schema => { + options[schema].forEach(table => { + initialSelections[`${schema}.${table}`] = + value.indexOf(`${schema}.${table}`) !== -1; + }); + }); + setSelectedItems(initialSelections); + }, [options, value]); + + const areAllChildrenSelected = (schema: string) => { + if (options[schema].length === 0) { + return false; + } + return options[schema].every(table => selectedItems[`${schema}.${table}`]); + }; + + const areSomeChildrenSelected = (schema: string) => + options[schema].some(table => selectedItems[`${schema}.${table}`]) && + !areAllChildrenSelected(schema); + + const toggleExpanded = (schema: string) => { + setExpandedSchema(expandedSchema === schema ? null : schema); + }; + + const handleSchemaCheckboxChange = (schema: string) => { + const newSelectedItems = { ...selectedItems }; + const newValue = !areAllChildrenSelected(schema); + + options[schema].forEach(table => { + newSelectedItems[`${schema}.${table}`] = newValue; + }); + + onSchemasChange(selectedItemsToValue(newSelectedItems)); + }; + + const handleChildCheckboxChange = (schema: string, table: string) => { + const newSelectedItems = { ...selectedItems }; + const key = `${schema}.${table}`; + + newSelectedItems[key] = !newSelectedItems[key]; + + onSchemasChange(selectedItemsToValue(newSelectedItems)); + }; + + const handleSelectAll = () => { + setAll(true); + }; + const handleUnselectAll = () => { + setAll(false); + }; + const setAll = (selected: boolean) => { + const newSelectedItems = { ...selectedItems }; + Object.keys(newSelectedItems).forEach(key => { + newSelectedItems[key] = selected; + }); + setSelectedItems(newSelectedItems); + onSchemasChange(selectedItemsToValue(newSelectedItems)); + }; + + const selectedItemsToValue = (selected: { [key: string]: boolean }) => { + const value = []; + for (const key in selected) { + if (selected[key]) { + value.push(key); + } + } + return value; + }; + + return ( + + {loading ? ( + + + {t('Loading schemas and tables...')} + + ) : error ? ( +

+ {t('An error occurred while retrieving schemas for this connection')} +

+ ) : ( + <> +
+
+ +
+
+ +
+
+ setFilterText(e.target.value)} + aria-label={t('Filter tables')} + /> +
+
+ + + {Object.keys(filteredOptions).map(schema => ( + + + {filteredOptions[schema].length > 0 ? ( + toggleExpanded(schema)} + aria-label={ + expandedSchema === schema ? 'Collapse' : 'Expand' + } + > + {expandedSchema === schema ? ( + + ) : ( + + )} + + ) : ( + + )} + + + handleSchemaCheckboxChange(schema)} + /> + + + + + {expandedSchema === schema && ( + + {filteredOptions[schema].map(table => ( + + + handleChildCheckboxChange(schema, table) + } + > + {table} + + + ))} + + )} + + ))} + + + + {Object.values(selectedItems).filter(Boolean).length} items selected + + + )} +
+ ); +}; + +export default SchemaSelector; diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.tsx b/superset-frontend/src/features/databases/DatabaseModal/index.tsx index a0dae6176279..eb20c6553a67 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/index.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/index.tsx @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ + import { t, styled, @@ -75,9 +76,11 @@ import { CatalogObject, Engines, ExtraJson, - CustomTextType, DatabaseParameters, + LlmConnection, + LlmContextOptions, } from '../types'; +import AIAssistantOptions from './AIAssistantOptions'; import ExtraOptions from './ExtraOptions'; import SqlAlchemyForm from './SqlAlchemyForm'; import DatabaseConnectionForm from './DatabaseConnectionForm'; @@ -113,6 +116,7 @@ const DEFAULT_EXTRA = JSON.stringify({ allows_virtual_table_explore: true }); const TABS_KEYS = { BASIC: 'basic', ADVANCED: 'advanced', + AI_ASSISTANT: 'ai_assistant', }; const engineSpecificAlertMapping = { @@ -174,10 +178,14 @@ export enum ActionType { QueryChange, RemoveTableCatalogSheet, Reset, + SelectChange, + SwitchChange, TextChange, ParametersSSHTunnelChange, SetSSHTunnelLoginMethod, RemoveSSHTunnelConfig, + LlmConnectionChange, + LlmContextOptionsChange, } export enum AuthType { @@ -202,6 +210,8 @@ export type DBReducerActionType = | ActionType.EncryptedExtraInputChange | ActionType.TextChange | ActionType.QueryChange + | ActionType.SelectChange + | ActionType.SwitchChange | ActionType.InputChange | ActionType.EditorChange | ActionType.ParametersChange @@ -248,6 +258,14 @@ export type DBReducerActionType = payload: { login_method: AuthType; }; + } + | { + type: ActionType.LlmConnectionChange; + payload: LlmConnection; + } + | { + type: ActionType.LlmContextOptionsChange; + payload: LlmContextOptions; }; const StyledBtns = styled.div` @@ -426,6 +444,16 @@ export function dbReducer( [action.payload.name]: action.payload.value, }, }; + case ActionType.SelectChange: + return { + ...trimmedState, + [action.payload.target || action.payload.name]: action.payload.value, + }; + case ActionType.SwitchChange: + return { + ...trimmedState, + [action.payload.name]: action.payload.checked, + }; case ActionType.SetSSHTunnelLoginMethod: { let ssh_tunnel = {}; if (trimmedState?.ssh_tunnel) { @@ -501,6 +529,19 @@ export function dbReducer( ...trimmedState, [action.payload.name]: action.payload.value, }; + case ActionType.LlmConnectionChange: + return { + ...trimmedState, + llm_connection: { ...trimmedState.llm_connection, ...action.payload }, + }; + case ActionType.LlmContextOptionsChange: + return { + ...trimmedState, + llm_context_options: { + ...trimmedState.llm_context_options, + ...action.payload, + }, + }; case ActionType.Fetched: // convert query to a string and store in query_input query = action.payload?.parameters?.query || {}; @@ -751,10 +792,7 @@ const DatabaseModal: FunctionComponent = ({ }; const onChange = useCallback( - ( - type: DBReducerActionType['type'], - payload: CustomTextType | DBReducerPayloadType, - ) => { + (type: DBReducerActionType['type'], payload: any) => { setDB({ type, payload } as DBReducerActionType); }, [], @@ -840,6 +878,12 @@ const DatabaseModal: FunctionComponent = ({ // Clone DB object const dbToUpdate = { ...(db || {}) }; + // If the database has no extra set, set it to an empty object to ensure the + // save request doesn't fail + if (!dbToUpdate.extra) { + dbToUpdate.extra = '{}'; + } + if (dbToUpdate.configuration_method === ConfigurationMethod.DynamicForm) { // Validate DB before saving if (dbToUpdate?.parameters?.catalog) { @@ -2064,6 +2108,21 @@ const DatabaseModal: FunctionComponent = ({ /> ), }, + { + key: TABS_KEYS.AI_ASSISTANT, + label: {t('AI Assistant')}, + children: ( + + onChange(ActionType.LlmConnectionChange, connection) + } + onLlmContextOptionsChange={(options: LlmContextOptions) => + onChange(ActionType.LlmContextOptionsChange, options) + } + /> + ), + }, ]} /> diff --git a/superset-frontend/src/features/databases/DatabaseModal/styles.ts b/superset-frontend/src/features/databases/DatabaseModal/styles.ts index e9bd68750f0c..06ed035b08b5 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/styles.ts +++ b/superset-frontend/src/features/databases/DatabaseModal/styles.ts @@ -524,3 +524,73 @@ export const StyledUploadWrapper = styled.div` display: none; } `; + +export const StyledContextWrapper = styled.div` + display: flex; + flex-direction: column; + margin-bottom: ${({ theme }) => theme.sizeUnit * 4}px; + gap: ${({ theme }) => theme.sizeUnit * 2}px; +`; + +export const StyledLlmSwitch = styled.div` + ${({ theme }) => css` + display: flex; + align-items: center; + margin-top: ${theme.sizeUnit * 6}px; + margin-left: ${theme.sizeUnit * 4}px; + margin-bottom: ${theme.sizeUnit * 6}px; + .control-label { + font-family: ${theme.fontFamily}; + font-size: ${theme.fontSize}px; + margin-right: ${theme.sizeUnit * 4}px; + } + .input-container { + display: flex; + align-items: center; + label { + margin-left: ${theme.sizeUnit * 2}px; + margin-top: ${theme.sizeUnit * 2}px; + } + } + `} +`; + +export const StyledTokenEstimate = styled.div` + border: 1px solid ${({ theme }) => theme.colors.grayscale.light2}; + border-radius: ${({ theme }) => theme.sizeUnit}px; + padding: ${({ theme }) => theme.sizeUnit * 3}px; + font-size: ${({ theme }) => theme.fontSizeSM}px; + background-color: ${({ theme }) => theme.colors.grayscale.light4}; + .warning { + color: ${({ theme }) => theme.colors.error.base}; + } +`; + +export const StyledContextError = styled.div` + border: 1px solid ${({ theme }) => theme.colors.error.base}; + border-radius: ${({ theme }) => theme.sizeUnit}px; + padding: ${({ theme }) => theme.sizeUnit * 3}px; + font-size: ${({ theme }) => theme.fontSizeSM}px; + background-color: ${({ theme }) => theme.colors.error.light2}; + color: ${({ theme }) => theme.colors.error.base}; +`; + +export const StyledTopKForm = styled.div` + display: flex; + flex-direction: column; + justify-content: space-between; + .input-container { + display: flex; + flex-direction: row; + align-items: center; + margin-bottom: ${({ theme }) => theme.sizeUnit * 4}px; + } + .control-label { + margin-top: ${({ theme }) => theme.sizeUnit * 2}px; + margin-right: ${({ theme }) => theme.sizeUnit * 2}px; + width: ${({ theme }) => theme.sizeUnit * 20}px; + } + .helper { + margin-top: 0; + } +`; diff --git a/superset-frontend/src/features/databases/types.ts b/superset-frontend/src/features/databases/types.ts index 57c76c987e7a..8ca7c6091fb7 100644 --- a/superset-frontend/src/features/databases/types.ts +++ b/superset-frontend/src/features/databases/types.ts @@ -1,7 +1,3 @@ -import { JsonObject } from '@superset-ui/core'; -import { InputProps } from '@superset-ui/core/components'; -import { ChangeEvent, EventHandler, FormEvent } from 'react'; - /** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -20,6 +16,11 @@ import { ChangeEvent, EventHandler, FormEvent } from 'react'; * specific language governing permissions and limitations * under the License. */ + +import { JsonObject } from '@superset-ui/core'; +import { InputProps } from '@superset-ui/core/components'; +import { ChangeEvent, EventHandler, FormEvent } from 'react'; + type DatabaseUser = { first_name: string; last_name: string; @@ -121,6 +122,26 @@ export type DatabaseObject = { // SSH Tunnel information ssh_tunnel?: SSHTunnelObject | null; + + // AI Assistant + llm_connection?: LlmConnection; + llm_context_options?: LlmContextOptions; +}; + +export type LlmConnection = { + provider?: string; + api_key?: string; + model?: string; + enabled?: boolean; +}; + +export type LlmContextOptions = { + schemas?: string; + include_indexes?: boolean; + refresh_interval?: number; + top_k?: number; + top_k_limit?: number; + instructions?: string; }; export type DatabaseForm = { @@ -265,6 +286,15 @@ export interface ExtraJson { version?: string; } +export interface LlmContextJson { + schemas: string[]; + include_indexes: boolean; + refresh_interval: number; + top_k: number; + top_k_limit: number; + instructions: string; +} + export type CustomTextType = { value?: string | boolean | number | object; type?: string | null; diff --git a/superset-frontend/src/features/datasets/AddDataset/LeftPanel/index.tsx b/superset-frontend/src/features/datasets/AddDataset/LeftPanel/index.tsx index 8906c0915f90..5d9e42420362 100644 --- a/superset-frontend/src/features/datasets/AddDataset/LeftPanel/index.tsx +++ b/superset-frontend/src/features/datasets/AddDataset/LeftPanel/index.tsx @@ -18,7 +18,10 @@ */ import { useEffect, SetStateAction, Dispatch, useCallback } from 'react'; import { styled, t } from '@superset-ui/core'; -import TableSelector, { TableOption } from 'src/components/TableSelector'; +import TableSelector, { + TableOption, + TableValue, +} from 'src/components/TableSelector'; import { EmptyState } from '@superset-ui/core/components'; import { type DatabaseObject } from 'src/components'; import { useToasts } from 'src/components/MessageToasts/withToasts'; @@ -138,7 +141,7 @@ export default function LeftPanel({ }); } }; - const setSchema = (schema: string) => { + const setSchema = (schema: string | string[]) => { if (schema) { setDataset({ type: DatasetActionType.SelectSchema, @@ -146,10 +149,14 @@ export default function LeftPanel({ }); } }; - const setTable = (tableName: string) => { + const setTable = (tableValue: TableValue) => { + setDataset({ + type: DatasetActionType.SelectSchema, + payload: { name: 'schema', value: tableValue.schema }, + }); setDataset({ type: DatasetActionType.SelectTable, - payload: { name: 'table_name', value: tableName }, + payload: { name: 'table_name', value: tableValue.value }, }); }; useEffect(() => { diff --git a/superset-frontend/src/hooks/apiResources/index.ts b/superset-frontend/src/hooks/apiResources/index.ts index 53aa7aa113cb..8f6a483455e6 100644 --- a/superset-frontend/src/hooks/apiResources/index.ts +++ b/superset-frontend/src/hooks/apiResources/index.ts @@ -32,3 +32,5 @@ export * from './dashboards'; export * from './tables'; export * from './schemas'; export * from './queryValidations'; +export * from './settings'; +export * from './llms'; diff --git a/superset-frontend/src/hooks/apiResources/llms.ts b/superset-frontend/src/hooks/apiResources/llms.ts new file mode 100644 index 000000000000..94428abb7e6a --- /dev/null +++ b/superset-frontend/src/hooks/apiResources/llms.ts @@ -0,0 +1,258 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { useEffect, useRef, useReducer } from 'react'; +import useEffectEvent from 'src/hooks/useEffectEvent'; +import { api, JsonResponse } from './queryApi'; + +export interface SavedContextStatus { + build_time: string; + status: string; + size?: number; + message?: string; +} + +export interface LlmContextStatus { + context?: SavedContextStatus; + error?: { build_time: string }; + status: 'waiting' | 'building'; +} + +export type FetchLlmContextStatusQueryParams = { + dbId?: string | number | null; + onSuccess?: (data: LlmContextStatus, isRefetched: boolean) => void; + onError?: (error: Response) => void; + skip?: boolean; +}; + +export interface LlmDefaults { + [provider: string]: { + models: { + [modelName: string]: { name: string; input_token_limit: number }; + }; + instructions: string; + name?: string; // Display name for custom providers + }; +} + +export type LlmDefaultsParams = { + dbId: number | null; + onSuccess?: (data: LlmDefaults, isRefetched: boolean) => void; + onError?: (error: Response) => void; + skip?: boolean; +}; + +export interface CustomLlmProvider { + id: number; + name: string; + endpoint_url: string; + request_template: string; + response_path: string; + headers: string | null; + models: string; + system_instructions: string | null; + timeout: number | null; + enabled: boolean; + created_on: string; + changed_on: string; +} + +export interface CustomLlmProviderForm { + name: string; + endpoint_url: string; + request_template: string; + response_path: string; + headers?: string; + models: string; + system_instructions?: string; + timeout?: number; + enabled?: boolean; +} + +const llmContextApi = api.injectEndpoints({ + endpoints: builder => ({ + contextStatus: builder.query< + LlmContextStatus, + FetchLlmContextStatusQueryParams + >({ + providesTags: ['LlmContextStatus'], + query: ({ dbId }) => ({ + endpoint: `api/v1/sqllab/db_context_status/`, + urlParams: { pk: dbId }, + transformResponse: ({ json }: JsonResponse) => json, + }), + }), + llmDefaults: builder.query({ + providesTags: ['LlmDefaults'], + query: dbId => ({ + endpoint: `api/v1/database/${dbId}/llm_defaults/`, + transformResponse: ({ json }: JsonResponse) => json, + }), + }), + customLlmProviders: builder.query({ + providesTags: ['CustomLlmProvider'], + query: () => ({ + endpoint: 'api/v1/custom_llm_provider/', + transformResponse: ({ json }: JsonResponse) => json.result, + }), + }), + createCustomLlmProvider: builder.mutation< + CustomLlmProvider, + CustomLlmProviderForm + >({ + invalidatesTags: ['CustomLlmProvider', 'LlmDefaults'], + query: provider => ({ + endpoint: 'api/v1/custom_llm_provider/', + method: 'POST', + body: provider, + transformResponse: ({ json }: JsonResponse) => json, + }), + }), + updateCustomLlmProvider: builder.mutation< + CustomLlmProvider, + { id: number; provider: Partial } + >({ + invalidatesTags: ['CustomLlmProvider', 'LlmDefaults'], + query: ({ id, provider }) => ({ + endpoint: `api/v1/custom_llm_provider/${id}`, + method: 'PUT', + body: provider, + transformResponse: ({ json }: JsonResponse) => json, + }), + }), + deleteCustomLlmProvider: builder.mutation({ + invalidatesTags: ['CustomLlmProvider', 'LlmDefaults'], + query: id => ({ + endpoint: `api/v1/custom_llm_provider/${id}`, + method: 'DELETE', + }), + }), + testCustomLlmProvider: builder.mutation< + { status: string; message: string }, + Partial + >({ + query: provider => ({ + endpoint: 'api/v1/custom_llm_provider/test', + method: 'POST', + body: provider, + transformResponse: ({ json }: JsonResponse) => json.result, + }), + }), + }), +}); + +export const { + useContextStatusQuery, + useLlmDefaultsQuery, + useCustomLlmProvidersQuery, + useCreateCustomLlmProviderMutation, + useUpdateCustomLlmProviderMutation, + useDeleteCustomLlmProviderMutation, + useTestCustomLlmProviderMutation, + endpoints: llmEndpoints, +} = llmContextApi; + +export function useLlmContextStatus(options: FetchLlmContextStatusQueryParams) { + const { dbId, onSuccess, onError, skip } = options || {}; + + const pollingInterval = useRef(30000); + const [, forceUpdate] = useReducer(x => x + 1, 0); + + const result = useContextStatusQuery( + { dbId: dbId || undefined }, + { + pollingInterval: pollingInterval.current, + refetchOnMountOrArgChange: true, + skip: skip || !dbId, + }, + ); + + // Adjust polling interval based on status + useEffect(() => { + const status = result?.data?.status; + const desiredInterval = status === 'building' ? 5000 : 30000; + if (pollingInterval.current !== desiredInterval) { + pollingInterval.current = desiredInterval; + forceUpdate(); + } + }, [result?.data?.status]); + + const handleOnSuccess = useEffectEvent( + (data: LlmContextStatus, isRefetched: boolean) => { + onSuccess?.(data, isRefetched); + }, + ); + + const handleOnError = useEffectEvent((error: Response) => { + onError?.(error); + }); + + useEffect(() => { + const { requestId, isSuccess, isError, isFetching, currentData, error } = + result; + if (requestId && !isFetching) { + if (isSuccess && currentData) { + handleOnSuccess(currentData, false); + } + if (isError) { + handleOnError(error as Response); + } + } + }, [result, handleOnSuccess, handleOnError]); + + return { + ...result, + }; +} + +export function useLlmDefaults(options: LlmDefaultsParams) { + const { dbId, onSuccess, onError, skip } = options || {}; + const result = useLlmDefaultsQuery(dbId || 0, { + skip: skip || !dbId, + refetchOnMountOrArgChange: true, + }); + + console.log('useLlmDefaults result:', result); + + const handleOnSuccess = useEffectEvent( + (data: LlmDefaults, isRefetched: boolean) => { + onSuccess?.(data, isRefetched); + }, + ); + + const handleOnError = useEffectEvent((error: Response) => { + onError?.(error); + }); + + useEffect(() => { + const { requestId, isSuccess, isError, isFetching, currentData, error } = + result; + if (requestId && !isFetching) { + if (isSuccess && currentData) { + handleOnSuccess(currentData, false); + } + if (isError) { + handleOnError(error as Response); + } + } + }, [result, handleOnSuccess, handleOnError]); + + return { + ...result, + }; +} diff --git a/superset-frontend/src/hooks/apiResources/queryApi.ts b/superset-frontend/src/hooks/apiResources/queryApi.ts index d09174c9b87d..acb4988023b0 100644 --- a/superset-frontend/src/hooks/apiResources/queryApi.ts +++ b/superset-frontend/src/hooks/apiResources/queryApi.ts @@ -81,6 +81,8 @@ export const api = createApi({ 'TableMetadatas', 'SqlLabInitialState', 'EditorQueries', + 'LlmContextStatus', + 'LlmDefaults', ], endpoints: () => ({}), baseQuery: supersetClientQuery, diff --git a/superset-frontend/src/hooks/apiResources/settings.ts b/superset-frontend/src/hooks/apiResources/settings.ts new file mode 100644 index 000000000000..69d2c83c9ff2 --- /dev/null +++ b/superset-frontend/src/hooks/apiResources/settings.ts @@ -0,0 +1,25 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { useApiV1Resource } from './apiResources'; + +export const useDatabaseTables = (id: string | number) => + useApiV1Resource<{ [schema: string]: string[] }>( + `/api/v1/database/${id}/schema_tables/`, + ); diff --git a/superset-frontend/src/hooks/apiResources/tables.ts b/superset-frontend/src/hooks/apiResources/tables.ts index 81792b4a522a..6da34820f729 100644 --- a/superset-frontend/src/hooks/apiResources/tables.ts +++ b/superset-frontend/src/hooks/apiResources/tables.ts @@ -27,6 +27,7 @@ export interface Table { label: string; value: string; type: string; + schema: string; extra?: { certification?: { certified_by: string; @@ -52,7 +53,7 @@ export type Data = { export type FetchTablesQueryParams = { dbId?: string | number; catalog?: string | null; - schema?: string; + schema?: string | string[]; forceRefresh?: boolean; onSuccess?: (data: Data, isRefetched: boolean) => void; onError?: (error: Response) => void; @@ -101,12 +102,32 @@ const tableApi = api.injectEndpoints({ endpoints: builder => ({ tables: builder.query({ providesTags: ['Tables'], + // query: ({ dbId, catalog, schema, forceRefresh }) => { + // return { + // endpoint: `/api/v1/database/${dbId ?? 'undefined'}/tables/`, + // // TODO: Would be nice to add pagination in a follow-up. Needs endpoint changes. + // urlParams: { + // force: forceRefresh, + // schema_name: schemas && schemas.length > 0 ? encodeURIComponent(schemas[0]) : '', + // schema_names: schemas, + // ...(catalog && { catalog_name: catalog }), + // }, + // transformResponse: ({ json }: QueryResponse) => ({ + // options: json.result, + // hasMore: json.count > json.result.length, + // }) + // }; + // }, query: ({ dbId, catalog, schema, forceRefresh }) => ({ endpoint: `/api/v1/database/${dbId ?? 'undefined'}/tables/`, // TODO: Would be nice to add pagination in a follow-up. Needs endpoint changes. urlParams: { force: forceRefresh, - schema_name: schema ? encodeURIComponent(schema) : '', + schema_name: Array.isArray(schema) + ? schema.map(s => encodeURIComponent(s)) + : schema + ? encodeURIComponent(schema) + : '', ...(catalog && { catalog_name: catalog }), }, transformResponse: ({ json }: QueryResponse) => ({ @@ -177,7 +198,12 @@ export function useTables(options: Params) { ); const enabled = Boolean( - dbId && schema && !isFetching && schemaOptionsMap.has(schema), + Array.isArray(schema) + ? dbId && + schema && + !isFetching && + schema.some(s => schemaOptionsMap.has(s)) + : dbId && schema && !isFetching && schemaOptionsMap.has(schema), ); const result = useTablesQuery( diff --git a/superset-frontend/src/logger/LogUtils.ts b/superset-frontend/src/logger/LogUtils.ts index 913a3d5af7bb..7aca76250102 100644 --- a/superset-frontend/src/logger/LogUtils.ts +++ b/superset-frontend/src/logger/LogUtils.ts @@ -85,6 +85,7 @@ export const LOG_ACTIONS_SQLLAB_COPY_RESULT_TO_CLIPBOARD = 'sqllab_copy_result_to_clipboard'; export const LOG_ACTIONS_SQLLAB_CREATE_CHART = 'sqllab_create_chart'; export const LOG_ACTIONS_SQLLAB_LOAD_TAB_STATE = 'sqllab_load_tab_state'; +export const LOG_ACTIONS_AI_ASSISTANT_OPENED = 'ai_assistant_opened'; // Log event types -------------------------------------------------------------- export const LOG_EVENT_TYPE_TIMING = new Set([ diff --git a/superset-frontend/src/pages/DatabaseList/index.tsx b/superset-frontend/src/pages/DatabaseList/index.tsx index e3f55a06f234..8a6c8cf78d4f 100644 --- a/superset-frontend/src/pages/DatabaseList/index.tsx +++ b/superset-frontend/src/pages/DatabaseList/index.tsx @@ -459,6 +459,16 @@ function DatabaseList({ size: 'md', id: 'expose_in_sqllab', }, + { + accessor: 'llm_available', + Header: t('AI Assistant'), + Cell: ({ + row: { + original: { llm_available: llmAvailable }, + }, + }: any) => , + size: 'md', + }, { Cell: ({ row: { diff --git a/superset/commands/database/bulk_schema_tables.py b/superset/commands/database/bulk_schema_tables.py new file mode 100644 index 000000000000..15201768a329 --- /dev/null +++ b/superset/commands/database/bulk_schema_tables.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +from typing import Any, cast + +from sqlalchemy.orm import lazyload, load_only + +from superset.commands.base import BaseCommand +from superset.commands.database.exceptions import ( + DatabaseNotFoundError, + DatabaseTablesUnexpectedError, +) +from superset.connectors.sqla.models import SqlaTable +from superset.daos.database import DatabaseDAO +from superset.exceptions import SupersetException +from superset.extensions import db, security_manager +from superset.models.core import Database +from superset.utils.core import DatasourceName + +logger = logging.getLogger(__name__) + + +class BulkSchemaTablesDatabaseCommand(BaseCommand): + _model: Database + + def __init__( + self, + db_id: int, + catalog_name: str | None, + schema_names: list[str], + force: bool, + ): + self._db_id = db_id + self._catalog_name = catalog_name + self._schema_names = schema_names + self._force = force + + def run(self) -> dict[str, Any]: + self.validate() + try: + all_tables = [] + all_views = [] + + for schema_name in self._schema_names: + tables = security_manager.get_datasources_accessible_by_user( + database=self._model, + catalog=self._catalog_name, + schema=schema_name, + datasource_names=sorted( + DatasourceName(*datasource_name) + for datasource_name in self._model.get_all_table_names_in_schema( + catalog=self._catalog_name, + schema=schema_name, + force=self._force, + cache=self._model.table_cache_enabled, + cache_timeout=self._model.table_cache_timeout, + ) + ), + ) + all_tables.extend(tables) + + views = security_manager.get_datasources_accessible_by_user( + database=self._model, + catalog=self._catalog_name, + schema=schema_name, + datasource_names=sorted( + DatasourceName(*datasource_name) + for datasource_name in self._model.get_all_view_names_in_schema( + catalog=self._catalog_name, + schema=schema_name, + force=self._force, + cache=self._model.table_cache_enabled, + cache_timeout=self._model.table_cache_timeout, + ) + ), + ) + all_views.extend(views) + + extra_dict_by_name = { + table.name: table.extra_dict + for table in ( + db.session.query(SqlaTable) + .filter( + SqlaTable.database_id == self._model.id, + SqlaTable.catalog == self._catalog_name, + ) + .options( + load_only( + SqlaTable.catalog, + SqlaTable.table_name, + SqlaTable.extra, + ), + lazyload(SqlaTable.columns), + lazyload(SqlaTable.metrics), + ) + ).all() + } + + options = sorted( + [ + { + "value": table.table, + "type": "table", + "extra": extra_dict_by_name.get(table.table, None), + "schema": table.schema, + } + for table in all_tables + ] + + [ + { + "value": view.table, + "type": "view", + "schema": view.schema, + } + for view in all_views + ], + key=lambda item: item["value"], + ) + + payload = {"count": len(all_tables) + len(all_views), "result": options} + return payload + except SupersetException: + raise + except Exception as ex: + raise DatabaseTablesUnexpectedError(str(ex)) from ex + + def validate(self) -> None: + self._model = cast(Database, DatabaseDAO.find_by_id(self._db_id)) + if not self._model: + raise DatabaseNotFoundError() diff --git a/superset/commands/database/tables.py b/superset/commands/database/tables.py index 31b3fd354270..38af4ce8204f 100644 --- a/superset/commands/database/tables.py +++ b/superset/commands/database/tables.py @@ -141,6 +141,7 @@ def run(self) -> dict[str, Any]: "value": table.table, "type": "table", "extra": extra_dict_by_name.get(table.table, None), + "schema": table.schema, } for table in tables ] @@ -148,6 +149,7 @@ def run(self) -> dict[str, Any]: { "value": view.table, "type": "view", + "schema": view.schema, } for view in views ] diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 98e3ffac2b37..6e4030145f5a 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -65,6 +65,25 @@ def run(self) -> Model: if not self._model: raise DatabaseNotFoundError() + # Log the operation without sensitive data + safe_properties = { + k: v + for k, v in self._properties.items() + if k + not in { + "password", + "api_key", + "encrypted_extra", + "masked_encrypted_extra", + "sqlalchemy_uri", + "ssh_tunnel", + } + } + logger.info( + "Updating database %s with fields: %s", + self._model_id, + list(safe_properties.keys()), + ) self.validate() if "masked_encrypted_extra" in self._properties: @@ -97,7 +116,6 @@ def run(self) -> Model: database.set_sqlalchemy_uri(database.sqlalchemy_uri) ssh_tunnel = self._handle_ssh_tunnel(database) new_catalog = database.get_default_catalog() - # update assets when the database catalog changes, if the database was not # configured with multi-catalog support; if it was enabled or is enabled in the # update we don't update the assets diff --git a/superset/config.py b/superset/config.py index 662e79576e10..5a2cb227a84b 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1144,6 +1144,7 @@ class CeleryConfig: # pylint: disable=too-few-public-methods "superset.tasks.thumbnails", "superset.tasks.cache", "superset.tasks.slack", + "superset.tasks.llm_context", ) result_backend = "db+sqlite:///celery_results.sqlite" worker_prefetch_multiplier = 1 @@ -1163,6 +1164,10 @@ class CeleryConfig: # pylint: disable=too-few-public-methods "task": "reports.prune_log", "schedule": crontab(minute=0, hour=0), }, + "check_for_expired_llm_context": { + "task": "check_for_expired_llm_context", + "schedule": crontab(minute="*/5"), + }, # Uncomment to enable pruning of the query table # "prune_query": { # "task": "prune_query", diff --git a/superset/constants.py b/superset/constants.py index d285fb1f9014..fef873709a55 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -132,6 +132,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods "related_objects": "read", "tables": "read", "schemas": "read", + "schema_tables": "read", "catalogs": "read", "select_star": "read", "table_metadata": "read", @@ -174,6 +175,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods "put_filters": "write", "put_colors": "write", "sync_permissions": "write", + "llm_defaults": "read", } EXTRA_FORM_DATA_APPEND_KEYS = { diff --git a/superset/daos/context_builder_task.py b/superset/daos/context_builder_task.py new file mode 100644 index 000000000000..dc460f71dd2d --- /dev/null +++ b/superset/daos/context_builder_task.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from superset.daos.base import BaseDAO +from superset.extensions import db +from superset.models.core import ContextBuilderTask + + +class ContextBuilderTaskDAO(BaseDAO[ContextBuilderTask]): + @staticmethod + def find_by_task_id(task_id: int) -> ContextBuilderTask | None: + return ( + db.session.query(ContextBuilderTask) + .filter_by(task_id=task_id) + .one_or_none() + ) + + @staticmethod + def get_latest_task_for_database(database_id: int) -> ContextBuilderTask: + return ( + db.session.query(ContextBuilderTask) + .filter(ContextBuilderTask.database_id == database_id) + .order_by(ContextBuilderTask.started_time.desc()) + .first() + ) + + @staticmethod + def get_last_successful_task_for_database(database_id: int) -> ContextBuilderTask: + return ( + db.session.query(ContextBuilderTask) + .filter( + ContextBuilderTask.database_id == database_id, + ContextBuilderTask.status == "SUCCESS", + ) + .order_by(ContextBuilderTask.started_time.desc()) + .first() + ) + + @staticmethod + def get_last_two_tasks_for_database(database_id: int) -> list[ContextBuilderTask]: + return ( + db.session.query(ContextBuilderTask) + .filter(ContextBuilderTask.database_id == database_id) + .order_by(ContextBuilderTask.started_time.desc()) + .limit(2) + .all() + ) diff --git a/superset/daos/database.py b/superset/daos/database.py index fa035534ee89..4dfbaec7d1eb 100644 --- a/superset/daos/database.py +++ b/superset/daos/database.py @@ -24,7 +24,12 @@ from superset.databases.filters import DatabaseFilter from superset.databases.ssh_tunnel.models import SSHTunnel from superset.extensions import db -from superset.models.core import Database, DatabaseUserOAuth2Tokens +from superset.models.core import ( + Database, + DatabaseUserOAuth2Tokens, + LlmConnection, + LlmContextOptions, +) from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.models.sql_lab import TabState @@ -59,6 +64,25 @@ def update( attributes["encrypted_extra"], ) + if attributes and "llm_connection" in attributes: + llm_conn_data = attributes.pop("llm_connection") + if llm_conn_data and item: + if item.llm_connection: + for k, v in llm_conn_data.items(): + setattr(item.llm_connection, k, v) + else: + # Use object.__setattr__ to bypass read-only restriction + object.__setattr__(item, 'llm_connection', LlmConnection(**llm_conn_data)) + + if attributes and "llm_context_options" in attributes: + llm_ctx_data = attributes.pop("llm_context_options") + if llm_ctx_data and item: + if item.llm_context_options: + for k, v in llm_ctx_data.items(): + setattr(item.llm_context_options, k, v) + else: + item.llm_context_options = LlmContextOptions(**llm_ctx_data) + return super().update(item, attributes) @staticmethod diff --git a/superset/databases/api.py b/superset/databases/api.py index c94060b28383..62dfcabaff25 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -39,6 +39,9 @@ from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError from superset import event_logger +from superset.commands.database.bulk_schema_tables import ( + BulkSchemaTablesDatabaseCommand, +) from superset.commands.database.create import CreateDatabaseCommand from superset.commands.database.delete import DeleteDatabaseCommand from superset.commands.database.exceptions import ( @@ -100,6 +103,7 @@ get_export_ids_schema, OAuth2ProviderResponseSchema, openapi_spec_methods_override, + QualifiedSchemaSchema, QualifiedTableSchema, SchemasResponseSchema, SelectStarResponseSchema, @@ -111,7 +115,7 @@ ValidateSQLRequest, ValidateSQLResponse, ) -from superset.databases.utils import get_table_metadata +from superset.databases.utils import get_database_metadata, get_table_metadata from superset.db_engine_specs import get_available_engine_specs from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( @@ -124,6 +128,7 @@ TableNotFoundException, ) from superset.extensions import security_manager +from superset.llms import dispatcher from superset.models.core import Database from superset.sql.parse import Table from superset.superset_typing import FlaskResponse @@ -165,6 +170,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "select_star", "catalogs", "schemas", + "schema_tables", "test_connection", "related_objects", "function_names", @@ -178,6 +184,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "upload", "oauth2", "sync_permissions", + "llm_schema", + "llm_defaults", } resource_name = "database" @@ -203,6 +211,9 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "impersonate_user", "is_managed_externally", "engine_information", + "llm_available", + "llm_connection", + "llm_context_options", ] list_columns = [ "allow_file_upload", @@ -231,6 +242,9 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "disable_drill_to_detail", "allow_multi_catalog", "engine_information", + "llm_available", + "llm_connection", + "llm_context_options", ] add_columns = [ "database_name", @@ -248,6 +262,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "extra", "encrypted_extra", "server_cert", + "llm_available", ] edit_columns = add_columns @@ -261,6 +276,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "database_name", "expose_in_sqllab", "uuid", + "llm_available", ] search_filters = {"allow_file_upload": [DatabaseUploadEnabledFilter]} allowed_rel_fields = {"changed_by", "created_by"} @@ -275,6 +291,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "created_by.first_name", "database_name", "expose_in_sqllab", + "llm_available", ] # Removes the local limit for the page size max_page_size = -1 @@ -879,10 +896,62 @@ def tables(self, pk: int, **kwargs: Any) -> FlaskResponse: catalog_name = kwargs["rison"].get("catalog_name") schema_name = kwargs["rison"].get("schema_name", "") - command = TablesDatabaseCommand(pk, catalog_name, schema_name, force) + if isinstance(schema_name, str): + command = TablesDatabaseCommand(pk, catalog_name, schema_name, force) + elif isinstance(schema_name, list): + command = BulkSchemaTablesDatabaseCommand( + pk, catalog_name, schema_name, force + ) + payload = command.run() return self.response(200, **payload) + @expose("//schema_tables/") + @protect() + @rison(database_schemas_query_schema) + @statsd_metrics + @handle_api_exception + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.schema_tables", + log_to_statsd=False, + ) + def schema_tables(self, pk: int, **kwargs: Any) -> FlaskResponse: + # Retriete all the schemas and tables for a given database and return them as a dict + # with schema names as keys and table names as values + database = self.datamodel.get(pk, self._base_filters) + if not database: + return self.response_404() + try: + catalog = kwargs["rison"].get("catalog") + schemas = database.get_all_schema_names( + cache=database.schema_cache_enabled, + cache_timeout=database.schema_cache_timeout or None, + ) + schemas = security_manager.get_schemas_accessible_by_user( + database, + catalog, + schemas, + ) + + def get_tables(pk, catalog, schema, force): + tables_result = TablesDatabaseCommand(pk, catalog, schema, force).run()[ + "result" + ] + return [result["value"] for result in tables_result] + + schema_tables = { + schema: get_tables(pk, catalog, schema, False) for schema in schemas + } + return self.response(200, result=schema_tables) + except OperationalError: + return self.response( + 500, message="There was an error connecting to the database" + ) + except OAuth2RedirectError: + raise + except SupersetException as ex: + return self.response(ex.status, message=ex.message) + @expose("//table///", methods=("GET",)) @protect() @check_table_access @@ -1012,6 +1081,101 @@ def table_extra_metadata_deprecated( payload = database.db_engine_spec.get_extra_table_metadata(database, table) return self.response(200, **payload) + @expose("//llm_schema/", methods=["GET"]) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.llm_schema", + log_to_statsd=False, + ) + def llm_schema(self, pk: int) -> FlaskResponse: + # Construct a JSON representation of the schema for the entire database and put it in this format: + # {[ + # { + # schema_name = + # schema_description = + # relations = [ + # { + # rel_name = + # rel_kind = + # rel_description = + # indexes = [ + # { + # index_name = + # is_unique = + # column_names = + # index_definition = + # }, + + # ] + # foregin_keys = [ + # { + # constraint_name = + # column_name = + # referenced_column = + # }, + + # ] + # columns = [ + # { + # column_name = + # data_type = + # is_nullable = + # column_description = + # most_common_values = + # }, + + # ] + # }, + # ] + # }, + # ]} + self.incr_stats("init", self.llm_schema.__name__) + + try: + parameters = QualifiedSchemaSchema().load(request.args) + except ValidationError as ex: + raise InvalidPayloadSchemaError(ex) from ex + + database = DatabaseDAO.find_by_id(pk) + if not database: + return self.response_404() + + context_settings = json.loads(database.llm_context_settings or "{}") + selected_schemas = context_settings.get("schemas", None) + include_indexes = context_settings.get("include_indexes", True) + top_k = context_settings.get("top_k", 10) + top_k_limit = context_settings.get("top_k_limit", 10000) + + schemas = get_database_metadata( + database, + parameters["catalog"], + selected_schemas, + include_indexes, + top_k, + top_k_limit, + ) + schema_response = None + + if parameters["minify"]: + + def reduce_json_token_count(data): + """ + Reduces the token count of a JSON string. + """ + data = data.replace(": ", ":").replace(", ", ",") + + return data + + schema_response = reduce_json_token_count( + json.dumps([schema.model_dump() for schema in schemas]) + ) + else: + schema_response = [schema.model_dump() for schema in schemas] + + return self.response(200, result=schema_response) + @expose("//table_metadata/", methods=["GET"]) @protect() @statsd_metrics @@ -2105,3 +2269,16 @@ def schemas_access_for_file_upload(self, pk: int) -> Response: database, database.get_default_catalog(), schemas_allowed, True ) return self.response(200, schemas=schemas_allowed_processed) + + @expose("//llm_defaults/") + @protect() + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.llm_defaults", + log_to_statsd=False, + ) + def llm_defaults(self, pk: int) -> Response: + return self.response( + 200, + **dispatcher.get_default_options(pk), # type: ignore[operator] # noqa: E501, + ) diff --git a/superset/databases/filters.py b/superset/databases/filters.py index 321eb621005e..6b54c4068443 100644 --- a/superset/databases/filters.py +++ b/superset/databases/filters.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import logging from typing import Any from flask import current_app, g @@ -27,6 +28,8 @@ from superset.models.core import Database from superset.views.base import BaseFilter +logger = logging.getLogger(__name__) + def can_access_databases(view_menu_name: str) -> set[str]: """ diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index ea24ba219b63..65eca5b179b8 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -78,7 +78,12 @@ "type": "object", "properties": { "force": {"type": "boolean"}, - "schema_name": {"type": "string"}, + "schema_name": { + "oneOf": [ + {"type": "string"}, + {"type": "array", "items": {"type": "string"}}, + ] + }, "catalog_name": {"type": "string"}, }, "required": ["schema_name"], @@ -451,6 +456,68 @@ class DatabaseSSHTunnel(Schema): private_key_password = fields.String(required=False) +class LlmConnectionSchema(Schema): + provider = fields.String( + required=False, + allow_none=True, + metadata={"description": "The LLM provider"}, + ) + api_key = fields.String( + required=False, + allow_none=True, + metadata={"description": "The LLM API key"}, + ) + model = fields.String( + required=False, + allow_none=True, + metadata={"description": "The LLM model"}, + ) + enabled = fields.Boolean( + required=False, + allow_none=True, + metadata={"description": "Whether the LLM connection is enabled"}, + ) + + +class LlmContextOptionsSchema(Schema): + refresh_interval = fields.Integer( + required=False, + allow_none=True, + metadata={"description": "The interval in hours to refresh the LLM context"}, + ) + schemas = fields.String( + required=False, + allow_none=True, + metadata={"description": "A list of schemas to include in the LLM context"}, + ) + include_indexes = fields.Boolean( + required=False, + allow_none=True, + metadata={"description": "Whether to include indexes in the LLM context"}, + ) + top_k = fields.Integer( + required=False, + allow_none=True, + metadata={ + "description": "The number of top K results to include in the LLM context" + }, + ) + top_k_limit = fields.Integer( + required=False, + allow_none=True, + metadata={ + "description": "The limit for the top K results to include in the LLM context" + }, + ) + instructions = fields.String( + required=False, + allow_none=True, + metadata={ + "description": "Instructions for the LLM to follow when generating SQL" + }, + ) + + class DatabasePostSchema(DatabaseParametersSchemaMixin, Schema): class Meta: # pylint: disable=too-few-public-methods unknown = EXCLUDE @@ -506,6 +573,8 @@ class Meta: # pylint: disable=too-few-public-methods external_url = fields.String(allow_none=True) uuid = fields.String(required=False) ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) + llm_connection = fields.Nested(LlmConnectionSchema, allow_none=True) + llm_context_options = fields.Nested(LlmContextOptionsSchema, allow_none=True) class DatabasePutSchema(DatabaseParametersSchemaMixin, Schema): @@ -563,6 +632,8 @@ class Meta: # pylint: disable=too-few-public-methods external_url = fields.String(allow_none=True) ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) uuid = fields.String(required=False) + llm_connection = fields.Nested(LlmConnectionSchema, allow_none=True) + llm_context_options = fields.Nested(LlmContextOptionsSchema, allow_none=True) class DatabaseTestConnectionSchema(DatabaseParametersSchemaMixin, Schema): @@ -1063,6 +1134,8 @@ class DatabaseConnectionSchema(Schema): metadata={"description": sqlalchemy_uri_description}, validate=[Length(1, 1024), sqlalchemy_uri_validator], ) + llm_connection = fields.Nested(LlmConnectionSchema, allow_none=True) + llm_context_options = fields.Nested(LlmContextOptionsSchema, allow_none=True) class DelimitedListField(fields.List): @@ -1342,3 +1415,26 @@ class QualifiedTableSchema(Schema): load_default=None, metadata={"description": "The table catalog"}, ) + + +class QualifiedSchemaSchema(Schema): + """ + Schema for a qualified schema reference. + + Catalog can be ommited to fallback to default value. + """ + + schema = fields.String( + required=True, + metadata={"description": "The schema"}, + ) + catalog = fields.String( + required=False, + load_default=None, + metadata={"description": "The catalog"}, + ) + minify = fields.Boolean( + required=False, + load_default=False, + metadata={"description": "Whether to minify the schema"}, + ) diff --git a/superset/databases/utils.py b/superset/databases/utils.py index e78500665ec7..f9dee7ed3279 100644 --- a/superset/databases/utils.py +++ b/superset/databases/utils.py @@ -17,12 +17,14 @@ from __future__ import annotations +import logging from typing import Any, TYPE_CHECKING from sqlalchemy.engine.url import make_url, URL from superset.commands.database.exceptions import DatabaseInvalidError from superset.sql.parse import Table +from typing import List, Optional if TYPE_CHECKING: from superset.databases.schemas import ( @@ -31,6 +33,8 @@ TableMetadataResponse, ) +logger = logging.getLogger(__name__) + def get_foreign_keys_metadata( database: Any, @@ -128,3 +132,330 @@ def make_url_safe(raw_url: str | URL) -> URL: else: return raw_url + + +def get_database_metadata( + database: Any, + catalog: str | None = None, + include_indexes: bool = True, + tables: list[str] | None = None, + top_k: int = 10, + top_k_limit: int = 100000, +) -> List[Schema]: + """ + Get database metadata information, including schemas, tables, columns, indexes, fks. + :param database: The database model + :return: Database metadata ready for API response + """ + logger.info("Getting metadata for database %s", database.database_name) + + # Build the list of selected schemas from the list of tables by extracting the schema name + schemas = set() + if tables: + for table in tables: + schema = table.split(".")[0] + schemas.add(schema) + + db_schemas = database.get_all_schema_names(catalog=catalog, cache=False) + logger.info("Found schemas: %s", db_schemas) + schemas_info = [] + + for schema in db_schemas: + if tables and (len(tables) > 0) and (schema not in schemas): + logger.info("Skipping schema %s not in schemas", schema) + continue + schema_info = get_schema_metadata( + database, + schema, + tables=tables, + include_indexes=include_indexes, + top_k=top_k, + top_k_limit=top_k_limit, + ) + schemas_info.append(schema_info) + + return schemas_info + + +def get_schema_metadata( + database: Any, + schema: str, + catalog: str | None = None, + tables: list[str] | None = None, + include_indexes: bool = True, + top_k: int = 10, + top_k_limit: int = 100000, +) -> Schema: + """ + Get schema metadata information, including tables, columns, indexes, fks. + :param database: The database model + :param schema: The schema name + :return: Schema metadata ready for API response + """ + db_tables = database.get_all_table_names_in_schema(catalog=catalog, schema=schema) + relations = [] + + for table, schema, catalog in db_tables: + if tables and len(tables) > 0 and f"{schema}.{table}" not in tables: + logger.info("Skipping table %s not in tables", table) + continue + t = Table(catalog=catalog, schema=schema, table=table) + table_metadata = get_table_relation_metadata( + database, + t, + include_indexes=include_indexes, + top_k=top_k, + top_k_limit=top_k_limit, + ) + relations.append(table_metadata) + + views = database.get_all_view_names_in_schema(catalog=catalog, schema=schema) + for view, schema, catalog in views: + v = Table(catalog=catalog, schema=schema, table=view) + view_metadata = get_view_relation_metadata(database, v) + relations.append(view_metadata) + + return Schema( + schema_name=schema, + relations=relations, + ) + + +def get_table_relation_metadata( + database: Any, + table: Table, + include_indexes: bool = True, + top_k: int = 10, + top_k_limit: int = 100000, +) -> Relation: + """ + Get table metadata information, including type, pk, fks. + This function raises SQLAlchemyError when a schema is not found. + + :param database: The database model + :param table: Table instance + :return: Dict table metadata ready for API response + """ + columns = database.get_columns(table) + primary_key = database.get_pk_constraint(table) + if primary_key and primary_key.get("constrained_columns"): + primary_key["column_names"] = primary_key.pop("constrained_columns") + primary_key["type"] = "pk" + + foreign_keys = get_foreign_keys_relation_data(database, table) + + if include_indexes: + indexes = get_indexes_relation_data(database, table) + else: + indexes = [] + + payload_columns: list[Column] = [] + table_comment = database.get_table_comment(table) + + for col in columns: + dtype = get_col_type(col) + dtype = dtype.split("(")[0] if "(" in dtype else dtype + + top_k_values = None + if dtype in ["CHAR", "VARCHAR", "TEXT", "STRING", "NVARCHAR"]: + top_k_values = get_column_top_k_values( + database, + table, + col["column_name"], + table.schema, + top_k=top_k, + top_k_limit=top_k_limit, + ) + + column_metadata = Column( + column_name=col["column_name"], + data_type=dtype, + is_nullable=col["nullable"], + column_description=col.get("comment"), + most_common_values=top_k_values if top_k_values else None, + ) + + payload_columns.append(column_metadata) + + result = Relation( + rel_name=table.table, + rel_kind="table", + rel_description=table_comment, + foreign_keys=foreign_keys, + columns=payload_columns, + indexes=indexes if include_indexes else None, + ) + + return result + + +def get_column_top_k_values( + database: Any, + table: Table, + column_name: str, + schema: str | None, + top_k: int = 10, + top_k_limit: int = 100000, +) -> list[str]: + # db_type = database.db_engine_spec.engine + # logging.info(f"Getting top k values for {column_name} in {table.__str__()} {schema} {db_type}") + + query = f""" + SELECT \"{column_name}\" AS value, COUNT(*) AS frequency + FROM (SELECT \"{column_name}\" FROM \"{table.table}\" LIMIT {top_k_limit}) AS subquery + WHERE \"{column_name}\" IS NOT NULL + GROUP BY \"{column_name}\" + ORDER BY frequency DESC + LIMIT {top_k}; + """ + + db_engine_spec = database.db_engine_spec + + with database.get_raw_connection(catalog="", schema=schema or "") as conn: + cursor = conn.cursor() + mutated_query = database.mutate_sql_based_on_config(query) + try: + cursor.execute(mutated_query) + db_engine_spec.execute(cursor, mutated_query, database) + result = db_engine_spec.fetch_data(cursor) + except Exception as e: + logging.error( + f"Unable to retrieve top_k values on {schema}/{table}, column {column_name}: {e}" + ) + return [] + + return [value for (value, _) in result] + + +def get_view_relation_metadata( + database: Any, + table: Table, +) -> Relation: + relation = get_table_relation_metadata(database, table, include_indexes=False) + # Create a new Relation with rel_kind set to "view" + return Relation( + rel_name=relation.rel_name, + rel_kind="view", + rel_description=relation.rel_description, + foreign_keys=relation.foreign_keys, + columns=relation.columns, + indexes=None, # Views don't have indexes + ) + + +def get_foreign_keys_relation_data( + database: Any, + table: Table, +) -> List[FKey]: + foreign_keys = database.get_foreign_keys(table) + ret = [] + for fk in foreign_keys: + result = FKey( + column_name=fk.pop("constrained_columns")[0], + referenced_column=fk.pop("referred_columns")[0], + constraint_name=fk.pop("name"), + ) + ret.append(result) + return ret + + +def get_indexes_relation_data( + database: Any, + table: Table, +) -> List[Index]: + indexes = database.get_indexes(table) + ret = [] + for idx in indexes: + result = Index( + column_names=idx.pop("column_names"), + is_unique=idx.pop("unique"), + index_name=idx.pop("name"), + index_definition=None, + ) + ret.append(result) + return ret + + + +from pydantic import BaseModel, Field + + +class FKey(BaseModel): + """ + Contains information about a foreign key contraints. + """ + + constraint_name: str = Field(description="Name of the the foreign key constraint.") + column_name: str = Field( + description="Name of the column to which the foreign key constraint is applied." + ) + referenced_column: str = Field( + description="Foreign column referenced by the constraint, expressed as 'foreign_schema_name.foreign_table_name.foreign_column_name'." + ) + + +class Index(BaseModel): + """ + Contains information about an index. + """ + + index_name: str = Field(description="Name of the index.") + is_unique: bool = Field(description="Whether the index is a unique constraint.") + column_names: List[str] = Field( + description="Name of the column(s) constituting the index." + ) + index_definition: Optional[str] = Field(description="CREATE INDEX statement.") + + +class Column(BaseModel): + """ + Contains information about a column. + """ + + column_name: str = Field(description="Name of the column.") + data_type: str = Field(description="Column data type.") + is_nullable: bool = Field( + description="Whether the column has or not a NOT NULL constraint." + ) + column_description: Optional[str] = Field( + default=None, description="SQL comment associated with the column." + ) + most_common_values: Optional[List[str]] = Field( + default=None, description="Most common values in the last many records." + ) + + +class Relation(BaseModel): + """ + Contains information about a relation, which is a table, a view or a materialized view. This includes columns, indexes and foreign keys. + """ + + rel_name: str = Field(description="Name of the relation.") + rel_kind: str = Field(description="Type of relation, such as 'table' or 'view'.") + rel_description: Optional[str] = Field( + default=None, description="SQL comment associated with the relation." + ) + indexes: Optional[List[Index]] = Field( + default=None, description="Indexes associated with columns of the relation." + ) + foreign_keys: Optional[List[FKey]] = Field( + default=None, + description="Foreign keys associated with columns of the relation.", + ) + columns: List[Column] = Field( + default=[], description="Columns belonging to the relation." + ) + + +class Schema(BaseModel): + """ + Contains information about a schema, including its relations. + """ + + schema_name: str = Field(description="Name of the schema.") + schema_description: Optional[str] = Field( + default=None, description="SQL comment associated with the schema." + ) + relations: List[Relation] = Field( + default=[], description="Relations belonging to the schema." + ) diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 92589309e199..551b8c16b66e 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -93,6 +93,7 @@ ) from superset.views.error_handling import handle_api_exception from superset.views.filters import BaseFilterRelatedUsers, FilterRelatedOwners +from superset.views.utils import sanitize_datasource_data logger = logging.getLogger(__name__) @@ -361,7 +362,12 @@ def post(self) -> Response: try: new_model = CreateDatasetCommand(item).run() - return self.response(201, id=new_model.id, result=item, data=new_model.data) + return self.response( + 201, + id=new_model.id, + result=item, + data=sanitize_datasource_data(new_model.data), + ) except DatasetInvalidError as ex: return self.response_422(message=ex.normalized_messages()) except DatasetCreateFailedError as ex: diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 133456f35a6d..e1c8278a216a 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -23,6 +23,7 @@ from typing import Any, Callable, TYPE_CHECKING import wtforms_json +from celery.signals import after_task_publish from colorama import Fore, Style from deprecation import deprecated from flask import abort, current_app, Flask, redirect, request, session, url_for @@ -76,6 +77,19 @@ logger = logging.getLogger(__name__) +@after_task_publish.connect +def update_sent_state(sender: str | None = None, headers: dict[str, str] | None = None, **kwargs: Any) -> None: + task = celery_app.tasks.get(sender) + backend = task.backend if task else celery_app.backend + + # For context worker tasks, set a special state so that we can tell the difference between + # tasks that might run and tasks that don't exist anymore. + logger.info("headers: %s", headers) + + if headers and headers["task"] == "generate_llm_context": + backend.store_result(headers["id"], None, "PUBLISHED") + + class SupersetAppInitializer: # pylint: disable=too-many-public-methods def __init__(self, app: SupersetApp) -> None: super().__init__() @@ -177,6 +191,7 @@ def init_views(self) -> None: from superset.explore.permalink.api import ExplorePermalinkRestApi from superset.extensions.view import ExtensionsView from superset.importexport.api import ImportExportRestApi + from superset.llms.api import CustomLlmProviderRestApi from superset.queries.api import QueryRestApi from superset.queries.saved_queries.api import SavedQueryRestApi from superset.reports.api import ReportScheduleRestApi @@ -255,6 +270,7 @@ def init_views(self) -> None: appbuilder.add_api(DashboardPermalinkRestApi) appbuilder.add_api(DashboardRestApi) appbuilder.add_api(DatabaseRestApi) + appbuilder.add_api(CustomLlmProviderRestApi) appbuilder.add_api(DatasetRestApi) appbuilder.add_api(DatasetColumnsRestApi) appbuilder.add_api(DatasetMetricRestApi) diff --git a/superset/llms/__init__.py b/superset/llms/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/superset/llms/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/llms/anthropic.py b/superset/llms/anthropic.py new file mode 100644 index 000000000000..c48c76a8a676 --- /dev/null +++ b/superset/llms/anthropic.py @@ -0,0 +1,215 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import logging +from typing import List + +import anthropic + +from superset.daos.database import DatabaseDAO +from superset.llms.base_llm import BaseLlm + +logger = logging.getLogger(__name__) + + +class AnthropicLlm(BaseLlm): + llm_type = "Anthropic" + cached_context_size = None + + def _trim_markdown(self, text: str) -> bool: + # Check the body for SQL wrapped in markdown ```{language}\n...\n``` + try: + sql_start = text.index("```") + # Find the first newline after the start of the SQL block + sql_start_len = text.index("\n", sql_start) + 1 if sql_start != -1 else 0 + sql_end = text.index("\n```") + sql = text[sql_start + sql_start_len : sql_end] + except ValueError: + # There was no markdown, so assume for now the whole response is the SQL + sql = text + + return sql + + @classmethod + def get_system_instructions(cls, dialect) -> str: + system_instructions = f"""You are a {dialect} database expert. Given an input question, create a syntactically correct {dialect} query. You MUST only answer with the SQL query, nothing else. Unless the user specifies a specific number of results they wish to obtain, always limit your query to at most return {cls.max_results} results. You can order the results by relevant columns. You MUST check that the query doesn't contain syntax errors or incorrect table, views, column names or joins on wrong columns. Fix any error you might find before returning your answer. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. To construct your database query you MUST ALWAYS use the database metadata information provided to you as a JSON file. Do NOT skip this step. This JSON file specifies all the database schemas, for each schema all its relations (which are tables, and views) and for each table its columns, indexes, and foreign key constraints. The unique indexes are very useful to understand what differentiates one record to another in the same relation. The foreign key constraints are very useful to find the correct columns to join. Do not include any markdown syntax in your response.""" + return system_instructions + + @staticmethod + def get_models(): + return { + "claude-3-5-haiku-latest": { + "name": "Claude Haiku 3.5", + "input_token_limit": 200000, + }, + "claude-sonnet-4-0": { + "name": "Claude Sonnet 4", + "input_token_limit": 200000, + }, + "claude-3-7-sonnet-latest": { + "name": "Claude Sonnet 3.7", + "input_token_limit": 200000, + }, + "claude-opus-4-0": {"name": "Claude Opus 4", "input_token_limit": 200000}, + } + + def generate_sql(self, prompt: str, history: str, schemas: List[str] | None) -> str: + """ + Generate SQL from a user prompt using the Anthropic SDK. + """ + db = DatabaseDAO.find_by_id(self.pk, True) + if not db: + logger.error("Database %s not found.", self.pk) + return + + if not db.llm_connection.enabled: + logger.error("LLM is not enabled for database %s.", self.pk) + return + + if not db.llm_connection.provider == self.llm_type: + logger.error( + "LLM provider is not %s for database %s.", self.llm_type, self.pk + ) + return + + llm_api_key = db.llm_connection.api_key + if not llm_api_key: + logger.error("API key not set for database %s.", self.pk) + return + + llm_model = db.llm_connection.model + if not llm_model: + logger.error("Model not set for database %s.", self.pk) + return + + logger.info("Using model %s for database %s", llm_model, self.pk) + + user_instructions = db.llm_context_options.instructions + client = anthropic.Anthropic(api_key=llm_api_key) + + # Compose system prompt and context + system_prompt = ( + user_instructions + if user_instructions + else self.get_system_instructions(self.dialect) + ) + context_json = json.dumps( + [ + schema + for schema in self.context + if not schemas or schema["schema_name"] in schemas + ] + ) + + message_parts = [system_prompt, "Database metadata:", context_json] + if history: + message_parts.append(history) + message_parts.append(prompt) + + try: + response = client.messages.create( + model=llm_model, + messages=[{"role": "user", "content": "\n".join(message_parts)}], + max_tokens=8192, + ) + except Exception as e: + logger.error("Anthropic API error: %s", e) + return f"-- Failed to generate SQL: {str(e)}" + + if not response or not response.content: + logger.error("No response from Anthropic API.") + return "-- Failed to generate SQL: No response from Anthropic API." + + reply = "\n".join( + [part.text for part in response.content if part.type == "text"] + ) + + sql = self._trim_markdown(reply) + if not sql: + return "-- Unable to find valid SQL in the LLM response" + + logger.info("Generated SQL: %s", sql) + return sql + + def get_context_size(self) -> int: + """ + Count the number of tokens in the prompt using the Anthropic SDK. + Cache the result in self.cached_size, which expires when self.cache_expiry changes. + """ + db = DatabaseDAO.find_by_id(self.pk, True) + if not db: + logger.error("Database %s not found.", self.pk) + return + + if not db.llm_connection.provider == self.llm_type: + logger.error( + "LLM provider is not %s for database %s.", self.llm_type, self.pk + ) + return + + llm_api_key = db.llm_connection.api_key + if not llm_api_key: + logger.error("API key not set for database %s.", self.pk) + return + + llm_model = db.llm_connection.model + if not llm_model: + logger.error("Model not set for database %s.", self.pk) + return + + # If we have a cached size and a valid cache_expiry, return the cached size + if self.cached_context_size is not None: + logger.info("Using cached context size: %s", self.cached_context_size) + return self.cached_context_size + else: + # Invalidate any old cached size + self.cached_context_size = None + + user_instructions = db.llm_context_options.instructions + system_prompt = ( + user_instructions + if user_instructions + else self.get_system_instructions(self.dialect) + ) + context_json = json.dumps([schema for schema in self.context]) + + # Anthropic expects a list of messages, similar to OpenAI + messages = [ + { + "role": "user", + "content": "\n".join( + [system_prompt, "Database metadata:", context_json] + ), + }, + ] + + try: + client = anthropic.Anthropic(api_key=llm_api_key) + # Use the count_tokens method from the Anthropic SDK + response = client.messages.count_tokens( + model=llm_model, + messages=messages, + ) + total_tokens = response.input_tokens + logger.info("Calculated context size: %s", total_tokens) + except Exception as e: + logger.error("Anthropic API error: %s", e) + return + + # Cache the size until cache_expiry changes or is reached + self.cached_context_size = total_tokens + return self.cached_context_size diff --git a/superset/llms/api.py b/superset/llms/api.py new file mode 100644 index 000000000000..3b688d0c2b3f --- /dev/null +++ b/superset/llms/api.py @@ -0,0 +1,471 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import logging +from typing import Any + +from flask import request, Response +from flask_appbuilder.api import expose, protect, rison, safe +from flask_appbuilder.api.schemas import get_list_schema +from flask_appbuilder.models.sqla.interface import SQLAInterface +from marshmallow import fields, post_load, Schema, ValidationError + +from superset.models.core import CustomLlmProvider +from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics + +logger = logging.getLogger(__name__) + + +class CustomLlmProviderSchema(Schema): + id = fields.Integer(dump_only=True) + name = fields.String(required=True) + endpoint_url = fields.String(required=True) + request_template = fields.String(required=True) + response_path = fields.String(required=True) + headers = fields.String(allow_none=True) + models = fields.String(required=True) + system_instructions = fields.String(allow_none=True) + timeout = fields.Integer(allow_none=True, missing=30) + enabled = fields.Boolean(missing=True) + created_on = fields.DateTime(dump_only=True) + changed_on = fields.DateTime(dump_only=True) + + @post_load + def validate_json_fields(self, data, **kwargs): + """Validate JSON fields.""" + # Validate request_template + try: + json.loads(data["request_template"]) + except json.JSONDecodeError: + raise ValidationError("request_template must be valid JSON") + + # Validate headers if provided + if data.get("headers"): + try: + json.loads(data["headers"]) + except json.JSONDecodeError: + raise ValidationError("headers must be valid JSON") + + # Validate models + try: + models = json.loads(data["models"]) + if not isinstance(models, dict): + raise ValidationError("models must be a JSON object") + except json.JSONDecodeError: + raise ValidationError("models must be valid JSON") + + return data + + +class CustomLlmProviderRestApi(BaseSupersetModelRestApi): + datamodel = SQLAInterface(CustomLlmProvider) + resource_name = "custom_llm_provider" + allow_browser_login = True + + class_permission_name = "CustomLlmProvider" + method_permission_name = { + "get": "read", + "get_list": "read", + "post": "write", + "put": "write", + "delete": "write", + } + + add_columns = [ + "name", + "endpoint_url", + "request_template", + "response_path", + "headers", + "models", + "system_instructions", + "timeout", + "enabled", + ] + + edit_columns = add_columns + + list_columns = [ + "id", + "name", + "endpoint_url", + "enabled", + "created_on", + "changed_on", + ] + + show_columns = [ + "id", + "name", + "endpoint_url", + "request_template", + "response_path", + "headers", + "models", + "system_instructions", + "timeout", + "enabled", + "created_on", + "changed_on", + ] + + openapi_spec_tag = "Custom LLM Providers" + + add_model_schema = CustomLlmProviderSchema() + edit_model_schema = CustomLlmProviderSchema() + show_model_schema = CustomLlmProviderSchema() + + @expose("/test", methods=("POST",)) + @protect() + @safe + @statsd_metrics + def test_connection(self) -> Response: + """Test connection to a custom LLM provider. + --- + post: + summary: Test connection to a custom LLM provider + requestBody: + description: Custom LLM provider connection details + required: true + content: + application/json: + schema: + type: object + required: + - endpoint_url + - request_template + - response_path + properties: + endpoint_url: + type: string + description: The LLM provider endpoint URL + request_template: + type: string + description: JSON template for requests + response_path: + type: string + description: Path to extract response content + headers: + type: string + description: Additional headers as JSON + timeout: + type: integer + description: Request timeout in seconds + responses: + 200: + description: Connection test result + content: + application/json: + schema: + type: object + properties: + result: + type: object + properties: + status: + type: string + status_code: + type: integer + message: + type: string + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 500: + $ref: '#/components/responses/500' + """ + try: + data = request.get_json() + + # Validate required fields + required_fields = ["endpoint_url", "request_template", "response_path"] + for field in required_fields: + if field not in data: + return self.response_400(f"Missing required field: {field}") + + # Validate JSON fields + try: + request_template = json.loads(data["request_template"]) + except json.JSONDecodeError: + return self.response_400("request_template must be valid JSON") + + headers = {"Content-Type": "application/json"} + if data.get("headers"): + try: + custom_headers = json.loads(data["headers"]) + headers.update(custom_headers) + except json.JSONDecodeError: + return self.response_400("headers must be valid JSON") + + # Create a simple test request + test_request = { + "model": "test", + "messages": [{"role": "user", "content": "SELECT 1"}], + } + + # Substitute template variables if needed + test_data = request_template.copy() + for key, value in test_data.items(): + if isinstance(value, str) and "{" in value: + test_data[key] = value.format( + model="test", messages=test_request["messages"], api_key="test" + ) + + import requests + + timeout = data.get("timeout", 30) + + try: + response = requests.post( + data["endpoint_url"], + json=test_data, + headers=headers, + timeout=timeout, + ) + + return self.response( + 200, + result={ + "status": "success", + "status_code": response.status_code, + "message": "Connection test completed", + }, + ) + + except requests.exceptions.RequestException as e: + return self.response( + 200, + result={ + "status": "error", + "message": f"Connection failed: {str(e)}", + }, + ) + + except Exception as e: + logger.exception("Error testing custom LLM provider connection") + return self.response_500(message=str(e)) + + @expose("/", methods=("GET",)) + @protect() + @safe + @statsd_metrics + @rison(get_list_schema) + def get_list(self, **kwargs: Any) -> Response: + """Get list of custom LLM providers. + --- + get: + summary: Get a list of custom LLM providers + parameters: + - in: query + name: q + content: + application/json: + schema: + $ref: '#/components/schemas/get_list_schema' + responses: + 200: + description: List of custom LLM providers + content: + application/json: + schema: + type: object + properties: + ids: + description: >- + A list of custom LLM provider ids + type: array + items: + type: integer + count: + description: >- + The total record count on the backend + type: number + result: + description: >- + The result from the get list query + type: array + items: + $ref: '#/components/schemas/{{self.__class__.__name__}}.get_list' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + return super().get_list(**kwargs) + + @expose("/", methods=("GET",)) + @protect() + @safe + @statsd_metrics + def get(self, pk: int, **kwargs: Any) -> Response: + """Get a custom LLM provider by ID. + --- + get: + summary: Get a custom LLM provider + parameters: + - in: path + schema: + type: integer + name: pk + description: The custom LLM provider id + responses: + 200: + description: Custom LLM provider details + content: + application/json: + schema: + $ref: '#/components/schemas/{{self.__class__.__name__}}.get' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + return super().get(pk, **kwargs) + + @expose("/", methods=("POST",)) + @protect() + @safe + @statsd_metrics + def post(self) -> Response: + """Create a new custom LLM provider. + --- + post: + summary: Create a custom LLM provider + requestBody: + description: Custom LLM provider details + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/{{self.__class__.__name__}}.post' + responses: + 201: + description: Custom LLM provider created + content: + application/json: + schema: + type: object + properties: + id: + type: integer + result: + $ref: '#/components/schemas/{{self.__class__.__name__}}.post' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + return super().post() + + @expose("/", methods=("PUT",)) + @protect() + @safe + @statsd_metrics + def put(self, pk: int) -> Response: + """Update a custom LLM provider. + --- + put: + summary: Update a custom LLM provider + parameters: + - in: path + schema: + type: integer + name: pk + description: The custom LLM provider id + requestBody: + description: Custom LLM provider details + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/{{self.__class__.__name__}}.put' + responses: + 200: + description: Custom LLM provider updated + content: + application/json: + schema: + type: object + properties: + id: + type: integer + result: + $ref: '#/components/schemas/{{self.__class__.__name__}}.put' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + return super().put(pk) + + @expose("/", methods=("DELETE",)) + @protect() + @safe + @statsd_metrics + def delete(self, pk: int) -> Response: + """Delete a custom LLM provider. + --- + delete: + summary: Delete a custom LLM provider + parameters: + - in: path + schema: + type: integer + name: pk + description: The custom LLM provider id + responses: + 200: + description: Custom LLM provider deleted + content: + application/json: + schema: + type: object + properties: + message: + type: string + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + return super().delete(pk) diff --git a/superset/llms/base_llm.py b/superset/llms/base_llm.py new file mode 100644 index 000000000000..4b801355f397 --- /dev/null +++ b/superset/llms/base_llm.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import datetime +from typing import Any, List + + +class BaseLlm: + llm_type = "Base" + max_results = 1000 + + def __init__(self, pk: int, dialect: str, context: Any) -> None: + self.pk = pk + self.dialect = dialect + self.context = context + self.created_at = datetime.datetime.now(datetime.timezone.utc) + + def __str__(self) -> str: + return f"{self.llm_type} ({self.dialect} DB {self.pk})" + + def __repr__(self) -> str: + return self.__str__() + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if not hasattr(cls, "llm_type"): + raise TypeError( + f"Can't instantiate abstract class {cls.__name__} without llm_type attribute" + ) + + def __eq__(self, other: object) -> bool: + return self.__str__() == other.__str__() + + def __hash__(self) -> int: + return hash(self.__str__()) + + def generate_sql( + self, pk: int, prompt: str, context: str, schemas: List[str] | None + ) -> str: + raise NotImplementedError + + def get_system_instructions(self) -> str: + raise NotImplementedError + + @staticmethod + def get_models() -> List[str]: + """ + Return a list of available models for the LLM. + """ + raise NotImplementedError + + def get_context_size(self) -> int: + """ + Return the size of the context in tokens. + """ + raise NotImplementedError diff --git a/superset/llms/custom.py b/superset/llms/custom.py new file mode 100644 index 000000000000..39b7bade5d31 --- /dev/null +++ b/superset/llms/custom.py @@ -0,0 +1,289 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import logging +import re +from typing import Any, Dict, List + +import requests + +from superset.daos.database import DatabaseDAO +from superset.llms.base_llm import BaseLlm +from superset.models.core import CustomLlmProvider + +logger = logging.getLogger(__name__) + + +class CustomLlm(BaseLlm): + llm_type = "Custom" + cached_context_size = None + + def __init__(self, pk: int, dialect: str, context: Any) -> None: + super().__init__(pk, dialect, context) + self._provider_config = None + self._load_provider_config() + + def _load_provider_config(self) -> None: + """Load custom provider configuration from database.""" + db = DatabaseDAO.find_by_id(self.pk, True) + if not db: + logger.error("Database %s not found.", self.pk) + return + + provider_name = db.llm_connection.provider + if not provider_name.startswith("custom_"): + logger.error("Invalid custom provider name: %s", provider_name) + return + + # Extract provider ID (remove 'custom_' prefix) + try: + provider_id = int(provider_name[7:]) # Remove 'custom_' prefix + except ValueError: + logger.error("Invalid custom provider ID in: %s", provider_name) + return + + # Load custom provider configuration from database + from superset.extensions import db as superset_db + + custom_provider = ( + superset_db.session.query(CustomLlmProvider) + .filter_by(id=provider_id, enabled=True) + .first() + ) + + if not custom_provider: + logger.error( + f"Custom provider with ID {provider_id} not found or disabled." + ) + return + + self._provider_config = custom_provider + logger.info("Loaded custom provider config for '%s'", custom_provider.name) + + def _trim_markdown(self, text: str) -> str: + """Remove markdown code blocks from SQL response.""" + try: + sql_start = text.index("```") + sql_start_len = text.index("\n", sql_start) + 1 if sql_start != -1 else 0 + sql_end = text.index("\n```") + sql = text[sql_start + sql_start_len : sql_end] + except ValueError: + sql = text + return sql + + def _extract_response_content( + self, response_data: Dict[str, Any], path: str + ) -> str: + """Extract content from API response using JSONPath-like syntax.""" + try: + current: Any = response_data + + # Simple JSONPath parser for basic paths like "choices[0].message.content" + parts = re.split(r"[\.\[\]]", path) + parts = [p for p in parts if p] # Remove empty strings + + for part in parts: + if part.isdigit(): + current = current[int(part)] + else: + current = current[part] + + return str(current) + except (KeyError, IndexError, TypeError) as e: + logger.error("Failed to extract content using path '%s': %s", path, e) + return "" + + def _substitute_template_variables( + self, template: Any, variables: Dict[str, Any] + ) -> Any: + """Recursively substitute variables in template (string, dict, or list).""" + if isinstance(template, str): + # Replace {variable} placeholders + for key, value in variables.items(): + if isinstance(value, (dict, list)): + # For complex objects, convert to JSON string + template = template.replace(f"{{{key}}}", json.dumps(value)) + else: + template = template.replace(f"{{{key}}}", str(value)) + return template + elif isinstance(template, dict): + return { + k: self._substitute_template_variables(v, variables) + for k, v in template.items() + } + elif isinstance(template, list): + return [ + self._substitute_template_variables(item, variables) + for item in template + ] + else: + return template + + def get_system_instructions(self) -> str: + """Default system instructions for custom providers.""" + return f"""You are a {self.dialect} database expert. Given an input question, create a syntactically correct {self.dialect} query. You MUST only answer with the SQL query, nothing else. Unless the user specifies a specific number of results they wish to obtain, always limit your query to at most return {self.max_results} results. You can order the results by relevant columns. You MUST check that the query doesn't contain syntax errors or incorrect table, views, column names or joins on wrong columns. Fix any error you might find before returning your answer. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. To construct your database query you MUST ALWAYS use the database metadata information provided to you as a JSON file. Do NOT skip this step. This JSON file specifies all the database schemas, for each schema all its relations (which are tables, and views) and for each table its columns, indexes, and foreign key constraints. The unique indexes are very useful to understand what differentiates one record to another in the same relation. The foreign key constraints are very useful to find the correct columns to join. Do not include any markdown syntax in your response.""" + + @staticmethod + def get_models() -> List[str]: + """Return empty list - models are loaded dynamically from config.""" + return [] + + def get_provider_models(self) -> Dict[str, Dict[str, Any]]: + """Get models for this specific custom provider.""" + if not self._provider_config: + return {} + + try: + models = json.loads(self._provider_config.models) + return models + except (json.JSONDecodeError, AttributeError): + logger.error("Failed to parse models configuration for custom provider.") + return {} + + def generate_sql(self, pk: int, prompt: str, context: str, schemas: List[str] | None) -> str: + """Generate SQL using custom LLM provider.""" + db = DatabaseDAO.find_by_id(self.pk, True) + if not db: + logger.error("Database %s not found.", self.pk) + return "-- Error: Database not found" + + if not db.llm_connection.enabled: + logger.error("LLM is not enabled for database %s.", self.pk) + return "-- Error: LLM not enabled" + + if not self._provider_config: + logger.error( + f"Custom provider configuration not found for database {self.pk}." + ) + return "-- Error: Custom provider configuration not found" + + llm_api_key = db.llm_connection.api_key + if not llm_api_key: + logger.error("API key not set for database %s.", self.pk) + return "-- Error: API key not set" + + llm_model = db.llm_connection.model + if not llm_model: + logger.error("Model not set for database %s.", self.pk) + return "-- Error: Model not set" + + # Build messages + user_instructions = db.llm_context_options.instructions + system_prompt = ( + user_instructions + if user_instructions + else ( + self._provider_config.system_instructions + or self.get_system_instructions(self.dialect) + ) + ) + + context_json = json.dumps( + [ + schema + for schema in self.context + if not schemas or schema["schema_name"] in schemas + ] + ) + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"Database metadata:\n{context_json}"}, + ] + if context: + messages.append({"role": "user", "content": context}) + messages.append({"role": "user", "content": prompt}) + + # Prepare request data using template + try: + request_template = json.loads(self._provider_config.request_template) + except json.JSONDecodeError: + logger.error( + "Invalid request_template JSON in custom provider configuration." + ) + return "-- Error: Invalid request template configuration" + + template_variables = { + "model": llm_model, + "messages": messages, + "api_key": llm_api_key, + } + + request_data = self._substitute_template_variables( + request_template, template_variables + ) + + # Prepare headers + headers = {"Content-Type": "application/json"} + if self._provider_config.headers: + try: + custom_headers = json.loads(self._provider_config.headers) + headers.update( + self._substitute_template_variables( + custom_headers, template_variables + ) + ) + except json.JSONDecodeError: + logger.error("Invalid headers JSON in custom provider configuration.") + + # Make API request + timeout = self._provider_config.timeout or 30 + try: + response = requests.post( + self._provider_config.endpoint_url, + json=request_data, + headers=headers, + timeout=timeout, + ) + response.raise_for_status() + response_data = response.json() + except requests.exceptions.RequestException as e: + logger.error("Custom LLM API request failed: %s", e) + return f"-- Failed to generate SQL: {str(e)}" + except json.JSONDecodeError as e: + logger.error("Invalid JSON response from custom LLM: %s", e) + return "-- Failed to generate SQL: Invalid JSON response" + + # Extract SQL from response + raw_content = self._extract_response_content( + response_data, self._provider_config.response_path + ) + + if not raw_content: + logger.error("No content extracted from custom LLM response.") + return "-- Failed to generate SQL: No content in response" + + sql = self._trim_markdown(raw_content) + if not sql: + return "-- Unable to find valid SQL in the LLM response" + + logger.info("Generated SQL: %s", sql) + return sql + + def get_context_size(self) -> int: + """Estimate context size for custom provider.""" + if self.cached_context_size is not None: + return self.cached_context_size + + # Simple token estimation (roughly 4 chars per token) + context_text = json.dumps(self.context) + estimated_tokens = len(context_text) // 4 + + self.cached_context_size = estimated_tokens + logger.info("Estimated context size: %s", estimated_tokens) + return self.cached_context_size diff --git a/superset/llms/dispatcher.py b/superset/llms/dispatcher.py new file mode 100644 index 000000000000..1a4089f324f3 --- /dev/null +++ b/superset/llms/dispatcher.py @@ -0,0 +1,262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import datetime +import json +import logging +from typing import List, Tuple + +from celery.result import AsyncResult + +from superset.daos.context_builder_task import ContextBuilderTaskDAO +from superset.daos.database import DatabaseDAO +from superset.exceptions import DatabaseNotFoundException +from superset.extensions import security_manager +from superset.llms import custom +from superset.llms.base_llm import BaseLlm +from superset.llms.exceptions import NoContextError, NoProviderError +from superset.models.core import ContextBuilderTask +from superset.tasks.llm_context import initiate_context_generation +from superset.utils.core import override_user + +logger = logging.getLogger(__name__) + +llm_providers = {} +VALIDATION_ATTEMPTS = 3 +AVAILABLE_PROVIDERS = [ + cls for cls in BaseLlm.__subclasses__() if hasattr(cls, "llm_type") +] + + +def _get_last_successful_task_for_database( + pk: int, +) -> Tuple[ContextBuilderTask, AsyncResult]: + task = ContextBuilderTaskDAO.get_last_successful_task_for_database(pk) + + if not task: + raise NoContextError(f"No context builder task found for database {pk}.") + + context_builder_worker = AsyncResult(task.task_id) + if context_builder_worker.status == "SUCCESS": + return task, context_builder_worker + + return task, None + + +def _get_or_create_llm_provider(pk: int, dialect: str, provider_type: str) -> BaseLlm: + (context_builder_task, context_builder_worker) = ( + _get_last_successful_task_for_database(pk) + ) + + # At this point we will always have a context_builder_task but may not have a context_builder_worker + # if the task result has expired from the Celery backend. If we still have an llm_provider we can + # continue to use the context stored in memory in the provider. + + # See if we have a provider already for this database + llm_provider = llm_providers.get(pk, None) + if llm_provider: + started_time_utc = context_builder_task.started_time.replace( + tzinfo=datetime.timezone.utc + ) + # Check if cached provider type matches current database settings + if ( + started_time_utc < llm_provider.created_at + and llm_provider.llm_type == provider_type + ): + return llm_provider + else: + # Provider type changed or context is newer, remove old cached provider + del llm_providers[pk] + + try: + context = json.loads(context_builder_worker.result["result"]) + except Exception as e: + logger.error("Failed to parse context JSON: %s", str(e)) + raise NoContextError(f"Failed to parse context JSON for database {pk}.") + + # Handle custom providers + if provider_type.startswith("custom_"): + llm_provider = custom.CustomLlm(pk, dialect, context) + else: + # Handle built-in providers + for provider in AVAILABLE_PROVIDERS: + if provider.llm_type == provider_type: + llm_provider = provider(pk, dialect, context) + break + else: + raise NoProviderError(f"No LLM provider found for type {provider_type}.") + + llm_providers[pk] = llm_provider + return llm_provider + + +def generate_sql(pk: int, prompt: str, context: str, schemas: List[str] | None) -> str: + admin_user = security_manager.find_user(username="admin") + if not admin_user: + return {"status_code": 500, "message": "Unable to find admin user"} + with override_user(admin_user): + db = DatabaseDAO.find_by_id(pk) + if not db: + raise DatabaseNotFoundException(f"No such database: {pk}") + + provider = _get_or_create_llm_provider(pk, db.backend, db.llm_connection.provider) + if not provider: + return None + + prompt_with_errors = prompt + + for _ in range(VALIDATION_ATTEMPTS): + generated = provider.generate_sql(prompt_with_errors, context, schemas) + + # Prepend 'EXPLAIN' command to the generated SQL to validate it + validation_sql = f"EXPLAIN {generated}" + error_text = None + + try: + with db.get_raw_connection() as conn: + cursor = conn.cursor() + mutated_query = db.mutate_sql_based_on_config(validation_sql) + cursor.execute(mutated_query) + db.db_engine_spec.execute(cursor, mutated_query, db) + logger.info("Validation SQL executed successfully: %s", validation_sql) + return generated + except Exception as error: + logger.error("Validation SQL execution failed: %s", error) + error_text = str(error) + + # Otherwise, we want to append the generated SQL and error message to the prompt and try again + prompt_with_errors = ( + f"{prompt_with_errors}\n\n{generated}\n\n-- Error: {error_text}\n" + ) + logger.info( + f"Generated SQL is invalid: {error_text}\nRetrying with updated prompt: {prompt_with_errors}" + ) + + logger.error("Failed to generate valid SQL after %s attempts.", VALIDATION_ATTEMPTS) + return f"-- Failed to generate valid SQL after {VALIDATION_ATTEMPTS} attempts." + + +def get_state(pk: int) -> dict: + """ + Get the state of the LLM context. + """ + # In total we're interested in knowing three things: + # - The last successful context build + # - The last build finished with an error + # - Whether there is a build in progress right now + + result = { + "status": "waiting", + } + + admin_user = security_manager.find_user(username="admin") + if not admin_user: + return {"status_code": 500, "message": "Unable to find admin user"} + with override_user(admin_user): + db = DatabaseDAO.find_by_id(pk) + if not db: + raise DatabaseNotFoundException(f"No such database: {pk}") + + successful_task = ContextBuilderTaskDAO.get_last_successful_task_for_database(pk) + if successful_task: + provider = _get_or_create_llm_provider( + pk, db.backend, db.llm_connection.provider + ) + result["context"] = { + "build_time": successful_task.started_time, + "status": successful_task.status, + "size": provider.get_context_size(), + } + + last_two_tasks = ContextBuilderTaskDAO.get_last_two_tasks_for_database(pk) + error_task = next((task for task in last_two_tasks if task.status == "ERROR"), None) + if error_task and len(last_two_tasks) > 0 and last_two_tasks[0].status != "SUCCESS": + result["error"] = { + "build_time": error_task.started_time, + } + + latest_task = last_two_tasks[0] if len(last_two_tasks) > 0 else None + if latest_task and latest_task.status == "PENDING": + result["status"] = "building" + + return result + + +def generate_context_for_db(pk: int): + """ + Generate the LLM context for a database. + """ + # Check if we have a task for this already + task = ContextBuilderTaskDAO.get_latest_task_for_database(pk) + if task and task.status == "PENDING": + return { + "status": "Pending", + "task_id": task.task_id, + } + + task = initiate_context_generation(pk) + + return { + "status": "Started", + "task_id": task.task_id, + } + + +def get_default_options(pk: int) -> dict: + """ + Get the default options for the LLM context. + """ + admin_user = security_manager.find_user(username="admin") + if not admin_user: + return {"status_code": 500, "message": "Unable to find admin user"} + with override_user(admin_user): + db = DatabaseDAO.find_by_id(pk) + if not db: + raise DatabaseNotFoundException(f"No such database: {pk}") + + # Built-in providers + result = { + provider.llm_type: { + "models": provider.get_models(), + "instructions": provider.get_system_instructions(db.backend), + } + for provider in AVAILABLE_PROVIDERS + } + + # Add custom providers + from superset.extensions import db as superset_db + from superset.models.core import CustomLlmProvider + + custom_providers = ( + superset_db.session.query(CustomLlmProvider).filter_by(enabled=True).all() + ) + + for provider in custom_providers: + try: + models = json.loads(provider.models) + provider_key = f"custom_{provider.id}" + result[provider_key] = { + "models": models, + "instructions": provider.system_instructions + or custom.CustomLlm.get_system_instructions(db.backend), + "name": provider.name, + } + except (json.JSONDecodeError, AttributeError) as e: + logger.error( + f"Failed to parse models for custom provider {provider.name}: {e}" + ) + + return result diff --git a/superset/llms/exceptions.py b/superset/llms/exceptions.py new file mode 100644 index 000000000000..43904dcd04ae --- /dev/null +++ b/superset/llms/exceptions.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from superset.exceptions import SupersetException + + +class NoContextError(SupersetException): + """Exception raised when no context is provided to the LLM.""" + + pass + + +class NoProviderError(SupersetException): + """Exception raised when an appropriate LLM provider can't be found.""" + + pass diff --git a/superset/llms/gemini.py b/superset/llms/gemini.py new file mode 100644 index 000000000000..89b08ae78441 --- /dev/null +++ b/superset/llms/gemini.py @@ -0,0 +1,321 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import datetime +import json +import logging +import time +from typing import List + +from google import genai +from google.genai import types + +from superset.daos.database import DatabaseDAO +from superset.llms.base_llm import BaseLlm + +logger = logging.getLogger(__name__) + + +class GeminiLlm(BaseLlm): + llm_type = "Gemini" + cache_name = None + cache_expiry = None + cache_model = None + cached_size = None + cached_size_cache_name = None + + def _create_schema_cache( + self, gemini_client, model, user_instructions + ) -> types.CachedContent: + cached_content = gemini_client.caches.create( + model=model, + config=types.CreateCachedContentConfig( + contents=[json.dumps(self.context)], + system_instruction=user_instructions + if user_instructions + else self.get_system_instructions(self.dialect), + display_name=f"DB({self.pk}) context for {model}", + ttl="86400s", + ), + ) + return cached_content + + def _get_cache_name(self, gemini_client, model, user_instructions) -> str: + """ + Get the cache name for the LLM. If we have a cache and we think it's valid, + just return it. Otherwise, generate a new cache. + """ + logger.info( + f"Current time is {datetime.datetime.now(tz=datetime.timezone.utc)}" + ) + logger.info("Cache expiry is %s", self.cache_expiry) + + # First check if the cache has expired + if self.cache_expiry and self.cache_expiry < datetime.datetime.now( + datetime.timezone.utc + ): + self.cache_name = None + self.cache_expiry = None + + # We'll also check if the model has changed + logger.info("Current model is %s, cache model is %s", model, self.cache_model) + if self.cache_model != model: + self.cache_name = None + self.cache_expiry = None + + if not self.cache_name: + logger.info("Creating new cache for model %s...", model) + start_time = time.perf_counter() + created_cache = self._create_schema_cache( + gemini_client, model, user_instructions + ) + end_time = time.perf_counter() + logger.info("Cache created in %.2f seconds", end_time - start_time) + self.cache_name = created_cache.name + self.cache_expiry = created_cache.expire_time + self.cache_model = model + + return self.cache_name + + def _get_response_error(self, response: types.GenerateContentResponse) -> str: + error = "-- Failed to generate SQL: " + match response.candidates[0].finish_reason: + case types.FinishReason.FINISH_REASON_UNSPECIFIED: + return error + "Gemini failed for an unspecified reason" + case types.FinishReason.MAX_TOKENS: + return error + "Gemini exceeded the maximum token limit" + case types.FinishReason.SAFETY: + return error + "Gemini detected unsafe content" + case types.FinishReason.RECITATION: + return error + "Gemini detected training data in the output" + case types.FinishReason.OTHER: + return error + "Gemini failed for an 'other' reason" + case types.FinishReason.BLOCKLIST: + return error + "Gemini detected blocklisted content" + case types.FinishReason.PROHIBITED_CONTENT: + return error + "Gemini detected prohibited content in the output" + case types.FinishReason.SPII: + return ( + error + + "Gemini detected personally identifiable information in the output" + ) + case types.FinishReason.MALFORMED_FUNCTION_CALL: + return error + "Gemini detected a malformed function call in the output" + + return error + "Gemini failed for an unknown reason" + + def _trim_markdown(self, text: str) -> bool: + # Check the body for SQL wrapped in markdown ```{language}\n...\n``` + try: + sql_start = text.index("```") + # Find the first newline after the start of the SQL block + sql_start_len = text.index("\n", sql_start) + 1 if sql_start != -1 else 0 + sql_end = text.index("\n```") + sql = text[sql_start + sql_start_len : sql_end] + except ValueError: + # There was no markdown, so assume for now the whole response is the SQL + sql = text + + return sql + + @classmethod + def get_system_instructions(cls, dialect) -> str: + system_instructions = f"""You are a {dialect} database expert. Given an input question, create a syntactically correct {dialect} query. You MUST only answer with the SQL query, nothing else. Unless the user specifies a specific number of results they wish to obtain, always limit your query to at most return {cls.max_results} results. You can order the results by relevant columns. You MUST check that the query doesn't contain syntax errors or incorrect table, views, column names or joins on wrong columns. Fix any error you might find before returning your answer. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. To construct your database query you MUST ALWAYS use the database metadata information provided to you as a JSON file. Do NOT skip this step. This JSON file specifies all the database schemas, for each schema all its relations (which are tables, and views) and for each table its columns, indexes, and foreign key constraints. The unique indexes are very useful to understand what differentiates one record to another in the same relation. The foreign key constraints are very useful to find the correct columns to join. Do not include any markdown syntax in your response.""" + return system_instructions + + @staticmethod + def get_models(): + return { + "models/gemini-1.5-flash-002": { + "name": "Gemini 1.5 Flash", + "input_token_limit": 1000000, + }, + "models/gemini-2.0-flash": { + "name": "Gemini 2.0 Flash", + "input_token_limit": 1048576, + }, + "models/gemini-2.0-flash-thinking-exp": { + "name": "Gemini 2.5 Flash Preview", + "input_token_limit": 1048576, + }, + "models/gemini-1.5-pro-002": { + "name": "Gemini 1.5 Pro", + "input_token_limit": 2000000, + }, + "models/gemini-2.0-pro-exp": { + "name": "Gemini 2.0 Pro Experimental", + "input_token_limit": 1048576, + }, + } + + def generate_sql(self, prompt: str, history: str, schemas: List[str] | None) -> str: + """ + Generate SQL from a user prompt. + """ + db = DatabaseDAO.find_by_id(self.pk, True) + if not db: + logger.error("Database %s not found.", self.pk) + return + + # TODO(AW): We should throw here instead of returning None + if not db.llm_connection.enabled: + logger.error("LLM is not enabled for database %s.", self.pk) + return + + if not db.llm_connection.provider == self.llm_type: + logger.error( + "LLM provider is not %s for database %s.", self.llm_type, self.pk + ) + return + + llm_api_key = db.llm_connection.api_key + if not llm_api_key: + logger.error("API key not set for database %s.", self.pk) + return + + llm_model = db.llm_connection.model + if not llm_model: + logger.error("Model not set for database %s.", self.pk) + return + + user_instructions = db.llm_context_options.instructions + + logger.info("Using model %s for database %s", llm_model, self.pk) + + gemini_client = genai.Client(api_key=llm_api_key) + + if schemas: + # Check and see if all the schemas are in the context + for schema in schemas: + if not any( + schema == context_schema["schema_name"] + for context_schema in self.context + ): + logger.error("Schema %s not found in context", schema) + return + + context = json.dumps( + [schema for schema in self.context if schema["schema_name"] in schemas] + ) + instructions = ( + user_instructions + if user_instructions + else self.get_system_instructions(self.dialect) + ) + + contents = ( + [context, instructions, history, prompt] + if history + else [context, instructions, prompt] + ) + response = gemini_client.models.generate_content( + model=llm_model, + contents=contents, + ) + else: + cache_name = self._get_cache_name( + gemini_client, llm_model, user_instructions + ) + logger.info("Using cache %s", self.cache_name) + + contents = [history, prompt] if history else [prompt] + try: + response = gemini_client.models.generate_content( + model=llm_model, + contents=contents, + config=types.GenerateContentConfig( + cached_content=cache_name, + ), + ) + except genai.errors.ServerError as e: + logger.error("Server error: %s", e) + return f"-- Failed to generate SQL: {e.message}" + + # Check if the response is an error by looking at the finish reason of every candidate + for candidate in response.candidates: + if candidate.finish_reason != types.FinishReason.STOP: + logger.error("Failed to generate SQL: %s", candidate.finish_reason) + + success = any( + candidate.finish_reason == types.FinishReason.STOP + for candidate in response.candidates + ) + if not success or not response.text: + return self._get_response_error(response) + + sql = self._trim_markdown(response.text) + if not sql: + return "-- Unable to find valid SQL in the LLM response" + + logger.info("Generated SQL: %s", sql) + return sql + + def get_context_size(self) -> int: + """ + Count the number of tokens in a prompt. + Cache the result in self.cached_size, which expires when self.cache_expiry changes. + """ + db = DatabaseDAO.find_by_id(self.pk, True) + if not db: + logger.error("Database %s not found.", self.pk) + return + + if not db.llm_connection.provider == self.llm_type: + logger.error( + "LLM provider is not %s for database %s.", self.llm_type, self.pk + ) + return + + llm_api_key = db.llm_connection.api_key + if not llm_api_key: + logger.error("API key not set for database %s.", self.pk) + return + + llm_model = db.llm_connection.model + if not llm_model: + logger.error("Model not set for database %s.", self.pk) + return + + # If we have a cached size and a valid cache_expiry, return the cached size + if ( + self.cached_size is not None + and self.cache_name == self.cached_size_cache_name + ): + logger.info("Using cached context size: %s", self.cached_size) + return self.cached_size + else: + # Invalidate any old cached size + self.cached_size = None + self.cached_size_cache_name = None + + user_instructions = db.llm_context_options.instructions + + gemini_client = genai.Client(api_key=llm_api_key) + response = gemini_client.models.count_tokens( + model=llm_model, + contents=[ + json.dumps(self.context), + user_instructions + if user_instructions + else self.get_system_instructions(self.dialect), + ], + ) + logger.info("Calculated context size: %s", response.total_tokens) + + # Cache the size until cache_expiry changes or is reached + self.cached_size = response.total_tokens + self.cached_size_cache_name = self.cache_name + return self.cached_size diff --git a/superset/llms/openai.py b/superset/llms/openai.py new file mode 100644 index 000000000000..e074745f96dd --- /dev/null +++ b/superset/llms/openai.py @@ -0,0 +1,193 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import logging +from typing import List + +import openai +import tiktoken + +from superset.daos.database import DatabaseDAO +from superset.llms.base_llm import BaseLlm + +logger = logging.getLogger(__name__) + + +class OpenAiLlm(BaseLlm): + llm_type = "OpenAI" + cached_context_size = None + + def _trim_markdown(self, text: str) -> bool: + # Check the body for SQL wrapped in markdown ```{language}\n...\n``` + try: + sql_start = text.index("```") + # Find the first newline after the start of the SQL block + sql_start_len = text.index("\n", sql_start) + 1 if sql_start != -1 else 0 + sql_end = text.index("\n```") + sql = text[sql_start + sql_start_len : sql_end] + except ValueError: + # There was no markdown, so assume for now the whole response is the SQL + sql = text + + return sql + + @classmethod + def get_system_instructions(cls, dialect) -> str: + system_instructions = f"""You are a {dialect} database expert. Given an input question, create a syntactically correct {dialect} query. You MUST only answer with the SQL query, nothing else. Unless the user specifies a specific number of results they wish to obtain, always limit your query to at most return {cls.max_results} results. You can order the results by relevant columns. You MUST check that the query doesn't contain syntax errors or incorrect table, views, column names or joins on wrong columns. Fix any error you might find before returning your answer. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. To construct your database query you MUST ALWAYS use the database metadata information provided to you as a JSON file. Do NOT skip this step. This JSON file specifies all the database schemas, for each schema all its relations (which are tables, and views) and for each table its columns, indexes, and foreign key constraints. The unique indexes are very useful to understand what differentiates one record to another in the same relation. The foreign key constraints are very useful to find the correct columns to join. Do not include any markdown syntax in your response.""" + return system_instructions + + @staticmethod + def get_models(): + return { + "gpt-4.1-nano": {"name": "GPT-4.1 nano", "input_token_limit": 1047576}, + "gpt-4.1-mini": {"name": "GPT-4.1 mini", "input_token_limit": 1047576}, + "o4-mini": {"name": "o4-mini", "input_token_limit": 200000}, + "o3": {"name": "o3", "input_token_limit": 200000}, + "gpt-4o-mini": {"name": "GPT-4o mini", "input_token_limit": 128000}, + } + + def generate_sql(self, prompt: str, history: str, schemas: List[str] | None) -> str: + """ + Generate SQL from a user prompt using the OpenAI SDK. + """ + db = DatabaseDAO.find_by_id(self.pk, True) + if not db: + logger.error("Database %s not found.", self.pk) + return + + if not db.llm_connection.enabled: + logger.error("LLM is not enabled for database %s.", self.pk) + return + + if not db.llm_connection.provider == self.llm_type: + logger.error( + "LLM provider is not %s for database %s.", self.llm_type, self.pk + ) + return + + llm_api_key = db.llm_connection.api_key + if not llm_api_key: + logger.error("API key not set for database %s.", self.pk) + return + + llm_model = db.llm_connection.model + if not llm_model: + logger.error("Model not set for database %s.", self.pk) + return + + logger.info("Using model %s for database %s", llm_model, self.pk) + + user_instructions = db.llm_context_options.instructions + client = openai.OpenAI(api_key=llm_api_key) + + # Compose system prompt and context + system_prompt = ( + user_instructions + if user_instructions + else self.get_system_instructions(self.dialect) + ) + context_json = json.dumps( + [ + schema + for schema in self.context + if not schemas or schema["schema_name"] in schemas + ] + ) + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"Database metadata:\n{context_json}"}, + ] + if history: + messages.append({"role": "user", "content": history}) + messages.append({"role": "user", "content": prompt}) + + try: + response = client.chat.completions.create( + model=llm_model, + messages=messages, + ) + except Exception as e: + logger.error("OpenAI API error: %s", e) + return f"-- Failed to generate SQL: {str(e)}" + + if not response or not response.choices or len(response.choices) < 1: + logger.error("No response from OpenAI API.") + return "-- Failed to generate SQL: No response from OpenAI API." + + reply = response.choices[0].message.content.strip() + sql = self._trim_markdown(reply) + if not sql: + return "-- Unable to find valid SQL in the LLM response" + + logger.info("Generated SQL: %s", sql) + return sql + + def get_context_size(self) -> int: + """ + Count the number of tokens in a prompt using the OpenAI SDK. + Cache the result in self.cached_context_size. + """ + + db = DatabaseDAO.find_by_id(self.pk, True) + if not db: + logger.error("Database %s not found.", self.pk) + return + + if not db.llm_connection.provider == self.llm_type: + logger.error( + "LLM provider is not %s for database %s.", self.llm_type, self.pk + ) + return + + llm_api_key = db.llm_connection.api_key + if not llm_api_key: + logger.error("API key not set for database %s.", self.pk) + return + + llm_model = db.llm_connection.model + if not llm_model: + logger.error("Model not set for database %s.", self.pk) + return + + if self.cached_context_size is not None: + logger.info("Using cached context size: %s", self.cached_context_size) + return self.cached_context_size + + user_instructions = db.llm_context_options.instructions + system_prompt = ( + user_instructions + if user_instructions + else self.get_system_instructions(self.dialect) + ) + context_json = json.dumps(self.context) + + try: + encoding = tiktoken.encoding_for_model(llm_model) + except Exception: + encoding = tiktoken.get_encoding("cl100k_base") + + # Compose the prompt as OpenAI would receive it + prompt_parts = [context_json, system_prompt] + prompt_text = "\n".join(prompt_parts) + tokens = encoding.encode(prompt_text) + total_tokens = len(tokens) + + logger.info("Calculated context size: %s", total_tokens) + self.cached_context_size = total_tokens + return self.cached_context_size diff --git a/superset/llms/schemas.py b/superset/llms/schemas.py new file mode 100644 index 000000000000..bce1a14ec10b --- /dev/null +++ b/superset/llms/schemas.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from typing import Any + +from marshmallow import fields, post_load, Schema, ValidationError + + +class CustomLlmProviderTestSchema(Schema): + """Schema for testing custom LLM provider connection.""" + + endpoint_url = fields.String(required=True) + request_template = fields.String(required=True) + response_path = fields.String(required=True) + headers = fields.String(allow_none=True) + timeout = fields.Integer(missing=30) + + @post_load + def validate_json_fields(self, data: dict[str, Any]) -> dict[str, Any]: + """Validate JSON fields.""" + # Validate request_template + try: + json.loads(data["request_template"]) + except json.JSONDecodeError: + raise ValidationError("request_template must be valid JSON") + + # Validate headers if provided + if data.get("headers"): + try: + json.loads(data["headers"]) + except json.JSONDecodeError: + raise ValidationError("headers must be valid JSON") + + return data + + +custom_llm_provider_test_schema = CustomLlmProviderTestSchema() diff --git a/superset/migrations/versions/2025-08-11_11-20_58200d37f074_add_llm_tables.py b/superset/migrations/versions/2025-08-11_11-20_58200d37f074_add_llm_tables.py new file mode 100644 index 000000000000..a7a25a74782c --- /dev/null +++ b/superset/migrations/versions/2025-08-11_11-20_58200d37f074_add_llm_tables.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""add llm tables + +Revision ID: 58200d37f074 +Revises: c233f5365c9e +Create Date: 2025-08-11 11:20:44.248026 + +""" + +# revision identifiers, used by Alembic. +revision = "58200d37f074" +down_revision = "c233f5365c9e" + +import sqlalchemy as sa +from alembic import op + + +def upgrade(): + op.create_table( + "llm_connection", + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("database_id", sa.Integer(), nullable=False), + sa.Column("enabled", sa.Boolean(), default=False, nullable=False), + sa.Column("provider", sa.String(length=255), nullable=False), + sa.Column("model", sa.String(length=255), nullable=False), + sa.Column("api_key", sa.Text(), nullable=False), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["changed_by_fk"], + ["ab_user.id"], + ), + sa.ForeignKeyConstraint( + ["created_by_fk"], + ["ab_user.id"], + ), + sa.ForeignKeyConstraint( + ["database_id"], + ["dbs.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "llm_context_options", + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("database_id", sa.Integer(), nullable=True), + sa.Column("refresh_interval", sa.Integer(), nullable=True), + sa.Column( + "schemas", + sa.Text().with_variant(sa.dialects.mysql.MEDIUMTEXT(), "mysql"), + nullable=True, + ), + sa.Column("include_indexes", sa.Boolean(), default=True), + sa.Column("top_k", sa.Integer(), default=10, nullable=True), + sa.Column("top_k_limit", sa.Integer(), default=50000, nullable=True), + sa.Column( + "instructions", + sa.Text().with_variant(sa.dialects.mysql.MEDIUMTEXT(), "mysql"), + nullable=True, + ), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["changed_by_fk"], + ["ab_user.id"], + ), + sa.ForeignKeyConstraint( + ["created_by_fk"], + ["ab_user.id"], + ), + sa.ForeignKeyConstraint( + ["database_id"], + ["dbs.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + + op.create_table( + "context_builder_task", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("task_id", sa.String(length=255), nullable=True), + sa.Column("database_id", sa.Integer(), nullable=True), + sa.Column("started_time", sa.DateTime(), nullable=True), + sa.Column( + "params", + sa.Text().with_variant(sa.dialects.mysql.MEDIUMTEXT(), "mysql"), + nullable=True, + ), + sa.Column("ended_time", sa.DateTime(), nullable=True), + sa.Column("status", sa.String(length=255), nullable=True), + sa.Column("duration", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["database_id"], + ["dbs.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("task_id"), + ) + + op.create_table( + "custom_llm_providers", + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("endpoint_url", sa.String(length=1024), nullable=False), + sa.Column( + "request_template", + sa.Text().with_variant(sa.dialects.mysql.MEDIUMTEXT(), "mysql"), + nullable=False, + ), + sa.Column("response_path", sa.String(length=255), nullable=False), + sa.Column( + "headers", + sa.Text().with_variant(sa.dialects.mysql.MEDIUMTEXT(), "mysql"), + nullable=True, + ), + sa.Column( + "models", + sa.Text().with_variant(sa.dialects.mysql.MEDIUMTEXT(), "mysql"), + nullable=False, + ), + sa.Column( + "system_instructions", + sa.Text().with_variant(sa.dialects.mysql.MEDIUMTEXT(), "mysql"), + nullable=True, + ), + sa.Column("timeout", sa.Integer(), default=30, nullable=True), + sa.Column("enabled", sa.Boolean(), default=True, nullable=False), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["changed_by_fk"], + ["ab_user.id"], + ), + sa.ForeignKeyConstraint( + ["created_by_fk"], + ["ab_user.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), + ) + # ### end Alembic commands ### + + +def downgrade(): + op.drop_table("custom_llm_providers") + op.drop_table("llm_context_options") + op.drop_table("llm_connection") + op.drop_table("context_builder_task") + # ### end Alembic commands ### diff --git a/superset/models/core.py b/superset/models/core.py index f6643d18ff2c..5ecb5da80ac8 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -27,7 +27,7 @@ from ast import literal_eval from contextlib import closing, contextmanager, nullcontext, suppress from copy import deepcopy -from datetime import datetime +from datetime import datetime, timezone from functools import lru_cache from inspect import signature from typing import Any, Callable, cast, Optional, TYPE_CHECKING @@ -56,7 +56,7 @@ from sqlalchemy.engine.url import URL from sqlalchemy.exc import NoSuchModuleError from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import relationship +from sqlalchemy.orm import backref, relationship from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import ColumnElement, expression, Select @@ -195,6 +195,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable "allow_file_upload", "extra", "impersonate_user", + "llm_available", + "llm_connection", + "llm_context_options", ] extra_import_fields = [ "password", @@ -285,6 +288,7 @@ def data(self) -> dict[str, Any]: "allow_multi_catalog": self.allow_multi_catalog, "parameters_schema": self.parameters_schema, "engine_information": self.engine_information, + "llm_connection": self.llm_connection, } @property @@ -380,6 +384,17 @@ def engine_information(self) -> dict[str, Any]: engine_information = {} return engine_information + @property + def llm_available(self) -> bool: + c = self.llm_connection + return ( + bool(c.provider) and bool(c.model) and bool(c.api_key) and bool(c.enabled) + ) + + @property + def llm_connection(self) -> LlmConnection: + return self.get_extra().get("llm_connection", {}) + @classmethod def get_password_masked_url_from_uri( # pylint: disable=invalid-name cls, uri: str @@ -1366,3 +1381,57 @@ class FavStar(UUIDMixin, Model): class_name = Column(String(50)) obj_id = Column(Integer) dttm = Column(DateTime, default=datetime.utcnow) + + +class ContextBuilderTask(Model): + __tablename__ = "context_builder_task" + id = Column(Integer, primary_key=True) + task_id = Column(String(255), unique=True) + database_id = Column(Integer, ForeignKey("dbs.id")) + started_time = Column(DateTime, default=datetime.now(timezone.utc)) + ended_time = Column(DateTime, nullable=True) + status = Column(String(255), nullable=True) + duration = Column(Integer, nullable=True) + params = Column(utils.MediumText()) + + +class LlmConnection(Model, AuditMixinNullable): + __tablename__ = "llm_connection" + id = Column(Integer, primary_key=True) + database_id = Column(Integer, ForeignKey("dbs.id")) + enabled = Column(Boolean, default=False) + provider = Column(String(255), nullable=False) + model = Column(String(255), nullable=False) + api_key = Column(Text, nullable=True) + database = relationship( + "Database", backref=backref("llm_connection", uselist=False), uselist=False + ) + + +class LlmContextOptions(Model, AuditMixinNullable): + __tablename__ = "llm_context_options" + id = Column(Integer, primary_key=True) + database_id = Column(Integer, ForeignKey("dbs.id")) + refresh_interval = Column(Integer, default=12) + schemas = Column(utils.MediumText(), nullable=True) + include_indexes = Column(Boolean, default=True) + top_k = Column(Integer, default=10) + top_k_limit = Column(Integer, default=10000) + instructions = Column(utils.MediumText(), nullable=True) + database = relationship( + "Database", backref=backref("llm_context_options", uselist=False) + ) + + +class CustomLlmProvider(Model, AuditMixinNullable): + __tablename__ = "custom_llm_providers" + id = Column(Integer, primary_key=True) + name = Column(String(255), nullable=False, unique=True) + endpoint_url = Column(String(1024), nullable=False) + request_template = Column(utils.MediumText(), nullable=False) + response_path = Column(String(255), nullable=False) + headers = Column(utils.MediumText(), nullable=True) + models = Column(utils.MediumText(), nullable=False) + system_instructions = Column(utils.MediumText(), nullable=True) + timeout = Column(Integer, default=30) + enabled = Column(Boolean, default=True) diff --git a/superset/sql/parse.py b/superset/sql/parse.py index a2b84376984c..91cbde9e7a72 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -1209,6 +1209,7 @@ def parse_predicate(self, predicate: str) -> str: return predicate + class SQLScript: """ A SQL script, with 0+ statements. diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py index 906dd72bcaf2..2edcffe270a5 100644 --- a/superset/sqllab/api.py +++ b/superset/sqllab/api.py @@ -34,6 +34,7 @@ from superset.daos.query import QueryDAO from superset.extensions import event_logger from superset.jinja_context import get_template_processor +from superset.llms import dispatcher from superset.models.sql_lab import Query from superset.sql.parse import SQLScript from superset.sql_lab import get_sql_results @@ -48,7 +49,10 @@ EstimateQueryCostSchema, ExecutePayloadSchema, FormatQueryPayloadSchema, + GenerateDbContextSchema, + GenerateSqlSchema, QueryExecutionResponseSchema, + sql_lab_get_assistant_status_schema, sql_lab_get_results_schema, SQLLabBootstrapSchema, ) @@ -420,6 +424,142 @@ def execute_sql_query(self) -> FlaskResponse: ) return self.response(response_status, **payload) + @expose("/generate_sql/", methods=("POST",)) + @protect() + @statsd_metrics + @requires_json + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.generate_sql", + log_to_statsd=False, + ) + def generate_sql_with_ai(self) -> FlaskResponse: + """Generate a SQL query with AI. + --- + post: + summary: Generate a SQL query with AI + requestBody: + description: User prompt and prior SQL query context + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/GenerateSqlSchema' + responses: + 200: + description: A generated SQL query + content: + application/json: + schema: + $ref: '#/components/schemas/GenerateSqlResponseSchema' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + params = GenerateSqlSchema().load(request.json) + logging.info(f"Generating SQL with AI for database {params['database_id']}") + logging.info(f"User prompt: {params['user_prompt']}") + logging.info(f"Prior context: {params.get('prior_context')}") + logging.info(f"Schemas: {params.get('schemas')}") + generated = dispatcher.generate_sql( + params["database_id"], + params["user_prompt"], + params.get("prior_context"), + params.get("schemas"), + ) + return json_success(json.dumps({"sql": generated}), 200) + + @expose("/generate_db_context/", methods=("POST",)) + @protect() + @statsd_metrics + @requires_json + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".generate_db_context", + log_to_statsd=False, + ) + def generate_db_context(self) -> FlaskResponse: + """Generate database context information for generating queries with an LLM. + --- + post: + summary: Generate database context information for generating queries with an LLM + requestBody: + description: Database ID + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/GenerateDbContextSchema' + responses: + 202: + description: Query execution result, query still running + content: + application/json: + schema: + $ref: '#/components/schemas/GenerateDbContextResponseSchema' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 500: + $ref: '#/components/responses/500' + """ + params = GenerateDbContextSchema().load(request.json) + result = dispatcher.generate_context_for_db(params["database_id"]) + return json_success(json.dumps(result), 200) + + @expose("/db_context_status/") + @protect() + @statsd_metrics + @rison(sql_lab_get_assistant_status_schema) + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".db_context_status", + log_to_statsd=False, + ) + def db_context_status(self, **kwargs: Any) -> FlaskResponse: + """Get the status of the AI assistant. + --- + get: + summary: Get the status of the AI assistant. + parameters: + - in: query + name: q + content: + application/json: + schema: + $ref: '#/components/schemas/sql_lab_get_assistant_status_schema' + responses: + 200: + description: Current state of the AI assistant + content: + application/json: + schema: + $ref: '#/components/schemas/AiAssistantStatusResponseSchema' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 410: + $ref: '#/components/responses/410' + 500: + $ref: '#/components/responses/500' + """ + params = kwargs["rison"] + pk = params.get("pk") + result = dispatcher.get_state(pk) + return json_success(json.dumps(result), 200) + @staticmethod def _create_sql_json_command( execution_context: SqlJsonExecutionContext, log_params: Optional[dict[str, Any]] diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py index 1e1d492b7cf8..bada9ab01a42 100644 --- a/superset/sqllab/schemas.py +++ b/superset/sqllab/schemas.py @@ -26,6 +26,14 @@ "required": ["key"], } +sql_lab_get_assistant_status_schema = { + "type": "object", + "properties": { + "pk": {"type": "integer"}, + }, + "required": ["pk"], +} + class EstimateQueryCostSchema(Schema): database_id = fields.Integer( @@ -153,3 +161,29 @@ class SQLLabBootstrapSchema(Schema): values=fields.Nested(QueryResultSchema), ) tab_state_ids = fields.List(fields.String()) + + +class GenerateSqlSchema(Schema): + database_id = fields.Integer(required=True) + user_prompt = fields.String(required=True, allow_none=False) + prior_context = fields.String(allow_none=True) + schemas = fields.List(fields.String(), allow_none=True) + + +class GenerateSqlResponseSchema(Schema): + sql = fields.String(required=True) + + +class GenerateDbContextSchema(Schema): + database_id = fields.Integer(required=True) + + +class AiAssistantStatusResponseSchema(Schema): + context = fields.Dict( + allow_none=True, + status=fields.String(required=True), + build_time=fields.DateTime(required=True), + message=fields.String(allow_none=True), + size=fields.Integer(allow_none=True), + ) + status = fields.String(required=True) diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index 0f28c070703b..14b387b73c24 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -257,7 +257,7 @@ def fetch_url(data: str, headers: dict[str, str]) -> dict[str, str]: url, data=bytes(data, "utf-8"), headers=headers, method="PUT" ) response = request.urlopen( # pylint: disable=consider-using-with # noqa: S310 - req, timeout=600 + req, timeout=30 ) logger.info( "Fetched %s with payload %s, status code: %s", url, data, response.code diff --git a/superset/tasks/celery_app.py b/superset/tasks/celery_app.py index 2049246f0428..7130435c4915 100644 --- a/superset/tasks/celery_app.py +++ b/superset/tasks/celery_app.py @@ -34,7 +34,7 @@ # Need to import late, as the celery_app will have been setup by "create_app()" # ruff: noqa: E402, F401 # pylint: disable=wrong-import-position, unused-import -from . import cache, scheduler +from . import cache, llm_context, scheduler # Export the celery app globally for Celery (as run on the cmd line) to find app = celery_app diff --git a/superset/tasks/llm_context.py b/superset/tasks/llm_context.py new file mode 100644 index 000000000000..18c34b9a3486 --- /dev/null +++ b/superset/tasks/llm_context.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +import json +import logging +import time +from typing import Any + +from celery.result import AsyncResult +from celery.signals import after_task_publish +from celery.utils.log import get_task_logger + +from superset import db +from superset.daos.context_builder_task import ContextBuilderTaskDAO +from superset.daos.database import DatabaseDAO +from superset.databases.utils import get_database_metadata +from superset.extensions import celery_app, security_manager +from superset.models.core import ContextBuilderTask +from superset.utils.core import override_user + +logger = get_task_logger(__name__) +logger.setLevel(logging.INFO) + + +@after_task_publish.connect +def update_sent_state(sender: str | None = None, headers: dict[str, str] | None = None, **kwargs: Any) -> None: + task = celery_app.tasks.get(sender) + backend = task.backend if task else celery_app.backend + + if headers and headers["task"] == "generate_llm_context": + backend.store_result(headers["id"], None, "PUBLISHED") + + +@celery_app.task(name="check_for_expired_llm_context") +def check_for_expired_llm_context() -> None: + admin_user = security_manager.find_user(username="admin") + if not admin_user: + logger.error("Unable to find admin user") + return + + with override_user(admin_user): + databases = DatabaseDAO.find_all() + databases = [ + database + for database in databases + if database.llm_connection and database.llm_connection.enabled + ] + + # For the list of candidate DBs, we need to determine if we need to generate a context for them. + # We need to genereate one if any of the following are true: + # - There is no ContextBuilderTask for the database + # - The latest ContextBuilderTask for the database is in a failed state + # - The latest ContextBuilderTask for the database is in a success state, but the context is older than the configured refresh_interval + # - The latest ContextBuilderTask for the database is in a success state, but the context is empty + for database in databases: + latest_task = ContextBuilderTaskDAO.get_latest_task_for_database(database.id) + if not latest_task: + logger.info("No previous tasks for database %s", database.id) + initiate_context_generation(database.id) + continue + + task_result = AsyncResult(latest_task.task_id) + + if task_result.status == "PENDING" or task_result.status == "FAILURE": + logger.info("Old context failed - generating for database %s", database.id) + initiate_context_generation(database.id) + elif task_result.status == "SUCCESS": + refresh_interval = ( + int(database.llm_context_options.refresh_interval or 12) * 60 * 60 + ) + old_started_time = latest_task.started_time.replace( + tzinfo=datetime.timezone.utc + ) + if ( + datetime.datetime.now(datetime.timezone.utc) - old_started_time + ).total_seconds() > refresh_interval: + logger.info( + f"Old LLM context expired - generating for database {database.id}" + ) + initiate_context_generation(database.id) + elif not task_result.result: + logger.info( + f"Old LLM context missing - generating for database {database.id}" + ) + initiate_context_generation(database.id) + else: + logger.info("Nothing to be done for database %s", database.id) + + +def reduce_json_token_count(data: str) -> str: + """ + Reduces the token count of a JSON string. + """ + data = data.replace(": ", ":").replace(", ", ",") + + return data + + +def initiate_context_generation(pk: int) -> Any: + task = generate_llm_context.delay(pk) + + context_task = ContextBuilderTask( + database_id=pk, + task_id=task.id, + params=json.dumps({}), + started_time=datetime.datetime.now(datetime.timezone.utc), + status="PENDING", + ) + ContextBuilderTaskDAO.create(context_task) + db.session.commit() + logger.info("Task %s created for database %s", task.id, pk) + + return task + + +@celery_app.task(bind=True, name="generate_llm_context") +def generate_llm_context(self: Any, db_id: int) -> dict[str, Any] | None: + logger.info("Generating LLM context for database %s", db_id) + start_time = time.perf_counter() + task_status = "SUCCESS" + + try: + admin_user = security_manager.find_user(username="admin") + if not admin_user: + return {"status_code": 500, "message": "Unable to find admin user"} + + with override_user(admin_user): + database = DatabaseDAO.find_by_id(db_id) + + if not database: + return {"status_code": 404, "message": "Database not found"} + + settings = database.llm_context_options + selected_schemas = ( + json.loads(settings.schemas) if settings.schemas else None + ) + include_indexes = ( + settings.include_indexes if settings.include_indexes else True + ) + top_k = settings.top_k if settings.top_k else 10 + top_k_limit = settings.top_k_limit if settings.top_k_limit else 50000 + + schemas = get_database_metadata( + database, None, include_indexes, selected_schemas, top_k, top_k_limit + ) + logger.info("Done generating LLM context for database %s", db_id) + + schema_json = reduce_json_token_count( + json.dumps([schema.model_dump() for schema in schemas]) + ) + except Exception: + task_status = "ERROR" + raise + finally: + db_task = ContextBuilderTaskDAO.find_by_task_id(self.request.id) + if db_task: + db_task.ended_time = datetime.datetime.now(datetime.timezone.utc) + db_task.status = task_status + end_time = time.perf_counter() + db_task.duration = int(end_time * 1000 - start_time * 1000) + db.session.commit() + else: + logger.error("Task %s not found in database", self.request.id) + + return {"status_code": 200, "result": schema_json} diff --git a/superset/views/sql_lab/views.py b/superset/views/sql_lab/views.py index 4ac9b51fcc46..28f25c5442fb 100644 --- a/superset/views/sql_lab/views.py +++ b/superset/views/sql_lab/views.py @@ -154,6 +154,8 @@ def put(self, tab_state_id: int) -> FlaskResponse: try: fields = {k: json.loads(v) for k, v in request.form.to_dict().items()} + if "schema" in fields: + fields["schema"] = json.dumps(fields["schema"]) db.session.query(TabState).filter_by(id=tab_state_id).update(fields) db.session.commit() return json_success(json.dumps(tab_state_id)) diff --git a/superset/views/utils.py b/superset/views/utils.py index a431c307fb18..ce57cc400030 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -64,7 +64,8 @@ def sanitize_datasource_data(datasource_data: dict[str, Any]) -> dict[str, Any]: datasource_database = datasource_data.get("database") if datasource_database: datasource_database["parameters"] = {} - + datasource_database["llm_connection"] = {} + datasource_database["llm_context_options"] = {} return datasource_data diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 225e3f582dc5..7af79a4f2350 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -199,6 +199,8 @@ def test_get_items(self): "extra", "force_ctas_schema", "id", + "llm_connection", + "llm_context_options", "uuid", ] diff --git a/tests/unit_tests/commands/databases/tables_test.py b/tests/unit_tests/commands/databases/tables_test.py index 80590c7ada73..b092ddc09c80 100644 --- a/tests/unit_tests/commands/databases/tables_test.py +++ b/tests/unit_tests/commands/databases/tables_test.py @@ -105,9 +105,14 @@ def test_tables_with_catalog( assert payload == { "count": 3, "result": [ - {"value": "table1", "type": "table", "extra": {"foo": "bar"}}, - {"value": "table2", "type": "table", "extra": None}, - {"value": "view1", "type": "view"}, + { + "value": "table1", + "type": "table", + "extra": {"foo": "bar"}, + "schema": "schema1", + }, + {"value": "table2", "type": "table", "extra": None, "schema": "schema1"}, + {"value": "view1", "type": "view", "schema": "schema1"}, ], } @@ -178,9 +183,14 @@ def test_tables_without_catalog( assert payload == { "count": 3, "result": [ - {"value": "table1", "type": "table", "extra": {"foo": "bar"}}, - {"value": "table2", "type": "table", "extra": None}, - {"value": "view1", "type": "view"}, + { + "value": "table1", + "type": "table", + "extra": {"foo": "bar"}, + "schema": "schema1", + }, + {"value": "table2", "type": "table", "extra": None, "schema": "schema1"}, + {"value": "view1", "type": "view", "schema": "schema1"}, ], }