diff --git a/src/mysql_mcp_server/server.py b/src/mysql_mcp_server/server.py index dc8d87e..651c50d 100644 --- a/src/mysql_mcp_server/server.py +++ b/src/mysql_mcp_server/server.py @@ -110,40 +110,64 @@ async def list_tools() -> list[Tool]: @app.call_tool() async def call_tool(name: str, arguments: dict) -> list[TextContent]: """Execute SQL commands.""" - config = get_db_config() logger.info(f"Calling tool: {name} with arguments: {arguments}") + # Verify tool name first, before checking DB config if name != "execute_sql": raise ValueError(f"Unknown tool: {name}") + # Then check if query is provided query = arguments.get("query") if not query: raise ValueError("Query is required") + # Now get DB config - this allows the above validation tests to pass + # without requiring actual DB credentials + config = get_db_config() + try: with connect(**config) as conn: with conn.cursor() as cursor: + # Always ensure we consume all results to avoid "Unread result found" errors cursor.execute(query) + # Check the query type by normalizing and checking the first word + query_upper = query.strip().upper() + # Special handling for SHOW TABLES - if query.strip().upper().startswith("SHOW TABLES"): + if query_upper.startswith("SHOW TABLES"): tables = cursor.fetchall() result = ["Tables_in_" + config["database"]] # Header result.extend([table[0] for table in tables]) return [TextContent(type="text", text="\n".join(result))] - # Handle all other queries that return result sets (SELECT, SHOW, DESCRIBE etc.) + # Special handling for DESCRIBE and SHOW COLUMNS + elif query_upper.startswith("DESCRIBE ") or query_upper.startswith("DESC ") or query_upper.startswith("SHOW COLUMNS FROM ") or query_upper.startswith("SHOW FIELDS FROM "): + columns = [desc[0] for desc in cursor.description] + rows = cursor.fetchall() + + # Format the results in a more readable way + results = [] + results.append(",".join(columns)) + for row in rows: + # Convert None values to "NULL" for better readability + formatted_row = [str(val) if val is not None else "NULL" for val in row] + results.append(",".join(formatted_row)) + + return [TextContent(type="text", text="\n".join(results))] + + # Handle all other queries that return result sets (SELECT, SHOW, etc.) elif cursor.description is not None: columns = [desc[0] for desc in cursor.description] - try: - rows = cursor.fetchall() - result = [",".join(map(str, row)) for row in rows] - return [TextContent(type="text", text="\n".join([",".join(columns)] + result))] - except Error as e: - logger.warning(f"Error fetching results: {str(e)}") - return [TextContent(type="text", text=f"Query executed but error fetching results: {str(e)}")] + rows = cursor.fetchall() + + if not rows: + return [TextContent(type="text", text="Query executed successfully. No results returned.")] + + result = [",".join(map(str, row)) for row in rows] + return [TextContent(type="text", text="\n".join([",".join(columns)] + result))] - # Non-SELECT queries + # Non-SELECT queries (INSERT, UPDATE, DELETE, etc.) else: conn.commit() return [TextContent(type="text", text=f"Query executed successfully. Rows affected: {cursor.rowcount}")]