From 087e263c57fc28c79eca5557f9df04d675fec5f1 Mon Sep 17 00:00:00 2001
From: Yukinobu Mine <14157373+Yukinobu-Mine@users.noreply.github.com>
Date: Wed, 18 Dec 2024 18:21:28 +0900
Subject: [PATCH] Fix: Nova doesn't work (#645)
* If model is Amazon Nova, combine multiple system prompts into one text. #629
* If model is Amazon Nova, set the upper limit of topK to 128. #629
* Remove title key from input JSON schema of AgentTool. #629
* Optimization for Amazon Nova.
- Change the system prompt when doing 'Retrieved Context Citation' with Amazon Nova.
- If the tool result has more than one element, pass it as single text content formatted as JSON array.
* Fix: mypy errors.
* Add stack trace to backend error logs.
* Fix for multimodal tool results for Amazon Nova
* Add comments, and document changes.
- `RemoveTitle` in app/agents/tools/agent_tool.py
- `_prepare_nova_model_params()` in app/bedrock.py
- `build_rag_prompt()` and `get_prompt_to_cite_tool_results()` in app/prompt.py
- `BaseTool` -> `AgentTool` in docs/AGENT.md
* Move `is_nova_model()` back to app/bedrock.py
- To avoid circular imports, add `from __future__ import annotations` and `if TYPE_CHECKING` to app/bedrock.py
* Update document of Agent functionality
- docs/AGENT.md
* Change `run_result_to_tool_result_content_model()` to be an instance method of `ToolResultContentModel`
- `agent_tool.run_result_to_tool_result_content_model()` -> `ToolResultContentModel.from_tool_run_result()`
---
backend/app/agents/tools/agent_tool.py | 50 ++++++-------
backend/app/agents/tools/knowledge.py | 4 +-
backend/app/bedrock.py | 74 +++++++++++++------
backend/app/prompt.py | 52 ++++++++++++-
backend/app/repositories/conversation.py | 1 +
.../app/repositories/models/conversation.py | 61 ++++++++++++++-
backend/app/usecases/chat.py | 17 +++--
backend/app/websocket.py | 11 ++-
.../test_agent/test_tools/test_agent_tool.py | 6 +-
.../test_tools/test_internet_search.py | 6 +-
.../test_agent/test_tools/test_knowledge.py | 10 ++-
backend/tests/test_usecases/test_chat.py | 1 +
docs/AGENT.md | 10 +--
examples/agents/tools/bmi/test_bmi.py | 9 ++-
14 files changed, 235 insertions(+), 77 deletions(-)
diff --git a/backend/app/agents/tools/agent_tool.py b/backend/app/agents/tools/agent_tool.py
index 7af079e44..2e4536135 100644
--- a/backend/app/agents/tools/agent_tool.py
+++ b/backend/app/agents/tools/agent_tool.py
@@ -5,12 +5,11 @@
TextToolResultModel,
JsonToolResultModel,
RelatedDocumentModel,
- ToolResultContentModel,
- ToolResultContentModelBody,
)
from app.repositories.models.custom_bot import BotModel
from app.routes.schemas.conversation import type_model_name
from pydantic import BaseModel, JsonValue
+from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from mypy_boto3_bedrock_runtime.type_defs import (
ToolSpecificationTypeDef,
)
@@ -27,28 +26,22 @@ class ToolRunResult(TypedDict):
related_documents: list[RelatedDocumentModel]
-def run_result_to_tool_result_content_model(
- run_result: ToolRunResult, display_citation: bool
-) -> ToolResultContentModel:
- return ToolResultContentModel(
- content_type="toolResult",
- body=ToolResultContentModelBody(
- tool_use_id=run_result["tool_use_id"],
- content=[
- related_document.to_tool_result_model(
- display_citation=display_citation,
- )
- for related_document in run_result["related_documents"]
- ],
- status=run_result["status"],
- ),
- )
-
-
class InvalidToolError(Exception):
pass
+class RemoveTitle(GenerateJsonSchema):
+ """Custom JSON schema generator that doesn't output `title`s for types and parameters."""
+
+ def field_title_should_be_set(self, schema) -> bool:
+ return False
+
+ def generate(self, schema, mode="validation") -> JsonSchemaValue:
+ value = super().generate(schema, mode)
+ del value["title"]
+ return value
+
+
class AgentTool(Generic[T]):
def __init__(
self,
@@ -59,19 +52,16 @@ def __init__(
[T, BotModel | None, type_model_name | None],
ToolFunctionResult | list[ToolFunctionResult],
],
- bot: BotModel | None = None,
- model: type_model_name | None = None,
):
self.name = name
self.description = description
self.args_schema = args_schema
self.function = function
- self.bot = bot
- self.model: type_model_name | None = model
def _generate_input_schema(self) -> dict[str, Any]:
"""Converts the Pydantic model to a JSON schema."""
- return self.args_schema.model_json_schema()
+ # Specify a custom generator `RemoveTitle` because some foundation models do not work properly if there are unnecessary titles.
+ return self.args_schema.model_json_schema(schema_generator=RemoveTitle)
def to_converse_spec(self) -> ToolSpecificationTypeDef:
return ToolSpecificationTypeDef(
@@ -80,10 +70,16 @@ def to_converse_spec(self) -> ToolSpecificationTypeDef:
inputSchema={"json": self._generate_input_schema()},
)
- def run(self, tool_use_id: str, input: dict[str, JsonValue]) -> ToolRunResult:
+ def run(
+ self,
+ tool_use_id: str,
+ input: dict[str, JsonValue],
+ model: type_model_name,
+ bot: BotModel | None = None,
+ ) -> ToolRunResult:
try:
arg = self.args_schema.model_validate(input)
- res = self.function(arg, self.bot, self.model)
+ res = self.function(arg, bot, model)
if isinstance(res, list):
related_documents = [
_function_result_to_related_document(
diff --git a/backend/app/agents/tools/knowledge.py b/backend/app/agents/tools/knowledge.py
index e90e16cbc..3fa1db6cf 100644
--- a/backend/app/agents/tools/knowledge.py
+++ b/backend/app/agents/tools/knowledge.py
@@ -39,7 +39,7 @@ def search_knowledge(
raise e
-def create_knowledge_tool(bot: BotModel, model: type_model_name) -> AgentTool:
+def create_knowledge_tool(bot: BotModel) -> AgentTool:
description = (
"Answer a user's question using information. The description is: {}".format(
bot.knowledge.__str_in_claude_format__()
@@ -51,6 +51,4 @@ def create_knowledge_tool(bot: BotModel, model: type_model_name) -> AgentTool:
description=description,
args_schema=KnowledgeToolInput,
function=search_knowledge,
- bot=bot,
- model=model,
)
diff --git a/backend/app/bedrock.py b/backend/app/bedrock.py
index cdfbb5f67..abe87ee6b 100644
--- a/backend/app/bedrock.py
+++ b/backend/app/bedrock.py
@@ -1,29 +1,34 @@
+from __future__ import annotations
+
import logging
import os
-from typing import TypeGuard, Dict, Any, Optional, Tuple
+from typing import TypeGuard, Dict, Any, Optional, Tuple, TYPE_CHECKING
-from app.agents.tools.agent_tool import AgentTool
from app.config import BEDROCK_PRICING
from app.config import DEFAULT_GENERATION_CONFIG as DEFAULT_CLAUDE_GENERATION_CONFIG
from app.config import DEFAULT_MISTRAL_GENERATION_CONFIG
-from app.repositories.models.conversation import (
- SimpleMessageModel,
- ContentModel,
-)
+
from app.repositories.models.custom_bot import GenerationParamsModel
from app.repositories.models.custom_bot_guardrails import BedrockGuardrailsModel
from app.routes.schemas.conversation import type_model_name
from app.utils import get_bedrock_runtime_client
-from mypy_boto3_bedrock_runtime.type_defs import (
- ConverseStreamRequestRequestTypeDef,
- MessageTypeDef,
- ConverseResponseTypeDef,
- ContentBlockTypeDef,
- GuardrailConverseContentBlockTypeDef,
- InferenceConfigurationTypeDef,
-)
-from mypy_boto3_bedrock_runtime.literals import ConversationRoleType
+if TYPE_CHECKING:
+ from app.agents.tools.agent_tool import AgentTool
+ from app.repositories.models.conversation import (
+ SimpleMessageModel,
+ ContentModel,
+ )
+ from mypy_boto3_bedrock_runtime.type_defs import (
+ ConverseStreamRequestRequestTypeDef,
+ MessageTypeDef,
+ ConverseResponseTypeDef,
+ ContentBlockTypeDef,
+ GuardrailConverseContentBlockTypeDef,
+ InferenceConfigurationTypeDef,
+ SystemContentBlockTypeDef,
+ )
+ from mypy_boto3_bedrock_runtime.literals import ConversationRoleType
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@@ -46,7 +51,7 @@ def _is_conversation_role(role: str) -> TypeGuard[ConversationRoleType]:
return role in ["user", "assistant"]
-def _is_nova_model(model: type_model_name) -> bool:
+def is_nova_model(model: type_model_name) -> bool:
"""Check if the model is an Amazon Nova model"""
return model in ["amazon-nova-pro", "amazon-nova-lite", "amazon-nova-micro"]
@@ -83,7 +88,14 @@ def _prepare_nova_model_params(
# Add top_k if specified in generation params
if generation_params and generation_params.top_k is not None:
- additional_fields["inferenceConfig"]["topK"] = generation_params.top_k
+ top_k = generation_params.top_k
+ if top_k > 128:
+ logger.warning(
+ "In Amazon Nova, an 'unexpected error' occurs if topK exceeds 128. To avoid errors, the upper limit of A is set to 128."
+ )
+ top_k = 128
+
+ additional_fields["inferenceConfig"]["topK"] = top_k
return inference_config, additional_fields
@@ -131,11 +143,24 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
]
# Prepare model-specific parameters
- if _is_nova_model(model):
+ inference_config: InferenceConfigurationTypeDef
+ additional_model_request_fields: dict[str, Any]
+ system_prompts: list[SystemContentBlockTypeDef]
+ if is_nova_model(model):
# Special handling for Nova models
inference_config, additional_model_request_fields = _prepare_nova_model_params(
model, generation_params
)
+ system_prompts = (
+ [
+ {
+ "text": "\n\n".join(instructions),
+ }
+ ]
+ if len(instructions) > 0
+ else []
+ )
+
else:
# Standard handling for non-Nova models
inference_config = {
@@ -167,17 +192,20 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
else DEFAULT_GENERATION_CONFIG["top_k"]
)
}
+ system_prompts = [
+ {
+ "text": instruction,
+ }
+ for instruction in instructions
+ if len(instruction) > 0
+ ]
# Construct the base arguments
args: ConverseStreamRequestRequestTypeDef = {
"inferenceConfig": inference_config,
"modelId": get_model_id(model),
"messages": arg_messages,
- "system": [
- {"text": instruction}
- for instruction in instructions
- if len(instruction) > 0
- ],
+ "system": system_prompts,
"additionalModelRequestFields": additional_model_request_fields,
}
diff --git a/backend/app/prompt.py b/backend/app/prompt.py
index 753916d98..680167d0c 100644
--- a/backend/app/prompt.py
+++ b/backend/app/prompt.py
@@ -1,14 +1,18 @@
+from app.bedrock import is_nova_model
from app.vector_search import SearchResult
+from app.routes.schemas.conversation import type_model_name
def build_rag_prompt(
search_results: list[SearchResult],
+ model: type_model_name,
display_citation: bool = True,
) -> str:
context_prompt = ""
for result in search_results:
context_prompt += f"\n\n{result['content']}\n\n"
+ # Prompt for RAG
inserted_prompt = """To answer the user's question, you are given a set of search results. Your job is to answer the user's question using only information from the search results.
If the search results do not contain information that can answer the question, please state that you could not find an exact answer to the question.
Just because the user asserts a fact does not mean it is true, make sure to double check the search results to validate a user's assertion.
@@ -24,6 +28,7 @@ def build_rag_prompt(
)
if display_citation:
+ # Prompt for 'Retrieved Context Citation'.
inserted_prompt += """
If you reference information from a search result within your answer, you must include a citation to source where the information was found.
Each result has a corresponding source ID that you should reference.
@@ -32,7 +37,23 @@ def build_rag_prompt(
Do NOT outputs sources at the end of your answer.
Followings are examples of how to reference sources in your answer. Note that the source ID is embedded in the answer in the format [^].
+"""
+ # Prompt to output Markdown-style citation.
+ if is_nova_model(model=model):
+ # For Amazon Nova, provides only good examples.
+ inserted_prompt += """
+
+first answer [^3]. second answer [^1][^2].
+
+
+
+first answer [^1][^5]. second answer [^2][^3][^4]. third answer [^4].
+
+"""
+ else:
+ # For other models, provide good examples and bad examples.
+ inserted_prompt += """
first answer [^3]. second answer [^1][^2].
@@ -57,9 +78,17 @@ def build_rag_prompt(
"""
else:
+ # Prompt when 'Retrieved Context Citation' is not specified.
inserted_prompt += """
Do NOT include citations in the format [^] in your answer.
+"""
+ if is_nova_model(model=model):
+ # For Amazon Nova, do not provide examples.
+ pass
+ else:
+ # For other models, suppress output of Markdown-style citation.
+ inserted_prompt += """
Followings are examples of how to answer.
@@ -78,7 +107,9 @@ def build_rag_prompt(
return inserted_prompt
-PROMPT_TO_CITE_TOOL_RESULTS = """To answer the user's question, you are given a set of tools. Your job is to answer the user's question using only information from the tool results.
+def get_prompt_to_cite_tool_results(model: type_model_name) -> str:
+ # Prompt for 'Retrieved Context Citation' of agent chat.
+ inserted_prompt = """To answer the user's question, you are given a set of tools. Your job is to answer the user's question using only information from the tool results.
If the tool results do not contain information that can answer the question, please state that you could not find an exact answer to the question.
Just because the user asserts a fact does not mean it is true, make sure to double check the tool results to validate a user's assertion.
@@ -86,6 +117,23 @@ def build_rag_prompt(
If you reference information from a tool result within your answer, you must include a citation to source_id where the information was found.
Followings are examples of how to reference source_id in your answer. Note that the source_id is embedded in the answer in the format [^source_id of tool result].
+"""
+ # Prompt to output Markdown-style citation.
+ if is_nova_model(model=model):
+ # For Amazon Nova, provides only good examples.
+ inserted_prompt += """
+
+first answer [^ccc]. second answer [^aaa][^bbb].
+
+
+
+first answer [^aaa][^eee]. second answer [^bbb][^ccc][^ddd]. third answer [^ddd].
+
+"""
+
+ else:
+ # For other models, provide good examples and bad examples.
+ inserted_prompt += """
first answer [^ccc]. second answer [^aaa][^bbb].
@@ -110,3 +158,5 @@ def build_rag_prompt(
"""
+
+ return inserted_prompt
diff --git a/backend/app/repositories/conversation.py b/backend/app/repositories/conversation.py
index 9a9917b04..bf8856630 100644
--- a/backend/app/repositories/conversation.py
+++ b/backend/app/repositories/conversation.py
@@ -272,6 +272,7 @@ def delete_large_messages(items):
except ClientError as e:
logger.error(f"An error occurred: {e.response['Error']['Message']}")
+ raise e
def change_conversation_title(user_id: str, conversation_id: str, new_title: str):
diff --git a/backend/app/repositories/models/conversation.py b/backend/app/repositories/models/conversation.py
index 1b87938f8..5d3725d66 100644
--- a/backend/app/repositories/models/conversation.py
+++ b/backend/app/repositories/models/conversation.py
@@ -1,8 +1,9 @@
from __future__ import annotations
+import json
import re
from pathlib import Path
-from typing import Annotated, Any, Literal, Self, TypedDict, TypeGuard
+from typing import Annotated, Any, Literal, Self, TypeGuard, TYPE_CHECKING
from urllib.parse import urlparse
from app.repositories.models.common import Base64EncodedBytes
@@ -36,6 +37,9 @@
)
from pydantic import BaseModel, Discriminator, Field, JsonValue, field_validator
+if TYPE_CHECKING:
+ from app.agents.tools.agent_tool import ToolRunResult
+
class TextContentModel(BaseModel):
content_type: Literal["text"]
@@ -474,6 +478,61 @@ def from_tool_result_content(cls, content: ToolResultContent) -> Self:
body=ToolResultContentModelBody.from_tool_result_content_body(content.body),
)
+ @classmethod
+ def from_tool_run_result(
+ cls,
+ run_result: ToolRunResult,
+ model: type_model_name,
+ display_citation: bool,
+ ) -> Self:
+ result_contents = [
+ related_document.to_tool_result_model(
+ display_citation=display_citation,
+ )
+ for related_document in run_result["related_documents"]
+ ]
+
+ from app.bedrock import is_nova_model
+
+ if is_nova_model(model=model):
+ text_or_json_contents = [
+ result_content
+ for result_content in result_contents
+ if isinstance(result_content, TextToolResultModel)
+ or isinstance(result_content, JsonToolResultModel)
+ ]
+ if len(text_or_json_contents) > 1:
+ return cls(
+ content_type="toolResult",
+ body=ToolResultContentModelBody(
+ tool_use_id=run_result["tool_use_id"],
+ content=[
+ TextToolResultModel(
+ text=json.dumps(
+ [
+ (
+ content.json_
+ if isinstance(content, JsonToolResultModel)
+ else content.text
+ )
+ for content in text_or_json_contents
+ ]
+ ),
+ ),
+ ],
+ status=run_result["status"],
+ ),
+ )
+
+ return cls(
+ content_type="toolResult",
+ body=ToolResultContentModelBody(
+ tool_use_id=run_result["tool_use_id"],
+ content=result_contents,
+ status=run_result["status"],
+ ),
+ )
+
def to_content(self) -> Content:
return ToolResultContent(
content_type="toolResult",
diff --git a/backend/app/usecases/chat.py b/backend/app/usecases/chat.py
index 44d7808db..6ab996e37 100644
--- a/backend/app/usecases/chat.py
+++ b/backend/app/usecases/chat.py
@@ -3,12 +3,11 @@
from app.agents.tools.agent_tool import (
ToolRunResult,
- run_result_to_tool_result_content_model,
)
from app.agents.tools.knowledge import create_knowledge_tool
from app.agents.utils import get_tool_by_name
from app.bedrock import call_converse_api, compose_args_for_converse_api
-from app.prompt import PROMPT_TO_CITE_TOOL_RESULTS, build_rag_prompt
+from app.prompt import build_rag_prompt, get_prompt_to_cite_tool_results
from app.repositories.conversation import (
RecordNotFoundError,
find_conversation_by_id,
@@ -260,11 +259,15 @@ def chat(
if bot.is_agent_enabled():
if bot.has_knowledge():
# Add knowledge tool
- knowledge_tool = create_knowledge_tool(bot, chat_input.message.model)
+ knowledge_tool = create_knowledge_tool(bot=bot)
tools[knowledge_tool.name] = knowledge_tool
if display_citation:
- instructions.append(PROMPT_TO_CITE_TOOL_RESULTS)
+ instructions.append(
+ get_prompt_to_cite_tool_results(
+ model=chat_input.message.model,
+ )
+ )
elif bot.has_knowledge():
# Fetch most related documents from vector store
@@ -306,6 +309,7 @@ def chat(
instructions.append(
build_rag_prompt(
search_results=search_results,
+ model=chat_input.message.model,
display_citation=display_citation,
)
)
@@ -432,6 +436,8 @@ def chat(
run_result = tool.run(
tool_use_id=content.body.tool_use_id,
input=content.body.input,
+ model=chat_input.message.model,
+ bot=bot,
)
run_results.append(run_result)
@@ -444,8 +450,9 @@ def chat(
tool_result_message = SimpleMessageModel(
role="user",
content=[
- run_result_to_tool_result_content_model(
+ ToolResultContentModel.from_tool_run_result(
run_result=result,
+ model=chat_input.message.model,
display_citation=display_citation,
)
for result in run_results
diff --git a/backend/app/websocket.py b/backend/app/websocket.py
index e601f3520..f59e3f93b 100644
--- a/backend/app/websocket.py
+++ b/backend/app/websocket.py
@@ -69,13 +69,13 @@ def run(self):
gatewayapi.exceptions.GoneException,
gatewayapi.exceptions.ForbiddenException,
) as e:
- logger.error(
+ logger.exception(
f"Shutdown the notification sender due to an exception: {e}"
)
break
except Exception as e:
- logger.error(f"Failed to send notification: {e}")
+ logger.exception(f"Failed to send notification: {e}")
elif command["type"] == "finish":
break
@@ -212,7 +212,7 @@ def process_chat_input(
}
except Exception as e:
- logger.error(f"Failed to run stream handler: {e}")
+ logger.exception(f"Failed to run stream handler: {e}")
return {
"statusCode": 500,
"body": json.dumps(
@@ -269,7 +269,7 @@ def handler(event, context):
# Verify JWT token
decoded = verify_token(token)
except Exception as e:
- logger.error(f"Invalid token: {e}")
+ logger.exception(f"Invalid token: {e}")
return {
"statusCode": 403,
"body": json.dumps(
@@ -356,8 +356,7 @@ def handler(event, context):
return {"statusCode": 200, "body": "Message part received."}
except Exception as e:
- logger.error(f"Operation failed: {e}")
- logger.error("".join(traceback.format_tb(e.__traceback__)))
+ logger.exception(f"Operation failed: {e}")
return {
"statusCode": 500,
"body": json.dumps(
diff --git a/backend/tests/test_agent/test_tools/test_agent_tool.py b/backend/tests/test_agent/test_tools/test_agent_tool.py
index 19f7e177c..54244f8ee 100644
--- a/backend/tests/test_agent/test_tools/test_agent_tool.py
+++ b/backend/tests/test_agent/test_tools/test_agent_tool.py
@@ -90,7 +90,11 @@ def test_run(self):
arg3=1,
arg4=["test"],
)
- result = self.tool.run(tool_use_id="dummy", input=arg.model_dump())
+ result = self.tool.run(
+ tool_use_id="dummy",
+ input=arg.model_dump(),
+ model="claude-v3.5-sonnet-v2",
+ )
self.assertEqual(
result["related_documents"],
[
diff --git a/backend/tests/test_agent/test_tools/test_internet_search.py b/backend/tests/test_agent/test_tools/test_internet_search.py
index 5dbfa7bf5..372ba41e6 100644
--- a/backend/tests/test_agent/test_tools/test_internet_search.py
+++ b/backend/tests/test_agent/test_tools/test_internet_search.py
@@ -13,7 +13,11 @@ def test_internet_search(self):
time_limit = "d"
country = "jp-jp"
arg = InternetSearchInput(query=query, time_limit=time_limit, country=country)
- response = internet_search_tool.run(tool_use_id="dummy", input=arg.model_dump())
+ response = internet_search_tool.run(
+ tool_use_id="dummy",
+ input=arg.model_dump(),
+ model="claude-v3.5-sonnet-v2",
+ )
self.assertIsInstance(response["related_documents"], list)
self.assertEqual(response["status"], "success")
print(response)
diff --git a/backend/tests/test_agent/test_tools/test_knowledge.py b/backend/tests/test_agent/test_tools/test_knowledge.py
index c5cf3c5d5..e15913ffd 100644
--- a/backend/tests/test_agent/test_tools/test_knowledge.py
+++ b/backend/tests/test_agent/test_tools/test_knowledge.py
@@ -5,6 +5,7 @@
from app.agents.tools.knowledge import KnowledgeToolInput, create_knowledge_tool
from app.repositories.models.custom_bot import (
+ ActiveModelsModel,
AgentModel,
BotModel,
GenerationParamsModel,
@@ -53,10 +54,15 @@ def test_knowledge_tool(self):
conversation_quick_starters=[],
bedrock_knowledge_base=None,
bedrock_guardrails=None,
+ active_models=ActiveModelsModel(),
)
arg = KnowledgeToolInput(query="What are delicious Japanese dishes?")
- tool = create_knowledge_tool(bot, model="claude-v3-sonnet")
- response = tool.run(tool_use_id="dummy", input=arg.model_dump())
+ tool = create_knowledge_tool(bot=bot)
+ response = tool.run(
+ tool_use_id="dummy",
+ input=arg.model_dump(),
+ model="claude-v3.5-sonnet-v2",
+ )
self.assertIsInstance(response["related_documents"], list)
self.assertEqual(response["status"], "success")
print(response)
diff --git a/backend/tests/test_usecases/test_chat.py b/backend/tests/test_usecases/test_chat.py
index db530066b..ea3846a75 100644
--- a/backend/tests/test_usecases/test_chat.py
+++ b/backend/tests/test_usecases/test_chat.py
@@ -998,6 +998,7 @@ def test_insert_knowledge(self):
]
instruction = build_rag_prompt(
search_results=results,
+ model="claude-v3.5-sonnet-v2",
display_citation=True,
)
print(instruction)
diff --git a/docs/AGENT.md b/docs/AGENT.md
index 223da0e59..90af9324d 100644
--- a/docs/AGENT.md
+++ b/docs/AGENT.md
@@ -10,11 +10,11 @@ This sample implements an Agent using the [ReAct (Reasoning + Acting)](https://w
An Agent using ReAct can be applied in various scenarios, providing accurate and efficient solutions.
-### Text-to-SQL Example
+### Text-to-SQL
-A user asks for "the total sales for the last quarter." The Agent interprets this request, converts it into a SQL query, executes it against the database, and presents the results. For the detail, see: [Text-to-SQL tool](../examples/agents/tools/text_to_sql/)
+A user asks for "the total sales for the last quarter." The Agent interprets this request, converts it into a SQL query, executes it against the database, and presents the results.
-### Financial Forecasting Example
+### Financial Forecasting
A financial analyst needs to forecast next quarter's revenue. The Agent gathers relevant data, performs necessary calculations using financial models, and generates a detailed forecast report, ensuring the accuracy of the projections.
@@ -46,7 +46,7 @@ This tool depends [DuckDuckGo](https://duckduckgo.com/) which has rate limit. It
To develop your own custom tools for the Agent, follow these guidelines:
-- Create a new class that inherits from the `BaseTool` class. Although the interface is compatible with LangChain, this sample implementation provides its own `BaseTool` class, which you should inherit from ([source](../backend/app/agents/tools/base.py)).
+- Create a new class that inherits from the `AgentTool` class. Although the interface is compatible with LangChain, this sample implementation provides its own `AgentTool` class, which you should inherit from ([source](../backend/app/agents/tools/agent_tool.py)).
- Refer to the sample implementation of a [BMI calculation tool](../examples/agents/tools/bmi/bmi.py). This example demonstrates how to create a tool that calculates the Body Mass Index (BMI) based on user input.
@@ -63,8 +63,6 @@ To develop your own custom tools for the Agent, follow these guidelines:
- Run `cdk deploy` to deploy your changes. This will make your custom tool available in the custom bot screen.
-In addition to the BMI calculation example, there are other tool examples available for reference, including [Text-to-SQL](../examples/agents/tools/text_to_sql/). Feel free to explore these [examples](../examples/agents/tools/) to gain insights and inspiration for creating your own tools.
-
## Contribution
**Contributions to the tool repository are welcome!** If you develop a useful and well-implemented tool, consider contributing it to the project by submitting an issue or a pull request.
diff --git a/examples/agents/tools/bmi/test_bmi.py b/examples/agents/tools/bmi/test_bmi.py
index b83c86d31..ff689fa40 100644
--- a/examples/agents/tools/bmi/test_bmi.py
+++ b/examples/agents/tools/bmi/test_bmi.py
@@ -8,7 +8,14 @@
class TestBmiTool(unittest.TestCase):
def test_bmi(self):
- result = bmi_tool.run(tool_use_id="dummy", input={"height": 170, "weight": 70})
+ result = bmi_tool.run(
+ tool_use_id="dummy",
+ input={
+ "height": 170,
+ "weight": 70,
+ },
+ model="claude-v3.5-sonnet-v2",
+ )
print(result)
self.assertEqual(type(result), str)