diff --git a/src/server.py b/src/server.py index 54c0ef9..bd0ee98 100644 --- a/src/server.py +++ b/src/server.py @@ -160,13 +160,119 @@ async def _execute_query(self, sql: str, params: Optional[tuple] = None, databas conn_state = f"Connection: {'acquired' if conn else 'not acquired'}" logger.error(f"Unexpected error during query execution ({conn_state}): {e}", exc_info=True) raise RuntimeError(f"An unexpected error occurred: {e}") from e + + def _has_properly_escaped_backticks(self, identifier: str) -> bool: + """ + Check that any backticks inside quoted identifier are properly escaped (doubled). + Returns True if all backticks are properly escaped, False otherwise. + + - identifier (str): quoted identifier. + """ + i = 0 + while i < len(identifier): + if identifier[i] == '`': + if i + 1 < len(identifier) and identifier[i + 1] == '`': + i += 2 # Skip the escaped pair + else: + return False # Found unescaped backtick + else: + i += 1 + return True + + def _quote_identifier(self, identifier: str) -> str: + """ + Quote an identifier + If already quoted, returns as-is. If unquoted, wraps in backticks. + + Parameters: + - identifier (str): The identifier to quote + + Returns: + - str: Quoted identifier + """ + if identifier is None: + raise ValueError("Identifier cannot be None") - async def _database_exists(self, database_name: str) -> bool: - """Checks if a database exists.""" - if not database_name or not database_name.isidentifier(): - logger.warning(f"_database_exists called with invalid database_name: {database_name}") - return False + if identifier.startswith('`') and identifier.endswith('`'): + # Already quoted, return as-is + return identifier + else: + # Unquoted, wrap in backticks and escape any existing backticks + escaped_content = identifier.replace('`', '``') + return f'`{escaped_content}`' + + + def _is_valid_identifier(self, identifier: str) -> bool: + """ + Validates MariaDB identifier that will be quoted when used in SQL. + Accepts both quoted and unquoted identifiers since all identifiers + are treated as "will be quoted when needed". + Parameters: + - identifier (str): identifier (quoted or unquoted). + """ + if not identifier: + return False + + # If unquoted, quote it first, then validate as quoted identifier + if not (identifier.startswith('`') and identifier.endswith('`')): + identifier = self._quote_identifier(identifier) + + # Now validate as quoted identifier + if len(identifier) <= 2: + return False + + actual_name = identifier[1:-1] + + # Check that any backticks inside are properly escaped (doubled) + if not self._has_properly_escaped_backticks(actual_name): + return False + + # Handle escaped backticks to get the real length + escaped_name = actual_name.replace('``', '`') + + # Basic length check + if len(escaped_name) > 64: + return False + + # No trailing spaces allowed + if escaped_name.endswith(' '): + return False + + # No null characters + if '\x00' in escaped_name: + return False + + return True + + def _normalize_identifier(self, identifier: str, method_name: str) -> str: + """ + Normalizes and validates an identifier for MCP methods. + Validates the identifier and returns the unquoted version. + + Parameters: + - identifier (str): identifier (quoted or unquoted) + - method_name (str): name of the calling method for error messages + + Returns: + - str: unquoted identifier + + Raises: + - ValueError: if identifier is invalid + """ + if not self._is_valid_identifier(identifier): + error_msg = f"Invalid identifier '{identifier}' in {method_name}" + logger.error(error_msg) + raise ValueError(error_msg) + + # Strip quotes if present + if identifier.startswith('`') and identifier.endswith('`'): + return identifier[1:-1].replace('``', '`') + else: + return identifier + + async def _database_exists(self, database_name: str) -> bool: + """Checks if a database exists. Expects normalized (unquoted) database name.""" sql = "SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = %s" try: results = await self._execute_query(sql, params=(database_name,), database='information_schema') @@ -176,12 +282,7 @@ async def _database_exists(self, database_name: str) -> bool: return False async def _table_exists(self, database_name: str, table_name: str) -> bool: - """Checks if a table exists in the given database.""" - if not database_name or not database_name.isidentifier() or \ - not table_name or not table_name.isidentifier(): - logger.warning(f"_table_exists called with invalid names: db='{database_name}', table='{table_name}'") - return False - + """Checks if a table exists in the given database. Expects normalized (unquoted) names.""" sql = "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s" try: results = await self._execute_query(sql, params=(database_name, table_name), database='information_schema') @@ -205,10 +306,7 @@ async def _is_vector_store(self, database_name: str, table_name: str) -> bool: """ logger.debug(f"Checking if '{database_name}.{table_name}' is a vector store.") - if not database_name or not database_name.isidentifier() or \ - not table_name or not table_name.isidentifier(): - logger.warning(f"_is_vector_store called with invalid names: db='{database_name}', table='{table_name}'") - return False + # Expects normalized (unquoted) names # SQL query to verify vector store criteria sql_query = """ @@ -254,9 +352,8 @@ async def list_databases(self) -> List[str]: async def list_tables(self, database_name: str) -> List[str]: """Lists all tables within the specified database.""" logger.info(f"TOOL START: list_tables called. database_name={database_name}") - if not database_name or not database_name.isidentifier(): - logger.warning(f"TOOL WARNING: list_tables called with invalid database_name: {database_name}") - raise ValueError(f"Invalid database name provided: {database_name}") + database_name = self._normalize_identifier(database_name, "list_tables") + sql = "SHOW TABLES" try: results = await self._execute_query(sql, database=database_name) @@ -273,18 +370,15 @@ async def get_table_schema(self, database_name: str, table_name: str) -> Dict[st for a specific table in a database. """ logger.info(f"TOOL START: get_table_schema called. database_name={database_name}, table_name={table_name}") - if not database_name or not database_name.isidentifier(): - logger.warning(f"TOOL WARNING: get_table_schema called with invalid database_name: {database_name}") - raise ValueError(f"Invalid database name provided: {database_name}") - if not table_name or not table_name.isidentifier(): - logger.warning(f"TOOL WARNING: get_table_schema called with invalid table_name: {table_name}") - raise ValueError(f"Invalid table name provided: {table_name}") - - sql = f"DESCRIBE `{database_name}`.`{table_name}`" + database_name = self._normalize_identifier(database_name, "get_table_schema") + table_name = self._normalize_identifier(table_name, "get_table_schema") + + sql = f"DESCRIBE {self._quote_identifier(database_name)}.{self._quote_identifier(table_name)}" try: schema_results = await self._execute_query(sql) schema_info = {} if not schema_results: + # Use normalized names for information_schema query exists_sql = "SELECT COUNT(*) as count FROM information_schema.tables WHERE table_schema = %s AND table_name = %s" exists_result = await self._execute_query(exists_sql, params=(database_name, table_name)) if not exists_result or exists_result[0]['count'] == 0: @@ -318,12 +412,8 @@ async def get_table_schema_with_relations(self, database_name: str, table_name: Includes all basic schema info plus foreign key relationships and referenced tables. """ logger.info(f"TOOL START: get_table_schema_with_relations called. database_name={database_name}, table_name={table_name}") - if not database_name or not database_name.isidentifier(): - logger.warning(f"TOOL WARNING: get_table_schema_with_relations called with invalid database_name: {database_name}") - raise ValueError(f"Invalid database name provided: {database_name}") - if not table_name or not table_name.isidentifier(): - logger.warning(f"TOOL WARNING: get_table_schema_with_relations called with invalid table_name: {table_name}") - raise ValueError(f"Invalid table name provided: {table_name}") + database_name = self._normalize_identifier(database_name, "get_table_schema_with_relations") + table_name = self._normalize_identifier(table_name, "get_table_schema_with_relations") try: # 1. Get basic schema information @@ -348,6 +438,7 @@ async def get_table_schema_with_relations(self, database_name: str, table_name: ORDER BY kcu.CONSTRAINT_NAME, kcu.ORDINAL_POSITION """ + # Use normalized names for information_schema query fk_results = await self._execute_query(fk_sql, params=(database_name, table_name)) # 3. Add foreign key information to the basic schema @@ -389,9 +480,7 @@ async def execute_sql(self, sql_query: str, database_name: str, parameters: Opti Example `parameters`: ["value1", 123] corresponding to %s placeholders in `sql_query`. """ logger.info(f"TOOL START: execute_sql called. database_name={database_name}, sql_query={sql_query[:100]}, parameters={parameters}") - if database_name and not database_name.isidentifier(): - logger.warning(f"TOOL WARNING: execute_sql called with invalid database_name: {database_name}") - raise ValueError(f"Invalid database name provided: {database_name}") + database_name = self._normalize_identifier(database_name, "execute_sql") param_tuple = tuple(parameters) if parameters is not None else None try: results = await self._execute_query(sql_query, params=param_tuple, database=database_name) @@ -406,9 +495,7 @@ async def create_database(self, database_name: str) -> Dict[str, Any]: Creates a new database if it doesn't exist. """ logger.info(f"TOOL START: create_database called for database: '{database_name}'") - if not database_name or not database_name.isidentifier(): - logger.error(f"Invalid database_name for creation: '{database_name}'. Must be a valid identifier.") - raise ValueError(f"Invalid database_name for creation: '{database_name}'. Must be a valid identifier.") + database_name = self._normalize_identifier(database_name, "create_database") # Check existence first to provide a clear message, though CREATE DATABASE IF NOT EXISTS is idempotent if await self._database_exists(database_name): @@ -416,7 +503,7 @@ async def create_database(self, database_name: str) -> Dict[str, Any]: logger.info(f"TOOL END: create_database. {message}") return {"status": "exists", "message": message, "database_name": database_name} - sql = f"CREATE DATABASE IF NOT EXISTS `{database_name}`;" + sql = f"CREATE DATABASE IF NOT EXISTS {self._quote_identifier(database_name)};" try: await self._execute_query(sql, database=None) @@ -455,12 +542,8 @@ async def create_vector_store_tool(self, logger.info(f"TOOL START: create_vector_store called. DB: '{database_name}', Store: '{vector_store_name}', Model: '{model_name}', Embedding_Length: {embedding_length}, Distance_Requested: '{distance_function}'") # --- Input Validation --- - if not database_name or not database_name.isidentifier(): - logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - if not vector_store_name or not vector_store_name.isidentifier(): - logger.error(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") - raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") + database_name = self._normalize_identifier(database_name, "create_vector_store_tool") + vector_store_name = self._normalize_identifier(vector_store_name, "create_vector_store_tool") if not isinstance(embedding_length, int) or embedding_length <= 0: logger.error(f"Invalid embedding_length: {embedding_length}. Must be a positive integer.") @@ -504,7 +587,7 @@ async def create_vector_store_tool(self, # --- SQL Query for Vector Store Table Creation --- schema_query = f""" - CREATE TABLE IF NOT EXISTS `{vector_store_name}` ( + CREATE TABLE IF NOT EXISTS {self._quote_identifier(vector_store_name)} ( id VARCHAR(36) NOT NULL DEFAULT UUID_v7() PRIMARY KEY, document TEXT NOT NULL, embedding VECTOR({embedding_length}) NOT NULL, @@ -550,9 +633,7 @@ async def list_vector_stores(self, database_name: str) -> List[str]: logger.info(f"TOOL START: list_vector_stores called for database: '{database_name}'") # --- Input Validation --- - if not database_name or not database_name.isidentifier(): - logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") + database_name = self._normalize_identifier(database_name, "list_vector_stores") if not await self._database_exists(database_name): logger.warning(f"Database '{database_name}' does not exist. Cannot list vector stores.") @@ -577,6 +658,7 @@ async def list_vector_stores(self, database_name: str) -> List[str]: """ try: + # Use normalized name for information_schema query results = await self._execute_query(sql_query, params=(database_name,), database='information_schema') store_list = [row['TABLE_NAME'] for row in results if 'TABLE_NAME' in row] @@ -614,12 +696,8 @@ async def delete_vector_store(self, logger.info(f"TOOL START: delete_vector_store called for: '{database_name}.{vector_store_name}'") # --- Input Validation for names --- - if not database_name or not database_name.isidentifier(): - logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - if not vector_store_name or not vector_store_name.isidentifier(): - logger.error(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") - raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") + database_name = self._normalize_identifier(database_name, "delete_vector_store") + vector_store_name = self._normalize_identifier(vector_store_name, "delete_vector_store") # --- Database Existence Check --- if not await self._database_exists(database_name): @@ -640,7 +718,7 @@ async def delete_vector_store(self, return {"status": "not_vector_store", "message": message} # --- SQL Query for Deletion --- - drop_query = f"DROP TABLE IF EXISTS `{vector_store_name}`;" + drop_query = f"DROP TABLE IF EXISTS {self._quote_identifier(vector_store_name)};" try: await self._execute_query(drop_query, database=database_name) @@ -670,12 +748,8 @@ async def insert_docs_vector_store(self, database_name: str, vector_store_name: If metadata is not provided, an empty dict will be used for each document. """ import json - if not database_name or not database_name.isidentifier(): - logger.error(f"Invalid database_name: '{database_name}'") - raise ValueError(f"Invalid database_name: '{database_name}'") - if not vector_store_name or not vector_store_name.isidentifier(): - logger.error(f"Invalid vector_store_name: '{vector_store_name}'") - raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'") + database_name = self._normalize_identifier(database_name, "insert_docs_vector_store") + vector_store_name = self._normalize_identifier(vector_store_name, "insert_docs_vector_store") if not isinstance(documents, list) or not documents or not all(isinstance(doc, str) and doc for doc in documents): logger.error("'documents' must be a non-empty list of non-empty strings.") raise ValueError("'documents' must be a non-empty list of non-empty strings.") @@ -690,7 +764,7 @@ async def insert_docs_vector_store(self, database_name: str, vector_store_name: # Prepare metadata JSON metadata_json = [json.dumps(m) for m in metadata] # Prepare values for batch insert - insert_query = f"INSERT INTO `{database_name}`.`{vector_store_name}` (document, embedding, metadata) VALUES (%s, VEC_FromText(%s), %s)" + insert_query = f"INSERT INTO {self._quote_identifier(database_name)}.{self._quote_identifier(vector_store_name)} (document, embedding, metadata) VALUES (%s, VEC_FromText(%s), %s)" inserted = 0 errors = [] for doc, emb, meta in zip(documents, embeddings, metadata_json): @@ -723,12 +797,8 @@ async def search_vector_store(self, user_query: str, database_name: str, vector_ if not user_query or not isinstance(user_query, str): logger.error("user_query must be a non-empty string.") raise ValueError("user_query must be a non-empty string.") - if not database_name or not database_name.isidentifier(): - logger.error(f"Invalid database_name: '{database_name}'") - raise ValueError(f"Invalid database_name: '{database_name}'") - if not vector_store_name or not vector_store_name.isidentifier(): - logger.error(f"Invalid vector_store_name: '{vector_store_name}'") - raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'") + database_name = self._normalize_identifier(database_name, "search_vector_store") + vector_store_name = self._normalize_identifier(vector_store_name, "search_vector_store") if not isinstance(k, int) or k <= 0: logger.error("k must be a positive integer.") raise ValueError("k must be a positive integer.") @@ -741,7 +811,7 @@ async def search_vector_store(self, user_query: str, database_name: str, vector_ document, metadata, VEC_DISTANCE_COSINE(embedding, VEC_FromText(%s)) AS distance - FROM `{database_name}`.`{vector_store_name}` + FROM {self._quote_identifier(database_name)}.{self._quote_identifier(vector_store_name)} ORDER BY distance ASC LIMIT %s """