Skip to content

Commit

Permalink
nl2sql refactoring (#194)
Browse files Browse the repository at this point in the history
* change insert to be sync

* add nl2sql

* nl2sql setting

* nl2sql setting

* fix test bug

* fix bugs

* data analysis retriever and synthesizer

* fix tests bugs

* add data_analysis ui

* update poetry.lock

* remove unnecessary comment

* add fault tolerance if no file provided

* add minor fault tolerance

* add upload_datasheet

* nl2sql refactor and add db ui

* restore retriever & synthesizer

* update poetry.lock

* Fix list merge

* bug fix

* add default display

---------

Co-authored-by: 陆逊 <[email protected]>
  • Loading branch information
aero-xi and moria97 authored Sep 4, 2024
1 parent 08c15e0 commit 90c4301
Show file tree
Hide file tree
Showing 29 changed files with 3,092 additions and 22 deletions.
44 changes: 30 additions & 14 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ pgvector = "^0.3.2"
pre-commit = "^3.8.0"
cn-clip = "^1.5.1"
llama-index-llms-paieas = "^0.1.0"
pymysql = "^1.1.1"
llama-index-experimental = "^0.2.0"
llama-index-readers-web = "^0.1.23"
rapidocr-onnxruntime = "^1.3.24"
rapid-table = "^0.1.3"
Expand Down
2 changes: 2 additions & 0 deletions pyproject_gpu.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ pgvector = "^0.3.2"
pre-commit = "^3.8.0"
cn-clip = "^1.5.1"
llama-index-llms-paieas = "^0.1.0"
pymysql = "^1.1.1"
llama-index-experimental = "^0.2.0"
llama-index-readers-web = "^0.1.23"
rapidocr-onnxruntime = "^1.3.24"
rapid-table = "^0.1.3"
Expand Down
65 changes: 65 additions & 0 deletions src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
import hashlib
import os
import tempfile
import shutil
import pandas as pd
from pai_rag.core.rag_service import rag_service
from pai_rag.app.api.models import (
RagQuery,
RetrievalQuery,
LlmResponse,
)
from fastapi.responses import StreamingResponse
import logging

logger = logging.getLogger(__name__)

router = APIRouter()

Expand Down Expand Up @@ -180,3 +185,63 @@ async def upload_oss_data(
)

return {"task_id": task_id}


@router.post("/upload_datasheet")
async def upload_datasheet(
file: UploadFile,
):
task_id = uuid.uuid4().hex
if not file:
return None

persist_path = "./localdata/data_analysis"

os.makedirs(name=persist_path, exist_ok=True)

# 清空目录中的文件
for filename in os.listdir(persist_path):
file_path = os.path.join(persist_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
except Exception as e:
logger.info(f"Failed to delete {file_path}. Reason: {e}")

# 指定持久化存储位置
file_name = os.path.basename(file.filename) # 获取文件名
destination_path = os.path.join(persist_path, file_name)
# 写入文件
try:
# shutil.copy(file.filename, destination_path)
with open(destination_path, "wb") as f:
shutil.copyfileobj(file.file, f)
logger.info("data analysis file saved successfully")

if destination_path.endswith(".csv"):
df = pd.read_csv(destination_path)
elif destination_path.endswith(".xlsx"):
df = pd.read_excel(destination_path)
else:
raise TypeError("Unsupported file type.")

except Exception as e:
return StreamingResponse(status_code=500, content={"message": str(e)})

return {
"task_id": task_id,
"destination_path": destination_path,
"data_preview": df.head(10).to_json(orient="records", lines=False),
}


@router.post("/query/data_analysis")
async def aquery_analysis(query: RagQuery):
response = await rag_service.aquery_analysis(query)
if not query.stream:
return response
else:
return StreamingResponse(
response,
media_type="text/event-stream",
)
104 changes: 104 additions & 0 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def query_url(self):
def search_url(self):
return f"{self.endpoint}service/query/search"

@property
def data_analysis_url(self):
return f"{self.endpoint}service/query/data_analysis"

@property
def llm_url(self):
return f"{self.endpoint}service/query/llm"
Expand All @@ -59,6 +63,10 @@ def config_url(self):
def load_data_url(self):
return f"{self.endpoint}service/upload_data"

@property
def load_datasheet_url(self):
return f"{self.endpoint}service/upload_datasheet"

@property
def load_agent_cfg_url(self):
return f"{self.endpoint}service/config/agent"
Expand Down Expand Up @@ -126,6 +134,8 @@ def _format_rag_response(
</span>
<br>
"""
else:
content = ""
content_list.append(content)
referenced_docs = "".join(content_list)

Expand Down Expand Up @@ -193,6 +203,35 @@ def query_search(
text, chunk_response, session_id=session_id, stream=stream
)

def query_data_analysis(
self,
text: str,
session_id: str = None,
stream: bool = False,
):
q = dict(
question=text,
session_id=session_id,
stream=stream,
)
r = requests.post(self.data_analysis_url, json=q, stream=True)
if r.status_code != HTTPStatus.OK:
raise RagApiError(code=r.status_code, msg=r.text)
if not stream:
response = dotdict(json.loads(r.text))
yield self._format_rag_response(
text, response, session_id=session_id, stream=stream
)
else:
full_content = ""
for chunk in r.iter_lines(chunk_size=8192, decode_unicode=True):
chunk_response = dotdict(json.loads(chunk))
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(
text, chunk_response, session_id=session_id, stream=stream
)

def query_llm(
self,
text: str,
Expand Down Expand Up @@ -298,6 +337,30 @@ def add_knowledge(
response = dotdict(json.loads(r.text))
return response

def add_datasheet(
self,
input_file: str,
):
file_obj = open(input_file, "rb")
mimetype = mimetypes.guess_type(input_file)[0]
files = {"file": (input_file, file_obj, mimetype)}
try:
r = requests.post(
self.load_datasheet_url,
files=files,
timeout=DEFAULT_CLIENT_TIME_OUT,
)
response = dotdict(json.loads(r.text))
if r.status_code != HTTPStatus.OK:
raise RagApiError(code=r.status_code, msg=response.message)
except Exception as e:
print(f"add_datasheet failed: {e}")
finally:
file_obj.close()

response = dotdict(json.loads(r.text))
return response

async def get_knowledge_state(self, task_id: str):
async with httpx.AsyncClient(timeout=DEFAULT_CLIENT_TIME_OUT) as client:
r = await client.get(self.get_load_state_url, params={"task_id": task_id})
Expand Down Expand Up @@ -376,5 +439,46 @@ def evaluate_for_response_stage(self):
raise RagApiError(code=r.status_code, msg=response.message)
print("evaluate_for_response_stage response", response)

def _format_data_analysis_rag_response(
self, question, response, session_id: str = None, stream: bool = False
):
if stream:
text = response["delta"]
else:
text = response["answer"]

docs = response.get("docs", []) or []
is_finished = response.get("is_finished", True)

referenced_docs = ""
if is_finished and len(docs) == 0 and not text:
response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE.format(query_str=question)
return response
elif is_finished:
seen_filenames = set()
file_idx = 1
for i, doc in enumerate(docs):
filename = doc["metadata"].get("file_name", None)
if filename and filename not in seen_filenames:
seen_filenames.add(filename)
formatted_file_name = re.sub("^[0-9a-z]{32}_", "", filename)
title = doc["metadata"].get("title")
if not title:
referenced_docs += f'[{file_idx}]: {formatted_file_name} Score:{doc["score"]} \n'
else:
referenced_docs += f'[{file_idx}]: [{title}]({formatted_file_name}) Score:{doc["score"]} \n'

file_idx += 1
formatted_answer = ""
if session_id:
new_query = response["new_query"]
formatted_answer += f"**Query Transformation**: {new_query} \n\n"
formatted_answer += f"**Answer**: {text} \n\n"
if referenced_docs:
formatted_answer += f"**Reference**:\n {referenced_docs}"

response["result"] = formatted_answer
return response


rag_client = RagWebClient()
4 changes: 4 additions & 0 deletions src/pai_rag/app/web/tabs/chat_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def respond(input_elements: List[Any]):
for element, value in input_elements.items():
update_dict[element.elem_id] = value

if update_dict["retrieval_mode"] == "data_analysis":
update_dict["retrieval_mode"] = "hybrid"
update_dict["synthesizer_type"] = "SimpleSummarize"

# empty input.
if not update_dict["question"]:
yield "", update_dict["chatbot"], 0
Expand Down
Loading

0 comments on commit 90c4301

Please sign in to comment.