diff --git a/docs/agentic_rag.md b/docs/agentic_rag.md index 2aa027b9..8246e2e5 100644 --- a/docs/agentic_rag.md +++ b/docs/agentic_rag.md @@ -118,9 +118,7 @@ Python代码工具分为两个部分:代码和函数定义。 ```python import requests import os -import logging - -logger = logging.getLogger(__name__) +from loguru import logger def get_place_weather(city: str) -> str: diff --git a/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/tools.py b/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/tools.py index 47b56423..9f85ffa1 100644 --- a/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/tools.py +++ b/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/tools.py @@ -1,8 +1,6 @@ import requests import os -import logging - -logger = logging.getLogger(__name__) +from loguru import logger def get_place_weather(city: str) -> str: diff --git a/src/pai_rag/app/api/agent_demo.py b/src/pai_rag/app/api/agent_demo.py index e2677718..8e5b5e5c 100644 --- a/src/pai_rag/app/api/agent_demo.py +++ b/src/pai_rag/app/api/agent_demo.py @@ -1,11 +1,8 @@ from datetime import datetime from fastapi import APIRouter -import logging from pydantic import BaseModel -logger = logging.getLogger(__name__) - demo_router = APIRouter() diff --git a/src/pai_rag/app/api/middleware.py b/src/pai_rag/app/api/middleware.py index d8780617..d935d99d 100644 --- a/src/pai_rag/app/api/middleware.py +++ b/src/pai_rag/app/api/middleware.py @@ -3,9 +3,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from asgi_correlation_id import CorrelationIdMiddleware import time -import logging - -logger = logging.getLogger(__name__) +from loguru import logger class CustomMiddleWare(BaseHTTPMiddleware): diff --git a/src/pai_rag/app/api/query.py b/src/pai_rag/app/api/query.py index adc4df97..e42bbafc 100644 --- a/src/pai_rag/app/api/query.py +++ b/src/pai_rag/app/api/query.py @@ -15,14 +15,12 @@ RetrievalQuery, ) from fastapi.responses import StreamingResponse -import logging +from loguru import logger from pai_rag.integrations.nodeparsers.pai.pai_node_parser import ( COMMON_FILE_PATH_FODER_NAME, ) -logger = logging.getLogger(__name__) - router = APIRouter() diff --git a/src/pai_rag/app/web/event_listeners.py b/src/pai_rag/app/web/event_listeners.py index 1fd96b21..943b3302 100644 --- a/src/pai_rag/app/web/event_listeners.py +++ b/src/pai_rag/app/web/event_listeners.py @@ -19,6 +19,7 @@ HuggingFaceEmbeddingConfig, ) from pai_rag.integrations.index.pai.vector_store_config import FaissVectorStoreConfig +from loguru import logger def add_index(*components): @@ -26,7 +27,7 @@ def add_index(*components): index_entry = components_to_index(**component_args) rag_client.add_index(index_entry) index_map = get_index_map() - print(f"Add index {index_entry.index_name} successfully") + logger.info(f"Add index {index_entry.index_name} successfully") return [ gr.update( choices=list(index_map.indexes.keys()) + ["NEW"], @@ -44,7 +45,7 @@ def update_index(*components): index_entry = components_to_index(**component_args) rag_client.update_index(index_entry) index_map = get_index_map() - print(f"Update index {index_entry.index_name} successfully") + logger.info(f"Update index {index_entry.index_name} successfully") return [ gr.update( choices=list(index_map.indexes.keys()) + ["NEW"], diff --git a/src/pai_rag/app/web/index_utils.py b/src/pai_rag/app/web/index_utils.py index 8e98d30d..7e010013 100644 --- a/src/pai_rag/app/web/index_utils.py +++ b/src/pai_rag/app/web/index_utils.py @@ -301,7 +301,6 @@ def index_to_components( component_settings = index_to_components_settings( index_entry, index_list, is_new_index ) - print("+++", index_entry.index_name) return [gr.update(**setting) for setting in component_settings.values()] + [ gr.update(choices=index_list, value=index_entry.index_name), gr.update(choices=index_list, value=index_entry.index_name), diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index a0a76451..e6946c72 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -8,6 +8,7 @@ import html import mimetypes from http import HTTPStatus +from loguru import logger from pai_rag.app.web.view_model import ViewModel from pai_rag.app.web.ui_constants import EMPTY_KNOWLEDGEBASE_MESSAGE from pai_rag.core.rag_config import RagConfig @@ -425,7 +426,7 @@ def add_datasheet( if r.status_code != HTTPStatus.OK: raise RagApiError(code=r.status_code, msg=response.message) except Exception as e: - print(f"add_datasheet failed: {e}") + logger.exception(f"add_datasheet failed: {e}") finally: file_obj.close() diff --git a/src/pai_rag/app/web/tabs/settings_tab.py b/src/pai_rag/app/web/tabs/settings_tab.py index 1e4576ce..a6d2ad44 100644 --- a/src/pai_rag/app/web/tabs/settings_tab.py +++ b/src/pai_rag/app/web/tabs/settings_tab.py @@ -4,12 +4,9 @@ from pai_rag.app.web.utils import components_to_dict from pai_rag.app.web.index_utils import index_related_component_keys from pai_rag.app.web.tabs.vector_db_panel import create_vector_db_panel -import logging import os import pai_rag.app.web.event_listeners as ev_listeners -logger = logging.getLogger(__name__) - DEFAULT_IS_INTERACTIVE = os.environ.get("PAIRAG_RAG__SETTING__interactive", "true") diff --git a/src/pai_rag/app/web/webui.py b/src/pai_rag/app/web/webui.py index 9730110a..166301cb 100644 --- a/src/pai_rag/app/web/webui.py +++ b/src/pai_rag/app/web/webui.py @@ -20,12 +20,10 @@ ) from pai_rag.app.web.tabs.model.index_info import get_index_map -import logging +from loguru import logger DEFAULT_IS_INTERACTIVE = os.environ.get("PAIRAG_RAG__SETTING__interactive", "true") -logger = logging.getLogger("WebUILogger") - def resume_ui(): outputs = {} @@ -131,6 +129,6 @@ def configure_webapp(app: FastAPI, web_url, rag_url=DEFAULT_LOCAL_URL) -> gr.Blo home = make_homepage() home.queue(concurrency_count=1, max_size=64) home._queue.set_url(web_url) - print(web_url) + logger.info(f"web_url: {web_url}") gr.mount_gradio_app(app, home, path="") return home diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index 9d2a06a4..ac280a24 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -25,14 +25,13 @@ ImageNode, ) import json -import logging +from loguru import logger import os from enum import Enum from uuid import uuid4 DEFAULT_EMPTY_RESPONSE_GEN = "Empty Response" DEFAULT_RAG_INDEX_FILE = "localdata/default_rag_indexes.json" -logger = logging.getLogger(__name__) def uuid_generator() -> str: @@ -74,7 +73,6 @@ async def event_generator_async( class RagApplication: def __init__(self, config: RagConfig): self.name = "RagApplication" - self.logger = logging.getLogger(__name__) self.config = config index_manager.add_default_index(self.config) @@ -147,7 +145,7 @@ async def aretrieve(self, query: RetrievalQuery) -> RetrievalResponse: async def aquery(self, query: RagQuery, chat_type: RagChatType = RagChatType.RAG): session_id = query.session_id or uuid_generator() - self.logger.debug(f"Get session ID: {session_id}.") + logger.debug(f"Get session ID: {session_id}.") session_config = self.config.model_copy() index_entry = index_manager.get_index_by_name(query.index_name) session_config.embedding = index_entry.embedding_config @@ -168,12 +166,12 @@ async def aquery(self, query: RagQuery, chat_type: RagChatType = RagChatType.RAG chat_history=query.chat_history, ) new_question = new_query_bundle.query_str - self.logger.info(f"Querying with question '{new_question}'.") + logger.info(f"Querying with question '{new_question}'.") if query.with_intent: intent_router = resolve_intent_router(session_config) intent = await intent_router.aselect(str_or_query_bundle=new_question) - self.logger.info(f"[IntentDetection] Routing query to {intent}.") + logger.info(f"[IntentDetection] Routing query to {intent}.") if intent == Intents.TOOL: return await self.aquery_agent(query) elif intent == Intents.WEBSEARCH: @@ -295,7 +293,7 @@ async def aquery_analysis(self, query: RagQuery): RagResponse """ session_id = query.session_id or uuid_generator() - self.logger.debug(f"Get session ID: {session_id}.") + logger.debug(f"Get session ID: {session_id}.") if not query.question: return RagResponse( answer="Empty query. Please input your question.", session_id=session_id diff --git a/src/pai_rag/core/rag_config_manager.py b/src/pai_rag/core/rag_config_manager.py index 5747cc70..8fb72f72 100644 --- a/src/pai_rag/core/rag_config_manager.py +++ b/src/pai_rag/core/rag_config_manager.py @@ -1,10 +1,11 @@ from dynaconf import Dynaconf, loaders from dynaconf.utils.boxing import DynaBox -import logging +from loguru import logger import os from pai_rag.core.rag_config import RagConfig +from pai_rag.utils.oss_utils import check_and_set_oss_auth # store config file generated from ui. GENERATED_CONFIG_FILE_NAME = "localdata/settings.snapshot.toml" @@ -26,7 +27,7 @@ def from_snapshot(cls): ) return cls(config) except Exception as error: - logging.critical("Read config file failed.") + logger.critical("Read config file failed.") raise error @classmethod @@ -51,7 +52,7 @@ def from_file(cls, config_file): # `envvar_prefix` = export envvars with `export PAIRAG_FOO=bar`. # `settings_files` = Load these files in the order. except Exception as error: - logging.critical("Read config file failed.") + logger.critical("Read config file failed.") raise error def get_value(self) -> RagConfig: @@ -61,6 +62,7 @@ def get_value(self) -> RagConfig: def update(self, new_value: Dynaconf): if self.config.get("rag", None): self.config.rag.update(new_value, merge=True) + check_and_set_oss_auth(self.config.rag) def persist(self): """Save configuration to file.""" @@ -72,5 +74,5 @@ def get_config_mtime(self): try: return os.path.getmtime(GENERATED_CONFIG_FILE_NAME) except Exception as ex: - print(f"Fail to read config mtime {ex}") + logger.critical(f"Fail to read config mtime {ex}") return -1 diff --git a/src/pai_rag/core/rag_data_loader.py b/src/pai_rag/core/rag_data_loader.py index 92dca649..487f075f 100644 --- a/src/pai_rag/core/rag_data_loader.py +++ b/src/pai_rag/core/rag_data_loader.py @@ -4,9 +4,7 @@ from llama_index.core.ingestion import IngestionPipeline from pai_rag.integrations.nodeparsers.pai.pai_node_parser import PaiNodeParser from pai_rag.integrations.readers.pai.pai_data_reader import PaiDataReader -import logging - -logger = logging.getLogger(__name__) +from loguru import logger class RagDataLoader: diff --git a/src/pai_rag/core/rag_index_manager.py b/src/pai_rag/core/rag_index_manager.py index 50e5cfb8..85195969 100644 --- a/src/pai_rag/core/rag_index_manager.py +++ b/src/pai_rag/core/rag_index_manager.py @@ -7,9 +7,7 @@ PaiBaseEmbeddingConfig, ) from pai_rag.integrations.index.pai.vector_store_config import BaseVectorStoreConfig -import logging - -logger = logging.getLogger(__name__) +from loguru import logger DEFAULT_INDEX_FILE = "localdata/default__rag__index.json" DEFAULT_INDEX_NAME = "default_index" diff --git a/src/pai_rag/core/rag_module.py b/src/pai_rag/core/rag_module.py index b6ad18ee..a2f4f1cd 100644 --- a/src/pai_rag/core/rag_module.py +++ b/src/pai_rag/core/rag_module.py @@ -29,9 +29,6 @@ from pai_rag.integrations.llms.pai.pai_llm import PaiLlm from pai_rag.integrations.llms.pai.pai_multi_modal_llm import PaiMultiModalLlm from pai_rag.utils.oss_client import OssClient -import logging - -logger = logging.getLogger(__name__) cls_cache = {} diff --git a/src/pai_rag/core/rag_service.py b/src/pai_rag/core/rag_service.py index ed07dd1d..6e561d44 100644 --- a/src/pai_rag/core/rag_service.py +++ b/src/pai_rag/core/rag_service.py @@ -13,10 +13,9 @@ ) from openinference.instrumentation import using_attributes from typing import Dict, List -import logging +from loguru import logger TASK_STATUS_FILE = "__upload_task_status.tmp" -logger = logging.getLogger(__name__) def trace_correlation_id(function): diff --git a/src/pai_rag/data/open_dataset.py b/src/pai_rag/data/open_dataset.py index f9b3c0b4..afd83b6e 100644 --- a/src/pai_rag/data/open_dataset.py +++ b/src/pai_rag/data/open_dataset.py @@ -6,6 +6,7 @@ import urllib.request import tarfile from datasets import load_dataset +from loguru import logger DEFAULT_DATASET_DIR = "datasets" @@ -40,7 +41,7 @@ def __init__( dataset_url, file_path, self.dataset_path ) else: - print( + logger.info( f"[MiraclOpenDataSet] Dataset file already exists at {self.dataset_path}." ) if not os.path.exists(self.corpus_path): @@ -48,7 +49,7 @@ def __init__( file_path = os.path.join(DEFAULT_DATASET_DIR, "miracl-corpus.tar.gz") self._extract_and_download_dataset(dataset_url, file_path, self.corpus_path) else: - print( + logger.info( f"[MiraclOpenDataSet] Corpus file already exists at {self.corpus_path}." ) @@ -57,22 +58,24 @@ def _extract_and_download_dataset(self, url, file_path, extract_path): if not os.path.exists(file_path_dir): os.makedirs(file_path_dir) with urllib.request.urlopen(url) as response, open(file_path, "wb") as out_file: - print(f"[MiraclOpenDataSet] Start downloading file {file_path} from {url}.") + logger.info( + f"[MiraclOpenDataSet] Start downloading file {file_path} from {url}." + ) data = response.read() out_file.write(data) - print("[MiraclOpenDataSet] Finish downloading.") + logger.info("[MiraclOpenDataSet] Finish downloading.") if not os.path.exists(extract_path): os.makedirs(extract_path) with tarfile.open(file_path, "r:gz") as tar: - print( + logger.info( f"[MiraclOpenDataSet] Start extracting file {file_path} to {extract_path}." ) tar.extractall(path=extract_path) - print("[MiraclOpenDataSet] Finish extracting.") + logger.info("[MiraclOpenDataSet] Finish extracting.") def load_qrels(self, type: str): file = f"{self.dataset_path}/miracl/miracl-v1.0-{self.lang}/qrels/qrels.miracl-v1.0-{self.lang}-{type}.tsv" - print( + logger.info( f"[MiraclOpenDataSet] Loading qrels for MiraclDataSet with type {type} from {file}..." ) qrels = defaultdict(dict) @@ -82,14 +85,14 @@ def load_qrels(self, type: str): qid, _, docid, rel = line.strip().split("\t") qrels[qid][docid] = int(rel) docids.add(docid) - print( + logger.info( f"[MiraclOpenDataSet] Loaded qrels {len(qrels)}, docids {len(docids)} with type {type}" ) return qrels, docids def load_topic(self, type: str): file = f"{self.dataset_path}/miracl/miracl-v1.0-{self.lang}/topics/topics.miracl-v1.0-{self.lang}-{type}.tsv" - print( + logger.info( f"[MiraclOpenDataSet] Loading topic for MiraclDataSet with type {type} from {file}..." ) qid2topic = {} @@ -97,7 +100,9 @@ def load_topic(self, type: str): for line in f: qid, topic = line.strip().split("\t") qid2topic[qid] = topic - print(f"[MiraclOpenDataSet] Loaded qid2topic {len(qid2topic)} with type {type}") + logger.info( + f"[MiraclOpenDataSet] Loaded qid2topic {len(qid2topic)} with type {type}" + ) return qid2topic def load_related_corpus(self): @@ -120,11 +125,11 @@ def load_related_corpus(self): ) ) docid2doc[json_obj["docid"]] = json_obj["text"] - print( + logger.info( f"[MiraclOpenDataSet] Loaded nodes {len(nodes)} from file_path {file_path}" ) - print(f"[MiraclOpenDataSet] Loaded all nodes {len(nodes)}") + logger.info(f"[MiraclOpenDataSet] Loaded all nodes {len(nodes)}") return nodes, docid2doc def load_related_corpus_for_type(self, type: str): @@ -149,11 +154,13 @@ def load_related_corpus_for_type(self, type: str): ) ) docid2doc[json_obj["docid"]] = json_obj["text"] - print( + logger.info( f"[MiraclOpenDataSet] Loaded nodes {len(nodes)} with type {type} from file_path {file_path}" ) - print(f"[MiraclOpenDataSet] Loaded all nodes {len(nodes)} with type {type}") + logger.info( + f"[MiraclOpenDataSet] Loaded all nodes {len(nodes)} with type {type}" + ) return nodes, docid2doc @@ -172,7 +179,7 @@ def __init__(self, dataset_path: str = None, corpus_path: str = None): dataset_url, file_path, self.dataset_path ) else: - print( + logger.info( f"[DuRetrievalDataSet] Dataset file already exists at {self.dataset_path}." ) if not os.path.exists(self.corpus_path): @@ -180,7 +187,7 @@ def __init__(self, dataset_path: str = None, corpus_path: str = None): file_path = os.path.join(DEFAULT_DATASET_DIR, "DuRetrieval.tar.gz") self._extract_and_download_dataset(dataset_url, file_path, self.corpus_path) else: - print( + logger.info( f"[DuRetrievalDataSet] Corpus file already exists at {self.corpus_path}." ) @@ -189,23 +196,23 @@ def _extract_and_download_dataset(self, url, file_path, extract_path): if not os.path.exists(file_path_dir): os.makedirs(file_path_dir) with urllib.request.urlopen(url) as response, open(file_path, "wb") as out_file: - print( + logger.info( f"[DuRetrievalDataSet] Start downloading file {file_path} from {url}." ) data = response.read() out_file.write(data) - print("[DuRetrievalDataSet] Finish downloading.") + logger.info("[DuRetrievalDataSet] Finish downloading.") if not os.path.exists(extract_path): os.makedirs(extract_path) with tarfile.open(file_path, "r:gz") as tar: - print( + logger.info( f"[DuRetrievalDataSet] Start extracting file {file_path} to {extract_path}." ) tar.extractall(path=extract_path) - print("[DuRetrievalDataSet] Finish extracting.") + logger.info("[DuRetrievalDataSet] Finish extracting.") def load_qrels(self, type: str = "dev"): - print( + logger.info( f"[DuRetrievalDataSet] Loading qrels for DuRetrievalDataSet with type {type} from {self.dataset_path}..." ) qrels_path = f"{self.dataset_path}/DuRetrieval-qrels" @@ -216,7 +223,7 @@ def load_qrels(self, type: str = "dev"): docid = sample["pid"] rel = sample["score"] qrels[qid][docid] = int(rel) - print(f"[DuRetrievalDataSet] Loaded qrels {len(qrels)} with type {type}") + logger.info(f"[DuRetrievalDataSet] Loaded qrels {len(qrels)} with type {type}") return qrels def load_related_corpus(self): @@ -234,15 +241,15 @@ def load_related_corpus(self): ) ) docid2doc[sample["id"]] = sample["text"] - print( + logger.info( f"[DuRetrievalDataSet] Loaded nodes {len(nodes)} from file_path {self.corpus_path}" ) for sample in du_dataset["queries"]: qid2query[sample["id"]] = sample["text"] - print( + logger.info( f"[DuRetrievalDataSet] Loaded queries {len(nodes)} from file_path {self.corpus_path}" ) - print(f"[DuRetrievalDataSet] Loaded all nodes {len(nodes)}") + logger.info(f"[DuRetrievalDataSet] Loaded all nodes {len(nodes)}") return nodes, docid2doc, qid2query diff --git a/src/pai_rag/evaluation/dataset/rag_eval_dataset.py b/src/pai_rag/evaluation/dataset/rag_eval_dataset.py index 8d8d5330..68a97619 100644 --- a/src/pai_rag/evaluation/dataset/rag_eval_dataset.py +++ b/src/pai_rag/evaluation/dataset/rag_eval_dataset.py @@ -4,6 +4,7 @@ from llama_index.core.bridge.pydantic import BaseModel from pai_rag.evaluation.dataset.rag_qca_dataset import RagQcaSample from llama_index.core.llama_dataset import CreatedBy +from loguru import logger class EvaluationSample(RagQcaSample): @@ -100,7 +101,7 @@ def save_json(self, path: str) -> None: } json.dump(data, f, indent=4, ensure_ascii=False) - print(f"Saved dataset to {path}.") + logger.info(f"Saved dataset to {path}.") @classmethod def from_json(cls, path: str) -> "PaiRagEvalDataset": diff --git a/src/pai_rag/evaluation/dataset/rag_qca_dataset.py b/src/pai_rag/evaluation/dataset/rag_qca_dataset.py index 5af59983..651d5f55 100644 --- a/src/pai_rag/evaluation/dataset/rag_qca_dataset.py +++ b/src/pai_rag/evaluation/dataset/rag_qca_dataset.py @@ -4,6 +4,7 @@ from llama_index.core.llama_dataset import CreatedBy import json from llama_index.core.bridge.pydantic import BaseModel +from loguru import logger class RagQcaSample(BaseLlamaDataExample): @@ -91,7 +92,7 @@ def save_json(self, path: str) -> None: } json.dump(data, f, indent=4, ensure_ascii=False) - print(f"Saved PaiRagQcaDataset to {path}.") + logger.info(f"Saved PaiRagQcaDataset to {path}.") @classmethod def from_json(cls, path: str) -> "PaiRagQcaDataset": diff --git a/src/pai_rag/evaluation/evaluator/base_evaluator.py b/src/pai_rag/evaluation/evaluator/base_evaluator.py index 6a71762b..70d62a44 100644 --- a/src/pai_rag/evaluation/evaluator/base_evaluator.py +++ b/src/pai_rag/evaluation/evaluator/base_evaluator.py @@ -14,6 +14,7 @@ CreatedByType, ) from pai_rag.evaluation.dataset.rag_qca_dataset import PaiRagQcaDataset +from loguru import logger class BaseEvaluator: @@ -48,7 +49,7 @@ def load_qca_dataset(self) -> None: if os.path.exists(self.qca_dataset_path): rag_qca_dataset = PaiRagQcaDataset.from_json(self.qca_dataset_path) if rag_qca_dataset.labelled and rag_qca_dataset.predicted: - print( + logger.info( f"Labelled QCA dataset already exists at {self.qca_dataset_path}." ) return rag_qca_dataset @@ -58,12 +59,14 @@ def load_qca_dataset(self) -> None: "Please either label it or provide a new one." ) else: - print("No existing QCA dataset found. You can proceed to create a new one.") + logger.info( + "No existing QCA dataset found. You can proceed to create a new one." + ) return None def load_evaluation_dataset(self) -> None: if os.path.exists(self.evaluation_dataset_path): - print( + logger.info( f"A evaluation dataset already exists at {self.evaluation_dataset_path}." ) evaluation_dataset = PaiRagEvalDataset.from_json( @@ -71,7 +74,7 @@ def load_evaluation_dataset(self) -> None: ) return evaluation_dataset else: - print( + logger.info( "No existing evaluation dataset found. You can proceed to create a new one." ) return None @@ -130,7 +133,7 @@ async def aevaluation(self, stage): evaluation_dataset = self.load_evaluation_dataset() qca_dataset = self.load_qca_dataset() if evaluation_dataset: - print( + logger.info( f"A evaluation dataset already exists with status: [[{evaluation_dataset.status}]]" ) _status = evaluation_dataset.status @@ -139,7 +142,9 @@ async def aevaluation(self, stage): else: qca_dataset = evaluation_dataset if qca_dataset: - print(f"Starting to generate evaluation dataset for stage: [[{stage}]]...") + logger.info( + f"Starting to generate evaluation dataset for stage: [[{stage}]]..." + ) eval_tasks = [] for qca in qca_dataset.examples: if stage == "retrieval": diff --git a/src/pai_rag/evaluation/generator/rag_qca_generator.py b/src/pai_rag/evaluation/generator/rag_qca_generator.py index 9e401fba..942f7b37 100644 --- a/src/pai_rag/evaluation/generator/rag_qca_generator.py +++ b/src/pai_rag/evaluation/generator/rag_qca_generator.py @@ -18,7 +18,7 @@ from pai_rag.integrations.synthesizer.pai_synthesizer import PaiQueryBundle import os -import logging +from loguru import logger from pai_rag.integrations.query_engine.pai_retriever_query_engine import ( PaiRetrieverQueryEngine, ) @@ -27,9 +27,6 @@ from llama_index.core.multi_modal_llms.generic_utils import load_image_urls -logger = logging.getLogger(__name__) - - class RagQcaGenerator: def __init__( self, @@ -62,23 +59,27 @@ def __init__( def load_qca_dataset(self) -> None: if os.path.exists(self.qca_dataset_path): rag_qca_dataset = PaiRagQcaDataset.from_json(self.qca_dataset_path) - print( + logger.info( f"A RAG QCA dataset already exists at {self.qca_dataset_path} with status: [labelled: {rag_qca_dataset.labelled}, predicted: {rag_qca_dataset.predicted}]." ) return rag_qca_dataset else: - print("No existing QCA dataset found. You can proceed to create a new one.") + logger.info( + "No existing QCA dataset found. You can proceed to create a new one." + ) return None async def agenerate_qca_dataset(self, stage): rag_qca_dataset = self.load_qca_dataset() if rag_qca_dataset and rag_qca_dataset.labelled: if stage == "labelled": - print("Labelled QCA dataset already exists. Skipping labelled stage.") + logger.info( + "Labelled QCA dataset already exists. Skipping labelled stage." + ) return rag_qca_dataset.examples elif stage == "predicted": if rag_qca_dataset.predicted: - print( + logger.info( "Predicted QCA dataset already exists. Skipping predicted stage." ) return rag_qca_dataset.examples @@ -190,7 +191,7 @@ async def agenerate_labelled_qca_sample(self, node): async def agenerate_labelled_qca_dataset( self, ): - print("Starting to generate QCA dataset for [[labelled]].") + logger.info("Starting to generate QCA dataset for [[labelled]].") docs = self._vector_index._docstore.docs nodes = list(docs.values()) tasks = [] @@ -251,7 +252,7 @@ async def agenerate_predicted_qca_sample(self, qca_sample): return qca_sample async def agenerate_predicted_qca_dataset(self, rag_qca_dataset): - print("Starting to generate QCA dataset for [[predicted]].") + logger.info("Starting to generate QCA dataset for [[predicted]].") tasks = [] for qca_sample in rag_qca_dataset.examples: if self.enable_multi_modal: diff --git a/src/pai_rag/evaluation/run_evaluation_experiments.py b/src/pai_rag/evaluation/run_evaluation_experiments.py index 4c43f902..47d35371 100644 --- a/src/pai_rag/evaluation/run_evaluation_experiments.py +++ b/src/pai_rag/evaluation/run_evaluation_experiments.py @@ -1,6 +1,6 @@ import yaml import click -import logging +from loguru import logger import time import json import hashlib @@ -27,7 +27,7 @@ def calculate_md5_from_json(data): def run_experiment(exp_params): exp_name = exp_params["name"] - logging.info(f"Running experiment with name={exp_name}, exp_params={exp_params}") + logger.info(f"Running experiment with name={exp_name}, exp_params={exp_params}") try: # 运行实验并获取结果 result = run_evaluation_pipeline( @@ -37,9 +37,9 @@ def run_experiment(exp_params): eval_model_source=exp_params["eval_model_source"], eval_model_name=exp_params["eval_model_name"], ) - logging.info(f"Finished experiment with name={exp_name}") + logger.info(f"Finished experiment with name={exp_name}") except Exception as e: - logging.error(f"Error running experiment {exp_name}: {e}") + logger.error(f"Error running experiment {exp_name}: {e}") return {"name": exp_params["name"], "parameters": exp_params, "result": result} @@ -63,4 +63,4 @@ def run(input_exp_config=None, output_path=None): with open(output_path, "w") as result_file: json.dump(results, result_file, ensure_ascii=False, indent=4) - logging.info(f"Results saved to {output_path}") + logger.info(f"Results saved to {output_path}") diff --git a/src/pai_rag/evaluation/run_evaluation_pipeline.py b/src/pai_rag/evaluation/run_evaluation_pipeline.py index fd799d39..a7560f40 100644 --- a/src/pai_rag/evaluation/run_evaluation_pipeline.py +++ b/src/pai_rag/evaluation/run_evaluation_pipeline.py @@ -14,14 +14,12 @@ ) from pai_rag.integrations.llms.pai.pai_llm import PaiLlm from pai_rag.evaluation.evaluator.base_evaluator import BaseEvaluator -import logging +from loguru import logger from pai_rag.integrations.llms.pai.llm_config import parse_llm_config from pai_rag.integrations.llms.pai.llm_utils import create_llm, create_multi_modal_llm -logger = logging.getLogger(__name__) - _BASE_DIR = Path(__file__).parent.parent DEFAULT_APPLICATION_CONFIG_FILE = os.path.join( _BASE_DIR, "evaluation/settings_eval_for_text.toml" @@ -36,7 +34,7 @@ def _create_components( mode = "image" if config.retriever.search_image else "text" config.synthesizer.use_multimodal_llm = True if mode == "image" else False - print(f"Creating RAG evaluation components for mode: {mode}...") + logger.info(f"Creating RAG evaluation components for mode: {mode}...") config.index.vector_store.persist_path = ( f"{config.index.vector_store.persist_path}__{exp_name}" @@ -105,5 +103,7 @@ def run_evaluation_pipeline( _ = asyncio.run(qca_generator.agenerate_qca_dataset(stage="predicted")) retrieval_result = asyncio.run(evaluator.aevaluation(stage="retrieval")) response_result = asyncio.run(evaluator.aevaluation(stage="response")) - print("retrieval_result", retrieval_result, "response_result", response_result) + logger.info( + "retrieval_result", retrieval_result, "response_result", response_result + ) return {"retrieval": retrieval_result, "response": response_result} diff --git a/src/pai_rag/integrations/agent/pai/base_tool.py b/src/pai_rag/integrations/agent/pai/base_tool.py index 44b9e30c..3cb042f6 100644 --- a/src/pai_rag/integrations/agent/pai/base_tool.py +++ b/src/pai_rag/integrations/agent/pai/base_tool.py @@ -1,9 +1,7 @@ from typing import Dict, List, Self from pydantic import BaseModel, model_validator from openai.types.beta.function_tool import FunctionTool -import logging - -logger = logging.getLogger(__name__) +from loguru import logger DEFAULT_TOOL_DEFINITION_FILE = "./example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/tools.json" DEFAULT_PYTHONSCRIPT_FILE = "./example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/tools.py" diff --git a/src/pai_rag/integrations/agent/pai/pai_agent.py b/src/pai_rag/integrations/agent/pai/pai_agent.py index 1903e047..d81ce8f9 100644 --- a/src/pai_rag/integrations/agent/pai/pai_agent.py +++ b/src/pai_rag/integrations/agent/pai/pai_agent.py @@ -22,9 +22,7 @@ from pai_rag.integrations.agent.pai.utils.tool_utils import ( get_customized_tools, ) -import logging - -logger = logging.getLogger(__name__) +from loguru import logger DEFAULT_MAX_FUNCTION_CALLS = 10 diff --git a/src/pai_rag/integrations/agent/pai/utils/tool_utils.py b/src/pai_rag/integrations/agent/pai/utils/tool_utils.py index 1e70d0e5..b786af82 100644 --- a/src/pai_rag/integrations/agent/pai/utils/tool_utils.py +++ b/src/pai_rag/integrations/agent/pai/utils/tool_utils.py @@ -8,9 +8,7 @@ DEFAULT_CALCULATE_SUBTRACT, DEFAULT_GET_DATETIME_TOOL, ) -import logging - -logger = logging.getLogger(__name__) +from loguru import logger def get_time_tools(): @@ -128,7 +126,7 @@ def get_customized_tools(agent_definition: PaiAgentDefinition): description=api_tool.description, fn=api_func, ) - print(f"Loaded api tool definition {tool.metadata}") + logger.info(f"Loaded api tool definition {tool.metadata}") tools.append(tool) for func_tool in agent_definition.function_tools: @@ -137,6 +135,6 @@ def get_customized_tools(agent_definition: PaiAgentDefinition): description=func_tool.function.description, fn=globals()[func_tool.function.name], ) - print(f"Loaded function tool definition {tool.metadata}") + logger.info(f"Loaded function tool definition {tool.metadata}") tools.append(tool) return tools diff --git a/src/pai_rag/integrations/chat_store/pai/pai_chat_store.py b/src/pai_rag/integrations/chat_store/pai/pai_chat_store.py index 80d41f31..5b861caa 100644 --- a/src/pai_rag/integrations/chat_store/pai/pai_chat_store.py +++ b/src/pai_rag/integrations/chat_store/pai/pai_chat_store.py @@ -7,9 +7,7 @@ from llama_index.core.storage.chat_store.base import BaseChatStore from pydantic import BaseModel from llama_index.core.bridge.pydantic import PrivateAttr -import logging - -logger = logging.getLogger(__name__) +from loguru import logger CHAT_STORE_FILE = "chat_store.json" DEFAULT_LOCAL_STORAGE_PATH = "./localdata/storage/" diff --git a/src/pai_rag/integrations/data_analysis/data_analysis_synthesizer.py b/src/pai_rag/integrations/data_analysis/data_analysis_synthesizer.py index 85da44b0..0e201457 100644 --- a/src/pai_rag/integrations/data_analysis/data_analysis_synthesizer.py +++ b/src/pai_rag/integrations/data_analysis/data_analysis_synthesizer.py @@ -1,4 +1,3 @@ -import logging from typing import Any, List, Generator, Optional, Sequence, cast, AsyncGenerator from llama_index.core.callbacks.base import CallbackManager @@ -23,8 +22,7 @@ ) from llama_index.core.callbacks.schema import CBEventType, EventPayload import llama_index.core.instrumentation as instrument - -logger = logging.getLogger(__name__) +from loguru import logger dispatcher = instrument.get_dispatcher(__name__) diff --git a/src/pai_rag/integrations/data_analysis/data_analysis_tool.py b/src/pai_rag/integrations/data_analysis/data_analysis_tool.py index 581988dc..cd109b56 100644 --- a/src/pai_rag/integrations/data_analysis/data_analysis_tool.py +++ b/src/pai_rag/integrations/data_analysis/data_analysis_tool.py @@ -1,4 +1,3 @@ -import logging from typing import Optional, List from llama_index.core.prompts import PromptTemplate @@ -24,7 +23,6 @@ DataAnalysisSynthesizer, ) -logger = logging.getLogger(__name__) dispatcher = instrument.get_dispatcher(__name__) DEFAULT_RESPONSE_SYNTHESIS_PROMPT = PromptTemplate( diff --git a/src/pai_rag/integrations/data_analysis/nl2pandas_retriever.py b/src/pai_rag/integrations/data_analysis/nl2pandas_retriever.py index 9d912211..8c5f0ced 100644 --- a/src/pai_rag/integrations/data_analysis/nl2pandas_retriever.py +++ b/src/pai_rag/integrations/data_analysis/nl2pandas_retriever.py @@ -1,5 +1,5 @@ import glob -import logging +from loguru import logger from typing import Any, Dict, List, Optional import os import pandas as pd @@ -17,9 +17,6 @@ from pai_rag.integrations.data_analysis.data_analysis_config import PandasAnalysisConfig -logger = logging.getLogger(__name__) - - DEFAULT_INSTRUCTION_STR = ( "1. Convert the query to executable Python code using Pandas.\n" "2. The final line of code should be a Python expression that can be called with the `eval()` function.\n" diff --git a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py index 8ac1ecdd..cc205166 100644 --- a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py +++ b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py @@ -8,7 +8,7 @@ """ import functools -import logging +from loguru import logger import os import re import signal @@ -48,8 +48,6 @@ SqliteAnalysisConfig, ) -logger = logging.getLogger(__name__) - DEFAULT_TEXT_TO_SQL_TMPL = PromptTemplate( "Given an input question, first create a syntactically correct {dialect} " "query to run, then look at the results of the query and return the answer. " diff --git a/src/pai_rag/integrations/embeddings/pai/embedding_utils.py b/src/pai_rag/integrations/embeddings/pai/embedding_utils.py index e88ecc79..9c2faa84 100644 --- a/src/pai_rag/integrations/embeddings/pai/embedding_utils.py +++ b/src/pai_rag/integrations/embeddings/pai/embedding_utils.py @@ -12,9 +12,7 @@ from llama_index.embeddings.huggingface import HuggingFaceEmbedding from pai_rag.integrations.embeddings.clip.cnclip_embedding import CnClipEmbedding import os -import logging - -logger = logging.getLogger(__name__) +from loguru import logger def create_embedding(embed_config: PaiBaseEmbeddingConfig): diff --git a/src/pai_rag/integrations/embeddings/pai/pai_embedding.py b/src/pai_rag/integrations/embeddings/pai/pai_embedding.py index fe70525d..30270302 100644 --- a/src/pai_rag/integrations/embeddings/pai/pai_embedding.py +++ b/src/pai_rag/integrations/embeddings/pai/pai_embedding.py @@ -4,15 +4,13 @@ from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.schema import BaseNode, MetadataMode, ImageNode -import logging +from loguru import logger from pai_rag.integrations.embeddings.pai.embedding_utils import create_embedding from pai_rag.integrations.embeddings.pai.pai_embedding_config import ( PaiBaseEmbeddingConfig, ) -logger = logging.getLogger(__name__) - class PaiEmbedding(BaseEmbedding): """PAI embedding model""" diff --git a/src/pai_rag/integrations/embeddings/pai/pai_multimodal_embedding.py b/src/pai_rag/integrations/embeddings/pai/pai_multimodal_embedding.py index 1ce0cbbc..2cc36bf0 100644 --- a/src/pai_rag/integrations/embeddings/pai/pai_multimodal_embedding.py +++ b/src/pai_rag/integrations/embeddings/pai/pai_multimodal_embedding.py @@ -5,14 +5,12 @@ from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.schema import ImageType, BaseNode, ImageNode -import logging +from loguru import logger from pai_rag.integrations.embeddings.pai.pai_embedding_config import ( PaiBaseEmbeddingConfig, ) from pai_rag.integrations.embeddings.pai.embedding_utils import create_embedding -logger = logging.getLogger(__name__) - class PaiMultiModalEmbedding(MultiModalEmbedding): """PAI multimodal embedding model""" diff --git a/src/pai_rag/integrations/extractors/html_qa_extractor.py b/src/pai_rag/integrations/extractors/html_qa_extractor.py index b339f593..b93eb24e 100644 --- a/src/pai_rag/integrations/extractors/html_qa_extractor.py +++ b/src/pai_rag/integrations/extractors/html_qa_extractor.py @@ -6,12 +6,10 @@ from llama_index.core.prompts import PromptTemplate from llama_index.core.async_utils import run_jobs import re -import logging +from loguru import logger CHINESE_PUNKTUATION = ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、\u3000、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·.!?。。" -logger = logging.getLogger(__name__) - class HtmlQAExtractor(BaseExtractor): llm: Optional[LLM] = Field(description="The LLM to use for generation.") @@ -147,9 +145,9 @@ def _extract_qa_dict(self, text): Q_index = [obj.span() for obj in list(partten_question.finditer(text))] A_index = [obj.span() for obj in list(partten_answer.finditer(text))] if len(Q_index) != len(A_index) or len(Q_index) == 0 or len(A_index) == 0: - print("text: ", text) - print("Q_index: ", Q_index) - print("A_index: ", A_index) + logger.debug("text: ", text) + logger.debug("Q_index: ", Q_index) + logger.debug("A_index: ", A_index) raise IndexError("[To Dict Error]提取出的问题和答案的数量不一致") QA_i = 0 QA_list = [[], []] diff --git a/src/pai_rag/integrations/extractors/text_qa_extractor.py b/src/pai_rag/integrations/extractors/text_qa_extractor.py index 2ece09a4..e7f4f83a 100644 --- a/src/pai_rag/integrations/extractors/text_qa_extractor.py +++ b/src/pai_rag/integrations/extractors/text_qa_extractor.py @@ -7,9 +7,7 @@ from llama_index.core.async_utils import run_jobs import os import re -import logging - -logger = logging.getLogger(__name__) +from loguru import logger class TextQAExtractor(BaseExtractor): @@ -95,9 +93,9 @@ def _extract_qa_dict(self, text): Q_index = [obj.span() for obj in list(partten_question.finditer(text))] A_index = [obj.span() for obj in list(partten_answer.finditer(text))] if len(Q_index) != len(A_index) or len(Q_index) == 0 or len(A_index) == 0: - print("text: ", text) - print("Q_index: ", Q_index) - print("A_index: ", A_index) + logger.debug("text: ", text) + logger.debug("Q_index: ", Q_index) + logger.debug("A_index: ", A_index) if len(Q_index) == len(A_index) + 1: # 截断的情况 Q_index = Q_index[:-1] @@ -114,7 +112,7 @@ def _extract_qa_dict(self, text): QA_list[1].append(text[A_index[-1][1] :].strip()) QA_dict = {QA_list[0][i]: QA_list[1][i] for i in range(len(Q_index))} # {Q1: A1, Q2: A2} - print(QA_dict) + logger.debug(QA_dict) return QA_dict def _get_prompt_template(self): diff --git a/src/pai_rag/integrations/index/pai/local/local_bm25_index.py b/src/pai_rag/integrations/index/pai/local/local_bm25_index.py index 3f1fcceb..11ee63fd 100644 --- a/src/pai_rag/integrations/index/pai/local/local_bm25_index.py +++ b/src/pai_rag/integrations/index/pai/local/local_bm25_index.py @@ -1,4 +1,3 @@ -import logging import os import pickle import json @@ -8,7 +7,7 @@ from pai_rag.utils.tokenizer import jieba_tokenizer from scipy.sparse import csr_matrix -logger = logging.getLogger(__name__) +from loguru import logger MAX_DOC_LIMIT = 9000000000 # 9B diff --git a/src/pai_rag/integrations/index/pai/local/local_bm25_retriever.py b/src/pai_rag/integrations/index/pai/local/local_bm25_retriever.py index 10d38416..7f936a14 100644 --- a/src/pai_rag/integrations/index/pai/local/local_bm25_retriever.py +++ b/src/pai_rag/integrations/index/pai/local/local_bm25_retriever.py @@ -1,4 +1,3 @@ -import logging from typing import List, Optional from llama_index.core.base.base_retriever import BaseRetriever @@ -7,8 +6,6 @@ from llama_index.core.schema import IndexNode, NodeWithScore, QueryBundle from pai_rag.integrations.index.pai.local.local_bm25_index import LocalBm25IndexStore -logger = logging.getLogger(__name__) - class LocalBM25Retriever(BaseRetriever): def __init__( diff --git a/src/pai_rag/integrations/index/pai/multimodal/multimodal_index.py b/src/pai_rag/integrations/index/pai/multimodal/multimodal_index.py index 3ba6297b..0c9c2856 100644 --- a/src/pai_rag/integrations/index/pai/multimodal/multimodal_index.py +++ b/src/pai_rag/integrations/index/pai/multimodal/multimodal_index.py @@ -4,7 +4,6 @@ """ -import logging from typing import Any, List, Optional, Sequence, cast from llama_index.core.base.embeddings.base import BaseEmbedding @@ -37,8 +36,7 @@ from pai_rag.integrations.index.pai.multimodal.multimodal_retriever import ( PaiMultiModalVectorIndexRetriever, ) - -logger = logging.getLogger(__name__) +from loguru import logger class PaiMultiModalVectorStoreIndex(VectorStoreIndex): @@ -217,7 +215,6 @@ def _get_node_with_embedding( else: result = node.copy() - # print("===", is_image, node.node_id, len(result.embedding)) results.append(result) return results diff --git a/src/pai_rag/integrations/index/pai/multimodal/multimodal_retriever.py b/src/pai_rag/integrations/index/pai/multimodal/multimodal_retriever.py index 20dafc28..05da7dac 100644 --- a/src/pai_rag/integrations/index/pai/multimodal/multimodal_retriever.py +++ b/src/pai_rag/integrations/index/pai/multimodal/multimodal_retriever.py @@ -35,11 +35,10 @@ VectorStoreQueryResult, ) from pai_rag.integrations.index.pai.local.local_bm25_index import LocalBm25IndexStore -import logging +from loguru import logger import llama_index.core.instrumentation as instrument dispatcher = instrument.get_dispatcher(__name__) -logger = logging.getLogger(__name__) DEFAULT_IMAGE_STORE = "image" @@ -291,17 +290,11 @@ def _fusion_nodes( keyword_nodes: List[NodeWithScore], similarity_top_k: int, ): - # print("Fusion weights: ", self._hybrid_fusion_weights) - for node_with_score in vector_nodes: - # print("vector score 0", node_with_score.node_id, node_with_score.score) node_with_score.score *= self._hybrid_fusion_weights[0] - # print("vector score 1", node_with_score.node_id, node_with_score.score) for node_with_score in keyword_nodes: - # print("keyword score 0", node_with_score.node_id, node_with_score.score) node_with_score.score *= self._hybrid_fusion_weights[1] - # print("keyword score 1", node_with_score.node_id, node_with_score.score) # Use a dict to de-duplicate nodes all_nodes: Dict[str, NodeWithScore] = {} @@ -495,8 +488,8 @@ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: task_results = await asyncio.gather(*tasks) text_nodes, image_nodes = task_results[0], task_results[1] - logger.info(f"Retrieved text nodes: {text_nodes}") - logger.info(f"Retrieved image nodes: {image_nodes}") + logger.debug(f"Retrieved text nodes: {text_nodes}") + logger.debug(f"Retrieved image nodes: {image_nodes}") seen_images = set([node.node.image_url for node in image_nodes]) # 从文本中召回图片 diff --git a/src/pai_rag/integrations/index/pai/pai_vector_index.py b/src/pai_rag/integrations/index/pai/pai_vector_index.py index 4773c122..37adfc73 100644 --- a/src/pai_rag/integrations/index/pai/pai_vector_index.py +++ b/src/pai_rag/integrations/index/pai/pai_vector_index.py @@ -1,4 +1,3 @@ -import logging import os from typing import Coroutine, List, Any, Sequence from llama_index.core.base.base_query_engine import BaseQueryEngine @@ -32,7 +31,7 @@ from llama_index.core.vector_stores.types import VectorStoreQueryMode from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K -logger = logging.getLogger(__name__) +from loguru import logger def retrieval_type_to_search_mode(retrieval_type: VectorIndexRetrievalType): diff --git a/src/pai_rag/integrations/index/pai/utils/index_utils.py b/src/pai_rag/integrations/index/pai/utils/index_utils.py index 1bb4f5b9..cf7b58cd 100644 --- a/src/pai_rag/integrations/index/pai/utils/index_utils.py +++ b/src/pai_rag/integrations/index/pai/utils/index_utils.py @@ -1,4 +1,3 @@ -import logging from typing import Any, List, Optional, Sequence from llama_index.core.indices.base import BaseIndex @@ -17,8 +16,6 @@ IndexStructType.MULTIMODAL_VECTOR_STORE ] = PaiMultiModalVectorStoreIndex -logger = logging.getLogger(__name__) - def load_index_from_storage( storage_context: StorageContext, diff --git a/src/pai_rag/integrations/index/pai/utils/sparse_embed_function.py b/src/pai_rag/integrations/index/pai/utils/sparse_embed_function.py index ae12a7ca..60540651 100644 --- a/src/pai_rag/integrations/index/pai/utils/sparse_embed_function.py +++ b/src/pai_rag/integrations/index/pai/utils/sparse_embed_function.py @@ -1,9 +1,8 @@ import os -import logging from typing import List from pai_rag.utils.constants import DEFAULT_MODEL_DIR -logger = logging.getLogger(__name__) +from loguru import logger MODEL_NAME = "bge-m3" diff --git a/src/pai_rag/integrations/index/pai/utils/vector_store_utils.py b/src/pai_rag/integrations/index/pai/utils/vector_store_utils.py index 342b120e..5a647836 100644 --- a/src/pai_rag/integrations/index/pai/utils/vector_store_utils.py +++ b/src/pai_rag/integrations/index/pai/utils/vector_store_utils.py @@ -1,6 +1,5 @@ import hashlib import faiss -import logging import os import json from llama_index.core.vector_stores.simple import DEFAULT_VECTOR_STORE, NAMESPACE_SEP @@ -31,8 +30,6 @@ ) -logger = logging.getLogger(__name__) - DEFAULT_PERSIST_IMAGE_NAMESPACE = "image" diff --git a/src/pai_rag/integrations/llms/pai/llm_utils.py b/src/pai_rag/integrations/llms/pai/llm_utils.py index 966b34a2..70e1c4b8 100644 --- a/src/pai_rag/integrations/llms/pai/llm_utils.py +++ b/src/pai_rag/integrations/llms/pai/llm_utils.py @@ -1,4 +1,3 @@ -import logging import os from urllib.parse import urljoin from llama_index.llms.openai import OpenAI @@ -14,7 +13,7 @@ OpenAIAlikeMultiModal, ) -logger = logging.getLogger(__name__) +from loguru import logger def create_llm(llm_config: PaiBaseLlmConfig): diff --git a/src/pai_rag/integrations/llms/pai/pai_llm.py b/src/pai_rag/integrations/llms/pai/pai_llm.py index bbdb65ca..bfd41154 100644 --- a/src/pai_rag/integrations/llms/pai/pai_llm.py +++ b/src/pai_rag/integrations/llms/pai/pai_llm.py @@ -23,9 +23,6 @@ DEFAULT_MAX_TOKENS, PaiBaseLlmConfig, ) -import logging - -logger = logging.getLogger(__name__) class PaiLlm(OpenAILike): diff --git a/src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py b/src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py index 1a508dce..1d5f816c 100644 --- a/src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py +++ b/src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py @@ -23,9 +23,7 @@ DEFAULT_BREAKPOINT, DEFAULT_BUFFER_SIZE, ) -import logging - -logger = logging.getLogger(__name__) +from loguru import logger class NodeParserConfig(BaseModel): diff --git a/src/pai_rag/integrations/nodes/raptor_nodes_enhance.py b/src/pai_rag/integrations/nodes/raptor_nodes_enhance.py index 35a1b690..149a5dfe 100644 --- a/src/pai_rag/integrations/nodes/raptor_nodes_enhance.py +++ b/src/pai_rag/integrations/nodes/raptor_nodes_enhance.py @@ -16,9 +16,7 @@ from pai_rag.integrations.nodes.raptor_clusters import get_clusters from pai_rag.utils.prompt_template import DEFAULT_SUMMARY_PROMPT -import logging - -logger = logging.getLogger(__name__) +from loguru import logger class RaptorProcessor(TransformComponent): diff --git a/src/pai_rag/integrations/postprocessor/pai/pai_postprocessor.py b/src/pai_rag/integrations/postprocessor/pai/pai_postprocessor.py index dca7233c..0f40abbf 100644 --- a/src/pai_rag/integrations/postprocessor/pai/pai_postprocessor.py +++ b/src/pai_rag/integrations/postprocessor/pai/pai_postprocessor.py @@ -10,9 +10,7 @@ from llama_index.core.postprocessor import SimilarityPostprocessor from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.core.schema import NodeWithScore, QueryBundle -import logging - -logger = logging.getLogger(__name__) +from loguru import logger # rerank constants DEFAULT_RERANK_MODEL = "bge-reranker-base" diff --git a/src/pai_rag/integrations/query_engine/pai_retriever_query_engine.py b/src/pai_rag/integrations/query_engine/pai_retriever_query_engine.py index 1073441c..8c18d7d9 100644 --- a/src/pai_rag/integrations/query_engine/pai_retriever_query_engine.py +++ b/src/pai_rag/integrations/query_engine/pai_retriever_query_engine.py @@ -10,10 +10,8 @@ from llama_index.core.callbacks.base import CallbackManager import llama_index.core.instrumentation as instrument from llama_index.core.response_synthesizers import BaseSynthesizer -import logging dispatcher = instrument.get_dispatcher(__name__) -logger = logging.getLogger(__name__) @dataclass diff --git a/src/pai_rag/integrations/query_transform/pai_query_transform.py b/src/pai_rag/integrations/query_transform/pai_query_transform.py index b9577051..b8a7232f 100644 --- a/src/pai_rag/integrations/query_transform/pai_query_transform.py +++ b/src/pai_rag/integrations/query_transform/pai_query_transform.py @@ -18,6 +18,7 @@ from llama_index.core.callbacks.base import CallbackManager from llama_index.core.prompts import PromptTemplate from pai_rag.utils.messages_utils import parse_chat_messages +from loguru import logger DEFAULT_FUSION_NUM_QUERIES = 4 @@ -86,7 +87,7 @@ def _run(self, query_bundle: QueryBundle, metadata: Dict) -> List[QueryBundle]: queries = [q.strip() for q in queries if q.strip()] if self._verbose: queries_str = "\n".join(queries) - print(f"Generated queries:\n{queries_str}") + logger.info(f"Generated queries:\n{queries_str}") # The LLM often returns more queries than we asked for, so trim the list. return [ @@ -113,7 +114,7 @@ async def _arun( queries = [q.strip() for q in queries if q.strip()] if self._verbose: queries_str = "\n".join(queries) - print(f"Generated queries:\n{queries_str}") + logger.info(f"Generated queries:\n{queries_str}") # The LLM often returns more queries than we asked for, so trim the list. return [ diff --git a/src/pai_rag/integrations/readers/llama_parse_reader.py b/src/pai_rag/integrations/readers/llama_parse_reader.py index 59deba79..87137072 100644 --- a/src/pai_rag/integrations/readers/llama_parse_reader.py +++ b/src/pai_rag/integrations/readers/llama_parse_reader.py @@ -1,7 +1,7 @@ import nest_asyncio from llama_index.core import SimpleDirectoryReader from llama_index.core import Document -import logging +from loguru import logger import os import json import asyncio @@ -21,8 +21,6 @@ nest_asyncio.apply() -logger = logging.getLogger(__name__) - def is_default_fs(fs: fsspec.AbstractFileSystem) -> bool: return isinstance(fs, LocalFileSystem) and not fs.auto_mkdir @@ -171,9 +169,8 @@ def load_file( if raise_on_error: raise Exception("Error loading file") from e # otherwise, just skip the file and report the error - print( - f"Failed to load file {input_file} with error: {e}. Skipping...", - flush=True, + logger.exception( + f"Failed to load file {input_file} with error: {e}. Skipping..." ) return [] diff --git a/src/pai_rag/integrations/readers/markdown_reader.py b/src/pai_rag/integrations/readers/markdown_reader.py index aeaa79bb..2dd12b68 100644 --- a/src/pai_rag/integrations/readers/markdown_reader.py +++ b/src/pai_rag/integrations/readers/markdown_reader.py @@ -9,9 +9,7 @@ import re from llama_index.core.readers.base import BaseReader from llama_index.core.schema import Document -import logging - -logger = logging.getLogger(__name__) +from loguru import logger REGEX_H1 = "=+" REGEX_H2 = "-+" diff --git a/src/pai_rag/integrations/readers/pai/pai_data_reader.py b/src/pai_rag/integrations/readers/pai/pai_data_reader.py index 49bcef7d..0348d194 100644 --- a/src/pai_rag/integrations/readers/pai/pai_data_reader.py +++ b/src/pai_rag/integrations/readers/pai/pai_data_reader.py @@ -13,9 +13,7 @@ from llama_index.core.readers.base import BaseReader from llama_index.core.readers import SimpleDirectoryReader from llama_index.core.schema import Document -import logging - -logger = logging.getLogger(__name__) +from loguru import logger COMMON_FILE_PATH_FODER_NAME = "__pairag__knowledgebase__" @@ -104,7 +102,6 @@ def get_oss_files(oss_path: str, filter_pattern: str = None, oss_store: Any = No if not files: raise ValueError(f"No file found at OSS path '{oss_path}'.") - print(files) return files diff --git a/src/pai_rag/integrations/readers/pai_csv_reader.py b/src/pai_rag/integrations/readers/pai_csv_reader.py index 132dcc27..44178e13 100644 --- a/src/pai_rag/integrations/readers/pai_csv_reader.py +++ b/src/pai_rag/integrations/readers/pai_csv_reader.py @@ -12,9 +12,7 @@ from llama_index.core.readers.base import BaseReader from llama_index.core.schema import Document import chardet -import logging - -logger = logging.getLogger(__name__) +from loguru import logger class PaiCSVReader(BaseReader): diff --git a/src/pai_rag/integrations/readers/pai_docx_reader.py b/src/pai_rag/integrations/readers/pai_docx_reader.py index 06eefb18..9375abfe 100644 --- a/src/pai_rag/integrations/readers/pai_docx_reader.py +++ b/src/pai_rag/integrations/readers/pai_docx_reader.py @@ -1,7 +1,6 @@ """Docs parser. """ -import logging from pathlib import Path from typing import Dict, List, Optional, Union, Any from llama_index.core.readers.base import BaseReader @@ -18,8 +17,9 @@ from PIL import Image import time from io import BytesIO +from loguru import logger + -logger = logging.getLogger(__name__) IMAGE_MAX_PIXELS = 512 * 512 @@ -265,5 +265,5 @@ def load( ) docs.append(doc) logger.info(f"processed doc file {file_path} without metadata") - print(f"[PaiDocxReader] successfully loaded {len(docs)} nodes.") + logger.info(f"[PaiDocxReader] successfully loaded {len(docs)} nodes.") return docs diff --git a/src/pai_rag/integrations/readers/pai_html_reader.py b/src/pai_rag/integrations/readers/pai_html_reader.py index 5ca4f8ed..0b70ea07 100644 --- a/src/pai_rag/integrations/readers/pai_html_reader.py +++ b/src/pai_rag/integrations/readers/pai_html_reader.py @@ -2,7 +2,6 @@ """ import html2text -import logging from bs4 import BeautifulSoup import requests from typing import Dict, List, Optional, Union, Any @@ -19,8 +18,8 @@ from PIL import Image from llama_index.core.readers.base import BaseReader from llama_index.core.schema import Document +from loguru import logger -logger = logging.getLogger(__name__) IMAGE_URL_PATTERN = ( r"!\[(?P.*?)\]\((https?://[^\s]+?[\s\w.-]*\.(jpg|jpeg|png|gif|bmp))\)" @@ -195,7 +194,7 @@ def convert_html_to_markdown(self, html_path): return markdown_content except Exception as e: - logger(e) + logger.exception(e) return None def load_data( @@ -238,5 +237,5 @@ def load( logger.info(f"processed html file {file_path} without metadata") doc = Document(text=md_content, extra_info=extra_info) docs.append(doc) - print(f"[PaiHtmlReader] successfully loaded {len(docs)} nodes.") + logger.info(f"[PaiHtmlReader] successfully loaded {len(docs)} nodes.") return docs diff --git a/src/pai_rag/integrations/readers/pai_image_reader.py b/src/pai_rag/integrations/readers/pai_image_reader.py index 7430e069..dec26308 100644 --- a/src/pai_rag/integrations/readers/pai_image_reader.py +++ b/src/pai_rag/integrations/readers/pai_image_reader.py @@ -10,9 +10,6 @@ import os from llama_index.core.readers.base import BaseReader from llama_index.core.schema import Document, ImageDocument -import logging - -logger = logging.getLogger(__name__) class PaiImageReader(BaseReader): diff --git a/src/pai_rag/integrations/readers/pai_pdf_reader.py b/src/pai_rag/integrations/readers/pai_pdf_reader.py index 20fc1262..2bfce9c2 100644 --- a/src/pai_rag/integrations/readers/pai_pdf_reader.py +++ b/src/pai_rag/integrations/readers/pai_pdf_reader.py @@ -23,15 +23,13 @@ import tempfile import re from PIL import Image - - -import logging import os import json +from loguru import logger + model_config.__use_inside_model__ = True -logger = logging.getLogger(__name__) IMAGE_MAX_PIXELS = 512 * 512 TABLE_SUMMARY_MAX_ROW_NUM = 5 @@ -188,7 +186,7 @@ def process_table(self, markdown_content, json_data): markdown_content, item["img_path"], ocr_content ) else: - print(f"警告:图片文件不存在 {img_path}") + logger.warning(f"警告:图片文件不存在 {img_path}") return markdown_content def post_process_multi_level_headings(self, json_data, md_content): @@ -298,7 +296,7 @@ def parse_pdf( elif parse_method == "ocr": pipe = OCRPipe(pdf_bytes, model_json, image_writer) else: - logger("unknown parse method, only auto, ocr, txt allowed") + logger.error("unknown parse method, only auto, ocr, txt allowed") exit(1) # 执行分类 @@ -309,12 +307,9 @@ def parse_pdf( if model_config.__use_inside_model__: pipe.pipe_analyze() # 解析 else: - logger("need model list input") + logger.error("need model list input") exit(1) - # Some dirty code from mineru modified log level to warning - logging.getLogger().setLevel(logging.INFO) - # 执行解析 pipe.pipe_parse() content_list = pipe.pipe_mk_uni_format(temp_file_path, drop_mode="none") @@ -328,7 +323,7 @@ def parse_pdf( return new_md_content except Exception as e: - logger(e) + logger.error(e) return None def load_data( @@ -377,5 +372,5 @@ def load( ) docs.append(doc) logger.info(f"processed pdf file {file_path} without metadata") - print(f"[PaiPDFReader] successfully loaded {len(docs)} nodes.") + logger.info(f"[PaiPDFReader] successfully loaded {len(docs)} nodes.") return docs diff --git a/src/pai_rag/integrations/search/bing_search.py b/src/pai_rag/integrations/search/bing_search.py index 6c3c1928..67874411 100644 --- a/src/pai_rag/integrations/search/bing_search.py +++ b/src/pai_rag/integrations/search/bing_search.py @@ -7,11 +7,10 @@ from llama_index.core.response_synthesizers import BaseSynthesizer from llama_index.core.schema import QueryBundle import faiss -import logging +from loguru import logger from pai_rag.integrations.search.bs4_reader import ParallelBeautifulSoupWebReader -logger = logging.getLogger(__name__) DEFAULT_ENDPOINT_BASE_URL = "https://api.bing.microsoft.com/v7.0/search" DEFAULT_SEARCH_COUNT = 10 diff --git a/src/pai_rag/integrations/search/bs4_reader.py b/src/pai_rag/integrations/search/bs4_reader.py index 9d5d1e0f..a3df3bd8 100644 --- a/src/pai_rag/integrations/search/bs4_reader.py +++ b/src/pai_rag/integrations/search/bs4_reader.py @@ -2,7 +2,7 @@ import asyncio import nest_asyncio -import logging +from loguru import logger from typing import Any, Callable, Dict, List, Optional, Tuple from urllib.parse import urljoin @@ -11,8 +11,6 @@ from llama_index.core.readers.base import BasePydanticReader from llama_index.core.schema import Document -logger = logging.getLogger(__name__) - def _substack_reader(soup: Any, **kwargs) -> Tuple[str, Dict[str, Any]]: """Extract text from Substack blog post.""" @@ -157,7 +155,6 @@ def fetch_multiple(urls): nest_asyncio.apply() tasks = [fetch_url(url) for url in urls] results = asyncio.run(asyncio.gather(*tasks)) - print(results) return results diff --git a/src/pai_rag/integrations/synthesizer/pai_synthesizer.py b/src/pai_rag/integrations/synthesizer/pai_synthesizer.py index 3933514d..68adcfad 100644 --- a/src/pai_rag/integrations/synthesizer/pai_synthesizer.py +++ b/src/pai_rag/integrations/synthesizer/pai_synthesizer.py @@ -37,11 +37,10 @@ astream_completion_response_to_tokens, ) from llama_index.core.prompts import PromptTemplate -import logging +from loguru import logger dispatcher = instrument.get_dispatcher(__name__) -logger = logging.getLogger(__name__) DEFAULT_LLM_CHAT_TMPL = ( "You are a helpful assistant." diff --git a/src/pai_rag/integrations/vector_stores/elasticsearch/my_async_vector_store.py b/src/pai_rag/integrations/vector_stores/elasticsearch/my_async_vector_store.py index fa1dc4ea..db2b7bfb 100644 --- a/src/pai_rag/integrations/vector_stores/elasticsearch/my_async_vector_store.py +++ b/src/pai_rag/integrations/vector_stores/elasticsearch/my_async_vector_store.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import logging import uuid from typing import Any, Callable, Dict, List, Optional @@ -27,8 +26,7 @@ AsyncRetrievalStrategy, ) from elasticsearch.helpers.vectorstore._utils import maximal_marginal_relevance - -logger = logging.getLogger(__name__) +from loguru import logger class AsyncVectorStore: diff --git a/src/pai_rag/integrations/vector_stores/elasticsearch/my_elasticsearch.py b/src/pai_rag/integrations/vector_stores/elasticsearch/my_elasticsearch.py index b9f171ea..90f0e362 100644 --- a/src/pai_rag/integrations/vector_stores/elasticsearch/my_elasticsearch.py +++ b/src/pai_rag/integrations/vector_stores/elasticsearch/my_elasticsearch.py @@ -1,7 +1,7 @@ """Elasticsearch vector store.""" import asyncio -from logging import getLogger +from loguru import logger from typing import Any, Callable, Dict, List, Literal, Optional, Union import nest_asyncio @@ -36,8 +36,6 @@ get_user_agent, ) -logger = getLogger(__name__) - DISTANCE_STRATEGIES = Literal[ "COSINE", "DOT_PRODUCT", diff --git a/src/pai_rag/integrations/vector_stores/faiss/my_faiss.py b/src/pai_rag/integrations/vector_stores/faiss/my_faiss.py index 201d5938..3e719f15 100644 --- a/src/pai_rag/integrations/vector_stores/faiss/my_faiss.py +++ b/src/pai_rag/integrations/vector_stores/faiss/my_faiss.py @@ -4,7 +4,6 @@ """ -import logging import os from typing import Any, List, cast @@ -19,7 +18,7 @@ from llama_index.vector_stores.faiss import FaissVectorStore from pai_rag.utils.score_utils import normalize_cosine_similarity_score -logger = logging.getLogger() +from loguru import logger DEFAULT_PERSIST_PATH = os.path.join( DEFAULT_PERSIST_DIR, f"{DEFAULT_VECTOR_STORE}{NAMESPACE_SEP}{DEFAULT_PERSIST_FNAME}" diff --git a/src/pai_rag/integrations/vector_stores/hologres/hologres.py b/src/pai_rag/integrations/vector_stores/hologres/hologres.py index f6b3b8a4..d971f041 100644 --- a/src/pai_rag/integrations/vector_stores/hologres/hologres.py +++ b/src/pai_rag/integrations/vector_stores/hologres/hologres.py @@ -3,7 +3,6 @@ Vector store using hologres back end. """ -import logging from typing import Any, List, cast, Dict from hologres_vector import HologresVector from llama_index.core.bridge.pydantic import PrivateAttr @@ -15,8 +14,6 @@ from llama_index.core.vector_stores.types import BasePydanticVectorStore from pai_rag.utils.score_utils import normalize_cosine_similarity_score -logger = logging.getLogger() - class HologresVectorStore(BasePydanticVectorStore): """Hologres Vector Store. diff --git a/src/pai_rag/integrations/vector_stores/milvus/my_milvus.py b/src/pai_rag/integrations/vector_stores/milvus/my_milvus.py index 78f9ddae..00bf3105 100644 --- a/src/pai_rag/integrations/vector_stores/milvus/my_milvus.py +++ b/src/pai_rag/integrations/vector_stores/milvus/my_milvus.py @@ -4,7 +4,7 @@ """ -import logging +from loguru import logger from typing import Any, Dict, List, Optional, Union import pymilvus # noqa @@ -34,8 +34,6 @@ from pymilvus import Collection, MilvusClient, DataType, AnnSearchRequest from pai_rag.utils.score_utils import normalize_cosine_similarity_score -logger = logging.getLogger(__name__) - DEFAULT_BATCH_SIZE = 100 MILVUS_ID_FIELD = "id" diff --git a/src/pai_rag/integrations/vector_stores/postgresql/postgresql.py b/src/pai_rag/integrations/vector_stores/postgresql/postgresql.py index e1f286ea..f8b13f65 100644 --- a/src/pai_rag/integrations/vector_stores/postgresql/postgresql.py +++ b/src/pai_rag/integrations/vector_stores/postgresql/postgresql.py @@ -1,4 +1,4 @@ -import logging +from loguru import logger import re from typing import Any, List, NamedTuple, Optional, Type, Union from urllib.parse import quote_plus @@ -32,9 +32,6 @@ class DBEmbeddingRow(NamedTuple): similarity: float -_logger = logging.getLogger(__name__) - - def get_data_model( base: Type, index_name: str, @@ -333,7 +330,7 @@ def _create_extension(self) -> None: statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS pg_jieba") session.execute(statement) except Exception: - _logger.warning("create extension pg_jieba failed") + logger.warning("create extension pg_jieba failed") session.commit() def _initialize(self) -> None: @@ -357,12 +354,12 @@ def _extension_load(self) -> None: session.execute( sqlalchemy.text("SELECT jieba_load_user_dict(0,0)") ) - _logger.info("session load jieba_load_user_dict success!") + logger.info("session load jieba_load_user_dict success!") session.commit() except Exception as e: - _logger.warning(e) + logger.warning(e) self._is_extension_load = True - _logger.info("load extension done!") + logger.info("load extension done!") async def _async_extension_load(self) -> None: if not self._is_async_extension_load: @@ -376,11 +373,11 @@ async def _async_extension_load(self) -> None: await async_session.execute( sqlalchemy.text("SELECT jieba_load_user_dict(0,0)") ) - _logger.info("async_session load jieba_load_user_dict success!") + logger.info("async_session load jieba_load_user_dict success!") except Exception as e: - _logger.warning(e) + logger.warning(e) self._is_async_extension_load = True - _logger.info("async load extension done!") + logger.info("async load extension done!") def _node_to_table_row(self, node: BaseNode) -> Any: return self._table_class( @@ -442,7 +439,7 @@ def _to_postgres_operator(self, operator: FilterOperator) -> str: elif operator == FilterOperator.CONTAINS: return "@>" else: - _logger.warning(f"Unknown operator: {operator}, fallback to '='") + logger.warning(f"Unknown operator: {operator}, fallback to '='") return "=" def _build_filter_clause(self, filter_: MetadataFilter) -> Any: @@ -707,7 +704,7 @@ async def _async_hybrid_query( import asyncio if query.alpha is not None: - _logger.warning("postgres hybrid search does not support alpha parameter.") + logger.warning("postgres hybrid search does not support alpha parameter.") sparse_top_k = query.sparse_top_k or query.similarity_top_k @@ -731,7 +728,7 @@ def _hybrid_query( self, query: VectorStoreQuery, **kwargs: Any ) -> List[DBEmbeddingRow]: if query.alpha is not None: - _logger.warning("postgres hybrid search does not support alpha parameter.") + logger.warning("postgres hybrid search does not support alpha parameter.") sparse_top_k = query.sparse_top_k or query.similarity_top_k diff --git a/src/pai_rag/main.py b/src/pai_rag/main.py index 84f62719..b4cad535 100644 --- a/src/pai_rag/main.py +++ b/src/pai_rag/main.py @@ -4,14 +4,9 @@ import uvicorn from fastapi import FastAPI from pai_rag.core.rag_config_manager import RagConfigManager -from pai_rag.utils.constants import DEFAULT_MODEL_DIR, EAS_DEFAULT_MODEL_DIR -from logging.config import dictConfig +from pai_rag.utils.constants import DEFAULT_MODEL_DIR import os from pathlib import Path -import logging - -logger = logging.getLogger(__name__) - _BASE_DIR = Path(__file__).parent _ROOT_BASE_DIR = Path(__file__).parent.parent.parent @@ -25,45 +20,6 @@ DEFAULT_GRADIO_PORT = 8002 -def init_log(): - log_config = { - "version": 1, - "disable_existing_loggers": False, - "filters": { - "correlation_id": { - "()": "asgi_correlation_id.CorrelationIdFilter", - "uuid_length": 32, - "default_value": "-", - }, - }, - "formatters": { - "sample": { - "format": "%(asctime)s %(levelname)s [%(correlation_id)s] %(message)s" - }, - "verbose": { - "format": "%(asctime)s %(levelname)s [%(correlation_id)s] %(name)s %(process)d %(thread)d %(message)s" - }, - "access": { - "()": "uvicorn.logging.AccessFormatter", - "fmt": '%(asctime)s %(levelprefix)s %(client_addr)s [%(correlation_id)s] - "%(request_line)s" %(status_code)s', - }, - }, - "handlers": { - "console": { - "formatter": "verbose", - "level": "DEBUG", - "filters": ["correlation_id"], - "class": "logging.StreamHandler", - }, - }, - "loggers": { - "": {"level": "INFO", "handlers": ["console"]}, - }, - } - dictConfig(log_config) - - -init_log() app = FastAPI() @@ -163,16 +119,7 @@ def serve(host, port, config_file, workers, enable_example, skip_download_models rag_configuration = RagConfigManager.from_file(config_file) rag_configuration.persist() - - if not skip_download_models and DEFAULT_MODEL_DIR != EAS_DEFAULT_MODEL_DIR: - logger.info("Start to download models.") - ModelScopeDownloader().load_basic_models() - ModelScopeDownloader().load_mineru_config() - logger.info("Finished downloading models.") - else: - logger.info("Start to loading minerU config file.") - ModelScopeDownloader().load_mineru_config() - logger.info("Finished loading minerU config file.") + ModelScopeDownloader().load_rag_models(skip_download_models) os.environ["PAI_RAG_MODEL_DIR"] = DEFAULT_MODEL_DIR app = FastAPI() diff --git a/src/pai_rag/tools/agent_tool.py b/src/pai_rag/tools/agent_tool.py index 122903c3..07f8ff75 100644 --- a/src/pai_rag/tools/agent_tool.py +++ b/src/pai_rag/tools/agent_tool.py @@ -3,9 +3,7 @@ from pathlib import Path from pai_rag.core.rag_config_manager import RagConfigManager from pai_rag.core.rag_module import resolve_agent -import logging - -logger = logging.getLogger(__name__) +from loguru import logger _BASE_DIR = Path(__file__).parent.parent DEFAULT_APPLICATION_CONFIG_FILE = os.path.join(_BASE_DIR, "config/settings.toml") @@ -56,8 +54,6 @@ def run( tool_definition_file=None, python_script_file=None, ): - logging.basicConfig(level=logging.DEBUG) - config = RagConfigManager.from_file(config_file).get_value() if tool_definition_file: config.agent.tool_definition_file = tool_definition_file @@ -66,9 +62,9 @@ def run( agent = resolve_agent(config) - print("**Question**: ", question) + logger.info("**Question**: ", question) response = agent.chat(question) - print("**Answer**: ", response.response) + logger.info("**Answer**: ", response.response) if __name__ == "__main__": diff --git a/src/pai_rag/tools/load_data_tool.py b/src/pai_rag/tools/load_data_tool.py index d3dd94cd..e9e1c26f 100644 --- a/src/pai_rag/tools/load_data_tool.py +++ b/src/pai_rag/tools/load_data_tool.py @@ -2,13 +2,11 @@ import os from pathlib import Path from pai_rag.core.rag_config_manager import RagConfigManager +from pai_rag.utils.download_models import ModelScopeDownloader from pai_rag.core.rag_module import resolve_data_loader -import logging +from loguru import logger -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - _BASE_DIR = Path(__file__).parent.parent DEFAULT_APPLICATION_CONFIG_FILE = os.path.join(_BASE_DIR, "config/settings.toml") @@ -71,7 +69,7 @@ def run( ), f"Can not provide both local path '{data_path}' and oss path '{oss_path}'." config = RagConfigManager.from_file(config_file).get_value() - + ModelScopeDownloader().load_rag_models() data_loader = resolve_data_loader(config) data_loader.load_data( file_path_or_directory=data_path, @@ -80,3 +78,4 @@ def run( from_oss=oss_path is not None, enable_raptor=enable_raptor, ) + logger.info("Load data tool invoke finished") diff --git a/src/pai_rag/tools/query_tool.py b/src/pai_rag/tools/query_tool.py index d4dd81e7..59ac019c 100644 --- a/src/pai_rag/tools/query_tool.py +++ b/src/pai_rag/tools/query_tool.py @@ -4,9 +4,7 @@ from pai_rag.core.rag_config_manager import RagConfigManager from pai_rag.core.rag_module import resolve_query_engine from pai_rag.integrations.synthesizer.pai_synthesizer import PaiQueryBundle -import logging - -logger = logging.getLogger(__name__) +from loguru import logger _BASE_DIR = Path(__file__).parent.parent DEFAULT_APPLICATION_CONFIG_FILE = os.path.join(_BASE_DIR, "config/settings.toml") @@ -41,24 +39,22 @@ def run( question=None, stream=False, ): - logging.basicConfig(level=logging.INFO) - config = RagConfigManager.from_file(config_file).get_value() query_engine = resolve_query_engine(config) - print("**Question**: ", question) + logger.info("**Question**: ", question) if not stream: query_bundle = PaiQueryBundle(query_str=question, stream=False) response = query_engine.query(query_bundle) - print("**Answer**: ", response.response) + logger.info("**Answer**: ", response.response) else: query_bundle = PaiQueryBundle(query_str=question, stream=True) response = query_engine.query(query_bundle) - print("**Answer**: ", end="") + logger.info("**Answer**: ", end="") for chunk in response.response_gen: - print(chunk, end="") + logger.info(chunk, end="") if __name__ == "__main__": diff --git a/src/pai_rag/utils/download_models.py b/src/pai_rag/utils/download_models.py index cf75859b..dc246077 100644 --- a/src/pai_rag/utils/download_models.py +++ b/src/pai_rag/utils/download_models.py @@ -1,4 +1,4 @@ -from pai_rag.utils.constants import DEFAULT_MODEL_DIR, OSS_URL +from pai_rag.utils.constants import DEFAULT_MODEL_DIR, EAS_DEFAULT_MODEL_DIR, OSS_URL from modelscope.hub.snapshot_download import snapshot_download from tempfile import TemporaryDirectory from pathlib import Path @@ -6,12 +6,10 @@ import shutil import os import time -import logging +from loguru import logger import click import json -logger = logging.getLogger(__name__) - class ModelScopeDownloader: def __init__(self, fetch_config: bool = False): @@ -47,20 +45,31 @@ def load_model(self, model): f"Finished downloading model {model} to {model_path}, took {duration:.2f} seconds." ) + def load_rag_models(self, skip_download_models: bool = False): + if not skip_download_models and DEFAULT_MODEL_DIR != EAS_DEFAULT_MODEL_DIR: + logger.info("Not in EAS-like environment, start downloading models.") + self.load_basic_models() + self.load_mineru_config() + def load_basic_models(self): + logger.info("Start to download basic models.") if not hasattr(self, "model_info"): response = requests.get(OSS_URL) response.raise_for_status() self.model_info = response.json() for model in self.model_info["basic_models"].keys(): self.load_model(model) + logger.info("Finished downloading basic models.") def load_mineru_config(self): + logger.info("Start to loading minerU config file.") source_path = "magic-pdf.template.json" destination_path = os.path.expanduser("~/magic-pdf.json") # 目标路径 if os.path.exists(destination_path): - print("magic-pdf.json already exists, skip modifying ~/magic-pdf.json.") + logger.info( + "magic-pdf.json already exists, skip modifying ~/magic-pdf.json." + ) return # 读取 source_path 文件的内容 @@ -76,7 +85,7 @@ def load_mineru_config(self): with open(destination_path, "w") as destination_file: json.dump(data, destination_file, indent=4) - print( + logger.info( "Copy magic-pdf.template.json to ~/magic-pdf.json and modify models-dir to model path." ) diff --git a/src/pai_rag/utils/embed_utils.py b/src/pai_rag/utils/embed_utils.py index 7ba74958..fdf006ae 100644 --- a/src/pai_rag/utils/embed_utils.py +++ b/src/pai_rag/utils/embed_utils.py @@ -1,5 +1,5 @@ from io import BytesIO -import logging +from loguru import logger import httpx import asyncio import numpy as np @@ -8,8 +8,6 @@ from llama_index.core.schema import BaseNode, MetadataMode, ImageNode from typing import Dict, List, Sequence -logger = logging.getLogger(__name__) - async def download_url(url): if not url: diff --git a/src/pai_rag/utils/markdown_utils.py b/src/pai_rag/utils/markdown_utils.py index 4746a367..24d89eda 100644 --- a/src/pai_rag/utils/markdown_utils.py +++ b/src/pai_rag/utils/markdown_utils.py @@ -4,6 +4,7 @@ from typing import Any, List, Optional from llama_index.core.bridge.pydantic import Field, BaseModel import math +from loguru import logger IMAGE_MAX_PIXELS = 512 * 512 @@ -84,12 +85,12 @@ def transform_local_to_oss(oss_cache: Any, image: PngImageFile, doc_name: str) - }, # set public read to make image accessible path_prefix=f"pairag/doc_images/{doc_name.strip()}/", ) - print( + logger.info( f"Cropped image {image_url} with width={image.width}, height={image.height}." ) return image_url except Exception as e: - print(f"无法打开图片 '{image}': {e}") + logger.warning(f"无法打开图片 '{image}': {e}") def _table_to_markdown(self, table, doc_name): diff --git a/src/pai_rag/utils/oss_client.py b/src/pai_rag/utils/oss_client.py index bd1704ba..d0f280c7 100644 --- a/src/pai_rag/utils/oss_client.py +++ b/src/pai_rag/utils/oss_client.py @@ -1,10 +1,8 @@ -import logging import hashlib import oss2 import os from oss2.credentials import EnvironmentVariableCredentialsProvider - -logger = logging.getLogger(__name__) +from loguru import logger class OssClient: diff --git a/src/pai_rag/utils/oss_utils.py b/src/pai_rag/utils/oss_utils.py index b9151d93..9452283b 100644 --- a/src/pai_rag/utils/oss_utils.py +++ b/src/pai_rag/utils/oss_utils.py @@ -1,6 +1,5 @@ import oss2 import os - from pai_rag.core.rag_config import RagConfig diff --git a/src/pai_rag/utils/tokenization_qwen.py b/src/pai_rag/utils/tokenization_qwen.py index a036a20f..c5ffa669 100644 --- a/src/pai_rag/utils/tokenization_qwen.py +++ b/src/pai_rag/utils/tokenization_qwen.py @@ -6,15 +6,13 @@ """Tokenization classes for QWen.""" import base64 -import logging import os import unicodedata from typing import Collection, Dict, List, Set, Tuple, Union import tiktoken from transformers import PreTrainedTokenizer, AddedToken - -logger = logging.getLogger(__name__) +from loguru import logger VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"} diff --git a/tests/data_readers/test_pdf_reader.py b/tests/data_readers/test_pdf_reader.py index 93ebbd0d..06baac20 100644 --- a/tests/data_readers/test_pdf_reader.py +++ b/tests/data_readers/test_pdf_reader.py @@ -18,8 +18,7 @@ def test_pai_pdf_reader(): reader_config=config.data_reader, ) input_dir = "tests/testdata/data/pdf_data" - ModelScopeDownloader().load_basic_models() - ModelScopeDownloader().load_mineru_config() + ModelScopeDownloader().load_rag_models() directory_reader.file_readers[".pdf"] = PaiPDFReader() diff --git a/tests/integrations/test_markdown_mode_parser.py b/tests/integrations/test_markdown_mode_parser.py index f074dec5..ab0cc55d 100644 --- a/tests/integrations/test_markdown_mode_parser.py +++ b/tests/integrations/test_markdown_mode_parser.py @@ -17,8 +17,7 @@ def test_markdown_parser(): reader_config=config.data_reader, ) input_dir = "tests/testdata/data/pdf_data" - ModelScopeDownloader().load_basic_models() - ModelScopeDownloader().load_mineru_config() + ModelScopeDownloader().load_rag_models() documents = directory_reader.load_data(file_path_or_directory=input_dir) md_node_parser = MarkdownNodeParser(enable_multimodal=False) splitted_nodes = []