diff --git a/PROJECT b/PROJECT index 5b5224bac..cddf8a2ef 100644 --- a/PROJECT +++ b/PROJECT @@ -151,6 +151,22 @@ resources: kind: RerankRetriever path: github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1 version: v1alpha1 +- api: + crdVersion: v1 + controller: true + domain: arcadia.kubeagi.k8s.com.cn + group: retriever + kind: MergerRetriever + path: github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1 + version: v1alpha1 +- api: + crdVersion: v1 + controller: true + domain: arcadia.kubeagi.k8s.com.cn + group: retriever + kind: MultiQueryRetriever + path: github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1 + version: v1alpha1 - api: crdVersion: v1 controller: true diff --git a/api/app-node/common_type.go b/api/app-node/common_type.go index 18493f19b..185a4fcf1 100644 --- a/api/app-node/common_type.go +++ b/api/app-node/common_type.go @@ -25,8 +25,15 @@ import ( const ( InputLengthAnnotationKey = v1alpha1.Group + `/input-rules` OutputLengthAnnotationKey = v1alpha1.Group + `/output-rules` + + // ConversationKnowledgebaseName is the placeholder name of the conversation knowledgebase + ConversationKnowledgebaseName = "conversation-knowledgebase-placeholder" ) +func IsPlaceholderConversationKnowledgebase(name string) bool { + return name == ConversationKnowledgebaseName +} + type Ref struct { Kind string `json:"kind,omitempty"` Group string `json:"group,omitempty"` diff --git a/api/app-node/retriever/v1alpha1/mergerretriever_types.go b/api/app-node/retriever/v1alpha1/mergerretriever_types.go new file mode 100644 index 000000000..cb8991ea9 --- /dev/null +++ b/api/app-node/retriever/v1alpha1/mergerretriever_types.go @@ -0,0 +1,76 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + node "github.com/kubeagi/arcadia/api/app-node" + "github.com/kubeagi/arcadia/api/base/v1alpha1" +) + +// MergerRetrieverSpec defines the desired state of MergerRetriever +type MergerRetrieverSpec struct { + v1alpha1.CommonSpec `json:",inline"` +} + +// MergerRetrieverStatus defines the observed state of MergerRetriever +type MergerRetrieverStatus struct { + // ObservedGeneration is the last observed generation. + // +optional + ObservedGeneration int64 `json:"observedGeneration,omitempty"` + + // ConditionedStatus is the current status + v1alpha1.ConditionedStatus `json:",inline"` +} + +//+kubebuilder:object:root=true +//+kubebuilder:subresource:status + +// MergerRetriever is the Schema for the MergerRetriever API +type MergerRetriever struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + + Spec MergerRetrieverSpec `json:"spec,omitempty"` + Status MergerRetrieverStatus `json:"status,omitempty"` +} + +//+kubebuilder:object:root=true + +// MergerRetrieverList contains a list of MergerRetriever +type MergerRetrieverList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata,omitempty"` + Items []MergerRetriever `json:"items"` +} + +func init() { + SchemeBuilder.Register(&MergerRetriever{}, &MergerRetrieverList{}) +} + +var _ node.Node = (*MergerRetriever)(nil) + +func (c *MergerRetriever) SetRef() { + annotations := node.SetRefAnnotations(c.GetAnnotations(), []node.Ref{node.RetrieverRef.Len(1)}, []node.Ref{node.RetrievalQAChainRef.Len(1)}) + if c.GetAnnotations() == nil { + c.SetAnnotations(annotations) + } + for k, v := range annotations { + c.Annotations[k] = v + } +} diff --git a/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go b/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go index 648d0d494..be0faddb5 100644 --- a/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go +++ b/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go @@ -138,6 +138,97 @@ func (in *KnowledgeBaseRetrieverStatus) DeepCopy() *KnowledgeBaseRetrieverStatus return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetriever) DeepCopyInto(out *MergerRetriever) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + out.Spec = in.Spec + in.Status.DeepCopyInto(&out.Status) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetriever. +func (in *MergerRetriever) DeepCopy() *MergerRetriever { + if in == nil { + return nil + } + out := new(MergerRetriever) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *MergerRetriever) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetrieverList) DeepCopyInto(out *MergerRetrieverList) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ListMeta.DeepCopyInto(&out.ListMeta) + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]MergerRetriever, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetrieverList. +func (in *MergerRetrieverList) DeepCopy() *MergerRetrieverList { + if in == nil { + return nil + } + out := new(MergerRetrieverList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *MergerRetrieverList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetrieverSpec) DeepCopyInto(out *MergerRetrieverSpec) { + *out = *in + out.CommonSpec = in.CommonSpec +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetrieverSpec. +func (in *MergerRetrieverSpec) DeepCopy() *MergerRetrieverSpec { + if in == nil { + return nil + } + out := new(MergerRetrieverSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetrieverStatus) DeepCopyInto(out *MergerRetrieverStatus) { + *out = *in + in.ConditionedStatus.DeepCopyInto(&out.ConditionedStatus) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetrieverStatus. +func (in *MergerRetrieverStatus) DeepCopy() *MergerRetrieverStatus { + if in == nil { + return nil + } + out := new(MergerRetrieverStatus) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *MultiQueryRetriever) DeepCopyInto(out *MultiQueryRetriever) { *out = *in diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index 656022a90..d56d00231 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -87,6 +87,7 @@ type ComplexityRoot struct { EnableRerank func(childComplexity int) int EnableUploadFile func(childComplexity int) int Knowledgebase func(childComplexity int) int + Knowledgebases func(childComplexity int) int Llm func(childComplexity int) int MaxLength func(childComplexity int) int MaxTokens func(childComplexity int) int @@ -1097,6 +1098,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Application.Knowledgebase(childComplexity), true + case "Application.knowledgebases": + if e.complexity.Application.Knowledgebases == nil { + break + } + + return e.complexity.Application.Knowledgebases(childComplexity), true + case "Application.llm": if e.complexity.Application.Llm == nil { break @@ -5180,10 +5188,15 @@ type Application { """ conversionWindowSize: Int + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] + """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") """ scoreThreshold 最终返回结果的最低相似度 @@ -5504,7 +5517,12 @@ input UpdateApplicationConfigInput { """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") + + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] """ scoreThreshold 最终返回结果的最低相似度 @@ -10008,6 +10026,47 @@ func (ec *executionContext) fieldContext_Application_conversionWindowSize(ctx co return fc, nil } +func (ec *executionContext) _Application_knowledgebases(ctx context.Context, field graphql.CollectedField, obj *Application) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Application_knowledgebases(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Knowledgebases, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.([]*string) + fc.Result = res + return ec.marshalOString2ᚕᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Application_knowledgebases(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Application", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Application_knowledgebase(ctx context.Context, field graphql.CollectedField, obj *Application) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Application_knowledgebase(ctx, field) if err != nil { @@ -11699,6 +11758,8 @@ func (ec *executionContext) fieldContext_ApplicationMutation_updateApplicationCo return ec.fieldContext_Application_maxTokens(ctx, field) case "conversionWindowSize": return ec.fieldContext_Application_conversionWindowSize(ctx, field) + case "knowledgebases": + return ec.fieldContext_Application_knowledgebases(ctx, field) case "knowledgebase": return ec.fieldContext_Application_knowledgebase(ctx, field) case "scoreThreshold": @@ -11808,6 +11869,8 @@ func (ec *executionContext) fieldContext_ApplicationQuery_getApplication(ctx con return ec.fieldContext_Application_maxTokens(ctx, field) case "conversionWindowSize": return ec.fieldContext_Application_conversionWindowSize(ctx, field) + case "knowledgebases": + return ec.fieldContext_Application_knowledgebases(ctx, field) case "knowledgebase": return ec.fieldContext_Application_knowledgebase(ctx, field) case "scoreThreshold": @@ -28890,6 +28953,8 @@ func (ec *executionContext) fieldContext_RAG_application(ctx context.Context, fi return ec.fieldContext_Application_maxTokens(ctx, field) case "conversionWindowSize": return ec.fieldContext_Application_conversionWindowSize(ctx, field) + case "knowledgebases": + return ec.fieldContext_Application_knowledgebases(ctx, field) case "knowledgebase": return ec.fieldContext_Application_knowledgebase(ctx, field) case "scoreThreshold": @@ -39224,7 +39289,7 @@ func (ec *executionContext) unmarshalInputUpdateApplicationConfigInput(ctx conte asMap[k] = v } - fieldsInOrder := [...]string{"name", "namespace", "prologue", "model", "llm", "temperature", "maxLength", "maxTokens", "conversionWindowSize", "knowledgebase", "scoreThreshold", "numDocuments", "docNullReturn", "userPrompt", "systemPrompt", "showRespInfo", "showRetrievalInfo", "showNextGuide", "tools", "enableRerank", "rerankModel", "enableMultiQuery", "chatTimeout", "enableUploadFile", "chunkSize", "chunkOverlap", "batchSize"} + fieldsInOrder := [...]string{"name", "namespace", "prologue", "model", "llm", "temperature", "maxLength", "maxTokens", "conversionWindowSize", "knowledgebase", "knowledgebases", "scoreThreshold", "numDocuments", "docNullReturn", "userPrompt", "systemPrompt", "showRespInfo", "showRetrievalInfo", "showNextGuide", "tools", "enableRerank", "rerankModel", "enableMultiQuery", "chatTimeout", "enableUploadFile", "chunkSize", "chunkOverlap", "batchSize"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -39301,6 +39366,13 @@ func (ec *executionContext) unmarshalInputUpdateApplicationConfigInput(ctx conte return it, err } it.Knowledgebase = data + case "knowledgebases": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("knowledgebases")) + data, err := ec.unmarshalOString2ᚕᚖstring(ctx, v) + if err != nil { + return it, err + } + it.Knowledgebases = data case "scoreThreshold": ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("scoreThreshold")) data, err := ec.unmarshalOFloat2ᚖfloat64(ctx, v) @@ -40622,6 +40694,8 @@ func (ec *executionContext) _Application(ctx context.Context, sel ast.SelectionS out.Values[i] = ec._Application_maxTokens(ctx, field, obj) case "conversionWindowSize": out.Values[i] = ec._Application_conversionWindowSize(ctx, field, obj) + case "knowledgebases": + out.Values[i] = ec._Application_knowledgebases(ctx, field, obj) case "knowledgebase": out.Values[i] = ec._Application_knowledgebase(ctx, field, obj) case "scoreThreshold": diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index 86765d4ff..e90b6c531 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -54,6 +54,8 @@ type Application struct { MaxTokens *int `json:"maxTokens,omitempty"` // conversionWindowSize 对话轮次 ConversionWindowSize *int `json:"conversionWindowSize,omitempty"` + // knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + Knowledgebases []*string `json:"knowledgebases,omitempty"` // knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 Knowledgebase *string `json:"knowledgebase,omitempty"` // scoreThreshold 最终返回结果的最低相似度 @@ -1709,6 +1711,8 @@ type UpdateApplicationConfigInput struct { ConversionWindowSize *int `json:"conversionWindowSize,omitempty"` // knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 Knowledgebase *string `json:"knowledgebase,omitempty"` + // knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + Knowledgebases []*string `json:"knowledgebases,omitempty"` // scoreThreshold 最终返回结果的最低相似度 ScoreThreshold *float64 `json:"scoreThreshold,omitempty"` // numDocuments 最终返回结果的引用上限 diff --git a/apiserver/graph/schema/application.gql b/apiserver/graph/schema/application.gql index 82804af47..ee68e3d33 100644 --- a/apiserver/graph/schema/application.gql +++ b/apiserver/graph/schema/application.gql @@ -78,6 +78,7 @@ mutation updateApplicationConfig($input: UpdateApplicationConfigInput!){ maxTokens conversionWindowSize knowledgebase + knowledgebases scoreThreshold numDocuments docNullReturn @@ -131,6 +132,7 @@ query getApplication($name: String!, $namespace: String!){ maxTokens conversionWindowSize knowledgebase + knowledgebases scoreThreshold numDocuments docNullReturn diff --git a/apiserver/graph/schema/application.graphqls b/apiserver/graph/schema/application.graphqls index f6c04afe7..4637a3e31 100644 --- a/apiserver/graph/schema/application.graphqls +++ b/apiserver/graph/schema/application.graphqls @@ -59,10 +59,15 @@ type Application { """ conversionWindowSize: Int + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] + """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") """ scoreThreshold 最终返回结果的最低相似度 @@ -383,7 +388,12 @@ input UpdateApplicationConfigInput { """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") + + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] """ scoreThreshold 最终返回结果的最低相似度 diff --git a/apiserver/pkg/application/application.go b/apiserver/pkg/application/application.go index 4fe3bca10..10f49049b 100644 --- a/apiserver/pkg/application/application.go +++ b/apiserver/pkg/application/application.go @@ -33,6 +33,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + appnode "github.com/kubeagi/arcadia/api/app-node" apiagent "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" apichain "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" apidocumentloader "github.com/kubeagi/arcadia/api/app-node/documentloader/v1alpha1" @@ -65,7 +66,7 @@ func addDefaultValue(gApp *generated.Application, app *v1alpha1.Application) { if len(app.Spec.Nodes) > 0 { return } - gApp.DocNullReturn = pointer.String("未找到您询问的内容,请详细描述您的问题") + // gApp.DocNullReturn = pointer.String("未找到您询问的内容,请详细描述您的问题") gApp.NumDocuments = pointer.Int(5) gApp.ScoreThreshold = pointer.Float64(0.3) gApp.Temperature = pointer.Float64(0.7) @@ -136,9 +137,10 @@ func cr2app(prompt *apiprompt.Prompt, chainConfig *apichain.CommonChainConfig, r case "llm": gApp.Llm = node.Ref.Name case "knowledgebase": - gApp.Knowledgebase = pointer.String(node.Ref.Name) + gApp.Knowledgebases = append(gApp.Knowledgebases, pointer.String(node.Ref.Name)) } } + gApp.Knowledgebase = ConvertKnowledgebases2Knowledgebase(gApp.Knowledgebases) // nolint:staticcheck if retriever != nil { gApp.ScoreThreshold = pointer.Float64(float64(pointer.Float32Deref(retriever.ScoreThreshold, 0.0))) gApp.NumDocuments = pointer.Int(retriever.NumDocuments) @@ -429,6 +431,9 @@ func ListApplicationMeatadatas(ctx context.Context, c client.Client, input gener } func UpdateApplicationConfig(ctx context.Context, c client.Client, input generated.UpdateApplicationConfigInput) (*generated.Application, error) { + input.Knowledgebases = ConvertKnowledgebase2Knowledgebases(input.Knowledgebase, input.Knowledgebases) // nolint:staticcheck + hasKnowledgebaseOrEnableUpload := utils.HasValues(input.Knowledgebases) || pointer.BoolDeref(input.EnableUploadFile, false) // input has knowledgebases or enable upload file(may have conversation knowledgebase) + // check tool name not duplicated if len(input.Tools) != 0 { key := make(map[string]bool, len(input.Tools)) for _, tool := range input.Tools { @@ -438,6 +443,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat key[tool.Name] = true } } + key := types.NamespacedName{Namespace: input.Namespace, Name: input.Name} // get application cr, if not exist, return error app := &v1alpha1.Application{} @@ -510,12 +516,53 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat } } + // create or update placeholder conversation knowledgebase, if enable upload file + var ( + conversationKnowledgebaseRetriever *apiretriever.KnowledgeBaseRetriever + ) + // note: no need to create placeholder conversation knowledgebase cr, it will be replaced by true conversation knowledgebase in appruntime + if pointer.BoolDeref(input.EnableUploadFile, false) { + conversationKnowledgebaseRetriever = &apiretriever.KnowledgeBaseRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name + "-" + appnode.ConversationKnowledgebaseName, + Namespace: input.Namespace, + }, + Spec: apiretriever.KnowledgeBaseRetrieverSpec{ + CommonSpec: v1alpha1.CommonSpec{ + DisplayName: "retriever", + Description: "retriever", + }, + CommonRetrieverConfig: apiretriever.CommonRetrieverConfig{ + ScoreThreshold: pointer.Float32(float32(pointer.Float64Deref(input.ScoreThreshold, apiretriever.DefaultScoreThreshold))), + NumDocuments: pointer.IntDeref(input.NumDocuments, apiretriever.DefaultNumDocuments), + }, + }, + } + if _, err = controllerutil.CreateOrUpdate(ctx, c, conversationKnowledgebaseRetriever, func() error { + if input.ScoreThreshold != nil { + conversationKnowledgebaseRetriever.Spec.ScoreThreshold = pointer.Float32(float32(*input.ScoreThreshold)) + } + conversationKnowledgebaseRetriever.Spec.NumDocuments = pointer.IntDeref(input.NumDocuments, conversationKnowledgebaseRetriever.Spec.NumDocuments) + return nil + }); err != nil { + return nil, err + } + } else { + conversationKnowledgebaseRetriever = &apiretriever.KnowledgeBaseRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name + "-" + appnode.ConversationKnowledgebaseName, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, conversationKnowledgebaseRetriever) + } + // create or update chain var ( chainConfig *apichain.CommonChainConfig retriever *apiretriever.CommonRetrieverConfig ) - if utils.HasValue(input.Knowledgebase) { + if hasKnowledgebaseOrEnableUpload { qachain := &apichain.RetrievalQAChain{ ObjectMeta: metav1.ObjectMeta{ Name: input.Name, @@ -547,6 +594,13 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat }); err != nil { return nil, err } + llmchain := &apichain.LLMChain{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, llmchain) chainConfig = &qachain.Spec.CommonChainConfig } else { llmchain := &apichain.LLMChain{ @@ -580,43 +634,88 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat }); err != nil { return nil, err } + qachain := &apichain.RetrievalQAChain{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, qachain) chainConfig = &llmchain.Spec.CommonChainConfig } // create or update retrievers - // knowledgebaseRetriever (must have) -> multiQueryRetriever (optional) -> rerankRetriever (optional) -> Output - hasKnowledgebaseRetriever := utils.HasValue(input.Knowledgebase) + // (must have) (must have) (optional) (optional) + // knowledgebase1 -> knowledgebaseRetriever1 + // knowledgebase2 -> knowledgebaseRetriever2 --> mergerRetriever --> multiQueryRetriever --> rerankRetriever --> Output + // conversation knowledgebase (placeholder or has true data) -> knowledgebaseRetriever3 + hasKnowledgebaseRetriever := hasKnowledgebaseOrEnableUpload hasMultiQueryRetriever := hasKnowledgebaseRetriever && pointer.BoolDeref(input.EnableMultiQuery, false) hasRerankRetriever := hasKnowledgebaseRetriever && pointer.BoolDeref(input.EnableRerank, false) rerankModel := "" var knowledgebaseRetriever *apiretriever.KnowledgeBaseRetriever if hasKnowledgebaseRetriever { - knowledgebaseRetriever = &apiretriever.KnowledgeBaseRetriever{ + for i := range input.Knowledgebases { + indexStr := fmt.Sprintf("-%d", i) + if i == 0 { + // Compatible with previous versions when there is only one knowledgebase + indexStr = "" + } + knowledgebaseRetriever = &apiretriever.KnowledgeBaseRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name + indexStr, + Namespace: input.Namespace, + Labels: map[string]string{ + "application": input.Name, + }, + }, + Spec: apiretriever.KnowledgeBaseRetrieverSpec{ + CommonSpec: v1alpha1.CommonSpec{ + DisplayName: "retriever", + Description: "retriever", + }, + CommonRetrieverConfig: apiretriever.CommonRetrieverConfig{ + ScoreThreshold: pointer.Float32(float32(pointer.Float64Deref(input.ScoreThreshold, apiretriever.DefaultScoreThreshold))), + NumDocuments: pointer.IntDeref(input.NumDocuments, apiretriever.DefaultNumDocuments), + }, + }, + } + if _, err = controllerutil.CreateOrUpdate(ctx, c, knowledgebaseRetriever, func() error { + if input.ScoreThreshold != nil { + knowledgebaseRetriever.Spec.ScoreThreshold = pointer.Float32(float32(*input.ScoreThreshold)) + } + knowledgebaseRetriever.Spec.NumDocuments = pointer.IntDeref(input.NumDocuments, knowledgebaseRetriever.Spec.NumDocuments) + return nil + }); err != nil { + return nil, err + } + retriever = &knowledgebaseRetriever.Spec.CommonRetrieverConfig + } + } else { + knowledgebaseRetriever = &apiretriever.KnowledgeBaseRetriever{} + _ = c.DeleteAllOf(ctx, knowledgebaseRetriever, client.InNamespace(input.Namespace), client.MatchingLabels{"application": input.Name}) + } + + if hasKnowledgebaseRetriever { + mergerRetriever := &apiretriever.MergerRetriever{ ObjectMeta: metav1.ObjectMeta{ Name: input.Name, Namespace: input.Namespace, }, - Spec: apiretriever.KnowledgeBaseRetrieverSpec{ - CommonSpec: v1alpha1.CommonSpec{ - DisplayName: "retriever", - Description: "retriever", - }, - CommonRetrieverConfig: apiretriever.CommonRetrieverConfig{ - ScoreThreshold: pointer.Float32(float32(pointer.Float64Deref(input.ScoreThreshold, apiretriever.DefaultScoreThreshold))), - NumDocuments: pointer.IntDeref(input.NumDocuments, apiretriever.DefaultNumDocuments), - }, - }, } - if _, err = controllerutil.CreateOrUpdate(ctx, c, knowledgebaseRetriever, func() error { - if input.ScoreThreshold != nil { - knowledgebaseRetriever.Spec.ScoreThreshold = pointer.Float32(float32(*input.ScoreThreshold)) - } - knowledgebaseRetriever.Spec.NumDocuments = pointer.IntDeref(input.NumDocuments, knowledgebaseRetriever.Spec.NumDocuments) + if _, err = controllerutil.CreateOrUpdate(ctx, c, mergerRetriever, func() error { return nil }); err != nil { return nil, err } - retriever = &knowledgebaseRetriever.Spec.CommonRetrieverConfig + } else { + mergerRetriever := &apiretriever.MergerRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, mergerRetriever) } if hasMultiQueryRetriever { @@ -653,7 +752,16 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat return nil, err } retriever = &multiQueryRetriever.Spec.CommonRetrieverConfig + } else { + multiQueryRetriever := &apiretriever.MultiQueryRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, multiQueryRetriever) } + if hasRerankRetriever { rerankRetriever := &apiretriever.RerankRetriever{ ObjectMeta: metav1.ObjectMeta{ @@ -695,17 +803,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat if rerankRetriever.Spec.Model != nil { rerankModel = rerankRetriever.Spec.Model.Name } - } - if !hasMultiQueryRetriever { - multiQueryRetriever := &apiretriever.MultiQueryRetriever{ - ObjectMeta: metav1.ObjectMeta{ - Name: input.Name, - Namespace: input.Namespace, - }, - } - _ = c.Delete(ctx, multiQueryRetriever) - } - if !hasRerankRetriever { + } else { reRankRetriever := &apiretriever.RerankRetriever{ ObjectMeta: metav1.ObjectMeta{ Name: input.Name, @@ -714,6 +812,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat } _ = c.Delete(ctx, reRankRetriever) } + // create or update agent for tools var agent *apiagent.Agent if len(input.Tools) != 0 { @@ -742,6 +841,14 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat }); err != nil { return nil, err } + } else { + agent = &apiagent.Agent{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, agent) } // update application @@ -765,7 +872,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat } func mutateApp(app *v1alpha1.Application, input generated.UpdateApplicationConfigInput, hasMultiQueryRetriever, hasRerankRetriever bool) error { - app.Spec.Nodes = redefineNodes(input.Knowledgebase, input.Namespace, input.Name, input.Llm, input.Tools, hasMultiQueryRetriever, hasRerankRetriever, input.EnableUploadFile) + app.Spec.Nodes = redefineNodes(input.Knowledgebases, input.Namespace, input.Name, input.Llm, input.Tools, hasMultiQueryRetriever, hasRerankRetriever, input.EnableUploadFile) app.Spec.Prologue = pointer.StringDeref(input.Prologue, app.Spec.Prologue) app.Spec.ShowRespInfo = pointer.BoolDeref(input.ShowRespInfo, app.Spec.ShowRespInfo) app.Spec.ShowRetrievalInfo = pointer.BoolDeref(input.ShowRetrievalInfo, app.Spec.ShowRetrievalInfo) @@ -779,7 +886,7 @@ func mutateApp(app *v1alpha1.Application, input generated.UpdateApplicationConfi } // redefineNodes redefine nodes in application -func redefineNodes(knowledgebase *string, namespace string, name string, llmName string, tools []*generated.ToolInput, hasMultiQueryRetriever, hasRerankRetriever bool, enableUploadFile *bool) (nodes []v1alpha1.Node) { +func redefineNodes(knowledgebases []*string, namespace string, name string, llmName string, tools []*generated.ToolInput, hasMultiQueryRetriever, hasRerankRetriever bool, enableUploadFile *bool) (nodes []v1alpha1.Node) { nodes = []v1alpha1.Node{ { NodeConfig: v1alpha1.NodeConfig{ @@ -808,41 +915,25 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName }, NextNodeName: []string{"chain-node"}, }, - } - if pointer.BoolDeref(enableUploadFile, false) { - nodes = append(nodes, v1alpha1.Node{ + { NodeConfig: v1alpha1.NodeConfig{ - Name: "documentloader-node", - DisplayName: "documentloader", - Description: "文档加载,可选", + Name: "llm-node", + DisplayName: "llm", + Description: "设定大模型的访问信息", Ref: &v1alpha1.TypedObjectReference{ APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), - Kind: "DocumentLoader", - Name: name, + Kind: "LLM", + Name: llmName, Namespace: &namespace, }, }, NextNodeName: []string{"chain-node"}, - }) - } - nodes = append(nodes, v1alpha1.Node{ - NodeConfig: v1alpha1.NodeConfig{ - Name: "llm-node", - DisplayName: "llm", - Description: "设定大模型的访问信息", - Ref: &v1alpha1.TypedObjectReference{ - APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), - Kind: "LLM", - Name: llmName, - Namespace: &namespace, - }, }, - NextNodeName: []string{"chain-node"}, - }) + } if len(tools) != 0 { nodes[len(nodes)-1].NextNodeName = []string{"chain-node", "agent-node"} } - if knowledgebase == nil { + if !utils.HasValues(knowledgebases) && !pointer.BoolDeref(enableUploadFile, false) { nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ Name: "chain-node", @@ -858,50 +949,102 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName NextNodeName: []string{"Output"}, }) } else { - nodes = append(nodes, - v1alpha1.Node{ - NodeConfig: v1alpha1.NodeConfig{ - Name: "knowledgebase-node", - DisplayName: "知识库", - Description: "连接知识库", - Ref: &v1alpha1.TypedObjectReference{ - APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), - Kind: "KnowledgeBase", - Name: pointer.StringDeref(knowledgebase, ""), - Namespace: &namespace, + for i, knowledgebase := range knowledgebases { + indexStr := fmt.Sprintf("-%d", i) + if i == 0 { + // Compatible with previous versions when there is only one knowledgebase + indexStr = "" + } + nodes = append(nodes, + v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: fmt.Sprintf("knowledgebase%s-node", indexStr), + DisplayName: "知识库", + Description: "连接知识库", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), + Kind: "KnowledgeBase", + Name: pointer.StringDeref(knowledgebase, ""), + Namespace: &namespace, + }, }, + NextNodeName: []string{fmt.Sprintf("retriever%s-node", indexStr)}, }, - NextNodeName: []string{"retriever-node"}, - }, + v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: fmt.Sprintf("retriever%s-node", indexStr), + DisplayName: "从知识库提取信息的retriever", + Description: "连接应用和知识库", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"), + Kind: "KnowledgeBaseRetriever", + Name: fmt.Sprintf("%s%s", name, indexStr), + Namespace: &namespace, + }, + }, + NextNodeName: []string{"mergerretriever-node"}, + }) + } + if pointer.BoolDeref(enableUploadFile, false) { + nodes = append(nodes, + v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: "conversation-knowledgebase-node", + DisplayName: "会话知识库", + Description: "连接会话知识库,包含会话中上传的文件", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), + Kind: "KnowledgeBase", + Name: appnode.ConversationKnowledgebaseName, + Namespace: &namespace, + }, + }, + NextNodeName: []string{"conversation-knowledgebase-retriever-node"}, + }, + v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: "conversation-knowledgebase-retriever-node", + DisplayName: "从会话知识库提取信息的retriever", + Description: "连接应用和知识库", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"), + Kind: "KnowledgeBaseRetriever", + Name: name + "-" + appnode.ConversationKnowledgebaseName, + Namespace: &namespace, + }, + }, + NextNodeName: []string{"mergerretriever-node"}, + }) + } + mergerRetrierNextNodeName := "chain-node" + multiqueryRetrieverNodeName := "chain-node" + switch { + case hasMultiQueryRetriever: + mergerRetrierNextNodeName = "multiqueryretriever-node" + if hasRerankRetriever { + multiqueryRetrieverNodeName = "rerankretriever-node" + } + case !hasMultiQueryRetriever && hasRerankRetriever: + mergerRetrierNextNodeName = "rerankretriever-node" + case !hasMultiQueryRetriever && !hasRerankRetriever: + mergerRetrierNextNodeName = "chain-node" + } + nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ - Name: "retriever-node", - DisplayName: "从知识库提取信息的retriever", - Description: "连接应用和知识库", + Name: "mergerretriever-node", + DisplayName: "知识库合并retriever", + Description: "知识库合并retriever", Ref: &v1alpha1.TypedObjectReference{ APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"), - Kind: "KnowledgeBaseRetriever", + Kind: "MergerRetriever", Name: name, Namespace: &namespace, }, }, - NextNodeName: []string{"chain-node"}, + NextNodeName: []string{mergerRetrierNextNodeName}, }) - knowledgebaseRetrierNextNodeName := "chain-node" - switch { - case hasMultiQueryRetriever: - knowledgebaseRetrierNextNodeName = "multiqueryretriever-node" - case !hasMultiQueryRetriever && hasRerankRetriever: - knowledgebaseRetrierNextNodeName = "rerankretriever-node" - case !hasMultiQueryRetriever && !hasRerankRetriever: - knowledgebaseRetrierNextNodeName = "chain-node" - } - nodes[len(nodes)-1].NextNodeName = []string{knowledgebaseRetrierNextNodeName} if hasMultiQueryRetriever { - nextNodeName := "chain-node" - if hasRerankRetriever { - nextNodeName = "rerankretriever-node" - } nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ @@ -915,7 +1058,7 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName Namespace: &namespace, }, }, - NextNodeName: []string{nextNodeName}, + NextNodeName: []string{multiqueryRetrieverNodeName}, }) } if hasRerankRetriever { @@ -951,6 +1094,7 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName NextNodeName: []string{"Output"}, }) } + if len(tools) != 0 { nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ @@ -967,6 +1111,24 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName NextNodeName: []string{"chain-node"}, }) } + + if pointer.BoolDeref(enableUploadFile, false) { + nodes = append(nodes, v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: "documentloader-node", + DisplayName: "documentloader", + Description: "文档加载,可选", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), + Kind: "DocumentLoader", + Name: name, + Namespace: &namespace, + }, + }, + NextNodeName: []string{"chain-node"}, + }) + } + nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ Name: "Output", @@ -982,6 +1144,25 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName return nodes } +// Deprecated: just for backward compatibility, should remove after new version +func ConvertKnowledgebase2Knowledgebases(knowledgebase *string, knowledgebases []*string) []*string { + if knowledgebases != nil { + return knowledgebases + } + if knowledgebase == nil { + return nil + } + return []*string{knowledgebase} +} + +// Deprecated: just for backward compatibility, should remove after new version +func ConvertKnowledgebases2Knowledgebase(knowledgebases []*string) *string { + if len(knowledgebases) == 0 { + return nil + } + return knowledgebases[0] +} + func UploadIcon(ctx context.Context, client client.Client, icon, appName, namespace string) (string, error) { if strings.HasPrefix(icon, "data:image") { imgBytes, err := pkgutils.ParseBase64ImageBytes(icon) diff --git a/apiserver/pkg/chat/chat_server.go b/apiserver/pkg/chat/chat_server.go index 4cecb65f5..66fee6547 100644 --- a/apiserver/pkg/chat/chat_server.go +++ b/apiserver/pkg/chat/chat_server.go @@ -29,7 +29,6 @@ import ( langchainllms "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/memory" "github.com/tmc/langchaingo/prompts" - langchainschema "github.com/tmc/langchaingo/schema" "golang.org/x/sync/errgroup" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" @@ -302,32 +301,29 @@ func (cs *ChatServer) ListPromptStarters(ctx context.Context, req APPMetadata, l if finish != nil { defer finish() } - v, ok := outArg[base.LangchaingoRetrieverKeyInArg] - if ok { - r, ok := v.(langchainschema.Retriever) - if ok { - doc, err := r.GetRelevantDocuments(ctx, "") - if err != nil { - return nil, err - } - for _, d := range doc { - hasAnswer := false - // has answer, means qa.csv, just return the question - v, ok := d.Metadata[documentloaders.AnswerCol] - if ok { - answer, ok := v.(string) - if ok && answer != "" { - question := strings.TrimSuffix(d.PageContent, "\na: "+answer) - promptStarters = append(promptStarters, strings.TrimPrefix(question, "q: ")) - hasAnswer = true - if len(promptStarters) == limit { - break - } + retrievers, err := base.GetRetrieversFromArg(outArg) + if err == nil && len(retrievers) > 0 { + doc, err := retrievers[0].GetRelevantDocuments(ctx, "") + if err != nil { + return nil, err + } + for _, d := range doc { + hasAnswer := false + // has answer, means qa.csv, just return the question + v, ok := d.Metadata[documentloaders.AnswerCol] + if ok { + answer, ok := v.(string) + if ok && answer != "" { + question := strings.TrimSuffix(d.PageContent, "\na: "+answer) + promptStarters = append(promptStarters, strings.TrimPrefix(question, "q: ")) + hasAnswer = true + if len(promptStarters) == limit { + break } } - if !hasAnswer { - content.WriteString(d.PageContent + "\n") - } + } + if !hasAnswer { + content.WriteString(d.PageContent + "\n") } } } diff --git a/apiserver/pkg/utils/structured.go b/apiserver/pkg/utils/structured.go index 20d1916dc..f72c319e0 100644 --- a/apiserver/pkg/utils/structured.go +++ b/apiserver/pkg/utils/structured.go @@ -38,3 +38,7 @@ func MapStr2Any(input map[string]string) map[string]any { func HasValue(s *string) bool { return s != nil && strings.TrimSpace(*s) != "" } + +func HasValues(s []*string) bool { + return len(s) > 0 && strings.TrimSpace(*s[0]) != "" +} diff --git a/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml new file mode 100644 index 000000000..590f3797a --- /dev/null +++ b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml @@ -0,0 +1,98 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.9.2 + creationTimestamp: null + name: mergerretrievers.retriever.arcadia.kubeagi.k8s.com.cn +spec: + group: retriever.arcadia.kubeagi.k8s.com.cn + names: + kind: MergerRetriever + listKind: MergerRetrieverList + plural: mergerretrievers + singular: mergerretriever + scope: Namespaced + versions: + - name: v1alpha1 + schema: + openAPIV3Schema: + description: MergerRetriever is the Schema for the MergerRetriever API + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: MergerRetrieverSpec defines the desired state of MergerRetriever + properties: + creator: + description: Creator defines datasource creator (AUTO-FILLED by webhook) + type: string + description: + description: Description defines datasource description + type: string + displayName: + description: DisplayName defines datasource display name + type: string + type: object + status: + description: MergerRetrieverStatus defines the observed state of MergerRetriever + properties: + conditions: + description: Conditions of the resource. + items: + description: A Condition that may apply to a resource. + properties: + lastSuccessfulTime: + description: LastSuccessfulTime is repository Last Successful + Update Time + format: date-time + type: string + lastTransitionTime: + description: LastTransitionTime is the last time this condition + transitioned from one status to another. + format: date-time + type: string + message: + description: A Message containing details about this condition's + last transition from one status to another, if any. + type: string + reason: + description: A Reason for this condition's last transition from + one status to another. + type: string + status: + description: Status of this condition; is it currently True, + False, or Unknown + type: string + type: + description: Type of this condition. At most one of each condition + type may apply to a resource at any point in time. + type: string + required: + - lastTransitionTime + - reason + - status + - type + type: object + type: array + observedGeneration: + description: ObservedGeneration is the last observed generation. + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/config/rbac/role.yaml b/config/rbac/role.yaml index 7d0493d62..16b90d5b7 100644 --- a/config/rbac/role.yaml +++ b/config/rbac/role.yaml @@ -653,6 +653,32 @@ rules: - get - patch - update +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - mergerRetrievers/finalizers + verbs: + - update +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - mergerretrievers + verbs: + - create + - delete + - get + - list + - patch + - update + - watch +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - mergerretrievers/status + verbs: + - get + - patch + - update - apiGroups: - retriever.arcadia.kubeagi.k8s.com.cn resources: diff --git a/config/samples/app_llmchain_abstract.yaml b/config/samples/app_llmchain_abstract.yaml index 84df0e4f1..825701ae0 100644 --- a/config/samples/app_llmchain_abstract.yaml +++ b/config/samples/app_llmchain_abstract.yaml @@ -40,11 +40,11 @@ spec: name: app-shared-llm-service nextNodeName: ["chain-node"] - name: chain-node - displayName: "llm chain" + displayName: "chain" description: "chain是langchain的核心概念,llmChain用于连接prompt和llm" ref: apiGroup: chain.arcadia.kubeagi.k8s.com.cn - kind: LLMChain + kind: RetrievalQAChain name: base-chat-document-assistant nextNodeName: ["Output"] - name: Output @@ -53,6 +53,30 @@ spec: ref: kind: Output name: Output + - name: conversation-knowledgebase-node + displayName: "对话知识库" + description: "对话知识库" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBase + name: conversation-knowledgebase-placeholder + nextNodeName: ["conversation-knowledgebase-retriever-node"] + - name: conversation-knowledgebase-retriever-node + displayName: "从对话知识库提取信息的retriever" + description: "连接应用和知识库" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBaseRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["mergerretriever-node"] + - name: mergerretriever-node + displayName: "整合多个retriever的结果" + description: "整合多个retriever的结果" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: MergerRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["chain-node"] --- apiVersion: prompt.arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: Prompt @@ -81,7 +105,7 @@ spec: chunkOverlap: 100 --- apiVersion: chain.arcadia.kubeagi.k8s.com.cn/v1alpha1 -kind: LLMChain +kind: RetrievalQAChain metadata: name: base-chat-document-assistant namespace: arcadia diff --git a/config/samples/app_retrievalqachain_knowledgebase_pgvector-conversation.yaml b/config/samples/app_retrievalqachain_knowledgebase_pgvector-conversation.yaml new file mode 100644 index 000000000..9feb4a951 --- /dev/null +++ b/config/samples/app_retrievalqachain_knowledgebase_pgvector-conversation.yaml @@ -0,0 +1,96 @@ +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Application +metadata: + name: base-chat-with-knowledgebase-pgvector + namespace: arcadia +spec: + displayName: "知识库应用" + description: "最简单的和知识库对话的应用" + prologue: "Welcome to talk to the KnowledgeBase!🤖" + docNullReturn: "未找到您询问的内容,请详细描述您的问题,以便我们为您提供更好的服务" + nodes: + - name: Input + displayName: "用户输入" + description: "用户输入节点,必须" + ref: + kind: Input + name: Input + nextNodeName: ["prompt-node"] + - name: prompt-node + displayName: "prompt" + description: "设定prompt,template中可以使用{{xx}}来替换变量" + ref: + apiGroup: prompt.arcadia.kubeagi.k8s.com.cn + kind: Prompt + name: base-chat-with-knowledgebase + nextNodeName: ["chain-node"] + - name: llm-node + displayName: "zhipu大模型服务" + description: "设定大模型的访问信息" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: LLM + name: app-shared-llm-service + nextNodeName: ["chain-node"] + - name: knowledgebase-node + displayName: "使用的知识库" + description: "要用哪个知识库" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBase + name: knowledgebase-sample-pgvector + nextNodeName: ["retriever-node"] + - name: retriever-node + displayName: "从知识库提取信息的retriever" + description: "连接应用和知识库" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBaseRetriever + name: base-chat-with-knowledgebase + nextNodeName: ["chain-node"] + - name: chain-node + displayName: "RetrievalQA chain" + description: "chain是langchain的核心概念,RetrievalQAChain用于从 retriever 中提取信息,供llm调用" + ref: + apiGroup: chain.arcadia.kubeagi.k8s.com.cn + kind: RetrievalQAChain + name: base-chat-with-knowledgebase + nextNodeName: ["Output"] + - name: Output + displayName: "最终输出" + description: "最终输出节点,必须" + ref: + kind: Output + name: Output + - name: documentloader-node + displayName: "documentloader" + description: "可选,如果需要对话中上传解析文件,需要添加这个节点" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: DocumentLoader + name: common + nextNodeName: [ "chain-node" ] + - name: conversation-knowledgebase-node + displayName: "对话知识库" + description: "对话知识库" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBase + name: conversation-knowledgebase-placeholder + nextNodeName: ["conversation-knowledgebase-retriever-node"] + - name: conversation-knowledgebase-retriever-node + displayName: "从对话知识库提取信息的retriever" + description: "连接应用和知识库" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBaseRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["mergerretriever-node"] + - name: mergerretriever-node + displayName: "整合多个retriever的结果" + description: "整合多个retriever的结果" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: MergerRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["chain-node"] diff --git a/config/samples/app_retrievalqachain_multi_knowledgebase_pgvector_rerank_multiquery.yaml b/config/samples/app_retrievalqachain_multi_knowledgebase_pgvector_rerank_multiquery.yaml new file mode 100644 index 000000000..328cc138c --- /dev/null +++ b/config/samples/app_retrievalqachain_multi_knowledgebase_pgvector_rerank_multiquery.yaml @@ -0,0 +1,145 @@ +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Application +metadata: + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + namespace: arcadia +spec: + displayName: "知识库应用" + description: "最简单的和知识库对话的应用" + prologue: "Welcome to talk to the KnowledgeBase!🤖" + docNullReturn: "未找到您询问的内容,请详细描述您的问题,以便我们为您提供更好的服务" + nodes: + - name: Input + displayName: "用户输入" + description: "用户输入节点,必须" + ref: + kind: Input + name: Input + nextNodeName: ["prompt-node"] + - name: prompt-node + displayName: "prompt" + description: "设定prompt,template中可以使用{{xx}}来替换变量" + ref: + apiGroup: prompt.arcadia.kubeagi.k8s.com.cn + kind: Prompt + name: base-chat-with-knowledgebase + nextNodeName: ["chain-node"] + - name: llm-node + displayName: "大模型服务" + description: "设定大模型的访问信息" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: LLM + name: app-shared-llm-service + nextNodeName: ["chain-node"] + - name: knowledgebase-node + displayName: "使用的知识库" + description: "要用哪个知识库" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBase + name: knowledgebase-sample-pgvector + nextNodeName: ["retriever-node"] + - name: retriever-node + displayName: "从知识库提取信息的retriever" + description: "连接应用和知识库" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBaseRetriever + name: base-chat-with-knowledgebase-pgvector-1 + nextNodeName: ["mergerretriever-node"] + - name: knowledgebase-1-node + displayName: "使用的知识库1" + description: "要用哪个知识库" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBase + name: knowledgebase-sample-pgvector2 + nextNodeName: ["retriever-1-node"] + - name: retriever-1-node + displayName: "从知识库1提取信息的retriever" + description: "连接应用和知识库" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBaseRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["mergerretriever-node"] + - name: mergerretriever-node + displayName: "整合多个retriever的结果" + description: "整合多个retriever的结果" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: MergerRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["multiquery-node"] + - name: multiquery-node + displayName: "让LLM多角度提问的Retriever" + description: "让LLM多角度提问的Retriever" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: MultiQueryRetriever + name: base-chat-with-knowledgebase-pgvector-2 + nextNodeName: ["rerank-retriever-node"] + - name: rerank-retriever-node + displayName: "rerank retriever" + description: "重排" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: RerankRetriever + name: base-chat-with-knowledgebase-pgvector-rerank + nextNodeName: ["chain-node"] + - name: chain-node + displayName: "RetrievalQA chain" + description: "chain是langchain的核心概念,RetrievalQAChain用于从 retriever 中提取信息,供llm调用" + ref: + apiGroup: chain.arcadia.kubeagi.k8s.com.cn + kind: RetrievalQAChain + name: base-chat-with-knowledgebase + nextNodeName: ["Output"] + - name: Output + displayName: "最终输出" + description: "最终输出节点,必须" + ref: + kind: Output + name: Output + - name: conversation-knowledgebase-node + displayName: "对话知识库" + description: "对话知识库" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBase + name: conversation-knowledgebase-placeholder + nextNodeName: ["conversation-knowledgebase-retriever-node"] + - name: conversation-knowledgebase-retriever-node + displayName: "从对话知识库提取信息的retriever" + description: "连接应用和知识库" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBaseRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["mergerretriever-node"] + +--- +apiVersion: retriever.arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: KnowledgeBaseRetriever +metadata: + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + namespace: arcadia + annotations: + arcadia.kubeagi.k8s.com.cn/input-rules: '[{"kind":"KnowledgeBase","group":"arcadia.kubeagi.k8s.com.cn","length":1}]' + arcadia.kubeagi.k8s.com.cn/output-rules: '[{"kind":"RetrievalQAChain","group":"chain.arcadia.kubeagi.k8s.com.cn","length":1}]' +spec: + displayName: "从知识库获取信息的Retriever" + scoreThreshold: 0.3 + numDocuments: 50 +--- +apiVersion: retriever.arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: MergerRetriever +metadata: + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + namespace: arcadia + annotations: + arcadia.kubeagi.k8s.com.cn/input-rules: '[{"kind":"KnowledgeBase","group":"arcadia.kubeagi.k8s.com.cn","length":1}]' + arcadia.kubeagi.k8s.com.cn/output-rules: '[{"kind":"RetrievalQAChain","group":"chain.arcadia.kubeagi.k8s.com.cn","length":1}]' +spec: + displayName: "整合多个Retriever" diff --git a/config/samples/app_shared_llm_service_zhipu.yaml b/config/samples/app_shared_llm_service_zhipu.yaml index 9762e60a7..e44aa08ff 100644 --- a/config/samples/app_shared_llm_service_zhipu.yaml +++ b/config/samples/app_shared_llm_service_zhipu.yaml @@ -5,7 +5,7 @@ metadata: namespace: arcadia type: Opaque data: - apiKey: "MTZlZDcxYzcwMDE0NGFiMjIyMmI5YmEwZDFhMTBhZTUuUTljWVZtWWxmdjlnZGtDeQ==" # replace this with your API key + apiKey: "YTgyNTlhNjFmN2EwZGYzNmQ5N2Q3ZDIwOGVlMTQ0NTUuODc5OGJyeldwaGUzWUlCOA==" # replace this with your API key --- apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: LLM diff --git a/config/samples/arcadia_v1alpha1_knowledgebase_pgvector.yaml b/config/samples/arcadia_v1alpha1_knowledgebase_pgvector.yaml index 5e41e3f3f..2f67fce5e 100644 --- a/config/samples/arcadia_v1alpha1_knowledgebase_pgvector.yaml +++ b/config/samples/arcadia_v1alpha1_knowledgebase_pgvector.yaml @@ -22,3 +22,27 @@ spec: files: - path: qa.csv - path: chunk.csv +--- +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: KnowledgeBase +metadata: + name: knowledgebase-sample-pgvector2 + namespace: arcadia +spec: + displayName: "测试 KnowledgeBase 2" + description: "测试 KnowledgeBase 2" + embedder: + kind: Embedders + name: embedders-sample + namespace: arcadia + vectorStore: + kind: VectorStores + name: pgvector-sample + namespace: arcadia + fileGroups: + - source: + kind: VersionedDataset + name: dataset-playground-v1 + namespace: arcadia + files: + - path: CODE_OF_CONDUCT.md diff --git a/config/samples/arcadia_v1alpha1_model_reranking_bce_modelscope.yaml b/config/samples/arcadia_v1alpha1_model_reranking_bce_modelscope.yaml new file mode 100644 index 000000000..221beee1b --- /dev/null +++ b/config/samples/arcadia_v1alpha1_model_reranking_bce_modelscope.yaml @@ -0,0 +1,16 @@ +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Model +metadata: + name: bce-reranker + namespace: arcadia +spec: + displayName: "bce-reranker-large" + description: | + BCEmbedding是由网易有道开发的中英双语和跨语种语义表征算法模型库,其中包含 EmbeddingModel和 RerankerModel两类基础模型。 + EmbeddingModel专门用于生成语义向量,在语义搜索和问答中起着关键作用,而 RerankerModel擅长优化语义搜索结果和语义相关顺序精排。 + + Github: https://github.com/netease-youdao/BCEmbedding + HuggingFace: https://huggingface.co/maidalun1020/bce-reranker-base_v1 + types: "reranking" + modelScopeRepo: maidalun/bce-reranker-base_v1 + modelSource: modelscope diff --git a/config/samples/arcadia_v1alpha1_versioneddataset.yaml b/config/samples/arcadia_v1alpha1_versioneddataset.yaml index 156a2ef31..d4e26a340 100644 --- a/config/samples/arcadia_v1alpha1_versioneddataset.yaml +++ b/config/samples/arcadia_v1alpha1_versioneddataset.yaml @@ -16,5 +16,6 @@ spec: files: - path: qa.csv - path: chunk.csv + - path: CODE_OF_CONDUCT.md released: 0 version: v1 diff --git a/controllers/app-node/retriever/merger_retriever_controller.go b/controllers/app-node/retriever/merger_retriever_controller.go new file mode 100644 index 000000000..6e1221e4c --- /dev/null +++ b/controllers/app-node/retriever/merger_retriever_controller.go @@ -0,0 +1,150 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package chain + +import ( + "context" + "reflect" + + "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/runtime" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + + api "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" + arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" + appnode "github.com/kubeagi/arcadia/controllers/app-node" +) + +// MergerRetrieverReconciler reconciles a MergerRetriever object +type MergerRetrieverReconciler struct { + client.Client + Scheme *runtime.Scheme +} + +//+kubebuilder:rbac:groups=retriever.arcadia.kubeagi.k8s.com.cn,resources=mergerretrievers,verbs=get;list;watch;create;update;patch;delete +//+kubebuilder:rbac:groups=retriever.arcadia.kubeagi.k8s.com.cn,resources=mergerretrievers/status,verbs=get;update;patch +//+kubebuilder:rbac:groups=retriever.arcadia.kubeagi.k8s.com.cn,resources=mergerRetrievers/finalizers,verbs=update + +// Reconcile is part of the main kubernetes reconciliation loop which aims to +// move the current state of the cluster closer to the desired state. +// For more details, check Reconcile and its Result here: +// - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.12.2/pkg/reconcile +func (r *MergerRetrieverReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + log := ctrl.LoggerFrom(ctx) + log.V(5).Info("Start MergerRetriever Reconcile") + instance := &api.MergerRetriever{} + if err := r.Get(ctx, req.NamespacedName, instance); err != nil { + // There's no need to requeue if the resource no longer exists. + // Otherwise, we'll be requeued implicitly because we return an error. + log.V(1).Info("Failed to get MergerRetriever") + return ctrl.Result{}, client.IgnoreNotFound(err) + } + log = log.WithValues("Generation", instance.GetGeneration(), "ObservedGeneration", instance.Status.ObservedGeneration, "creator", instance.Spec.Creator) + log.V(5).Info("Get MergerRetriever instance") + + // Add a finalizer.Then, we can define some operations which should + // occur before the MergerRetriever to be deleted. + // More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/finalizers + if newAdded := controllerutil.AddFinalizer(instance, arcadiav1alpha1.Finalizer); newAdded { + log.Info("Try to add Finalizer for MergerRetriever") + if err := r.Update(ctx, instance); err != nil { + log.Error(err, "Failed to update MergerRetriever to add finalizer, will try again later") + return ctrl.Result{}, err + } + log.Info("Adding Finalizer for MergerRetriever done") + return ctrl.Result{}, nil + } + + // Check if the MergerRetriever instance is marked to be deleted, which is + // indicated by the deletion timestamp being set. + if instance.GetDeletionTimestamp() != nil && controllerutil.ContainsFinalizer(instance, arcadiav1alpha1.Finalizer) { + log.Info("Performing Finalizer Operations for MergerRetriever before delete CR") + // TODO perform the finalizer operations here, for example: remove vectorstore data? + log.Info("Removing Finalizer for MergerRetriever after successfully performing the operations") + controllerutil.RemoveFinalizer(instance, arcadiav1alpha1.Finalizer) + if err := r.Update(ctx, instance); err != nil { + log.Error(err, "Failed to remove the finalizer for MergerRetriever") + return ctrl.Result{}, err + } + log.Info("Remove MergerRetriever done") + return ctrl.Result{}, nil + } + + instance, result, err := r.reconcile(ctx, log, instance) + + // Update status after reconciliation. + if updateStatusErr := r.patchStatus(ctx, instance); updateStatusErr != nil { + log.Error(updateStatusErr, "unable to update status after reconciliation") + return ctrl.Result{Requeue: true}, updateStatusErr + } + + return result, err +} + +func (r *MergerRetrieverReconciler) reconcile(ctx context.Context, log logr.Logger, instance *api.MergerRetriever) (*api.MergerRetriever, ctrl.Result, error) { + // Observe generation change + if instance.Status.ObservedGeneration != instance.Generation { + instance.Status.ObservedGeneration = instance.Generation + r.setCondition(instance, instance.Status.WaitingCompleteCondition()...) + if updateStatusErr := r.patchStatus(ctx, instance); updateStatusErr != nil { + log.Error(updateStatusErr, "unable to update status after generation update") + return instance, ctrl.Result{Requeue: true}, updateStatusErr + } + } + + if instance.Status.IsReady() { + return instance, ctrl.Result{}, nil + } + // Note: should change here + // TODO: we should do more checks later.For example: + // LLM status + // Prompt status + if err := appnode.CheckAndUpdateAnnotation(ctx, log, r.Client, instance); err != nil { + instance.Status.SetConditions(instance.Status.ErrorCondition(err.Error())...) + } else { + instance.Status.SetConditions(instance.Status.ReadyCondition()...) + } + + return instance, ctrl.Result{}, nil +} + +func (r *MergerRetrieverReconciler) patchStatus(ctx context.Context, instance *api.MergerRetriever) error { + latest := &api.MergerRetriever{} + if err := r.Client.Get(ctx, client.ObjectKeyFromObject(instance), latest); err != nil { + return err + } + if reflect.DeepEqual(instance.Status, latest.Status) { + return nil + } + patch := client.MergeFrom(latest.DeepCopy()) + latest.Status = instance.Status + return r.Client.Status().Patch(ctx, latest, patch, client.FieldOwner("MergerRetriever-controller")) +} + +// SetupWithManager sets up the controller with the Manager. +func (r *MergerRetrieverReconciler) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&api.MergerRetriever{}). + Complete(r) +} + +func (r *MergerRetrieverReconciler) setCondition(instance *api.MergerRetriever, condition ...arcadiav1alpha1.Condition) *api.MergerRetriever { + instance.Status.SetConditions(condition...) + return instance +} diff --git a/controllers/base/application_controller.go b/controllers/base/application_controller.go index 6c230184e..797ff589a 100644 --- a/controllers/base/application_controller.go +++ b/controllers/base/application_controller.go @@ -55,6 +55,7 @@ const ( KnowledgebaseRetrieverIndexKey = "metadata.knowledgebaseretriever" RerankRetrieverIndexKey = "metadata.rerankretriever" MultiQueryRetrieverIndexKey = "metadata.multiqueryretriever" + MergerRetrieverIndexKey = "metadata.mergerretriever" AgentIndexKey = "metadata.agent" DocumentLoaderIndexKey = "metadata.documentloader" ) @@ -368,6 +369,7 @@ func (r *ApplicationReconciler) SetupWithManager(ctx context.Context, mgr ctrl.M {KnowledgebaseRetrieverIndexKey, "retriever", "knowledgebaseretriever"}, {RerankRetrieverIndexKey, "retriever", "rerankretriever"}, {MultiQueryRetrieverIndexKey, "retriever", "multiqueryretriever"}, + {MergerRetrieverIndexKey, "retriever", "mergerretriever"}, {AgentIndexKey, "", "agent"}, {DocumentLoaderIndexKey, "", "documentloader"}, } @@ -424,6 +426,7 @@ func (r *ApplicationReconciler) SetupWithManager(ctx context.Context, mgr ctrl.M Watches(&source.Kind{Type: &retrieveralpha1.KnowledgeBaseRetriever{}}, getEventHandler(KnowledgebaseRetrieverIndexKey)). Watches(&source.Kind{Type: &retrieveralpha1.RerankRetriever{}}, getEventHandler(RerankRetrieverIndexKey)). Watches(&source.Kind{Type: &retrieveralpha1.MultiQueryRetriever{}}, getEventHandler(MultiQueryRetrieverIndexKey)). + Watches(&source.Kind{Type: &retrieveralpha1.MergerRetriever{}}, getEventHandler(MergerRetrieverIndexKey)). Watches(&source.Kind{Type: &agentv1alpha1.Agent{}}, getEventHandler(AgentIndexKey)). Watches(&source.Kind{Type: &documentloaderv1alpha1.DocumentLoader{}}, getEventHandler(DocumentLoaderIndexKey)). Complete(r) diff --git a/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml new file mode 100644 index 000000000..590f3797a --- /dev/null +++ b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml @@ -0,0 +1,98 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.9.2 + creationTimestamp: null + name: mergerretrievers.retriever.arcadia.kubeagi.k8s.com.cn +spec: + group: retriever.arcadia.kubeagi.k8s.com.cn + names: + kind: MergerRetriever + listKind: MergerRetrieverList + plural: mergerretrievers + singular: mergerretriever + scope: Namespaced + versions: + - name: v1alpha1 + schema: + openAPIV3Schema: + description: MergerRetriever is the Schema for the MergerRetriever API + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: MergerRetrieverSpec defines the desired state of MergerRetriever + properties: + creator: + description: Creator defines datasource creator (AUTO-FILLED by webhook) + type: string + description: + description: Description defines datasource description + type: string + displayName: + description: DisplayName defines datasource display name + type: string + type: object + status: + description: MergerRetrieverStatus defines the observed state of MergerRetriever + properties: + conditions: + description: Conditions of the resource. + items: + description: A Condition that may apply to a resource. + properties: + lastSuccessfulTime: + description: LastSuccessfulTime is repository Last Successful + Update Time + format: date-time + type: string + lastTransitionTime: + description: LastTransitionTime is the last time this condition + transitioned from one status to another. + format: date-time + type: string + message: + description: A Message containing details about this condition's + last transition from one status to another, if any. + type: string + reason: + description: A Reason for this condition's last transition from + one status to another. + type: string + status: + description: Status of this condition; is it currently True, + False, or Unknown + type: string + type: + description: Type of this condition. At most one of each condition + type may apply to a resource at any point in time. + type: string + required: + - lastTransitionTime + - reason + - status + - type + type: object + type: array + observedGeneration: + description: ObservedGeneration is the last observed generation. + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/deploy/charts/arcadia/templates/rbac.yaml b/deploy/charts/arcadia/templates/rbac.yaml index 1f74678ab..0119f0a1e 100644 --- a/deploy/charts/arcadia/templates/rbac.yaml +++ b/deploy/charts/arcadia/templates/rbac.yaml @@ -671,6 +671,32 @@ rules: - get - patch - update +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - mergerRetrievers/finalizers + verbs: + - update +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - mergerretrievers + verbs: + - create + - delete + - get + - list + - patch + - update + - watch +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - mergerretrievers/status + verbs: + - get + - patch + - update - apiGroups: - retriever.arcadia.kubeagi.k8s.com.cn resources: diff --git a/deploy/charts/arcadia/templates/role-templates.yaml b/deploy/charts/arcadia/templates/role-templates.yaml index cbb4f8362..218c7947e 100644 --- a/deploy/charts/arcadia/templates/role-templates.yaml +++ b/deploy/charts/arcadia/templates/role-templates.yaml @@ -171,6 +171,7 @@ spec: - knowledgebaseretrievers - multiqueryretrievers - rerankretrievers + - mergerretrievers verbs: - create - delete @@ -188,6 +189,7 @@ spec: - knowledgebaseretrievers/status - multiqueryretrievers/status - rerankretrievers/status + - mergerretrievers/status verbs: - get - patch diff --git a/go.mod b/go.mod index 8f8600ff6..22c661473 100644 --- a/go.mod +++ b/go.mod @@ -215,4 +215,4 @@ require ( sigs.k8s.io/yaml v1.3.0 // indirect ) -replace github.com/tmc/langchaingo => github.com/kubeagi/langchaingo v0.0.0-20240312075057-ca2f549e8d91 // branch dev +replace github.com/tmc/langchaingo => github.com/kubeagi/langchaingo v0.0.0-20240416092403-dd907a8798bd // branch dev diff --git a/go.sum b/go.sum index b8718c1b0..5c1b36485 100644 --- a/go.sum +++ b/go.sum @@ -549,8 +549,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kubeagi/langchaingo v0.0.0-20240312075057-ca2f549e8d91 h1:4VbKHgpTrG/EPWIn4n3FMIedvz3z2aYL2E0Z7RIEvik= -github.com/kubeagi/langchaingo v0.0.0-20240312075057-ca2f549e8d91/go.mod h1:RLtnUED/hH2v765vdjS9Z6gonErZAXURuJHph0BttqM= +github.com/kubeagi/langchaingo v0.0.0-20240416092403-dd907a8798bd h1:3+lIBh7HtAqvbeBTpnuy79zvgBl4dCQz5YkhXKeocOc= +github.com/kubeagi/langchaingo v0.0.0-20240416092403-dd907a8798bd/go.mod h1:RLtnUED/hH2v765vdjS9Z6gonErZAXURuJHph0BttqM= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 h1:6Yzfa6GP0rIo/kULo2bwGEkFvCePZ3qHDDTC3/J9Swo= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= diff --git a/main.go b/main.go index 09ac62956..40f1737bb 100644 --- a/main.go +++ b/main.go @@ -278,6 +278,13 @@ func main() { setupLog.Error(err, "unable to create controller", "controller", "MultiQueryRetriever") os.Exit(1) } + if err = (&retrievertrollers.MergerRetrieverReconciler{ + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + }).SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create controller", "controller", "MergerRetriever") + os.Exit(1) + } if err = (&promptcontrollers.PromptReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), diff --git a/pkg/appruntime/agent/executor.go b/pkg/appruntime/agent/executor.go index a657cfe71..64b1c478a 100644 --- a/pkg/appruntime/agent/executor.go +++ b/pkg/appruntime/agent/executor.go @@ -23,6 +23,7 @@ import ( "github.com/tmc/langchaingo/agents" "github.com/tmc/langchaingo/callbacks" + "github.com/tmc/langchaingo/chains" "github.com/tmc/langchaingo/llms" langchaingoschema "github.com/tmc/langchaingo/schema" "k8s.io/apimachinery/pkg/types" @@ -79,15 +80,21 @@ func (p *Executor) Run(ctx context.Context, cli client.Client, args map[string]a agents.WithCallbacksHandler(streamHandler)(o) } } - agents.WithMemory(chain.GetMemory(llm, instance.Spec.AgentConfig.Options.Memory, history, "", ""))(o) + agents.WithMemory(chain.GetMemory(llm, instance.Spec.AgentConfig.Options.Memory, history, "input", ""))(o) } - executor, err := agents.Initialize(llm, allowedTools, agents.ZeroShotReactDescription, executorOptions) + executor, err := agents.Initialize(llm, allowedTools, agents.ConversationalReactDescription, executorOptions) if err != nil { return args, fmt.Errorf("failed to initialize executor: %w", err) } + executor.CallbacksHandler = log.KLogHandler{LogLevel: 3} input := make(map[string]any) - input["input"] = fmt.Sprintf("%s, %s", instance.Spec.Prompt, args["question"]) - response, err := executor.Call(ctx, input) + if instance.Spec.Prompt != "" { + input["input"] = fmt.Sprintf("%s, %s", instance.Spec.Prompt, args["question"]) + } else { + input["input"] = args["question"] + } + // chains.Call will add history to args + response, err := chains.Call(ctx, executor, input) if err != nil { klog.FromContext(ctx).Error(err, "error when call agent") // return args, fmt.Errorf("error when call agent: %w", err) diff --git a/pkg/appruntime/app_runtime.go b/pkg/appruntime/app_runtime.go index 82e3d5595..6f4483e23 100644 --- a/pkg/appruntime/app_runtime.go +++ b/pkg/appruntime/app_runtime.go @@ -25,8 +25,6 @@ import ( "strings" langchaingoschema "github.com/tmc/langchaingo/schema" - apierrors "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" "k8s.io/utils/strings/slices" "sigs.k8s.io/controller-runtime/pkg/client" @@ -149,32 +147,12 @@ func (a *Application) Run(ctx context.Context, cli client.Client, respStream cha base.InputIsNeedStreamKeyInArg: input.NeedStream, base.LangchaingoChatMessageHistoryKeyInArg: input.History, // Use an empty context before run - "context": "", + "context": "", + base.ConversationIDInArg: input.ConversationID, } if a.Spec.DocNullReturn != "" { out[base.APPDocNullReturn] = a.Spec.DocNullReturn } - if input.ConversationID != "" { // means this is not a new conversation - conversationKnowledgebaseExist := true - kb := &arcadiav1alpha1.KnowledgeBase{} - err := cli.Get(ctx, types.NamespacedName{Namespace: a.Namespace, Name: input.ConversationID}, kb) - if err != nil { - if apierrors.IsNotFound(err) { - conversationKnowledgebaseExist = false - // TODO We can search for whether there should be a conversation knowledgebase from the pg - klog.FromContext(ctx).V(5).Info("conversation knowledgebase not exist", "ConversationID", input.ConversationID) - } else { - return output, err - } - } - if conversationKnowledgebaseExist { - if kb.Status.IsReady() { - out[base.ConversationKnowledgeBaseInArg] = kb - } else { - klog.FromContext(ctx).V(3).Info("conversation knowledgebase not ready", "ConversationID", input.ConversationID) - } - } - } visited := make(map[string]bool) waitRunningNodes := list.New() for _, v := range a.StartingNodes { @@ -194,7 +172,6 @@ func (a *Application) Run(ctx context.Context, cli client.Client, respStream cha waitRunningNodes.PushBack(e) continue } - klog.FromContext(ctx).V(3).Info(fmt.Sprintf("try to run node:%s", e.Name())) defer func() { if r := recover(); r != nil { klog.FromContext(ctx).Info(fmt.Sprintf("Recovered from node:%s error:%s stack:%s", e.Name(), r, string(debug.Stack()))) @@ -253,7 +230,7 @@ func InitNode(ctx context.Context, appNamespace, name string, ref arcadiav1alpha } }() baseNode := base.NewBaseNode(appNamespace, name, ref) - err = fmt.Errorf("unknown kind %s:%v", name, ref) + err = fmt.Errorf("unknown kind %s:%v, get group:%s kind:%s", name, ref, baseNode.Group(), baseNode.Kind()) switch baseNode.Group() { case "chain": switch baseNode.Kind() { @@ -280,6 +257,9 @@ func InitNode(ctx context.Context, appNamespace, name string, ref arcadiav1alpha case "multiqueryretriever": logger.V(3).Info("initnode multiqueryretriever") return retriever.NewMultiQueryRetriever(baseNode), nil + case "mergerretriever": + logger.V(3).Info("initnode mergerretriever") + return retriever.NewMergerRetriever(baseNode), nil default: return nil, err } diff --git a/pkg/appruntime/base/keyword.go b/pkg/appruntime/base/keyword.go index 2171c84bc..47d9fa171 100644 --- a/pkg/appruntime/base/keyword.go +++ b/pkg/appruntime/base/keyword.go @@ -16,6 +16,12 @@ limitations under the License. package base +import ( + "errors" + + langchainschema "github.com/tmc/langchaingo/schema" +) + const ( InputQuestionKeyInArg = "question" InputIsNeedStreamKeyInArg = "_need_stream" @@ -25,9 +31,63 @@ const ( MapReduceDocumentOutputInArg = "_mapreduce_document_answer" OutputAnswerStreamChanKeyInArg = "_answer_stream" RuntimeRetrieverReferencesKeyInArg = "_references" - LangchaingoRetrieverKeyInArg = "retriever" + LangchaingoRetrieversKeyInArg = "retrievers" LangchaingoLLMKeyInArg = "llm" LangchaingoPromptKeyInArg = "prompt" APPDocNullReturn = "_app_doc_null_return" ConversationKnowledgeBaseInArg = "_conversation_knowledgebase" // the conversation Knowledgebase cr in args, status has ready + ConversationIDInArg = "_conversation_id" +) + +var ( + ErrNoQuestion = errors.New("no question in args") + ErrNoRetrievers = errors.New("no retrievers in args") ) + +func GetInputQuestionFromArg(args map[string]any) (string, error) { + q, ok := args[InputQuestionKeyInArg] + if !ok { + return "", ErrNoQuestion + } + query, ok := q.(string) + if !ok || len(query) == 0 { + return "", errors.New("empty question") + } + return query, nil +} + +func GetRetrieversFromArg(args map[string]any) ([]langchainschema.Retriever, error) { + v, ok := args[LangchaingoRetrieversKeyInArg] + if !ok { + return nil, ErrNoRetrievers + } + retrievers, ok := v.([]langchainschema.Retriever) + if !ok { + return nil, errors.New("retrievers not []schema.Retriever") + } + return retrievers, nil +} + +func GetAPPDocNullReturnFromArg(args map[string]any) (string, error) { + v, ok := args[APPDocNullReturn] + if !ok { + return "", nil + } + docNullReturn, ok := v.(string) + if !ok { + return "", errors.New("app doc null return not string type") + } + return docNullReturn, nil +} + +// AddKnowledgebaseRetrieverToArg add knowledgebase retriever to args +// Note: only knowledgebase retriever will be appended, other components like qachain will only use the first retriever in args +func AddKnowledgebaseRetrieverToArg(args map[string]any, retriever langchainschema.Retriever) map[string]any { + if _, exist := args[LangchaingoRetrieversKeyInArg]; !exist { + args[LangchaingoRetrieversKeyInArg] = make([]langchainschema.Retriever, 0) + } + retrievers := args[LangchaingoRetrieversKeyInArg].([]langchainschema.Retriever) + retrievers = append(retrievers, retriever) + args[LangchaingoRetrieversKeyInArg] = retrievers + return args +} diff --git a/pkg/appruntime/base/node.go b/pkg/appruntime/base/node.go index 7122c608b..c401fc639 100644 --- a/pkg/appruntime/base/node.go +++ b/pkg/appruntime/base/node.go @@ -111,8 +111,8 @@ func (c *BaseNode) Init(_ context.Context, _ client.Client, _ map[string]any) er return nil } -func (c *BaseNode) Run(_ context.Context, _ client.Client, _ map[string]any) (map[string]any, error) { - return nil, nil +func (c *BaseNode) Run(_ context.Context, _ client.Client, args map[string]any) (map[string]any, error) { + return args, nil } func (c *BaseNode) Ready() (bool, string) { diff --git a/pkg/appruntime/chain/llmchain.go b/pkg/appruntime/chain/llmchain.go index 9883693bc..bb84cac1e 100644 --- a/pkg/appruntime/chain/llmchain.go +++ b/pkg/appruntime/chain/llmchain.go @@ -27,7 +27,6 @@ import ( langchaingoschema "github.com/tmc/langchaingo/schema" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" - "k8s.io/utils/pointer" "sigs.k8s.io/controller-runtime/pkg/client" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" @@ -57,19 +56,6 @@ func (l *LLMChain) Init(ctx context.Context, cli client.Client, args map[string] } func (l *LLMChain) Run(ctx context.Context, cli client.Client, args map[string]any) (outArgs map[string]any, err error) { - if args[base.ConversationKnowledgeBaseInArg] != nil { - chain := NewRetrievalQAChain(l.BaseNode) - chain.Instance = &v1alpha1.RetrievalQAChain{ - Spec: v1alpha1.RetrievalQAChainSpec{ - CommonChainConfig: v1alpha1.CommonChainConfig{ - Memory: v1alpha1.Memory{ - ConversionWindowSize: pointer.Int(v1alpha1.DefaultConversionWindowSize), - }, - }, - }, - } - return chain.Run(ctx, cli, args) - } v1, ok := args[base.LangchaingoLLMKeyInArg] if !ok { return args, errors.New("no llm") @@ -122,7 +108,7 @@ func (l *LLMChain) Run(ctx context.Context, cli client.Client, args map[string]a // Add the agent output to the context if it's not empty if args[base.AgentOutputInArg] != nil { klog.FromContext(ctx).V(5).Info(fmt.Sprintf("get answer from upstream: %s", args[base.AgentOutputInArg])) - args["context"] = fmt.Sprintf("%s\n%s", args["context"], args[base.AgentOutputInArg]) + args["context"] = fmt.Sprintf("%s\n %s", args["context"], fmt.Sprintf("Tool Output of question \"%s\" is: %s", args["question"].(string), args[base.AgentOutputInArg].(string))) } // Add the mapReduceDocument output to the context if it's not empty if args[base.MapReduceDocumentOutputInArg] != nil { diff --git a/pkg/appruntime/chain/retrievalqachain.go b/pkg/appruntime/chain/retrievalqachain.go index e14a2eece..2c44aa618 100644 --- a/pkg/appruntime/chain/retrievalqachain.go +++ b/pkg/appruntime/chain/retrievalqachain.go @@ -27,12 +27,9 @@ import ( langchainschema "github.com/tmc/langchaingo/schema" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" - "k8s.io/utils/pointer" "sigs.k8s.io/controller-runtime/pkg/client" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" - apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" - arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" "github.com/kubeagi/arcadia/pkg/appruntime/log" appruntimeretriever "github.com/kubeagi/arcadia/pkg/appruntime/retriever" @@ -77,27 +74,19 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli client.Client, args map[ if !ok { return args, errors.New("prompt not prompts.FormatPrompter") } - vkb, ok := args[base.ConversationKnowledgeBaseInArg] - if ok { - kb, ok := vkb.(*arcadiav1alpha1.KnowledgeBase) - if !ok { - return args, errors.New("knowledgebase not arcadiav1alpha1.KnowledgeBase") - } - args, finish, err := appruntimeretriever.GenerateKnowledgebaseRetriever(ctx, cli, kb.Name, kb.Namespace, apiretriever.CommonRetrieverConfig{ScoreThreshold: pointer.Float32(apiretriever.DefaultScoreThreshold), NumDocuments: 20}, args) - if err != nil { - return args, err - } - defer finish() - } - v3, ok := args["retriever"] - if !ok { - return args, errors.New("no retriever") - } - retriever, ok := v3.(langchainschema.Retriever) - if !ok { - return args, errors.New("retriever not schema.Retriever") + retrieversInArg, err := base.GetRetrieversFromArg(args) + if err != nil { + if errors.Is(err, base.ErrNoRetrievers) { + klog.FromContext(ctx).Info("no retrievers found in chain, roll back to LLMChain") + llmchain := NewLLMChain(l.BaseNode) + llmchain.Instance = &v1alpha1.LLMChain{} + llmchain.Instance.Spec.CommonChainConfig = l.Instance.Spec.CommonChainConfig + return llmchain.Run(ctx, cli, args) + } + return args, err } + retriever := retrieversInArg[0] v4, ok := args[base.LangchaingoChatMessageHistoryKeyInArg] if !ok { @@ -130,6 +119,10 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli client.Client, args map[ } } + docs, err := retriever.GetRelevantDocuments(ctx, args["question"].(string)) + if err != nil { + return args, fmt.Errorf("can't get doc from retriever: %w", err) + } /* Note: we can only add context to retriever's documents not args["context"] if we use ConversationalRetrievalQA chain see https://github.com/kubeagi/langchaingo/blob/ca2f549e8d91788fd76a9f8706afcaae617275c5/chains/stuff_documents.go#L54-L69 @@ -139,14 +132,17 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli client.Client, args map[ // Add the agent output to qachain context if args[base.AgentOutputInArg] != nil { klog.FromContext(ctx).V(5).Info(fmt.Sprintf("get context from agent: %s", args[base.AgentOutputInArg])) - docs, err := retriever.GetRelevantDocuments(ctx, args["question"].(string)) - if err != nil { - return args, fmt.Errorf("can't get doc from retriever: %w", err) - } doc := langchainschema.Document{PageContent: fmt.Sprintf("Tool Output of %s is: %s", args["question"].(string), args[base.AgentOutputInArg].(string))} docs = append(docs, doc) retriever = &appruntimeretriever.Fakeretriever{Docs: docs, Name: "AddAgentOutputRetriever"} } + if len(docs) == 0 { + docNullReturn, err := base.GetAPPDocNullReturnFromArg(args) + if err == nil && len(docNullReturn) > 0 { + return args, &base.RetrieverGetNullDocError{Msg: docNullReturn} + } + } + // Add the mapReduceDocument output to the context if it's not empty if args[base.MapReduceDocumentOutputInArg] != nil { // Note: Now args[base.MapReduceDocumentOutputInArg] will not be null only for the first "summarize content of uploaded document" Chat request @@ -167,28 +163,42 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli client.Client, args map[ condenseQustionGenerator.CallbacksHandler = log.KLogHandler{LogLevel: 3} chain := chains.NewConversationalRetrievalQA(chains.NewStuffDocuments(llmChain), condenseQustionGenerator, retriever, GetMemory(llm, instance.Spec.Memory, history, "", "")) chain.RephraseQuestion = false + chain.ReturnSourceDocuments = true l.ConversationalRetrievalQA = chain args["query"] = args["question"] - var out string + var ( + out string + outputValues map[string]any + ) needStream := false needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool) if ok && needStream { options = append(options, chains.WithStreamingFunc(stream(args))) - out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args, options...) + outputValues, err = chains.Call(ctx, l.ConversationalRetrievalQA, args, options...) } else { if len(options) > 0 { - out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args, options...) + outputValues, err = chains.Call(ctx, l.ConversationalRetrievalQA, args, options...) } else { - out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args) + outputValues, err = chains.Call(ctx, l.ConversationalRetrievalQA, args) } } + // _llmChainDefaultOutputKey + out, _ = outputValues["text"].(string) out, err = handleNoErrNoOut(ctx, needStream, out, err, l.ConversationalRetrievalQA, args, options) klog.FromContext(ctx).V(5).Info("use retrievalqachain, blocking out:" + out) if err == nil { args[base.OutputAnswerKeyInArg] = out + // _conversationalRetrievalQADefaultSourceDocumentKey + doc, ok := outputValues["source_documents"].([]langchainschema.Document) + if ok { + _, refs := appruntimeretriever.ConvertDocuments(ctx, doc, "retrievalqachain") + // note: the references in args will be replaced, not append + args[base.RuntimeRetrieverReferencesKeyInArg] = refs + } return args, nil } + return args, fmt.Errorf("retrievalqachain run error: %w", err) } diff --git a/pkg/appruntime/knowledgebase/knowledgebase.go b/pkg/appruntime/knowledgebase/knowledgebase.go index e57ee21d1..8448f7a36 100644 --- a/pkg/appruntime/knowledgebase/knowledgebase.go +++ b/pkg/appruntime/knowledgebase/knowledgebase.go @@ -23,6 +23,7 @@ import ( "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" + appnode "github.com/kubeagi/arcadia/api/app-node" "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" ) @@ -39,6 +40,9 @@ func NewKnowledgebase(baseNode base.BaseNode) *Knowledgebase { } func (k *Knowledgebase) Init(ctx context.Context, cli client.Client, _ map[string]any) error { + if appnode.IsPlaceholderConversationKnowledgebase(k.Ref.Name) { + return nil + } instance := &v1alpha1.KnowledgeBase{} if err := cli.Get(ctx, types.NamespacedName{Namespace: k.RefNamespace(), Name: k.Ref.Name}, instance); err != nil { return fmt.Errorf("can't find the knowledgebase in cluster: %w", err) @@ -47,10 +51,9 @@ func (k *Knowledgebase) Init(ctx context.Context, cli client.Client, _ map[strin return nil } -func (k *Knowledgebase) Run(_ context.Context, _ client.Client, args map[string]any) (map[string]any, error) { - return args, nil -} - func (k *Knowledgebase) Ready() (isReady bool, msg string) { + if appnode.IsPlaceholderConversationKnowledgebase(k.Ref.Name) { + return true, "" + } return k.Instance.Status.IsReadyOrGetReadyMessage() } diff --git a/pkg/appruntime/log/log.go b/pkg/appruntime/log/log.go index e1c4ba7a3..5eb23c5ce 100644 --- a/pkg/appruntime/log/log.go +++ b/pkg/appruntime/log/log.go @@ -82,7 +82,8 @@ func (l KLogHandler) HandleStreamingFunc(ctx context.Context, chunk []byte) { logger := klog.FromContext(ctx) logger.WithValues("logger", "arcadia") // Lower level for log streaming - logger.V(l.LogLevel + 1).Info("log streaming: " + string(chunk)) + // maybe 6 is enough + logger.V(6).Info("log streaming: " + string(chunk)) } func (l KLogHandler) HandleText(ctx context.Context, text string) { diff --git a/pkg/appruntime/retriever/common.go b/pkg/appruntime/retriever/common.go index 50bf2e2f9..4bd1f7a2f 100644 --- a/pkg/appruntime/retriever/common.go +++ b/pkg/appruntime/retriever/common.go @@ -56,6 +56,8 @@ type Reference struct { Metadata map[string]any `json:"-"` } +const RerankScoreCol string = "rerank_score" + func (reference Reference) String() string { bytes, err := json.Marshal(&reference) if err != nil { @@ -79,6 +81,7 @@ func AddReferencesToArgs(args map[string]any, refs []Reference) map[string]any { return args } +// ConvertDocuments, convert raw doc to what we want doc, for example, knowledgebase doc should add answer into page content func ConvertDocuments(ctx context.Context, docs []langchaingoschema.Document, retrieverName string) (newDocs []langchaingoschema.Document, refs []Reference) { logger := klog.FromContext(ctx) docLen := len(docs) @@ -108,7 +111,7 @@ func ConvertDocuments(ctx context.Context, docs []langchaingoschema.Document, re doc.PageContent = doc.PageContent + joinStr + answer } } - if retrieverName == "multiquery" { + if retrieverName == "multiquery" || retrieverName == "retrievalqachain" { // pageContent may have the answer in previous steps, and we want this field only has question in reference output pageContent = strings.TrimSuffix(pageContent, joinStr+answer) } @@ -145,6 +148,7 @@ func ConvertDocuments(ctx context.Context, docs []langchaingoschema.Document, re content = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") } } + rerankScore, _ := doc.Metadata[RerankScoreCol].(float32) refs = append(refs, Reference{ Question: pageContent, Answer: answer, @@ -155,6 +159,7 @@ func ConvertDocuments(ctx context.Context, docs []langchaingoschema.Document, re PageNumber: page, Content: content, Metadata: doc.Metadata, + RerankScore: rerankScore, }) docs[k] = doc } diff --git a/pkg/appruntime/retriever/knowledgebaseretriever.go b/pkg/appruntime/retriever/knowledgebaseretriever.go index 28450a6fd..e1c207da1 100644 --- a/pkg/appruntime/retriever/knowledgebaseretriever.go +++ b/pkg/appruntime/retriever/knowledgebaseretriever.go @@ -18,16 +18,16 @@ package retriever import ( "context" - "errors" "fmt" - langchaingoschema "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" + apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" "k8s.io/utils/pointer" "sigs.k8s.io/controller-runtime/pkg/client" + appnode "github.com/kubeagi/arcadia/api/app-node" apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" @@ -101,7 +101,21 @@ func (l *KnowledgeBaseRetriever) Cleanup() { func GenerateKnowledgebaseRetriever(ctx context.Context, cli client.Client, knowledgebaseName, knowledgebaseNamespace string, retrieverConfig apiretriever.CommonRetrieverConfig, args map[string]any) (outArg map[string]any, finish func(), err error) { knowledgebase := &v1alpha1.KnowledgeBase{} + isConversationKnowledgebase := appnode.IsPlaceholderConversationKnowledgebase(knowledgebaseName) + if isConversationKnowledgebase { + v, ok := args[base.ConversationIDInArg] + if ok { + conversationID, ok := v.(string) + if ok && conversationID != "" { + knowledgebaseName = conversationID + } + } + } if err := cli.Get(ctx, types.NamespacedName{Namespace: knowledgebaseNamespace, Name: knowledgebaseName}, knowledgebase); err != nil { + if isConversationKnowledgebase && apierrors.IsNotFound(err) { // When there is a conversationID, look for the corresponding conversation knowledgebase. This knowledgebase may not exist. This is not a error + // TODO We can search for whether there should be a conversation knowledgebase from the pg + return args, nil, nil + } return nil, nil, fmt.Errorf("can't find the knowledgebase in cluster: %w", err) } @@ -138,40 +152,14 @@ func GenerateKnowledgebaseRetriever(ctx context.Context, cli client.Client, know } retriever.CallbacksHandler = log.KLogHandler{LogLevel: 3} - question, ok := args["question"] - if !ok { - return nil, finish, errors.New("no question in args") - } - query, ok := question.(string) - if !ok { - return nil, finish, errors.New("question not string") + query, err := base.GetInputQuestionFromArg(args) + if err != nil { + return nil, finish, err } docs, err := retriever.GetRelevantDocuments(ctx, query) if err != nil { return nil, finish, fmt.Errorf("can't get relevant documents: %w", err) } - oldDocs := make([]langchaingoschema.Document, 0) - v, ok := args[base.LangchaingoRetrieverKeyInArg] - if ok { - // may exist other retriever, like conversation retriever - oldRetriever, ok := v.(langchaingoschema.Retriever) - if ok { - oldDocs, err = oldRetriever.GetRelevantDocuments(ctx, query) - if err != nil { - return nil, finish, fmt.Errorf("can't get old doc: %w", err) - } - } - } - if len(docs) == 0 && len(oldDocs) == 0 { - // FIXME: 需要决定当知识库找不到相关内容,但是conversation知识库存在文档时如何处理 - v, exist := args[base.APPDocNullReturn] - if exist { - docNullReturn, ok := v.(string) - if ok && len(docNullReturn) > 0 { - return nil, finish, &base.RetrieverGetNullDocError{Msg: docNullReturn} - } - } - } // pgvector get score means vector distance, similarity = 1 - vector distance // chroma get score means similarity // we want similarity finally. @@ -181,7 +169,7 @@ func GenerateKnowledgebaseRetriever(ctx context.Context, cli client.Client, know } } docs, refs := ConvertDocuments(ctx, docs, "knowledgebase") - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: append(docs, oldDocs...), Name: "KnowledgebaseRetriever"} - AddReferencesToArgs(args, refs) + args = AddReferencesToArgs(args, refs) + args = base.AddKnowledgebaseRetrieverToArg(args, &Fakeretriever{Docs: docs, Name: "KnowledgebaseRetriever"}) return args, finish, nil } diff --git a/pkg/appruntime/retriever/mergerretriever.go b/pkg/appruntime/retriever/mergerretriever.go new file mode 100644 index 000000000..26330da95 --- /dev/null +++ b/pkg/appruntime/retriever/mergerretriever.go @@ -0,0 +1,79 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package retriever + +import ( + "context" + "errors" + "fmt" + + langchainretrievers "github.com/tmc/langchaingo/retrievers" + "github.com/tmc/langchaingo/schema" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" + "github.com/kubeagi/arcadia/pkg/appruntime/base" +) + +type MergerRetriever struct { + base.BaseNode + Instance *apiretriever.MergerRetriever +} + +func NewMergerRetriever(baseNode base.BaseNode) *MergerRetriever { + return &MergerRetriever{ + BaseNode: baseNode, + } +} + +func (l *MergerRetriever) Init(ctx context.Context, cli client.Client, _ map[string]any) error { + instance := &apiretriever.MergerRetriever{} + if err := cli.Get(ctx, types.NamespacedName{Namespace: l.RefNamespace(), Name: l.BaseNode.Ref.Name}, instance); err != nil { + return fmt.Errorf("can't find the merger retriever in cluster: %w", err) + } + l.Instance = instance + return nil +} + +func (l *MergerRetriever) Run(ctx context.Context, _ client.Client, args map[string]any) (map[string]any, error) { + retrievers, err := base.GetRetrieversFromArg(args) + if err != nil { + if errors.Is(err, base.ErrNoRetrievers) { + return args, nil + } + return args, err + } + r := langchainretrievers.NewMergerRetriever(retrievers) + query, err := base.GetInputQuestionFromArg(args) + if err != nil { + return args, err + } + docs, err := r.GetRelevantDocuments(ctx, query) + if err != nil { + return args, err + } + docs, refs := ConvertDocuments(ctx, docs, "mergerretriever") + // note: the references in args will be replaced, not append + args[base.RuntimeRetrieverReferencesKeyInArg] = refs + args[base.LangchaingoRetrieversKeyInArg] = []schema.Retriever{&Fakeretriever{Docs: docs, Name: "mergerRetriever"}} + return args, nil +} + +func (l *MergerRetriever) Ready() (isReady bool, msg string) { + return l.Instance.Status.IsReadyOrGetReadyMessage() +} diff --git a/pkg/appruntime/retriever/multiqueryretriever.go b/pkg/appruntime/retriever/multiqueryretriever.go index 610ec9c1f..e2bca6d4f 100644 --- a/pkg/appruntime/retriever/multiqueryretriever.go +++ b/pkg/appruntime/retriever/multiqueryretriever.go @@ -69,13 +69,12 @@ func (l *MultiQueryRetriever) Run(ctx context.Context, cli client.Client, args m return args, errors.New("empty question") } - v1, ok := args[base.LangchaingoRetrieverKeyInArg] - if !ok { - return args, errors.New("no retriever") - } - retriever, ok := v1.(langchainschema.Retriever) - if !ok { - return args, errors.New("retriever not schema.Retriever") + retrieversInArg, err := base.GetRetrieversFromArg(args) + if err != nil { + if errors.Is(err, base.ErrNoRetrievers) { + return args, nil + } + return args, err } v2, ok := args[base.LangchaingoLLMKeyInArg] @@ -88,7 +87,7 @@ func (l *MultiQueryRetriever) Run(ctx context.Context, cli client.Client, args m } prompt := prompts.NewPromptTemplate(_defaultQueryTemplate, []string{"question"}) llmchain := chains.NewLLMChain(llm, prompt, chains.WithCallback(log.KLogHandler{LogLevel: 3})) - multiqueryRetriever := retrievers.NewMultiQueryRetriever(retriever, llmchain, true) + multiqueryRetriever := retrievers.NewMultiQueryRetriever(retrieversInArg[0], llmchain, true) multiqueryRetriever.CallbacksHandler = log.KLogHandler{LogLevel: 3} docs, err := multiqueryRetriever.GetRelevantDocuments(ctx, query) if err != nil { @@ -107,7 +106,7 @@ func (l *MultiQueryRetriever) Run(ctx context.Context, cli client.Client, args m newDocs, newRef := ConvertDocuments(ctx, newDocs, "multiquery") // note: the references in args will be replaced, not append args[base.RuntimeRetrieverReferencesKeyInArg] = newRef - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: newDocs, Name: "MultiqueryRetriever"} + args[base.LangchaingoRetrieversKeyInArg] = []langchainschema.Retriever{&Fakeretriever{Docs: newDocs, Name: "MultiqueryRetriever"}} return args, nil } diff --git a/pkg/appruntime/retriever/rerankretriever.go b/pkg/appruntime/retriever/rerankretriever.go index 4ae0c1fe6..f7343357e 100644 --- a/pkg/appruntime/retriever/rerankretriever.go +++ b/pkg/appruntime/retriever/rerankretriever.go @@ -66,14 +66,8 @@ func (l *RerankRetriever) Run(ctx context.Context, cli client.Client, args map[s return args, errors.New("empty references") } if len(references) == 0 { - v, exist := args[base.APPDocNullReturn] - if exist { - docNullReturn, ok := v.(string) - if ok && len(docNullReturn) > 0 { - return nil, &base.RetrieverGetNullDocError{Msg: docNullReturn} - } - } - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: nil, Name: "RerankRetriever"} + klog.FromContext(ctx).V(3).Info("rerank retriever get no references, skip rerank") + args[base.LangchaingoRetrieversKeyInArg] = []langchainschema.Retriever{&Fakeretriever{Docs: nil, Name: "RerankRetriever"}} return args, nil } q, ok := args[base.InputQuestionKeyInArg] @@ -143,15 +137,14 @@ func (l *RerankRetriever) Run(ctx context.Context, cli client.Client, args map[s // note: the references in args will be replaced, not append args[base.RuntimeRetrieverReferencesKeyInArg] = newRef - v, ok := args[base.LangchaingoRetrieverKeyInArg] - if !ok { - return args, errors.New("no retriever") - } - retriever, ok := v.(langchainschema.Retriever) - if !ok { - return args, errors.New("retriever not schema.Retriever") + retrievers, err := base.GetRetrieversFromArg(args) + if err != nil { + if errors.Is(err, base.ErrNoRetrievers) { + return args, nil + } + return args, err } - docs, err := retriever.GetRelevantDocuments(ctx, query) + docs, err := retrievers[0].GetRelevantDocuments(ctx, query) if err != nil { return args, fmt.Errorf("get relevant documents failed: %w", err) } @@ -159,11 +152,15 @@ func (l *RerankRetriever) Run(ctx context.Context, cli client.Client, args map[s for i := range newRef { for j := range docs { if newRef[i].Score == docs[j].Score && reflect.DeepEqual(newRef[i].Metadata, docs[j].Metadata) { + if v := docs[j].Metadata; len(v) == 0 { + docs[j].Metadata = make(map[string]any, 1) + } + docs[j].Metadata[RerankScoreCol] = newRef[i].RerankScore newDocs = append(newDocs, docs[j]) } } } - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: newDocs, Name: "RerankRetriever"} + args[base.LangchaingoRetrieversKeyInArg] = []langchainschema.Retriever{&Fakeretriever{Docs: newDocs, Name: "RerankRetriever"}} return args, nil } diff --git a/tests/deploy-values.yaml b/tests/deploy-values.yaml index 3b7e906a5..53fcad628 100644 --- a/tests/deploy-values.yaml +++ b/tests/deploy-values.yaml @@ -140,3 +140,8 @@ postgresql: repository: kubeagi/postgresql tag: 16.1.0-debian-11-r18-pgvector-v0.5.1 pullPolicy: IfNotPresent + +config: + embedder: + enabled: true + model: "arcadia-embedder" diff --git a/tests/example-test.sh b/tests/example-test.sh index 6e80dd971..05ea3a0fe 100755 --- a/tests/example-test.sh +++ b/tests/example-test.sh @@ -40,314 +40,11 @@ Timeout="${TimeoutSeconds}s" mkdir ${TempFilePath} || true env -function debugInfo { - if [[ $? -eq 0 ]]; then - exit 0 - fi - if [[ $debug -ne 0 ]]; then - exit 1 - fi - if [[ $GITHUB_ACTIONS == "true" ]]; then - warning "debugInfo start 🧐" - mkdir -p $LOG_DIR || true - df -h - - warning "1. Try to get all resources " - kubectl api-resources --verbs=list -o name | xargs -n 1 kubectl get -A --ignore-not-found=true --show-kind=true >$LOG_DIR/get-all-resources-list.log - kubectl api-resources --verbs=list -o name | xargs -n 1 kubectl get -A -oyaml --ignore-not-found=true --show-kind=true >$LOG_DIR/get-all-resources-yaml.log - - warning "2. Try to describe all resources " - kubectl api-resources --verbs=list -o name | xargs -n 1 kubectl describe -A >$LOG_DIR/describe-all-resources.log - - warning "3. Try to export kind logs to $LOG_DIR..." - kind export logs --name=${KindName} $LOG_DIR - sudo chown -R $USER:$USER $LOG_DIR - - warning "debugInfo finished ! " - warning "This means that some tests have failed. Please check the log. 🌚" - debug=1 - fi - exit 1 -} +source ./tests/scripts/utils.sh trap 'debugInfo $LINENO' ERR trap 'debugInfo $LINENO' EXIT debug=0 -function cecho() { - declare -A colors - colors=( - ['black']='\E[0;47m' - ['red']='\E[0;31m' - ['green']='\E[0;32m' - ['yellow']='\E[0;33m' - ['blue']='\E[0;34m' - ['magenta']='\E[0;35m' - ['cyan']='\E[0;36m' - ['white']='\E[0;37m' - ) - local defaultMSG="No message passed." - local defaultColor="black" - local defaultNewLine=true - while [[ $# -gt 1 ]]; do - key="$1" - case $key in - -c | --color) - color="$2" - shift - ;; - -n | --noline) - newLine=false - ;; - *) - # unknown option - ;; - esac - shift - done - message=${1:-$defaultMSG} # Defaults to default message. - color=${color:-$defaultColor} # Defaults to default color, if not specified. - newLine=${newLine:-$defaultNewLine} - echo -en "${colors[$color]}" - echo -en "$message" - if [ "$newLine" = true ]; then - echo - fi - tput sgr0 # Reset text attributes to normal without clearing screen. - return -} - -function warning() { - cecho -c 'yellow' "$@" -} - -function error() { - cecho -c 'red' "$@" -} - -function info() { - cecho -c 'blue' "$@" -} - -function waitPodReady() { - namespace=$1 - podLabel=$2 - START_TIME=$(date +%s) - while true; do - readStatus=$(kubectl -n${namespace} get po -l ${podLabel} --ignore-not-found=true -o json | jq -r '.items[0].status.conditions[] | select(."type"=="Ready") | .status') - if [[ $readStatus == "True" ]]; then - info "Pod:${podLabel} ready" - break - fi - kubectl -n${namespace} get po -l ${podLabel} - - CURRENT_TIME=$(date +%s) - ELAPSED_TIME=$((CURRENT_TIME - START_TIME)) - if [ $ELAPSED_TIME -gt $TimeoutSeconds ]; then - error "Timeout reached" - kubectl describe po -n${namespace} -l ${podLabel} - kubectl get po -n${namespace} --show-labels - exit 1 - fi - sleep 5 - done -} - -function EnableAPIServerPortForward() { - waitPodReady "arcadia" "app=arcadia-apiserver" - if [ $portal_pid -ne 0 ]; then - kill $portal_pid >/dev/null 2>&1 - fi - echo "re port-forward apiserver..." - kubectl port-forward svc/arcadia-apiserver -n arcadia 8081:8081 >/dev/null 2>&1 & - portal_pid=$! - sleep 3 - info "port-forward apiserver in pid: $portal_pid" -} - -function waitCRDStatusReady() { - source=$1 - namespace=$2 - name=$3 - START_TIME=$(date +%s) - while true; do - readStatus=$(kubectl -n${namespace} get ${source} ${name} --ignore-not-found=true -o json | jq -r '.status.conditions[0].status') - message=$(kubectl -n${namespace} get ${source} ${name} --ignore-not-found=true -o json | jq -r '.status.conditions[0].message') - if [[ $readStatus == "True" ]]; then - info $message - if [[ ${source} == "KnowledgeBase" ]]; then - fileStatus=$(kubectl get knowledgebase -n $namespace $name -o json | jq -r '.status.fileGroupDetail[0].fileDetails[0].phase') - if [[ $fileStatus != "Succeeded" ]]; then - kubectl get knowledgebase -n $namespace $name -o json | jq -r '.status.fileGroupDetail[0].fileDetails' - exit 1 - fi - fi - break - fi - - CURRENT_TIME=$(date +%s) - ELAPSED_TIME=$((CURRENT_TIME - START_TIME)) - if [[ ${source} == "Worker" ]]; then - if [ $ELAPSED_TIME -gt 1800 ]; then - error "Timeout reached" - exit 1 - fi - else - if [ $ELAPSED_TIME -gt $TimeoutSeconds ]; then - error "Timeout reached" - exit 1 - fi - fi - sleep 5 - done -} - -function getRespInAppChat() { - appname=$1 - namespace=$2 - query=$3 - conversationID=$4 - testStream=$5 - attempt=0 - while true; do - info "sleep 3 seconds" - sleep 3 - data=$(jq -n --arg appname "$appname" --arg query "$query" --arg conversationID "$conversationID" '{"query":$query,"response_mode":"blocking","conversation_id":$conversationID,"app_name":$appname}') - resp=$(curl --max-time $TimeoutSeconds -s --show-error -XPOST http://127.0.0.1:8081/chat --data "$data" -H "namespace: ${namespace}") - ai_data=$(echo $resp | jq -r '.message') - references=$(echo $resp | jq -r '.references') - if [ -z "$ai_data" ] || [ "$ai_data" = "null" ]; then - echo $resp - EnableAPIServerPortForward - if [[ $resp == *"googleapi: Error"* ]]; then - echo "google api error, will retry after 60s" - sleep 60 - fi - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - continue - fi - echo "👤: ${query}" - echo "🤖: ${ai_data}" - echo "🔗: ${references}" - break - done - resp_conversation_id=$(echo $resp | jq -r '.conversation_id') - - if [ $testStream == "true" ]; then - attempt=0 - while true; do - info "sleep 5 seconds" - sleep 5 - info "just test stream mode" - data=$(jq -n --arg appname "$appname" --arg query "$query" --arg conversationID "$conversationID" '{"query":$query,"response_mode":"streaming","conversation_id":$conversationID,"app_name":$appname}') - curl --max-time $TimeoutSeconds -s --show-error -XPOST http://127.0.0.1:8081/chat --data "$data" -H "namespace: ${namespace}" - if [[ $? -ne 0 ]]; then - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - EnableAPIServerPortForward - echo "and wait 60s for google api error" - sleep 60 - continue - fi - break - done - fi -} - -function fileUploadSummarise() { - appname=$1 - namespace=$2 - filename=$3 - attempt=0 - while true; do - info "sleep 3 seconds" - sleep 3 - resp=$(curl --max-time $TimeoutSeconds -s --show-error -XPOST --form file=@$filename --form app_name=$appname -H "namespace: ${namespace}" -H "Content-Type: multipart/form-data" http://127.0.0.1:8081/chat/conversations/file) - doc_data=$(echo $resp | jq -r '.document') - if [ -z "$doc_data" ]; then - echo $resp - EnableAPIServerPortForward - if [[ $resp == *"googleapi: Error"* ]]; then - echo "google api error, will retry after 60s" - sleep 60 - fi - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - continue - fi - echo "👤: ${filename}" - echo "🤖: ${doc_data}" - break - done - file_id=$(echo $resp | jq -r '.document.object') - resp_conversation_id=$(echo $resp | jq -r '.conversation_id') - attempt=0 - while true; do - info "sleep 3 seconds to sumerize doc" - sleep 3 - data=$(jq -n --arg fileid "$file_id" --arg appname "$appname" --arg query "总结一下" --arg conversationID "$resp_conversation_id" '{"query":$query,"response_mode":"blocking","conversation_id":$conversationID,"app_name":$appname, "files": [$fileid]}') - resp=$(curl --max-time $TimeoutSeconds -s --show-error -XPOST http://127.0.0.1:8081/chat --data "$data" -H "namespace: ${namespace}") - ai_data=$(echo $resp | jq -r '.message') - references=$(echo $resp | jq -r '.references') - if [ -z "$ai_data" ] || [ "$ai_data" = "null" ]; then - echo $resp - EnableAPIServerPortForward - if [[ $resp == *"googleapi: Error"* ]]; then - echo "google api error, will retry after 60s" - sleep 60 - fi - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - continue - fi - echo "👤: 总结一下" - echo "🤖: ${ai_data}" - echo "🔗: ${references}" - break - done - resp_conversation_id=$(echo $resp | jq -r '.conversation_id') - - if [ $testStream == "true" ]; then - attempt=0 - while true; do - info "sleep 5 seconds" - sleep 5 - info "just test stream mode" - data=$(jq -n --arg fileid "$file_id" --arg appname "$appname" --arg query "总结一下" --arg conversationID "$resp_conversation_id" '{"query":$query,"response_mode":"blocking","conversation_id":$conversationID,"app_name":$appname, "files": [$fileid]}') - curl --max-time $TimeoutSeconds -s --show-error -XPOST http://127.0.0.1:8081/chat --data "$data" -H "namespace: ${namespace}" - if [[ $? -ne 0 ]]; then - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - EnableAPIServerPortForward - echo "and wait 60s for google api error" - sleep 60 - continue - fi - break - done - fi -} - info "1. create kind cluster" make kind df -h @@ -410,6 +107,7 @@ s3_secret=$(kubectl get secrets -n arcadia datasource-sample-authsecret -o json export MC_HOST_arcadiatest=http://${s3_key}:${s3_secret}@127.0.0.1:9000 mc cp pkg/documentloaders/testdata/qa.csv arcadiatest/${bucket}/qa.csv mc cp pkg/documentloaders/testdata/chunk.csv arcadiatest/${bucket}/chunk.csv +mc cp CODE_OF_CONDUCT.md arcadiatest/${bucket}/CODE_OF_CONDUCT.md info "add tags to these files" mc tag set arcadiatest/${bucket}/qa.csv "object_type=QA" mc tag set arcadiatest/${bucket}/chunk.csv "object_type=QA" @@ -445,6 +143,7 @@ sleep 3 info "7.4.2 create knowledgebase based on pgvector and wait it ready" kubectl apply -f config/samples/arcadia_v1alpha1_knowledgebase_pgvector.yaml waitCRDStatusReady "KnowledgeBase" "arcadia" "knowledgebase-sample-pgvector" +waitCRDStatusReady "KnowledgeBase" "arcadia" "knowledgebase-sample-pgvector2" info "7.5 check vectorstore has data" info "7.5.1 check chroma vectorstore has data" @@ -590,7 +289,14 @@ kubectl patch applications -n arcadia base-chat-with-knowledgebase-pgvector -p ' getRespInAppChat "base-chat-with-knowledgebase-pgvector" "arcadia" "飞天的主演是谁?" "" "true" info "8.2.3 QA app using knowledgebase base on pgvector and rerank" -kubectl apply -f config/samples/arcadia_v1alpha1_model_reranking_bce.yaml +if [[ $GITHUB_ACTIONS == "true" ]]; then + info "in github action, download model from huggingface" + kubectl apply -f config/samples/arcadia_v1alpha1_model_reranking_bce.yaml +else + # https://github.com/kubeagi/core-library/issues/54 + info "in local, download model from modelscope" + kubectl apply -f config/samples/arcadia_v1alpha1_model_reranking_bce_modelscope.yaml +fi waitCRDStatusReady "Model" "arcadia" "bce-reranker" kubectl apply -f config/samples/arcadia_v1alpha1_worker_reranking_bce.yaml waitCRDStatusReady "Worker" "arcadia" "bce-reranker" @@ -668,6 +374,37 @@ info "8.2.5.3 When no related doc is found and application.spec.docNullReturn is kubectl patch applications -n arcadia base-chat-with-knowledgebase-pgvector-multiquery -p '{"spec":{"docNullReturn":""}}' --type='merge' getRespInAppChat "base-chat-with-knowledgebase-pgvector-multiquery" "arcadia" "飞天的主演是谁?" "" "true" +info "8.2.6 QA app using multiple knowledgebase base on pgvector and multiquery" +kubectl apply -f config/samples/app_retrievalqachain_multi_knowledgebase_pgvector_rerank_multiquery.yaml +waitCRDStatusReady "Application" "arcadia" "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" +sleep 3 +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "公司的考勤管理制度适用于哪些人员?" "" "true" +if [[ $ai_data != *"全体正式员工及实习生"* ]]; then + echo "resp should contains '公司全体正式员工及实习生', but resp is:"$resp + exit 1 +fi +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "怀孕9个月以上每月可以享受几天假期?" "" "true" +if [[ $ai_data != *"4"* ]]; then + echo "resp should contains '4', but resp is:"$resp + exit 1 +fi +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "arcadia follows which Conduct?" "" "true" +# FIXME: how to change cnfc to cncf +if [[ $ai_data != *"cncf"* ]] && [[ $ai_data != *"CNCF"* ]] && [[ $ai_data != *"CNFC"* ]]; then + echo "resp should contains 'cncf', but resp is:"$resp + exit 1 +fi +info "8.2.5.2 When no related doc is found, return application.spec.docNullReturn info, if set" +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "飞天的主演是谁?" "" "true" +expected=$(kubectl get applications -n arcadia base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery -o json | jq -r .spec.docNullReturn) +if [[ $ai_data != $expected ]]; then + echo "when no related doc is found, return application.spec.docNullReturn info should be:"$expected ", but resp:"$resp + exit 1 +fi +info "8.2.5.3 When no related doc is found and application.spec.docNullReturn is not set" +kubectl patch applications -n arcadia base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery -p '{"spec":{"docNullReturn":""}}' --type='merge' +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "飞天的主演是谁?" "" "true" + info "8.3 conversation chat app" kubectl apply -f config/samples/app_llmchain_chat_with_bot.yaml waitCRDStatusReady "Application" "arcadia" "base-chat-with-bot" @@ -747,17 +484,40 @@ kubectl apply -f config/samples/app_llmchain_abstract.yaml waitCRDStatusReady "Application" "arcadia" "base-chat-document-assistant" fileUploadSummarise "base-chat-document-assistant" "arcadia" "./pkg/documentloaders/testdata/arcadia-readme.pdf" getRespInAppChat "base-chat-document-assistant" "arcadia" "what is arcadia?" ${resp_conversation_id} "false" +if [[ $ai_data != *"mythology"* ]] && [[ $ai_data != *"LLMOps"* ]] && [[ $ai_data != *"Kubernetes"* ]]; then + echo "resp should contains 'mythology' or 'LLMOps' or 'Kubernetes', but resp is:"$ai_data + exit 1 +fi getRespInAppChat "base-chat-document-assistant" "arcadia" "Does your model based on gpt-3.5?" ${resp_conversation_id} "false" info "8.4.7 chat with document with knowledgebase" +kubectl apply -f config/samples/app_retrievalqachain_knowledgebase_pgvector-conversation.yaml fileUploadSummarise "base-chat-with-knowledgebase-pgvector" "arcadia" "./pkg/documentloaders/testdata/arcadia-readme.pdf" getRespInAppChat "base-chat-with-knowledgebase-pgvector" "arcadia" "what is arcadia?" ${resp_conversation_id} "false" +if [[ $ai_data != *"mythology"* ]] && [[ $ai_data != *"LLMOps"* ]] && [[ $ai_data != *"Kubernetes"* ]]; then + echo "resp should contains 'mythology' or 'LLMOps' or 'Kubernetes', but resp is:"$ai_data + exit 1 +fi getRespInAppChat "base-chat-with-knowledgebase-pgvector" "arcadia" "公司的考勤管理制度适用于哪些人员?" ${resp_conversation_id} "false" # FIXME According to the log, we returned the corresponding document to the big model, but the big model probably returned unknown -#if [[ $ai_data != *"全体正式员工及实习生"* ]]; then -# echo "resp should contains '公司全体正式员工及实习生', but resp is:"$resp -# exit 1 -#fi +if [[ $ai_data != *"全体正式员工及实习生"* ]]; then + echo "resp should contains '公司全体正式员工及实习生', but resp is:"$resp + exit 1 +fi + +info "8.4.8 chat with document with multi knowledgebase" +fileUploadSummarise "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "./pkg/documentloaders/testdata/arcadia-readme.pdf" +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "what is arcadia?" ${resp_conversation_id} "false" +if [[ $ai_data != *"mythology"* ]] && [[ $ai_data != *"LLMOps"* ]] && [[ $ai_data != *"Kubernetes"* ]]; then + echo "resp should contains 'mythology' or 'LLMOps' or 'Kubernetes', but resp is:"$ai_data + exit 1 +fi +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "公司的考勤管理制度适用于哪些人员?" ${resp_conversation_id} "false" +# FIXME According to the log, we returned the corresponding document to the big model, but the big model probably returned unknown +if [[ $ai_data != *"全体正式员工及实习生"* ]]; then + echo "resp should contains '公司全体正式员工及实习生', but resp is:"$resp + exit 1 +fi # There is uncertainty in the AI replies, most of the time, it will pass the test, a small percentage of the time, the AI will call names in each reply, causing the test to fail, therefore, temporarily disable the following tests #getRespInAppChat "base-chat-with-bot" "arcadia" "What is your model?" ${resp_conversation_id} "false" @@ -871,3 +631,4 @@ kubectl logs --tail=100 -n arcadia -l app=arcadia-apiserver >/tmp/apiserver.log cat /tmp/apiserver.log info "all finished! ✅" +exit 1