Skip to content

Commit

Permalink
Merge pull request #8 from github/jorendorff/axum-encapsulate
Browse files Browse the repository at this point in the history
Tidy up after hyper 1.x update
  • Loading branch information
jorendorff authored Jan 25, 2024
2 parents 4fc11f2 + 538ee51 commit eee91a5
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 60 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ mod haberdash {
include!(concat!(env!("OUT_DIR"), "/service.haberdash.v1.rs"));
}

use axum::Router;
use haberdash::{MakeHatRequest, MakeHatResponse};

#[tokio::main]
Expand Down Expand Up @@ -87,6 +88,10 @@ impl haberdash::HaberdasherAPI for HaberdasherAPIServer {
}
```

This code creates an `axum::Router`, then hands it off to `axum::serve()` to handle networking.
This use of `axum::serve` is optional. After building `app`, you can instead invoke it from any
`hyper`-based server by importing `twirp::tower::Service` and doing `app.call(request).await`.

## Usage (client side)

On the client side, you also get a generated twirp client (based on the rpc endpoints in your proto). Include the generated code, create a client, and start making rpc calls:
Expand Down
26 changes: 7 additions & 19 deletions crates/twirp-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
let service_fqn = format!("{}.{}", service.package, service_name);
writeln!(buf).unwrap();

writeln!(buf, "pub const SERVICE_FQN: &str = \"{service_fqn}\";").unwrap();
writeln!(buf, "pub const SERVICE_FQN: &str = \"/{service_fqn}\";").unwrap();

//
// generate the twirp server
Expand All @@ -43,37 +43,25 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
where
T: {service_name} + Send + Sync + 'static,
{{
twirp::Router::new()"#,
twirp::details::TwirpRouterBuilder::new(api)"#,
)
.unwrap();
for m in &service.methods {
let uri = &m.proto_name;
let req_type = &m.input_type;
let rust_method_name = &m.name;
writeln!(
buf,
r#" .route(
"/{uri}",
twirp::details::post(
|twirp::details::State(api): twirp::details::State<std::sync::Arc<T>>,
req: twirp::details::Request| async move {{
twirp::server::handle_request(
req,
move |req| async move {{
api.{rust_method_name}(req).await
}},
)
.await
}},
),
)"#,
r#" .route("/{uri}", |api: std::sync::Arc<T>, req: {req_type}| async move {{
api.{rust_method_name}(req).await
}})"#,
)
.unwrap();
}
writeln!(
buf,
r#"
.with_state(api)
.fallback(twirp::server::not_found_handler)
.build()
}}"#
)
.unwrap();
Expand Down
60 changes: 54 additions & 6 deletions crates/twirp/src/details.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,58 @@
//! Undocumented features that are public for use in generated code (see `twirp-build`).
#[doc(hidden)]
pub use axum::extract::{Request, State};
use std::future::Future;

#[doc(hidden)]
pub use axum::routing::post;
use axum::extract::{Request, State};
use axum::Router;

#[doc(hidden)]
pub use axum::response::Response;
use crate::{server, TwirpErrorResponse};

/// Builder object used by generated code to build a Twirp service.
///
/// The type `S` is something like `Arc<MyExampleAPIServer>`, which can be cheaply cloned for each
/// incoming request, providing access to the Rust value that actually implements the RPCs.
pub struct TwirpRouterBuilder<S> {
service: S,
router: Router<S>,
}

impl<S> TwirpRouterBuilder<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn new(service: S) -> Self {
TwirpRouterBuilder {
service,
router: Router::new(),
}
}

/// Add a handler for an `rpc` to the router.
///
/// The generated code passes a closure that calls the method, like
/// `|api: Arc<HaberdasherAPIServer>, req: MakeHatRequest| async move { api.make_hat(req) }`.
pub fn route<F, Fut, Req, Res>(self, url: &str, f: F) -> Self
where
F: Fn(S, Req) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<Res, TwirpErrorResponse>> + Send,
Req: prost::Message + Default + serde::de::DeserializeOwned,
Res: prost::Message + serde::Serialize,
{
TwirpRouterBuilder {
service: self.service,
router: self.router.route(
url,
axum::routing::post(move |State(api): State<S>, req: Request| async move {
server::handle_request(api, req, f).await
}),
),
}
}

/// Finish building the axum router.
pub fn build(self) -> axum::Router {
self.router
.fallback(crate::server::not_found_handler)
.with_state(self.service)
}
}
14 changes: 10 additions & 4 deletions crates/twirp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,24 @@ pub mod server;
#[cfg(any(test, feature = "test-support"))]
pub mod test;

#[doc(hidden)]
pub mod details;

pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, Result};
pub use error::*; // many constructors like `invalid_argument()`
pub use server::{Router, Timings};

// Re-export `reqwest` so that it's easy to implement middleware.
// Re-export this crate's dependencies that users are likely to code against. These can be used to
// import the exact versions of these libraries `twirp` is built with -- useful if your project is
// so sprawling that it builds multiple versions of some crates.
pub use axum;
pub use reqwest;

// Re-export `url so that the generated code works without additional dependencies beyond just the `twirp` crate.
pub use tower;
pub use url;

/// Re-export of `axum::Router`, the type that encapsulates a server-side implementation of a Twirp
/// service.
pub use axum::Router;

pub(crate) fn serialize_proto_message<T>(m: T) -> Vec<u8>
where
T: prost::Message,
Expand Down
34 changes: 26 additions & 8 deletions crates/twirp/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
//! Support for serving Twirp APIs.
//!
//! There is not much to see in the documentation here. This API is meant to be used with
//! `twirp-build`. See <https://github.com/github/twirp-rs#usage> for details and an example.
use std::fmt::Debug;

use axum::body::Body;
use axum::response::IntoResponse;
pub use axum::Router;
use futures::Future;
use http_body_util::BodyExt;
use hyper::{header, Request, Response};
Expand All @@ -13,9 +17,6 @@ use tokio::time::{Duration, Instant};
use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF};
use crate::{error, serialize_proto_message, GenericError, TwirpErrorResponse};

/// The canonical twirp path prefix. You don't have to use this, but it's the default.
pub const DEFAULT_TWIRP_PATH_PREFIX: &str = "/twirp";

// TODO: Properly implement JsonPb (de)serialization as it is slightly different
// than standard JSON.
#[derive(Debug, Clone, Copy, Default)]
Expand All @@ -39,9 +40,13 @@ impl BodyFormat {
}

/// Entry point used in code generated by `twirp-build`.
pub async fn handle_request<F, Fut, Req, Resp>(req: Request<Body>, f: F) -> Response<Body>
pub(crate) async fn handle_request<S, F, Fut, Req, Resp>(
service: S,
req: Request<Body>,
f: F,
) -> Response<Body>
where
F: FnOnce(Req) -> Fut + Clone + Sync + Send + 'static,
F: FnOnce(S, Req) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<Resp, TwirpErrorResponse>> + Send,
Req: prost::Message + Default + serde::de::DeserializeOwned,
Resp: prost::Message + serde::Serialize,
Expand All @@ -64,7 +69,7 @@ where
}
};

let res = f(req).await;
let res = f(service, req).await;
timings.set_response_handled();

let mut resp = match write_response(res, resp_fmt) {
Expand Down Expand Up @@ -111,7 +116,7 @@ where
BodyFormat::Pb => Response::builder()
.header(header::CONTENT_TYPE, CONTENT_TYPE_PROTOBUF)
.body(Body::from(serialize_proto_message(response)))?,
_ => {
BodyFormat::JsonPb => {
let data = serde_json::to_string(&response)?;
Response::builder()
.header(header::CONTENT_TYPE, CONTENT_TYPE_JSON)
Expand All @@ -126,6 +131,19 @@ where
/// Axum handler function that returns 404 Not Found with a Twirp JSON payload.
///
/// `axum::Router`'s default fallback handler returns a 404 Not Found with no body content.
/// Use this fallback instead for full Twirp compliance.
///
/// # Usage
///
/// ```
/// use axum::Router;
///
/// # fn build_app(twirp_routes: Router) -> Router {
/// let app = Router::new()
/// .nest("/twirp", twirp_routes)
/// .fallback(twirp::server::not_found_handler);
/// # app }
/// ```
pub async fn not_found_handler() -> Response<Body> {
error::bad_route("not found").into_response()
}
Expand Down
28 changes: 5 additions & 23 deletions crates/twirp/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use serde::de::DeserializeOwned;
use tokio::task::JoinHandle;
use tokio::time::Instant;

use crate::details::TwirpRouterBuilder;
use crate::server::Timings;
use crate::{error, Client, Result, TwirpErrorResponse};

Expand All @@ -28,35 +29,16 @@ pub fn test_api_router() -> Router {
let api = Arc::new(TestAPIServer {});

// NB: This part would be generated
let test_router = crate::Router::new()
let test_router = TwirpRouterBuilder::new(api)
.route(
"/Ping",
crate::details::post(
|crate::details::State(api): crate::details::State<Arc<TestAPIServer>>,
req: crate::details::Request| async move {
crate::server::handle_request(
req,
move |req| async move { api.ping(req).await },
)
.await
},
),
|api: Arc<TestAPIServer>, req: PingRequest| async move { api.ping(req).await },
)
.route(
"/Boom",
crate::details::post(
|crate::details::State(api): crate::details::State<Arc<TestAPIServer>>,
req: crate::details::Request| async move {
crate::server::handle_request(
req,
move |req| async move { api.boom(req).await },
)
.await
},
),
|api: Arc<TestAPIServer>, req: PingRequest| async move { api.boom(req).await },
)
.fallback(crate::server::not_found_handler)
.with_state(api);
.build();

axum::Router::new()
.nest("/twirp/test.TestAPI", test_router)
Expand Down

0 comments on commit eee91a5

Please sign in to comment.