Kshitijk20 commited on
Commit
4db8795
·
1 Parent(s): 2080a31

agent version2

Browse files
app/agents/sql_agent.py CHANGED
@@ -1,64 +1,384 @@
1
 
2
- from langchain_community.utilities import SQLDatabase
3
- from langchain_groq import ChatGroq
4
- from langgraph.graph import StateGraph, END, START
5
- from langchain_core.messages import AIMessage, ToolMessage, AnyMessage, HumanMessage
6
- from langgraph.graph.message import AnyMessage, add_messages
7
- from langchain_core.tools import tool
8
- from typing import Annotated, Literal, TypedDict, Any
9
- from pydantic import BaseModel, Field
10
- from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
11
- from langgraph.prebuilt import ToolNode
12
- from langchain_core.prompts import ChatPromptTemplate
13
- from langchain_community.agent_toolkits import SQLDatabaseToolkit
14
- from dotenv import load_dotenv
15
- import os
16
- from IPython.display import display
17
- import PIL
18
- from langgraph.errors import GraphRecursionError
19
- import os
20
- import io
21
- from typing import Annotated, Any, TypedDict
22
- from langgraph.graph import StateGraph, END, MessagesState
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- from IPython.display import Image, display
25
- from langchain_core.runnables.graph import MermaidDrawMethod
26
- from typing import Optional, Dict
27
 
28
- from langchain_community.utilities import SQLDatabase
29
- from langchain_community.agent_toolkits import SQLDatabaseToolkit
30
  from langchain_groq import ChatGroq
31
- from langchain_core.messages import HumanMessage, AIMessage
32
- from langchain_core.prompts import ChatPromptTemplate
33
- # from langchain_core.pydantic_v1 import BaseModel, Field
34
- from langgraph.graph import StateGraph, END, MessagesState
35
- from typing import TypedDict, Annotated, List, Literal, Dict, Any
36
  from langchain_google_genai import ChatGoogleGenerativeAI
37
- from app.schemas.agent_state import DBQuery, SQLAgentState
38
- from app.tools.database_tools import DatabaseTools
39
  from app.utils.database_connection import DatabaseConnection
40
- from dotenv import load_dotenv
 
41
  load_dotenv()
42
  import os
43
  os.environ["GROQ_API_KEY"]=os.getenv("GROQ_API_KEY")
44
  os.environ["GEMINI_API_KEY"]=os.getenv("GEMINI_API_KEY")
45
 
46
 
 
 
 
 
 
47
  class SQLAgent:
 
 
48
  def __init__(self):
49
 
50
  # Initialize instance variables
51
  self.db = None
52
- self.toolkit = None
53
- self.tools = None
54
- self.list_tables_tool = None
55
- self.sql_db_query = None
56
- self.get_schema_tool = None
57
- self.app = None
58
-
59
  # Setting up LLM
60
- # self.llm = ChatGroq(model=model,api_key = os.getenv("GROQ_API_KEY"))
61
- self.llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite", google_api_key=os.environ["GEMINI_API_KEY"])
62
  # Register the tool method
63
  # self.query_to_database = self._create_query_tool()
64
 
@@ -67,13 +387,16 @@ class SQLAgent:
67
  def setup_database_connection(self, connection_string: str):
68
  """Set up database connection and initialize tools"""
69
  try:
70
- # Initialize database connection
71
- # self.db = SQLDatabase.from_uri(connection_string)
72
- # print("Database connection successful!")
73
  self.db = DatabaseConnection(connection_string).db
74
  print("Database connection successful!")
75
- # Initialize toolkit and tools class
76
- self.db_tools = DatabaseTools(db=self.db, llm=self.llm)
 
 
 
 
 
77
 
78
  try:
79
  self.initialize_workflow()
@@ -93,252 +416,72 @@ class SQLAgent:
93
  except Exception as e:
94
  print(f"Unexpected error during database connection: {str(e)}")
95
  raise ValueError(f"Failed to establish database connection: {str(e)}")
96
-
97
- def initialize_workflow(self):
98
- """Initialize the workflow graph"""
99
 
100
- print("Intializing Workflow....")
101
-
102
- def creating_sql_agent_chain():
103
  """Creating a sql agent chain"""
104
-
105
- # 4. check_query - Check if the query is correct
106
- # - Query checked: {check_query}
107
- # If query generated but not checked, respond with 'check_query'.
108
- # If query checked but not executed, respond with 'execute_query'.
109
  print("Creating a sql agent chain")
110
- sql_agent_prompt = ChatPromptTemplate.from_messages([
111
- ("system", """You are a supervisor SQL agent managing tools to get the answer to the user's query.
 
112
 
113
- Based on the current state, decide which tool should be called next:
114
- 1. list_table_tools - List all tables from the database
115
  2. get_schema - Get the schema of required tables
116
- 3. generate_query - Generate a SQL query
117
 
118
- 4. execute_query - Execute the query
119
- 5. response - Create response for the user
120
-
121
- Current state:
122
- - Tables listed: {tables_list}
123
- - Schema retrieved: {schema_of_table}
124
- - Query generated: {query_gen}
125
- - Query executed: {execute_query}
126
- - Response created: {response_to_user}
 
127
 
128
- If no tables are listed, respond with 'list_table_tools'.
129
- If tables are listed but no schema, respond with 'get_schema'.
130
- If schema exists but no query generated, respond with 'generate_query'.
131
- If query generated but not executed, respond with 'execute_query'.
132
- If query executed but no response, respond with 'response'.
133
- If everything is complete, respond with 'DONE'.
134
-
135
- Respond with ONLY the tool name or 'DONE'.
136
- """),
137
- ("human", "{task}")
138
- ])
139
- return sql_agent_prompt | self.llm
140
-
141
- def sql_agent(state: SQLAgentState) -> Dict:
142
- """Agent decides which tool to call next"""
143
- messages = state["messages"]
144
- task = messages[-1].content if messages else "No task"
145
-
146
- # Store the original query in state if not already stored
147
- if not state.get("query"):
148
- state["query"] = task
149
-
150
- # Check what's been completed (convert to boolean properly)
151
- tables_list = bool(state.get("tables_list", "").strip())
152
- schema_of_table = bool(state.get("schema_of_table", "").strip())
153
- query_gen = bool(state.get("query_gen", "").strip())
154
- # check_query = bool(state.get("check_query", "").strip())
155
- execute_query = bool(state.get("execute_query", "").strip())
156
- response_to_user = bool(state.get("response_to_user", "").strip())
157
-
158
- # print(f"State check - Tables: {tables_list}, Schema: {schema_of_table}, Query: {query_gen}, Check: {check_query}, Execute: {execute_query}, Response: {response_to_user}")
159
 
160
- chain = creating_sql_agent_chain()
161
- decision = chain.invoke({
162
- "task": task,
163
- "tables_list": tables_list,
164
- "schema_of_table": schema_of_table,
165
- "query_gen": query_gen,
166
- # "check_query": check_query,
167
- "execute_query": execute_query,
168
- "response_to_user": response_to_user
169
- })
170
- decision_text = decision.content.strip().lower()
171
- print(f"Agent decision: {decision_text}")
172
-
173
- if "done" in decision_text:
174
- next_tool = "end"
175
- agent_msg = "✅ SQL Agent: All tasks complete!"
176
- elif "list_table_tools" in decision_text:
177
- next_tool = "list_table_tools"
178
- agent_msg = "📋 SQL Agent: Listing all tables in database."
179
- elif "get_schema" in decision_text:
180
- next_tool = "get_schema"
181
- agent_msg = "📋 SQL Agent: Getting schema of tables."
182
- elif "generate_query" in decision_text:
183
- next_tool = "generate_query"
184
- agent_msg = "📋 SQL Agent: Generating SQL query."
185
- # elif "check_query" in decision_text:
186
- # next_tool = "check_query"
187
- # agent_msg = "📋 SQL Agent: Checking SQL query."
188
- elif "execute_query" in decision_text:
189
- next_tool = "execute_query"
190
- agent_msg = "📋 SQL Agent: Executing query."
191
- elif "response" in decision_text:
192
- next_tool = "response"
193
- agent_msg = "📋 SQL Agent: Creating response."
194
- else:
195
- next_tool = "end"
196
- agent_msg = "✅ SQL Agent: Task complete."
197
-
198
- return {
199
- "messages": [AIMessage(content=agent_msg)],
200
- "next_tool": next_tool,
201
- "current_task": task
202
- }
203
 
204
- def router(state: SQLAgentState):
205
- """Route to the next node"""
206
- print("🔁 Entering router...")
207
- next_tool = state.get("next_tool", "")
208
- print(f"➡️ Next tool: {next_tool}")
209
-
210
- if next_tool == "end" or state.get("task_complete", False):
211
- return END
212
-
213
- # valid_tools = [
214
- # "sql_agent", "list_table_tools", "get_schema", "generate_query",
215
- # "check_query", "execute_query", "response"
216
- # ]
217
- valid_tools = [
218
- "sql_agent", "list_table_tools", "get_schema", "generate_query",
219
- "execute_query", "response"
220
- ]
221
-
222
- return next_tool if next_tool in valid_tools else "sql_agent"
223
 
 
224
  # Create workflow
225
- workflow = StateGraph(SQLAgentState)
226
 
227
  # Add nodes
228
- workflow.add_node("sql_agent", sql_agent)
229
- workflow.add_node("list_table_tools", self.db_tools.list_table_tools)
230
- workflow.add_node("get_schema", self.db_tools.get_schema)
231
- workflow.add_node("generate_query", self.db_tools.generate_query)
232
- # workflow.add_node("check_query", self.db_tools.check_query)
233
- workflow.add_node("execute_query", self.db_tools.execute_query)
234
- workflow.add_node("response", self.db_tools.create_response)
235
 
236
  # Set entry point
237
- workflow.set_entry_point("sql_agent")
238
-
239
- # Add routing
240
- # for node in ["sql_agent", "list_table_tools", "get_schema", "generate_query", "check_query", "execute_query", "response"]:
241
- for node in ["sql_agent", "list_table_tools", "get_schema", "generate_query", "execute_query", "response"]:
242
- workflow.add_conditional_edges(
243
- node,
244
- router,
245
- {
246
- "sql_agent": "sql_agent",
247
- "list_table_tools": "list_table_tools",
248
- "get_schema": "get_schema",
249
- "generate_query": "generate_query",
250
- # "check_query": "check_query",
251
- "execute_query": "execute_query",
252
- "response": "response",
253
- END: END
254
- }
255
- )
256
 
 
257
  # Compile the graph
258
- self.app = workflow.compile()
259
- # self.app.get_graph().draw_mermaid_png(output_file_path="sql_agent_workflow.png", draw_method=MermaidDrawMethod.API)
260
 
261
-
262
-
263
- def is_query_relevant(self, query: str) -> bool:
264
- """Check if the query is relevant to the database using the LLM."""
265
-
266
- # Retrieve the schema of the relevant tables
267
- if self.db_tools.list_tables_tool:
268
- relevant_tables = self.db_tools.list_tables_tool.invoke("")
269
- # print(relevant_tables)
270
- table_list= relevant_tables.split(", ")
271
- print(table_list)
272
- # print(agent.get_schema_tool.invoke(table_list[0]))
273
- schema = ""
274
- for table in table_list:
275
- schema+= self.db_tools.get_schema_tool.invoke(table)
276
-
277
- print(schema)
278
-
279
- # if self.get_schema_tool:
280
- # schema_response = self.get_schema_tool.invoke({})
281
- # table_schema = schema_response.content # Assuming this returns the schema as a string
282
-
283
- relevance_check_prompt = (
284
- """You are an expert SQL agent which takes user query in Natural language and find out it have releavnce with the given schema or not. Please determine if the following query is related to a database.Here is the schema of the tables present in database:\n{schema}\n\n. If the query related to given schema respond with 'yes'. Here is the query: {query}. Answer with only 'yes' or 'no'."""
285
- ).format(schema=relevant_tables, query=query)
286
-
287
- response = self.llm.invoke([{"role": "user", "content": relevance_check_prompt}])
288
-
289
- # Assuming the LLM returns a simple 'yes' or 'no'
290
- return response.content == "yes"
291
-
292
  ## called from the fastapi endpoint
293
- def execute_query(self, query: str):
294
  """Execute a query through the workflow"""
295
  if self.db is None:
296
  raise ValueError("Database connection not established. Please set up the connection first.")
297
  if self.app is None:
298
  raise ValueError("Workflow not initialized. Please set up the connection first.")
299
- # First, handle simple queries like "list tables" directly
300
- query_lower = query.lower()
301
- if any(phrase in query_lower for phrase in ["list all the tables", "show tables", "name of tables",
302
- "which tables are present", "how many tables", "list all tables"]):
303
- if self.db_tools.list_tables_tool:
304
- tables = self.db_tools.list_tables_tool.invoke("")
305
- return f"The tables in the database are: {tables}"
306
- else:
307
- return "Error: Unable to list tables. The list_tables_tool is not initialized."
308
-
309
- # Check if the query is relevant to the database
310
- if not self.is_query_relevant(query):
311
- print("Not relevent to database.")
312
- # If not relevant, let the LLM answer the question directly
313
- non_relevant_prompt = (
314
- """You are an expert SQL agent created by Kshitij Kumrawat. You can only assist with questions related to databases so repond the user with the following example resonse and Do not answer any questions that are not related to databases.:
315
- Please ask a question that pertains to database operations, such as querying tables, retrieving data, or understanding the database schema. """
316
- )
317
-
318
- # Invoke the LLM with the non-relevant prompt
319
- response = self.llm.invoke([{"role": "user", "content": non_relevant_prompt}])
320
- # print(response.content)
321
- return response.content
322
-
323
- # If relevant, proceed with the SQL workflow
324
- # response = self.app.invoke({"messages": [HumanMessage(content=query, role="user")]})
325
  response = self.app.invoke({
326
- "messages": [HumanMessage(content=query)],
327
- "query": query
328
- })
329
 
330
  return response["messages"][-1].content
331
-
332
- # # More robust final answer extraction
333
- # if (
334
- # response
335
- # and response["messages"]
336
- # and response["messages"][-1].tool_calls
337
- # and len(response["messages"][-1].tool_calls) > 0
338
- # and "args" in response["messages"][-1].tool_calls[0]
339
- # and "final_answer" in response["messages"][-1].tool_calls[0]["args"]
340
- # ):
341
- # return response["messages"][-1].tool_calls[0]["args"]["final_answer"]
342
- # else:
343
- # return "Error: Could not extract final answer."
344
-
 
1
 
2
+ # from langchain_community.utilities import SQLDatabase
3
+ # from langchain_groq import ChatGroq
4
+ # from langgraph.graph import StateGraph, END, START
5
+ # from langchain_core.messages import AIMessage, ToolMessage, AnyMessage, HumanMessage
6
+ # from langgraph.graph.message import AnyMessage, add_messages
7
+ # from langchain_core.tools import tool
8
+ # from typing import Annotated, Literal, TypedDict, Any
9
+ # from pydantic import BaseModel, Field
10
+ # from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
11
+ # from langgraph.prebuilt import ToolNode
12
+ # from langchain_core.prompts import ChatPromptTemplate
13
+ # from langchain_community.agent_toolkits import SQLDatabaseToolkit
14
+ # from dotenv import load_dotenv
15
+ # import os
16
+ # from IPython.display import display
17
+ # import PIL
18
+ # from langgraph.errors import GraphRecursionError
19
+ # import os
20
+ # import io
21
+ # from typing import Annotated, Any, TypedDict
22
+ # from langgraph.graph import StateGraph, END, MessagesState
23
+
24
+ # from IPython.display import Image, display
25
+ # from langchain_core.runnables.graph import MermaidDrawMethod
26
+ # from typing import Optional, Dict
27
+
28
+ # from langchain_community.utilities import SQLDatabase
29
+ # from langchain_community.agent_toolkits import SQLDatabaseToolkit
30
+ # from langchain_groq import ChatGroq
31
+ # from langchain_core.messages import HumanMessage, AIMessage
32
+ # from langchain_core.prompts import ChatPromptTemplate
33
+ # # from langchain_core.pydantic_v1 import BaseModel, Field
34
+ # from langgraph.graph import StateGraph, END, MessagesState
35
+ # from typing import TypedDict, Annotated, List, Literal, Dict, Any
36
+ # from langchain_google_genai import ChatGoogleGenerativeAI
37
+ # from app.schemas.agent_state import DBQuery, SQLAgentState
38
+ # from app.tools.database_tools import DatabaseTools
39
+ # from app.utils.database_connection import DatabaseConnection
40
+ # from dotenv import load_dotenv
41
+ # load_dotenv()
42
+ # import os
43
+ # os.environ["GROQ_API_KEY"]=os.getenv("GROQ_API_KEY")
44
+ # os.environ["GEMINI_API_KEY"]=os.getenv("GEMINI_API_KEY")
45
+
46
+
47
+ # class SQLAgent:
48
+ # def __init__(self):
49
+
50
+ # # Initialize instance variables
51
+ # self.db = None
52
+ # self.toolkit = None
53
+ # self.tools = None
54
+ # self.list_tables_tool = None
55
+ # self.sql_db_query = None
56
+ # self.get_schema_tool = None
57
+ # self.app = None
58
+
59
+ # # Setting up LLM
60
+ # # self.llm = ChatGroq(model=model,api_key = os.getenv("GROQ_API_KEY"))
61
+ # self.llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite", google_api_key=os.environ["GEMINI_API_KEY"])
62
+ # # Register the tool method
63
+ # # self.query_to_database = self._create_query_tool()
64
+
65
+
66
+
67
+ # def setup_database_connection(self, connection_string: str):
68
+ # """Set up database connection and initialize tools"""
69
+ # try:
70
+ # # Initialize database connection
71
+ # # self.db = SQLDatabase.from_uri(connection_string)
72
+ # # print("Database connection successful!")
73
+ # self.db = DatabaseConnection(connection_string).db
74
+ # print("Database connection successful!")
75
+ # # Initialize toolkit and tools class
76
+ # self.db_tools = DatabaseTools(db=self.db, llm=self.llm)
77
+
78
+ # try:
79
+ # self.initialize_workflow()
80
+
81
+ # return self.db
82
+
83
+ # except Exception as e:
84
+ # print(f"Error initializing tools and workflow: {str(e)}")
85
+ # raise ValueError(f"Failed to initialize database tools: {str(e)}")
86
+
87
+ # except ImportError as e:
88
+ # print(f"Database driver import error: {str(e)}")
89
+ # raise ValueError(f"Missing database driver or invalid database type: {str(e)}")
90
+ # except ValueError as e:
91
+ # print(f"Invalid connection string or configuration: {str(e)}")
92
+ # raise
93
+ # except Exception as e:
94
+ # print(f"Unexpected error during database connection: {str(e)}")
95
+ # raise ValueError(f"Failed to establish database connection: {str(e)}")
96
+
97
+ # def initialize_workflow(self):
98
+ # """Initialize the workflow graph"""
99
+
100
+ # print("Intializing Workflow....")
101
+
102
+ # def creating_sql_agent_chain():
103
+ # """Creating a sql agent chain"""
104
+
105
+ # # 4. check_query - Check if the query is correct
106
+ # # - Query checked: {check_query}
107
+ # # If query generated but not checked, respond with 'check_query'.
108
+ # # If query checked but not executed, respond with 'execute_query'.
109
+ # print("Creating a sql agent chain")
110
+ # sql_agent_prompt = ChatPromptTemplate.from_messages([
111
+ # ("system", """You are a supervisor SQL agent managing tools to get the answer to the user's query.
112
+
113
+ # Based on the current state, decide which tool should be called next:
114
+ # 1. list_table_tools - List all tables from the database
115
+ # 2. get_schema - Get the schema of required tables
116
+ # 3. generate_query - Generate a SQL query
117
+
118
+ # 4. execute_query - Execute the query
119
+ # 5. response - Create response for the user
120
+
121
+ # Current state:
122
+ # - Tables listed: {tables_list}
123
+ # - Schema retrieved: {schema_of_table}
124
+ # - Query generated: {query_gen}
125
+ # - Query executed: {execute_query}
126
+ # - Response created: {response_to_user}
127
+
128
+ # If no tables are listed, respond with 'list_table_tools'.
129
+ # If tables are listed but no schema, respond with 'get_schema'.
130
+ # If schema exists but no query generated, respond with 'generate_query'.
131
+ # If query generated but not executed, respond with 'execute_query'.
132
+ # If query executed but no response, respond with 'response'.
133
+ # If everything is complete, respond with 'DONE'.
134
+
135
+ # Respond with ONLY the tool name or 'DONE'.
136
+ # """),
137
+ # ("human", "{task}")
138
+ # ])
139
+ # return sql_agent_prompt | self.llm
140
+
141
+ # def sql_agent(state: SQLAgentState) -> Dict:
142
+ # """Agent decides which tool to call next"""
143
+ # messages = state["messages"]
144
+ # task = messages[-1].content if messages else "No task"
145
+
146
+ # # Store the original query in state if not already stored
147
+ # if not state.get("query"):
148
+ # state["query"] = task
149
+
150
+ # # Check what's been completed (convert to boolean properly)
151
+ # tables_list = bool(state.get("tables_list", "").strip())
152
+ # schema_of_table = bool(state.get("schema_of_table", "").strip())
153
+ # query_gen = bool(state.get("query_gen", "").strip())
154
+ # # check_query = bool(state.get("check_query", "").strip())
155
+ # execute_query = bool(state.get("execute_query", "").strip())
156
+ # response_to_user = bool(state.get("response_to_user", "").strip())
157
+
158
+ # # print(f"State check - Tables: {tables_list}, Schema: {schema_of_table}, Query: {query_gen}, Check: {check_query}, Execute: {execute_query}, Response: {response_to_user}")
159
+
160
+ # chain = creating_sql_agent_chain()
161
+ # decision = chain.invoke({
162
+ # "task": task,
163
+ # "tables_list": tables_list,
164
+ # "schema_of_table": schema_of_table,
165
+ # "query_gen": query_gen,
166
+ # # "check_query": check_query,
167
+ # "execute_query": execute_query,
168
+ # "response_to_user": response_to_user
169
+ # })
170
+ # decision_text = decision.content.strip().lower()
171
+ # print(f"Agent decision: {decision_text}")
172
+
173
+ # if "done" in decision_text:
174
+ # next_tool = "end"
175
+ # agent_msg = "✅ SQL Agent: All tasks complete!"
176
+ # elif "list_table_tools" in decision_text:
177
+ # next_tool = "list_table_tools"
178
+ # agent_msg = "📋 SQL Agent: Listing all tables in database."
179
+ # elif "get_schema" in decision_text:
180
+ # next_tool = "get_schema"
181
+ # agent_msg = "📋 SQL Agent: Getting schema of tables."
182
+ # elif "generate_query" in decision_text:
183
+ # next_tool = "generate_query"
184
+ # agent_msg = "📋 SQL Agent: Generating SQL query."
185
+ # # elif "check_query" in decision_text:
186
+ # # next_tool = "check_query"
187
+ # # agent_msg = "📋 SQL Agent: Checking SQL query."
188
+ # elif "execute_query" in decision_text:
189
+ # next_tool = "execute_query"
190
+ # agent_msg = "📋 SQL Agent: Executing query."
191
+ # elif "response" in decision_text:
192
+ # next_tool = "response"
193
+ # agent_msg = "📋 SQL Agent: Creating response."
194
+ # else:
195
+ # next_tool = "end"
196
+ # agent_msg = "✅ SQL Agent: Task complete."
197
+
198
+ # return {
199
+ # "messages": [AIMessage(content=agent_msg)],
200
+ # "next_tool": next_tool,
201
+ # "current_task": task
202
+ # }
203
+
204
+ # def router(state: SQLAgentState):
205
+ # """Route to the next node"""
206
+ # print("🔁 Entering router...")
207
+ # next_tool = state.get("next_tool", "")
208
+ # print(f"➡️ Next tool: {next_tool}")
209
+
210
+ # if next_tool == "end" or state.get("task_complete", False):
211
+ # return END
212
+
213
+ # # valid_tools = [
214
+ # # "sql_agent", "list_table_tools", "get_schema", "generate_query",
215
+ # # "check_query", "execute_query", "response"
216
+ # # ]
217
+ # valid_tools = [
218
+ # "sql_agent", "list_table_tools", "get_schema", "generate_query",
219
+ # "execute_query", "response"
220
+ # ]
221
+
222
+ # return next_tool if next_tool in valid_tools else "sql_agent"
223
+
224
+ # # Create workflow
225
+ # workflow = StateGraph(SQLAgentState)
226
+
227
+ # # Add nodes
228
+ # workflow.add_node("sql_agent", sql_agent)
229
+ # workflow.add_node("list_table_tools", self.db_tools.list_table_tools)
230
+ # workflow.add_node("get_schema", self.db_tools.get_schema)
231
+ # workflow.add_node("generate_query", self.db_tools.generate_query)
232
+ # # workflow.add_node("check_query", self.db_tools.check_query)
233
+ # workflow.add_node("execute_query", self.db_tools.execute_query)
234
+ # workflow.add_node("response", self.db_tools.create_response)
235
+
236
+ # # Set entry point
237
+ # workflow.set_entry_point("sql_agent")
238
+
239
+ # # Add routing
240
+ # # for node in ["sql_agent", "list_table_tools", "get_schema", "generate_query", "check_query", "execute_query", "response"]:
241
+ # for node in ["sql_agent", "list_table_tools", "get_schema", "generate_query", "execute_query", "response"]:
242
+ # workflow.add_conditional_edges(
243
+ # node,
244
+ # router,
245
+ # {
246
+ # "sql_agent": "sql_agent",
247
+ # "list_table_tools": "list_table_tools",
248
+ # "get_schema": "get_schema",
249
+ # "generate_query": "generate_query",
250
+ # # "check_query": "check_query",
251
+ # "execute_query": "execute_query",
252
+ # "response": "response",
253
+ # END: END
254
+ # }
255
+ # )
256
+
257
+ # # Compile the graph
258
+ # self.app = workflow.compile()
259
+ # # self.app.get_graph().draw_mermaid_png(output_file_path="sql_agent_workflow.png", draw_method=MermaidDrawMethod.API)
260
+
261
+
262
+
263
+ # def is_query_relevant(self, query: str) -> bool:
264
+ # """Check if the query is relevant to the database using the LLM."""
265
+
266
+ # # Retrieve the schema of the relevant tables
267
+ # if self.db_tools.list_tables_tool:
268
+ # relevant_tables = self.db_tools.list_tables_tool.invoke("")
269
+ # # print(relevant_tables)
270
+ # table_list= relevant_tables.split(", ")
271
+ # print(table_list)
272
+ # # print(agent.get_schema_tool.invoke(table_list[0]))
273
+ # schema = ""
274
+ # for table in table_list:
275
+ # schema+= self.db_tools.get_schema_tool.invoke(table)
276
+
277
+ # print(schema)
278
+
279
+ # # if self.get_schema_tool:
280
+ # # schema_response = self.get_schema_tool.invoke({})
281
+ # # table_schema = schema_response.content # Assuming this returns the schema as a string
282
+
283
+ # relevance_check_prompt = (
284
+ # """You are an expert SQL agent which takes user query in Natural language and find out it have releavnce with the given schema or not. Please determine if the following query is related to a database.Here is the schema of the tables present in database:\n{schema}\n\n. If the query related to given schema respond with 'yes'. Here is the query: {query}. Answer with only 'yes' or 'no'."""
285
+ # ).format(schema=relevant_tables, query=query)
286
+
287
+ # response = self.llm.invoke([{"role": "user", "content": relevance_check_prompt}])
288
+
289
+ # # Assuming the LLM returns a simple 'yes' or 'no'
290
+ # return response.content == "yes"
291
+
292
+ # ## called from the fastapi endpoint
293
+ # def execute_query(self, query: str):
294
+ # """Execute a query through the workflow"""
295
+ # if self.db is None:
296
+ # raise ValueError("Database connection not established. Please set up the connection first.")
297
+ # if self.app is None:
298
+ # raise ValueError("Workflow not initialized. Please set up the connection first.")
299
+ # # First, handle simple queries like "list tables" directly
300
+ # query_lower = query.lower()
301
+ # if any(phrase in query_lower for phrase in ["list all the tables", "show tables", "name of tables",
302
+ # "which tables are present", "how many tables", "list all tables"]):
303
+ # if self.db_tools.list_tables_tool:
304
+ # tables = self.db_tools.list_tables_tool.invoke("")
305
+ # return f"The tables in the database are: {tables}"
306
+ # else:
307
+ # return "Error: Unable to list tables. The list_tables_tool is not initialized."
308
+
309
+ # # Check if the query is relevant to the database
310
+ # if not self.is_query_relevant(query):
311
+ # print("Not relevent to database.")
312
+ # # If not relevant, let the LLM answer the question directly
313
+ # non_relevant_prompt = (
314
+ # """You are an expert SQL agent created by Kshitij Kumrawat. You can only assist with questions related to databases so repond the user with the following example resonse and Do not answer any questions that are not related to databases.:
315
+ # Please ask a question that pertains to database operations, such as querying tables, retrieving data, or understanding the database schema. """
316
+ # )
317
+
318
+ # # Invoke the LLM with the non-relevant prompt
319
+ # response = self.llm.invoke([{"role": "user", "content": non_relevant_prompt}])
320
+ # # print(response.content)
321
+ # return response.content
322
+
323
+ # # If relevant, proceed with the SQL workflow
324
+ # # response = self.app.invoke({"messages": [HumanMessage(content=query, role="user")]})
325
+ # response = self.app.invoke({
326
+ # "messages": [HumanMessage(content=query)],
327
+ # "query": query
328
+ # })
329
+
330
+ # return response["messages"][-1].content
331
+
332
+ # # # More robust final answer extraction
333
+ # # if (
334
+ # # response
335
+ # # and response["messages"]
336
+ # # and response["messages"][-1].tool_calls
337
+ # # and len(response["messages"][-1].tool_calls) > 0
338
+ # # and "args" in response["messages"][-1].tool_calls[0]
339
+ # # and "final_answer" in response["messages"][-1].tool_calls[0]["args"]
340
+ # # ):
341
+ # # return response["messages"][-1].tool_calls[0]["args"]["final_answer"]
342
+ # # else:
343
+ # # return "Error: Could not extract final answer."
344
+
345
 
 
 
 
346
 
 
 
347
  from langchain_groq import ChatGroq
348
+ from langgraph.graph import StateGraph, END, START, MessagesState
349
+ from langchain_core.messages import AIMessage, ToolMessage, AnyMessage, HumanMessage, SystemMessage
350
+ from dotenv import load_dotenv
351
+ import os
352
+ from IPython.display import display, Image
353
  from langchain_google_genai import ChatGoogleGenerativeAI
354
+ # from app.tools.database_tools import DatabaseTools
 
355
  from app.utils.database_connection import DatabaseConnection
356
+ from app.tools.database_tools_v2 import DatabaseTools
357
+
358
  load_dotenv()
359
  import os
360
  os.environ["GROQ_API_KEY"]=os.getenv("GROQ_API_KEY")
361
  os.environ["GEMINI_API_KEY"]=os.getenv("GEMINI_API_KEY")
362
 
363
 
364
+ from langgraph.graph import MessagesState
365
+ from langgraph.prebuilt import tools_condition, ToolNode
366
+ from langgraph.checkpoint.memory import MemorySaver
367
+
368
+
369
  class SQLAgent:
370
+
371
+
372
  def __init__(self):
373
 
374
  # Initialize instance variables
375
  self.db = None
376
+ # self.repl = PythonREPL()
377
+ # self.code = None
378
+
 
 
 
 
379
  # Setting up LLM
380
+ self.llm = ChatGroq(model="openai/gpt-oss-120b",api_key = os.getenv("GROQ_API_KEY"))
381
+ # self.llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite", google_api_key=os.environ["GEMINI_API_KEY"])
382
  # Register the tool method
383
  # self.query_to_database = self._create_query_tool()
384
 
 
387
  def setup_database_connection(self, connection_string: str):
388
  """Set up database connection and initialize tools"""
389
  try:
390
+
 
 
391
  self.db = DatabaseConnection(connection_string).db
392
  print("Database connection successful!")
393
+ self.db_tools = DatabaseTools(db=self.db, llm=self.llm)
394
+ self.list_tables_tool = self.db_tools.list_tables
395
+ self.schema_tool = self.db_tools.get_schema
396
+ self.execute_query_tools = self.db_tools.execute_query
397
+ self.tools_list = [self.list_tables_tool, self.schema_tool, self.execute_query_tools]
398
+
399
+
400
 
401
  try:
402
  self.initialize_workflow()
 
416
  except Exception as e:
417
  print(f"Unexpected error during database connection: {str(e)}")
418
  raise ValueError(f"Failed to establish database connection: {str(e)}")
 
 
 
419
 
420
+ def sql_agent(self, state: MessagesState):
 
 
421
  """Creating a sql agent chain"""
422
+
 
 
 
 
423
  print("Creating a sql agent chain")
424
+ self.llm_with_tools = self.llm.bind_tools(self.tools_list)
425
+
426
+ sys_msg = SystemMessage(content = f"""You are a supervisor SQL agent managing tools to get the answer to the user's query.
427
 
428
+ You posses the following tools :
429
+ 1. list_tables - List all tables from the database
430
  2. get_schema - Get the schema of required tables
431
+ 3. execute_query - Execute the SQL query
432
 
433
+ The following are instructions to help you decide which tool to use next:
434
+ - Always breakdown the user query into smaller sub-tasks and decide which tool should be called next to accomplish each sub-task.
435
+ - Always list down the tables, never assume any table names or believe on users assuming table names because they can be incorrect.
436
+ - Dont make any schema assumptions, always get the schema using the get_schema tool before generating any query of the required table.
437
+ - Use the execute_query tool to run the final query and get results.
438
+ - If a query execution fails, analyze the error message, adjust the query accordingly, and try executing it again.
439
+
440
+ Dont do :
441
+ - Dont go off topic, always stick to the user query.
442
+ - Dont answer any unwanted queries of user, stick to the database related queries only.
443
 
444
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
+ return {"messages": [self.llm_with_tools.invoke([sys_msg] + state["messages"])]}
447
+
448
+ def initialize_workflow(self):
449
+ """Initialize the workflow graph"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
+ memory = MemorySaver()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
+ print("Intializing Workflow....")
454
  # Create workflow
455
+ workflow = StateGraph(MessagesState)
456
 
457
  # Add nodes
458
+ workflow.add_node("sql_agent", self.sql_agent)
459
+ workflow.add_node("tools", ToolNode(tools=self.tools_list))
 
 
 
 
 
460
 
461
  # Set entry point
462
+ workflow.add_edge(START, "sql_agent")
463
+ workflow.add_conditional_edges(
464
+ "sql_agent",
465
+ # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
466
+ # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
467
+ tools_condition,
468
+ )
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
+ workflow.add_edge("tools", "sql_agent")
471
  # Compile the graph
472
+ self.app = workflow.compile(checkpointer = memory)
473
+ # display(Image(self.app.get_graph(xray=True).draw_mermaid_png()))
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  ## called from the fastapi endpoint
476
+ def execute_query(self, query: str, config: dict):
477
  """Execute a query through the workflow"""
478
  if self.db is None:
479
  raise ValueError("Database connection not established. Please set up the connection first.")
480
  if self.app is None:
481
  raise ValueError("Workflow not initialized. Please set up the connection first.")
482
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  response = self.app.invoke({
484
+ "messages": [HumanMessage(content=query)]
485
+ }, config=config)
 
486
 
487
  return response["messages"][-1].content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/api/v1/endpoints/sql_query.py CHANGED
@@ -1,37 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # from fastapi import APIRouter, HTTPException
2
- # from app.models import SQLQueryRequest, SQLQueryResponse
3
- # from app.services.sql_agent import execute_query
4
 
5
  # router = APIRouter()
6
 
 
 
 
 
 
 
7
  # @router.post("/query", response_model=SQLQueryResponse)
8
  # async def query_database(request: SQLQueryRequest):
9
  # try:
10
- # result = execute_query(request.query)
11
  # return SQLQueryResponse(result=result)
12
  # except ValueError as e:
13
  # raise HTTPException(status_code=400, detail=str(e))
14
  # except Exception as e:
15
  # raise HTTPException(status_code=500, detail=str(e))
16
 
17
- # app/api/v1/endpoints/sql_query.py
18
  from fastapi import APIRouter, HTTPException
19
  from pydantic import BaseModel
20
  from app.services.sql_agent_instance import sql_agent
21
-
 
22
  router = APIRouter()
23
 
24
  class SQLQueryRequest(BaseModel):
25
  query: str
 
26
 
27
  class SQLQueryResponse(BaseModel):
28
  result: str
 
29
 
30
  @router.post("/query", response_model=SQLQueryResponse)
31
  async def query_database(request: SQLQueryRequest):
32
  try:
33
- result = sql_agent.execute_query(request.query)
34
- return SQLQueryResponse(result=result)
 
 
 
 
 
35
  except ValueError as e:
36
  raise HTTPException(status_code=400, detail=str(e))
37
  except Exception as e:
 
1
+ # # from fastapi import APIRouter, HTTPException
2
+ # # from app.models import SQLQueryRequest, SQLQueryResponse
3
+ # # from app.services.sql_agent import execute_query
4
+
5
+ # # router = APIRouter()
6
+
7
+ # # @router.post("/query", response_model=SQLQueryResponse)
8
+ # # async def query_database(request: SQLQueryRequest):
9
+ # # try:
10
+ # # result = execute_query(request.query)
11
+ # # return SQLQueryResponse(result=result)
12
+ # # except ValueError as e:
13
+ # # raise HTTPException(status_code=400, detail=str(e))
14
+ # # except Exception as e:
15
+ # # raise HTTPException(status_code=500, detail=str(e))
16
+
17
+ # # app/api/v1/endpoints/sql_query.py
18
  # from fastapi import APIRouter, HTTPException
19
+ # from pydantic import BaseModel
20
+ # from app.services.sql_agent_instance import sql_agent
21
 
22
  # router = APIRouter()
23
 
24
+ # class SQLQueryRequest(BaseModel):
25
+ # query: str
26
+
27
+ # class SQLQueryResponse(BaseModel):
28
+ # result: str
29
+
30
  # @router.post("/query", response_model=SQLQueryResponse)
31
  # async def query_database(request: SQLQueryRequest):
32
  # try:
33
+ # result = sql_agent.execute_query(request.query)
34
  # return SQLQueryResponse(result=result)
35
  # except ValueError as e:
36
  # raise HTTPException(status_code=400, detail=str(e))
37
  # except Exception as e:
38
  # raise HTTPException(status_code=500, detail=str(e))
39
 
 
40
  from fastapi import APIRouter, HTTPException
41
  from pydantic import BaseModel
42
  from app.services.sql_agent_instance import sql_agent
43
+ from typing import Optional
44
+ import uuid
45
  router = APIRouter()
46
 
47
  class SQLQueryRequest(BaseModel):
48
  query: str
49
+ thread_id: Optional[str] = None
50
 
51
  class SQLQueryResponse(BaseModel):
52
  result: str
53
+ thread_id: str ## client can use this to continue the conversation
54
 
55
  @router.post("/query", response_model=SQLQueryResponse)
56
  async def query_database(request: SQLQueryRequest):
57
  try:
58
+ ## generate if not provided thread id
59
+ thread_id = request.thread_id or str(uuid.uuid4())
60
+ ## add debug
61
+ print(f"Thread ID: {thread_id}, Query: {request.query}")
62
+ result = sql_agent.execute_query(request.query, config={"configurable": {"thread_id": thread_id}})
63
+ print(f"Result: {result}")
64
+ return SQLQueryResponse(result=result, thread_id=thread_id)
65
  except ValueError as e:
66
  raise HTTPException(status_code=400, detail=str(e))
67
  except Exception as e:
app/tools/database_tools.py CHANGED
@@ -1,9 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ## creating database tools
2
  from langchain_core.tools import tool
3
  from app.schemas.agent_state import SQLAgentState
4
  from typing import Dict
5
- from langchain_core.messages import AIMessage, HumanMessage
6
- from app.utils.database_connection import DatabaseConnection
7
  from langchain_community.agent_toolkits import SQLDatabaseToolkit
8
  from app.schemas.agent_state import DBQuery
9
  from langchain_core.prompts import ChatPromptTemplate
@@ -12,7 +248,8 @@ class DatabaseTools:
12
  def __init__(self,db = None, llm = None):
13
  self.db = db
14
  self.llm = llm
15
- self._create_query_tool = self._create_query_tool()
 
16
  try:
17
  # Initialize toolkit and tools
18
  self.toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
@@ -34,60 +271,51 @@ class DatabaseTools:
34
  except Exception as e:
35
  print(f"Error initializing tools and workflow: {str(e)}")
36
  raise ValueError(f"Failed to initialize database tools: {str(e)}")
37
-
38
- def _create_query_tool(self):
39
- """Create the query tool bound to this instance"""
40
- print("creating _create_query_tool")
41
- @tool
42
- def query_to_database(query: str) -> str:
43
- """
44
- Execute a SQL query against the database and return the result.
45
- If the query is invalid or returns no result, an error message will be returned.
46
- In case of an error, the user is advised to rewrite the query and try again.
47
- """
48
- if self.db is None:
49
- return "Error: Database connection not established. Please set up the connection first."
50
- result = self.db.run_no_throw(query)
51
- if not result:
52
- return "Error: Query failed. Please rewrite your query and try again."
53
- return result
54
-
55
- return query_to_database
56
-
57
- def list_table_tools(self, state: SQLAgentState = None) -> Dict:
58
  """List all the tables"""
59
  tables_list = self.list_tables_tool.invoke("")
60
  print(f"Tables found: {tables_list}")
61
- return {
62
- "messages": [AIMessage(content=f"Tables found: {tables_list}")],
63
- "tables_list": tables_list,
64
- "next_tool": "sql_agent"
65
- }
66
 
67
- def get_schema(self,state: SQLAgentState) -> Dict:
68
  """Get the schema of required tables"""
69
  print("📘 Getting schema...")
70
- tables_list = state.get("tables_list", "")
71
- if not tables_list:
72
- tables_list = self.list_tables_tool.invoke("")
73
 
74
  tables = [table.strip() for table in tables_list.split(",")]
75
- full_schema = ""
76
 
77
  for table in tables:
78
  try:
79
  schema = self.get_schema_tool.invoke(table)
80
- full_schema += f"\nTable: {table}\n{schema}\n"
81
  except Exception as e:
82
  print(f"Error getting schema for {table}: {e}")
83
 
84
- print(f"📘 Schema collected for tables: {tables}")
85
- return {
86
- "messages": [AIMessage(content=f"Schema retrieved: {full_schema}")],
87
- "schema_of_table": full_schema,
88
- "tables_list": tables_list,
89
- "next_tool": "sql_agent"
90
- }
91
  def generate_query(self, state: SQLAgentState) -> Dict:
92
  """Generate a SQL Query according to the user query"""
93
  schema = state.get("schema_of_table", "")
@@ -141,95 +369,25 @@ class DatabaseTools:
141
  "next_tool": "sql_agent"
142
  }
143
 
144
- def check_query(self,state: SQLAgentState) -> Dict:
145
- """Check if the query is correct"""
146
- query = state.get("query_gen", "")
147
- print(f"Checking query: {query}")
148
-
149
- if not query:
150
- return {
151
- "messages": [AIMessage(content="No query to check")],
152
- "check_query": "",
153
- "next_tool": "sql_agent"
154
- }
155
-
156
- try:
157
- checked_query = self.query_checker_tool.invoke(query)
158
- ## if checked query contains ``` anywhere remove it
159
- if "```" in checked_query:
160
- checked_query = checked_query.replace("```", "")
161
- print(f"Query checked: {checked_query}")
162
- return {
163
- "messages": [AIMessage(content=f"Query checked: {checked_query}")],
164
- "check_query": checked_query if checked_query else query,
165
- "next_tool": "sql_agent"
166
- }
167
- except Exception as e:
168
- print(f"Error checking query: {e}")
169
- return {
170
- "messages": [AIMessage(content="Query check failed, using original query")],
171
- "check_query": query,
172
- "next_tool": "sql_agent"
173
- }
174
-
175
- def execute_query(self,state: SQLAgentState) -> Dict:
176
- """Execute the SQL query"""
177
- query = state.get("check_query", "") or state.get("query_gen", "")
178
- print(f"Executing query: {query}")
179
 
180
- if not query:
181
- return {
182
- "messages": [AIMessage(content="No query to execute")],
183
- "execute_query": "",
184
- "next_tool": "sql_agent"
185
- }
186
 
187
  try:
188
  results = self.query_tool.invoke(query)
189
  print(f"Query results: {results}")
190
- return {
191
- "messages": [AIMessage(content=f"Query executed successfully: {results}")],
192
- "execute_query": results,
193
- "next_tool": "sql_agent"
194
- }
195
  except Exception as e:
196
  print(f"Error executing query: {e}")
197
- return {
198
- "messages": [AIMessage(content=f"Query execution failed: {e}")],
199
- "execute_query": "",
200
- "next_tool": "sql_agent"
201
- }
202
- def create_response(self,state: SQLAgentState) -> Dict:
203
- """Create a final response for the user"""
204
- print("Creating final response...")
205
-
206
- query = state.get("check_query", "") or state.get("query_gen", "")
207
- result = state.get("execute_query", "")
208
- human_query = state.get("query", "")
209
-
210
- response_prompt = f"""Create a clear, concise response for the user based on:
211
 
212
- User Question: {human_query}
213
- SQL Query: {query}
214
- Query Result: {result}
215
-
216
- Provide a natural language answer that directly addresses the user's question. Make sure to provide only answer to human question, no any internal process results and explaination, just answer related to the human query."""
217
-
218
- try:
219
- response = self.llm.invoke([HumanMessage(content=response_prompt)])
220
- print(f"Response created: {response.content}")
221
-
222
- return {
223
- "messages": [response],
224
- "response_to_user": response.content,
225
- "next_tool": "sql_agent",
226
- "task_complete": True
227
- }
228
- except Exception as e:
229
- print(f"Error creating response: {e}")
230
- return {
231
- "messages": [AIMessage(content="Failed to create response")],
232
- "response_to_user": "",
233
- "next_tool": "sql_agent",
234
- "task_complete": True
235
- }
 
1
+ # ## creating database tools
2
+ # from langchain_core.tools import tool
3
+ # from app.schemas.agent_state import SQLAgentState
4
+ # from typing import Dict
5
+ # from langchain_core.messages import AIMessage, HumanMessage
6
+ # from app.utils.database_connection import DatabaseConnection
7
+ # from langchain_community.agent_toolkits import SQLDatabaseToolkit
8
+ # from app.schemas.agent_state import DBQuery
9
+ # from langchain_core.prompts import ChatPromptTemplate
10
+
11
+ # class DatabaseTools:
12
+ # def __init__(self,db = None, llm = None):
13
+ # self.db = db
14
+ # self.llm = llm
15
+ # self._create_query_tool = self._create_query_tool()
16
+ # try:
17
+ # # Initialize toolkit and tools
18
+ # self.toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
19
+ # self.tools = self.toolkit.get_tools()
20
+ # for tool in self.tools:
21
+ # print(f"Initialized tool: {tool.name}")
22
+
23
+ # # Create instances of the tools
24
+ # self.list_tables_tool = next((tool for tool in self.tools if tool.name == "sql_db_list_tables"), None)
25
+ # self.query_tool = next((tool for tool in self.tools if tool.name == "sql_db_query"), None)
26
+ # self.get_schema_tool = next((tool for tool in self.tools if tool.name == "sql_db_schema"), None)
27
+ # self.query_checker_tool = next((tool for tool in self.tools if tool.name == "sql_db_query_checker"), None)
28
+ # if not all([self.list_tables_tool, self.query_tool, self.get_schema_tool, self.query_checker_tool]):
29
+ # raise ValueError("Failed to initialize one or more required database tools")
30
+
31
+ # # # Initialize workflow and compile it into an app
32
+ # # self.initialize_workflow()
33
+
34
+ # except Exception as e:
35
+ # print(f"Error initializing tools and workflow: {str(e)}")
36
+ # raise ValueError(f"Failed to initialize database tools: {str(e)}")
37
+
38
+ # def _create_query_tool(self):
39
+ # """Create the query tool bound to this instance"""
40
+ # print("creating _create_query_tool")
41
+ # @tool
42
+ # def query_to_database(query: str) -> str:
43
+ # """
44
+ # Execute a SQL query against the database and return the result.
45
+ # If the query is invalid or returns no result, an error message will be returned.
46
+ # In case of an error, the user is advised to rewrite the query and try again.
47
+ # """
48
+ # if self.db is None:
49
+ # return "Error: Database connection not established. Please set up the connection first."
50
+ # result = self.db.run_no_throw(query)
51
+ # if not result:
52
+ # return "Error: Query failed. Please rewrite your query and try again."
53
+ # return result
54
+
55
+ # return query_to_database
56
+
57
+ # def list_table_tools(self, state: SQLAgentState = None) -> Dict:
58
+ # """List all the tables"""
59
+ # tables_list = self.list_tables_tool.invoke("")
60
+ # print(f"Tables found: {tables_list}")
61
+ # return {
62
+ # "messages": [AIMessage(content=f"Tables found: {tables_list}")],
63
+ # "tables_list": tables_list,
64
+ # "next_tool": "sql_agent"
65
+ # }
66
+
67
+ # def get_schema(self,state: SQLAgentState) -> Dict:
68
+ # """Get the schema of required tables"""
69
+ # print("📘 Getting schema...")
70
+ # tables_list = state.get("tables_list", "")
71
+ # if not tables_list:
72
+ # tables_list = self.list_tables_tool.invoke("")
73
+
74
+ # tables = [table.strip() for table in tables_list.split(",")]
75
+ # full_schema = ""
76
+
77
+ # for table in tables:
78
+ # try:
79
+ # schema = self.get_schema_tool.invoke(table)
80
+ # full_schema += f"\nTable: {table}\n{schema}\n"
81
+ # except Exception as e:
82
+ # print(f"Error getting schema for {table}: {e}")
83
+
84
+ # print(f"📘 Schema collected for tables: {tables}")
85
+ # return {
86
+ # "messages": [AIMessage(content=f"Schema retrieved: {full_schema}")],
87
+ # "schema_of_table": full_schema,
88
+ # "tables_list": tables_list,
89
+ # "next_tool": "sql_agent"
90
+ # }
91
+ # def generate_query(self, state: SQLAgentState) -> Dict:
92
+ # """Generate a SQL Query according to the user query"""
93
+ # schema = state.get("schema_of_table", "")
94
+ # human_query = state.get("query", "")
95
+ # tables = state.get("tables_list", "")
96
+
97
+ # print(f"Generating query for: {human_query}")
98
+
99
+ # generate_query_system_prompt = """You are a SQL expert that generates precise SQL queries based on user questions.
100
+
101
+ # You will be provided with:
102
+ # - User's question
103
+ # - Available tables
104
+ # - Complete schema information
105
+
106
+ # Generate a SQL query that:
107
+ # - Uses correct column names from schema
108
+ # - Properly joins tables if needed
109
+ # - Includes appropriate WHERE clauses
110
+ # - Uses proper aggregation functions when needed
111
+
112
+ # Respond ONLY with the SQL query. Do not explain."""
113
+
114
+ # combined_input = f"""
115
+ # User Question: {human_query}
116
+ # Tables: {tables}
117
+ # Schema: {schema}
118
+ # """
119
+
120
+ # generate_query_prompt = ChatPromptTemplate.from_messages([
121
+ # ("system", generate_query_system_prompt),
122
+ # ("human", "{input}")
123
+ # ])
124
+
125
+ # try:
126
+ # formatted_prompt = generate_query_prompt.invoke({"input": combined_input})
127
+ # generate_query_llm = self.llm.with_structured_output(DBQuery)
128
+ # result = generate_query_llm.invoke(formatted_prompt)
129
+
130
+ # print(f"✅ Query generated: {result.query}")
131
+ # return {
132
+ # "messages": [AIMessage(content=f"Query generated: {result.query}")],
133
+ # "query_gen": result.query,
134
+ # "next_tool": "sql_agent"
135
+ # }
136
+ # except Exception as e:
137
+ # print(f"❌ Failed to generate query: {e}")
138
+ # return {
139
+ # "messages": [AIMessage(content="⚠️ Failed to generate SQL query.")],
140
+ # "query_gen": "",
141
+ # "next_tool": "sql_agent"
142
+ # }
143
+
144
+ # def check_query(self,state: SQLAgentState) -> Dict:
145
+ # """Check if the query is correct"""
146
+ # query = state.get("query_gen", "")
147
+ # print(f"Checking query: {query}")
148
+
149
+ # if not query:
150
+ # return {
151
+ # "messages": [AIMessage(content="No query to check")],
152
+ # "check_query": "",
153
+ # "next_tool": "sql_agent"
154
+ # }
155
+
156
+ # try:
157
+ # checked_query = self.query_checker_tool.invoke(query)
158
+ # ## if checked query contains ``` anywhere remove it
159
+ # if "```" in checked_query:
160
+ # checked_query = checked_query.replace("```", "")
161
+ # print(f"Query checked: {checked_query}")
162
+ # return {
163
+ # "messages": [AIMessage(content=f"Query checked: {checked_query}")],
164
+ # "check_query": checked_query if checked_query else query,
165
+ # "next_tool": "sql_agent"
166
+ # }
167
+ # except Exception as e:
168
+ # print(f"Error checking query: {e}")
169
+ # return {
170
+ # "messages": [AIMessage(content="Query check failed, using original query")],
171
+ # "check_query": query,
172
+ # "next_tool": "sql_agent"
173
+ # }
174
+
175
+ # def execute_query(self,state: SQLAgentState) -> Dict:
176
+ # """Execute the SQL query"""
177
+ # query = state.get("check_query", "") or state.get("query_gen", "")
178
+ # print(f"Executing query: {query}")
179
+
180
+ # if not query:
181
+ # return {
182
+ # "messages": [AIMessage(content="No query to execute")],
183
+ # "execute_query": "",
184
+ # "next_tool": "sql_agent"
185
+ # }
186
+
187
+ # try:
188
+ # results = self.query_tool.invoke(query)
189
+ # print(f"Query results: {results}")
190
+ # return {
191
+ # "messages": [AIMessage(content=f"Query executed successfully: {results}")],
192
+ # "execute_query": results,
193
+ # "next_tool": "sql_agent"
194
+ # }
195
+ # except Exception as e:
196
+ # print(f"Error executing query: {e}")
197
+ # return {
198
+ # "messages": [AIMessage(content=f"Query execution failed: {e}")],
199
+ # "execute_query": "",
200
+ # "next_tool": "sql_agent"
201
+ # }
202
+ # def create_response(self,state: SQLAgentState) -> Dict:
203
+ # """Create a final response for the user"""
204
+ # print("Creating final response...")
205
+
206
+ # query = state.get("check_query", "") or state.get("query_gen", "")
207
+ # result = state.get("execute_query", "")
208
+ # human_query = state.get("query", "")
209
+
210
+ # response_prompt = f"""Create a clear, concise response for the user based on:
211
+
212
+ # User Question: {human_query}
213
+ # SQL Query: {query}
214
+ # Query Result: {result}
215
+
216
+ # Provide a natural language answer that directly addresses the user's question. Make sure to provide only answer to human question, no any internal process results and explaination, just answer related to the human query."""
217
+
218
+ # try:
219
+ # response = self.llm.invoke([HumanMessage(content=response_prompt)])
220
+ # print(f"Response created: {response.content}")
221
+
222
+ # return {
223
+ # "messages": [response],
224
+ # "response_to_user": response.content,
225
+ # "next_tool": "sql_agent",
226
+ # "task_complete": True
227
+ # }
228
+ # except Exception as e:
229
+ # print(f"Error creating response: {e}")
230
+ # return {
231
+ # "messages": [AIMessage(content="Failed to create response")],
232
+ # "response_to_user": "",
233
+ # "next_tool": "sql_agent",
234
+ # "task_complete": True
235
+ # }
236
+
237
+
238
  ## creating database tools
239
  from langchain_core.tools import tool
240
  from app.schemas.agent_state import SQLAgentState
241
  from typing import Dict
242
+ from langchain_core.messages import AIMessage
 
243
  from langchain_community.agent_toolkits import SQLDatabaseToolkit
244
  from app.schemas.agent_state import DBQuery
245
  from langchain_core.prompts import ChatPromptTemplate
 
248
  def __init__(self,db = None, llm = None):
249
  self.db = db
250
  self.llm = llm
251
+ # self._create_query_tool = self._create_query_tool()
252
+ self.tools = self.get_all_tools()
253
  try:
254
  # Initialize toolkit and tools
255
  self.toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
 
271
  except Exception as e:
272
  print(f"Error initializing tools and workflow: {str(e)}")
273
  raise ValueError(f"Failed to initialize database tools: {str(e)}")
274
+ # @tool
275
+ # def _create_query_tool(self):
276
+ # """Create the query tool bound to this instance"""
277
+ # print("creating _create_query_tool")
278
+ # @tool
279
+ # def query_to_database(query: str) -> str:
280
+ # """
281
+ # Execute a SQL query against the database and return the result.
282
+ # If the query is invalid or returns no result, an error message will be returned.
283
+ # In case of an error, the user is advised to rewrite the query and try again.
284
+ # """
285
+ # if self.db is None:
286
+ # return "Error: Database connection not established. Please set up the connection first."
287
+ # result = self.db.run_no_throw(query)
288
+ # if not result:
289
+ # return "Error: Query failed. Please rewrite your query and try again."
290
+ # return result
291
+
292
+ # return query_to_database
293
+ def list_tables(self) -> Dict:
 
294
  """List all the tables"""
295
  tables_list = self.list_tables_tool.invoke("")
296
  print(f"Tables found: {tables_list}")
297
+ return tables_list
 
 
 
 
298
 
299
+ def get_schema(self, table_name: list[str]) -> Dict:
300
  """Get the schema of required tables"""
301
  print("📘 Getting schema...")
302
+ tables_list = self.list_tables_tool.invoke("")
303
+ if any(table not in tables_list for table in table_name):
304
+ return "Table not exits in database"
305
 
306
  tables = [table.strip() for table in tables_list.split(",")]
307
+ required_schema = ""
308
 
309
  for table in tables:
310
  try:
311
  schema = self.get_schema_tool.invoke(table)
312
+ required_schema += f"\nTable: {table}\n{schema}\n"
313
  except Exception as e:
314
  print(f"Error getting schema for {table}: {e}")
315
 
316
+ return required_schema
317
+
318
+
 
 
 
 
319
  def generate_query(self, state: SQLAgentState) -> Dict:
320
  """Generate a SQL Query according to the user query"""
321
  schema = state.get("schema_of_table", "")
 
369
  "next_tool": "sql_agent"
370
  }
371
 
372
+
373
+ def execute_query(self,query: str) -> Dict:
374
+ """Execute the SQL query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
+ Arguments:
377
+ query -- The SQL query to execute
378
+
379
+ returns:
380
+ execution results
381
+ """
382
 
383
  try:
384
  results = self.query_tool.invoke(query)
385
  print(f"Query results: {results}")
386
+ return results
 
 
 
 
387
  except Exception as e:
388
  print(f"Error executing query: {e}")
389
+ return "Query execution failed."
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
+ def get_all_tools(self):
392
+ return [self.list_tables, self.get_schema, self.execute_query]
393
+