Skip to content

Commit

Permalink
feat: batch requests low level
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Jan 21, 2025
1 parent a87fba6 commit 05fb492
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 18 deletions.
11 changes: 11 additions & 0 deletions src/ape/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,17 @@ def stream_request( # type: ignore[empty-body]
An iterator of items.
"""

@raises_not_implemented
def batch_requests(self, requests: list[dict]) -> Any:
"""
Send batched requests (multiple requests at once) to the RPC provider.
Args:
requests (list[dict]): The requests to send.
Returns: The results of each request.
"""

# TODO: In 0.9, delete this method.
def get_storage_at(self, *args, **kwargs) -> "HexBytes":
warnings.warn(
Expand Down
65 changes: 47 additions & 18 deletions src/ape_ethereum/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast

import ijson # type: ignore
import requests
import requests as requests_lib
from eth_pydantic_types import HexBytes
from eth_typing import BlockNumber, HexStr
from eth_utils import add_0x_prefix, is_hex, to_hex
Expand Down Expand Up @@ -1333,7 +1333,7 @@ def make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any:
def _make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any:
parameters = parameters or []
try:
result = self.web3.provider.make_request(RPCEndpoint(rpc), parameters)
response = self.web3.provider.make_request(RPCEndpoint(rpc), parameters)
except HTTPError as err:
if "method not allowed" in str(err).lower():
raise APINotImplementedError(
Expand All @@ -1345,28 +1345,36 @@ def _make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any:

raise ProviderError(str(err)) from err

if "error" in result:
error = result["error"]
return self._get_result_from_rpc_response(rpc, response)

def _get_result_from_rpc_response(
self, rpc: str, response: dict, raise_on_failure: bool = True
) -> Any:
if "error" in response:
error = response["error"]
message = (
error["message"] if isinstance(error, dict) and "message" in error else str(error)
)
if raise_on_failure:
if (
"does not exist/is not available" in str(message)
or re.match(r"[m|M]ethod .*?not found", message)
or message.startswith("Unknown RPC Endpoint")
or "RPC Endpoint has not been implemented" in message
):
raise APINotImplementedError(
f"RPC method '{rpc}' is not implemented by this node instance."
)

if (
"does not exist/is not available" in str(message)
or re.match(r"[m|M]ethod .*?not found", message)
or message.startswith("Unknown RPC Endpoint")
or "RPC Endpoint has not been implemented" in message
):
raise APINotImplementedError(
f"RPC method '{rpc}' is not implemented by this node instance."
)
raise ProviderError(message)

raise ProviderError(message)
else:
return message

elif "result" in result:
return result.get("result", {})
elif "result" in response:
return response.get("result", {})

return result
return response

def stream_request(self, method: str, params: Iterable, iter_path: str = "result.item"):
if not (uri := self.http_uri):
Expand All @@ -1375,14 +1383,35 @@ def stream_request(self, method: str, params: Iterable, iter_path: str = "result
payload = {"jsonrpc": "2.0", "id": 1, "method": method, "params": params}
results = ijson.sendable_list()
coroutine = ijson.items_coro(results, iter_path)
resp = requests.post(uri, json=payload, stream=True)
resp = requests_lib.post(uri, json=payload, stream=True)
resp.raise_for_status()

for chunk in resp.iter_content(chunk_size=2**17):
coroutine.send(chunk)
yield from results
del results[:]

def batch_requests(self, requests: list[dict]) -> Any:
if not (uri := self.http_uri):
raise ProviderError("This provider has no HTTP URI and is unable to batch requests.")

for idx, request in enumerate(requests):
if "jsonrpc" not in request:
request["jsonrpc"] = "2.0"
if "id" not in request:
request["id"] = idx + 1

response = requests_lib.post(uri, json=requests)
try:
response.raise_for_status()
except HTTPError as err:
raise ProviderError(str(err)) from err

return [
self._get_result_from_rpc_response(uri, r, raise_on_failure=False)
for r in response.json()
]

def create_access_list(
self, transaction: TransactionAPI, block_id: Optional["BlockID"] = None
) -> list[AccessList]:
Expand Down
19 changes: 19 additions & 0 deletions tests/functional/geth/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,25 @@ def test_make_request_not_exists(geth_provider):
geth_provider.make_request("ape_thisDoesNotExist")


@geth_process_test
def test_batch_requests(geth_account, geth_contract, geth_provider):
call = geth_contract.myNumber.as_transaction().model_dump()
results = geth_provider.batch_requests(
[
{
"method": "eth_call",
"params": [{"data": call["data"], "to": geth_contract.address}],
},
{
"method": "eth_getBalance",
"params": [geth_account.address, "latest"],
},
]
)
for result in results:
assert result.startswith("0x")


@geth_process_test
@pytest.mark.parametrize(
"message",
Expand Down

0 comments on commit 05fb492

Please sign in to comment.