Kshitijk20 commited on
Commit
2080a31
Β·
1 Parent(s): 34a5262

fixing check query

Browse files
app/agents/sql_agent.py CHANGED
@@ -102,6 +102,10 @@ class SQLAgent:
102
  def creating_sql_agent_chain():
103
  """Creating a sql agent chain"""
104
 
 
 
 
 
105
  print("Creating a sql agent chain")
106
  sql_agent_prompt = ChatPromptTemplate.from_messages([
107
  ("system", """You are a supervisor SQL agent managing tools to get the answer to the user's query.
@@ -110,23 +114,21 @@ class SQLAgent:
110
  1. list_table_tools - List all tables from the database
111
  2. get_schema - Get the schema of required tables
112
  3. generate_query - Generate a SQL query
113
- 4. check_query - Check if the query is correct
114
- 5. execute_query - Execute the query
115
- 6. response - Create response for the user
116
 
 
 
 
117
  Current state:
118
  - Tables listed: {tables_list}
119
  - Schema retrieved: {schema_of_table}
120
  - Query generated: {query_gen}
121
- - Query checked: {check_query}
122
  - Query executed: {execute_query}
123
  - Response created: {response_to_user}
124
 
125
  If no tables are listed, respond with 'list_table_tools'.
126
  If tables are listed but no schema, respond with 'get_schema'.
127
  If schema exists but no query generated, respond with 'generate_query'.
128
- If query generated but not checked, respond with 'check_query'.
129
- If query checked but not executed, respond with 'execute_query'.
130
  If query executed but no response, respond with 'response'.
131
  If everything is complete, respond with 'DONE'.
132
 
@@ -149,11 +151,11 @@ class SQLAgent:
149
  tables_list = bool(state.get("tables_list", "").strip())
150
  schema_of_table = bool(state.get("schema_of_table", "").strip())
151
  query_gen = bool(state.get("query_gen", "").strip())
152
- check_query = bool(state.get("check_query", "").strip())
153
  execute_query = bool(state.get("execute_query", "").strip())
154
  response_to_user = bool(state.get("response_to_user", "").strip())
155
 
156
- print(f"State check - Tables: {tables_list}, Schema: {schema_of_table}, Query: {query_gen}, Check: {check_query}, Execute: {execute_query}, Response: {response_to_user}")
157
 
158
  chain = creating_sql_agent_chain()
159
  decision = chain.invoke({
@@ -161,7 +163,7 @@ class SQLAgent:
161
  "tables_list": tables_list,
162
  "schema_of_table": schema_of_table,
163
  "query_gen": query_gen,
164
- "check_query": check_query,
165
  "execute_query": execute_query,
166
  "response_to_user": response_to_user
167
  })
@@ -180,9 +182,9 @@ class SQLAgent:
180
  elif "generate_query" in decision_text:
181
  next_tool = "generate_query"
182
  agent_msg = "πŸ“‹ SQL Agent: Generating SQL query."
183
- elif "check_query" in decision_text:
184
- next_tool = "check_query"
185
- agent_msg = "πŸ“‹ SQL Agent: Checking SQL query."
186
  elif "execute_query" in decision_text:
187
  next_tool = "execute_query"
188
  agent_msg = "πŸ“‹ SQL Agent: Executing query."
@@ -208,9 +210,13 @@ class SQLAgent:
208
  if next_tool == "end" or state.get("task_complete", False):
209
  return END
210
 
 
 
 
 
211
  valid_tools = [
212
  "sql_agent", "list_table_tools", "get_schema", "generate_query",
213
- "check_query", "execute_query", "response"
214
  ]
215
 
216
  return next_tool if next_tool in valid_tools else "sql_agent"
@@ -223,7 +229,7 @@ class SQLAgent:
223
  workflow.add_node("list_table_tools", self.db_tools.list_table_tools)
224
  workflow.add_node("get_schema", self.db_tools.get_schema)
225
  workflow.add_node("generate_query", self.db_tools.generate_query)
226
- workflow.add_node("check_query", self.db_tools.check_query)
227
  workflow.add_node("execute_query", self.db_tools.execute_query)
228
  workflow.add_node("response", self.db_tools.create_response)
229
 
@@ -231,7 +237,8 @@ class SQLAgent:
231
  workflow.set_entry_point("sql_agent")
232
 
233
  # Add routing
234
- for node in ["sql_agent", "list_table_tools", "get_schema", "generate_query", "check_query", "execute_query", "response"]:
 
235
  workflow.add_conditional_edges(
236
  node,
237
  router,
@@ -240,7 +247,7 @@ class SQLAgent:
240
  "list_table_tools": "list_table_tools",
241
  "get_schema": "get_schema",
242
  "generate_query": "generate_query",
243
- "check_query": "check_query",
244
  "execute_query": "execute_query",
245
  "response": "response",
246
  END: END
 
102
  def creating_sql_agent_chain():
103
  """Creating a sql agent chain"""
104
 
105
+ # 4. check_query - Check if the query is correct
106
+ # - Query checked: {check_query}
107
+ # If query generated but not checked, respond with 'check_query'.
108
+ # If query checked but not executed, respond with 'execute_query'.
109
  print("Creating a sql agent chain")
110
  sql_agent_prompt = ChatPromptTemplate.from_messages([
111
  ("system", """You are a supervisor SQL agent managing tools to get the answer to the user's query.
 
114
  1. list_table_tools - List all tables from the database
115
  2. get_schema - Get the schema of required tables
116
  3. generate_query - Generate a SQL query
 
 
 
117
 
118
+ 4. execute_query - Execute the query
119
+ 5. response - Create response for the user
120
+
121
  Current state:
122
  - Tables listed: {tables_list}
123
  - Schema retrieved: {schema_of_table}
124
  - Query generated: {query_gen}
 
125
  - Query executed: {execute_query}
126
  - Response created: {response_to_user}
127
 
128
  If no tables are listed, respond with 'list_table_tools'.
129
  If tables are listed but no schema, respond with 'get_schema'.
130
  If schema exists but no query generated, respond with 'generate_query'.
131
+ If query generated but not executed, respond with 'execute_query'.
 
132
  If query executed but no response, respond with 'response'.
133
  If everything is complete, respond with 'DONE'.
134
 
 
151
  tables_list = bool(state.get("tables_list", "").strip())
152
  schema_of_table = bool(state.get("schema_of_table", "").strip())
153
  query_gen = bool(state.get("query_gen", "").strip())
154
+ # check_query = bool(state.get("check_query", "").strip())
155
  execute_query = bool(state.get("execute_query", "").strip())
156
  response_to_user = bool(state.get("response_to_user", "").strip())
157
 
158
+ # print(f"State check - Tables: {tables_list}, Schema: {schema_of_table}, Query: {query_gen}, Check: {check_query}, Execute: {execute_query}, Response: {response_to_user}")
159
 
160
  chain = creating_sql_agent_chain()
161
  decision = chain.invoke({
 
163
  "tables_list": tables_list,
164
  "schema_of_table": schema_of_table,
165
  "query_gen": query_gen,
166
+ # "check_query": check_query,
167
  "execute_query": execute_query,
168
  "response_to_user": response_to_user
169
  })
 
182
  elif "generate_query" in decision_text:
183
  next_tool = "generate_query"
184
  agent_msg = "πŸ“‹ SQL Agent: Generating SQL query."
185
+ # elif "check_query" in decision_text:
186
+ # next_tool = "check_query"
187
+ # agent_msg = "πŸ“‹ SQL Agent: Checking SQL query."
188
  elif "execute_query" in decision_text:
189
  next_tool = "execute_query"
190
  agent_msg = "πŸ“‹ SQL Agent: Executing query."
 
210
  if next_tool == "end" or state.get("task_complete", False):
211
  return END
212
 
213
+ # valid_tools = [
214
+ # "sql_agent", "list_table_tools", "get_schema", "generate_query",
215
+ # "check_query", "execute_query", "response"
216
+ # ]
217
  valid_tools = [
218
  "sql_agent", "list_table_tools", "get_schema", "generate_query",
219
+ "execute_query", "response"
220
  ]
221
 
222
  return next_tool if next_tool in valid_tools else "sql_agent"
 
229
  workflow.add_node("list_table_tools", self.db_tools.list_table_tools)
230
  workflow.add_node("get_schema", self.db_tools.get_schema)
231
  workflow.add_node("generate_query", self.db_tools.generate_query)
232
+ # workflow.add_node("check_query", self.db_tools.check_query)
233
  workflow.add_node("execute_query", self.db_tools.execute_query)
234
  workflow.add_node("response", self.db_tools.create_response)
235
 
 
237
  workflow.set_entry_point("sql_agent")
238
 
239
  # Add routing
240
+ # for node in ["sql_agent", "list_table_tools", "get_schema", "generate_query", "check_query", "execute_query", "response"]:
241
+ for node in ["sql_agent", "list_table_tools", "get_schema", "generate_query", "execute_query", "response"]:
242
  workflow.add_conditional_edges(
243
  node,
244
  router,
 
247
  "list_table_tools": "list_table_tools",
248
  "get_schema": "get_schema",
249
  "generate_query": "generate_query",
250
+ # "check_query": "check_query",
251
  "execute_query": "execute_query",
252
  "response": "response",
253
  END: END
app/schemas/agent_state.py CHANGED
@@ -7,7 +7,7 @@ class SQLAgentState(MessagesState):
7
  tables_list: str = ""
8
  schema_of_table: str = ""
9
  query_gen : str= ""
10
- check_query: str = ""
11
  execute_query : str = ""
12
  task_complete: bool = False
13
  response_to_user: str= ""
 
7
  tables_list: str = ""
8
  schema_of_table: str = ""
9
  query_gen : str= ""
10
+ # check_query: str = ""
11
  execute_query : str = ""
12
  task_complete: bool = False
13
  response_to_user: str= ""