Spaces:
Sleeping
Sleeping
| import { HfInference } from "@huggingface/inference"; | |
| import { NextApiRequest, NextApiResponse } from "next"; | |
| import { createConnection, executeQuery } from "@/utils/database"; | |
| export default async function handler( | |
| req: NextApiRequest, | |
| res: NextApiResponse | |
| ) { | |
| if (req.method !== "POST") { | |
| return res.status(405).json({ message: "Method not allowed" }); | |
| } | |
| const { dbUri, userPrompt } = req.body; | |
| if (!dbUri || !userPrompt) { | |
| return res.status(400).json({ | |
| message: "Missing required fields", | |
| details: { | |
| dbUri: !dbUri ? "Database URI is required" : null, | |
| userPrompt: !userPrompt ? "Query prompt is required" : null | |
| } | |
| }); | |
| } | |
| try { | |
| const apiKey = process.env.API_TOKEN; | |
| if (!apiKey) { | |
| return res.status(500).json({ | |
| message: "Server configuration error", | |
| details: "API key is not configured" | |
| }); | |
| } | |
| let hf; | |
| try { | |
| hf = new HfInference(apiKey); | |
| } catch (error: any) { | |
| return res.status(500).json({ | |
| message: "Failed to initialize AI model", | |
| details: error.message | |
| }); | |
| } | |
| let response; | |
| const prompt = `You are a SQL expert. Convert the following text to a SQL query. | |
| Rules: | |
| - Return a JSON object with exactly this format: {"query": "YOUR SQL QUERY HERE", "chartType": "CHART TYPE HERE"} | |
| - For chartType use one of: "bar", "pie", "line", "doughnut", or null | |
| - The query should be safe and only return the requested data | |
| - Keep table names exactly as provided | |
| - Do not include any explanations or comments | |
| Example input: "Show me sales data as a pie chart" | |
| Example output: {"query": "SELECT * FROM sales LIMIT 10", "chartType": "pie"} | |
| Text: ${userPrompt}`; | |
| try { | |
| response = await hf.chatCompletion({ | |
| model: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", | |
| messages: [{ role: "user", content: prompt }], | |
| temperature: 0.1, | |
| max_tokens: 500, | |
| }); | |
| console.log(response); | |
| } catch (error: any) { | |
| console.log(error); | |
| return res.status(500).json({ | |
| message: "AI model error", | |
| details: error.message || "Failed to generate SQL query" | |
| }); | |
| } | |
| let sqlQuery = ''; | |
| let requestedChartType = null; | |
| try { | |
| const content = response?.choices?.[0]?.message?.content?.trim() || ''; | |
| const jsonMatch = content.match(/\{[\s\S]*\}/); | |
| if (jsonMatch) { | |
| const parsedResponse = JSON.parse(jsonMatch[0]); | |
| sqlQuery = parsedResponse.query?.trim(); | |
| requestedChartType = parsedResponse.chartType; | |
| } else { | |
| sqlQuery = content | |
| .replace(/```sql/gi, '') | |
| .replace(/```/gi, '') | |
| .replace(/sql query:?\s*/gi, '') | |
| .replace(/query:?\s*/gi, '') | |
| .trim(); | |
| } | |
| if (!sqlQuery) { | |
| throw new Error('No valid SQL query found in response'); | |
| } | |
| } catch (error: any) { | |
| return res.status(500).json({ | |
| message: "Failed to parse AI response", | |
| details: error.message | |
| }); | |
| } | |
| let connection; | |
| try { | |
| connection = await createConnection(dbUri); | |
| } catch (error: any) { | |
| return res.status(500).json({ | |
| message: "Database connection error", | |
| details: error.message | |
| }); | |
| } | |
| try { | |
| const results = await executeQuery(connection, sqlQuery || ''); | |
| let visualization = null; | |
| if (Array.isArray(results) && results.length > 0) { | |
| const firstRow = results[0]; | |
| const columns = Object.keys(firstRow); | |
| const dataAnalysis = { | |
| totalColumns: columns.length, | |
| numericColumns: columns.filter((col: string) => | |
| typeof (firstRow as any)[col] === 'number' && | |
| !col.toLowerCase().includes('id') && | |
| !col.toLowerCase().includes('_id') | |
| ), | |
| dateColumns: columns.filter((col: string) => (firstRow as any)[col] instanceof Date), | |
| stringColumns: columns.filter((col: string) => | |
| typeof (firstRow as any)[col] === 'string' || | |
| col.toLowerCase().includes('name') || | |
| col.toLowerCase().includes('title') | |
| ), | |
| rowCount: results.length | |
| }; | |
| if (requestedChartType === 'pie' || requestedChartType === 'doughnut') { | |
| const preferredNumericColumns = dataAnalysis.numericColumns.filter(col => | |
| col.toLowerCase().includes('status') || | |
| col.toLowerCase().includes('count') || | |
| col.toLowerCase().includes('amount') || | |
| col.toLowerCase().includes('total') | |
| ); | |
| if (preferredNumericColumns.length > 0) { | |
| dataAnalysis.numericColumns = preferredNumericColumns; | |
| } | |
| } | |
| if (requestedChartType) { | |
| switch (requestedChartType) { | |
| case 'pie': | |
| case 'doughnut': | |
| if (dataAnalysis.numericColumns.length > 0) { | |
| visualization = { | |
| type: requestedChartType, | |
| config: { | |
| labels: results.map((row: any) => | |
| dataAnalysis.stringColumns[0] | |
| ? String(row[dataAnalysis.stringColumns[0]]) | |
| : `Row ${results.indexOf(row) + 1}` | |
| ), | |
| datasets: [{ | |
| data: results.map((row: any) => row[dataAnalysis.numericColumns[0]]), | |
| backgroundColor: results.map(() => | |
| `hsla(${Math.random() * 360}, 70%, 50%, 0.6)` | |
| ) | |
| }] | |
| } | |
| }; | |
| } | |
| break; | |
| case 'line': | |
| if (dataAnalysis.dateColumns.length > 0 || dataAnalysis.numericColumns.length > 0) { | |
| visualization = { | |
| type: 'line', | |
| config: { | |
| labels: dataAnalysis.dateColumns.length > 0 | |
| ? results.map((row: any) => new Date(row[dataAnalysis.dateColumns[0]]).toLocaleDateString()) | |
| : results.map((_, idx) => `Point ${idx + 1}`), | |
| datasets: dataAnalysis.numericColumns.map((col: string) => ({ | |
| label: col, | |
| data: results.map((row: any) => row[col]), | |
| borderColor: `hsl(${Math.random() * 360}, 70%, 50%)`, | |
| tension: 0.1 | |
| })) | |
| } | |
| }; | |
| } | |
| break; | |
| case 'bar': | |
| default: | |
| visualization = { | |
| type: 'bar', | |
| config: { | |
| labels: dataAnalysis.stringColumns.length > 0 | |
| ? results.map((row: any) => String(row[dataAnalysis.stringColumns[0]])) | |
| : results.map((_, idx) => `Row ${idx + 1}`), | |
| datasets: dataAnalysis.numericColumns.map((col: string) => ({ | |
| label: col, | |
| data: results.map((row: any) => row[col]), | |
| backgroundColor: `hsla(${Math.random() * 360}, 70%, 50%, 0.6)`, | |
| borderColor: `hsl(${Math.random() * 360}, 70%, 50%)`, | |
| borderWidth: 1 | |
| })) | |
| } | |
| }; | |
| break; | |
| } | |
| } | |
| } | |
| await connection.end(); | |
| return res.status(200).json({ | |
| results, | |
| query: sqlQuery, | |
| visualization | |
| }); | |
| } catch (error: any) { | |
| await connection?.end(); | |
| return res.status(500).json({ | |
| message: "Query execution error", | |
| details: error.message | |
| }); | |
| } | |
| } catch (error: any) { | |
| console.error('Unexpected API Error:', error); | |
| return res.status(500).json({ | |
| message: "Unexpected error occurred", | |
| details: error.message || "Unknown error" | |
| }); | |
| } | |
| } | |