Skip to content

Commit

Permalink
docx_reader (#250)
Browse files Browse the repository at this point in the history
* docx_reader

* docx_reader

* docx_reader

* docx_reader

* docx_reader
  • Loading branch information
Ceceliachenen authored Oct 23, 2024
1 parent c74428a commit b4464f7
Show file tree
Hide file tree
Showing 8 changed files with 401 additions and 50 deletions.
29 changes: 28 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ pai-llm-trace = {url = "https://pai-llm-trace.oss-cn-zhangjiakou.aliyuncs.com/sd
llama-index-callbacks-arize-phoenix = "0.1.6"
peft = "^0.12.0"
duckduckgo-search = "6.2.12"
docx = "^0.2.4"

[tool.poetry.scripts]
pai_rag = "pai_rag.main:main"
Expand Down
2 changes: 2 additions & 0 deletions src/pai_rag/core/rag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def validate_case_insensitive(value: Dict) -> Dict:
if key in value:
value[key] = value[key].lower()
# fix old config
if value[key] == "simple-weighted-reranker":
value[key] = "no-reranker"
if value[key] == "pai-llm-trace":
value[key] = "pai_trace"
return value
Expand Down
3 changes: 2 additions & 1 deletion src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class NodeParserConfig(BaseModel):
DOC_TYPES_DO_NOT_NEED_CHUNKING = set(
[".csv", ".xlsx", ".xls", ".htm", ".html", ".jsonl"]
)
DOC_TYPES_CONVERT_TO_MD = set([".md", ".pdf", ".docx"])
IMAGE_FILE_TYPES = set([".jpg", ".jpeg", ".png"])

IMAGE_URL_REGEX = re.compile(
Expand Down Expand Up @@ -160,7 +161,7 @@ def get_nodes_from_documents(
)
)
else:
if doc_type == ".md" or doc_type == ".pdf":
if doc_type in DOC_TYPES_CONVERT_TO_MD:
md_node_parser = MarkdownNodeParser(
id_func=node_id_hash,
enable_multimodal=self._parser_config.enable_multimodal,
Expand Down
5 changes: 5 additions & 0 deletions src/pai_rag/integrations/readers/pai/pai_data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pai_rag.integrations.readers.pai_csv_reader import PaiPandasCSVReader
from pai_rag.integrations.readers.pai_excel_reader import PaiPandasExcelReader
from pai_rag.integrations.readers.pai_jsonl_reader import PaiJsonLReader
from pai_rag.integrations.readers.pai_docx_reader import PaiDocxReader

from llama_index.core.readers.base import BaseReader
from llama_index.core.readers import SimpleDirectoryReader
Expand All @@ -33,6 +34,10 @@ def get_file_readers(reader_config: BaseDataReaderConfig = None, oss_store: Any
file_readers = {
".html": HtmlReader(),
".htm": HtmlReader(),
".docx": PaiDocxReader(
enable_table_summary=reader_config.enable_table_summary,
oss_cache=oss_store, # Storing docx images
),
".pdf": PaiPDFReader(
enable_table_summary=reader_config.enable_table_summary,
oss_cache=oss_store, # Storing pdf images
Expand Down
267 changes: 267 additions & 0 deletions src/pai_rag/integrations/readers/pai_docx_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
"""Docs parser.
"""
import logging
from pathlib import Path
from typing import Dict, List, Optional, Union, Any
from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document
from pai_rag.utils.markdown_utils import (
transform_local_to_oss,
convert_table_to_markdown,
PaiTable,
)
from docx import Document as DocxDocument
import re
import os
from PIL import Image
import time
from io import BytesIO

logger = logging.getLogger(__name__)
IMAGE_MAX_PIXELS = 512 * 512


class PaiDocxReader(BaseReader):
"""Read docx files including texts, tables, images.
Args:
enable_table_summary (bool): whether to use table_summary to process tables
"""

def __init__(
self,
enable_table_summary: bool = False,
oss_cache: Any = None,
) -> None:
self.enable_table_summary = enable_table_summary
self._oss_cache = oss_cache
logger.info(
f"PaiDocxReader created with enable_table_summary : {self.enable_table_summary}"
)

def _transform_local_to_oss(
self, image_blob: bytes, image_filename: str, doc_name: str
):
# 暂时不处理Windows图元文件
if image_filename.lower().endswith(".emf") or image_filename.lower().endswith(
".wmf"
):
return None
image = Image.open(BytesIO(image_blob))
return transform_local_to_oss(self._oss_cache, image, doc_name)

def _convert_paragraph(self, paragraph):
text = paragraph.text.strip()
if not text:
return ""

# 处理标题
if paragraph.style.name.startswith("Heading"):
heading_level = int(
re.search(r"Heading (\d)", paragraph.style.name).group(1)
)
if heading_level > 6:
heading_level = 6
return f"{'#' * heading_level} {text}\n\n"

# 处理普通段落
return f"{text}\n\n"

def _get_list_level(self, paragraph):
indent_levels = {
"List Paragraph": 0,
"List Bullet": 1,
"List Number": 1,
"List Bullet 2": 2,
"List Number 2": 2,
"List Bullet 3": 3,
"List Number 3": 3,
}

# 获取段落的样式名称
style_name = paragraph.style.name
# 根据样式名称获取层级
return indent_levels.get(style_name, 0)

def _convert_list(self, paragraph, level=0):
text = paragraph.text.strip()
if not text:
return ""

# 处理无序列表
if paragraph.style.name.startswith("List"):
return f"{'-' * level} {text}\n"

# 处理有序列表
if paragraph.style.name.startswith("List"):
return f"{level}. {text}\n"

return ""

def _convert_table_to_markdown(self, table, doc_name):
total_cols = max(len(row.cells) for row in table.rows)

header_row = table.rows[0]
rows = []
headers = self._parse_row(header_row, doc_name, total_cols)
for row in table.rows[1:]:
rows.append(self._parse_row(row, doc_name, total_cols))
table = PaiTable(headers=[headers], rows=rows)
return convert_table_to_markdown(table, total_cols)

def _parse_row(self, row, doc_name, total_cols):
row_cells = [""] * total_cols
col_index = 0
for cell in row.cells:
while col_index < total_cols and row_cells[col_index] != "":
col_index += 1
if col_index >= total_cols:
break
cell_content = self._parse_cell(cell, doc_name).strip()
row_cells[col_index] = cell_content
return row_cells

def _parse_cell(self, cell, doc_name):
cell_content = []
for paragraph in cell.paragraphs:
parsed_paragraph = self._parse_cell_paragraph(paragraph, doc_name)
if parsed_paragraph:
cell_content.append(parsed_paragraph)
unique_content = list(dict.fromkeys(cell_content))
return " ".join(unique_content)

def _parse_cell_paragraph(self, paragraph, doc_name):
paragraph_content = []
for run in paragraph.runs:
if run.element.xpath(".//a:blip"):
for blip in run.element.xpath(".//a:blip"):
image_id = blip.get(
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
)
if not image_id:
continue
image_part = paragraph.part.rels.get(image_id, None)
if image_id:
image_blob = image_part.blob
image_filename = os.path.basename(image_part.partname)
image_url = self._transform_local_to_oss(
image_blob, image_filename, doc_name
)
time_tag = int(time.time())
alt_text = f"pai_rag_image_{time_tag}_"
image_content = f"![{alt_text}]({image_url})"
paragraph_content.append(image_content)

else:
paragraph_content.append(run.text)
return "".join(paragraph_content).strip()

def convert_document_to_markdown(self, doc_path):
doc_name = os.path.basename(doc_path).split(".")[0]
doc_name = doc_name.replace(" ", "_")
document = DocxDocument(doc_path)
markdown = []

paragraphs = document.paragraphs.copy()
tables = document.tables.copy()

for element in document.element.body:
if isinstance(element.tag, str) and element.tag.endswith("p"): # 段落
paragraph = paragraphs.pop(0)

if paragraph.style.name.startswith(
"List"
) or paragraph.style.name.startswith("List"):
current_list_level = self._get_list_level(paragraph)
markdown.append(self._convert_list(paragraph, current_list_level))
else:
for run in paragraph.runs:
if (
hasattr(run.element, "tag")
and isinstance(element.tag, str)
and run.element.tag.endswith("r")
):
drawing_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing"
)
for drawing in drawing_elements:
blip_elements = drawing.findall(
".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip"
)
for blip in blip_elements:
embed_id = blip.get(
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
)
if embed_id:
image_part = document.part.related_parts.get(
embed_id
)
image_blob = image_part.blob
image_filename = os.path.basename(
image_part.partname
)
image_url = self._transform_local_to_oss(
image_blob, image_filename, doc_name
)
time_tag = int(time.time())
alt_text = f"pai_rag_image_{time_tag}_"
image_content = f"![{alt_text}]({image_url})"
markdown.append(f"{image_content}\n\n")
markdown.append(self._convert_paragraph(paragraph))

elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # 表格
table = tables.pop(0)
markdown.append(self._convert_table_to_markdown(table, None))
markdown.append("\n\n")

return "".join(markdown)

def load_data(
self,
file_path: Union[Path, str],
metadata: bool = True,
extra_info: Optional[Dict] = None,
) -> List[Document]:
"""Loads list of documents from PDF file and also accepts extra information in dict format."""
return self.load(file_path, metadata=metadata, extra_info=extra_info)

def load(
self,
file_path: Union[Path, str],
metadata: bool = True,
extra_info: Optional[Dict] = None,
) -> List[Document]:
"""Loads list of documents from Docx file and also accepts extra information in dict format.
Args:
file_path (Union[Path, str]): file path of Docx file (accepts string or Path).
metadata (bool, optional): if metadata to be included or not. Defaults to True.
extra_info (Optional[Dict], optional): extra information related to each document in dict format. Defaults to None.
Raises:
TypeError: if extra_info is not a dictionary.
TypeError: if file_path is not a string or Path.
Returns:
List[Document]: list of documents.
"""

md_content = self.convert_document_to_markdown(file_path)
logger.info(f"[PaiDocxReader] successfully processed docx file {file_path}.")
docs = []
if metadata:
if not extra_info:
extra_info = {}
doc = Document(text=md_content, extra_info=extra_info)

docs.append(doc)
else:
doc = Document(
text=md_content,
extra_info=dict(),
)
docs.append(doc)
logger.info(f"processed doc file {file_path} without metadata")
print(f"[PaiDocxReader] successfully loaded {len(docs)} nodes.")
return docs
Loading

0 comments on commit b4464f7

Please sign in to comment.