Skip to content

Commit

Permalink
Sync again
Browse files Browse the repository at this point in the history
  • Loading branch information
tclem committed Jan 10, 2024
1 parent bdaf9e4 commit 716c187
Show file tree
Hide file tree
Showing 12 changed files with 520 additions and 491 deletions.
800 changes: 463 additions & 337 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Add the `twirp-build` crate as a build dependency in your `Cargo.toml` (you'll n
# Cargo.toml
[build-dependencies]
twirp-build = "0.1"
prost-build = "0.11"
prost-build = "0.12"
```

Add a `build.rs` file to your project to compile the protos and generate Rust code:
Expand Down
2 changes: 1 addition & 1 deletion crates/twirp-build/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ authors = ["The blackbird team <[email protected]>"]
edition = "2021"

[dependencies]
prost-build = "0.11"
prost-build = "0.12"
4 changes: 2 additions & 2 deletions crates/twirp-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ where
// Define: <METHOD>
writeln!(
buf,
" async fn {}(&self, req: {}) -> Result<{}, twirp::client::ClientError>;",
" async fn {}(&self, req: {}) -> Result<{}, twirp::ClientError>;",
m.name, m.input_type, m.output_type,
)
.unwrap();
Expand All @@ -96,7 +96,7 @@ where
// Define the rpc `<METHOD>`
writeln!(
buf,
" async fn {}(&self, req: {}) -> Result<{}, twirp::client::ClientError> {{",
" async fn {}(&self, req: {}) -> Result<{}, twirp::ClientError> {{",
m.name, m.input_type, m.output_type,
)
.unwrap();
Expand Down
8 changes: 4 additions & 4 deletions crates/twirp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ server = ["dep:hyper"]

[dependencies]
futures = "0.3"
prost = "0.11"
prost = "0.12"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"

# For the client feature
reqwest = { version = "0.11", optional = true }
url = { version = "2.4", optional = true }
reqwest = { version = "0.11", features = ["default", "gzip", "json"], optional = true }
url = { version = "2.5", optional = true }

# For the server feature
hyper = { version = "0.14", features = ["full"], optional = true }

# For the test-support feature
async-trait = { version = "0.1", optional = true }
tokio = { version = "1.28", features = [], optional = true }
tokio = { version = "1.33", features = [], optional = true }
26 changes: 11 additions & 15 deletions crates/twirp/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
use std::sync::Arc;

use async_trait::async_trait;
use hyper::header;
use hyper::header::{InvalidHeaderValue, CONTENT_TYPE};
use hyper::StatusCode;
use thiserror::Error;
use url::Url;

use crate::headers::*;
use crate::{error::*, to_proto_body};
use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF};
use crate::{to_proto_body, TwirpErrorResponse};

/// A twirp client error.
#[derive(Debug, Error)]
pub enum ClientError {
#[error(transparent)]
InvalidHeader(#[from] header::InvalidHeaderValue),
InvalidHeader(#[from] InvalidHeaderValue),
#[error("base_url must end in /, but got: {0}")]
InvalidBaseUrl(Url),
#[error(transparent)]
Expand All @@ -39,9 +38,8 @@ pub enum ClientError {
TwirpError(TwirpErrorResponse),
}

pub type Result<T> = core::result::Result<T, ClientError>;
pub type Result<T, E = ClientError> = std::result::Result<T, E>;

/// Use ClientBuilder to build a TwirpClient.
pub struct ClientBuilder {
base_url: Url,
http_client: reqwest::Client,
Expand All @@ -52,15 +50,11 @@ impl ClientBuilder {
pub fn new(base_url: Url, http_client: reqwest::Client) -> Self {
Self {
base_url,
http_client,
middleware: vec![],
http_client,
}
}

pub fn from_base_url(base_url: Url) -> Self {
Self::new(base_url, reqwest::Client::default())
}

/// Add middleware to the client that will be called on each request.
/// Middlewares are invoked in the order they are added as part of the
/// request cycle.
Expand Down Expand Up @@ -127,7 +121,7 @@ impl Client {
/// The underlying `reqwest::Client` holds a connection pool internally, so it is advised that
/// you create one and **reuse** it.
pub fn from_base_url(base_url: Url) -> Result<Self> {
Self::new(base_url, reqwest::Client::default(), vec![])
Self::new(base_url, reqwest::Client::new(), vec![])
}

/// Add middleware to this specific request stack. Middlewares are invoked
Expand All @@ -151,7 +145,7 @@ impl Client {
let req = self
.http_client
.post(url)
.header(header::CONTENT_TYPE, CONTENT_TYPE_PROTOBUF)
.header(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF)
.body(to_proto_body(body))
.build()?;

Expand All @@ -161,8 +155,9 @@ impl Client {

// These have to be extracted because reading the body consumes `Response`.
let status = resp.status();
let content_type = resp.headers().get(header::CONTENT_TYPE).cloned();
let content_type = resp.headers().get(CONTENT_TYPE).cloned();

// TODO: Include more info in the error cases: request path, content-type, etc.
match (status, content_type) {
(status, Some(ct)) if status.is_success() && ct.as_bytes() == CONTENT_TYPE_PROTOBUF => {
O::decode(resp.bytes().await?).map_err(|e| e.into())
Expand Down Expand Up @@ -268,6 +263,7 @@ mod tests {
#[tokio::test]
async fn test_routes() {
let base_url = Url::parse("http://localhost:3001/twirp/").unwrap();

let client = ClientBuilder::new(base_url, reqwest::Client::new())
.with(AssertRouting {
expected_url: "http://localhost:3001/twirp/test.TestAPI/Ping",
Expand Down
8 changes: 4 additions & 4 deletions crates/twirp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ pub mod server;
#[cfg(any(test, feature = "test-support"))]
pub mod test;

pub use client::*;
pub use error::*;
pub use server::*;
pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, Result};
pub use error::*; // many constructors like `invalid_argument()`
pub use server::{serve, Router};

// Re-export `reqwest` so that it's easy to implement middleware.
pub use reqwest;

// Re-export `url` so that the generated code works without additional dependencies beyond just the `twirp` crate.
// Re-export `url so that the generated code works without additional dependencies beyond just the `twirp` crate.
pub use url;

pub(crate) fn to_proto_body<T>(m: T) -> hyper::Body
Expand Down
113 changes: 13 additions & 100 deletions crates/twirp/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,21 @@ use futures::Future;
use hyper::{header, Body, Method, Request, Response};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::time::{Duration, Instant};

use crate::error::*;
use crate::headers::*;
use crate::to_proto_body;
use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF};
use crate::{error, to_proto_body, GenericError, TwirpErrorResponse};

/// A function that handles a request and returns a response.
type HandlerFn = Box<dyn Fn(Request<Body>) -> HandlerResponse + Send + Sync>;

/// Type alias for a handler response.
type HandlerResponse =
Box<dyn Future<Output = Result<Response<Body>, GenericError>> + Unpin + Send>;

/// A Router maps a request (method, path) tuple to a handler.
type HandlerFn = Box<dyn Fn(Request<Body>) -> HandlerResponse + Send + Sync>;

/// A Router maps a request to a handler.
pub struct Router {
routes: HashMap<(Method, String), HandlerFn>,
prefix: &'static str,
}

/// The canonical twirp path prefix. You don't have to use this, but it's the default.
pub const DEFAULT_TWIRP_PATH_PREFIX: &str = "/twirp";

impl Default for Router {
Expand Down Expand Up @@ -53,7 +48,7 @@ impl Router {
}
}

/// Adds a sync handler to the router for the given method and path.
/// Adds a handler to the router for the given method and path.
pub fn add_sync_handler<F>(&mut self, method: Method, path: &str, f: F)
where
F: Fn(Request<Body>) -> Result<Response<Body>, GenericError>
Expand Down Expand Up @@ -101,53 +96,35 @@ impl Router {
> {
let f = f.clone();
Box::new(Box::pin(async move {
let mut timings = *req
.extensions()
.get::<Timings>()
.expect("timings must exist");
timings.request_received();
match parse_request(req).await {
Ok((req, resp_fmt)) => {
timings.request_parsed();
let res = f(req).await;
timings.response_handled();
write_response(res, resp_fmt)
}
Ok((req, resp_fmt)) => write_response(f(req).await, resp_fmt),
Err(err) => {
// This is the only place we use tracing (would be nice to remove)
// tracing::error!(?err, "failed to parse request");
// TODO: We don't want to loose the underlying error
// here, but it might not be safe to include in the
// response like this always.
timings.request_parsed();
let mut twirp_err = malformed("bad request");
let mut twirp_err = error::malformed("bad request");
twirp_err.insert_meta("error".to_string(), err.to_string());
twirp_err.to_response()
}
}
.map(|mut resp| {
timings.response_written();
resp.extensions_mut().insert(timings);
resp
})
}))
};
let key = (Method::POST, [self.prefix, path].join("/"));
self.routes.insert(key, Box::new(g));
}
}

/// Serve a request using the given router.
pub async fn serve(
router: Arc<Router>,
mut req: Request<Body>,
req: Request<Body>,
) -> Result<Response<Body>, GenericError> {
req.extensions_mut().insert(Timings::default());
let key = (req.method().clone(), req.uri().path().to_string());
if let Some(handler) = router.routes.get(&key) {
handler(req).await
} else {
bad_route("not found").to_response()
error::bad_route("not found").to_response()
}
}

Expand Down Expand Up @@ -214,70 +191,6 @@ where
Ok(res)
}

#[derive(Debug, Clone, Copy)]
pub struct Timings {
// When the request started.
pub start: Instant,
// When the request was received.
pub request_received: Option<Instant>,
// When the request body was parsed.
pub request_parsed: Option<Instant>,
// When the response handler returned.
pub response_handled: Option<Instant>,
// When the response was written.
pub response_written: Option<Instant>,
}

impl Default for Timings {
fn default() -> Self {
Self::new(Instant::now())
}
}

impl Timings {
pub fn new(start: Instant) -> Self {
Self {
start,
request_received: None,
request_parsed: None,
response_handled: None,
response_written: None,
}
}

pub fn received(&self) -> Option<Duration> {
self.request_received.map(|x| x - self.start)
}

fn request_received(&mut self) {
self.request_received = Some(Instant::now());
}

pub fn parsed(&self) -> Option<Duration> {
self.request_parsed.map(|x| x - self.start)
}

fn request_parsed(&mut self) {
self.request_parsed = Some(Instant::now());
}

pub fn handled(&self) -> Option<Duration> {
self.response_handled.map(|x| x - self.start)
}

fn response_handled(&mut self) {
self.response_handled = Some(Instant::now());
}

pub fn written(&self) -> Option<Duration> {
self.response_written.map(|x| x - self.start)
}

fn response_written(&mut self) {
self.response_written = Some(Instant::now());
}
}

#[cfg(test)]
mod tests {

Expand All @@ -290,7 +203,7 @@ mod tests {
let req = Request::get("/nothing").body(Body::empty()).unwrap();
let resp = serve(router, req).await.unwrap();
let data = read_err_body(resp.into_body()).await;
assert_eq!(data, bad_route("not found"));
assert_eq!(data, error::bad_route("not found"));
}

#[tokio::test]
Expand Down Expand Up @@ -326,7 +239,7 @@ mod tests {
// TODO: I think malformed should return some info about what was wrong
// with the request, but we don't want to leak server errors that have
// other details.
let mut expected = malformed("bad request");
let mut expected = error::malformed("bad request");
expected.insert_meta(
"error".to_string(),
"EOF while parsing a value at line 1 column 0".to_string(),
Expand All @@ -347,6 +260,6 @@ mod tests {
let resp = serve(router, req).await.unwrap();
assert!(resp.status().is_server_error(), "{:?}", resp);
let data = read_err_body(resp.into_body()).await;
assert_eq!(data, internal("boom!"));
assert_eq!(data, error::internal("boom!"));
}
}
Loading

0 comments on commit 716c187

Please sign in to comment.