14_text2sql_agent
October 8, 2024
1 Text2SQL Agent to Interact with CSV Data
1.1 System Architecture
Think about it as an agent with a set of tools such as search_cache() and generate_SQL_query(),
and run_sql_query().
1.2 Data Ingestion Pipeline
1. Read CSV
1
2. Create Database schema
3. Create a table
4. Load table with CSV data
[1]: import pandas as pd
import sqlite3
def csv_to_sqlite(csv_file, db_name, table_name):
# Read the CSV file into a pandas DataFrame
df = pd.read_csv(csv_file)
# Connect to the SQLite database (it will create the database file if it␣
↪doesn't exist)
conn = sqlite3.connect(db_name)
cursor = conn.cursor()
# Infer the schema based on the DataFrame columns and data types
def create_table_from_df(df, table_name):
# Get column names and types
col_types = []
for col in df.columns:
dtype = df[col].dtype
if dtype == 'int64':
col_type = 'INTEGER'
elif dtype == 'float64':
col_type = 'REAL'
else:
col_type = 'TEXT'
col_types.append(f'"{col}" {col_type}')
# Create the table schema
col_definitions = ", ".join(col_types)
create_table_query = f'CREATE TABLE IF NOT EXISTS {table_name}␣
↪({col_definitions});'
# print(create_table_query)
# Execute the table creation query
cursor.execute(create_table_query)
print(f"Table '{table_name}' created with schema: {col_definitions}")
# Create table schema
create_table_from_df(df, table_name)
# Insert CSV data into the SQLite table
df.to_sql(table_name, conn, if_exists='replace', index=False)
# Commit and close the connection
2
conn.commit()
conn.close()
print(f"Data loaded into '{table_name}' table in '{db_name}' SQLite␣
↪database.")
csv_file = "movies.csv"
db_name = "movies_db.db"
table_name = "movies"
csv_to_sqlite(csv_file, db_name, table_name)
Table 'movies' created with schema: "Movie" TEXT, "LeadStudio" TEXT,
"RottenTomatoes" REAL, "AudienceScore" REAL, "Story" TEXT, "Genre" TEXT,
"TheatersOpenWeek" REAL, "OpeningWeekend" REAL, "BOAvgOpenWeekend" REAL,
"DomesticGross" REAL, "ForeignGross" REAL, "WorldGross" REAL, "Budget" REAL,
"Profitability" REAL, "OpenProfit" REAL, "Year" INTEGER
Data loaded into 'movies' table in 'movies_db.db' SQLite database.
[2]: def run_sql_query(db_name, query):
"""
Executes a SQL query on a SQLite database and returns the results.
Args:
db_name (str): The name of the SQLite database file.
query (str): The SQL query to run.
Returns:
list: Query result as a list of tuples, or an empty list if no results␣
↪or error occurred.
"""
try:
# Connect to the SQLite database
conn = sqlite3.connect(db_name)
cursor = conn.cursor()
# Execute the SQL query
cursor.execute(query)
# Fetch all results
results = cursor.fetchall()
# Close the connection
conn.close()
# Return results or an empty list if no results were found
return results if results else []
except sqlite3.Error as e:
3
print(f"An error occurred while executing the query: {e}")
return []
[3]: query = f"SELECT count(*) FROM {table_name};"
results = run_sql_query(db_name, query)
if results:
for row in results:
print(row)
(970,)
1.3 Ask Natural Language Questions
[24]: import openai
import faiss
import numpy as np
import os
from openai import OpenAI
from litellm import completion
from IPython.display import Markdown, display
[5]: OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
client = OpenAI(api_key=OPENAI_API_KEY)
# Initialize the FAISS index
dimension = 1536 # Dimension size for OpenAI embeddings (may vary by model)
index = faiss.IndexFlatL2(dimension) # L2 distance index
# Cache will hold (user_question, sql_query, response)
cache = []
[6]: # Helper function to get embeddings from OpenAI or any embedding model
def get_embeddings(text):
"""
Converts a text string into a vector embedding using OpenAI embeddings.
Args:
text (str): The text string to convert.
Returns:
np.array: A vector representation of the text.
"""
response = client.embeddings.create(input=text,␣
↪model="text-embedding-3-small")
embedding = np.array(response.data[0].embedding)
return embedding
4
[31]: def search_cache(question_embedding, threshold=0.1):
"""
Searches the FAISS index for a similar question.
Args:
question_embedding (np.array): The embedding of the user's question.
threshold (float): The similarity threshold for considering a hit.
Returns:
tuple: (sql_query, response) if a hit is found, otherwise None.
"""
if index.ntotal > 0:
distances, indices = index.search(np.array([question_embedding]), k=1)
# print(distances)
# print(indices)
# Check if the closest distance is below the threshold
if distances[0][0] < threshold:
cache_index = indices[0][0]
return cache[cache_index][1], cache[cache_index][2]
return None
[16]: def get_table_schema(db_name, table_name):
"""
Retrieves the schema (columns and data types) for a given table in the␣
↪SQLite database.
Args:
db_name (str): The name of the SQLite database file.
table_name (str): The name of the table.
Returns:
list: A list of tuples with column name, data type, and other info.
"""
conn = sqlite3.connect(db_name)
cursor = conn.cursor()
# Use PRAGMA to get the table schema
cursor.execute(f"PRAGMA table_info({table_name});")
schema = cursor.fetchall()
conn.close()
return schema
table_name = 'movies'
schema = get_table_schema(db_name, table_name)
print(f"Schema for {table_name}:")
for col in schema:
5
print(col)
Schema for movies:
(0, 'Movie', 'TEXT', 0, None, 0)
(1, 'LeadStudio', 'TEXT', 0, None, 0)
(2, 'RottenTomatoes', 'REAL', 0, None, 0)
(3, 'AudienceScore', 'REAL', 0, None, 0)
(4, 'Story', 'TEXT', 0, None, 0)
(5, 'Genre', 'TEXT', 0, None, 0)
(6, 'TheatersOpenWeek', 'REAL', 0, None, 0)
(7, 'OpeningWeekend', 'REAL', 0, None, 0)
(8, 'BOAvgOpenWeekend', 'REAL', 0, None, 0)
(9, 'DomesticGross', 'REAL', 0, None, 0)
(10, 'ForeignGross', 'REAL', 0, None, 0)
(11, 'WorldGross', 'REAL', 0, None, 0)
(12, 'Budget', 'REAL', 0, None, 0)
(13, 'Profitability', 'REAL', 0, None, 0)
(14, 'OpenProfit', 'REAL', 0, None, 0)
(15, 'Year', 'INTEGER', 0, None, 0)
[25]: def generate_llm_prompt(table_name, table_schema):
"""
Generates a prompt to provide context about a table's schema for LLM to␣
↪convert natural language to SQL.
Args:
table_name (str): The name of the table.
table_schema (list): A list of tuples where each tuple contains␣
↪information about the columns in the table.
Returns:
str: The generated prompt to be used by the LLM.
"""
prompt = f"""You are an expert in writing SQL queries for relational␣
↪databases.
You will be provided with a database schema and a natural
language question, and your task is to generate an accurate SQL query.
The database has a table named '{table_name}' with the following schema:
↪\n\n"""
prompt += "Columns:\n"
for col in table_schema:
column_name = col[1]
column_type = col[2]
prompt += f"- {column_name} ({column_type})\n"
6
prompt += "\nPlease generate a SQL query based on the following natural␣
↪language question. ONLY return the SQL query."
return prompt
table_name = "movies"
schema = get_table_schema(db_name, table_name)
# Generate the prompt
llm_prompt = generate_llm_prompt(table_name, schema)
print(llm_prompt)
You are an expert in writing SQL queries for relational databases.
You will be provided with a database schema and a natural
language question, and your task is to generate an accurate SQL query.
The database has a table named 'movies' with the following schema:
Columns:
- Movie (TEXT)
- LeadStudio (TEXT)
- RottenTomatoes (REAL)
- AudienceScore (REAL)
- Story (TEXT)
- Genre (TEXT)
- TheatersOpenWeek (REAL)
- OpeningWeekend (REAL)
- BOAvgOpenWeekend (REAL)
- DomesticGross (REAL)
- ForeignGross (REAL)
- WorldGross (REAL)
- Budget (REAL)
- Profitability (REAL)
- OpenProfit (REAL)
- Year (INTEGER)
Please generate a SQL query based on the following natural language question.
ONLY return the SQL query.
[26]: def handle_user_question(user_question):
"""
Handles the user's question by first searching the cache, and if there's no␣
↪hit, generating a SQL query and response.
Args:
user_question (str): The user's natural language question.
7
Returns:
list: The response to the user's question.
"""
# Convert the user's question to an embedding
question_embedding = get_embeddings(user_question)
# Step 1: Search cache for similar questions
cache_hit = search_cache(question_embedding)
if cache_hit:
sql_query, response = cache_hit
print(f"Cache hit! SQL Query: {sql_query}")
return response
# Step 2: No hit, go to LLM for SQL generation
print("Cache miss! Generating SQL from LLM...")
sql_query = generate_sql_query(user_question)
# Step 3: Run the SQL query on the database
response = run_sql_query(db_name, sql_query)
# Step 4: Store question, SQL, and response in cache
cache.append((user_question, sql_query, response))
index.add(np.array([question_embedding])) # Add question embedding to␣
↪FAISS index
return response
[27]: def generate_sql_query(question):
table_name = 'movies'
db_name = 'movies_db.db'
table_schema = get_table_schema(db_name, table_name)
llm_prompt = generate_llm_prompt(table_name, table_schema)
user_prompt = """Question: {question}"""
response = completion(
api_key=OPENAI_API_KEY,
model="gpt-4o-mini",
messages=[
{"content": llm_prompt.format(table_name=table_name),"role":␣
↪"system"},
{"content": user_prompt.format(question=question),"role": "user"}],
max_tokens=1000
)
answer = response.choices[0].message.content
display(Markdown(answer))
query = answer.replace("```sql", "").replace("```", "")
query = query.strip()
return query
8
[37]: # question = "total number of movies are made by Warner Bros company in year␣
↪2008?"
# question = "how many movies have RottenTomatoes scores lower than 85?"
question = "how many movies with action genre are in the database"
handle_user_question(question)
Cache miss! Generating SQL from LLM…
SELECT COUNT(*) AS ActionMovieCount
FROM movies
WHERE Genre = 'Action';
[37]: [(166,)]
[38]: cache
[38]: [('total number of movies are made by Warner Bros company in year 2008?',
"SELECT COUNT(*) \nFROM movies \nWHERE LeadStudio = 'Warner Bros' AND Year =
2008;",
[(21,)]),
('how many movies have RottenTomatoes scores greater than 85?',
'SELECT COUNT(*) \nFROM movies \nWHERE RottenTomatoes > 85;',
[(120,)]),
('how many movies have RottenTomatoes scores lower than 85?',
'SELECT COUNT(*) \nFROM movies \nWHERE RottenTomatoes < 85;',
[(782,)]),
('how many movies with action genre are in the database',
"SELECT COUNT(*) AS ActionMovieCount\nFROM movies\nWHERE Genre = 'Action';",
[(166,)])]
[ ]:
[ ]: