From b1b32c060101aa8c5f2b64a63de5cc51453cce0c Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Sun, 12 Jan 2025 00:17:17 +0100 Subject: [PATCH] Enable ALPN protocol in `reqwest` --- .devcontainer/Dockerfile | 3 +- engine/Cargo.lock | 100 ++++++++++++------ engine/Cargo.toml | 3 +- engine/baml-runtime/Cargo.toml | 2 +- .../primitive/google/googleai_client.rs | 12 ++- .../src/internal/llm_client/traits/mod.rs | 15 +-- 6 files changed, 87 insertions(+), 48 deletions(-) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index b9a041dd7..aaef43ddb 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -4,6 +4,7 @@ FROM ubuntu:24.04 RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y \ build-essential \ clang \ + pkg-config \ zlib1g \ zlib1g-dev \ libssl-dev \ @@ -29,7 +30,7 @@ RUN curl https://mise.run | sh \ && ~/.local/bin/mise --verbose use -g --yes ruby@3.1 node@20.14 pnpm@9.9 \ && echo 'export PATH="$HOME/.local/bin:$HOME/.local/share/mise/shims:$PATH"' >> ~/.bash_profile \ && echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/share/mise/installs/ruby/3.1/lib' >> ~/.bash_profile - + # Install Rust RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y diff --git a/engine/Cargo.lock b/engine/Cargo.lock index 71ba45d1a..a672d3395 100644 --- a/engine/Cargo.lock +++ b/engine/Cargo.lock @@ -735,7 +735,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.1", "tokio", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -780,7 +780,7 @@ dependencies = [ "pin-project-lite", "serde", "serde_json", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -1261,9 +1261,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.6.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "bytes-utils" @@ -2446,9 +2446,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.6" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" dependencies = [ "bytes", "futures-channel", @@ -2459,7 +2459,6 @@ dependencies = [ "pin-project-lite", "socket2 0.5.7", "tokio", - "tower", "tower-service", "tracing", ] @@ -3448,9 +3447,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-src" -version = "300.3.1+3.3.1" +version = "300.4.1+3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7259953d42a81bf137fbbd73bd30a8e1914d6dce43c2b90ed575783a22608b91" +checksum = "faa4eac4138c62414b5622d1b31c5c304f34b406b013c079c2bbc652fdd6678c" dependencies = [ "cc", ] @@ -4081,9 +4080,9 @@ checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" [[package]] name = "reqwest" -version = "0.12.5" +version = "0.12.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37" +checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" dependencies = [ "base64 0.22.1", "bytes", @@ -4115,13 +4114,14 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-util", + "tower 0.5.2", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "wasm-streams", "web-sys", - "winreg", + "windows-registry", ] [[package]] @@ -4778,23 +4778,26 @@ name = "sync_wrapper" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +dependencies = [ + "futures-core", +] [[package]] name = "system-configuration" -version = "0.5.1" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.6.0", "core-foundation", "system-configuration-sys", ] [[package]] name = "system-configuration-sys" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" dependencies = [ "core-foundation-sys", "libc", @@ -5081,17 +5084,32 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 1.0.1", + "tokio", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -5614,6 +5632,36 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-registry" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +dependencies = [ + "windows-result", + "windows-strings", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result", + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -5771,16 +5819,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "winreg" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a277a57398d4bfa075df44f501a17cfdf8542d224f0d36095a2adc7aee4ef0a5" -dependencies = [ - "cfg-if", - "windows-sys 0.48.0", -] - [[package]] name = "winsafe" version = "0.0.19" diff --git a/engine/Cargo.toml b/engine/Cargo.toml index e454bcbbd..b26571a55 100644 --- a/engine/Cargo.toml +++ b/engine/Cargo.toml @@ -70,9 +70,10 @@ minijinja = { version = "1.0.16", default-features = false, features = [ pretty_assertions = "1.4.1" rand = "0.8.5" regex = "1.10.4" -reqwest = { version = "0.12.5", features = [ +reqwest = { version = "0.12.12", features = [ "json", "native-tls-vendored", + "native-tls-alpn", "stream", ] } scopeguard = "1.2.0" diff --git a/engine/baml-runtime/Cargo.toml b/engine/baml-runtime/Cargo.toml index fdd11b554..fabaeff76 100644 --- a/engine/baml-runtime/Cargo.toml +++ b/engine/baml-runtime/Cargo.toml @@ -103,7 +103,7 @@ colored = { version = "2.1.0", default-features = false, features = [ ] } futures-timer = { version = "3.0.3", features = ["wasm-bindgen"] } js-sys = "0.3.69" -reqwest = { version = "0.12.5", features = ["stream", "json"] } +reqwest = { version = "0.12.12", features = ["stream", "json", "native-tls-vendored", "native-tls-alpn"] } # send_wrapper = { version = "0.6.0", features = ["futures"] } serde-wasm-bindgen = "0.6.5" diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs index b92fa1f56..8fd314642 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs @@ -168,7 +168,11 @@ impl SseResponseTrait for GoogleAIClient { if let Some(choice) = event.candidates.get(0) { let part_index = content_part(&model_id); - if let Some(content) = choice.content.as_ref().and_then(|c| c.parts.get(part_index)) { + if let Some(content) = choice + .content + .as_ref() + .and_then(|c| c.parts.get(part_index)) + { inner.content += &content.text; } if let Some(FinishReason::Stop) = choice.finish_reason.as_ref() { @@ -458,11 +462,11 @@ impl ToProviderMessage for GoogleAIClient { /// The Google Gemini 2 model has an experimental feature /// called Flash Thinking Mode, which is turned on in a particular /// named model: gemini-2.0-flash-thinking-exp-1219 -/// +/// /// When run in this mode, Gemini returns `candidates` with 2 parts each. /// Part 0 is the chain of thought, part 1 is the actual output. /// Other Gemini models put the output data in part 0. -/// +/// /// TODO: Explicitly represent Flash Thinking Mode response and /// do more thorough checking for the content part. /// For examples of how to introspect the response more safely, see: @@ -473,4 +477,4 @@ fn content_part(model_name: &str) -> usize { } else { 0 } -} \ No newline at end of file +} diff --git a/engine/baml-runtime/src/internal/llm_client/traits/mod.rs b/engine/baml-runtime/src/internal/llm_client/traits/mod.rs index 8d28d766b..f37457e94 100644 --- a/engine/baml-runtime/src/internal/llm_client/traits/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/traits/mod.rs @@ -150,8 +150,8 @@ where { #[allow(async_fn_in_trait)] async fn single_call(&self, ctx: &RuntimeContext, prompt: &RenderedPrompt) -> LLMResponse { - if let RenderedPrompt::Chat(chat) = &prompt { - match process_media_urls( + match prompt { + RenderedPrompt::Chat(chat) => match process_media_urls( self.model_features().resolve_media_urls, true, None, @@ -160,15 +160,10 @@ where ) .await { - Ok(messages) => return self.chat(ctx, &messages).await, - Err(e) => { - return LLMResponse::InternalFailure(format!("Error occurred:\n\n{:?}", e)) - } - } - } + Ok(messages) => self.chat(ctx, &messages).await, + Err(e) => LLMResponse::InternalFailure(format!("Error occurred:\n\n{:?}", e)), + }, - match prompt { - RenderedPrompt::Chat(p) => self.chat(ctx, p).await, RenderedPrompt::Completion(p) => self.completion(ctx, p).await, } }