Skip to content

Commit

Permalink
chore: use of handle_all_response_variants macro instead of handle_re…
Browse files Browse the repository at this point in the history
…sponse_variants (#3277)

commit-id:fa953022
  • Loading branch information
lev-starkware authored Jan 14, 2025
1 parent dd2af61 commit 2245c5e
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 170 deletions.
86 changes: 1 addition & 85 deletions crates/papyrus_proc_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,7 @@ use std::str::FromStr;

use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::parse::{Parse, ParseStream, Result};
use syn::{
parse_macro_input,
ExprLit,
Ident,
ItemFn,
ItemTrait,
LitBool,
LitStr,
Meta,
Token,
TraitItem,
};
use syn::{parse_macro_input, ExprLit, Ident, ItemFn, ItemTrait, LitBool, LitStr, Meta, TraitItem};

/// This macro is a wrapper around the "rpc" macro supplied by the jsonrpsee library that generates
/// a server and client traits from a given trait definition. The wrapper gets a version id and
Expand Down Expand Up @@ -182,78 +170,6 @@ pub fn latency_histogram(attr: TokenStream, input: TokenStream) -> TokenStream {
modified_function.to_token_stream().into()
}

struct HandleResponseVariantsMacroInput {
response_enum: Ident,
request_response_enum_var: Ident,
component_client_error: Ident,
component_error: Ident,
}

impl Parse for HandleResponseVariantsMacroInput {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let response_enum = input.parse()?;
input.parse::<Token![,]>()?;
let request_response_enum_var = input.parse()?;
input.parse::<Token![,]>()?;
let component_client_error = input.parse()?;
input.parse::<Token![,]>()?;
let component_error = input.parse()?;
Ok(HandleResponseVariantsMacroInput {
response_enum,
request_response_enum_var,
component_client_error,
component_error,
})
}
}

/// A macro for generating code that handles the received communication response.
/// Takes the following arguments:
/// * response_enum -- the response enum type
/// * request_response_enum_var -- the request/response enum variant corresponding to the invoked
/// function
/// * component_client_error -- the component client error type
/// * component_error -- the component error type
///
/// For example, the following code:
/// ```rust,ignore
/// handle_response_variants!(MempoolResponse, GetTransactions, MempoolClientError, MempoolError)
/// ``````
///
/// Results in:
/// ```rust,ignore
/// match response {
/// MempoolResponse::GetTransactions(Ok(response)) => Ok(response),
/// MempoolResponse::GetTransactions(Err(response)) => {
/// Err(MempoolClientError::MempoolError(response))
/// }
/// unexpected_response => Err(MempoolClientError::ClientError(
/// ClientError::UnexpectedResponse(format!("{unexpected_response:?}")),
/// )),
/// }
/// ```
#[proc_macro]
pub fn handle_response_variants(input: TokenStream) -> TokenStream {
let HandleResponseVariantsMacroInput {
response_enum,
request_response_enum_var,
component_client_error,
component_error,
} = parse_macro_input!(input as HandleResponseVariantsMacroInput);

let expanded = quote! {
match response? {
#response_enum::#request_response_enum_var(Ok(response)) => Ok(response),
#response_enum::#request_response_enum_var(Err(response)) => {
Err(#component_client_error::#component_error(response))
}
unexpected_response => Err(#component_client_error::ClientError(ClientError::UnexpectedResponse(format!("{unexpected_response:?}")))),
}
};

TokenStream::from(expanded)
}

struct HandleAllResponseVariantsMacroInput {
response_enum: Ident,
request_response_enum_var: Ident,
Expand Down
62 changes: 41 additions & 21 deletions crates/starknet_batcher_types/src/communication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use async_trait::async_trait;
#[cfg(any(feature = "testing", test))]
use mockall::automock;
use papyrus_proc_macros::handle_response_variants;
use papyrus_proc_macros::handle_all_response_variants;
use serde::{Deserialize, Serialize};
use starknet_sequencer_infra::component_client::{
ClientError,
Expand Down Expand Up @@ -120,58 +120,73 @@ where
{
async fn propose_block(&self, input: ProposeBlockInput) -> BatcherClientResult<()> {
let request = BatcherRequest::ProposeBlock(input);
let response = self.send(request).await;
handle_response_variants!(BatcherResponse, ProposeBlock, BatcherClientError, BatcherError)
handle_all_response_variants!(
BatcherResponse,
ProposeBlock,
BatcherClientError,
BatcherError,
Direct
)
}

async fn get_proposal_content(
&self,
input: GetProposalContentInput,
) -> BatcherClientResult<GetProposalContentResponse> {
let request = BatcherRequest::GetProposalContent(input);
let response = self.send(request).await;
handle_response_variants!(
handle_all_response_variants!(
BatcherResponse,
GetProposalContent,
BatcherClientError,
BatcherError
BatcherError,
Direct
)
}

async fn validate_block(&self, input: ValidateBlockInput) -> BatcherClientResult<()> {
let request = BatcherRequest::ValidateBlock(input);
let response = self.send(request).await;
handle_response_variants!(BatcherResponse, ValidateBlock, BatcherClientError, BatcherError)
handle_all_response_variants!(
BatcherResponse,
ValidateBlock,
BatcherClientError,
BatcherError,
Direct
)
}

async fn send_proposal_content(
&self,
input: SendProposalContentInput,
) -> BatcherClientResult<SendProposalContentResponse> {
let request = BatcherRequest::SendProposalContent(input);
let response = self.send(request).await;
handle_response_variants!(
handle_all_response_variants!(
BatcherResponse,
SendProposalContent,
BatcherClientError,
BatcherError
BatcherError,
Direct
)
}

async fn start_height(&self, input: StartHeightInput) -> BatcherClientResult<()> {
let request = BatcherRequest::StartHeight(input);
let response = self.send(request).await;
handle_response_variants!(BatcherResponse, StartHeight, BatcherClientError, BatcherError)
handle_all_response_variants!(
BatcherResponse,
StartHeight,
BatcherClientError,
BatcherError,
Direct
)
}

async fn get_height(&self) -> BatcherClientResult<GetHeightResponse> {
let request = BatcherRequest::GetCurrentHeight;
let response = self.send(request).await;
handle_response_variants!(
handle_all_response_variants!(
BatcherResponse,
GetCurrentHeight,
BatcherClientError,
BatcherError
BatcherError,
Direct
)
}

Expand All @@ -180,18 +195,23 @@ where
input: DecisionReachedInput,
) -> BatcherClientResult<DecisionReachedResponse> {
let request = BatcherRequest::DecisionReached(input);
let response = self.send(request).await;
handle_response_variants!(
handle_all_response_variants!(
BatcherResponse,
DecisionReached,
BatcherClientError,
BatcherError
BatcherError,
Direct
)
}

async fn add_sync_block(&self, sync_block: SyncBlock) -> BatcherClientResult<()> {
let request = BatcherRequest::AddSyncBlock(sync_block);
let response = self.send(request).await;
handle_response_variants!(BatcherResponse, AddSyncBlock, BatcherClientError, BatcherError)
handle_all_response_variants!(
BatcherResponse,
AddSyncBlock,
BatcherClientError,
BatcherError,
Direct
)
}
}
26 changes: 13 additions & 13 deletions crates/starknet_class_manager_types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pub mod transaction_converter;
use std::sync::Arc;

use async_trait::async_trait;
use papyrus_proc_macros::handle_response_variants;
use papyrus_proc_macros::handle_all_response_variants;
use serde::{Deserialize, Serialize};
use starknet_api::contract_class::ContractClass;
use starknet_api::core::{ClassHash, CompiledClassHash};
Expand Down Expand Up @@ -90,12 +90,12 @@ where
class: Class,
) -> ClassManagerClientResult<ExecutableClassHash> {
let request = ClassManagerRequest::AddClass(class_id, class);
let response = self.send(request).await;
handle_response_variants!(
handle_all_response_variants!(
ClassManagerResponse,
AddClass,
ClassManagerClientError,
ClassManagerError
ClassManagerError,
Direct
)
}

Expand All @@ -105,34 +105,34 @@ where
class: DeprecatedClass,
) -> ClassManagerClientResult<()> {
let request = ClassManagerRequest::AddDeprecatedClass(class_id, class);
let response = self.send(request).await;
handle_response_variants!(
handle_all_response_variants!(
ClassManagerResponse,
AddDeprecatedClass,
ClassManagerClientError,
ClassManagerError
ClassManagerError,
Direct
)
}

async fn get_executable(&self, class_id: ClassId) -> ClassManagerClientResult<ExecutableClass> {
let request = ClassManagerRequest::GetExecutable(class_id);
let response = self.send(request).await;
handle_response_variants!(
handle_all_response_variants!(
ClassManagerResponse,
GetExecutable,
ClassManagerClientError,
ClassManagerError
ClassManagerError,
Direct
)
}

async fn get_sierra(&self, class_id: ClassId) -> ClassManagerClientResult<Class> {
let request = ClassManagerRequest::GetSierra(class_id);
let response = self.send(request).await;
handle_response_variants!(
handle_all_response_variants!(
ClassManagerResponse,
GetSierra,
ClassManagerClientError,
ClassManagerError
ClassManagerError,
Direct
)
}
}
11 changes: 8 additions & 3 deletions crates/starknet_gateway_types/src/communication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use async_trait::async_trait;
#[cfg(any(feature = "testing", test))]
use mockall::automock;
use papyrus_proc_macros::handle_response_variants;
use papyrus_proc_macros::handle_all_response_variants;
use serde::{Deserialize, Serialize};
use starknet_api::transaction::TransactionHash;
use starknet_sequencer_infra::component_client::{
Expand Down Expand Up @@ -62,7 +62,12 @@ where
#[instrument(skip(self))]
async fn add_tx(&self, gateway_input: GatewayInput) -> GatewayClientResult<TransactionHash> {
let request = GatewayRequest::AddTransaction(gateway_input);
let response = self.send(request).await;
handle_response_variants!(GatewayResponse, AddTransaction, GatewayClientError, GatewayError)
handle_all_response_variants!(
GatewayResponse,
AddTransaction,
GatewayClientError,
GatewayError,
Direct
)
}
}
14 changes: 7 additions & 7 deletions crates/starknet_l1_provider_types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use async_trait::async_trait;
#[cfg(any(feature = "testing", test))]
use mockall::automock;
use papyrus_base_layer::L1Event;
use papyrus_proc_macros::handle_response_variants;
use papyrus_proc_macros::handle_all_response_variants;
use serde::{Deserialize, Serialize};
use starknet_api::block::BlockNumber;
use starknet_api::executable_transaction::L1HandlerTransaction;
Expand Down Expand Up @@ -70,24 +70,24 @@ where
height: BlockNumber,
) -> L1ProviderClientResult<Vec<L1HandlerTransaction>> {
let request = L1ProviderRequest::GetTransactions { n_txs, height };
let response = self.send(request).await;
handle_response_variants!(
handle_all_response_variants!(
L1ProviderResponse,
GetTransactions,
L1ProviderClientError,
L1ProviderError
L1ProviderError,
Direct
)
}

#[instrument(skip(self))]
async fn add_events(&self, events: Vec<Event>) -> L1ProviderClientResult<()> {
let request = L1ProviderRequest::AddEvents(events);
let response = self.send(request).await;
handle_response_variants!(
handle_all_response_variants!(
L1ProviderResponse,
AddEvents,
L1ProviderClientError,
L1ProviderError
L1ProviderError,
Direct
)
}

Expand Down
Loading

0 comments on commit 2245c5e

Please sign in to comment.