Skip to content

Commit

Permalink
perf: utilize chain ID cache on re-connect in Ethereum node provider (#…
Browse files Browse the repository at this point in the history
…2464)

Co-authored-by: antazoey <[email protected]>
  • Loading branch information
antazoey and antazoey authored Jan 10, 2025
1 parent 61d5a57 commit 298d690
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 15 deletions.
1 change: 0 additions & 1 deletion src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,6 @@ def chain_id(self) -> int:
**NOTE**: Unless overridden, returns same as
:py:attr:`ape.api.providers.ProviderAPI.chain_id`.
"""

return self.provider.chain_id

@property
Expand Down
1 change: 0 additions & 1 deletion src/ape/managers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,6 @@ def chain_id(self) -> int:
The blockchain ID.
See `ChainList <https://chainlist.org/>`__ for a comprehensive list of IDs.
"""

network_name = self.provider.network.name
if network_name not in self._chain_id_map:
self._chain_id_map[network_name] = self.provider.chain_id
Expand Down
28 changes: 16 additions & 12 deletions src/ape_ethereum/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,20 +572,28 @@ def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional["BlockID"] =
@cached_property
def chain_id(self) -> int:
default_chain_id = None
if self.network.name != "custom" and not self.network.is_dev:
# If using a live network, the chain ID is hardcoded.
if self.network.name not in ("adhoc", "custom") and not self.network.is_dev:
# If using a live plugin-based network, the chain ID is hardcoded.
default_chain_id = self.network.chain_id

try:
if hasattr(self.web3, "eth"):
return self.web3.eth.chain_id
return self._get_chain_id()

except ProviderNotConnectedError:
if default_chain_id is not None:
return default_chain_id

raise # Original error

except ValueError as err:
# Possible syncing error.
raise ProviderError(
err.args[0].get("message")
if all((hasattr(err, "args"), err.args, isinstance(err.args[0], dict)))
else "Error getting chain ID."
)

if default_chain_id is not None:
return default_chain_id

Expand All @@ -606,6 +614,10 @@ def priority_fee(self) -> int:
"eth_maxPriorityFeePerGas not supported in this RPC. Please specify manually."
) from err

def _get_chain_id(self) -> int:
result = self.make_request("eth_chainId", [])
return result if isinstance(result, int) else int(result, 16)

def get_block(self, block_id: "BlockID") -> BlockAPI:
if isinstance(block_id, str) and block_id.isnumeric():
block_id = int(block_id)
Expand Down Expand Up @@ -1603,15 +1615,7 @@ def _complete_connect(self):
if not self.network.is_dev:
self.web3.eth.set_gas_price_strategy(rpc_gas_price_strategy)

# Check for chain errors, including syncing
try:
chain_id = self.web3.eth.chain_id
except ValueError as err:
raise ProviderError(
err.args[0].get("message")
if all((hasattr(err, "args"), err.args, isinstance(err.args[0], dict)))
else "Error getting chain id."
)
chain_id = self.chain_id

# NOTE: We have to check both earliest and latest
# because if the chain was _ever_ PoA, we need
Expand Down
9 changes: 8 additions & 1 deletion tests/functional/geth/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,14 @@ def test_connect_to_chain_that_started_poa(mock_web3, web3_factory, ethereum):
to fetch blocks during the PoA portion of the chain.
"""
mock_web3.eth.get_block.side_effect = ExtraDataLengthError
mock_web3.eth.chain_id = ethereum.sepolia.chain_id

def make_request(rpc, arguments):
if rpc == "eth_chainId":
return {"result": ethereum.sepolia.chain_id}

return None

mock_web3.provider.make_request.side_effect = make_request
web3_factory.return_value = mock_web3
provider = ethereum.sepolia.get_provider("node")
provider.provider_settings = {"uri": "http://node.example.com"} # fake
Expand Down
58 changes: 58 additions & 0 deletions tests/functional/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,35 @@ def test_chain_id_is_cached(eth_tester_provider):
eth_tester_provider._web3 = web3 # Undo


def test_chain_id_from_ethereum_base_provider_is_cached(mock_web3, ethereum, eth_tester_provider):
"""
Simulated chain ID from a plugin (using base-ethereum class) to ensure is
also cached.
"""

def make_request(rpc, arguments):
if rpc == "eth_chainId":
return {"result": 11155111} # Sepolia

return eth_tester_provider.make_request(rpc, arguments)

mock_web3.provider.make_request.side_effect = make_request

class PluginProvider(Web3Provider):
def connect(self):
return

def disconnect(self):
return

provider = PluginProvider(name="sim", network=ethereum.sepolia)
provider._web3 = mock_web3
assert provider.chain_id == 11155111
# Unset to web3 to prove it does not check it again (else it would fail).
provider._web3 = None
assert provider.chain_id == 11155111


def test_chain_id_when_disconnected(eth_tester_provider):
eth_tester_provider.disconnect()
try:
Expand Down Expand Up @@ -658,3 +687,32 @@ def test_update_settings_invalidates_snapshots(eth_tester_provider, chain):
assert snapshot in chain._snapshots[eth_tester_provider.chain_id]
eth_tester_provider.update_settings({})
assert snapshot not in chain._snapshots[eth_tester_provider.chain_id]


def test_connect_uses_cached_chain_id(mocker, mock_web3, ethereum, eth_tester_provider):
class PluginProvider(EthereumNodeProvider):
pass

web3_factory_patch = mocker.patch("ape_ethereum.provider._create_web3")
web3_factory_patch.return_value = mock_web3

class ChainIDTracker:
call_count = 0

def make_request(self, rpc, args):
if rpc == "eth_chainId":
self.call_count += 1
return {"result": "0xaa36a7"} # Sepolia

return eth_tester_provider.make_request(rpc, args)

chain_id_tracker = ChainIDTracker()
mock_web3.provider.make_request.side_effect = chain_id_tracker.make_request

provider = PluginProvider(name="node", network=ethereum.sepolia)
provider.connect()
assert chain_id_tracker.call_count == 1
provider.disconnect()
provider.connect()
# It is still cached from the previous connection.
assert chain_id_tracker.call_count == 1

0 comments on commit 298d690

Please sign in to comment.