Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Couple fixes for RAG evaluation #979

Merged
merged 2 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/base/v1alpha1/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ func (worker Worker) BuildEmbedder() *Embedder {
Type: embeddings.OpenAI,
Provider: Provider{
Worker: &TypedObjectReference{
APIGroup: pointer.String(GroupVersion.String()),
Kind: "Worker",
Namespace: &worker.Namespace,
Name: worker.Name,
Expand All @@ -188,6 +189,7 @@ func (worker Worker) BuildLLM() *LLM {
Type: llms.OpenAI,
Provider: Provider{
Worker: &TypedObjectReference{
APIGroup: pointer.String(GroupVersion.String()),
Kind: "Worker",
Namespace: &worker.Namespace,
Name: worker.Name,
Expand Down
85 changes: 49 additions & 36 deletions apiserver/pkg/application/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat
}

func mutateApp(app *v1alpha1.Application, input generated.UpdateApplicationConfigInput, hasMultiQueryRetriever, hasRerankRetriever bool) error {
app.Spec.Nodes = redefineNodes(input.Knowledgebase, input.Name, input.Llm, input.Tools, hasMultiQueryRetriever, hasRerankRetriever, input.EnableUploadFile)
app.Spec.Nodes = redefineNodes(input.Knowledgebase, input.Namespace, input.Name, input.Llm, input.Tools, hasMultiQueryRetriever, hasRerankRetriever, input.EnableUploadFile)
app.Spec.Prologue = pointer.StringDeref(input.Prologue, app.Spec.Prologue)
app.Spec.ShowRespInfo = pointer.BoolDeref(input.ShowRespInfo, app.Spec.ShowRespInfo)
app.Spec.ShowRetrievalInfo = pointer.BoolDeref(input.ShowRetrievalInfo, app.Spec.ShowRetrievalInfo)
Expand All @@ -755,16 +755,18 @@ func mutateApp(app *v1alpha1.Application, input generated.UpdateApplicationConfi
return nil
}

func redefineNodes(knowledgebase *string, name string, llmName string, tools []*generated.ToolInput, hasMultiQueryRetriever, hasRerankRetriever bool, enableUploadFile *bool) (nodes []v1alpha1.Node) {
// redefineNodes redefine nodes in application
func redefineNodes(knowledgebase *string, namespace string, name string, llmName string, tools []*generated.ToolInput, hasMultiQueryRetriever, hasRerankRetriever bool, enableUploadFile *bool) (nodes []v1alpha1.Node) {
nodes = []v1alpha1.Node{
{
NodeConfig: v1alpha1.NodeConfig{
Name: "Input",
DisplayName: "用户输入",
Description: "用户输入节点,必须",
Ref: &v1alpha1.TypedObjectReference{
Kind: "Input",
Name: "Input",
Kind: "Input",
Name: "Input",
Namespace: &namespace,
},
},
NextNodeName: []string{"prompt-node"},
Expand All @@ -775,9 +777,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "prompt",
Description: "设定prompt,template中可以使用{{.}}来替换变量",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("prompt.arcadia.kubeagi.k8s.com.cn"),
Kind: "Prompt",
Name: name,
APIGroup: pointer.String("prompt.arcadia.kubeagi.k8s.com.cn"),
Kind: "Prompt",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"chain-node"},
Expand All @@ -790,9 +793,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "documentloader",
Description: "文档加载,可选",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "DocumentLoader",
Name: name,
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "DocumentLoader",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"chain-node"},
Expand All @@ -804,9 +808,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "llm",
Description: "设定大模型的访问信息",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "LLM",
Name: llmName,
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "LLM",
Name: llmName,
Namespace: &namespace,
},
},
NextNodeName: []string{"chain-node"},
Expand All @@ -821,9 +826,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "llm chain",
Description: "chain是langchain的核心概念,llmChain用于连接prompt和llm",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("chain.arcadia.kubeagi.k8s.com.cn"),
Kind: "LLMChain",
Name: name,
APIGroup: pointer.String("chain.arcadia.kubeagi.k8s.com.cn"),
Kind: "LLMChain",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"Output"},
Expand All @@ -836,9 +842,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "知识库",
Description: "连接知识库",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "KnowledgeBase",
Name: pointer.StringDeref(knowledgebase, ""),
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "KnowledgeBase",
Name: pointer.StringDeref(knowledgebase, ""),
Namespace: &namespace,
},
},
NextNodeName: []string{"retriever-node"},
Expand All @@ -849,9 +856,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "从知识库提取信息的retriever",
Description: "连接应用和知识库",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"),
Kind: "KnowledgeBaseRetriever",
Name: name,
APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"),
Kind: "KnowledgeBaseRetriever",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"chain-node"},
Expand All @@ -878,9 +886,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "多查询retriever",
Description: "多查询retriever",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"),
Kind: "MultiQueryRetriever",
Name: name,
APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"),
Kind: "MultiQueryRetriever",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{nextNodeName},
Expand All @@ -894,9 +903,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "rerank retriever",
Description: "rerank retriever",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"),
Kind: "RerankRetriever",
Name: name,
APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"),
Kind: "RerankRetriever",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"chain-node"},
Expand All @@ -909,9 +919,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "RetrievalQA chain",
Description: "chain是langchain的核心概念RetrievalQAChain用于从retriever中提取信息,供llm调用",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("chain.arcadia.kubeagi.k8s.com.cn"),
Kind: "RetrievalQAChain",
Name: name,
APIGroup: pointer.String("chain.arcadia.kubeagi.k8s.com.cn"),
Kind: "RetrievalQAChain",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"Output"},
Expand All @@ -924,9 +935,10 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "agent",
Description: "agent 调用复杂工具完成任务",
Ref: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "Agent",
Name: name,
APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"),
Kind: "Agent",
Name: name,
Namespace: &namespace,
},
},
NextNodeName: []string{"chain-node"},
Expand All @@ -938,8 +950,9 @@ func redefineNodes(knowledgebase *string, name string, llmName string, tools []*
DisplayName: "最终输出",
Description: "最终输出节点,必须",
Ref: &v1alpha1.TypedObjectReference{
Kind: "Output",
Name: "Output",
Kind: "Output",
Name: "Output",
Namespace: &namespace,
},
},
})
Expand Down
4 changes: 4 additions & 0 deletions apiserver/pkg/knowledgebase/knowledgebase.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func knowledgebase2model(ctx context.Context, c client.Client, knowledgebase *v1
}

source := &generated.TypedObjectReference{
APIGroup: fg.Source.APIGroup,
Kind: fg.Source.Kind,
Name: fg.Source.Name,
Namespace: new(string),
Expand Down Expand Up @@ -122,6 +123,7 @@ func knowledgebase2model(ctx context.Context, c client.Client, knowledgebase *v1

embedderResource := &v1alpha1.Embedder{}
embedder := generated.TypedObjectReference{
APIGroup: knowledgebase.Spec.Embedder.APIGroup,
Kind: knowledgebase.Spec.Embedder.Kind,
Name: knowledgebase.Spec.Embedder.Name,
Namespace: knowledgebase.Spec.Embedder.Namespace,
Expand Down Expand Up @@ -156,6 +158,7 @@ func knowledgebase2model(ctx context.Context, c client.Client, knowledgebase *v1
EmbedderType: &embedderType,
// Vector info
VectorStore: &generated.TypedObjectReference{
APIGroup: knowledgebase.Spec.VectorStore.APIGroup,
Kind: knowledgebase.Spec.VectorStore.Kind,
Name: knowledgebase.Spec.VectorStore.Name,
Namespace: knowledgebase.Spec.VectorStore.Namespace,
Expand Down Expand Up @@ -246,6 +249,7 @@ func CreateKnowledgeBase(ctx context.Context, c client.Client, input generated.C
Description: description,
},
Embedder: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String(v1alpha1.GroupVersion.String()),
Kind: "Embedder",
Name: embedder,
Namespace: &input.Namespace,
Expand Down
3 changes: 3 additions & 0 deletions apiserver/pkg/versioneddataset/versioned_dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ func UpdateVersionedDataset(ctx context.Context, c client.Client, input *generat
for _, item := range input.FileGroups {
tmp := v1alpha1.FileGroup{
Source: &v1alpha1.TypedObjectReference{
APIGroup: item.Source.APIGroup,
Kind: item.Source.Kind,
Name: item.Source.Name,
Namespace: item.Source.Namespace,
Expand Down Expand Up @@ -312,6 +313,7 @@ func CreateVersionedDataset(ctx context.Context, c client.Client, input *generat
vds.Spec = v1alpha1.VersionedDatasetSpec{
Version: input.Version,
Dataset: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String(v1alpha1.GroupVersion.String()),
Kind: "Dataset",
Name: input.DatasetName,
Namespace: &input.Namespace,
Expand All @@ -328,6 +330,7 @@ func CreateVersionedDataset(ctx context.Context, c client.Client, input *generat
for _, item := range input.FileGrups {
tmp := v1alpha1.FileGroup{
Source: &v1alpha1.TypedObjectReference{
APIGroup: item.Source.APIGroup,
Kind: item.Source.Kind,
Name: item.Source.Name,
Namespace: item.Source.Namespace,
Expand Down
1 change: 1 addition & 0 deletions apiserver/pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ func CreateWorker(ctx context.Context, c client.Client, input generated.CreateWo
},
Type: workerType,
Model: &v1alpha1.TypedObjectReference{
APIGroup: pointer.String(v1alpha1.GroupVersion.String()),
Name: input.Model.Name,
Namespace: &modelNs,
Kind: "Model",
Expand Down
5 changes: 5 additions & 0 deletions deploy/charts/arcadia/templates/rag-rbac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ rules:
- knowledgebases
- embedders
- vectorstores
- documentloaders
- agents
verbs:
- get
- list
Expand All @@ -36,6 +38,7 @@ rules:
resources:
- llmchains
- retrievalqachains
- apichains
verbs:
- get
- list
Expand All @@ -50,6 +53,8 @@ rules:
- retriever.arcadia.kubeagi.k8s.com.cn
resources:
- knowledgebaseretrievers
- multiqueryretrievers
- rerankretrievers
verbs:
- list
- get
Expand Down
5 changes: 5 additions & 0 deletions pkg/versioneddataset/versioneddataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/klog/v2"
"k8s.io/utils/pointer"

"github.com/kubeagi/arcadia/api/base/v1alpha1"
"github.com/kubeagi/arcadia/pkg/datasource"
Expand Down Expand Up @@ -61,6 +62,7 @@ func generateInheritedFileStatus(oss *datasource.OSS, instance *v1alpha1.Version
return []v1alpha1.FileStatus{
{
TypedObjectReference: v1alpha1.TypedObjectReference{
APIGroup: pointer.String(v1alpha1.GroupVersion.String()),
Name: name,
Namespace: &instance.Namespace,
Kind: "VersionedDataset",
Expand Down Expand Up @@ -93,6 +95,7 @@ func generateDatasourceFileStatus(instance *v1alpha1.VersionedDataset) []v1alpha
_, _ = fmt.Sscanf(datasource, "%s %s", &namespace, &name)
item := v1alpha1.FileStatus{
TypedObjectReference: v1alpha1.TypedObjectReference{
APIGroup: pointer.String(v1alpha1.GroupVersion.String()),
Name: name,
Namespace: &namespace,
Kind: "Datasource",
Expand Down Expand Up @@ -168,6 +171,8 @@ func CopiedFileGroup2Status(oss *datasource.OSS, instance *v1alpha1.VersionedDat
if len(datasourceFiles) > 0 {
ds := v1alpha1.FileStatus{
TypedObjectReference: v1alpha1.TypedObjectReference{
APIGroup: item.APIGroup,
Kind: item.Kind,
Name: item.Name,
Namespace: item.Namespace,
},
Expand Down
Loading