Spaces:
Sleeping
Sleeping
Commit
Β·
2080a31
1
Parent(s):
34a5262
fixing check query
Browse files- app/agents/sql_agent.py +23 -16
- app/schemas/agent_state.py +1 -1
app/agents/sql_agent.py
CHANGED
|
@@ -102,6 +102,10 @@ class SQLAgent:
|
|
| 102 |
def creating_sql_agent_chain():
|
| 103 |
"""Creating a sql agent chain"""
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
print("Creating a sql agent chain")
|
| 106 |
sql_agent_prompt = ChatPromptTemplate.from_messages([
|
| 107 |
("system", """You are a supervisor SQL agent managing tools to get the answer to the user's query.
|
|
@@ -110,23 +114,21 @@ class SQLAgent:
|
|
| 110 |
1. list_table_tools - List all tables from the database
|
| 111 |
2. get_schema - Get the schema of required tables
|
| 112 |
3. generate_query - Generate a SQL query
|
| 113 |
-
4. check_query - Check if the query is correct
|
| 114 |
-
5. execute_query - Execute the query
|
| 115 |
-
6. response - Create response for the user
|
| 116 |
|
|
|
|
|
|
|
|
|
|
| 117 |
Current state:
|
| 118 |
- Tables listed: {tables_list}
|
| 119 |
- Schema retrieved: {schema_of_table}
|
| 120 |
- Query generated: {query_gen}
|
| 121 |
-
- Query checked: {check_query}
|
| 122 |
- Query executed: {execute_query}
|
| 123 |
- Response created: {response_to_user}
|
| 124 |
|
| 125 |
If no tables are listed, respond with 'list_table_tools'.
|
| 126 |
If tables are listed but no schema, respond with 'get_schema'.
|
| 127 |
If schema exists but no query generated, respond with 'generate_query'.
|
| 128 |
-
If query generated but not
|
| 129 |
-
If query checked but not executed, respond with 'execute_query'.
|
| 130 |
If query executed but no response, respond with 'response'.
|
| 131 |
If everything is complete, respond with 'DONE'.
|
| 132 |
|
|
@@ -149,11 +151,11 @@ class SQLAgent:
|
|
| 149 |
tables_list = bool(state.get("tables_list", "").strip())
|
| 150 |
schema_of_table = bool(state.get("schema_of_table", "").strip())
|
| 151 |
query_gen = bool(state.get("query_gen", "").strip())
|
| 152 |
-
check_query = bool(state.get("check_query", "").strip())
|
| 153 |
execute_query = bool(state.get("execute_query", "").strip())
|
| 154 |
response_to_user = bool(state.get("response_to_user", "").strip())
|
| 155 |
|
| 156 |
-
print(f"State check - Tables: {tables_list}, Schema: {schema_of_table}, Query: {query_gen}, Check: {check_query}, Execute: {execute_query}, Response: {response_to_user}")
|
| 157 |
|
| 158 |
chain = creating_sql_agent_chain()
|
| 159 |
decision = chain.invoke({
|
|
@@ -161,7 +163,7 @@ class SQLAgent:
|
|
| 161 |
"tables_list": tables_list,
|
| 162 |
"schema_of_table": schema_of_table,
|
| 163 |
"query_gen": query_gen,
|
| 164 |
-
"check_query": check_query,
|
| 165 |
"execute_query": execute_query,
|
| 166 |
"response_to_user": response_to_user
|
| 167 |
})
|
|
@@ -180,9 +182,9 @@ class SQLAgent:
|
|
| 180 |
elif "generate_query" in decision_text:
|
| 181 |
next_tool = "generate_query"
|
| 182 |
agent_msg = "π SQL Agent: Generating SQL query."
|
| 183 |
-
elif "check_query" in decision_text:
|
| 184 |
-
|
| 185 |
-
|
| 186 |
elif "execute_query" in decision_text:
|
| 187 |
next_tool = "execute_query"
|
| 188 |
agent_msg = "π SQL Agent: Executing query."
|
|
@@ -208,9 +210,13 @@ class SQLAgent:
|
|
| 208 |
if next_tool == "end" or state.get("task_complete", False):
|
| 209 |
return END
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
valid_tools = [
|
| 212 |
"sql_agent", "list_table_tools", "get_schema", "generate_query",
|
| 213 |
-
"
|
| 214 |
]
|
| 215 |
|
| 216 |
return next_tool if next_tool in valid_tools else "sql_agent"
|
|
@@ -223,7 +229,7 @@ class SQLAgent:
|
|
| 223 |
workflow.add_node("list_table_tools", self.db_tools.list_table_tools)
|
| 224 |
workflow.add_node("get_schema", self.db_tools.get_schema)
|
| 225 |
workflow.add_node("generate_query", self.db_tools.generate_query)
|
| 226 |
-
workflow.add_node("check_query", self.db_tools.check_query)
|
| 227 |
workflow.add_node("execute_query", self.db_tools.execute_query)
|
| 228 |
workflow.add_node("response", self.db_tools.create_response)
|
| 229 |
|
|
@@ -231,7 +237,8 @@ class SQLAgent:
|
|
| 231 |
workflow.set_entry_point("sql_agent")
|
| 232 |
|
| 233 |
# Add routing
|
| 234 |
-
for node in ["sql_agent", "list_table_tools", "get_schema", "generate_query", "check_query", "execute_query", "response"]:
|
|
|
|
| 235 |
workflow.add_conditional_edges(
|
| 236 |
node,
|
| 237 |
router,
|
|
@@ -240,7 +247,7 @@ class SQLAgent:
|
|
| 240 |
"list_table_tools": "list_table_tools",
|
| 241 |
"get_schema": "get_schema",
|
| 242 |
"generate_query": "generate_query",
|
| 243 |
-
"check_query": "check_query",
|
| 244 |
"execute_query": "execute_query",
|
| 245 |
"response": "response",
|
| 246 |
END: END
|
|
|
|
| 102 |
def creating_sql_agent_chain():
|
| 103 |
"""Creating a sql agent chain"""
|
| 104 |
|
| 105 |
+
# 4. check_query - Check if the query is correct
|
| 106 |
+
# - Query checked: {check_query}
|
| 107 |
+
# If query generated but not checked, respond with 'check_query'.
|
| 108 |
+
# If query checked but not executed, respond with 'execute_query'.
|
| 109 |
print("Creating a sql agent chain")
|
| 110 |
sql_agent_prompt = ChatPromptTemplate.from_messages([
|
| 111 |
("system", """You are a supervisor SQL agent managing tools to get the answer to the user's query.
|
|
|
|
| 114 |
1. list_table_tools - List all tables from the database
|
| 115 |
2. get_schema - Get the schema of required tables
|
| 116 |
3. generate_query - Generate a SQL query
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
4. execute_query - Execute the query
|
| 119 |
+
5. response - Create response for the user
|
| 120 |
+
|
| 121 |
Current state:
|
| 122 |
- Tables listed: {tables_list}
|
| 123 |
- Schema retrieved: {schema_of_table}
|
| 124 |
- Query generated: {query_gen}
|
|
|
|
| 125 |
- Query executed: {execute_query}
|
| 126 |
- Response created: {response_to_user}
|
| 127 |
|
| 128 |
If no tables are listed, respond with 'list_table_tools'.
|
| 129 |
If tables are listed but no schema, respond with 'get_schema'.
|
| 130 |
If schema exists but no query generated, respond with 'generate_query'.
|
| 131 |
+
If query generated but not executed, respond with 'execute_query'.
|
|
|
|
| 132 |
If query executed but no response, respond with 'response'.
|
| 133 |
If everything is complete, respond with 'DONE'.
|
| 134 |
|
|
|
|
| 151 |
tables_list = bool(state.get("tables_list", "").strip())
|
| 152 |
schema_of_table = bool(state.get("schema_of_table", "").strip())
|
| 153 |
query_gen = bool(state.get("query_gen", "").strip())
|
| 154 |
+
# check_query = bool(state.get("check_query", "").strip())
|
| 155 |
execute_query = bool(state.get("execute_query", "").strip())
|
| 156 |
response_to_user = bool(state.get("response_to_user", "").strip())
|
| 157 |
|
| 158 |
+
# print(f"State check - Tables: {tables_list}, Schema: {schema_of_table}, Query: {query_gen}, Check: {check_query}, Execute: {execute_query}, Response: {response_to_user}")
|
| 159 |
|
| 160 |
chain = creating_sql_agent_chain()
|
| 161 |
decision = chain.invoke({
|
|
|
|
| 163 |
"tables_list": tables_list,
|
| 164 |
"schema_of_table": schema_of_table,
|
| 165 |
"query_gen": query_gen,
|
| 166 |
+
# "check_query": check_query,
|
| 167 |
"execute_query": execute_query,
|
| 168 |
"response_to_user": response_to_user
|
| 169 |
})
|
|
|
|
| 182 |
elif "generate_query" in decision_text:
|
| 183 |
next_tool = "generate_query"
|
| 184 |
agent_msg = "π SQL Agent: Generating SQL query."
|
| 185 |
+
# elif "check_query" in decision_text:
|
| 186 |
+
# next_tool = "check_query"
|
| 187 |
+
# agent_msg = "π SQL Agent: Checking SQL query."
|
| 188 |
elif "execute_query" in decision_text:
|
| 189 |
next_tool = "execute_query"
|
| 190 |
agent_msg = "π SQL Agent: Executing query."
|
|
|
|
| 210 |
if next_tool == "end" or state.get("task_complete", False):
|
| 211 |
return END
|
| 212 |
|
| 213 |
+
# valid_tools = [
|
| 214 |
+
# "sql_agent", "list_table_tools", "get_schema", "generate_query",
|
| 215 |
+
# "check_query", "execute_query", "response"
|
| 216 |
+
# ]
|
| 217 |
valid_tools = [
|
| 218 |
"sql_agent", "list_table_tools", "get_schema", "generate_query",
|
| 219 |
+
"execute_query", "response"
|
| 220 |
]
|
| 221 |
|
| 222 |
return next_tool if next_tool in valid_tools else "sql_agent"
|
|
|
|
| 229 |
workflow.add_node("list_table_tools", self.db_tools.list_table_tools)
|
| 230 |
workflow.add_node("get_schema", self.db_tools.get_schema)
|
| 231 |
workflow.add_node("generate_query", self.db_tools.generate_query)
|
| 232 |
+
# workflow.add_node("check_query", self.db_tools.check_query)
|
| 233 |
workflow.add_node("execute_query", self.db_tools.execute_query)
|
| 234 |
workflow.add_node("response", self.db_tools.create_response)
|
| 235 |
|
|
|
|
| 237 |
workflow.set_entry_point("sql_agent")
|
| 238 |
|
| 239 |
# Add routing
|
| 240 |
+
# for node in ["sql_agent", "list_table_tools", "get_schema", "generate_query", "check_query", "execute_query", "response"]:
|
| 241 |
+
for node in ["sql_agent", "list_table_tools", "get_schema", "generate_query", "execute_query", "response"]:
|
| 242 |
workflow.add_conditional_edges(
|
| 243 |
node,
|
| 244 |
router,
|
|
|
|
| 247 |
"list_table_tools": "list_table_tools",
|
| 248 |
"get_schema": "get_schema",
|
| 249 |
"generate_query": "generate_query",
|
| 250 |
+
# "check_query": "check_query",
|
| 251 |
"execute_query": "execute_query",
|
| 252 |
"response": "response",
|
| 253 |
END: END
|
app/schemas/agent_state.py
CHANGED
|
@@ -7,7 +7,7 @@ class SQLAgentState(MessagesState):
|
|
| 7 |
tables_list: str = ""
|
| 8 |
schema_of_table: str = ""
|
| 9 |
query_gen : str= ""
|
| 10 |
-
check_query: str = ""
|
| 11 |
execute_query : str = ""
|
| 12 |
task_complete: bool = False
|
| 13 |
response_to_user: str= ""
|
|
|
|
| 7 |
tables_list: str = ""
|
| 8 |
schema_of_table: str = ""
|
| 9 |
query_gen : str= ""
|
| 10 |
+
# check_query: str = ""
|
| 11 |
execute_query : str = ""
|
| 12 |
task_complete: bool = False
|
| 13 |
response_to_user: str= ""
|