Added SQL RAG example

This commit is contained in:
Helge-Mikael Nordgård 2025-01-24 01:04:19 +01:00
parent 8d74eac34d
commit 5303d7829d

195
02_oms_sql_pipeline.py Normal file
View File

@ -0,0 +1,195 @@
from typing import List, Union, Generator, Iterator, Dict
from pydantic import BaseModel
from sqlalchemy import create_engine
from sqlalchemy import text
import logging
import os
import requests
import json
logging.basicConfig(level=logging.DEBUG)
class Pipeline:
class Valves(BaseModel):
OR_MODEL: str
OR_URL: str
OR_KEY: str
DB_HOST: str
DB_PORT: str
DB_USER: str
DB_PASSWORD: str
DB_DATABASE: str
DB_TABLES: List[str]
def __init__(self):
self.name = "ØMS Membership Database"
self.engine = None
self.nlsql_response = ""
self.valves = self.Valves(
**{
"pipelines": ["*"],
"OR_MODEL": os.getenv("OR_MODEL", "anthropic/claude-3.5-haiku:beta"),
"OR_URL": os.getenv("OR_URL", "https://openrouter.ai/api/v1/chat/completions"),
"OR_KEY": os.getenv("OR_KEY", "OPENROUTER_API_KEY"),
"DB_HOST": os.getenv("DB_HOST", "elrond.outlands.lan"),
"DB_PORT": os.getenv("DB_PORT", "3306"),
"DB_USER": os.getenv("DB_USER", "polarpress_demo_dba"),
"DB_PASSWORD": os.getenv("DB_PASSWORD", "YOUR_PASSWORD"),
"DB_DATABASE": os.getenv("DB_DATABASE", "pp_polarpress_demo_prod"),
"DB_TABLES": ["users", "club_memberships", "stripe_transactions", "vipps_transactions"],
}
)
def init_db(self):
try:
self.engine = create_engine(f"mysql+mysqldb://{self.valves.DB_USER}:{self.valves.DB_PASSWORD}@{self.valves.DB_HOST}:{self.valves.DB_PORT}/{self.valves.DB_DATABASE}")
print(f"Connection to MariaDB database {self.valves.DB_DATABASE} on host {self.valves.DB_HOST} established")
except Exception as e:
print(f"Error connecting to MariaDB: {e}")
return self.engine
async def on_startup(self):
self.init_db()
async def on_shutdown(self):
pass
def run_llm_query(self, message: str):
try:
response = requests.post(
url = self.valves.OR_URL,
headers = {
"Authorization": f"Bearer {self.valves.OR_KEY}"
},
data = json.dumps({
"model": self.valves.OR_MODEL,
"messages": [
{
"role": "user",
"content": message
}
]
})
)
if response.status_code == 200:
response_data = response.json()
sql_statement = response_data.get("choices", [{}])[0].get("message", {}).get("content")
if sql_statement:
return {"success": True, "data": sql_statement}
else:
logging.error("Response did not contain SQL statement.")
return {"success": False, "data": "No SQL statement in response."}
else:
logging.error(f"Error response {response.status_code}: {response.text}")
return {"success": False, "data": f"Error: {response.status_code}"}
except requests.HTTPError as e:
logging.error(f"Clientresponse error: {e}")
return {"success": False, "data": "HTTP backend error"}
except Exception as e:
logging.error(f"Unexpected error: {e}")
return {"success": False, "data": f"Unexpected error: {e}"}
def run_mysql_query(self, query: str):
try:
with self.engine.connect() as connection:
result = connection.execute(text(query))
rows = result.fetchall()
return str(rows)
except Exception as e:
return {"error": str(e)}
def reformat_data(self, message: str, query: str, result: str):
llm_reformat_instructions = f"""
Given an input question, create a syntactically correct mysql query to run. You have 4 tables to work with:
1. users - the users table holds the user records and contain the columns id, name, email, created_at, updated_at
2. club_memberships - the user membership table containing the users club memberships and contain the columns id,
user_id, valid_from, valid_to, renew, cancelled, credited, price, payment_method, created_at, updated_at. The
valid_from and valid_to columns in the club_memberships table is a varchar with a date in the format 'dd.mm.YYYY'
and contains a timespan (from and to) when the membership is valid. The columns cancelled and credited are boolean
columns and if any of these are true, the membership is not valid. Cancelled means that their membership is
cancelled and credited means that the membership is refunded (and cancelled).
3. stripe_transactions - all the transactions of the users that has been processed by stripe. The table contains
the columns id, user_id, event_signup_id, processed, amount, phone_number, created_at and updated_at. The columns
processed is a boolean value and if this is set to false, the stripe_transaction is not valid (the user has not
been charged.) Amount is a double (8,2) containing the amount the user has been charged, phone_number contains
the users phone number. If the column event_signup_id is not null, the transaction does not pertain to a membership
payment
4. vipps_transactions - all the transactions of the users that has been processed by vipps. The table contains
the columns id, user_id, event_signup_id, processed, amount, phone_number, created_at and updated_at. The columns
are identical to stripe_transaction columns.
Always run queries with INNER JOIN on users.id and the other tables user_id in order to get which user that
belongs to the other tables. You should use DISTINCT statements to avoid returning duplicates wherever possible.
Pay attention to use only the column names that I have provided. Currency for the amount columns will always be
in NORWEGIAN (NOK). The input question can be in english or norwegian. Always reply in norwegian if the input
question is in norwegian or english if the input question is in english.
Input question: {message}
MySQLQuery: {query}
Query result: {result}
Excellent. We now have the result from the query. As you can see, the "Query result:" has been formatted in a Python
list of tuples. Take this string and reformat it in markdown. Please translate all the column names to Norwegian.
"""
return llm_reformat_instructions
def pipe(self, user_message: str, model_id: str, messages: List[dict], body: dict) -> Union[str, Generator, Iterator]:
llm_initial_instructions = f"""
Given an input question, create a syntactically correct mysql query to run. You have 4 tables to work with:
1. users - the users table holds the user records and contain the columns id, name, email, created_at, updated_at
2. club_memberships - the user membership table containing the users club memberships and contain the columns id,
user_id, valid_from, valid_to, renew, cancelled, credited, price, payment_method, created_at, updated_at. The
valid_from and valid_to columns in the club_memberships table is a varchar with a date in the format 'dd.mm.YYYY'
and contains a timespan (from and to) when the membership is valid. The columns cancelled and credited are boolean
columns and if any of these are true, the membership is not valid. Cancelled means that their membership is
cancelled and credited means that the membership is refunded (and cancelled).
3. stripe_transactions - all the transactions of the users that has been processed by stripe. The table contains
the columns id, user_id, event_signup_id, processed, amount, phone_number, created_at and updated_at. The columns
processed is a boolean value and if this is set to false, the stripe_transaction is not valid (the user has not
been charged.) Amount is a double (8,2) containing the amount the user has been charged, phone_number contains
the users phone number. If the column event_signup_id is not null, the transaction does not pertain to a membership
payment
4. vipps_transactions - all the transactions of the users that has been processed by vipps. The table contains
the columns id, user_id, event_signup_id, processed, amount, phone_number, created_at and updated_at. The columns
are identical to stripe_transaction columns.
Always run queries with INNER JOIN on users.id and the other tables user_id in order to get which user that
belongs to the other tables. You should use DISTINCT statements to avoid returning duplicates wherever possible.
Pay attention to use only the column names that I have provided. Currency for the amount columns will always be
in NORWEGIAN (NOK). The input question can be in english or norwegian. Always reply in norwegian if the input
question is in norwegian or english if the input question is in english.
IMPORTANT: ONLY respond with a syntactically correct mysql query. DO NOT respond with a description of what
you have done. ONLY reply with the SQL query. DO NOT format the query with markup. ONLY provide a string
containing the SQL query you assemble from the Input question. DO NOT translate any of the column names
in the table for the query.
Input question: {user_message}
MySQLQuery:
"""
initial = self.run_llm_query(llm_initial_instructions)
if initial["success"]:
query = initial["data"]
query_result = self.run_mysql_query(query)
if isinstance(query_result, dict) and "error" in query_result:
return f"Error occurred: {query_result['error']}. Initial data: {initial['data']}"
formatted = self.reformat_data(user_message, query, query_result)
formatted_result = self.run_llm_query(formatted)
data = formatted_result["data"]
if formatted_result["success"]:
return data
return f"Error occured: {data}"
else:
data = initial["data"]
return f"Error occured: {data}"