Skip to content

Commit

Permalink
fix: 429 issue on chain ID request #100 (#101)
Browse files Browse the repository at this point in the history
Co-authored-by: antazoey <[email protected]>
  • Loading branch information
antazoey and antazoey authored Jan 8, 2025
1 parent 963b637 commit c87afb1
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 46 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ repos:
additional_dependencies: [flake8-breakpoint, flake8-print, flake8-pydantic, flake8-type-checking]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
rev: v1.14.1
hooks:
- id: mypy
additional_dependencies: [types-PyYAML, types-requests, types-setuptools, pydantic]

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.19
rev: 0.7.21
hooks:
- id: mdformat
additional_dependencies: [mdformat-gfm, mdformat-frontmatter, mdformat-pyproject]
Expand Down
41 changes: 16 additions & 25 deletions ape_infura/provider.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import os
import random
import time
from collections.abc import Callable
from functools import cached_property
from typing import Optional

from ape.api import UpstreamProvider
from ape.exceptions import ContractLogicError, ProviderError, VirtualMachineError
from ape.logging import logger
from ape.utils.rpc import request_with_retry
from ape_ethereum.provider import Web3Provider
from requests import HTTPError, Session
from requests import Session
from web3 import HTTPProvider, Web3
from web3.exceptions import ContractLogicError as Web3ContractLogicError
from web3.exceptions import ExtraDataLengthError
Expand Down Expand Up @@ -140,9 +138,21 @@ def ws_uri(self) -> Optional[str]:
def connection_str(self) -> str:
return self.uri

@property
@cached_property
def chain_id(self):
return _run_with_retry(lambda: self._web3.eth.chain_id)
return request_with_retry(
lambda: self._get_chain_id(),
max_retries=_MAX_REQUEST_RETRIES,
min_retry_delay=_REQUEST_RETRY_DELAY * 1_000,
)

def _get_chain_id(self):
result = self.make_request("eth_chainId", [])
if isinstance(result, int):
return result

# Is a hex.
return int(result, 16)

def connect(self):
session = _get_session()
Expand Down Expand Up @@ -232,22 +242,3 @@ def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMa

def _create_web3(http_provider: HTTPProvider) -> Web3:
return Web3(http_provider)


def _run_with_retry(
func: Callable, max_retries: int = _MAX_REQUEST_RETRIES, retry_delay: int = _REQUEST_RETRY_DELAY
):
retries = 0
while retries < max_retries:
try:
return func()
except HTTPError as err:
if err.response.status_code == 429:
logger.debug(f"429 Too Many Requests. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
retries += 1
retry_delay += retry_delay
else:
raise # Re-raise non-429 HTTP errors

raise ProviderError(f"Exceeded maximum retries ({max_retries}).")
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ include = '\.pyi?$'

[tool.pytest.ini_options]
addopts = """
-p no:ape_test
--cov-branch
--cov-report term
--cov-report html
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
],
"lint": [
"black>=24.10.0,<25", # Auto-formatter and linter
"mypy>=1.13.0,<2", # Static type analyzer
"mypy>=1.14.1,<2", # Static type analyzer
"types-setuptools", # Needed for mypy type shed
"flake8>=7.1.1,<8", # Style linter
"flake8-breakpoint>=1.1.0,<2", # Detect breakpoints left in code
"flake8-print>=5.0.0,<6", # Detect print statements left in code
"flake8-pydantic", # For detecting issues with Pydantic models
"flake8-type-checking", # Detect imports to move in/out of type-checking blocks
"isort>=5.13.2,<6", # Import sorting linter
"mdformat>=0.7.19", # Auto-formatter for markdown
"mdformat>=0.7.21", # Auto-formatter for markdown
"mdformat-gfm>=0.3.5", # Needed for formatting GitHub-flavored markdown
"mdformat-frontmatter>=0.4.1", # Needed for frontmatters-style headers in issue templates
"mdformat-pyproject>=0.0.2", # Allows configuring in pyproject.toml
Expand Down
16 changes: 0 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,10 @@
import ape
import pytest

from ape_infura import NETWORKS

NETWORK_SKIPS = ("starknet",)


@pytest.fixture
def accounts():
return ape.accounts


@pytest.fixture
def Contract():
return ape.Contract


@pytest.fixture
def networks():
return ape.networks


# NOTE: Using a `str` as param for better pytest test-case name generation.
@pytest.fixture(
params=[f"{e}:{n}" for e, values in NETWORKS.items() if e not in NETWORK_SKIPS for n in values]
Expand Down
47 changes: 47 additions & 0 deletions tests/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def test_dynamic_poa_check(mocker):
infura = Infura(name=real.name, network=real.network)
patch = mocker.patch("ape_infura.provider._create_web3")
patch.return_value = mock_web3

def make_request(rpc, arguments):
if rpc == "eth_chainId":
return {"result": "0x4268"}

mock_web3.provider.make_request.side_effect = make_request

infura.connect()
mock_web3.middleware_onion.inject.assert_called_once_with(ExtraDataToPOAMiddleware, layer=0)

Expand All @@ -128,3 +135,43 @@ def test_api_secret():
session = _get_session()
assert session.auth == ("", "123")
del os.environ["WEB3_INFURA_PROJECT_SECRET"]


def test_chain_id(networks):
with networks.ethereum.sepolia.use_provider("infura") as infura:
assert infura.chain_id == 11155111

with networks.ethereum.holesky.use_provider("infura") as infura:
assert infura.chain_id == 17000


def test_chain_id_cached(mocker, networks):
"""
A test just showing we utilize a cached chain ID
to limit unnecessary requests.
"""

infura = networks.ethereum.sepolia.get_provider("infura")
infura.connect()

class ChainIdTracker:
call_count = 0

def make_request(self, rpc, arguments):
if rpc == "eth_chainId":
self.call_count += 1
return {"result": "0x4268"}

tracker = ChainIdTracker()
mock_web3 = mocker.MagicMock()
mock_web3.provider.make_request.side_effect = tracker.make_request
infura._web3 = mock_web3

# Start off fresh for the sake of the test.
infura.__dict__.pop("chain_id")

_ = infura.chain_id
_ = infura.chain_id # Call again!
_ = infura.chain_id # Once more!

assert tracker.call_count == 1

0 comments on commit c87afb1

Please sign in to comment.