Skip to content

Commit

Permalink
Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
tclem committed Nov 3, 2023
1 parent 0e1ba34 commit 56f62a0
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 30 deletions.
55 changes: 26 additions & 29 deletions crates/twirp/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::sync::Arc;

use async_trait::async_trait;
use hyper::header::{InvalidHeaderValue, CONTENT_TYPE};
use hyper::http::HeaderValue;
use hyper::{HeaderMap, StatusCode};
use hyper::header;
use hyper::StatusCode;
use thiserror::Error;
use url::Url;

Expand All @@ -14,7 +13,7 @@ use crate::{error::*, to_proto_body};
#[derive(Debug, Error)]
pub enum ClientError {
#[error(transparent)]
InvalidHeader(#[from] InvalidHeaderValue),
InvalidHeader(#[from] header::InvalidHeaderValue),
#[error("base_url must end in /, but got: {0}")]
InvalidBaseUrl(Url),
#[error(transparent)]
Expand Down Expand Up @@ -45,19 +44,23 @@ pub type Result<T> = core::result::Result<T, ClientError>;
/// Use ClientBuilder to build a TwirpClient.
pub struct ClientBuilder {
base_url: Url,
builder: reqwest::ClientBuilder,
http_client: reqwest::Client,
middleware: Vec<Arc<dyn Middleware>>,
}

impl ClientBuilder {
pub fn new(base_url: Url) -> Self {
pub fn new(base_url: Url, http_client: reqwest::Client) -> Self {
Self {
base_url,
builder: reqwest::ClientBuilder::default(),
http_client,
middleware: vec![],
}
}

pub fn from_base_url(base_url: Url) -> Self {
Self::new(base_url, reqwest::Client::default())
}

/// Add middleware to the client that will be called on each request.
/// Middlewares are invoked in the order they are added as part of the
/// request cycle.
Expand All @@ -69,21 +72,13 @@ impl ClientBuilder {
mw.push(Arc::new(middleware));
Self {
base_url: self.base_url,
builder: self.builder,
http_client: self.http_client,
middleware: mw,
}
}

pub fn with_client_builder(self, builder: reqwest::ClientBuilder) -> Self {
Self {
base_url: self.base_url,
builder,
middleware: self.middleware,
}
}

pub fn build(self) -> Result<Client> {
Client::new(self.base_url, self.builder, self.middleware)
Client::new(self.base_url, self.http_client, self.middleware)
}
}

Expand All @@ -92,15 +87,15 @@ impl ClientBuilder {
#[derive(Clone)]
pub struct Client {
pub base_url: Arc<Url>,
client: Arc<reqwest::Client>,
http_client: Arc<reqwest::Client>,
middlewares: Vec<Arc<dyn Middleware>>,
}

impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TwirpClient")
.field("base_url", &self.base_url)
.field("client", &self.client)
.field("client", &self.http_client)
.field("middlewares", &self.middlewares.len())
.finish()
}
Expand All @@ -113,16 +108,13 @@ impl Client {
/// you create one and **reuse** it.
pub fn new(
base_url: Url,
b: reqwest::ClientBuilder,
http_client: reqwest::Client,
middlewares: Vec<Arc<dyn Middleware>>,
) -> Result<Self> {
if base_url.path().ends_with('/') {
let mut headers: HeaderMap<HeaderValue> = HeaderMap::default();
headers.insert(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF.try_into()?);
let client = b.default_headers(headers).build()?;
Ok(Client {
base_url: Arc::new(base_url),
client: Arc::new(client),
http_client: Arc::new(http_client),
middlewares,
})
} else {
Expand All @@ -135,7 +127,7 @@ impl Client {
/// The underlying `reqwest::Client` holds a connection pool internally, so it is advised that
/// you create one and **reuse** it.
pub fn from_base_url(base_url: Url) -> Result<Self> {
Self::new(base_url, reqwest::ClientBuilder::default(), vec![])
Self::new(base_url, reqwest::Client::default(), vec![])
}

/// Add middleware to this specific request stack. Middlewares are invoked
Expand All @@ -156,15 +148,20 @@ impl Client {
O: prost::Message + Default,
{
let path = url.path().to_string();
let req = self.client.post(url).body(to_proto_body(body)).build()?;
let req = self
.http_client
.post(url)
.header(header::CONTENT_TYPE, CONTENT_TYPE_PROTOBUF)
.body(to_proto_body(body))
.build()?;

// Create and execute the middleware handlers
let next = Next::new(&self.client, &self.middlewares);
let next = Next::new(&self.http_client, &self.middlewares);
let resp = next.run(req).await?;

// These have to be extracted because reading the body consumes `Response`.
let status = resp.status();
let content_type = resp.headers().get(CONTENT_TYPE).cloned();
let content_type = resp.headers().get(header::CONTENT_TYPE).cloned();

match (status, content_type) {
(status, Some(ct)) if status.is_success() && ct.as_bytes() == CONTENT_TYPE_PROTOBUF => {
Expand Down Expand Up @@ -271,7 +268,7 @@ mod tests {
#[tokio::test]
async fn test_routes() {
let base_url = Url::parse("http://localhost:3001/twirp/").unwrap();
let client = ClientBuilder::new(base_url)
let client = ClientBuilder::new(base_url, reqwest::Client::new())
.with(AssertRouting {
expected_url: "http://localhost:3001/twirp/test.TestAPI/Ping",
})
Expand Down
16 changes: 16 additions & 0 deletions crates/twirp/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,22 @@ impl Router {
self.routes.insert(key, Box::new(g));
}

/// Adds an async handler to the router for the given method and path.
pub fn add_async_handler<F, Fut>(&mut self, method: Method, path: &str, f: F)
where
F: Fn(Request<Body>) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<Response<Body>, GenericError>> + Send,
{
let g = move |req| -> Box<
dyn Future<Output = Result<Response<Body>, GenericError>> + Unpin + Send,
> {
let f = f.clone();
Box::new(Box::pin(async move { f(req).await }))
};
let key = (method, path.to_string());
self.routes.insert(key, Box::new(g));
}

/// Adds a twirp method handler to the router for the given path.
pub fn add_method<F, Fut, Req, Resp>(&mut self, path: &str, f: F)
where
Expand Down
2 changes: 1 addition & 1 deletion example/src/bin/example-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub async fn main() -> Result<(), GenericError> {
eprintln!("{:?}", resp);

// customize the client with middleware
let client = ClientBuilder::new(Url::parse("http://xyz:3000/twirp/")?)
let client = ClientBuilder::from_base_url(Url::parse("http://xyz:3000/twirp/")?)
.with(RequestHeaders { hmac_key: None })
.build()?;
let resp = client
Expand Down

0 comments on commit 56f62a0

Please sign in to comment.