From ac479384625a0da79c2fe93a93c3cfd8ff452f98 Mon Sep 17 00:00:00 2001 From: ranfdev Date: Mon, 27 Nov 2023 23:13:03 +0100 Subject: [PATCH] Implement this_client() --- capnp-rpc/src/local.rs | 29 +++++++++++++++++++++--- capnp-rpc/test/impls.rs | 40 ++++++++++++++++++++++++++++++++- capnp-rpc/test/test.capnp | 5 +++++ capnp-rpc/test/test.rs | 19 ++++++++++++++++ capnp/src/private/capability.rs | 5 +++++ capnpc/src/codegen.rs | 4 ++++ 6 files changed, 98 insertions(+), 4 deletions(-) diff --git a/capnp-rpc/src/local.rs b/capnp-rpc/src/local.rs index 50ac0d128..c83ca1122 100644 --- a/capnp-rpc/src/local.rs +++ b/capnp-rpc/src/local.rs @@ -368,15 +368,38 @@ where let inner = self.inner.clone(); Promise::from_future(async move { let f = { + let this_hook = Self { inner }; + + // We need to provide a way to get the corresponding ClientHook of a Server. + // Passing the ClientHook inside `Params` would require a lot of changes, breaking compatiblity + // with existing code and ruining the developer experience, because `Params` would end up containing more generics. + // Instead, I'm passing the `ClientHook` through a global static variable, CURRENT_THIS. + // This operation will be called for every method call, so we cannot allocate a `Box` everytime. + // To save us from allocating a `Box` even when we don't access CURRENT_THIS, we set the static variable to a closure + // returning `Box`, making the allocation "lazy". + let prev = unsafe { + // This is a gimmick to make Rust happy. We can only set static values on a `static` variable. + // `static_hook` is not actually static, because it's referencing a variable on the stack. + // Before `this_hook` is dropped, we remove `static_hook` from CURRENT_THIS. + let static_hook = &this_hook as *const (dyn ClientHook + 'static); + capnp::private::capability::CURRENT_THIS.replace(Some(&{ + move || (&*static_hook as &dyn ClientHook).add_ref() + } + as *const (dyn Fn() -> Box<(dyn ClientHook + 'static)> + 'static))) + }; + let server = &mut *this_hook.inner.borrow_mut(); + // We put this borrow_mut() inside a block to avoid a potential // double borrow during f.await - let server = &mut *inner.borrow_mut(); - server.dispatch_call( + let f = server.dispatch_call( interface_id, method_id, ::capnp::capability::Params::new(params), ::capnp::capability::Results::new(results), - ) + ); + + capnp::private::capability::CURRENT_THIS.replace(prev); + f }; f.await }) diff --git a/capnp-rpc/test/impls.rs b/capnp-rpc/test/impls.rs index 71571bb50..c2dbeb8f6 100644 --- a/capnp-rpc/test/impls.rs +++ b/capnp-rpc/test/impls.rs @@ -21,7 +21,7 @@ use crate::test_capnp::{ bootstrap, test_call_order, test_capability_server_set, test_extends, test_handle, - test_interface, test_more_stuff, test_pipeline, + test_interface, test_more_stuff, test_pipeline, test_recursive_client_factorial, }; use capnp::capability::Promise; @@ -113,6 +113,16 @@ impl bootstrap::Server for Bootstrap { .set_cap(capnp_rpc::new_client(TestCapabilityServerSet::new())); Promise::ok(()) } + fn test_recursive_client_factorial( + &mut self, + _params: bootstrap::TestRecursiveClientFactorialParams, + mut results: bootstrap::TestRecursiveClientFactorialResults, + ) -> Promise<(), Error> { + results.get().set_cap(capnp_rpc::new_client( + TestRecursiveClientFactorial::default(), + )); + Promise::ok(()) + } } #[derive(Default)] @@ -656,3 +666,31 @@ impl test_capability_server_set::Server for TestCapabilityServerSet { }) } } + +#[derive(Default)] +pub struct TestRecursiveClientFactorial {} + +impl test_recursive_client_factorial::Server for TestRecursiveClientFactorial { + fn fact( + &mut self, + params: test_recursive_client_factorial::FactParams, + mut results: test_recursive_client_factorial::FactResults, + ) -> Promise<(), Error> { + // the points is to test `this_client()` + let client = self.this_client(); + Promise::from_future(async move { + let n = params.get()?.get_n(); + + let res_number = if n <= 1 { + n + } else { + let mut req = client.fact_request(); + req.get().set_n(n - 1); + let res = req.send().promise.await?; + n * res.get()?.get_res() + }; + results.get().set_res(res_number); + Ok(()) + }) + } +} diff --git a/capnp-rpc/test/test.capnp b/capnp-rpc/test/test.capnp index cbd34a932..701211b39 100644 --- a/capnp-rpc/test/test.capnp +++ b/capnp-rpc/test/test.capnp @@ -79,6 +79,7 @@ interface Bootstrap { testCallOrder @4 () -> (cap: TestCallOrder); testMoreStuff @5 () -> (cap: TestMoreStuff); testCapabilityServerSet @6 () -> (cap: TestCapabilityServerSet); + testRecursiveClientFactorial @7 () -> (cap: TestRecursiveClientFactorial); } interface TestInterface { @@ -183,3 +184,7 @@ interface TestCapabilityServerSet { createHandle @0 () -> (handle :Handle); checkHandle @1 (handle: Handle) -> (isOurs :Bool); } + +interface TestRecursiveClientFactorial { + fact @0 (n: Int32) -> (res :Int32); +} diff --git a/capnp-rpc/test/test.rs b/capnp-rpc/test/test.rs index f0d6268bd..abbfb9221 100644 --- a/capnp-rpc/test/test.rs +++ b/capnp-rpc/test/test.rs @@ -1014,3 +1014,22 @@ fn capability_server_set_rpc() { Ok(()) }) } +#[test] +fn recursive_client() { + rpc_top_level(|_spawner, client| async move { + let response1 = client + .test_recursive_client_factorial_request() + .send() + .promise + .await?; + let client1 = response1.get()?.get_cap()?; + + let mut req = client1.fact_request(); + req.get().set_n(4); + + let res = req.send().promise.await?; + assert_eq!(res.get()?.get_res(), 24); + + Ok(()) + }) +} diff --git a/capnp/src/private/capability.rs b/capnp/src/private/capability.rs index ba725fb60..96668dc9d 100644 --- a/capnp/src/private/capability.rs +++ b/capnp/src/private/capability.rs @@ -22,11 +22,16 @@ #![cfg(feature = "alloc")] use alloc::boxed::Box; use alloc::vec::Vec; +use core::cell::RefCell; use crate::any_pointer; use crate::capability::{Params, Promise, RemotePromise, Request, Results}; use crate::MessageSize; +thread_local! { + pub static CURRENT_THIS: RefCell Box>> = Default::default(); +} + pub trait ResponseHook { fn get(&self) -> crate::Result>; } diff --git a/capnpc/src/codegen.rs b/capnpc/src/codegen.rs index 9449718eb..372ad9ced 100644 --- a/capnpc/src/codegen.rs +++ b/capnpc/src/codegen.rs @@ -2745,6 +2745,10 @@ fn generate_node( params.params, server_base, params.where_clause )), indent(server_interior), + indent(line(format!("fn this_client<'a>(&'a mut self) -> Client<{}> {{", params.params))), // are these generics ALWAYS the same as the Client generics? It seems like so. + indent(indent(Line(format!("::capnp::private::capability::CURRENT_THIS.with_borrow(|curr_this| as ::capnp::capability::FromClientHook>::new(", params.params)))), // Should replace ::capnp with {capnp} maybe? + indent(indent(Line(fmt!(ctx, "unsafe {{&*curr_this.unwrap() as &dyn Fn() -> Box}}()))")))), + indent(line("}")), line("}"), ]));