Skip to content

Commit

Permalink
Add pdf ocr choice (#303)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceceliachenen authored Dec 10, 2024
1 parent e750673 commit beeaf52
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
16 changes: 16 additions & 0 deletions src/pai_rag/app/web/tabs/upload_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def upload_oss_knowledge(
chunk_overlap,
enable_raptor,
enable_multimodal,
enable_mandatory_ocr,
enable_table_summary,
upload_index,
):
Expand All @@ -35,6 +36,7 @@ def upload_oss_knowledge(
chunk_overlap=chunk_overlap,
enable_raptor=enable_raptor,
enable_multimodal=enable_multimodal,
enable_mandatory_ocr=enable_mandatory_ocr,
enable_table_summary=enable_table_summary,
index_name=upload_index,
from_oss=True,
Expand All @@ -48,6 +50,7 @@ def upload_files(
chunk_overlap,
enable_raptor,
enable_multimodal,
enable_mandatory_ocr,
enable_table_summary,
upload_index,
):
Expand All @@ -67,6 +70,7 @@ def upload_files(
chunk_overlap=chunk_overlap,
enable_raptor=enable_raptor,
enable_multimodal=enable_multimodal,
enable_mandatory_ocr=enable_mandatory_ocr,
enable_table_summary=enable_table_summary,
index_name=upload_index,
):
Expand All @@ -80,6 +84,7 @@ def upload_knowledge(
chunk_overlap,
enable_raptor,
enable_multimodal,
enable_mandatory_ocr,
enable_table_summary,
index_name,
from_oss: bool = False,
Expand All @@ -89,6 +94,7 @@ def upload_knowledge(
{
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
"enable_mandatory_ocr": enable_mandatory_ocr,
"enable_table_summary": enable_table_summary,
}
)
Expand Down Expand Up @@ -188,6 +194,12 @@ def create_upload_tab() -> Dict[str, Any]:
elem_id="enable_multimodal",
visible=True,
)
enable_mandatory_ocr = gr.Checkbox(
label="Yes",
info="Process PDF with OCR",
elem_id="enable_mandatory_ocr",
visible=True,
)
enable_table_summary = gr.Checkbox(
label="Yes",
info="Process with Table Summary ",
Expand Down Expand Up @@ -232,6 +244,7 @@ def create_upload_tab() -> Dict[str, Any]:
chunk_overlap,
enable_raptor,
enable_multimodal,
enable_mandatory_ocr,
enable_table_summary,
upload_index,
],
Expand All @@ -247,6 +260,7 @@ def create_upload_tab() -> Dict[str, Any]:
chunk_overlap,
enable_raptor,
enable_multimodal,
enable_mandatory_ocr,
enable_table_summary,
upload_index,
],
Expand All @@ -269,6 +283,7 @@ def create_upload_tab() -> Dict[str, Any]:
chunk_overlap,
enable_raptor,
enable_multimodal,
enable_mandatory_ocr,
enable_table_summary,
upload_index,
],
Expand All @@ -287,5 +302,6 @@ def create_upload_tab() -> Dict[str, Any]:
chunk_overlap.elem_id: chunk_overlap,
enable_raptor.elem_id: enable_raptor,
enable_multimodal.elem_id: enable_multimodal,
enable_mandatory_ocr.elem_id: enable_mandatory_ocr,
enable_table_summary.elem_id: enable_table_summary,
}
4 changes: 4 additions & 0 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ViewModel(BaseModel):
# reader
reader_type: str = "SimpleDirectoryReader"
enable_raptor: bool = False
enable_mandatory_ocr: bool = False
enable_table_summary: bool = False

config_file: str = None
Expand Down Expand Up @@ -174,6 +175,7 @@ def from_app_config(config: RagConfig):
view_model.chunk_overlap = config.node_parser.chunk_overlap
view_model.chunk_size = config.node_parser.chunk_size

view_model.enable_mandatory_ocr = config.data_reader.enable_mandatory_ocr
view_model.enable_table_summary = config.data_reader.enable_table_summary

view_model.similarity_top_k = config.retriever.similarity_top_k
Expand Down Expand Up @@ -282,6 +284,7 @@ def to_app_config(self):
config["node_parser"]["chunk_size"] = int(self.chunk_size)
config["node_parser"]["chunk_overlap"] = int(self.chunk_overlap)

config["data_reader"]["enable_mandatory_ocr"] = self.enable_mandatory_ocr
config["data_reader"]["enable_table_summary"] = self.enable_table_summary

config["retriever"]["similarity_top_k"] = self.similarity_top_k
Expand Down Expand Up @@ -506,6 +509,7 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]:
settings["chunk_overlap"] = {"value": self.chunk_overlap}
settings["enable_raptor"] = {"value": self.enable_raptor}
settings["enable_multimodal"] = {"value": self.enable_multimodal}
settings["enable_mandatory_ocr"] = {"value": self.enable_mandatory_ocr}
settings["enable_table_summary"] = {"value": self.enable_table_summary}

# retrieval and rerank
Expand Down
2 changes: 2 additions & 0 deletions src/pai_rag/integrations/readers/pai/pai_data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

class BaseDataReaderConfig(BaseModel):
concat_csv_rows: bool = False
enable_mandatory_ocr: bool = False
enable_table_summary: bool = False
format_sheet_data_to_json: bool = False
sheet_column_filters: List[str] | None = None
Expand All @@ -45,6 +46,7 @@ def get_file_readers(reader_config: BaseDataReaderConfig = None, oss_store: Any
oss_cache=oss_store, # Storing docx images
),
".pdf": PaiPDFReader(
enable_mandatory_ocr=reader_config.enable_mandatory_ocr,
enable_table_summary=reader_config.enable_table_summary,
oss_cache=oss_store, # Storing pdf images
),
Expand Down
17 changes: 11 additions & 6 deletions src/pai_rag/integrations/readers/pai_pdf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.pipe.OCRPipe import OCRPipe
from magic_pdf.pipe.TXTPipe import TXTPipe
import magic_pdf.model as model_config
from rapidocr_onnxruntime import RapidOCR
from rapid_table import RapidTable
Expand Down Expand Up @@ -52,14 +51,19 @@ class PaiPDFReader(BaseReader):

def __init__(
self,
enable_mandatory_ocr: bool = False,
enable_table_summary: bool = False,
oss_cache: Any = None,
) -> None:
self.enable_table_summary = enable_table_summary
self.enable_mandatory_ocr = enable_mandatory_ocr
self._oss_cache = oss_cache
logger.info(
f"PaiPdfReader created with enable_table_summary : {self.enable_table_summary}"
)
logger.info(
f"PaiPdfReader created with enable_mandatory_ocr : {self.enable_mandatory_ocr}"
)

def _transform_local_to_oss(self, pdf_name: str, local_url: str):
image = Image.open(local_url)
Expand Down Expand Up @@ -270,7 +274,7 @@ def parse_pdf(
执行从 pdf 转换到 json、md 的过程,输出 md 和 json 文件到 pdf 文件所在的目录
:param pdf_path: .pdf 文件的路径,可以是相对路径,也可以是绝对路径
:param parse_method: 解析方法, 共 auto、ocr、txt 三种,默认 auto,如果效果不好,可以尝试 ocr
:param parse_method: 解析方法, 共 auto、ocr两种,默认 auto。auto会根据文件类型选择TXT模式或者OCR模式解析。ocr会直接使用OCR模式。
:param model_json_path: 已经存在的模型数据文件,如果为空则使用内置模型,pdf 和 model_json 务必对应
"""
try:
Expand All @@ -294,8 +298,6 @@ def parse_pdf(
if parse_method == "auto":
jso_useful_key = {"_pdf_type": "", "model_list": model_json}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
elif parse_method == "txt":
pipe = TXTPipe(pdf_bytes, model_json, image_writer)
elif parse_method == "ocr":
pipe = OCRPipe(pdf_bytes, model_json, image_writer)
else:
Expand Down Expand Up @@ -358,8 +360,11 @@ def load(
Returns:
List[Document]: list of documents.
"""

md_content = self.parse_pdf(file_path, "auto")
if self.enable_mandatory_ocr:
parse_method = "ocr"
else:
parse_method = "auto"
md_content = self.parse_pdf(file_path, parse_method)
logger.info(f"[PaiPDFReader] successfully processed pdf file {file_path}.")
docs = []
if metadata:
Expand Down

0 comments on commit beeaf52

Please sign in to comment.