Skip to content

Commit

Permalink
add "enable_ocr" and "enable_table_summary" (#138)
Browse files Browse the repository at this point in the history
* add "enable_ocr" and "enable_table_summary"

* add "enable_ocr" and "enable_table_summary"

* add "enable_ocr" and "enable_table_summary"
  • Loading branch information
Ceceliachenen authored Aug 1, 2024
1 parent 714bc03 commit b313fe8
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 12 deletions.
5 changes: 4 additions & 1 deletion src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,10 @@ def query_vector(self, text: str):
yield response

def add_knowledge(
self, input_files: str, enable_qa_extraction: bool, enable_raptor: bool
self,
input_files: str,
enable_qa_extraction: bool,
enable_raptor: bool,
):
files = []
file_obj_list = []
Expand Down
31 changes: 29 additions & 2 deletions src/pai_rag/app/web/tabs/upload_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,25 @@


def upload_knowledge(
upload_files, chunk_size, chunk_overlap, enable_qa_extraction, enable_raptor
upload_files,
chunk_size,
chunk_overlap,
enable_qa_extraction,
enable_raptor,
enable_ocr,
enable_table_summary,
):
if not upload_files:
return

try:
rag_client.patch_config(
{"chunk_size": chunk_size, "chunk_overlap": chunk_overlap}
{
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
"enable_ocr": enable_ocr,
"enable_table_summary": enable_table_summary,
}
)
except RagApiError as api_error:
raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")
Expand Down Expand Up @@ -105,6 +116,16 @@ def create_upload_tab() -> Dict[str, Any]:
info="Process with Raptor Node Enhancement",
elem_id="enable_raptor",
)
enable_ocr = gr.Checkbox(
label="Yes",
info="Process with OCR",
elem_id="enable_ocr",
)
enable_table_summary = gr.Checkbox(
label="Yes",
info="Process with Table Summary ",
elem_id="enable_table_summary",
)
with gr.Column(scale=8):
with gr.Tab("Files"):
upload_file = gr.File(
Expand All @@ -131,6 +152,8 @@ def create_upload_tab() -> Dict[str, Any]:
chunk_overlap,
enable_qa_extraction,
enable_raptor,
enable_ocr,
enable_table_summary,
],
outputs=[upload_file_state_df, upload_file_state],
api_name="upload_knowledge",
Expand All @@ -149,6 +172,8 @@ def create_upload_tab() -> Dict[str, Any]:
chunk_overlap,
enable_qa_extraction,
enable_raptor,
enable_ocr,
enable_table_summary,
],
outputs=[upload_dir_state_df, upload_dir_state],
api_name="upload_knowledge_dir",
Expand All @@ -164,4 +189,6 @@ def create_upload_tab() -> Dict[str, Any]:
chunk_overlap.elem_id: chunk_overlap,
enable_qa_extraction.elem_id: enable_qa_extraction,
enable_raptor.elem_id: enable_raptor,
enable_ocr.elem_id: enable_ocr,
enable_table_summary.elem_id: enable_table_summary,
}
12 changes: 12 additions & 0 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class ViewModel(BaseModel):
reader_type: str = "SimpleDirectoryReader"
enable_qa_extraction: bool = False
enable_raptor: bool = False
enable_ocr: bool = False
enable_table_summary: bool = False

config_file: str = None

Expand Down Expand Up @@ -249,6 +251,12 @@ def from_app_config(config):
view_model.enable_raptor = config["data_reader"].get(
"enable_raptor", view_model.enable_raptor
)
view_model.enable_ocr = config["data_reader"].get(
"enable_ocr", view_model.enable_ocr
)
view_model.enable_table_summary = config["data_reader"].get(
"enable_table_summary", view_model.enable_table_summary
)

view_model.similarity_top_k = config["retriever"].get("similarity_top_k", 5)
if config["retriever"]["retrieval_mode"] == "hybrid":
Expand Down Expand Up @@ -323,6 +331,8 @@ def to_app_config(self):

config["data_reader"]["enable_qa_extraction"] = self.enable_qa_extraction
config["data_reader"]["enable_raptor"] = self.enable_raptor
config["data_reader"]["enable_ocr"] = self.enable_ocr
config["data_reader"]["enable_table_summary"] = self.enable_table_summary
config["data_reader"]["type"] = self.reader_type

if self.vectordb_type == "Hologres":
Expand Down Expand Up @@ -503,6 +513,8 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]:
settings["chunk_overlap"] = {"value": self.chunk_overlap}
settings["enable_qa_extraction"] = {"value": self.enable_qa_extraction}
settings["enable_raptor"] = {"value": self.enable_raptor}
settings["enable_ocr"] = {"value": self.enable_ocr}
settings["enable_table_summary"] = {"value": self.enable_table_summary}

# retrieval and rerank
settings["retrieval_mode"] = {"value": self.retrieval_mode}
Expand Down
19 changes: 16 additions & 3 deletions src/pai_rag/data/rag_datapipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ async def ingest_from_input_path(
else:
input_paths = [file.strip() for file in input_path.split(",")]
self.data_loader.load(
input_paths, pattern, enable_qa_extraction, enable_raptor
input_paths,
pattern,
enable_qa_extraction,
enable_raptor,
)
else:
self.data_loader.load_eval_data(name)
Expand Down Expand Up @@ -103,11 +106,21 @@ def __init_data_pipeline(config_file, use_local_qa_model):
default=None,
)
def run(
config, data_path, pattern, extract_qa, use_local_qa_model, enable_raptor, name
config,
data_path,
pattern,
extract_qa,
use_local_qa_model,
enable_raptor,
name,
):
data_pipeline = __init_data_pipeline(config, use_local_qa_model)
asyncio.run(
data_pipeline.ingest_from_input_path(
data_path, pattern, extract_qa, enable_raptor, name
data_path,
pattern,
extract_qa,
enable_raptor,
name,
)
)
19 changes: 14 additions & 5 deletions src/pai_rag/integrations/readers/pai_pdf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,15 @@ class PaiPDFReader(BaseReader):
"""

def __init__(
self, enable_image_ocr: bool = False, model_dir: str = DEFAULT_MODEL_DIR
self,
enable_image_ocr: bool = False,
enable_table_summary: bool = False,
model_dir: str = DEFAULT_MODEL_DIR,
) -> None:
self.enable_image_ocr = enable_image_ocr
self.enable_table_summary = enable_table_summary
if self.enable_table_summary:
logger.info("process with table summary")
if self.enable_image_ocr:
self.model_dir = model_dir or os.path.join(DEFAULT_MODEL_DIR, "easyocr")
logger.info("start loading ocr model")
Expand Down Expand Up @@ -446,11 +452,14 @@ def load(
for table in total_tables:
# If the page number matches
if pagenum == table["page_number"]:
summarized_table_text = PaiPDFReader.tables_summarize(table["text"])
if self.enable_table_summary:
summarized_table_text = PaiPDFReader.tables_summarize(
table["text"]
)
page_tables_summaries.append(
summarized_table_text.text[:TABLE_SUMMARY_MAX_TOKEN]
)
json_data = PaiPDFReader.table_to_json(table["text"])
page_tables_summaries.append(
summarized_table_text.text[:TABLE_SUMMARY_MAX_TOKEN]
)
page_tables_json.append(json_data)
page_table_summary = "\n".join(page_tables_summaries)
page_table_json = "\n".join(page_tables_json)
Expand Down
5 changes: 4 additions & 1 deletion src/pai_rag/modules/datareader/datareader_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
".html": HtmlReader(),
".htm": HtmlReader(),
".pdf": PaiPDFReader(
enable_image_ocr=self.reader_config.get("enable_image_ocr", False),
enable_image_ocr=self.reader_config.get("enable_ocr", False),
enable_table_summary=self.reader_config.get(
"enable_table_summary", False
),
model_dir=self.reader_config.get("easyocr_model_dir", None),
),
".csv": PaiPandasCSVReader(
Expand Down

0 comments on commit b313fe8

Please sign in to comment.