195 lines
10 KiB
Python
195 lines
10 KiB
Python
from typing import List, Union, Generator, Iterator, Dict
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy import text
|
|
import logging
|
|
import os
|
|
import requests
|
|
import json
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
class Pipeline:
|
|
class Valves(BaseModel):
|
|
OR_MODEL: str
|
|
OR_URL: str
|
|
OR_KEY: str
|
|
|
|
DB_HOST: str
|
|
DB_PORT: str
|
|
DB_USER: str
|
|
DB_PASSWORD: str
|
|
DB_DATABASE: str
|
|
DB_TABLES: List[str]
|
|
|
|
def __init__(self):
|
|
self.name = "ØMS Membership Database"
|
|
self.engine = None
|
|
self.nlsql_response = ""
|
|
self.valves = self.Valves(
|
|
**{
|
|
"pipelines": ["*"],
|
|
"OR_MODEL": os.getenv("OR_MODEL", "anthropic/claude-3.5-haiku:beta"),
|
|
"OR_URL": os.getenv("OR_URL", "https://openrouter.ai/api/v1/chat/completions"),
|
|
"OR_KEY": os.getenv("OR_KEY", "OPENROUTER_API_KEY"),
|
|
"DB_HOST": os.getenv("DB_HOST", "elrond.outlands.lan"),
|
|
"DB_PORT": os.getenv("DB_PORT", "3306"),
|
|
"DB_USER": os.getenv("DB_USER", "polarpress_demo_dba"),
|
|
"DB_PASSWORD": os.getenv("DB_PASSWORD", "YOUR_PASSWORD"),
|
|
"DB_DATABASE": os.getenv("DB_DATABASE", "pp_polarpress_demo_prod"),
|
|
"DB_TABLES": ["users", "club_memberships", "stripe_transactions", "vipps_transactions"],
|
|
}
|
|
)
|
|
|
|
def init_db(self):
|
|
try:
|
|
self.engine = create_engine(f"mysql+mysqldb://{self.valves.DB_USER}:{self.valves.DB_PASSWORD}@{self.valves.DB_HOST}:{self.valves.DB_PORT}/{self.valves.DB_DATABASE}")
|
|
print(f"Connection to MariaDB database {self.valves.DB_DATABASE} on host {self.valves.DB_HOST} established")
|
|
except Exception as e:
|
|
print(f"Error connecting to MariaDB: {e}")
|
|
|
|
return self.engine
|
|
|
|
async def on_startup(self):
|
|
self.init_db()
|
|
|
|
async def on_shutdown(self):
|
|
pass
|
|
|
|
def run_llm_query(self, message: str):
|
|
try:
|
|
response = requests.post(
|
|
url = self.valves.OR_URL,
|
|
headers = {
|
|
"Authorization": f"Bearer {self.valves.OR_KEY}"
|
|
},
|
|
|
|
data = json.dumps({
|
|
"model": self.valves.OR_MODEL,
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": message
|
|
}
|
|
]
|
|
})
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
response_data = response.json()
|
|
sql_statement = response_data.get("choices", [{}])[0].get("message", {}).get("content")
|
|
|
|
if sql_statement:
|
|
return {"success": True, "data": sql_statement}
|
|
else:
|
|
logging.error("Response did not contain SQL statement.")
|
|
return {"success": False, "data": "No SQL statement in response."}
|
|
else:
|
|
logging.error(f"Error response {response.status_code}: {response.text}")
|
|
return {"success": False, "data": f"Error: {response.status_code}"}
|
|
except requests.HTTPError as e:
|
|
logging.error(f"Clientresponse error: {e}")
|
|
return {"success": False, "data": "HTTP backend error"}
|
|
except Exception as e:
|
|
logging.error(f"Unexpected error: {e}")
|
|
return {"success": False, "data": f"Unexpected error: {e}"}
|
|
|
|
def run_mysql_query(self, query: str):
|
|
try:
|
|
with self.engine.connect() as connection:
|
|
result = connection.execute(text(query))
|
|
rows = result.fetchall()
|
|
return str(rows)
|
|
except Exception as e:
|
|
return {"error": str(e)}
|
|
|
|
def reformat_data(self, message: str, query: str, result: str):
|
|
llm_reformat_instructions = f"""
|
|
Given an input question, create a syntactically correct mysql query to run. You have 4 tables to work with:
|
|
1. users - the users table holds the user records and contain the columns id, name, email, created_at, updated_at
|
|
2. club_memberships - the user membership table containing the users club memberships and contain the columns id,
|
|
user_id, valid_from, valid_to, renew, cancelled, credited, price, payment_method, created_at, updated_at. The
|
|
valid_from and valid_to columns in the club_memberships table is a varchar with a date in the format 'dd.mm.YYYY'
|
|
and contains a timespan (from and to) when the membership is valid. The columns cancelled and credited are boolean
|
|
columns and if any of these are true, the membership is not valid. Cancelled means that their membership is
|
|
cancelled and credited means that the membership is refunded (and cancelled).
|
|
3. stripe_transactions - all the transactions of the users that has been processed by stripe. The table contains
|
|
the columns id, user_id, event_signup_id, processed, amount, phone_number, created_at and updated_at. The columns
|
|
processed is a boolean value and if this is set to false, the stripe_transaction is not valid (the user has not
|
|
been charged.) Amount is a double (8,2) containing the amount the user has been charged, phone_number contains
|
|
the users phone number. If the column event_signup_id is not null, the transaction does not pertain to a membership
|
|
payment
|
|
4. vipps_transactions - all the transactions of the users that has been processed by vipps. The table contains
|
|
the columns id, user_id, event_signup_id, processed, amount, phone_number, created_at and updated_at. The columns
|
|
are identical to stripe_transaction columns.
|
|
|
|
Always run queries with INNER JOIN on users.id and the other tables user_id in order to get which user that
|
|
belongs to the other tables. You should use DISTINCT statements to avoid returning duplicates wherever possible.
|
|
Pay attention to use only the column names that I have provided. Currency for the amount columns will always be
|
|
in NORWEGIAN (NOK). The input question can be in english or norwegian. Always reply in norwegian if the input
|
|
question is in norwegian or english if the input question is in english.
|
|
|
|
Input question: {message}
|
|
MySQLQuery: {query}
|
|
Query result: {result}
|
|
|
|
Excellent. We now have the result from the query. As you can see, the "Query result:" has been formatted in a Python
|
|
list of tuples. Take this string and reformat it in markdown. Please translate all the column names to Norwegian.
|
|
"""
|
|
|
|
return llm_reformat_instructions
|
|
|
|
def pipe(self, user_message: str, model_id: str, messages: List[dict], body: dict) -> Union[str, Generator, Iterator]:
|
|
llm_initial_instructions = f"""
|
|
Given an input question, create a syntactically correct mysql query to run. You have 4 tables to work with:
|
|
1. users - the users table holds the user records and contain the columns id, name, email, created_at, updated_at
|
|
2. club_memberships - the user membership table containing the users club memberships and contain the columns id,
|
|
user_id, valid_from, valid_to, renew, cancelled, credited, price, payment_method, created_at, updated_at. The
|
|
valid_from and valid_to columns in the club_memberships table is a varchar with a date in the format 'dd.mm.YYYY'
|
|
and contains a timespan (from and to) when the membership is valid. The columns cancelled and credited are boolean
|
|
columns and if any of these are true, the membership is not valid. Cancelled means that their membership is
|
|
cancelled and credited means that the membership is refunded (and cancelled).
|
|
3. stripe_transactions - all the transactions of the users that has been processed by stripe. The table contains
|
|
the columns id, user_id, event_signup_id, processed, amount, phone_number, created_at and updated_at. The columns
|
|
processed is a boolean value and if this is set to false, the stripe_transaction is not valid (the user has not
|
|
been charged.) Amount is a double (8,2) containing the amount the user has been charged, phone_number contains
|
|
the users phone number. If the column event_signup_id is not null, the transaction does not pertain to a membership
|
|
payment
|
|
4. vipps_transactions - all the transactions of the users that has been processed by vipps. The table contains
|
|
the columns id, user_id, event_signup_id, processed, amount, phone_number, created_at and updated_at. The columns
|
|
are identical to stripe_transaction columns.
|
|
|
|
Always run queries with INNER JOIN on users.id and the other tables user_id in order to get which user that
|
|
belongs to the other tables. You should use DISTINCT statements to avoid returning duplicates wherever possible.
|
|
Pay attention to use only the column names that I have provided. Currency for the amount columns will always be
|
|
in NORWEGIAN (NOK). The input question can be in english or norwegian. Always reply in norwegian if the input
|
|
question is in norwegian or english if the input question is in english.
|
|
|
|
IMPORTANT: ONLY respond with a syntactically correct mysql query. DO NOT respond with a description of what
|
|
you have done. ONLY reply with the SQL query. DO NOT format the query with markup. ONLY provide a string
|
|
containing the SQL query you assemble from the Input question. DO NOT translate any of the column names
|
|
in the table for the query.
|
|
|
|
Input question: {user_message}
|
|
MySQLQuery:
|
|
"""
|
|
|
|
initial = self.run_llm_query(llm_initial_instructions)
|
|
if initial["success"]:
|
|
query = initial["data"]
|
|
query_result = self.run_mysql_query(query)
|
|
|
|
if isinstance(query_result, dict) and "error" in query_result:
|
|
return f"Error occurred: {query_result['error']}. Initial data: {initial['data']}"
|
|
|
|
formatted = self.reformat_data(user_message, query, query_result)
|
|
formatted_result = self.run_llm_query(formatted)
|
|
|
|
data = formatted_result["data"]
|
|
if formatted_result["success"]:
|
|
return data
|
|
|
|
return f"Error occured: {data}"
|
|
else:
|
|
data = initial["data"]
|
|
return f"Error occured: {data}" |