Skip to content

Commit

Permalink
Fix Vertex system prompt (#1319)
Browse files Browse the repository at this point in the history
The request body for Vertex calls was malformed. Example function:

```baml
function TestVertexWithSystemInstructions() -> string {
  client Vertex
  prompt #"{{_.role("system")}} You are a helpful assistant
  {{_.role("user")}} Write a poem about llamas
  "#
}
```
Request body:

```json
{
  "contents": [
    {
      "role": "system",
      "parts": [
        {
          "text": "You are a helpful assistant"
        }
      ]
    },
    {
      "role": "user",
      "parts": [
        {
          "text": "Write a poem about llamas"
        }
      ]
    }
  ]
}
```

The expected JSON should look like this:

```json
{
  "system_instruction": {
    "parts": {
      "text": "You are a helpful assistant"
    }
  },
  "contents": [
    {
      "role": "user",
      "parts": [
        {
          "text": "Write a poem about llamas"
        }
      ]
    }
  ]
}
```

This PR updates the vertex client with the correct syntax.

- [Gemini API
Docs](https://ai.google.dev/gemini-api/docs/text-generation?lang=rest#system-instructions)
- [Discord Thread (Bug
Report)](https://discord.com/channels/1119368998161752075/1326035112852328541/1326979978050408582)
<!-- ELLIPSIS_HIDDEN -->

----

> [!IMPORTANT]
> Fixes Vertex client request body to correctly handle system
instructions and adds related tests.
> 
>   - **Behavior**:
> - Fixes request body format in `vertex_client.rs` to correctly handle
system instructions by separating them from user content.
> - Logs a warning if multiple system instructions are provided, using
only the last one.
>   - **Testing**:
> - Adds `TestVertexWithSystemInstructions` function in `vertex.baml`
and `async_client.py`.
> - Adds test case `should support vertex with system instructions` in
`vertex.test.ts`.
>   - **Misc**:
>     - Adds `CustomRetry` policy in `prompt_fiddle_example.baml`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 52350de. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
antoniosarosi authored Jan 12, 2025
1 parent 1ea1d8b commit 4b7db0f
Show file tree
Hide file tree
Showing 15 changed files with 311 additions and 871 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,17 @@ client<llm> GPT4 {

client<llm> GPT4o {
provider openai
retry_policy CustomRetry
options {
model gpt-4o
api_key env.OPENAI_API_KEY
}
}

retry_policy CustomRetry {
max_retries 3
}

client<llm> Claude {
provider anthropic
options {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,16 +547,34 @@ impl ToProviderMessageExt for VertexClient {
&self,
chat: &[RenderedChatMessage],
) -> Result<serde_json::Map<String, serde_json::Value>> {
// merge all adjacent roles of the same type
let mut res = serde_json::Map::new();

res.insert(
"contents".into(),
chat.iter()
.map(|c| self.role_to_message(c))
.collect::<Result<Vec<_>>>()?
.into(),
);
// https://ai.google.dev/gemini-api/docs/text-generation?lang=rest#system-instructions
let mut system_instructions = vec![];
let mut contents = vec![];

for rendered_chat_message in chat {
let mut message = self.role_to_message(rendered_chat_message)?;

if rendered_chat_message.role == "system" {
// No role here.
message.remove("role");
system_instructions.push(message);
} else {
// User-Model chat.
contents.push(message);
}
}

if let Some(system_instruction) = system_instructions.pop() {
res.insert("system_instruction".into(), system_instruction.into());
}

if !system_instructions.is_empty() {
log::warn!("Vertex API only supports one system instruction, using last one and ignoring the rest");
}

res.insert("contents".into(), contents.into());

Ok(res)
}
Expand Down
1 change: 0 additions & 1 deletion integ-tests/baml_src/clients.baml
Original file line number Diff line number Diff line change
Expand Up @@ -250,4 +250,3 @@ client<llm> TogetherAi {
model "meta-llama/Llama-3-70b-chat-hf"
}
}

9 changes: 8 additions & 1 deletion integ-tests/baml_src/test-files/providers/vertex.baml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,11 @@ function TestVertex(input: string) -> string {
prompt #"
Write a nice short story about {{ input }}
"#
}
}

function TestVertexWithSystemInstructions() -> string {
client Vertex
prompt #"{{_.role("system")}} You are a helpful assistant
{{_.role("user")}} Write a poem about llamas
"#
}
52 changes: 52 additions & 0 deletions integ-tests/python/baml_client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3086,6 +3086,29 @@ async def TestVertex(
)
return cast(str, raw.cast_to(types, types))

async def TestVertexWithSystemInstructions(
self,

baml_options: BamlCallOptions = {},
) -> str:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"TestVertexWithSystemInstructions",
{

},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(str, raw.cast_to(types, types))

async def UnionTest_Function(
self,
input: Union[str, bool],
Expand Down Expand Up @@ -7154,6 +7177,35 @@ def TestVertex(
self.__ctx_manager.get(),
)

def TestVertexWithSystemInstructions(
self,

baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[Optional[str], str]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"TestVertexWithSystemInstructions",
{
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlStream[Optional[str], str](
raw,
lambda x: cast(Optional[str], x.cast_to(types, partial_types)),
lambda x: cast(str, x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def UnionTest_Function(
self,
input: Union[str, bool],
Expand Down
4 changes: 2 additions & 2 deletions integ-tests/python/baml_client/inlinedbaml.py

Large diffs are not rendered by default.

52 changes: 52 additions & 0 deletions integ-tests/python/baml_client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3083,6 +3083,29 @@ def TestVertex(
)
return cast(str, raw.cast_to(types, types))

def TestVertexWithSystemInstructions(
self,

baml_options: BamlCallOptions = {},
) -> str:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.call_function_sync(
"TestVertexWithSystemInstructions",
{

},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(str, raw.cast_to(types, types))

def UnionTest_Function(
self,
input: Union[str, bool],
Expand Down Expand Up @@ -7152,6 +7175,35 @@ def TestVertex(
self.__ctx_manager.get(),
)

def TestVertexWithSystemInstructions(
self,

baml_options: BamlCallOptions = {},
) -> baml_py.BamlSyncStream[Optional[str], str]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function_sync(
"TestVertexWithSystemInstructions",
{
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlSyncStream[Optional[str], str](
raw,
lambda x: cast(Optional[str], x.cast_to(types, partial_types)),
lambda x: cast(str, x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def UnionTest_Function(
self,
input: Union[str, bool],
Expand Down
6 changes: 6 additions & 0 deletions integ-tests/python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,12 @@ async def test_should_work_with_vertex():
assert_that("donkey kong" in res.lower())


@pytest.mark.asyncio
async def test_should_work_with_vertex_adding_system_instructions():
res = await b.TestVertexWithSystemInstructions()
assert_that(len(res) > 0)


@pytest.mark.asyncio
async def test_should_work_with_image_base64():
res = await b.TestImageInput(img=baml_py.Image.from_base64("image/png", image_b64))
Expand Down
67 changes: 67 additions & 0 deletions integ-tests/ruby/baml_client/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4274,6 +4274,38 @@ def TestVertex(
(raw.parsed_using_types(Baml::Types))
end

sig {
params(
varargs: T.untyped,

baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)]
).returns(String)
}
def TestVertexWithSystemInstructions(
*varargs,

baml_options: {}
)
if varargs.any?

raise ArgumentError.new("TestVertexWithSystemInstructions may only be called with keyword arguments")
end
if (baml_options.keys - [:client_registry, :tb]).any?
raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}")
end

raw = @runtime.call_function(
"TestVertexWithSystemInstructions",
{

},
@ctx_manager,
baml_options[:tb]&.instance_variable_get(:@registry),
baml_options[:client_registry],
)
(raw.parsed_using_types(Baml::Types))
end

sig {
params(
varargs: T.untyped,
Expand Down Expand Up @@ -9035,6 +9067,41 @@ def TestVertex(
)
end

sig {
params(
varargs: T.untyped,

baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)]
).returns(Baml::BamlStream[String])
}
def TestVertexWithSystemInstructions(
*varargs,

baml_options: {}
)
if varargs.any?

raise ArgumentError.new("TestVertexWithSystemInstructions may only be called with keyword arguments")
end
if (baml_options.keys - [:client_registry, :tb]).any?
raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}")
end

raw = @runtime.stream_function(
"TestVertexWithSystemInstructions",
{

},
@ctx_manager,
baml_options[:tb]&.instance_variable_get(:@registry),
baml_options[:client_registry],
)
Baml::BamlStream[T.nilable(String), String].new(
ffi_stream: raw,
ctx_manager: @ctx_manager
)
end

sig {
params(
varargs: T.untyped,
Expand Down
Loading

0 comments on commit 4b7db0f

Please sign in to comment.