-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathknowledge_base.py
30 lines (23 loc) · 984 Bytes
/
knowledge_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import logging
import torch
from sentence_transformers import SentenceTransformer, util
from typing import List
class KnowledgeBase():
def __init__(
self,
answers: List[str],
model_name="multi-qa-MiniLM-L6-cos-v1"
):
self.answers = answers
logging.debug("Loading the sentence embedding model")
self.embedding_model = SentenceTransformer(model_name)
logging.debug("Embedding knowledge base answers")
self.answer_embeddings = self.embedding_model.encode(
self.answers, show_progress_bar=False)
def look_up(self, question: str) -> str:
question_embedding = self.embedding_model.encode(
question, show_progress_bar=False)
cos_scores = util.dot_score(question_embedding, self.answer_embeddings)
top_results = torch.topk(cos_scores, k=1)
top_score, top_idx = top_results[0][0], top_results[1][0]
return self.answers[top_idx], top_score.item()