Skip to content

Commit

Permalink
Support RDS postgres vector store (#134)
Browse files Browse the repository at this point in the history
* support rds postgers for store engine

* Format

* support table

* Make format

---------

Co-authored-by: Yue Fei <[email protected]>
  • Loading branch information
zt2645802240 and moria97 authored Jul 31, 2024
1 parent f0fc85c commit 43b1c17
Show file tree
Hide file tree
Showing 8 changed files with 982 additions and 5 deletions.
81 changes: 76 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ protobuf = "3.20.0"
modelscope = "^1.16.0"
llama-index-multi-modal-llms-dashscope = "^0.1.2"
llama-index-vector-stores-alibabacloud-opensearch = "^0.1.0"
asyncpg = "^0.29.0"
pgvector = "^0.3.2"

[tool.poetry.scripts]
pai_rag = "pai_rag.main:main"
Expand Down
2 changes: 2 additions & 0 deletions pyproject_gpu.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ protobuf = "3.20.0"
modelscope = "^1.16.0"
llama-index-multi-modal-llms-dashscope = "^0.1.2"
llama-index-vector-stores-alibabacloud-opensearch = "^0.1.0"
asyncpg = "^0.29.0"
pgvector = "^0.3.2"

[tool.poetry.scripts]
pai_rag = "pai_rag.main:main"
Expand Down
46 changes: 46 additions & 0 deletions src/pai_rag/app/web/tabs/vector_db_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def create_vector_db_panel(
"AnalyticDB",
"FAISS",
"OpenSearch",
"PostgreSQL",
],
label="Which VectorStore do you want to use?",
elem_id="vectordb_type",
Expand Down Expand Up @@ -255,6 +256,40 @@ def create_vector_db_panel(
outputs=con_state_opensearch,
api_name="connect_faiss",
)
with gr.Column(visible=(vectordb_type == "PostgreSQL")) as postgresql_col:
postgresql_host = gr.Textbox(label="Host", elem_id="postgresql_host")
postgresql_port = gr.Textbox(label="Port", elem_id="postgresql_port")
postgresql_database = gr.Textbox(
label="Database", elem_id="postgresql_database"
)
postgresql_table_name = gr.Textbox(
label="TableName", elem_id="postgresql_table_name"
)
postgresql_password = gr.Textbox(
label="Password", elem_id="postgresql_password"
)
postgresql_username = gr.Textbox(
label="UserName", elem_id="postgresql_username"
)
connect_btn_pg = gr.Button("Connect PostgreSQL", variant="primary")
con_state_pg = gr.Textbox(label="Connection Info: ")
inputs_pg = input_elements.union(
{
vectordb_type,
postgresql_host,
postgresql_port,
postgresql_database,
postgresql_table_name,
postgresql_username,
postgresql_password,
}
)
connect_btn_pg.click(
fn=connect_vector_func,
inputs=inputs_pg,
outputs=con_state_pg,
api_name="connect_pg",
)

def change_vectordb_conn(vectordb_type):
adb_visible = False
Expand All @@ -263,6 +298,7 @@ def change_vectordb_conn(vectordb_type):
es_visible = False
milvus_visible = False
opensearch_visible = False
postgresql_visible = False
if vectordb_type == "AnalyticDB":
adb_visible = True
elif vectordb_type == "Hologres":
Expand All @@ -275,6 +311,8 @@ def change_vectordb_conn(vectordb_type):
faiss_visible = True
elif vectordb_type == "OpenSearch":
opensearch_visible = True
elif vectordb_type == "PostgreSQL":
postgresql_visible = True

return {
adb_col: gr.update(visible=adb_visible),
Expand All @@ -283,6 +321,7 @@ def change_vectordb_conn(vectordb_type):
faiss_col: gr.update(visible=faiss_visible),
milvus_col: gr.update(visible=milvus_visible),
opensearch_col: gr.update(visible=opensearch_visible),
postgresql_col: gr.update(visible=postgresql_visible),
}

vectordb_type.change(
Expand All @@ -295,6 +334,7 @@ def change_vectordb_conn(vectordb_type):
es_col,
milvus_col,
opensearch_col,
postgresql_col,
],
)

Expand Down Expand Up @@ -332,6 +372,12 @@ def change_vectordb_conn(vectordb_type):
opensearch_username,
opensearch_password,
opensearch_table_name,
postgresql_host,
postgresql_port,
postgresql_database,
postgresql_table_name,
postgresql_username,
postgresql_password,
]
)

Expand Down
34 changes: 34 additions & 0 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ class ViewModel(BaseModel):
opensearch_password: str = None
opensearch_table_name: str = "pairag"

# PostgreSQL
postgresql_host: str = None
postgresql_port: int = 5432
postgresql_database: str = None
postgresql_table_name: str = "pairag"
postgresql_username: str = None
postgresql_password: str = None

# retriever
similarity_top_k: int = 5
retrieval_mode: str = "hybrid" # hybrid / embedding / keyword
Expand Down Expand Up @@ -218,6 +226,16 @@ def from_app_config(config):
"table_name"
]

elif view_model.vectordb_type.lower() == "postgresql":
view_model.postgresql_host = config["index"]["vector_store"]["host"]
view_model.postgresql_port = config["index"]["vector_store"]["port"]
view_model.postgresql_database = config["index"]["vector_store"]["database"]
view_model.postgresql_table_name = config["index"]["vector_store"][
"table_name"
]
view_model.postgresql_username = config["index"]["vector_store"]["username"]
view_model.postgresql_password = config["index"]["vector_store"]["password"]

view_model.parser_type = config["node_parser"]["type"]
view_model.chunk_size = config["node_parser"]["chunk_size"]
view_model.chunk_overlap = config["node_parser"]["chunk_overlap"]
Expand Down Expand Up @@ -356,6 +374,14 @@ def to_app_config(self):
config["index"]["vector_store"]["password"] = self.opensearch_password
config["index"]["vector_store"]["table_name"] = self.opensearch_table_name

elif self.vectordb_type.lower() == "postgresql":
config["index"]["vector_store"]["host"] = self.postgresql_host
config["index"]["vector_store"]["port"] = self.postgresql_port
config["index"]["vector_store"]["database"] = self.postgresql_database
config["index"]["vector_store"]["table_name"] = self.postgresql_table_name
config["index"]["vector_store"]["username"] = self.postgresql_username
config["index"]["vector_store"]["password"] = self.postgresql_password

config["retriever"]["similarity_top_k"] = self.similarity_top_k
if self.retrieval_mode == "Hybrid":
config["retriever"]["retrieval_mode"] = "hybrid"
Expand Down Expand Up @@ -536,6 +562,14 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]:
settings["opensearch_password"] = {"value": self.opensearch_password}
settings["opensearch_table_name"] = {"value": self.opensearch_table_name}

# postgresql
settings["postgresql_host"] = {"value": self.postgresql_host}
settings["postgresql_port"] = {"value": self.postgresql_port}
settings["postgresql_database"] = {"value": self.postgresql_database}
settings["postgresql_table_name"] = {"value": self.postgresql_table_name}
settings["postgresql_username"] = {"value": self.postgresql_username}
settings["postgresql_password"] = {"value": self.postgresql_password}

# evaluation
if self.vectordb_type == "FAISS":
qa_dataset_path, qa_dataset_res = self.get_local_generated_qa_file()
Expand Down
Loading

0 comments on commit 43b1c17

Please sign in to comment.