Added SQL RAG example
This commit is contained in:
parent
8d74eac34d
commit
5303d7829d
195
02_oms_sql_pipeline.py
Normal file
195
02_oms_sql_pipeline.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
from typing import List, Union, Generator, Iterator, Dict
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import text
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
class Pipeline:
|
||||
class Valves(BaseModel):
|
||||
OR_MODEL: str
|
||||
OR_URL: str
|
||||
OR_KEY: str
|
||||
|
||||
DB_HOST: str
|
||||
DB_PORT: str
|
||||
DB_USER: str
|
||||
DB_PASSWORD: str
|
||||
DB_DATABASE: str
|
||||
DB_TABLES: List[str]
|
||||
|
||||
def __init__(self):
|
||||
self.name = "ØMS Membership Database"
|
||||
self.engine = None
|
||||
self.nlsql_response = ""
|
||||
self.valves = self.Valves(
|
||||
**{
|
||||
"pipelines": ["*"],
|
||||
"OR_MODEL": os.getenv("OR_MODEL", "anthropic/claude-3.5-haiku:beta"),
|
||||
"OR_URL": os.getenv("OR_URL", "https://openrouter.ai/api/v1/chat/completions"),
|
||||
"OR_KEY": os.getenv("OR_KEY", "OPENROUTER_API_KEY"),
|
||||
"DB_HOST": os.getenv("DB_HOST", "elrond.outlands.lan"),
|
||||
"DB_PORT": os.getenv("DB_PORT", "3306"),
|
||||
"DB_USER": os.getenv("DB_USER", "polarpress_demo_dba"),
|
||||
"DB_PASSWORD": os.getenv("DB_PASSWORD", "YOUR_PASSWORD"),
|
||||
"DB_DATABASE": os.getenv("DB_DATABASE", "pp_polarpress_demo_prod"),
|
||||
"DB_TABLES": ["users", "club_memberships", "stripe_transactions", "vipps_transactions"],
|
||||
}
|
||||
)
|
||||
|
||||
def init_db(self):
|
||||
try:
|
||||
self.engine = create_engine(f"mysql+mysqldb://{self.valves.DB_USER}:{self.valves.DB_PASSWORD}@{self.valves.DB_HOST}:{self.valves.DB_PORT}/{self.valves.DB_DATABASE}")
|
||||
print(f"Connection to MariaDB database {self.valves.DB_DATABASE} on host {self.valves.DB_HOST} established")
|
||||
except Exception as e:
|
||||
print(f"Error connecting to MariaDB: {e}")
|
||||
|
||||
return self.engine
|
||||
|
||||
async def on_startup(self):
|
||||
self.init_db()
|
||||
|
||||
async def on_shutdown(self):
|
||||
pass
|
||||
|
||||
def run_llm_query(self, message: str):
|
||||
try:
|
||||
response = requests.post(
|
||||
url = self.valves.OR_URL,
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.valves.OR_KEY}"
|
||||
},
|
||||
|
||||
data = json.dumps({
|
||||
"model": self.valves.OR_MODEL,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": message
|
||||
}
|
||||
]
|
||||
})
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
sql_statement = response_data.get("choices", [{}])[0].get("message", {}).get("content")
|
||||
|
||||
if sql_statement:
|
||||
return {"success": True, "data": sql_statement}
|
||||
else:
|
||||
logging.error("Response did not contain SQL statement.")
|
||||
return {"success": False, "data": "No SQL statement in response."}
|
||||
else:
|
||||
logging.error(f"Error response {response.status_code}: {response.text}")
|
||||
return {"success": False, "data": f"Error: {response.status_code}"}
|
||||
except requests.HTTPError as e:
|
||||
logging.error(f"Clientresponse error: {e}")
|
||||
return {"success": False, "data": "HTTP backend error"}
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error: {e}")
|
||||
return {"success": False, "data": f"Unexpected error: {e}"}
|
||||
|
||||
def run_mysql_query(self, query: str):
|
||||
try:
|
||||
with self.engine.connect() as connection:
|
||||
result = connection.execute(text(query))
|
||||
rows = result.fetchall()
|
||||
return str(rows)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
def reformat_data(self, message: str, query: str, result: str):
|
||||
llm_reformat_instructions = f"""
|
||||
Given an input question, create a syntactically correct mysql query to run. You have 4 tables to work with:
|
||||
1. users - the users table holds the user records and contain the columns id, name, email, created_at, updated_at
|
||||
2. club_memberships - the user membership table containing the users club memberships and contain the columns id,
|
||||
user_id, valid_from, valid_to, renew, cancelled, credited, price, payment_method, created_at, updated_at. The
|
||||
valid_from and valid_to columns in the club_memberships table is a varchar with a date in the format 'dd.mm.YYYY'
|
||||
and contains a timespan (from and to) when the membership is valid. The columns cancelled and credited are boolean
|
||||
columns and if any of these are true, the membership is not valid. Cancelled means that their membership is
|
||||
cancelled and credited means that the membership is refunded (and cancelled).
|
||||
3. stripe_transactions - all the transactions of the users that has been processed by stripe. The table contains
|
||||
the columns id, user_id, event_signup_id, processed, amount, phone_number, created_at and updated_at. The columns
|
||||
processed is a boolean value and if this is set to false, the stripe_transaction is not valid (the user has not
|
||||
been charged.) Amount is a double (8,2) containing the amount the user has been charged, phone_number contains
|
||||
the users phone number. If the column event_signup_id is not null, the transaction does not pertain to a membership
|
||||
payment
|
||||
4. vipps_transactions - all the transactions of the users that has been processed by vipps. The table contains
|
||||
the columns id, user_id, event_signup_id, processed, amount, phone_number, created_at and updated_at. The columns
|
||||
are identical to stripe_transaction columns.
|
||||
|
||||
Always run queries with INNER JOIN on users.id and the other tables user_id in order to get which user that
|
||||
belongs to the other tables. You should use DISTINCT statements to avoid returning duplicates wherever possible.
|
||||
Pay attention to use only the column names that I have provided. Currency for the amount columns will always be
|
||||
in NORWEGIAN (NOK). The input question can be in english or norwegian. Always reply in norwegian if the input
|
||||
question is in norwegian or english if the input question is in english.
|
||||
|
||||
Input question: {message}
|
||||
MySQLQuery: {query}
|
||||
Query result: {result}
|
||||
|
||||
Excellent. We now have the result from the query. As you can see, the "Query result:" has been formatted in a Python
|
||||
list of tuples. Take this string and reformat it in markdown. Please translate all the column names to Norwegian.
|
||||
"""
|
||||
|
||||
return llm_reformat_instructions
|
||||
|
||||
def pipe(self, user_message: str, model_id: str, messages: List[dict], body: dict) -> Union[str, Generator, Iterator]:
|
||||
llm_initial_instructions = f"""
|
||||
Given an input question, create a syntactically correct mysql query to run. You have 4 tables to work with:
|
||||
1. users - the users table holds the user records and contain the columns id, name, email, created_at, updated_at
|
||||
2. club_memberships - the user membership table containing the users club memberships and contain the columns id,
|
||||
user_id, valid_from, valid_to, renew, cancelled, credited, price, payment_method, created_at, updated_at. The
|
||||
valid_from and valid_to columns in the club_memberships table is a varchar with a date in the format 'dd.mm.YYYY'
|
||||
and contains a timespan (from and to) when the membership is valid. The columns cancelled and credited are boolean
|
||||
columns and if any of these are true, the membership is not valid. Cancelled means that their membership is
|
||||
cancelled and credited means that the membership is refunded (and cancelled).
|
||||
3. stripe_transactions - all the transactions of the users that has been processed by stripe. The table contains
|
||||
the columns id, user_id, event_signup_id, processed, amount, phone_number, created_at and updated_at. The columns
|
||||
processed is a boolean value and if this is set to false, the stripe_transaction is not valid (the user has not
|
||||
been charged.) Amount is a double (8,2) containing the amount the user has been charged, phone_number contains
|
||||
the users phone number. If the column event_signup_id is not null, the transaction does not pertain to a membership
|
||||
payment
|
||||
4. vipps_transactions - all the transactions of the users that has been processed by vipps. The table contains
|
||||
the columns id, user_id, event_signup_id, processed, amount, phone_number, created_at and updated_at. The columns
|
||||
are identical to stripe_transaction columns.
|
||||
|
||||
Always run queries with INNER JOIN on users.id and the other tables user_id in order to get which user that
|
||||
belongs to the other tables. You should use DISTINCT statements to avoid returning duplicates wherever possible.
|
||||
Pay attention to use only the column names that I have provided. Currency for the amount columns will always be
|
||||
in NORWEGIAN (NOK). The input question can be in english or norwegian. Always reply in norwegian if the input
|
||||
question is in norwegian or english if the input question is in english.
|
||||
|
||||
IMPORTANT: ONLY respond with a syntactically correct mysql query. DO NOT respond with a description of what
|
||||
you have done. ONLY reply with the SQL query. DO NOT format the query with markup. ONLY provide a string
|
||||
containing the SQL query you assemble from the Input question. DO NOT translate any of the column names
|
||||
in the table for the query.
|
||||
|
||||
Input question: {user_message}
|
||||
MySQLQuery:
|
||||
"""
|
||||
|
||||
initial = self.run_llm_query(llm_initial_instructions)
|
||||
if initial["success"]:
|
||||
query = initial["data"]
|
||||
query_result = self.run_mysql_query(query)
|
||||
|
||||
if isinstance(query_result, dict) and "error" in query_result:
|
||||
return f"Error occurred: {query_result['error']}. Initial data: {initial['data']}"
|
||||
|
||||
formatted = self.reformat_data(user_message, query, query_result)
|
||||
formatted_result = self.run_llm_query(formatted)
|
||||
|
||||
data = formatted_result["data"]
|
||||
if formatted_result["success"]:
|
||||
return data
|
||||
|
||||
return f"Error occured: {data}"
|
||||
else:
|
||||
data = initial["data"]
|
||||
return f"Error occured: {data}"
|
||||
Loading…
Reference in New Issue
Block a user