Skip to content

Commit

Permalink
feat: Add Web Crawler for Knowledge Bases (#600)
Browse files Browse the repository at this point in the history
* Advanced Parsing

* reformatted

* webcrawler

* 参考ドキュメントにknowledge baseのwebLocationを設定する

* fix ci

* fix reviewed

* syncが自動で走らない不具合を修正

* レビュー指摘反映

* Rename WebCrawlingFilters as WebCrawlingFiltersModel to obey coding convention

---------

Co-authored-by: statefb <[email protected]>
  • Loading branch information
fsatsuki and statefb authored Nov 21, 2024
1 parent 8bb15e8 commit d807d0c
Show file tree
Hide file tree
Showing 22 changed files with 1,974 additions and 25 deletions.
12 changes: 11 additions & 1 deletion backend/app/repositories/models/custom_bot_kb.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from app.routes.schemas.bot_kb import (
type_kb_chunking_strategy,
type_kb_embeddings_model,
type_kb_parsing_model,
type_kb_search_type,
type_kb_web_crawling_scope,
type_os_character_filter,
type_os_token_filter,
type_os_tokenizer,
type_kb_parsing_model,
)
from pydantic import BaseModel

Expand Down Expand Up @@ -53,6 +54,11 @@ class NoneParamsModel(BaseModel):
chunking_strategy: type_kb_chunking_strategy = "none"


class WebCrawlingFiltersModel(BaseModel):
exclude_patterns: list[str]
include_patterns: list[str]


class BedrockKnowledgeBaseModel(BaseModel):
embeddings_model: type_kb_embeddings_model
open_search: OpenSearchParamsModel
Expand All @@ -68,3 +74,7 @@ class BedrockKnowledgeBaseModel(BaseModel):
knowledge_base_id: str | None = None
data_source_ids: list[str] | None = None
parsing_model: type_kb_parsing_model = "disabled"
web_crawling_scope: type_kb_web_crawling_scope = "DEFAULT"
web_crawling_filters: WebCrawlingFiltersModel = WebCrawlingFiltersModel(
exclude_patterns=[], include_patterns=[]
)
53 changes: 50 additions & 3 deletions backend/app/routes/schemas/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,51 @@ class BotModifyInput(BaseSchema):
bedrock_knowledge_base: BedrockKnowledgeBaseInput | None = None
bedrock_guardrails: BedrockGuardrailsInput | None = None

def has_update_files(self) -> bool:
def _has_update_files(self) -> bool:
return self.knowledge is not None and (
len(self.knowledge.added_filenames) > 0
or len(self.knowledge.deleted_filenames) > 0
)

def guardrails_update_required(self, current_bot_model: BotModel) -> bool:
def _has_update_source_urls(self, current_bot_model: BotModel) -> bool:
return self.knowledge is not None and (
len(self.knowledge.source_urls) > 0
and (
set(self.knowledge.source_urls)
!= set(current_bot_model.knowledge.source_urls)
)
)

def _is_crawling_scope_modified(self, current_bot_model: BotModel) -> bool:
if (
self.bedrock_knowledge_base is None
or current_bot_model.bedrock_knowledge_base is None
):
return False
return (
self.bedrock_knowledge_base.web_crawling_scope
!= current_bot_model.bedrock_knowledge_base.web_crawling_scope
)

def _is_crawling_filters_modified(self, current_bot_model: BotModel) -> bool:
if (
self.bedrock_knowledge_base is None
or current_bot_model.bedrock_knowledge_base is None
or self.bedrock_knowledge_base.web_crawling_filters is None
or current_bot_model.bedrock_knowledge_base.web_crawling_filters is None
):
return False
return set(
self.bedrock_knowledge_base.web_crawling_filters.exclude_patterns
) != set(
current_bot_model.bedrock_knowledge_base.web_crawling_filters.exclude_patterns
) or set(
self.bedrock_knowledge_base.web_crawling_filters.include_patterns
) != set(
current_bot_model.bedrock_knowledge_base.web_crawling_filters.include_patterns
)

def is_guardrails_update_required(self, current_bot_model: BotModel) -> bool:
# Check if self.bedrock_guardrails is None
if not self.bedrock_guardrails:
return False
Expand Down Expand Up @@ -151,7 +189,16 @@ def guardrails_update_required(self, current_bot_model: BotModel) -> bool:
return False

def is_embedding_required(self, current_bot_model: BotModel) -> bool:
if self.has_update_files():
if self._has_update_files():
return True

if self._has_update_source_urls(current_bot_model):
return True

if self._is_crawling_scope_modified(current_bot_model):
return True

if self._is_crawling_filters_modified(current_bot_model):
return True

if self.knowledge is not None and current_bot_model.has_knowledge():
Expand Down
14 changes: 14 additions & 0 deletions backend/app/routes/schemas/bot_kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
type_kb_parsing_model = Literal[
"anthropic.claude-3-sonnet-v1", "anthropic.claude-3-haiku-v1", "disabled"
]
type_kb_web_crawling_scope = Literal["DEFAULT", "HOST_ONLY", "SUBDOMAINS"]

# OpenSearch Serverless Analyzer
# Ref: https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-genref.html
Expand Down Expand Up @@ -75,6 +76,11 @@ class NoneParams(BaseSchema):
chunking_strategy: type_kb_chunking_strategy = "none"


class WebCrawlingFilters(BaseSchema):
exclude_patterns: list[str] = Field(default_factory=list)
include_patterns: list[str] = Field(default_factory=list)


class BedrockKnowledgeBaseInput(BaseSchema):
embeddings_model: type_kb_embeddings_model
open_search: OpenSearchParams
Expand All @@ -88,6 +94,10 @@ class BedrockKnowledgeBaseInput(BaseSchema):
search_params: SearchParams
knowledge_base_id: str | None = None
parsing_model: type_kb_parsing_model = "disabled"
web_crawling_scope: type_kb_web_crawling_scope = "DEFAULT"
web_crawling_filters: WebCrawlingFilters = WebCrawlingFilters(
exclude_patterns=[], include_patterns=[]
)


class BedrockKnowledgeBaseOutput(BaseSchema):
Expand All @@ -105,3 +115,7 @@ class BedrockKnowledgeBaseOutput(BaseSchema):
knowledge_base_id: str | None = None
data_source_ids: list[str] | None = None
parsing_model: type_kb_parsing_model = "disabled"
web_crawling_scope: type_kb_web_crawling_scope = "DEFAULT"
web_crawling_filters: WebCrawlingFilters = WebCrawlingFilters(
exclude_patterns=[], include_patterns=[]
)
2 changes: 1 addition & 1 deletion backend/app/usecases/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def modify_owned_bot(
sync_status = (
"QUEUED"
if modify_input.is_embedding_required(bot)
or modify_input.guardrails_update_required(bot)
or modify_input.is_guardrails_update_required(bot)
else "SUCCEEDED"
)

Expand Down
23 changes: 14 additions & 9 deletions backend/app/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,9 @@ def get_source_link(source: str) -> tuple[Literal["s3", "url"], str]:
client_method="get_object",
)
return "s3", source_link
elif source.startswith("http://") or source.startswith("https://"):
return "url", source
else:
# Assume source is a youtube video id
return "url", f"https://www.youtube.com/watch?v={source}"
# Return the source as is for knowledge base references
return "url", source


def _bedrock_knowledge_base_search(bot: BotModel, query: str) -> list[SearchResult]:
Expand All @@ -102,14 +100,21 @@ def _bedrock_knowledge_base_search(bot: BotModel, query: str) -> list[SearchResu
},
)

def extract_source_from_retrieval_result(retrieval_result):
"""Extract source URL/URI from retrieval result based on location type."""
location = retrieval_result.get("location", {})
location_type = location.get("type")

if location_type == "WEB":
return location.get("webLocation", {}).get("url", "")
elif location_type == "S3":
return location.get("s3Location", {}).get("uri", "")
return ""

search_results = []
for i, retrieval_result in enumerate(response.get("retrievalResults", [])):
content = retrieval_result.get("content", {}).get("text", "")
source = (
retrieval_result.get("location", {})
.get("s3Location", {})
.get("uri", "")
)
source = extract_source_from_retrieval_result(retrieval_result)

search_results.append(
SearchResult(rank=i, bot_id=bot.id, content=content, source=source)
Expand Down
1 change: 1 addition & 0 deletions backend/tests/test_repositories/test_custom_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_store_and_find_bot(self):
overlap_percentage=0,
),
parsing_model="anthropic.claude-3-sonnet-v1",
web_crawling_scope="DEFAULT",
),
bedrock_guardrails=BedrockGuardrailsModel(
is_guardrail_enabled=True,
Expand Down
15 changes: 15 additions & 0 deletions cdk/bin/bedrock-custom-bot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ import {
getChunkingStrategy,
getAnalyzer,
getParsingModel,
getCrowlingScope,
getCrawlingFilters,
} from "../lib/utils/bedrock-knowledge-base-args";
import {
CrawlingFilters,
} from "@cdklabs/generative-ai-cdk-constructs/lib/cdk-lib/bedrock/data-sources/web-crawler-data-source";

const app = new cdk.App();

Expand Down Expand Up @@ -40,6 +45,9 @@ const guardrails = JSON.parse(BEDROCK_GUARDRAILS);
const existingS3Urls: string[] = knowledge.s3_urls.L.map(
(s3Url: any) => s3Url.S
);
const sourceUrls: string[] = knowledge.source_urls.L.map(
(sourceUrl: any) => sourceUrl.S
);
const useStandbyReplicas: boolean = USE_STAND_BY_REPLICAS === "true";

console.log("ownerUserId: ", ownerUserId);
Expand All @@ -48,9 +56,12 @@ console.log("knowledgeBase: ", knowledgeBase);
console.log("knowledge: ", knowledge);
console.log("guardrails: ", guardrails);
console.log("existingS3Urls: ", existingS3Urls);
console.log("sourceUrls: ", sourceUrls);

const embeddingsModel = getEmbeddingModel(knowledgeBase.embeddings_model.S);
const parsingModel = getParsingModel(knowledgeBase.parsing_model.S)
const crawlingScope = getCrowlingScope(knowledgeBase.web_crawling_scope.S)
const crawlingFilters: CrawlingFilters = getCrawlingFilters(knowledgeBase.web_crawling_filters.M)
const maxTokens: number | undefined = knowledgeBase.chunking_configuration.M.max_tokens
? Number(knowledgeBase.chunking_configuration.M.max_tokens.N)
: undefined;
Expand Down Expand Up @@ -137,6 +148,7 @@ console.log("relevanceThreshold: ", relevanceThreshold);
console.log("guardrailArn: ", guardrailArn);
console.log("guardrailVersion: ", guardrailVersion);
console.log("parsingModel: ", parsingModel);
console.log("crawlingScope: ", crawlingScope);

if (analyzer) {
console.log(
Expand All @@ -161,10 +173,13 @@ const bedrockCustomBotStack = new BedrockCustomBotStack(
botId,
embeddingsModel,
parsingModel,
crawlingScope,
crawlingFilters,
bedrockClaudeChatDocumentBucketName:
BEDROCK_CLAUDE_CHAT_DOCUMENT_BUCKET_NAME,
chunkingStrategy,
existingS3Urls,
sourceUrls,
maxTokens,
instruction,
analyzer,
Expand Down
29 changes: 29 additions & 0 deletions cdk/lib/bedrock-custom-bot-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ import {
import {
S3DataSource,
} from "@cdklabs/generative-ai-cdk-constructs/lib/cdk-lib/bedrock/data-sources/s3-data-source";
import {
WebCrawlerDataSource,
CrawlingScope,
CrawlingFilters,
} from "@cdklabs/generative-ai-cdk-constructs/lib/cdk-lib/bedrock/data-sources/web-crawler-data-source";
import {
ParsingStategy
} from "@cdklabs/generative-ai-cdk-constructs/lib/cdk-lib/bedrock/data-sources/parsing";
Expand Down Expand Up @@ -49,12 +54,15 @@ interface BedrockCustomBotStackProps extends StackProps {
readonly bedrockClaudeChatDocumentBucketName: string;
readonly chunkingStrategy: ChunkingStrategy;
readonly existingS3Urls: string[];
readonly sourceUrls: string[];
readonly maxTokens?: number;
readonly instruction?: string;
readonly analyzer?: Analyzer;
readonly overlapPercentage?: number;
readonly guardrail?: BedrockGuardrailProps;
readonly useStandbyReplicas?: boolean;
readonly crawlingScope?: CrawlingScope;
readonly crawlingFilters?: CrawlingFilters;
}

export class BedrockCustomBotStack extends Stack {
Expand Down Expand Up @@ -114,6 +122,27 @@ export class BedrockCustomBotStack extends Stack {
});
});

// Add Web Crawler Data Sources
if (props.sourceUrls.length > 0) {
const webCrawlerDataSource = new WebCrawlerDataSource(this, 'WebCrawlerDataSource', {
knowledgeBase: kb,
sourceUrls: props.sourceUrls,
chunkingStrategy: props.chunkingStrategy,
parsingStrategy: props.parsingModel ? ParsingStategy.foundationModel({
parsingModel: props.parsingModel.asIModel(this),
}) : undefined,
crawlingScope: props.crawlingScope,
filters: {
excludePatterns: props.crawlingFilters?.excludePatterns,
includePatterns: props.crawlingFilters?.includePatterns,
}

});
new CfnOutput(this, 'DataSourceIdWebCrawler', {
value: webCrawlerDataSource.dataSourceId
})
}

if (props.guardrail?.is_guardrail_enabled == true) {
// Use only parameters with a value greater than or equal to 0
let contentPolicyConfigFiltersConfig = [];
Expand Down
39 changes: 39 additions & 0 deletions cdk/lib/utils/bedrock-knowledge-base-args.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import { unmarshall } from "@aws-sdk/util-dynamodb";
import {
BedrockFoundationModel,
} from "@cdklabs/generative-ai-cdk-constructs/lib/cdk-lib/bedrock";
import {
HierarchicalChunkingProps,
ChunkingStrategy,
} from "@cdklabs/generative-ai-cdk-constructs/lib/cdk-lib/bedrock/data-sources/chunking";
import {
CrawlingScope,
CrawlingFilters,
} from "@cdklabs/generative-ai-cdk-constructs/lib/cdk-lib/bedrock/data-sources/web-crawler-data-source";
import { Analyzer } from "@cdklabs/generative-ai-cdk-constructs/lib/cdk-lib/opensearch-vectorindex";
import {
CharacterFilterType,
Expand Down Expand Up @@ -51,6 +56,40 @@ export const getParsingModel = (
}
}

export const getCrowlingScope = (
web_crawling_scope: string
): CrawlingScope | undefined => {

switch(web_crawling_scope) {
case "DEFAULT":
return CrawlingScope.DEFAULT
case "HOST_ONLY":
return CrawlingScope.HOST_ONLY
case "SUBDOMAINS":
return CrawlingScope.SUBDOMAINS
default:
return undefined
}
}

export const getCrawlingFilters =(
web_crawling_filters: any
): CrawlingFilters => {
const regularJson = unmarshall(web_crawling_filters);
console.log(`regularJson: ${JSON.stringify(regularJson)}`)

let excludePatterns = undefined
let includePatterns = undefined

if (regularJson.exclude_patterns.length > 0 && regularJson.exclude_patterns[0] != "") excludePatterns = regularJson.exclude_patterns
if (regularJson.include_patterns.length > 0 && regularJson.include_patterns[0] != "") includePatterns = regularJson.include_patterns

return {
excludePatterns,
includePatterns,
}
}

export const getChunkingStrategy = (
chunkingStrategy: string,
embeddingsModel: string,
Expand Down
Loading

0 comments on commit d807d0c

Please sign in to comment.