diff --git a/backend/app/agent_types/openai_agent.py b/backend/app/agent_types/openai_agent.py index 5de2ac8a..1812cbec 100644 --- a/backend/app/agent_types/openai_agent.py +++ b/backend/app/agent_types/openai_agent.py @@ -10,6 +10,7 @@ from langgraph.prebuilt import ToolExecutor, ToolInvocation from app.message_types import LiberalToolMessage +from app.messages import select_conversation_messages def get_openai_agent_executor( @@ -30,7 +31,7 @@ async def _get_messages(messages): else: msgs.append(m) - return [SystemMessage(content=system_message)] + msgs + return [SystemMessage(content=system_message)] + select_conversation_messages(msgs) if tools: llm_with_tools = llm.bind(tools=[format_tool_to_openai_tool(t) for t in tools]) diff --git a/backend/app/messages.py b/backend/app/messages.py new file mode 100644 index 00000000..cf943ec2 --- /dev/null +++ b/backend/app/messages.py @@ -0,0 +1,28 @@ +from langchain_core.messages import HumanMessage, AIMessage, BaseMessage + +from typing import Sequence + +def select_conversation_messages(messages: Sequence[BaseMessage]): + """Select only user input <> completion pairs and current scratchpad. + + Ignore previous scratchpads (function calls, etc).""" + new_messages = [] + _messages = [] + for m in messages: + if isinstance(m, HumanMessage): + # if the last message in the existing run is NOT AIMessage, then + # that means something interrupted it, so let's ignore this + if not isinstance(_messages[-1], AIMessage): + continue + # Otherwise, we add the first (Human) and last (AI) message to the + # full list of messages + new_messages.append(_messages[0]) + new_messages.append(_messages[-1]) + # Start a new list of messages + _messages = [m] + else: + _messages.append(m) + # Now we add the final messages to the list of messages + # This are all messages that are part of the current scratchpad + new_messages.extend(_messages) + return new_messages