diff --git a/PROJECT b/PROJECT index 5b5224bac..cddf8a2ef 100644 --- a/PROJECT +++ b/PROJECT @@ -151,6 +151,22 @@ resources: kind: RerankRetriever path: github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1 version: v1alpha1 +- api: + crdVersion: v1 + controller: true + domain: arcadia.kubeagi.k8s.com.cn + group: retriever + kind: MergerRetriever + path: github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1 + version: v1alpha1 +- api: + crdVersion: v1 + controller: true + domain: arcadia.kubeagi.k8s.com.cn + group: retriever + kind: MultiQueryRetriever + path: github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1 + version: v1alpha1 - api: crdVersion: v1 controller: true diff --git a/api/app-node/common_type.go b/api/app-node/common_type.go index 18493f19b..185a4fcf1 100644 --- a/api/app-node/common_type.go +++ b/api/app-node/common_type.go @@ -25,8 +25,15 @@ import ( const ( InputLengthAnnotationKey = v1alpha1.Group + `/input-rules` OutputLengthAnnotationKey = v1alpha1.Group + `/output-rules` + + // ConversationKnowledgebaseName is the placeholder name of the conversation knowledgebase + ConversationKnowledgebaseName = "conversation-knowledgebase-placeholder" ) +func IsPlaceholderConversationKnowledgebase(name string) bool { + return name == ConversationKnowledgebaseName +} + type Ref struct { Kind string `json:"kind,omitempty"` Group string `json:"group,omitempty"` diff --git a/api/app-node/retriever/v1alpha1/mergerretriever_types.go b/api/app-node/retriever/v1alpha1/mergerretriever_types.go new file mode 100644 index 000000000..cb8991ea9 --- /dev/null +++ b/api/app-node/retriever/v1alpha1/mergerretriever_types.go @@ -0,0 +1,76 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + node "github.com/kubeagi/arcadia/api/app-node" + "github.com/kubeagi/arcadia/api/base/v1alpha1" +) + +// MergerRetrieverSpec defines the desired state of MergerRetriever +type MergerRetrieverSpec struct { + v1alpha1.CommonSpec `json:",inline"` +} + +// MergerRetrieverStatus defines the observed state of MergerRetriever +type MergerRetrieverStatus struct { + // ObservedGeneration is the last observed generation. + // +optional + ObservedGeneration int64 `json:"observedGeneration,omitempty"` + + // ConditionedStatus is the current status + v1alpha1.ConditionedStatus `json:",inline"` +} + +//+kubebuilder:object:root=true +//+kubebuilder:subresource:status + +// MergerRetriever is the Schema for the MergerRetriever API +type MergerRetriever struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + + Spec MergerRetrieverSpec `json:"spec,omitempty"` + Status MergerRetrieverStatus `json:"status,omitempty"` +} + +//+kubebuilder:object:root=true + +// MergerRetrieverList contains a list of MergerRetriever +type MergerRetrieverList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata,omitempty"` + Items []MergerRetriever `json:"items"` +} + +func init() { + SchemeBuilder.Register(&MergerRetriever{}, &MergerRetrieverList{}) +} + +var _ node.Node = (*MergerRetriever)(nil) + +func (c *MergerRetriever) SetRef() { + annotations := node.SetRefAnnotations(c.GetAnnotations(), []node.Ref{node.RetrieverRef.Len(1)}, []node.Ref{node.RetrievalQAChainRef.Len(1)}) + if c.GetAnnotations() == nil { + c.SetAnnotations(annotations) + } + for k, v := range annotations { + c.Annotations[k] = v + } +} diff --git a/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go b/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go index 648d0d494..be0faddb5 100644 --- a/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go +++ b/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go @@ -138,6 +138,97 @@ func (in *KnowledgeBaseRetrieverStatus) DeepCopy() *KnowledgeBaseRetrieverStatus return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetriever) DeepCopyInto(out *MergerRetriever) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + out.Spec = in.Spec + in.Status.DeepCopyInto(&out.Status) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetriever. +func (in *MergerRetriever) DeepCopy() *MergerRetriever { + if in == nil { + return nil + } + out := new(MergerRetriever) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *MergerRetriever) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetrieverList) DeepCopyInto(out *MergerRetrieverList) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ListMeta.DeepCopyInto(&out.ListMeta) + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]MergerRetriever, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetrieverList. +func (in *MergerRetrieverList) DeepCopy() *MergerRetrieverList { + if in == nil { + return nil + } + out := new(MergerRetrieverList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *MergerRetrieverList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetrieverSpec) DeepCopyInto(out *MergerRetrieverSpec) { + *out = *in + out.CommonSpec = in.CommonSpec +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetrieverSpec. +func (in *MergerRetrieverSpec) DeepCopy() *MergerRetrieverSpec { + if in == nil { + return nil + } + out := new(MergerRetrieverSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetrieverStatus) DeepCopyInto(out *MergerRetrieverStatus) { + *out = *in + in.ConditionedStatus.DeepCopyInto(&out.ConditionedStatus) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetrieverStatus. +func (in *MergerRetrieverStatus) DeepCopy() *MergerRetrieverStatus { + if in == nil { + return nil + } + out := new(MergerRetrieverStatus) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *MultiQueryRetriever) DeepCopyInto(out *MultiQueryRetriever) { *out = *in diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index 656022a90..d56d00231 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -87,6 +87,7 @@ type ComplexityRoot struct { EnableRerank func(childComplexity int) int EnableUploadFile func(childComplexity int) int Knowledgebase func(childComplexity int) int + Knowledgebases func(childComplexity int) int Llm func(childComplexity int) int MaxLength func(childComplexity int) int MaxTokens func(childComplexity int) int @@ -1097,6 +1098,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Application.Knowledgebase(childComplexity), true + case "Application.knowledgebases": + if e.complexity.Application.Knowledgebases == nil { + break + } + + return e.complexity.Application.Knowledgebases(childComplexity), true + case "Application.llm": if e.complexity.Application.Llm == nil { break @@ -5180,10 +5188,15 @@ type Application { """ conversionWindowSize: Int + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] + """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") """ scoreThreshold 最终返回结果的最低相似度 @@ -5504,7 +5517,12 @@ input UpdateApplicationConfigInput { """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") + + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] """ scoreThreshold 最终返回结果的最低相似度 @@ -10008,6 +10026,47 @@ func (ec *executionContext) fieldContext_Application_conversionWindowSize(ctx co return fc, nil } +func (ec *executionContext) _Application_knowledgebases(ctx context.Context, field graphql.CollectedField, obj *Application) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Application_knowledgebases(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Knowledgebases, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.([]*string) + fc.Result = res + return ec.marshalOString2ᚕᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Application_knowledgebases(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Application", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Application_knowledgebase(ctx context.Context, field graphql.CollectedField, obj *Application) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Application_knowledgebase(ctx, field) if err != nil { @@ -11699,6 +11758,8 @@ func (ec *executionContext) fieldContext_ApplicationMutation_updateApplicationCo return ec.fieldContext_Application_maxTokens(ctx, field) case "conversionWindowSize": return ec.fieldContext_Application_conversionWindowSize(ctx, field) + case "knowledgebases": + return ec.fieldContext_Application_knowledgebases(ctx, field) case "knowledgebase": return ec.fieldContext_Application_knowledgebase(ctx, field) case "scoreThreshold": @@ -11808,6 +11869,8 @@ func (ec *executionContext) fieldContext_ApplicationQuery_getApplication(ctx con return ec.fieldContext_Application_maxTokens(ctx, field) case "conversionWindowSize": return ec.fieldContext_Application_conversionWindowSize(ctx, field) + case "knowledgebases": + return ec.fieldContext_Application_knowledgebases(ctx, field) case "knowledgebase": return ec.fieldContext_Application_knowledgebase(ctx, field) case "scoreThreshold": @@ -28890,6 +28953,8 @@ func (ec *executionContext) fieldContext_RAG_application(ctx context.Context, fi return ec.fieldContext_Application_maxTokens(ctx, field) case "conversionWindowSize": return ec.fieldContext_Application_conversionWindowSize(ctx, field) + case "knowledgebases": + return ec.fieldContext_Application_knowledgebases(ctx, field) case "knowledgebase": return ec.fieldContext_Application_knowledgebase(ctx, field) case "scoreThreshold": @@ -39224,7 +39289,7 @@ func (ec *executionContext) unmarshalInputUpdateApplicationConfigInput(ctx conte asMap[k] = v } - fieldsInOrder := [...]string{"name", "namespace", "prologue", "model", "llm", "temperature", "maxLength", "maxTokens", "conversionWindowSize", "knowledgebase", "scoreThreshold", "numDocuments", "docNullReturn", "userPrompt", "systemPrompt", "showRespInfo", "showRetrievalInfo", "showNextGuide", "tools", "enableRerank", "rerankModel", "enableMultiQuery", "chatTimeout", "enableUploadFile", "chunkSize", "chunkOverlap", "batchSize"} + fieldsInOrder := [...]string{"name", "namespace", "prologue", "model", "llm", "temperature", "maxLength", "maxTokens", "conversionWindowSize", "knowledgebase", "knowledgebases", "scoreThreshold", "numDocuments", "docNullReturn", "userPrompt", "systemPrompt", "showRespInfo", "showRetrievalInfo", "showNextGuide", "tools", "enableRerank", "rerankModel", "enableMultiQuery", "chatTimeout", "enableUploadFile", "chunkSize", "chunkOverlap", "batchSize"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -39301,6 +39366,13 @@ func (ec *executionContext) unmarshalInputUpdateApplicationConfigInput(ctx conte return it, err } it.Knowledgebase = data + case "knowledgebases": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("knowledgebases")) + data, err := ec.unmarshalOString2ᚕᚖstring(ctx, v) + if err != nil { + return it, err + } + it.Knowledgebases = data case "scoreThreshold": ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("scoreThreshold")) data, err := ec.unmarshalOFloat2ᚖfloat64(ctx, v) @@ -40622,6 +40694,8 @@ func (ec *executionContext) _Application(ctx context.Context, sel ast.SelectionS out.Values[i] = ec._Application_maxTokens(ctx, field, obj) case "conversionWindowSize": out.Values[i] = ec._Application_conversionWindowSize(ctx, field, obj) + case "knowledgebases": + out.Values[i] = ec._Application_knowledgebases(ctx, field, obj) case "knowledgebase": out.Values[i] = ec._Application_knowledgebase(ctx, field, obj) case "scoreThreshold": diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index 86765d4ff..e90b6c531 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -54,6 +54,8 @@ type Application struct { MaxTokens *int `json:"maxTokens,omitempty"` // conversionWindowSize 对话轮次 ConversionWindowSize *int `json:"conversionWindowSize,omitempty"` + // knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + Knowledgebases []*string `json:"knowledgebases,omitempty"` // knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 Knowledgebase *string `json:"knowledgebase,omitempty"` // scoreThreshold 最终返回结果的最低相似度 @@ -1709,6 +1711,8 @@ type UpdateApplicationConfigInput struct { ConversionWindowSize *int `json:"conversionWindowSize,omitempty"` // knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 Knowledgebase *string `json:"knowledgebase,omitempty"` + // knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + Knowledgebases []*string `json:"knowledgebases,omitempty"` // scoreThreshold 最终返回结果的最低相似度 ScoreThreshold *float64 `json:"scoreThreshold,omitempty"` // numDocuments 最终返回结果的引用上限 diff --git a/apiserver/graph/schema/application.gql b/apiserver/graph/schema/application.gql index 82804af47..ee68e3d33 100644 --- a/apiserver/graph/schema/application.gql +++ b/apiserver/graph/schema/application.gql @@ -78,6 +78,7 @@ mutation updateApplicationConfig($input: UpdateApplicationConfigInput!){ maxTokens conversionWindowSize knowledgebase + knowledgebases scoreThreshold numDocuments docNullReturn @@ -131,6 +132,7 @@ query getApplication($name: String!, $namespace: String!){ maxTokens conversionWindowSize knowledgebase + knowledgebases scoreThreshold numDocuments docNullReturn diff --git a/apiserver/graph/schema/application.graphqls b/apiserver/graph/schema/application.graphqls index f6c04afe7..4637a3e31 100644 --- a/apiserver/graph/schema/application.graphqls +++ b/apiserver/graph/schema/application.graphqls @@ -59,10 +59,15 @@ type Application { """ conversionWindowSize: Int + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] + """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") """ scoreThreshold 最终返回结果的最低相似度 @@ -383,7 +388,12 @@ input UpdateApplicationConfigInput { """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") + + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] """ scoreThreshold 最终返回结果的最低相似度 diff --git a/apiserver/pkg/application/application.go b/apiserver/pkg/application/application.go index 4fe3bca10..10f49049b 100644 --- a/apiserver/pkg/application/application.go +++ b/apiserver/pkg/application/application.go @@ -33,6 +33,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + appnode "github.com/kubeagi/arcadia/api/app-node" apiagent "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" apichain "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" apidocumentloader "github.com/kubeagi/arcadia/api/app-node/documentloader/v1alpha1" @@ -65,7 +66,7 @@ func addDefaultValue(gApp *generated.Application, app *v1alpha1.Application) { if len(app.Spec.Nodes) > 0 { return } - gApp.DocNullReturn = pointer.String("未找到您询问的内容,请详细描述您的问题") + // gApp.DocNullReturn = pointer.String("未找到您询问的内容,请详细描述您的问题") gApp.NumDocuments = pointer.Int(5) gApp.ScoreThreshold = pointer.Float64(0.3) gApp.Temperature = pointer.Float64(0.7) @@ -136,9 +137,10 @@ func cr2app(prompt *apiprompt.Prompt, chainConfig *apichain.CommonChainConfig, r case "llm": gApp.Llm = node.Ref.Name case "knowledgebase": - gApp.Knowledgebase = pointer.String(node.Ref.Name) + gApp.Knowledgebases = append(gApp.Knowledgebases, pointer.String(node.Ref.Name)) } } + gApp.Knowledgebase = ConvertKnowledgebases2Knowledgebase(gApp.Knowledgebases) // nolint:staticcheck if retriever != nil { gApp.ScoreThreshold = pointer.Float64(float64(pointer.Float32Deref(retriever.ScoreThreshold, 0.0))) gApp.NumDocuments = pointer.Int(retriever.NumDocuments) @@ -429,6 +431,9 @@ func ListApplicationMeatadatas(ctx context.Context, c client.Client, input gener } func UpdateApplicationConfig(ctx context.Context, c client.Client, input generated.UpdateApplicationConfigInput) (*generated.Application, error) { + input.Knowledgebases = ConvertKnowledgebase2Knowledgebases(input.Knowledgebase, input.Knowledgebases) // nolint:staticcheck + hasKnowledgebaseOrEnableUpload := utils.HasValues(input.Knowledgebases) || pointer.BoolDeref(input.EnableUploadFile, false) // input has knowledgebases or enable upload file(may have conversation knowledgebase) + // check tool name not duplicated if len(input.Tools) != 0 { key := make(map[string]bool, len(input.Tools)) for _, tool := range input.Tools { @@ -438,6 +443,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat key[tool.Name] = true } } + key := types.NamespacedName{Namespace: input.Namespace, Name: input.Name} // get application cr, if not exist, return error app := &v1alpha1.Application{} @@ -510,12 +516,53 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat } } + // create or update placeholder conversation knowledgebase, if enable upload file + var ( + conversationKnowledgebaseRetriever *apiretriever.KnowledgeBaseRetriever + ) + // note: no need to create placeholder conversation knowledgebase cr, it will be replaced by true conversation knowledgebase in appruntime + if pointer.BoolDeref(input.EnableUploadFile, false) { + conversationKnowledgebaseRetriever = &apiretriever.KnowledgeBaseRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name + "-" + appnode.ConversationKnowledgebaseName, + Namespace: input.Namespace, + }, + Spec: apiretriever.KnowledgeBaseRetrieverSpec{ + CommonSpec: v1alpha1.CommonSpec{ + DisplayName: "retriever", + Description: "retriever", + }, + CommonRetrieverConfig: apiretriever.CommonRetrieverConfig{ + ScoreThreshold: pointer.Float32(float32(pointer.Float64Deref(input.ScoreThreshold, apiretriever.DefaultScoreThreshold))), + NumDocuments: pointer.IntDeref(input.NumDocuments, apiretriever.DefaultNumDocuments), + }, + }, + } + if _, err = controllerutil.CreateOrUpdate(ctx, c, conversationKnowledgebaseRetriever, func() error { + if input.ScoreThreshold != nil { + conversationKnowledgebaseRetriever.Spec.ScoreThreshold = pointer.Float32(float32(*input.ScoreThreshold)) + } + conversationKnowledgebaseRetriever.Spec.NumDocuments = pointer.IntDeref(input.NumDocuments, conversationKnowledgebaseRetriever.Spec.NumDocuments) + return nil + }); err != nil { + return nil, err + } + } else { + conversationKnowledgebaseRetriever = &apiretriever.KnowledgeBaseRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name + "-" + appnode.ConversationKnowledgebaseName, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, conversationKnowledgebaseRetriever) + } + // create or update chain var ( chainConfig *apichain.CommonChainConfig retriever *apiretriever.CommonRetrieverConfig ) - if utils.HasValue(input.Knowledgebase) { + if hasKnowledgebaseOrEnableUpload { qachain := &apichain.RetrievalQAChain{ ObjectMeta: metav1.ObjectMeta{ Name: input.Name, @@ -547,6 +594,13 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat }); err != nil { return nil, err } + llmchain := &apichain.LLMChain{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, llmchain) chainConfig = &qachain.Spec.CommonChainConfig } else { llmchain := &apichain.LLMChain{ @@ -580,43 +634,88 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat }); err != nil { return nil, err } + qachain := &apichain.RetrievalQAChain{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, qachain) chainConfig = &llmchain.Spec.CommonChainConfig } // create or update retrievers - // knowledgebaseRetriever (must have) -> multiQueryRetriever (optional) -> rerankRetriever (optional) -> Output - hasKnowledgebaseRetriever := utils.HasValue(input.Knowledgebase) + // (must have) (must have) (optional) (optional) + // knowledgebase1 -> knowledgebaseRetriever1 + // knowledgebase2 -> knowledgebaseRetriever2 --> mergerRetriever --> multiQueryRetriever --> rerankRetriever --> Output + // conversation knowledgebase (placeholder or has true data) -> knowledgebaseRetriever3 + hasKnowledgebaseRetriever := hasKnowledgebaseOrEnableUpload hasMultiQueryRetriever := hasKnowledgebaseRetriever && pointer.BoolDeref(input.EnableMultiQuery, false) hasRerankRetriever := hasKnowledgebaseRetriever && pointer.BoolDeref(input.EnableRerank, false) rerankModel := "" var knowledgebaseRetriever *apiretriever.KnowledgeBaseRetriever if hasKnowledgebaseRetriever { - knowledgebaseRetriever = &apiretriever.KnowledgeBaseRetriever{ + for i := range input.Knowledgebases { + indexStr := fmt.Sprintf("-%d", i) + if i == 0 { + // Compatible with previous versions when there is only one knowledgebase + indexStr = "" + } + knowledgebaseRetriever = &apiretriever.KnowledgeBaseRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name + indexStr, + Namespace: input.Namespace, + Labels: map[string]string{ + "application": input.Name, + }, + }, + Spec: apiretriever.KnowledgeBaseRetrieverSpec{ + CommonSpec: v1alpha1.CommonSpec{ + DisplayName: "retriever", + Description: "retriever", + }, + CommonRetrieverConfig: apiretriever.CommonRetrieverConfig{ + ScoreThreshold: pointer.Float32(float32(pointer.Float64Deref(input.ScoreThreshold, apiretriever.DefaultScoreThreshold))), + NumDocuments: pointer.IntDeref(input.NumDocuments, apiretriever.DefaultNumDocuments), + }, + }, + } + if _, err = controllerutil.CreateOrUpdate(ctx, c, knowledgebaseRetriever, func() error { + if input.ScoreThreshold != nil { + knowledgebaseRetriever.Spec.ScoreThreshold = pointer.Float32(float32(*input.ScoreThreshold)) + } + knowledgebaseRetriever.Spec.NumDocuments = pointer.IntDeref(input.NumDocuments, knowledgebaseRetriever.Spec.NumDocuments) + return nil + }); err != nil { + return nil, err + } + retriever = &knowledgebaseRetriever.Spec.CommonRetrieverConfig + } + } else { + knowledgebaseRetriever = &apiretriever.KnowledgeBaseRetriever{} + _ = c.DeleteAllOf(ctx, knowledgebaseRetriever, client.InNamespace(input.Namespace), client.MatchingLabels{"application": input.Name}) + } + + if hasKnowledgebaseRetriever { + mergerRetriever := &apiretriever.MergerRetriever{ ObjectMeta: metav1.ObjectMeta{ Name: input.Name, Namespace: input.Namespace, }, - Spec: apiretriever.KnowledgeBaseRetrieverSpec{ - CommonSpec: v1alpha1.CommonSpec{ - DisplayName: "retriever", - Description: "retriever", - }, - CommonRetrieverConfig: apiretriever.CommonRetrieverConfig{ - ScoreThreshold: pointer.Float32(float32(pointer.Float64Deref(input.ScoreThreshold, apiretriever.DefaultScoreThreshold))), - NumDocuments: pointer.IntDeref(input.NumDocuments, apiretriever.DefaultNumDocuments), - }, - }, } - if _, err = controllerutil.CreateOrUpdate(ctx, c, knowledgebaseRetriever, func() error { - if input.ScoreThreshold != nil { - knowledgebaseRetriever.Spec.ScoreThreshold = pointer.Float32(float32(*input.ScoreThreshold)) - } - knowledgebaseRetriever.Spec.NumDocuments = pointer.IntDeref(input.NumDocuments, knowledgebaseRetriever.Spec.NumDocuments) + if _, err = controllerutil.CreateOrUpdate(ctx, c, mergerRetriever, func() error { return nil }); err != nil { return nil, err } - retriever = &knowledgebaseRetriever.Spec.CommonRetrieverConfig + } else { + mergerRetriever := &apiretriever.MergerRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, mergerRetriever) } if hasMultiQueryRetriever { @@ -653,7 +752,16 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat return nil, err } retriever = &multiQueryRetriever.Spec.CommonRetrieverConfig + } else { + multiQueryRetriever := &apiretriever.MultiQueryRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, multiQueryRetriever) } + if hasRerankRetriever { rerankRetriever := &apiretriever.RerankRetriever{ ObjectMeta: metav1.ObjectMeta{ @@ -695,17 +803,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat if rerankRetriever.Spec.Model != nil { rerankModel = rerankRetriever.Spec.Model.Name } - } - if !hasMultiQueryRetriever { - multiQueryRetriever := &apiretriever.MultiQueryRetriever{ - ObjectMeta: metav1.ObjectMeta{ - Name: input.Name, - Namespace: input.Namespace, - }, - } - _ = c.Delete(ctx, multiQueryRetriever) - } - if !hasRerankRetriever { + } else { reRankRetriever := &apiretriever.RerankRetriever{ ObjectMeta: metav1.ObjectMeta{ Name: input.Name, @@ -714,6 +812,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat } _ = c.Delete(ctx, reRankRetriever) } + // create or update agent for tools var agent *apiagent.Agent if len(input.Tools) != 0 { @@ -742,6 +841,14 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat }); err != nil { return nil, err } + } else { + agent = &apiagent.Agent{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, agent) } // update application @@ -765,7 +872,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat } func mutateApp(app *v1alpha1.Application, input generated.UpdateApplicationConfigInput, hasMultiQueryRetriever, hasRerankRetriever bool) error { - app.Spec.Nodes = redefineNodes(input.Knowledgebase, input.Namespace, input.Name, input.Llm, input.Tools, hasMultiQueryRetriever, hasRerankRetriever, input.EnableUploadFile) + app.Spec.Nodes = redefineNodes(input.Knowledgebases, input.Namespace, input.Name, input.Llm, input.Tools, hasMultiQueryRetriever, hasRerankRetriever, input.EnableUploadFile) app.Spec.Prologue = pointer.StringDeref(input.Prologue, app.Spec.Prologue) app.Spec.ShowRespInfo = pointer.BoolDeref(input.ShowRespInfo, app.Spec.ShowRespInfo) app.Spec.ShowRetrievalInfo = pointer.BoolDeref(input.ShowRetrievalInfo, app.Spec.ShowRetrievalInfo) @@ -779,7 +886,7 @@ func mutateApp(app *v1alpha1.Application, input generated.UpdateApplicationConfi } // redefineNodes redefine nodes in application -func redefineNodes(knowledgebase *string, namespace string, name string, llmName string, tools []*generated.ToolInput, hasMultiQueryRetriever, hasRerankRetriever bool, enableUploadFile *bool) (nodes []v1alpha1.Node) { +func redefineNodes(knowledgebases []*string, namespace string, name string, llmName string, tools []*generated.ToolInput, hasMultiQueryRetriever, hasRerankRetriever bool, enableUploadFile *bool) (nodes []v1alpha1.Node) { nodes = []v1alpha1.Node{ { NodeConfig: v1alpha1.NodeConfig{ @@ -808,41 +915,25 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName }, NextNodeName: []string{"chain-node"}, }, - } - if pointer.BoolDeref(enableUploadFile, false) { - nodes = append(nodes, v1alpha1.Node{ + { NodeConfig: v1alpha1.NodeConfig{ - Name: "documentloader-node", - DisplayName: "documentloader", - Description: "文档加载,可选", + Name: "llm-node", + DisplayName: "llm", + Description: "设定大模型的访问信息", Ref: &v1alpha1.TypedObjectReference{ APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), - Kind: "DocumentLoader", - Name: name, + Kind: "LLM", + Name: llmName, Namespace: &namespace, }, }, NextNodeName: []string{"chain-node"}, - }) - } - nodes = append(nodes, v1alpha1.Node{ - NodeConfig: v1alpha1.NodeConfig{ - Name: "llm-node", - DisplayName: "llm", - Description: "设定大模型的访问信息", - Ref: &v1alpha1.TypedObjectReference{ - APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), - Kind: "LLM", - Name: llmName, - Namespace: &namespace, - }, }, - NextNodeName: []string{"chain-node"}, - }) + } if len(tools) != 0 { nodes[len(nodes)-1].NextNodeName = []string{"chain-node", "agent-node"} } - if knowledgebase == nil { + if !utils.HasValues(knowledgebases) && !pointer.BoolDeref(enableUploadFile, false) { nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ Name: "chain-node", @@ -858,50 +949,102 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName NextNodeName: []string{"Output"}, }) } else { - nodes = append(nodes, - v1alpha1.Node{ - NodeConfig: v1alpha1.NodeConfig{ - Name: "knowledgebase-node", - DisplayName: "知识库", - Description: "连接知识库", - Ref: &v1alpha1.TypedObjectReference{ - APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), - Kind: "KnowledgeBase", - Name: pointer.StringDeref(knowledgebase, ""), - Namespace: &namespace, + for i, knowledgebase := range knowledgebases { + indexStr := fmt.Sprintf("-%d", i) + if i == 0 { + // Compatible with previous versions when there is only one knowledgebase + indexStr = "" + } + nodes = append(nodes, + v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: fmt.Sprintf("knowledgebase%s-node", indexStr), + DisplayName: "知识库", + Description: "连接知识库", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), + Kind: "KnowledgeBase", + Name: pointer.StringDeref(knowledgebase, ""), + Namespace: &namespace, + }, }, + NextNodeName: []string{fmt.Sprintf("retriever%s-node", indexStr)}, }, - NextNodeName: []string{"retriever-node"}, - }, + v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: fmt.Sprintf("retriever%s-node", indexStr), + DisplayName: "从知识库提取信息的retriever", + Description: "连接应用和知识库", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"), + Kind: "KnowledgeBaseRetriever", + Name: fmt.Sprintf("%s%s", name, indexStr), + Namespace: &namespace, + }, + }, + NextNodeName: []string{"mergerretriever-node"}, + }) + } + if pointer.BoolDeref(enableUploadFile, false) { + nodes = append(nodes, + v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: "conversation-knowledgebase-node", + DisplayName: "会话知识库", + Description: "连接会话知识库,包含会话中上传的文件", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), + Kind: "KnowledgeBase", + Name: appnode.ConversationKnowledgebaseName, + Namespace: &namespace, + }, + }, + NextNodeName: []string{"conversation-knowledgebase-retriever-node"}, + }, + v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: "conversation-knowledgebase-retriever-node", + DisplayName: "从会话知识库提取信息的retriever", + Description: "连接应用和知识库", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"), + Kind: "KnowledgeBaseRetriever", + Name: name + "-" + appnode.ConversationKnowledgebaseName, + Namespace: &namespace, + }, + }, + NextNodeName: []string{"mergerretriever-node"}, + }) + } + mergerRetrierNextNodeName := "chain-node" + multiqueryRetrieverNodeName := "chain-node" + switch { + case hasMultiQueryRetriever: + mergerRetrierNextNodeName = "multiqueryretriever-node" + if hasRerankRetriever { + multiqueryRetrieverNodeName = "rerankretriever-node" + } + case !hasMultiQueryRetriever && hasRerankRetriever: + mergerRetrierNextNodeName = "rerankretriever-node" + case !hasMultiQueryRetriever && !hasRerankRetriever: + mergerRetrierNextNodeName = "chain-node" + } + nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ - Name: "retriever-node", - DisplayName: "从知识库提取信息的retriever", - Description: "连接应用和知识库", + Name: "mergerretriever-node", + DisplayName: "知识库合并retriever", + Description: "知识库合并retriever", Ref: &v1alpha1.TypedObjectReference{ APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"), - Kind: "KnowledgeBaseRetriever", + Kind: "MergerRetriever", Name: name, Namespace: &namespace, }, }, - NextNodeName: []string{"chain-node"}, + NextNodeName: []string{mergerRetrierNextNodeName}, }) - knowledgebaseRetrierNextNodeName := "chain-node" - switch { - case hasMultiQueryRetriever: - knowledgebaseRetrierNextNodeName = "multiqueryretriever-node" - case !hasMultiQueryRetriever && hasRerankRetriever: - knowledgebaseRetrierNextNodeName = "rerankretriever-node" - case !hasMultiQueryRetriever && !hasRerankRetriever: - knowledgebaseRetrierNextNodeName = "chain-node" - } - nodes[len(nodes)-1].NextNodeName = []string{knowledgebaseRetrierNextNodeName} if hasMultiQueryRetriever { - nextNodeName := "chain-node" - if hasRerankRetriever { - nextNodeName = "rerankretriever-node" - } nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ @@ -915,7 +1058,7 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName Namespace: &namespace, }, }, - NextNodeName: []string{nextNodeName}, + NextNodeName: []string{multiqueryRetrieverNodeName}, }) } if hasRerankRetriever { @@ -951,6 +1094,7 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName NextNodeName: []string{"Output"}, }) } + if len(tools) != 0 { nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ @@ -967,6 +1111,24 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName NextNodeName: []string{"chain-node"}, }) } + + if pointer.BoolDeref(enableUploadFile, false) { + nodes = append(nodes, v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: "documentloader-node", + DisplayName: "documentloader", + Description: "文档加载,可选", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), + Kind: "DocumentLoader", + Name: name, + Namespace: &namespace, + }, + }, + NextNodeName: []string{"chain-node"}, + }) + } + nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ Name: "Output", @@ -982,6 +1144,25 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName return nodes } +// Deprecated: just for backward compatibility, should remove after new version +func ConvertKnowledgebase2Knowledgebases(knowledgebase *string, knowledgebases []*string) []*string { + if knowledgebases != nil { + return knowledgebases + } + if knowledgebase == nil { + return nil + } + return []*string{knowledgebase} +} + +// Deprecated: just for backward compatibility, should remove after new version +func ConvertKnowledgebases2Knowledgebase(knowledgebases []*string) *string { + if len(knowledgebases) == 0 { + return nil + } + return knowledgebases[0] +} + func UploadIcon(ctx context.Context, client client.Client, icon, appName, namespace string) (string, error) { if strings.HasPrefix(icon, "data:image") { imgBytes, err := pkgutils.ParseBase64ImageBytes(icon) diff --git a/apiserver/pkg/chat/chat_server.go b/apiserver/pkg/chat/chat_server.go index 4cecb65f5..66fee6547 100644 --- a/apiserver/pkg/chat/chat_server.go +++ b/apiserver/pkg/chat/chat_server.go @@ -29,7 +29,6 @@ import ( langchainllms "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/memory" "github.com/tmc/langchaingo/prompts" - langchainschema "github.com/tmc/langchaingo/schema" "golang.org/x/sync/errgroup" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" @@ -302,32 +301,29 @@ func (cs *ChatServer) ListPromptStarters(ctx context.Context, req APPMetadata, l if finish != nil { defer finish() } - v, ok := outArg[base.LangchaingoRetrieverKeyInArg] - if ok { - r, ok := v.(langchainschema.Retriever) - if ok { - doc, err := r.GetRelevantDocuments(ctx, "") - if err != nil { - return nil, err - } - for _, d := range doc { - hasAnswer := false - // has answer, means qa.csv, just return the question - v, ok := d.Metadata[documentloaders.AnswerCol] - if ok { - answer, ok := v.(string) - if ok && answer != "" { - question := strings.TrimSuffix(d.PageContent, "\na: "+answer) - promptStarters = append(promptStarters, strings.TrimPrefix(question, "q: ")) - hasAnswer = true - if len(promptStarters) == limit { - break - } + retrievers, err := base.GetRetrieversFromArg(outArg) + if err == nil && len(retrievers) > 0 { + doc, err := retrievers[0].GetRelevantDocuments(ctx, "") + if err != nil { + return nil, err + } + for _, d := range doc { + hasAnswer := false + // has answer, means qa.csv, just return the question + v, ok := d.Metadata[documentloaders.AnswerCol] + if ok { + answer, ok := v.(string) + if ok && answer != "" { + question := strings.TrimSuffix(d.PageContent, "\na: "+answer) + promptStarters = append(promptStarters, strings.TrimPrefix(question, "q: ")) + hasAnswer = true + if len(promptStarters) == limit { + break } } - if !hasAnswer { - content.WriteString(d.PageContent + "\n") - } + } + if !hasAnswer { + content.WriteString(d.PageContent + "\n") } } } diff --git a/apiserver/pkg/utils/structured.go b/apiserver/pkg/utils/structured.go index 20d1916dc..f72c319e0 100644 --- a/apiserver/pkg/utils/structured.go +++ b/apiserver/pkg/utils/structured.go @@ -38,3 +38,7 @@ func MapStr2Any(input map[string]string) map[string]any { func HasValue(s *string) bool { return s != nil && strings.TrimSpace(*s) != "" } + +func HasValues(s []*string) bool { + return len(s) > 0 && strings.TrimSpace(*s[0]) != "" +} diff --git a/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml new file mode 100644 index 000000000..590f3797a --- /dev/null +++ b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml @@ -0,0 +1,98 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.9.2 + creationTimestamp: null + name: mergerretrievers.retriever.arcadia.kubeagi.k8s.com.cn +spec: + group: retriever.arcadia.kubeagi.k8s.com.cn + names: + kind: MergerRetriever + listKind: MergerRetrieverList + plural: mergerretrievers + singular: mergerretriever + scope: Namespaced + versions: + - name: v1alpha1 + schema: + openAPIV3Schema: + description: MergerRetriever is the Schema for the MergerRetriever API + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: MergerRetrieverSpec defines the desired state of MergerRetriever + properties: + creator: + description: Creator defines datasource creator (AUTO-FILLED by webhook) + type: string + description: + description: Description defines datasource description + type: string + displayName: + description: DisplayName defines datasource display name + type: string + type: object + status: + description: MergerRetrieverStatus defines the observed state of MergerRetriever + properties: + conditions: + description: Conditions of the resource. + items: + description: A Condition that may apply to a resource. + properties: + lastSuccessfulTime: + description: LastSuccessfulTime is repository Last Successful + Update Time + format: date-time + type: string + lastTransitionTime: + description: LastTransitionTime is the last time this condition + transitioned from one status to another. + format: date-time + type: string + message: + description: A Message containing details about this condition's + last transition from one status to another, if any. + type: string + reason: + description: A Reason for this condition's last transition from + one status to another. + type: string + status: + description: Status of this condition; is it currently True, + False, or Unknown + type: string + type: + description: Type of this condition. At most one of each condition + type may apply to a resource at any point in time. + type: string + required: + - lastTransitionTime + - reason + - status + - type + type: object + type: array + observedGeneration: + description: ObservedGeneration is the last observed generation. + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/config/rbac/role.yaml b/config/rbac/role.yaml index 7d0493d62..16b90d5b7 100644 --- a/config/rbac/role.yaml +++ b/config/rbac/role.yaml @@ -653,6 +653,32 @@ rules: - get - patch - update +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - mergerRetrievers/finalizers + verbs: + - update +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - mergerretrievers + verbs: + - create + - delete + - get + - list + - patch + - update + - watch +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - mergerretrievers/status + verbs: + - get + - patch + - update - apiGroups: - retriever.arcadia.kubeagi.k8s.com.cn resources: diff --git a/config/samples/app_llmchain_abstract.yaml b/config/samples/app_llmchain_abstract.yaml index 84df0e4f1..825701ae0 100644 --- a/config/samples/app_llmchain_abstract.yaml +++ b/config/samples/app_llmchain_abstract.yaml @@ -40,11 +40,11 @@ spec: name: app-shared-llm-service nextNodeName: ["chain-node"] - name: chain-node - displayName: "llm chain" + displayName: "chain" description: "chain是langchain的核心概念,llmChain用于连接prompt和llm" ref: apiGroup: chain.arcadia.kubeagi.k8s.com.cn - kind: LLMChain + kind: RetrievalQAChain name: base-chat-document-assistant nextNodeName: ["Output"] - name: Output @@ -53,6 +53,30 @@ spec: ref: kind: Output name: Output + - name: conversation-knowledgebase-node + displayName: "对话知识库" + description: "对话知识库" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBase + name: conversation-knowledgebase-placeholder + nextNodeName: ["conversation-knowledgebase-retriever-node"] + - name: conversation-knowledgebase-retriever-node + displayName: "从对话知识库提取信息的retriever" + description: "连接应用和知识库" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBaseRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["mergerretriever-node"] + - name: mergerretriever-node + displayName: "整合多个retriever的结果" + description: "整合多个retriever的结果" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: MergerRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["chain-node"] --- apiVersion: prompt.arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: Prompt @@ -81,7 +105,7 @@ spec: chunkOverlap: 100 --- apiVersion: chain.arcadia.kubeagi.k8s.com.cn/v1alpha1 -kind: LLMChain +kind: RetrievalQAChain metadata: name: base-chat-document-assistant namespace: arcadia diff --git a/config/samples/app_retrievalqachain_knowledgebase.yaml b/config/samples/app_retrievalqachain_knowledgebase.yaml index 21d998b1d..d7e4fafef 100644 --- a/config/samples/app_retrievalqachain_knowledgebase.yaml +++ b/config/samples/app_retrievalqachain_knowledgebase.yaml @@ -76,7 +76,7 @@ spec: description: "设定知识库应用的prompt, 来自 https://github.com/tmc/langchaingo/blob/af36340149bbf35ae51c80357fa80bf648c33512/chains/question_answering.go#L9" userMessage: | Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. - + -------- {{.context}} -------- Question: {{.question}} diff --git a/config/samples/app_retrievalqachain_knowledgebase_pgvector-conversation.yaml b/config/samples/app_retrievalqachain_knowledgebase_pgvector-conversation.yaml new file mode 100644 index 000000000..9feb4a951 --- /dev/null +++ b/config/samples/app_retrievalqachain_knowledgebase_pgvector-conversation.yaml @@ -0,0 +1,96 @@ +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Application +metadata: + name: base-chat-with-knowledgebase-pgvector + namespace: arcadia +spec: + displayName: "知识库应用" + description: "最简单的和知识库对话的应用" + prologue: "Welcome to talk to the KnowledgeBase!🤖" + docNullReturn: "未找到您询问的内容,请详细描述您的问题,以便我们为您提供更好的服务" + nodes: + - name: Input + displayName: "用户输入" + description: "用户输入节点,必须" + ref: + kind: Input + name: Input + nextNodeName: ["prompt-node"] + - name: prompt-node + displayName: "prompt" + description: "设定prompt,template中可以使用{{xx}}来替换变量" + ref: + apiGroup: prompt.arcadia.kubeagi.k8s.com.cn + kind: Prompt + name: base-chat-with-knowledgebase + nextNodeName: ["chain-node"] + - name: llm-node + displayName: "zhipu大模型服务" + description: "设定大模型的访问信息" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: LLM + name: app-shared-llm-service + nextNodeName: ["chain-node"] + - name: knowledgebase-node + displayName: "使用的知识库" + description: "要用哪个知识库" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBase + name: knowledgebase-sample-pgvector + nextNodeName: ["retriever-node"] + - name: retriever-node + displayName: "从知识库提取信息的retriever" + description: "连接应用和知识库" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBaseRetriever + name: base-chat-with-knowledgebase + nextNodeName: ["chain-node"] + - name: chain-node + displayName: "RetrievalQA chain" + description: "chain是langchain的核心概念,RetrievalQAChain用于从 retriever 中提取信息,供llm调用" + ref: + apiGroup: chain.arcadia.kubeagi.k8s.com.cn + kind: RetrievalQAChain + name: base-chat-with-knowledgebase + nextNodeName: ["Output"] + - name: Output + displayName: "最终输出" + description: "最终输出节点,必须" + ref: + kind: Output + name: Output + - name: documentloader-node + displayName: "documentloader" + description: "可选,如果需要对话中上传解析文件,需要添加这个节点" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: DocumentLoader + name: common + nextNodeName: [ "chain-node" ] + - name: conversation-knowledgebase-node + displayName: "对话知识库" + description: "对话知识库" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBase + name: conversation-knowledgebase-placeholder + nextNodeName: ["conversation-knowledgebase-retriever-node"] + - name: conversation-knowledgebase-retriever-node + displayName: "从对话知识库提取信息的retriever" + description: "连接应用和知识库" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBaseRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["mergerretriever-node"] + - name: mergerretriever-node + displayName: "整合多个retriever的结果" + description: "整合多个retriever的结果" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: MergerRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["chain-node"] diff --git a/config/samples/app_retrievalqachain_multi_knowledgebase_pgvector_rerank_multiquery.yaml b/config/samples/app_retrievalqachain_multi_knowledgebase_pgvector_rerank_multiquery.yaml new file mode 100644 index 000000000..328cc138c --- /dev/null +++ b/config/samples/app_retrievalqachain_multi_knowledgebase_pgvector_rerank_multiquery.yaml @@ -0,0 +1,145 @@ +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Application +metadata: + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + namespace: arcadia +spec: + displayName: "知识库应用" + description: "最简单的和知识库对话的应用" + prologue: "Welcome to talk to the KnowledgeBase!🤖" + docNullReturn: "未找到您询问的内容,请详细描述您的问题,以便我们为您提供更好的服务" + nodes: + - name: Input + displayName: "用户输入" + description: "用户输入节点,必须" + ref: + kind: Input + name: Input + nextNodeName: ["prompt-node"] + - name: prompt-node + displayName: "prompt" + description: "设定prompt,template中可以使用{{xx}}来替换变量" + ref: + apiGroup: prompt.arcadia.kubeagi.k8s.com.cn + kind: Prompt + name: base-chat-with-knowledgebase + nextNodeName: ["chain-node"] + - name: llm-node + displayName: "大模型服务" + description: "设定大模型的访问信息" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: LLM + name: app-shared-llm-service + nextNodeName: ["chain-node"] + - name: knowledgebase-node + displayName: "使用的知识库" + description: "要用哪个知识库" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBase + name: knowledgebase-sample-pgvector + nextNodeName: ["retriever-node"] + - name: retriever-node + displayName: "从知识库提取信息的retriever" + description: "连接应用和知识库" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBaseRetriever + name: base-chat-with-knowledgebase-pgvector-1 + nextNodeName: ["mergerretriever-node"] + - name: knowledgebase-1-node + displayName: "使用的知识库1" + description: "要用哪个知识库" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBase + name: knowledgebase-sample-pgvector2 + nextNodeName: ["retriever-1-node"] + - name: retriever-1-node + displayName: "从知识库1提取信息的retriever" + description: "连接应用和知识库" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBaseRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["mergerretriever-node"] + - name: mergerretriever-node + displayName: "整合多个retriever的结果" + description: "整合多个retriever的结果" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: MergerRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["multiquery-node"] + - name: multiquery-node + displayName: "让LLM多角度提问的Retriever" + description: "让LLM多角度提问的Retriever" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: MultiQueryRetriever + name: base-chat-with-knowledgebase-pgvector-2 + nextNodeName: ["rerank-retriever-node"] + - name: rerank-retriever-node + displayName: "rerank retriever" + description: "重排" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: RerankRetriever + name: base-chat-with-knowledgebase-pgvector-rerank + nextNodeName: ["chain-node"] + - name: chain-node + displayName: "RetrievalQA chain" + description: "chain是langchain的核心概念,RetrievalQAChain用于从 retriever 中提取信息,供llm调用" + ref: + apiGroup: chain.arcadia.kubeagi.k8s.com.cn + kind: RetrievalQAChain + name: base-chat-with-knowledgebase + nextNodeName: ["Output"] + - name: Output + displayName: "最终输出" + description: "最终输出节点,必须" + ref: + kind: Output + name: Output + - name: conversation-knowledgebase-node + displayName: "对话知识库" + description: "对话知识库" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBase + name: conversation-knowledgebase-placeholder + nextNodeName: ["conversation-knowledgebase-retriever-node"] + - name: conversation-knowledgebase-retriever-node + displayName: "从对话知识库提取信息的retriever" + description: "连接应用和知识库" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBaseRetriever + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + nextNodeName: ["mergerretriever-node"] + +--- +apiVersion: retriever.arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: KnowledgeBaseRetriever +metadata: + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + namespace: arcadia + annotations: + arcadia.kubeagi.k8s.com.cn/input-rules: '[{"kind":"KnowledgeBase","group":"arcadia.kubeagi.k8s.com.cn","length":1}]' + arcadia.kubeagi.k8s.com.cn/output-rules: '[{"kind":"RetrievalQAChain","group":"chain.arcadia.kubeagi.k8s.com.cn","length":1}]' +spec: + displayName: "从知识库获取信息的Retriever" + scoreThreshold: 0.3 + numDocuments: 50 +--- +apiVersion: retriever.arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: MergerRetriever +metadata: + name: base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery + namespace: arcadia + annotations: + arcadia.kubeagi.k8s.com.cn/input-rules: '[{"kind":"KnowledgeBase","group":"arcadia.kubeagi.k8s.com.cn","length":1}]' + arcadia.kubeagi.k8s.com.cn/output-rules: '[{"kind":"RetrievalQAChain","group":"chain.arcadia.kubeagi.k8s.com.cn","length":1}]' +spec: + displayName: "整合多个Retriever" diff --git a/config/samples/app_shared_llm_service_zhipu.yaml b/config/samples/app_shared_llm_service_zhipu.yaml index 9762e60a7..e44aa08ff 100644 --- a/config/samples/app_shared_llm_service_zhipu.yaml +++ b/config/samples/app_shared_llm_service_zhipu.yaml @@ -5,7 +5,7 @@ metadata: namespace: arcadia type: Opaque data: - apiKey: "MTZlZDcxYzcwMDE0NGFiMjIyMmI5YmEwZDFhMTBhZTUuUTljWVZtWWxmdjlnZGtDeQ==" # replace this with your API key + apiKey: "YTgyNTlhNjFmN2EwZGYzNmQ5N2Q3ZDIwOGVlMTQ0NTUuODc5OGJyeldwaGUzWUlCOA==" # replace this with your API key --- apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: LLM diff --git a/config/samples/arcadia_v1alpha1_knowledgebase_pgvector.yaml b/config/samples/arcadia_v1alpha1_knowledgebase_pgvector.yaml index 5e41e3f3f..2f67fce5e 100644 --- a/config/samples/arcadia_v1alpha1_knowledgebase_pgvector.yaml +++ b/config/samples/arcadia_v1alpha1_knowledgebase_pgvector.yaml @@ -22,3 +22,27 @@ spec: files: - path: qa.csv - path: chunk.csv +--- +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: KnowledgeBase +metadata: + name: knowledgebase-sample-pgvector2 + namespace: arcadia +spec: + displayName: "测试 KnowledgeBase 2" + description: "测试 KnowledgeBase 2" + embedder: + kind: Embedders + name: embedders-sample + namespace: arcadia + vectorStore: + kind: VectorStores + name: pgvector-sample + namespace: arcadia + fileGroups: + - source: + kind: VersionedDataset + name: dataset-playground-v1 + namespace: arcadia + files: + - path: CODE_OF_CONDUCT.md diff --git a/config/samples/arcadia_v1alpha1_model_reranking_bce_modelscope.yaml b/config/samples/arcadia_v1alpha1_model_reranking_bce_modelscope.yaml new file mode 100644 index 000000000..221beee1b --- /dev/null +++ b/config/samples/arcadia_v1alpha1_model_reranking_bce_modelscope.yaml @@ -0,0 +1,16 @@ +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Model +metadata: + name: bce-reranker + namespace: arcadia +spec: + displayName: "bce-reranker-large" + description: | + BCEmbedding是由网易有道开发的中英双语和跨语种语义表征算法模型库,其中包含 EmbeddingModel和 RerankerModel两类基础模型。 + EmbeddingModel专门用于生成语义向量,在语义搜索和问答中起着关键作用,而 RerankerModel擅长优化语义搜索结果和语义相关顺序精排。 + + Github: https://github.com/netease-youdao/BCEmbedding + HuggingFace: https://huggingface.co/maidalun1020/bce-reranker-base_v1 + types: "reranking" + modelScopeRepo: maidalun/bce-reranker-base_v1 + modelSource: modelscope diff --git a/config/samples/arcadia_v1alpha1_versioneddataset.yaml b/config/samples/arcadia_v1alpha1_versioneddataset.yaml index 156a2ef31..d4e26a340 100644 --- a/config/samples/arcadia_v1alpha1_versioneddataset.yaml +++ b/config/samples/arcadia_v1alpha1_versioneddataset.yaml @@ -16,5 +16,6 @@ spec: files: - path: qa.csv - path: chunk.csv + - path: CODE_OF_CONDUCT.md released: 0 version: v1 diff --git a/controllers/app-node/retriever/merger_retriever_controller.go b/controllers/app-node/retriever/merger_retriever_controller.go new file mode 100644 index 000000000..6e1221e4c --- /dev/null +++ b/controllers/app-node/retriever/merger_retriever_controller.go @@ -0,0 +1,150 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package chain + +import ( + "context" + "reflect" + + "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/runtime" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + + api "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" + arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" + appnode "github.com/kubeagi/arcadia/controllers/app-node" +) + +// MergerRetrieverReconciler reconciles a MergerRetriever object +type MergerRetrieverReconciler struct { + client.Client + Scheme *runtime.Scheme +} + +//+kubebuilder:rbac:groups=retriever.arcadia.kubeagi.k8s.com.cn,resources=mergerretrievers,verbs=get;list;watch;create;update;patch;delete +//+kubebuilder:rbac:groups=retriever.arcadia.kubeagi.k8s.com.cn,resources=mergerretrievers/status,verbs=get;update;patch +//+kubebuilder:rbac:groups=retriever.arcadia.kubeagi.k8s.com.cn,resources=mergerRetrievers/finalizers,verbs=update + +// Reconcile is part of the main kubernetes reconciliation loop which aims to +// move the current state of the cluster closer to the desired state. +// For more details, check Reconcile and its Result here: +// - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.12.2/pkg/reconcile +func (r *MergerRetrieverReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + log := ctrl.LoggerFrom(ctx) + log.V(5).Info("Start MergerRetriever Reconcile") + instance := &api.MergerRetriever{} + if err := r.Get(ctx, req.NamespacedName, instance); err != nil { + // There's no need to requeue if the resource no longer exists. + // Otherwise, we'll be requeued implicitly because we return an error. + log.V(1).Info("Failed to get MergerRetriever") + return ctrl.Result{}, client.IgnoreNotFound(err) + } + log = log.WithValues("Generation", instance.GetGeneration(), "ObservedGeneration", instance.Status.ObservedGeneration, "creator", instance.Spec.Creator) + log.V(5).Info("Get MergerRetriever instance") + + // Add a finalizer.Then, we can define some operations which should + // occur before the MergerRetriever to be deleted. + // More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/finalizers + if newAdded := controllerutil.AddFinalizer(instance, arcadiav1alpha1.Finalizer); newAdded { + log.Info("Try to add Finalizer for MergerRetriever") + if err := r.Update(ctx, instance); err != nil { + log.Error(err, "Failed to update MergerRetriever to add finalizer, will try again later") + return ctrl.Result{}, err + } + log.Info("Adding Finalizer for MergerRetriever done") + return ctrl.Result{}, nil + } + + // Check if the MergerRetriever instance is marked to be deleted, which is + // indicated by the deletion timestamp being set. + if instance.GetDeletionTimestamp() != nil && controllerutil.ContainsFinalizer(instance, arcadiav1alpha1.Finalizer) { + log.Info("Performing Finalizer Operations for MergerRetriever before delete CR") + // TODO perform the finalizer operations here, for example: remove vectorstore data? + log.Info("Removing Finalizer for MergerRetriever after successfully performing the operations") + controllerutil.RemoveFinalizer(instance, arcadiav1alpha1.Finalizer) + if err := r.Update(ctx, instance); err != nil { + log.Error(err, "Failed to remove the finalizer for MergerRetriever") + return ctrl.Result{}, err + } + log.Info("Remove MergerRetriever done") + return ctrl.Result{}, nil + } + + instance, result, err := r.reconcile(ctx, log, instance) + + // Update status after reconciliation. + if updateStatusErr := r.patchStatus(ctx, instance); updateStatusErr != nil { + log.Error(updateStatusErr, "unable to update status after reconciliation") + return ctrl.Result{Requeue: true}, updateStatusErr + } + + return result, err +} + +func (r *MergerRetrieverReconciler) reconcile(ctx context.Context, log logr.Logger, instance *api.MergerRetriever) (*api.MergerRetriever, ctrl.Result, error) { + // Observe generation change + if instance.Status.ObservedGeneration != instance.Generation { + instance.Status.ObservedGeneration = instance.Generation + r.setCondition(instance, instance.Status.WaitingCompleteCondition()...) + if updateStatusErr := r.patchStatus(ctx, instance); updateStatusErr != nil { + log.Error(updateStatusErr, "unable to update status after generation update") + return instance, ctrl.Result{Requeue: true}, updateStatusErr + } + } + + if instance.Status.IsReady() { + return instance, ctrl.Result{}, nil + } + // Note: should change here + // TODO: we should do more checks later.For example: + // LLM status + // Prompt status + if err := appnode.CheckAndUpdateAnnotation(ctx, log, r.Client, instance); err != nil { + instance.Status.SetConditions(instance.Status.ErrorCondition(err.Error())...) + } else { + instance.Status.SetConditions(instance.Status.ReadyCondition()...) + } + + return instance, ctrl.Result{}, nil +} + +func (r *MergerRetrieverReconciler) patchStatus(ctx context.Context, instance *api.MergerRetriever) error { + latest := &api.MergerRetriever{} + if err := r.Client.Get(ctx, client.ObjectKeyFromObject(instance), latest); err != nil { + return err + } + if reflect.DeepEqual(instance.Status, latest.Status) { + return nil + } + patch := client.MergeFrom(latest.DeepCopy()) + latest.Status = instance.Status + return r.Client.Status().Patch(ctx, latest, patch, client.FieldOwner("MergerRetriever-controller")) +} + +// SetupWithManager sets up the controller with the Manager. +func (r *MergerRetrieverReconciler) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&api.MergerRetriever{}). + Complete(r) +} + +func (r *MergerRetrieverReconciler) setCondition(instance *api.MergerRetriever, condition ...arcadiav1alpha1.Condition) *api.MergerRetriever { + instance.Status.SetConditions(condition...) + return instance +} diff --git a/controllers/base/application_controller.go b/controllers/base/application_controller.go index 6c230184e..797ff589a 100644 --- a/controllers/base/application_controller.go +++ b/controllers/base/application_controller.go @@ -55,6 +55,7 @@ const ( KnowledgebaseRetrieverIndexKey = "metadata.knowledgebaseretriever" RerankRetrieverIndexKey = "metadata.rerankretriever" MultiQueryRetrieverIndexKey = "metadata.multiqueryretriever" + MergerRetrieverIndexKey = "metadata.mergerretriever" AgentIndexKey = "metadata.agent" DocumentLoaderIndexKey = "metadata.documentloader" ) @@ -368,6 +369,7 @@ func (r *ApplicationReconciler) SetupWithManager(ctx context.Context, mgr ctrl.M {KnowledgebaseRetrieverIndexKey, "retriever", "knowledgebaseretriever"}, {RerankRetrieverIndexKey, "retriever", "rerankretriever"}, {MultiQueryRetrieverIndexKey, "retriever", "multiqueryretriever"}, + {MergerRetrieverIndexKey, "retriever", "mergerretriever"}, {AgentIndexKey, "", "agent"}, {DocumentLoaderIndexKey, "", "documentloader"}, } @@ -424,6 +426,7 @@ func (r *ApplicationReconciler) SetupWithManager(ctx context.Context, mgr ctrl.M Watches(&source.Kind{Type: &retrieveralpha1.KnowledgeBaseRetriever{}}, getEventHandler(KnowledgebaseRetrieverIndexKey)). Watches(&source.Kind{Type: &retrieveralpha1.RerankRetriever{}}, getEventHandler(RerankRetrieverIndexKey)). Watches(&source.Kind{Type: &retrieveralpha1.MultiQueryRetriever{}}, getEventHandler(MultiQueryRetrieverIndexKey)). + Watches(&source.Kind{Type: &retrieveralpha1.MergerRetriever{}}, getEventHandler(MergerRetrieverIndexKey)). Watches(&source.Kind{Type: &agentv1alpha1.Agent{}}, getEventHandler(AgentIndexKey)). Watches(&source.Kind{Type: &documentloaderv1alpha1.DocumentLoader{}}, getEventHandler(DocumentLoaderIndexKey)). Complete(r) diff --git a/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml new file mode 100644 index 000000000..590f3797a --- /dev/null +++ b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml @@ -0,0 +1,98 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.9.2 + creationTimestamp: null + name: mergerretrievers.retriever.arcadia.kubeagi.k8s.com.cn +spec: + group: retriever.arcadia.kubeagi.k8s.com.cn + names: + kind: MergerRetriever + listKind: MergerRetrieverList + plural: mergerretrievers + singular: mergerretriever + scope: Namespaced + versions: + - name: v1alpha1 + schema: + openAPIV3Schema: + description: MergerRetriever is the Schema for the MergerRetriever API + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: MergerRetrieverSpec defines the desired state of MergerRetriever + properties: + creator: + description: Creator defines datasource creator (AUTO-FILLED by webhook) + type: string + description: + description: Description defines datasource description + type: string + displayName: + description: DisplayName defines datasource display name + type: string + type: object + status: + description: MergerRetrieverStatus defines the observed state of MergerRetriever + properties: + conditions: + description: Conditions of the resource. + items: + description: A Condition that may apply to a resource. + properties: + lastSuccessfulTime: + description: LastSuccessfulTime is repository Last Successful + Update Time + format: date-time + type: string + lastTransitionTime: + description: LastTransitionTime is the last time this condition + transitioned from one status to another. + format: date-time + type: string + message: + description: A Message containing details about this condition's + last transition from one status to another, if any. + type: string + reason: + description: A Reason for this condition's last transition from + one status to another. + type: string + status: + description: Status of this condition; is it currently True, + False, or Unknown + type: string + type: + description: Type of this condition. At most one of each condition + type may apply to a resource at any point in time. + type: string + required: + - lastTransitionTime + - reason + - status + - type + type: object + type: array + observedGeneration: + description: ObservedGeneration is the last observed generation. + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/deploy/charts/arcadia/templates/rbac.yaml b/deploy/charts/arcadia/templates/rbac.yaml index 1f74678ab..0119f0a1e 100644 --- a/deploy/charts/arcadia/templates/rbac.yaml +++ b/deploy/charts/arcadia/templates/rbac.yaml @@ -671,6 +671,32 @@ rules: - get - patch - update +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - mergerRetrievers/finalizers + verbs: + - update +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - mergerretrievers + verbs: + - create + - delete + - get + - list + - patch + - update + - watch +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - mergerretrievers/status + verbs: + - get + - patch + - update - apiGroups: - retriever.arcadia.kubeagi.k8s.com.cn resources: diff --git a/deploy/charts/arcadia/templates/role-templates.yaml b/deploy/charts/arcadia/templates/role-templates.yaml index cbb4f8362..218c7947e 100644 --- a/deploy/charts/arcadia/templates/role-templates.yaml +++ b/deploy/charts/arcadia/templates/role-templates.yaml @@ -171,6 +171,7 @@ spec: - knowledgebaseretrievers - multiqueryretrievers - rerankretrievers + - mergerretrievers verbs: - create - delete @@ -188,6 +189,7 @@ spec: - knowledgebaseretrievers/status - multiqueryretrievers/status - rerankretrievers/status + - mergerretrievers/status verbs: - get - patch diff --git a/go.mod b/go.mod index 8f8600ff6..22c661473 100644 --- a/go.mod +++ b/go.mod @@ -215,4 +215,4 @@ require ( sigs.k8s.io/yaml v1.3.0 // indirect ) -replace github.com/tmc/langchaingo => github.com/kubeagi/langchaingo v0.0.0-20240312075057-ca2f549e8d91 // branch dev +replace github.com/tmc/langchaingo => github.com/kubeagi/langchaingo v0.0.0-20240416092403-dd907a8798bd // branch dev diff --git a/go.sum b/go.sum index b8718c1b0..5c1b36485 100644 --- a/go.sum +++ b/go.sum @@ -549,8 +549,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kubeagi/langchaingo v0.0.0-20240312075057-ca2f549e8d91 h1:4VbKHgpTrG/EPWIn4n3FMIedvz3z2aYL2E0Z7RIEvik= -github.com/kubeagi/langchaingo v0.0.0-20240312075057-ca2f549e8d91/go.mod h1:RLtnUED/hH2v765vdjS9Z6gonErZAXURuJHph0BttqM= +github.com/kubeagi/langchaingo v0.0.0-20240416092403-dd907a8798bd h1:3+lIBh7HtAqvbeBTpnuy79zvgBl4dCQz5YkhXKeocOc= +github.com/kubeagi/langchaingo v0.0.0-20240416092403-dd907a8798bd/go.mod h1:RLtnUED/hH2v765vdjS9Z6gonErZAXURuJHph0BttqM= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 h1:6Yzfa6GP0rIo/kULo2bwGEkFvCePZ3qHDDTC3/J9Swo= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= diff --git a/main.go b/main.go index 09ac62956..40f1737bb 100644 --- a/main.go +++ b/main.go @@ -278,6 +278,13 @@ func main() { setupLog.Error(err, "unable to create controller", "controller", "MultiQueryRetriever") os.Exit(1) } + if err = (&retrievertrollers.MergerRetrieverReconciler{ + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + }).SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create controller", "controller", "MergerRetriever") + os.Exit(1) + } if err = (&promptcontrollers.PromptReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), diff --git a/pkg/appruntime/agent/executor.go b/pkg/appruntime/agent/executor.go index a657cfe71..64b1c478a 100644 --- a/pkg/appruntime/agent/executor.go +++ b/pkg/appruntime/agent/executor.go @@ -23,6 +23,7 @@ import ( "github.com/tmc/langchaingo/agents" "github.com/tmc/langchaingo/callbacks" + "github.com/tmc/langchaingo/chains" "github.com/tmc/langchaingo/llms" langchaingoschema "github.com/tmc/langchaingo/schema" "k8s.io/apimachinery/pkg/types" @@ -79,15 +80,21 @@ func (p *Executor) Run(ctx context.Context, cli client.Client, args map[string]a agents.WithCallbacksHandler(streamHandler)(o) } } - agents.WithMemory(chain.GetMemory(llm, instance.Spec.AgentConfig.Options.Memory, history, "", ""))(o) + agents.WithMemory(chain.GetMemory(llm, instance.Spec.AgentConfig.Options.Memory, history, "input", ""))(o) } - executor, err := agents.Initialize(llm, allowedTools, agents.ZeroShotReactDescription, executorOptions) + executor, err := agents.Initialize(llm, allowedTools, agents.ConversationalReactDescription, executorOptions) if err != nil { return args, fmt.Errorf("failed to initialize executor: %w", err) } + executor.CallbacksHandler = log.KLogHandler{LogLevel: 3} input := make(map[string]any) - input["input"] = fmt.Sprintf("%s, %s", instance.Spec.Prompt, args["question"]) - response, err := executor.Call(ctx, input) + if instance.Spec.Prompt != "" { + input["input"] = fmt.Sprintf("%s, %s", instance.Spec.Prompt, args["question"]) + } else { + input["input"] = args["question"] + } + // chains.Call will add history to args + response, err := chains.Call(ctx, executor, input) if err != nil { klog.FromContext(ctx).Error(err, "error when call agent") // return args, fmt.Errorf("error when call agent: %w", err) diff --git a/pkg/appruntime/app_runtime.go b/pkg/appruntime/app_runtime.go index 82e3d5595..6f4483e23 100644 --- a/pkg/appruntime/app_runtime.go +++ b/pkg/appruntime/app_runtime.go @@ -25,8 +25,6 @@ import ( "strings" langchaingoschema "github.com/tmc/langchaingo/schema" - apierrors "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" "k8s.io/utils/strings/slices" "sigs.k8s.io/controller-runtime/pkg/client" @@ -149,32 +147,12 @@ func (a *Application) Run(ctx context.Context, cli client.Client, respStream cha base.InputIsNeedStreamKeyInArg: input.NeedStream, base.LangchaingoChatMessageHistoryKeyInArg: input.History, // Use an empty context before run - "context": "", + "context": "", + base.ConversationIDInArg: input.ConversationID, } if a.Spec.DocNullReturn != "" { out[base.APPDocNullReturn] = a.Spec.DocNullReturn } - if input.ConversationID != "" { // means this is not a new conversation - conversationKnowledgebaseExist := true - kb := &arcadiav1alpha1.KnowledgeBase{} - err := cli.Get(ctx, types.NamespacedName{Namespace: a.Namespace, Name: input.ConversationID}, kb) - if err != nil { - if apierrors.IsNotFound(err) { - conversationKnowledgebaseExist = false - // TODO We can search for whether there should be a conversation knowledgebase from the pg - klog.FromContext(ctx).V(5).Info("conversation knowledgebase not exist", "ConversationID", input.ConversationID) - } else { - return output, err - } - } - if conversationKnowledgebaseExist { - if kb.Status.IsReady() { - out[base.ConversationKnowledgeBaseInArg] = kb - } else { - klog.FromContext(ctx).V(3).Info("conversation knowledgebase not ready", "ConversationID", input.ConversationID) - } - } - } visited := make(map[string]bool) waitRunningNodes := list.New() for _, v := range a.StartingNodes { @@ -194,7 +172,6 @@ func (a *Application) Run(ctx context.Context, cli client.Client, respStream cha waitRunningNodes.PushBack(e) continue } - klog.FromContext(ctx).V(3).Info(fmt.Sprintf("try to run node:%s", e.Name())) defer func() { if r := recover(); r != nil { klog.FromContext(ctx).Info(fmt.Sprintf("Recovered from node:%s error:%s stack:%s", e.Name(), r, string(debug.Stack()))) @@ -253,7 +230,7 @@ func InitNode(ctx context.Context, appNamespace, name string, ref arcadiav1alpha } }() baseNode := base.NewBaseNode(appNamespace, name, ref) - err = fmt.Errorf("unknown kind %s:%v", name, ref) + err = fmt.Errorf("unknown kind %s:%v, get group:%s kind:%s", name, ref, baseNode.Group(), baseNode.Kind()) switch baseNode.Group() { case "chain": switch baseNode.Kind() { @@ -280,6 +257,9 @@ func InitNode(ctx context.Context, appNamespace, name string, ref arcadiav1alpha case "multiqueryretriever": logger.V(3).Info("initnode multiqueryretriever") return retriever.NewMultiQueryRetriever(baseNode), nil + case "mergerretriever": + logger.V(3).Info("initnode mergerretriever") + return retriever.NewMergerRetriever(baseNode), nil default: return nil, err } diff --git a/pkg/appruntime/base/keyword.go b/pkg/appruntime/base/keyword.go index 2171c84bc..47d9fa171 100644 --- a/pkg/appruntime/base/keyword.go +++ b/pkg/appruntime/base/keyword.go @@ -16,6 +16,12 @@ limitations under the License. package base +import ( + "errors" + + langchainschema "github.com/tmc/langchaingo/schema" +) + const ( InputQuestionKeyInArg = "question" InputIsNeedStreamKeyInArg = "_need_stream" @@ -25,9 +31,63 @@ const ( MapReduceDocumentOutputInArg = "_mapreduce_document_answer" OutputAnswerStreamChanKeyInArg = "_answer_stream" RuntimeRetrieverReferencesKeyInArg = "_references" - LangchaingoRetrieverKeyInArg = "retriever" + LangchaingoRetrieversKeyInArg = "retrievers" LangchaingoLLMKeyInArg = "llm" LangchaingoPromptKeyInArg = "prompt" APPDocNullReturn = "_app_doc_null_return" ConversationKnowledgeBaseInArg = "_conversation_knowledgebase" // the conversation Knowledgebase cr in args, status has ready + ConversationIDInArg = "_conversation_id" +) + +var ( + ErrNoQuestion = errors.New("no question in args") + ErrNoRetrievers = errors.New("no retrievers in args") ) + +func GetInputQuestionFromArg(args map[string]any) (string, error) { + q, ok := args[InputQuestionKeyInArg] + if !ok { + return "", ErrNoQuestion + } + query, ok := q.(string) + if !ok || len(query) == 0 { + return "", errors.New("empty question") + } + return query, nil +} + +func GetRetrieversFromArg(args map[string]any) ([]langchainschema.Retriever, error) { + v, ok := args[LangchaingoRetrieversKeyInArg] + if !ok { + return nil, ErrNoRetrievers + } + retrievers, ok := v.([]langchainschema.Retriever) + if !ok { + return nil, errors.New("retrievers not []schema.Retriever") + } + return retrievers, nil +} + +func GetAPPDocNullReturnFromArg(args map[string]any) (string, error) { + v, ok := args[APPDocNullReturn] + if !ok { + return "", nil + } + docNullReturn, ok := v.(string) + if !ok { + return "", errors.New("app doc null return not string type") + } + return docNullReturn, nil +} + +// AddKnowledgebaseRetrieverToArg add knowledgebase retriever to args +// Note: only knowledgebase retriever will be appended, other components like qachain will only use the first retriever in args +func AddKnowledgebaseRetrieverToArg(args map[string]any, retriever langchainschema.Retriever) map[string]any { + if _, exist := args[LangchaingoRetrieversKeyInArg]; !exist { + args[LangchaingoRetrieversKeyInArg] = make([]langchainschema.Retriever, 0) + } + retrievers := args[LangchaingoRetrieversKeyInArg].([]langchainschema.Retriever) + retrievers = append(retrievers, retriever) + args[LangchaingoRetrieversKeyInArg] = retrievers + return args +} diff --git a/pkg/appruntime/base/node.go b/pkg/appruntime/base/node.go index 7122c608b..c401fc639 100644 --- a/pkg/appruntime/base/node.go +++ b/pkg/appruntime/base/node.go @@ -111,8 +111,8 @@ func (c *BaseNode) Init(_ context.Context, _ client.Client, _ map[string]any) er return nil } -func (c *BaseNode) Run(_ context.Context, _ client.Client, _ map[string]any) (map[string]any, error) { - return nil, nil +func (c *BaseNode) Run(_ context.Context, _ client.Client, args map[string]any) (map[string]any, error) { + return args, nil } func (c *BaseNode) Ready() (bool, string) { diff --git a/pkg/appruntime/chain/llmchain.go b/pkg/appruntime/chain/llmchain.go index 9883693bc..bb84cac1e 100644 --- a/pkg/appruntime/chain/llmchain.go +++ b/pkg/appruntime/chain/llmchain.go @@ -27,7 +27,6 @@ import ( langchaingoschema "github.com/tmc/langchaingo/schema" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" - "k8s.io/utils/pointer" "sigs.k8s.io/controller-runtime/pkg/client" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" @@ -57,19 +56,6 @@ func (l *LLMChain) Init(ctx context.Context, cli client.Client, args map[string] } func (l *LLMChain) Run(ctx context.Context, cli client.Client, args map[string]any) (outArgs map[string]any, err error) { - if args[base.ConversationKnowledgeBaseInArg] != nil { - chain := NewRetrievalQAChain(l.BaseNode) - chain.Instance = &v1alpha1.RetrievalQAChain{ - Spec: v1alpha1.RetrievalQAChainSpec{ - CommonChainConfig: v1alpha1.CommonChainConfig{ - Memory: v1alpha1.Memory{ - ConversionWindowSize: pointer.Int(v1alpha1.DefaultConversionWindowSize), - }, - }, - }, - } - return chain.Run(ctx, cli, args) - } v1, ok := args[base.LangchaingoLLMKeyInArg] if !ok { return args, errors.New("no llm") @@ -122,7 +108,7 @@ func (l *LLMChain) Run(ctx context.Context, cli client.Client, args map[string]a // Add the agent output to the context if it's not empty if args[base.AgentOutputInArg] != nil { klog.FromContext(ctx).V(5).Info(fmt.Sprintf("get answer from upstream: %s", args[base.AgentOutputInArg])) - args["context"] = fmt.Sprintf("%s\n%s", args["context"], args[base.AgentOutputInArg]) + args["context"] = fmt.Sprintf("%s\n %s", args["context"], fmt.Sprintf("Tool Output of question \"%s\" is: %s", args["question"].(string), args[base.AgentOutputInArg].(string))) } // Add the mapReduceDocument output to the context if it's not empty if args[base.MapReduceDocumentOutputInArg] != nil { diff --git a/pkg/appruntime/chain/retrievalqachain.go b/pkg/appruntime/chain/retrievalqachain.go index e14a2eece..2c44aa618 100644 --- a/pkg/appruntime/chain/retrievalqachain.go +++ b/pkg/appruntime/chain/retrievalqachain.go @@ -27,12 +27,9 @@ import ( langchainschema "github.com/tmc/langchaingo/schema" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" - "k8s.io/utils/pointer" "sigs.k8s.io/controller-runtime/pkg/client" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" - apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" - arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" "github.com/kubeagi/arcadia/pkg/appruntime/log" appruntimeretriever "github.com/kubeagi/arcadia/pkg/appruntime/retriever" @@ -77,27 +74,19 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli client.Client, args map[ if !ok { return args, errors.New("prompt not prompts.FormatPrompter") } - vkb, ok := args[base.ConversationKnowledgeBaseInArg] - if ok { - kb, ok := vkb.(*arcadiav1alpha1.KnowledgeBase) - if !ok { - return args, errors.New("knowledgebase not arcadiav1alpha1.KnowledgeBase") - } - args, finish, err := appruntimeretriever.GenerateKnowledgebaseRetriever(ctx, cli, kb.Name, kb.Namespace, apiretriever.CommonRetrieverConfig{ScoreThreshold: pointer.Float32(apiretriever.DefaultScoreThreshold), NumDocuments: 20}, args) - if err != nil { - return args, err - } - defer finish() - } - v3, ok := args["retriever"] - if !ok { - return args, errors.New("no retriever") - } - retriever, ok := v3.(langchainschema.Retriever) - if !ok { - return args, errors.New("retriever not schema.Retriever") + retrieversInArg, err := base.GetRetrieversFromArg(args) + if err != nil { + if errors.Is(err, base.ErrNoRetrievers) { + klog.FromContext(ctx).Info("no retrievers found in chain, roll back to LLMChain") + llmchain := NewLLMChain(l.BaseNode) + llmchain.Instance = &v1alpha1.LLMChain{} + llmchain.Instance.Spec.CommonChainConfig = l.Instance.Spec.CommonChainConfig + return llmchain.Run(ctx, cli, args) + } + return args, err } + retriever := retrieversInArg[0] v4, ok := args[base.LangchaingoChatMessageHistoryKeyInArg] if !ok { @@ -130,6 +119,10 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli client.Client, args map[ } } + docs, err := retriever.GetRelevantDocuments(ctx, args["question"].(string)) + if err != nil { + return args, fmt.Errorf("can't get doc from retriever: %w", err) + } /* Note: we can only add context to retriever's documents not args["context"] if we use ConversationalRetrievalQA chain see https://github.com/kubeagi/langchaingo/blob/ca2f549e8d91788fd76a9f8706afcaae617275c5/chains/stuff_documents.go#L54-L69 @@ -139,14 +132,17 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli client.Client, args map[ // Add the agent output to qachain context if args[base.AgentOutputInArg] != nil { klog.FromContext(ctx).V(5).Info(fmt.Sprintf("get context from agent: %s", args[base.AgentOutputInArg])) - docs, err := retriever.GetRelevantDocuments(ctx, args["question"].(string)) - if err != nil { - return args, fmt.Errorf("can't get doc from retriever: %w", err) - } doc := langchainschema.Document{PageContent: fmt.Sprintf("Tool Output of %s is: %s", args["question"].(string), args[base.AgentOutputInArg].(string))} docs = append(docs, doc) retriever = &appruntimeretriever.Fakeretriever{Docs: docs, Name: "AddAgentOutputRetriever"} } + if len(docs) == 0 { + docNullReturn, err := base.GetAPPDocNullReturnFromArg(args) + if err == nil && len(docNullReturn) > 0 { + return args, &base.RetrieverGetNullDocError{Msg: docNullReturn} + } + } + // Add the mapReduceDocument output to the context if it's not empty if args[base.MapReduceDocumentOutputInArg] != nil { // Note: Now args[base.MapReduceDocumentOutputInArg] will not be null only for the first "summarize content of uploaded document" Chat request @@ -167,28 +163,42 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli client.Client, args map[ condenseQustionGenerator.CallbacksHandler = log.KLogHandler{LogLevel: 3} chain := chains.NewConversationalRetrievalQA(chains.NewStuffDocuments(llmChain), condenseQustionGenerator, retriever, GetMemory(llm, instance.Spec.Memory, history, "", "")) chain.RephraseQuestion = false + chain.ReturnSourceDocuments = true l.ConversationalRetrievalQA = chain args["query"] = args["question"] - var out string + var ( + out string + outputValues map[string]any + ) needStream := false needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool) if ok && needStream { options = append(options, chains.WithStreamingFunc(stream(args))) - out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args, options...) + outputValues, err = chains.Call(ctx, l.ConversationalRetrievalQA, args, options...) } else { if len(options) > 0 { - out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args, options...) + outputValues, err = chains.Call(ctx, l.ConversationalRetrievalQA, args, options...) } else { - out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args) + outputValues, err = chains.Call(ctx, l.ConversationalRetrievalQA, args) } } + // _llmChainDefaultOutputKey + out, _ = outputValues["text"].(string) out, err = handleNoErrNoOut(ctx, needStream, out, err, l.ConversationalRetrievalQA, args, options) klog.FromContext(ctx).V(5).Info("use retrievalqachain, blocking out:" + out) if err == nil { args[base.OutputAnswerKeyInArg] = out + // _conversationalRetrievalQADefaultSourceDocumentKey + doc, ok := outputValues["source_documents"].([]langchainschema.Document) + if ok { + _, refs := appruntimeretriever.ConvertDocuments(ctx, doc, "retrievalqachain") + // note: the references in args will be replaced, not append + args[base.RuntimeRetrieverReferencesKeyInArg] = refs + } return args, nil } + return args, fmt.Errorf("retrievalqachain run error: %w", err) } diff --git a/pkg/appruntime/knowledgebase/knowledgebase.go b/pkg/appruntime/knowledgebase/knowledgebase.go index e57ee21d1..8448f7a36 100644 --- a/pkg/appruntime/knowledgebase/knowledgebase.go +++ b/pkg/appruntime/knowledgebase/knowledgebase.go @@ -23,6 +23,7 @@ import ( "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" + appnode "github.com/kubeagi/arcadia/api/app-node" "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" ) @@ -39,6 +40,9 @@ func NewKnowledgebase(baseNode base.BaseNode) *Knowledgebase { } func (k *Knowledgebase) Init(ctx context.Context, cli client.Client, _ map[string]any) error { + if appnode.IsPlaceholderConversationKnowledgebase(k.Ref.Name) { + return nil + } instance := &v1alpha1.KnowledgeBase{} if err := cli.Get(ctx, types.NamespacedName{Namespace: k.RefNamespace(), Name: k.Ref.Name}, instance); err != nil { return fmt.Errorf("can't find the knowledgebase in cluster: %w", err) @@ -47,10 +51,9 @@ func (k *Knowledgebase) Init(ctx context.Context, cli client.Client, _ map[strin return nil } -func (k *Knowledgebase) Run(_ context.Context, _ client.Client, args map[string]any) (map[string]any, error) { - return args, nil -} - func (k *Knowledgebase) Ready() (isReady bool, msg string) { + if appnode.IsPlaceholderConversationKnowledgebase(k.Ref.Name) { + return true, "" + } return k.Instance.Status.IsReadyOrGetReadyMessage() } diff --git a/pkg/appruntime/log/log.go b/pkg/appruntime/log/log.go index e1c4ba7a3..5eb23c5ce 100644 --- a/pkg/appruntime/log/log.go +++ b/pkg/appruntime/log/log.go @@ -82,7 +82,8 @@ func (l KLogHandler) HandleStreamingFunc(ctx context.Context, chunk []byte) { logger := klog.FromContext(ctx) logger.WithValues("logger", "arcadia") // Lower level for log streaming - logger.V(l.LogLevel + 1).Info("log streaming: " + string(chunk)) + // maybe 6 is enough + logger.V(6).Info("log streaming: " + string(chunk)) } func (l KLogHandler) HandleText(ctx context.Context, text string) { diff --git a/pkg/appruntime/retriever/common.go b/pkg/appruntime/retriever/common.go index 50bf2e2f9..4bd1f7a2f 100644 --- a/pkg/appruntime/retriever/common.go +++ b/pkg/appruntime/retriever/common.go @@ -56,6 +56,8 @@ type Reference struct { Metadata map[string]any `json:"-"` } +const RerankScoreCol string = "rerank_score" + func (reference Reference) String() string { bytes, err := json.Marshal(&reference) if err != nil { @@ -79,6 +81,7 @@ func AddReferencesToArgs(args map[string]any, refs []Reference) map[string]any { return args } +// ConvertDocuments, convert raw doc to what we want doc, for example, knowledgebase doc should add answer into page content func ConvertDocuments(ctx context.Context, docs []langchaingoschema.Document, retrieverName string) (newDocs []langchaingoschema.Document, refs []Reference) { logger := klog.FromContext(ctx) docLen := len(docs) @@ -108,7 +111,7 @@ func ConvertDocuments(ctx context.Context, docs []langchaingoschema.Document, re doc.PageContent = doc.PageContent + joinStr + answer } } - if retrieverName == "multiquery" { + if retrieverName == "multiquery" || retrieverName == "retrievalqachain" { // pageContent may have the answer in previous steps, and we want this field only has question in reference output pageContent = strings.TrimSuffix(pageContent, joinStr+answer) } @@ -145,6 +148,7 @@ func ConvertDocuments(ctx context.Context, docs []langchaingoschema.Document, re content = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") } } + rerankScore, _ := doc.Metadata[RerankScoreCol].(float32) refs = append(refs, Reference{ Question: pageContent, Answer: answer, @@ -155,6 +159,7 @@ func ConvertDocuments(ctx context.Context, docs []langchaingoschema.Document, re PageNumber: page, Content: content, Metadata: doc.Metadata, + RerankScore: rerankScore, }) docs[k] = doc } diff --git a/pkg/appruntime/retriever/knowledgebaseretriever.go b/pkg/appruntime/retriever/knowledgebaseretriever.go index 28450a6fd..e1c207da1 100644 --- a/pkg/appruntime/retriever/knowledgebaseretriever.go +++ b/pkg/appruntime/retriever/knowledgebaseretriever.go @@ -18,16 +18,16 @@ package retriever import ( "context" - "errors" "fmt" - langchaingoschema "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" + apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" "k8s.io/utils/pointer" "sigs.k8s.io/controller-runtime/pkg/client" + appnode "github.com/kubeagi/arcadia/api/app-node" apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" @@ -101,7 +101,21 @@ func (l *KnowledgeBaseRetriever) Cleanup() { func GenerateKnowledgebaseRetriever(ctx context.Context, cli client.Client, knowledgebaseName, knowledgebaseNamespace string, retrieverConfig apiretriever.CommonRetrieverConfig, args map[string]any) (outArg map[string]any, finish func(), err error) { knowledgebase := &v1alpha1.KnowledgeBase{} + isConversationKnowledgebase := appnode.IsPlaceholderConversationKnowledgebase(knowledgebaseName) + if isConversationKnowledgebase { + v, ok := args[base.ConversationIDInArg] + if ok { + conversationID, ok := v.(string) + if ok && conversationID != "" { + knowledgebaseName = conversationID + } + } + } if err := cli.Get(ctx, types.NamespacedName{Namespace: knowledgebaseNamespace, Name: knowledgebaseName}, knowledgebase); err != nil { + if isConversationKnowledgebase && apierrors.IsNotFound(err) { // When there is a conversationID, look for the corresponding conversation knowledgebase. This knowledgebase may not exist. This is not a error + // TODO We can search for whether there should be a conversation knowledgebase from the pg + return args, nil, nil + } return nil, nil, fmt.Errorf("can't find the knowledgebase in cluster: %w", err) } @@ -138,40 +152,14 @@ func GenerateKnowledgebaseRetriever(ctx context.Context, cli client.Client, know } retriever.CallbacksHandler = log.KLogHandler{LogLevel: 3} - question, ok := args["question"] - if !ok { - return nil, finish, errors.New("no question in args") - } - query, ok := question.(string) - if !ok { - return nil, finish, errors.New("question not string") + query, err := base.GetInputQuestionFromArg(args) + if err != nil { + return nil, finish, err } docs, err := retriever.GetRelevantDocuments(ctx, query) if err != nil { return nil, finish, fmt.Errorf("can't get relevant documents: %w", err) } - oldDocs := make([]langchaingoschema.Document, 0) - v, ok := args[base.LangchaingoRetrieverKeyInArg] - if ok { - // may exist other retriever, like conversation retriever - oldRetriever, ok := v.(langchaingoschema.Retriever) - if ok { - oldDocs, err = oldRetriever.GetRelevantDocuments(ctx, query) - if err != nil { - return nil, finish, fmt.Errorf("can't get old doc: %w", err) - } - } - } - if len(docs) == 0 && len(oldDocs) == 0 { - // FIXME: 需要决定当知识库找不到相关内容,但是conversation知识库存在文档时如何处理 - v, exist := args[base.APPDocNullReturn] - if exist { - docNullReturn, ok := v.(string) - if ok && len(docNullReturn) > 0 { - return nil, finish, &base.RetrieverGetNullDocError{Msg: docNullReturn} - } - } - } // pgvector get score means vector distance, similarity = 1 - vector distance // chroma get score means similarity // we want similarity finally. @@ -181,7 +169,7 @@ func GenerateKnowledgebaseRetriever(ctx context.Context, cli client.Client, know } } docs, refs := ConvertDocuments(ctx, docs, "knowledgebase") - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: append(docs, oldDocs...), Name: "KnowledgebaseRetriever"} - AddReferencesToArgs(args, refs) + args = AddReferencesToArgs(args, refs) + args = base.AddKnowledgebaseRetrieverToArg(args, &Fakeretriever{Docs: docs, Name: "KnowledgebaseRetriever"}) return args, finish, nil } diff --git a/pkg/appruntime/retriever/mergerretriever.go b/pkg/appruntime/retriever/mergerretriever.go new file mode 100644 index 000000000..26330da95 --- /dev/null +++ b/pkg/appruntime/retriever/mergerretriever.go @@ -0,0 +1,79 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package retriever + +import ( + "context" + "errors" + "fmt" + + langchainretrievers "github.com/tmc/langchaingo/retrievers" + "github.com/tmc/langchaingo/schema" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" + "github.com/kubeagi/arcadia/pkg/appruntime/base" +) + +type MergerRetriever struct { + base.BaseNode + Instance *apiretriever.MergerRetriever +} + +func NewMergerRetriever(baseNode base.BaseNode) *MergerRetriever { + return &MergerRetriever{ + BaseNode: baseNode, + } +} + +func (l *MergerRetriever) Init(ctx context.Context, cli client.Client, _ map[string]any) error { + instance := &apiretriever.MergerRetriever{} + if err := cli.Get(ctx, types.NamespacedName{Namespace: l.RefNamespace(), Name: l.BaseNode.Ref.Name}, instance); err != nil { + return fmt.Errorf("can't find the merger retriever in cluster: %w", err) + } + l.Instance = instance + return nil +} + +func (l *MergerRetriever) Run(ctx context.Context, _ client.Client, args map[string]any) (map[string]any, error) { + retrievers, err := base.GetRetrieversFromArg(args) + if err != nil { + if errors.Is(err, base.ErrNoRetrievers) { + return args, nil + } + return args, err + } + r := langchainretrievers.NewMergerRetriever(retrievers) + query, err := base.GetInputQuestionFromArg(args) + if err != nil { + return args, err + } + docs, err := r.GetRelevantDocuments(ctx, query) + if err != nil { + return args, err + } + docs, refs := ConvertDocuments(ctx, docs, "mergerretriever") + // note: the references in args will be replaced, not append + args[base.RuntimeRetrieverReferencesKeyInArg] = refs + args[base.LangchaingoRetrieversKeyInArg] = []schema.Retriever{&Fakeretriever{Docs: docs, Name: "mergerRetriever"}} + return args, nil +} + +func (l *MergerRetriever) Ready() (isReady bool, msg string) { + return l.Instance.Status.IsReadyOrGetReadyMessage() +} diff --git a/pkg/appruntime/retriever/multiqueryretriever.go b/pkg/appruntime/retriever/multiqueryretriever.go index 610ec9c1f..e2bca6d4f 100644 --- a/pkg/appruntime/retriever/multiqueryretriever.go +++ b/pkg/appruntime/retriever/multiqueryretriever.go @@ -69,13 +69,12 @@ func (l *MultiQueryRetriever) Run(ctx context.Context, cli client.Client, args m return args, errors.New("empty question") } - v1, ok := args[base.LangchaingoRetrieverKeyInArg] - if !ok { - return args, errors.New("no retriever") - } - retriever, ok := v1.(langchainschema.Retriever) - if !ok { - return args, errors.New("retriever not schema.Retriever") + retrieversInArg, err := base.GetRetrieversFromArg(args) + if err != nil { + if errors.Is(err, base.ErrNoRetrievers) { + return args, nil + } + return args, err } v2, ok := args[base.LangchaingoLLMKeyInArg] @@ -88,7 +87,7 @@ func (l *MultiQueryRetriever) Run(ctx context.Context, cli client.Client, args m } prompt := prompts.NewPromptTemplate(_defaultQueryTemplate, []string{"question"}) llmchain := chains.NewLLMChain(llm, prompt, chains.WithCallback(log.KLogHandler{LogLevel: 3})) - multiqueryRetriever := retrievers.NewMultiQueryRetriever(retriever, llmchain, true) + multiqueryRetriever := retrievers.NewMultiQueryRetriever(retrieversInArg[0], llmchain, true) multiqueryRetriever.CallbacksHandler = log.KLogHandler{LogLevel: 3} docs, err := multiqueryRetriever.GetRelevantDocuments(ctx, query) if err != nil { @@ -107,7 +106,7 @@ func (l *MultiQueryRetriever) Run(ctx context.Context, cli client.Client, args m newDocs, newRef := ConvertDocuments(ctx, newDocs, "multiquery") // note: the references in args will be replaced, not append args[base.RuntimeRetrieverReferencesKeyInArg] = newRef - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: newDocs, Name: "MultiqueryRetriever"} + args[base.LangchaingoRetrieversKeyInArg] = []langchainschema.Retriever{&Fakeretriever{Docs: newDocs, Name: "MultiqueryRetriever"}} return args, nil } diff --git a/pkg/appruntime/retriever/rerankretriever.go b/pkg/appruntime/retriever/rerankretriever.go index 4ae0c1fe6..f7343357e 100644 --- a/pkg/appruntime/retriever/rerankretriever.go +++ b/pkg/appruntime/retriever/rerankretriever.go @@ -66,14 +66,8 @@ func (l *RerankRetriever) Run(ctx context.Context, cli client.Client, args map[s return args, errors.New("empty references") } if len(references) == 0 { - v, exist := args[base.APPDocNullReturn] - if exist { - docNullReturn, ok := v.(string) - if ok && len(docNullReturn) > 0 { - return nil, &base.RetrieverGetNullDocError{Msg: docNullReturn} - } - } - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: nil, Name: "RerankRetriever"} + klog.FromContext(ctx).V(3).Info("rerank retriever get no references, skip rerank") + args[base.LangchaingoRetrieversKeyInArg] = []langchainschema.Retriever{&Fakeretriever{Docs: nil, Name: "RerankRetriever"}} return args, nil } q, ok := args[base.InputQuestionKeyInArg] @@ -143,15 +137,14 @@ func (l *RerankRetriever) Run(ctx context.Context, cli client.Client, args map[s // note: the references in args will be replaced, not append args[base.RuntimeRetrieverReferencesKeyInArg] = newRef - v, ok := args[base.LangchaingoRetrieverKeyInArg] - if !ok { - return args, errors.New("no retriever") - } - retriever, ok := v.(langchainschema.Retriever) - if !ok { - return args, errors.New("retriever not schema.Retriever") + retrievers, err := base.GetRetrieversFromArg(args) + if err != nil { + if errors.Is(err, base.ErrNoRetrievers) { + return args, nil + } + return args, err } - docs, err := retriever.GetRelevantDocuments(ctx, query) + docs, err := retrievers[0].GetRelevantDocuments(ctx, query) if err != nil { return args, fmt.Errorf("get relevant documents failed: %w", err) } @@ -159,11 +152,15 @@ func (l *RerankRetriever) Run(ctx context.Context, cli client.Client, args map[s for i := range newRef { for j := range docs { if newRef[i].Score == docs[j].Score && reflect.DeepEqual(newRef[i].Metadata, docs[j].Metadata) { + if v := docs[j].Metadata; len(v) == 0 { + docs[j].Metadata = make(map[string]any, 1) + } + docs[j].Metadata[RerankScoreCol] = newRef[i].RerankScore newDocs = append(newDocs, docs[j]) } } } - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: newDocs, Name: "RerankRetriever"} + args[base.LangchaingoRetrieversKeyInArg] = []langchainschema.Retriever{&Fakeretriever{Docs: newDocs, Name: "RerankRetriever"}} return args, nil } diff --git a/tests/deploy-values.yaml b/tests/deploy-values.yaml index 3b7e906a5..53fcad628 100644 --- a/tests/deploy-values.yaml +++ b/tests/deploy-values.yaml @@ -140,3 +140,8 @@ postgresql: repository: kubeagi/postgresql tag: 16.1.0-debian-11-r18-pgvector-v0.5.1 pullPolicy: IfNotPresent + +config: + embedder: + enabled: true + model: "arcadia-embedder" diff --git a/tests/example-test.sh b/tests/example-test.sh index 6e80dd971..a61168995 100755 --- a/tests/example-test.sh +++ b/tests/example-test.sh @@ -40,314 +40,11 @@ Timeout="${TimeoutSeconds}s" mkdir ${TempFilePath} || true env -function debugInfo { - if [[ $? -eq 0 ]]; then - exit 0 - fi - if [[ $debug -ne 0 ]]; then - exit 1 - fi - if [[ $GITHUB_ACTIONS == "true" ]]; then - warning "debugInfo start 🧐" - mkdir -p $LOG_DIR || true - df -h - - warning "1. Try to get all resources " - kubectl api-resources --verbs=list -o name | xargs -n 1 kubectl get -A --ignore-not-found=true --show-kind=true >$LOG_DIR/get-all-resources-list.log - kubectl api-resources --verbs=list -o name | xargs -n 1 kubectl get -A -oyaml --ignore-not-found=true --show-kind=true >$LOG_DIR/get-all-resources-yaml.log - - warning "2. Try to describe all resources " - kubectl api-resources --verbs=list -o name | xargs -n 1 kubectl describe -A >$LOG_DIR/describe-all-resources.log - - warning "3. Try to export kind logs to $LOG_DIR..." - kind export logs --name=${KindName} $LOG_DIR - sudo chown -R $USER:$USER $LOG_DIR - - warning "debugInfo finished ! " - warning "This means that some tests have failed. Please check the log. 🌚" - debug=1 - fi - exit 1 -} +source ./tests/scripts/utils.sh trap 'debugInfo $LINENO' ERR trap 'debugInfo $LINENO' EXIT debug=0 -function cecho() { - declare -A colors - colors=( - ['black']='\E[0;47m' - ['red']='\E[0;31m' - ['green']='\E[0;32m' - ['yellow']='\E[0;33m' - ['blue']='\E[0;34m' - ['magenta']='\E[0;35m' - ['cyan']='\E[0;36m' - ['white']='\E[0;37m' - ) - local defaultMSG="No message passed." - local defaultColor="black" - local defaultNewLine=true - while [[ $# -gt 1 ]]; do - key="$1" - case $key in - -c | --color) - color="$2" - shift - ;; - -n | --noline) - newLine=false - ;; - *) - # unknown option - ;; - esac - shift - done - message=${1:-$defaultMSG} # Defaults to default message. - color=${color:-$defaultColor} # Defaults to default color, if not specified. - newLine=${newLine:-$defaultNewLine} - echo -en "${colors[$color]}" - echo -en "$message" - if [ "$newLine" = true ]; then - echo - fi - tput sgr0 # Reset text attributes to normal without clearing screen. - return -} - -function warning() { - cecho -c 'yellow' "$@" -} - -function error() { - cecho -c 'red' "$@" -} - -function info() { - cecho -c 'blue' "$@" -} - -function waitPodReady() { - namespace=$1 - podLabel=$2 - START_TIME=$(date +%s) - while true; do - readStatus=$(kubectl -n${namespace} get po -l ${podLabel} --ignore-not-found=true -o json | jq -r '.items[0].status.conditions[] | select(."type"=="Ready") | .status') - if [[ $readStatus == "True" ]]; then - info "Pod:${podLabel} ready" - break - fi - kubectl -n${namespace} get po -l ${podLabel} - - CURRENT_TIME=$(date +%s) - ELAPSED_TIME=$((CURRENT_TIME - START_TIME)) - if [ $ELAPSED_TIME -gt $TimeoutSeconds ]; then - error "Timeout reached" - kubectl describe po -n${namespace} -l ${podLabel} - kubectl get po -n${namespace} --show-labels - exit 1 - fi - sleep 5 - done -} - -function EnableAPIServerPortForward() { - waitPodReady "arcadia" "app=arcadia-apiserver" - if [ $portal_pid -ne 0 ]; then - kill $portal_pid >/dev/null 2>&1 - fi - echo "re port-forward apiserver..." - kubectl port-forward svc/arcadia-apiserver -n arcadia 8081:8081 >/dev/null 2>&1 & - portal_pid=$! - sleep 3 - info "port-forward apiserver in pid: $portal_pid" -} - -function waitCRDStatusReady() { - source=$1 - namespace=$2 - name=$3 - START_TIME=$(date +%s) - while true; do - readStatus=$(kubectl -n${namespace} get ${source} ${name} --ignore-not-found=true -o json | jq -r '.status.conditions[0].status') - message=$(kubectl -n${namespace} get ${source} ${name} --ignore-not-found=true -o json | jq -r '.status.conditions[0].message') - if [[ $readStatus == "True" ]]; then - info $message - if [[ ${source} == "KnowledgeBase" ]]; then - fileStatus=$(kubectl get knowledgebase -n $namespace $name -o json | jq -r '.status.fileGroupDetail[0].fileDetails[0].phase') - if [[ $fileStatus != "Succeeded" ]]; then - kubectl get knowledgebase -n $namespace $name -o json | jq -r '.status.fileGroupDetail[0].fileDetails' - exit 1 - fi - fi - break - fi - - CURRENT_TIME=$(date +%s) - ELAPSED_TIME=$((CURRENT_TIME - START_TIME)) - if [[ ${source} == "Worker" ]]; then - if [ $ELAPSED_TIME -gt 1800 ]; then - error "Timeout reached" - exit 1 - fi - else - if [ $ELAPSED_TIME -gt $TimeoutSeconds ]; then - error "Timeout reached" - exit 1 - fi - fi - sleep 5 - done -} - -function getRespInAppChat() { - appname=$1 - namespace=$2 - query=$3 - conversationID=$4 - testStream=$5 - attempt=0 - while true; do - info "sleep 3 seconds" - sleep 3 - data=$(jq -n --arg appname "$appname" --arg query "$query" --arg conversationID "$conversationID" '{"query":$query,"response_mode":"blocking","conversation_id":$conversationID,"app_name":$appname}') - resp=$(curl --max-time $TimeoutSeconds -s --show-error -XPOST http://127.0.0.1:8081/chat --data "$data" -H "namespace: ${namespace}") - ai_data=$(echo $resp | jq -r '.message') - references=$(echo $resp | jq -r '.references') - if [ -z "$ai_data" ] || [ "$ai_data" = "null" ]; then - echo $resp - EnableAPIServerPortForward - if [[ $resp == *"googleapi: Error"* ]]; then - echo "google api error, will retry after 60s" - sleep 60 - fi - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - continue - fi - echo "👤: ${query}" - echo "🤖: ${ai_data}" - echo "🔗: ${references}" - break - done - resp_conversation_id=$(echo $resp | jq -r '.conversation_id') - - if [ $testStream == "true" ]; then - attempt=0 - while true; do - info "sleep 5 seconds" - sleep 5 - info "just test stream mode" - data=$(jq -n --arg appname "$appname" --arg query "$query" --arg conversationID "$conversationID" '{"query":$query,"response_mode":"streaming","conversation_id":$conversationID,"app_name":$appname}') - curl --max-time $TimeoutSeconds -s --show-error -XPOST http://127.0.0.1:8081/chat --data "$data" -H "namespace: ${namespace}" - if [[ $? -ne 0 ]]; then - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - EnableAPIServerPortForward - echo "and wait 60s for google api error" - sleep 60 - continue - fi - break - done - fi -} - -function fileUploadSummarise() { - appname=$1 - namespace=$2 - filename=$3 - attempt=0 - while true; do - info "sleep 3 seconds" - sleep 3 - resp=$(curl --max-time $TimeoutSeconds -s --show-error -XPOST --form file=@$filename --form app_name=$appname -H "namespace: ${namespace}" -H "Content-Type: multipart/form-data" http://127.0.0.1:8081/chat/conversations/file) - doc_data=$(echo $resp | jq -r '.document') - if [ -z "$doc_data" ]; then - echo $resp - EnableAPIServerPortForward - if [[ $resp == *"googleapi: Error"* ]]; then - echo "google api error, will retry after 60s" - sleep 60 - fi - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - continue - fi - echo "👤: ${filename}" - echo "🤖: ${doc_data}" - break - done - file_id=$(echo $resp | jq -r '.document.object') - resp_conversation_id=$(echo $resp | jq -r '.conversation_id') - attempt=0 - while true; do - info "sleep 3 seconds to sumerize doc" - sleep 3 - data=$(jq -n --arg fileid "$file_id" --arg appname "$appname" --arg query "总结一下" --arg conversationID "$resp_conversation_id" '{"query":$query,"response_mode":"blocking","conversation_id":$conversationID,"app_name":$appname, "files": [$fileid]}') - resp=$(curl --max-time $TimeoutSeconds -s --show-error -XPOST http://127.0.0.1:8081/chat --data "$data" -H "namespace: ${namespace}") - ai_data=$(echo $resp | jq -r '.message') - references=$(echo $resp | jq -r '.references') - if [ -z "$ai_data" ] || [ "$ai_data" = "null" ]; then - echo $resp - EnableAPIServerPortForward - if [[ $resp == *"googleapi: Error"* ]]; then - echo "google api error, will retry after 60s" - sleep 60 - fi - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - continue - fi - echo "👤: 总结一下" - echo "🤖: ${ai_data}" - echo "🔗: ${references}" - break - done - resp_conversation_id=$(echo $resp | jq -r '.conversation_id') - - if [ $testStream == "true" ]; then - attempt=0 - while true; do - info "sleep 5 seconds" - sleep 5 - info "just test stream mode" - data=$(jq -n --arg fileid "$file_id" --arg appname "$appname" --arg query "总结一下" --arg conversationID "$resp_conversation_id" '{"query":$query,"response_mode":"blocking","conversation_id":$conversationID,"app_name":$appname, "files": [$fileid]}') - curl --max-time $TimeoutSeconds -s --show-error -XPOST http://127.0.0.1:8081/chat --data "$data" -H "namespace: ${namespace}" - if [[ $? -ne 0 ]]; then - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - EnableAPIServerPortForward - echo "and wait 60s for google api error" - sleep 60 - continue - fi - break - done - fi -} - info "1. create kind cluster" make kind df -h @@ -410,6 +107,7 @@ s3_secret=$(kubectl get secrets -n arcadia datasource-sample-authsecret -o json export MC_HOST_arcadiatest=http://${s3_key}:${s3_secret}@127.0.0.1:9000 mc cp pkg/documentloaders/testdata/qa.csv arcadiatest/${bucket}/qa.csv mc cp pkg/documentloaders/testdata/chunk.csv arcadiatest/${bucket}/chunk.csv +mc cp CODE_OF_CONDUCT.md arcadiatest/${bucket}/CODE_OF_CONDUCT.md info "add tags to these files" mc tag set arcadiatest/${bucket}/qa.csv "object_type=QA" mc tag set arcadiatest/${bucket}/chunk.csv "object_type=QA" @@ -445,6 +143,7 @@ sleep 3 info "7.4.2 create knowledgebase based on pgvector and wait it ready" kubectl apply -f config/samples/arcadia_v1alpha1_knowledgebase_pgvector.yaml waitCRDStatusReady "KnowledgeBase" "arcadia" "knowledgebase-sample-pgvector" +waitCRDStatusReady "KnowledgeBase" "arcadia" "knowledgebase-sample-pgvector2" info "7.5 check vectorstore has data" info "7.5.1 check chroma vectorstore has data" @@ -590,7 +289,14 @@ kubectl patch applications -n arcadia base-chat-with-knowledgebase-pgvector -p ' getRespInAppChat "base-chat-with-knowledgebase-pgvector" "arcadia" "飞天的主演是谁?" "" "true" info "8.2.3 QA app using knowledgebase base on pgvector and rerank" -kubectl apply -f config/samples/arcadia_v1alpha1_model_reranking_bce.yaml +if [[ $GITHUB_ACTIONS == "true" ]]; then + info "in github action, download model from huggingface" + kubectl apply -f config/samples/arcadia_v1alpha1_model_reranking_bce.yaml +else + # https://github.com/kubeagi/core-library/issues/54 + info "in local, download model from modelscope" + kubectl apply -f config/samples/arcadia_v1alpha1_model_reranking_bce_modelscope.yaml +fi waitCRDStatusReady "Model" "arcadia" "bce-reranker" kubectl apply -f config/samples/arcadia_v1alpha1_worker_reranking_bce.yaml waitCRDStatusReady "Worker" "arcadia" "bce-reranker" @@ -668,6 +374,37 @@ info "8.2.5.3 When no related doc is found and application.spec.docNullReturn is kubectl patch applications -n arcadia base-chat-with-knowledgebase-pgvector-multiquery -p '{"spec":{"docNullReturn":""}}' --type='merge' getRespInAppChat "base-chat-with-knowledgebase-pgvector-multiquery" "arcadia" "飞天的主演是谁?" "" "true" +info "8.2.6 QA app using multiple knowledgebase base on pgvector and multiquery" +kubectl apply -f config/samples/app_retrievalqachain_multi_knowledgebase_pgvector_rerank_multiquery.yaml +waitCRDStatusReady "Application" "arcadia" "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" +sleep 3 +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "公司的考勤管理制度适用于哪些人员?" "" "true" +if [[ $ai_data != *"全体正式员工及实习生"* ]]; then + echo "resp should contains '公司全体正式员工及实习生', but resp is:"$resp + exit 1 +fi +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "怀孕9个月以上每月可以享受几天假期?" "" "true" +if [[ $ai_data != *"4"* ]]; then + echo "resp should contains '4', but resp is:"$resp + exit 1 +fi +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "arcadia follows which Conduct?" "" "true" +# FIXME: how to change cnfc to cncf +if [[ $ai_data != *"cncf"* ]] && [[ $ai_data != *"CNCF"* ]] && [[ $ai_data != *"CNFC"* ]]; then + echo "resp should contains 'cncf', but resp is:"$resp + exit 1 +fi +info "8.2.5.2 When no related doc is found, return application.spec.docNullReturn info, if set" +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "飞天的主演是谁?" "" "true" +expected=$(kubectl get applications -n arcadia base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery -o json | jq -r .spec.docNullReturn) +if [[ $ai_data != $expected ]]; then + echo "when no related doc is found, return application.spec.docNullReturn info should be:"$expected ", but resp:"$resp + exit 1 +fi +info "8.2.5.3 When no related doc is found and application.spec.docNullReturn is not set" +kubectl patch applications -n arcadia base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery -p '{"spec":{"docNullReturn":""}}' --type='merge' +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "飞天的主演是谁?" "" "true" + info "8.3 conversation chat app" kubectl apply -f config/samples/app_llmchain_chat_with_bot.yaml waitCRDStatusReady "Application" "arcadia" "base-chat-with-bot" @@ -747,17 +484,38 @@ kubectl apply -f config/samples/app_llmchain_abstract.yaml waitCRDStatusReady "Application" "arcadia" "base-chat-document-assistant" fileUploadSummarise "base-chat-document-assistant" "arcadia" "./pkg/documentloaders/testdata/arcadia-readme.pdf" getRespInAppChat "base-chat-document-assistant" "arcadia" "what is arcadia?" ${resp_conversation_id} "false" +if [[ $ai_data != *"mythology"* ]] && [[ $ai_data != *"LLMOps"* ]] && [[ $ai_data != *"Kubernetes"* ]]; then + echo "resp should contains 'mythology' or 'LLMOps' or 'Kubernetes', but resp is:"$ai_data + exit 1 +fi getRespInAppChat "base-chat-document-assistant" "arcadia" "Does your model based on gpt-3.5?" ${resp_conversation_id} "false" info "8.4.7 chat with document with knowledgebase" +kubectl apply -f config/samples/app_retrievalqachain_knowledgebase_pgvector-conversation.yaml fileUploadSummarise "base-chat-with-knowledgebase-pgvector" "arcadia" "./pkg/documentloaders/testdata/arcadia-readme.pdf" getRespInAppChat "base-chat-with-knowledgebase-pgvector" "arcadia" "what is arcadia?" ${resp_conversation_id} "false" +if [[ $ai_data != *"mythology"* ]] && [[ $ai_data != *"LLMOps"* ]] && [[ $ai_data != *"Kubernetes"* ]]; then + echo "resp should contains 'mythology' or 'LLMOps' or 'Kubernetes', but resp is:"$ai_data + exit 1 +fi getRespInAppChat "base-chat-with-knowledgebase-pgvector" "arcadia" "公司的考勤管理制度适用于哪些人员?" ${resp_conversation_id} "false" -# FIXME According to the log, we returned the corresponding document to the big model, but the big model probably returned unknown -#if [[ $ai_data != *"全体正式员工及实习生"* ]]; then -# echo "resp should contains '公司全体正式员工及实习生', but resp is:"$resp -# exit 1 -#fi +if [[ $ai_data != *"全体正式员工及实习生"* ]]; then + echo "resp should contains '公司全体正式员工及实习生', but resp is:"$resp + exit 1 +fi + +info "8.4.8 chat with document with multi knowledgebase" +fileUploadSummarise "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "./pkg/documentloaders/testdata/arcadia-readme.pdf" +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "what is arcadia?" ${resp_conversation_id} "false" +if [[ $ai_data != *"mythology"* ]] && [[ $ai_data != *"LLMOps"* ]] && [[ $ai_data != *"Kubernetes"* ]]; then + echo "resp should contains 'mythology' or 'LLMOps' or 'Kubernetes', but resp is:"$ai_data + exit 1 +fi +getRespInAppChat "base-chat-with-multi-knowledgebase-pgvector-rerank-multiquery" "arcadia" "公司的考勤管理制度适用于哪些人员?" ${resp_conversation_id} "false" +if [[ $ai_data != *"全体正式员工及实习生"* ]]; then + echo "resp should contains '公司全体正式员工及实习生', but resp is:"$resp + exit 1 +fi # There is uncertainty in the AI replies, most of the time, it will pass the test, a small percentage of the time, the AI will call names in each reply, causing the test to fail, therefore, temporarily disable the following tests #getRespInAppChat "base-chat-with-bot" "arcadia" "What is your model?" ${resp_conversation_id} "false" @@ -790,11 +548,13 @@ if [[ $ai_data != *"782"* ]]; then echo "resp should contains 782, but resp:"$resp exit 1 fi +sleep 1 getRespInAppChat "base-chat-with-bot-tool" "arcadia" "结果再乘2" ${resp_conversation_id} "false" if [[ $ai_data != *"1564"* ]]; then echo "resp should contains 1564, but resp:"$resp exit 1 fi +sleep 1 getRespInAppChat "base-chat-with-bot-tool" "arcadia" "结果再减去564" ${resp_conversation_id} "false" if [[ $ai_data != *"1000"* ]]; then echo "resp should contains 1000, but resp:"$resp @@ -871,3 +631,4 @@ kubectl logs --tail=100 -n arcadia -l app=arcadia-apiserver >/tmp/apiserver.log cat /tmp/apiserver.log info "all finished! ✅" +exit 1