Skip to content

Commit 6b60681

Browse files
committed
Add the NL2SQL tool to the MySQL MCP server. This tool internally uses semantic models to convert the user's natural language query into a SQL query and then executes it
1 parent 70932b0 commit 6b60681

File tree

2 files changed

+145
-22
lines changed

2 files changed

+145
-22
lines changed

src/mysql-mcp-server/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ A Python-based MCP (Model Context Protocol) server that provides a suite of tool
2121
- `ragify_column`: Create/populate vector columns for embeddings
2222
- `ask_ml_rag`: Retrieval-augmented generation from vector stores
2323
- `heatwave_ask_help`: Answers questions about how to use HeatWave ML
24+
- `ask_nl_sql`: Convert natural language questions into SQL queries and execute them automatically
2425

2526
- **Vector Store Management**
2627
- List files in `secure_file_priv` (local mode)
@@ -213,6 +214,7 @@ python mysql_mcp_server.py
213214
11. `list_all_compartments()`: List OCI compartments
214215
12. `object_storage_list_buckets(compartment_name | compartment_id)`: List buckets in a compartment
215216
13. `object_storage_list_objects(namespace, bucket_name)`: List objects in a bucket
217+
14. `ask_nl_sql(connection_id, question)`: Convert natural language questions into SQL queries and execute them automatically
216218

217219
## Security
218220

@@ -236,6 +238,7 @@ Here are example prompts you can use to interact with the MCP server, note that
236238
```
237239
"Generate a summary of error logs"
238240
"Ask ml_rag: Show me refund policy from the vector store"
241+
"What is the average delay incurred by flights?"
239242
```
240243

241244
### 3. Object Storage

src/mysql-mcp-server/mysql_mcp_server.py

Lines changed: 142 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,13 @@
1212
from fastmcp import FastMCP
1313
from mysql import connector
1414
from mysql.connector.abstracts import MySQLConnectionAbstract
15-
16-
from utils import DatabaseConnectionError, get_ssh_command, load_mysql_config, Mode, OciInfo
15+
from utils import (
16+
DatabaseConnectionError,
17+
Mode,
18+
OciInfo,
19+
get_ssh_command,
20+
load_mysql_config,
21+
)
1722

1823
MIN_CONTEXT_SIZE = 10
1924
DEFAULT_CONTEXT_SIZE = 20
@@ -29,20 +34,26 @@
2934
try:
3035
config = load_mysql_config()
3136
except Exception as e:
32-
config_error_msg = json.dumps({
33-
"error" : f"Error loading config. Fix configuration file and try restarting MCP server {str(e)}."
34-
})
37+
config_error_msg = json.dumps(
38+
{
39+
"error": f"Error loading config. Fix configuration file and try restarting MCP server {str(e)}."
40+
}
41+
)
3542

3643
# Setup oci connection if applicable
3744
oci_info: Optional[OciInfo] = None # None if not available, otherwise OCI config info
38-
oci_error_msg: Optional[str] = None # None if OCI available, otherwise a json formatted string
45+
oci_error_msg: Optional[str] = (
46+
None # None if OCI available, otherwise a json formatted string
47+
)
3948
try:
4049
oci_info = OciInfo()
4150
except Exception as e:
42-
oci_error_msg = json.dumps({
43-
"error" : "object store unavailable. If object store is required, the MCP server must be restarted with a valid"
44-
f" OCI config. OCI connection attempt yielded error {str(e)}."
45-
})
51+
oci_error_msg = json.dumps(
52+
{
53+
"error": "object store unavailable. If object store is required, the MCP server must be restarted with a valid"
54+
f" OCI config. OCI connection attempt yielded error {str(e)}."
55+
}
56+
)
4657

4758
# Create mcp server
4859
mcp = FastMCP("MySQL")
@@ -51,6 +62,7 @@
5162
# Finish setup
5263
###############################################################
5364

65+
5466
def _validate_name(name: str) -> str:
5567
"""
5668
Validate that the string is a legal SQL identifier (letters, digits, underscores).
@@ -81,9 +93,7 @@ def _get_mode(connection_id: str) -> Mode:
8193
Returns:
8294
Mode: The resolved provider mode.
8395
"""
84-
provider_result = _execute_sql_tool(
85-
connection_id, "SELECT @@rapid_cloud_provider;"
86-
)
96+
provider_result = _execute_sql_tool(connection_id, "SELECT @@rapid_cloud_provider;")
8797
if check_error(provider_result):
8898
raise Exception(
8999
f"Exception occurred while fetching cloud provider {str(provider_result)}"
@@ -230,7 +240,7 @@ def list_all_connections() -> str:
230240
{
231241
"key": connection_id,
232242
"error": str(e),
233-
"hint": f"Bastion/jump host may be down. Try starting it with {get_ssh_command(config)}"
243+
"hint": f"Bastion/jump host may be down. Try starting it with {get_ssh_command(config)}",
234244
}
235245
)
236246
return json.dumps({"valid keys": valid_keys, "invalid keys": invalid_keys})
@@ -258,6 +268,19 @@ def execute_sql_tool_by_connection_id(
258268
return _execute_sql_tool(connection_id, sql_script, params=params)
259269

260270

271+
from datetime import date, datetime
272+
from decimal import Decimal
273+
274+
275+
class CustomJSONEncoder(json.JSONEncoder):
276+
def default(self, o):
277+
if isinstance(o, Decimal):
278+
return str(o)
279+
if isinstance(o, (date, datetime)):
280+
return o.isoformat()
281+
return super().default(o)
282+
283+
261284
def _execute_sql_tool(
262285
connection: Union[str, MySQLConnectionAbstract],
263286
sql_script: str,
@@ -309,7 +332,7 @@ def _execute_sql_tool(
309332

310333
db_connection.commit()
311334

312-
return json.dumps(results)
335+
return json.dumps(results, cls=CustomJSONEncoder)
313336

314337
except Exception as e:
315338
return json.dumps(
@@ -565,7 +588,9 @@ def load_vector_store_oci(
565588

566589

567590
@mcp.tool()
568-
def ask_ml_rag_vector_store(connection_id: str, question: str, context_size: int = DEFAULT_CONTEXT_SIZE) -> str:
591+
def ask_ml_rag_vector_store(
592+
connection_id: str, question: str, context_size: int = DEFAULT_CONTEXT_SIZE
593+
) -> str:
569594
"""
570595
[MCP Tool] Retrieve segments from the default vector store (skip_generate=true).
571596
@@ -586,16 +611,26 @@ def ask_ml_rag_vector_store(connection_id: str, question: str, context_size: int
586611
arguments: {"connection_id": "example_local_server", "question": "Find information about refunds."}
587612
"""
588613
if context_size < MIN_CONTEXT_SIZE or MAX_CONTEXT_SIZE < context_size:
589-
return json.dumps({"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"})
614+
return json.dumps(
615+
{
616+
"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"
617+
}
618+
)
590619

591620
return _ask_ml_rag_helper(
592-
connection_id, question, f"JSON_OBJECT('skip_generate', true, 'n_citations', {context_size})"
621+
connection_id,
622+
question,
623+
f"JSON_OBJECT('skip_generate', true, 'n_citations', {context_size})",
593624
)
594625

595626

596627
@mcp.tool()
597628
def ask_ml_rag_innodb(
598-
connection_id: str, question: str, segment_col: str, embedding_col: str, context_size: int = DEFAULT_CONTEXT_SIZE
629+
connection_id: str,
630+
question: str,
631+
segment_col: str,
632+
embedding_col: str,
633+
context_size: int = DEFAULT_CONTEXT_SIZE,
599634
) -> str:
600635
"""
601636
[MCP Tool] Retrieve segments from InnoDB tables using specified segment and embedding columns.
@@ -626,7 +661,11 @@ def ask_ml_rag_innodb(
626661
arguments: {"connection_id": "example_local_server", "question": "Search product docs", "segment_col": "body", "embedding_col": "embedding"}
627662
"""
628663
if context_size < MIN_CONTEXT_SIZE or MAX_CONTEXT_SIZE < context_size:
629-
return json.dumps({"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"})
664+
return json.dumps(
665+
{
666+
"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"
667+
}
668+
)
630669

631670
try:
632671
# prevent possible injection
@@ -732,6 +771,84 @@ def heatwave_ask_help(connection_id: str, question: str) -> str:
732771
return json.dumps({"error": f"Error with NL2ML: {str(e)}"})
733772

734773

774+
@mcp.tool()
775+
def ask_nl_sql(connection_id: str, question: str) -> str:
776+
"""
777+
[MCP Tool] Convert natural language questions into SQL queries and execute them automatically.
778+
779+
This tool is ideal for database exploration using plain English questions like:
780+
- "What tables are available?"
781+
- "Show me the average price by category"
782+
- "How many users registered last month?"
783+
- "What are the column names in the customers table?"
784+
785+
Args:
786+
connection_id (str): MySQL connection key.
787+
question (str): Natural language query.
788+
789+
Returns:
790+
JSON object containing:
791+
792+
sql_response(str): The response from executing the generated SQL query.
793+
sql_query(str): The generated SQL query
794+
schemas(json): The schemas where metadata was retrieved
795+
tables(json): The tables where metadata was retrieved
796+
is_sql_valid(bool): Whether the generated SQL statement is valid
797+
model_id(str): The LLM used for generation
798+
799+
800+
MCP usage example:
801+
- name: ask_nl_sql
802+
arguments: {"connection_id": "example_local_server", "question": "How many singers are there?"}
803+
804+
Here is the what part of the return JSON looks like;
805+
{
806+
"tables": [
807+
"singer.singer",
808+
"singer.song",
809+
"concert_singer.singer",
810+
"concert_singer.stadium",
811+
"music_2.Songs",
812+
"music_2.Instruments",
813+
"music_2.Band",
814+
"music_2.Vocals",
815+
"music_2.Tracklists"
816+
],
817+
"schemas": [
818+
"concert_singer",
819+
"music_2",
820+
"singer"
821+
],
822+
"sql_query": "SELECT COUNT(`Singer_ID`) FROM `concert_singer`.`singer`;",
823+
"is_sql_valid": 1
824+
}
825+
"""
826+
with _get_database_connection_cm(connection_id) as db_connection:
827+
# Execute the heatwave chat query
828+
set_response = _execute_sql_tool(db_connection, "SET @response = NULL;")
829+
if check_error(set_response):
830+
return json.dumps({"error": f"Error with NL_SQL: {set_response}"})
831+
832+
nl2sql_response = _execute_sql_tool(
833+
db_connection,
834+
f"CALL sys.NL_SQL(%s, @response, NULL)",
835+
params=[question],
836+
)
837+
if check_error(nl2sql_response):
838+
return json.dumps({"error": f"Error with NL_SQL: {nl2sql_response}"})
839+
840+
fetch_response = _execute_sql_tool(db_connection, "SELECT @response;")
841+
if check_error(fetch_response):
842+
return json.dumps({"error": f"Error with ML_RAG: {fetch_response}"})
843+
844+
try:
845+
response = json.loads(fetch_one(fetch_response))
846+
response["sql_response"] = nl2sql_response
847+
return json.dumps(response)
848+
except:
849+
return json.dumps({"error": "Unexpected response format from NL_SQL"})
850+
851+
735852
"""
736853
Object store
737854
"""
@@ -745,7 +862,7 @@ def verify_compartment_access(compartments):
745862
"compartment_id": compartment.id,
746863
"object_storage": False,
747864
"databases": False,
748-
"errors": []
865+
"errors": [],
749866
}
750867

751868
# Test Object Storage
@@ -756,10 +873,13 @@ def verify_compartment_access(compartments):
756873
)
757874
access_report[compartment.name]["object_storage"] = True
758875
except Exception as e:
759-
access_report[compartment.name]["errors"].append(f"Object Storage: {str(e)}")
876+
access_report[compartment.name]["errors"].append(
877+
f"Object Storage: {str(e)}"
878+
)
760879

761880
return access_report
762881

882+
763883
@mcp.tool()
764884
def list_all_compartments() -> str:
765885
"""

0 commit comments

Comments
 (0)