diff --git a/setup.py b/setup.py index 76f40d5f..893b4409 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,7 @@ "eth-pydantic-types", # Use same version as eth-ape "packaging", # Use same version as eth-ape "pydantic_settings", # Use same version as eth-ape + "quattro>=24.1,<25", "taskiq[metrics]>=0.11.3,<0.12", ], entry_points={ diff --git a/silverback/_cli.py b/silverback/_cli.py index 17781839..b17870c3 100644 --- a/silverback/_cli.py +++ b/silverback/_cli.py @@ -109,8 +109,9 @@ async def run_worker(broker: AsyncBroker, worker_count=2, shutdown_timeout=90): callback=_recorder_callback, ) @click.option("-x", "--max-exceptions", type=int, default=3) +@click.option("--debug", is_flag=True, default=False) @click.argument("path") -def run(cli_ctx, account, runner_class, recorder, max_exceptions, path): +def run(cli_ctx, account, runner_class, recorder, max_exceptions, debug, path): if not runner_class: # NOTE: Automatically select runner class if cli_ctx.provider.ws_uri: @@ -124,7 +125,7 @@ def run(cli_ctx, account, runner_class, recorder, max_exceptions, path): app = import_from_string(path) runner = runner_class(app, recorder=recorder, max_exceptions=max_exceptions) - asyncio.run(runner.run()) + asyncio.run(runner.run(), debug=debug) @cli.command(cls=ConnectedProviderCommand, help="Run Silverback application task workers") @@ -138,7 +139,11 @@ def run(cli_ctx, account, runner_class, recorder, max_exceptions, path): @click.option("-w", "--workers", type=int, default=2) @click.option("-x", "--max-exceptions", type=int, default=3) @click.option("-s", "--shutdown_timeout", type=int, default=90) +@click.option("--debug", is_flag=True, default=False) @click.argument("path") -def worker(cli_ctx, account, workers, max_exceptions, shutdown_timeout, path): +def worker(cli_ctx, account, workers, max_exceptions, shutdown_timeout, debug, path): app = import_from_string(path) - asyncio.run(run_worker(app.broker, worker_count=workers, shutdown_timeout=shutdown_timeout)) + asyncio.run( + run_worker(app.broker, worker_count=workers, shutdown_timeout=shutdown_timeout), + debug=debug, + ) diff --git a/silverback/runner.py b/silverback/runner.py index 3856b0e6..7d51306e 100644 --- a/silverback/runner.py +++ b/silverback/runner.py @@ -1,6 +1,8 @@ import asyncio +import atexit from abc import ABC, abstractmethod +import quattro from ape import chain from ape.logging import logger from ape.utils import ManagerAccessMixin @@ -109,6 +111,10 @@ async def _event_task(self, task_data: TaskData): handle an event handler task for the given contract event """ + def _shutdown(self): + asyncio.run(self.app.broker.shutdown(), debug=True) + logger.info("Application shutdown completed") + async def run(self): """ Run the task broker client for the assembled ``SilverbackApp`` application. @@ -124,6 +130,8 @@ async def run(self): """ # Initialize broker (run worker startup events) await self.app.broker.startup() + # NOTE: Always ensure we shutdown the broker no matter what + atexit.register(self._shutdown) # Obtain system configuration for worker result = await run_taskiq_task_wait_result( @@ -133,18 +141,18 @@ async def run(self): raise StartupFailure("Unable to determine system configuration of worker") # NOTE: Increase the specifier set here if there is a breaking change to this - if Version(result.return_value.sdk_version) not in SpecifierSet(">=0.5.0"): - # TODO: set to next breaking change release before release + if (sdk_version := Version(result.return_value.sdk_version)) not in SpecifierSet(">=0.5.0"): raise StartupFailure("Worker SDK version too old, please rebuild") if not ( system_tasks := set(TaskType(task_name) for task_name in result.return_value.task_types) ): + # NOTE: Guaranteed to be at least one because of `TaskType.SYSTEM_CONFIG` raise StartupFailure("No system tasks detected, startup failure") - # NOTE: Guaranteed to be at least one because of `TaskType.SYSTEM_CONFIG` + system_tasks_str = "\n- ".join(system_tasks) logger.info( - f"Worker using Silverback SDK v{result.return_value.sdk_version}" + f"Worker using Silverback SDK v{sdk_version}" f", available task types:\n- {system_tasks_str}" ) @@ -163,20 +171,18 @@ async def run(self): self.state = AppState(last_block_seen=-1, last_block_processed=-1) # Execute Silverback startup task before we init the rest - startup_taskdata_result = await run_taskiq_task_wait_result( - self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.STARTUP - ) - - if startup_taskdata_result.is_err: - raise StartupFailure(startup_taskdata_result.error) - - else: - startup_task_handlers = map( - self._create_task_kicker, startup_taskdata_result.return_value + if ( + startup_taskdata_result := await run_taskiq_task_wait_result( + self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.STARTUP ) + ).is_err: + raise StartupFailure(startup_taskdata_result.error) + elif startup_task_handlers := tuple( + map(self._create_task_kicker, startup_taskdata_result.return_value) + ): startup_task_results = await run_taskiq_task_group_wait_results( - (task_handler for task_handler in startup_task_handlers), self.state + startup_task_handlers, self.state ) if any(result.is_err for result in startup_task_results): @@ -187,21 +193,26 @@ async def run(self): elif self.recorder: converted_results = map(TaskResult.from_taskiq, startup_task_results) - await asyncio.gather(*(self.recorder.add_result(r) for r in converted_results)) + await quattro.gather(*(self.recorder.add_result(r) for r in converted_results)) - # NOTE: No need to handle results otherwise + # else: No need to handle results otherwise + + else: + logger.info("No startup tasks detected") # Create our long-running event listeners - new_block_taskdata_results = await run_taskiq_task_wait_result( - self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.NEW_BLOCK - ) - if new_block_taskdata_results.is_err: + if ( + new_block_taskdata_results := await run_taskiq_task_wait_result( + self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.NEW_BLOCK + ) + ).is_err: raise StartupFailure(new_block_taskdata_results.error) - event_log_taskdata_results = await run_taskiq_task_wait_result( - self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.EVENT_LOG - ) - if event_log_taskdata_results.is_err: + if ( + event_log_taskdata_results := await run_taskiq_task_wait_result( + self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.EVENT_LOG + ) + ).is_err: raise StartupFailure(event_log_taskdata_results.error) if ( @@ -212,50 +223,28 @@ async def run(self): raise NoTasksAvailableError() # NOTE: Any propagated failure in here should be handled such that shutdown tasks also run - # TODO: `asyncio.TaskGroup` added in Python 3.11 - listener_tasks = ( - *( - asyncio.create_task(self._block_task(task_def)) - for task_def in new_block_taskdata_results.return_value - ), - *( - asyncio.create_task(self._event_task(task_def)) - for task_def in event_log_taskdata_results.return_value - ), - ) - - # NOTE: Safe to do this because no tasks were actually scheduled to run - if len(listener_tasks) == 0: - raise NoTasksAvailableError() - - # Run until one task bubbles up an exception that should stop execution - tasks_with_errors, tasks_running = await asyncio.wait( - listener_tasks, return_when=asyncio.FIRST_EXCEPTION + exceptions_or_none = await quattro.gather( + *(self._block_task(task_def) for task_def in new_block_taskdata_results.return_value), + *(self._event_task(task_def) for task_def in event_log_taskdata_results.return_value), + return_exceptions=True, ) - if runtime_errors := "\n".join(str(task.exception()) for task in tasks_with_errors): - # NOTE: In case we are somehow not displaying the error correctly with task status - logger.debug(f"Runtime error(s) detected, shutting down:\n{runtime_errors}") - # Cancel any still running - (task.cancel() for task in tasks_running) - # NOTE: All listener tasks are shut down now + # NOTE: Result is either None or Exception + if err_msg := "\n\n".join(str(e) for e in exceptions_or_none if e): + logger.error(f"Runtime error(s) detected, shutting down:\n{err_msg}") # Execute Silverback shutdown task(s) before shutting down the broker and app - shutdown_taskdata_result = await run_taskiq_task_wait_result( - self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.SHUTDOWN - ) - - if shutdown_taskdata_result.is_err: - raise StartupFailure(shutdown_taskdata_result.error) - - else: - shutdown_task_handlers = map( - self._create_task_kicker, shutdown_taskdata_result.return_value + if ( + shutdown_taskdata_result := await run_taskiq_task_wait_result( + self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.SHUTDOWN ) + ).is_err: + raise RuntimeError(shutdown_taskdata_result.error) - shutdown_task_results = await run_taskiq_task_group_wait_results( - (task_handler for task_handler in shutdown_task_handlers) - ) + elif shutdown_task_handlers := tuple( + map(self._create_task_kicker, shutdown_taskdata_result.return_value) + ): + shutdown_task_results = await run_taskiq_task_group_wait_results(shutdown_task_handlers) if any(result.is_err for result in shutdown_task_results): errors_str = "\n".join( @@ -265,11 +254,14 @@ async def run(self): elif self.recorder: converted_results = map(TaskResult.from_taskiq, shutdown_task_results) - await asyncio.gather(*(self.recorder.add_result(r) for r in converted_results)) + await quattro.gather(*(self.recorder.add_result(r) for r in converted_results)) + + # else: No need to handle results otherwise - # NOTE: No need to handle results otherwise + else: + logger.info("No shutdown tasks detected") - await self.app.broker.shutdown() + # NOTE: atexit handles self.app.broker.shutdown() class WebsocketRunner(BaseRunner, ManagerAccessMixin): diff --git a/silverback/subscriptions.py b/silverback/subscriptions.py index d99a6488..fe836ffb 100644 --- a/silverback/subscriptions.py +++ b/silverback/subscriptions.py @@ -1,8 +1,10 @@ import asyncio import json +from collections import defaultdict from enum import Enum from typing import AsyncGenerator +import quattro from ape.logging import logger from websockets import ConnectionClosedError from websockets import client as ws_client @@ -28,7 +30,7 @@ def __init__(self, ws_provider_uri: str): # Stateful self._connection: ws_client.WebSocketClientProtocol | None = None self._last_request: int = 0 - self._subscriptions: dict[str, asyncio.Queue] = {} + self._subscriptions: dict[str, asyncio.Queue] = defaultdict(asyncio.Queue) self._rpc_msg_buffer: list[dict] = [] self._ws_lock = asyncio.Lock() @@ -36,17 +38,17 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__} uri={self._ws_provider_uri}>" async def __aenter__(self) -> "Web3SubscriptionsManager": - self.connection = await ws_client.connect(self._ws_provider_uri) + self._connection = await ws_client.connect(self._ws_provider_uri) return self def __aiter__(self) -> "Web3SubscriptionsManager": return self async def __anext__(self) -> str: - if not self.connection: + if not self._connection: raise StopAsyncIteration - message = await self.connection.recv() + message = await self._connection.recv() # TODO: Handle retries when connection breaks response = json.loads(message) @@ -56,9 +58,6 @@ async def __anext__(self) -> str: logger.debug(f"Corrupted subscription data: {response}") return response - if sub_id not in self._subscriptions: - self._subscriptions[sub_id] = asyncio.Queue() - await self._subscriptions[sub_id].put(sub_params.get("result", {})) else: @@ -94,7 +93,7 @@ async def _get_response(self, request_id: int) -> dict: raise RuntimeError("Timeout waiting for response.") async def subscribe(self, type: SubscriptionType, **filter_params) -> str: - if not self.connection: + if not self._connection: raise ValueError("Connection required.") if type is SubscriptionType.BLOCKS and filter_params: @@ -104,7 +103,7 @@ async def subscribe(self, type: SubscriptionType, **filter_params) -> str: "eth_subscribe", [type.value, filter_params] if type is SubscriptionType.EVENTS else [type.value], ) - await self.connection.send(json.dumps(request)) + await self._connection.send(json.dumps(request)) response = await self._get_response(request.get("id") or self._last_request) sub_id = response.get("result") @@ -116,24 +115,27 @@ async def subscribe(self, type: SubscriptionType, **filter_params) -> str: async def get_subscription_data(self, sub_id: str) -> AsyncGenerator[dict, None]: while True: - if not (queue := self._subscriptions.get(sub_id)) or queue.empty(): + if self._subscriptions[sub_id].empty(): async with self._ws_lock: # Keep pulling until a message comes to process # NOTE: Python <3.10 does not support `anext` function await self.__anext__() else: - yield await queue.get() + yield await self._subscriptions[sub_id].get() async def unsubscribe(self, sub_id: str) -> bool: if sub_id not in self._subscriptions: raise ValueError(f"Unknown sub_id '{sub_id}'") - if not self.connection: + if not self._connection: # Nothing to unsubscribe. return True request = self._create_request("eth_unsubscribe", [sub_id]) - await self.connection.send(json.dumps(request)) + try: + await self._connection.send(json.dumps(request)) + except ConnectionClosedError: + return False response = await self._get_response(request.get("id") or self._last_request) if success := response.get("result", False): @@ -142,16 +144,16 @@ async def unsubscribe(self, sub_id: str) -> bool: return success async def __aexit__(self, exc_type, exc, tb): - try: - # Try to gracefully unsubscribe to all events - await asyncio.gather(*(self.unsubscribe(sub_id) for sub_id in self._subscriptions)) - - except ConnectionClosedError: - pass # Websocket already closed (ctrl+C and patiently waiting) - - finally: - # Disconnect and release websocket - try: - await self.connection.close() - except RuntimeError: - pass # No running event loop to disconnect from (multiple ctrl+C presses) + if not all( + is_successful is True + for is_successful in await quattro.gather( + # Try to gracefully unsubscribe to all events + *(self.unsubscribe(sub_id) for sub_id in self._subscriptions), + # NOTE: Do not catch error + return_exceptions=True, + ) + ): + logger.debug("Failed to unsubscribe from all tasks") + + # Disconnect and release websocket + await self._connection.close()