diff --git a/alchemiscale/compute/api.py b/alchemiscale/compute/api.py index 9337055b..c1f1a7f4 100644 --- a/alchemiscale/compute/api.py +++ b/alchemiscale/compute/api.py @@ -13,6 +13,7 @@ from fastapi import FastAPI, APIRouter, Body, Depends from fastapi.middleware.gzip import GZipMiddleware from gufe.tokenization import GufeTokenizable, JSON_HANDLER +from gufe.protocols import ProtocolDAGResult from ..base.api import ( QueryGUFEHandler, @@ -329,7 +330,7 @@ def set_task_result( validate_scopes(task_sk.scope, token) pdr = json.loads(protocoldagresult, cls=JSON_HANDLER.decoder) - pdr = GufeTokenizable.from_dict(pdr) + pdr: ProtocolDAGResult = GufeTokenizable.from_dict(pdr) tf_sk, _ = n4js.get_task_transformation( task=task_scoped_key, @@ -351,7 +352,11 @@ def set_task_result( if protocoldagresultref.ok: n4js.set_task_complete(tasks=[task_sk]) else: + n4js.add_protocol_dag_result_ref_tracebacks( + pdr.protocol_unit_failures, result_sk + ) n4js.set_task_error(tasks=[task_sk]) + n4js.resolve_task_restarts(tasks=[task_sk]) return result_sk diff --git a/alchemiscale/interface/api.py b/alchemiscale/interface/api.py index 5b6aeb1e..a5211010 100644 --- a/alchemiscale/interface/api.py +++ b/alchemiscale/interface/api.py @@ -976,6 +976,100 @@ def get_task_transformation( return str(transformation) +@router.post("/networks/{network_scoped_key}/restartpatterns/add") +def add_task_restart_patterns( + network_scoped_key: str, + *, + patterns: list[str] = Body(embed=True), + num_allowed_restarts: int = Body(embed=True), + n4js: Neo4jStore = Depends(get_n4js_depends), + token: TokenData = Depends(get_token_data_depends), +): + sk = ScopedKey.from_str(network_scoped_key) + validate_scopes(sk.scope, token) + + taskhub_scoped_key = n4js.get_taskhub(sk) + n4js.add_task_restart_patterns(taskhub_scoped_key, patterns, num_allowed_restarts) + + +@router.post("/networks/{network_scoped_key}/restartpatterns/remove") +def remove_task_restart_patterns( + network_scoped_key: str, + *, + patterns: list[str] = Body(embed=True), + n4js: Neo4jStore = Depends(get_n4js_depends), + token: TokenData = Depends(get_token_data_depends), +): + sk = ScopedKey.from_str(network_scoped_key) + validate_scopes(sk.scope, token) + + taskhub_scoped_key = n4js.get_taskhub(sk) + n4js.remove_task_restart_patterns(taskhub_scoped_key, patterns) + + +@router.get("/networks/{network_scoped_key}/restartpatterns/clear") +def clear_task_restart_patterns( + network_scoped_key: str, + *, + n4js: Neo4jStore = Depends(get_n4js_depends), + token: TokenData = Depends(get_token_data_depends), +): + sk = ScopedKey.from_str(network_scoped_key) + validate_scopes(sk.scope, token) + + taskhub_scoped_key = n4js.get_taskhub(sk) + n4js.clear_task_restart_patterns(taskhub_scoped_key) + return [network_scoped_key] + + +@router.post("/bulk/networks/restartpatterns/get") +def get_task_restart_patterns( + *, + networks: list[str] = Body(embed=True), + n4js: Neo4jStore = Depends(get_n4js_depends), + token: TokenData = Depends(get_token_data_depends), +) -> dict[str, set[tuple[str, int]]]: + + network_scoped_keys = [ScopedKey.from_str(network) for network in networks] + for sk in network_scoped_keys: + validate_scopes(sk.scope, token) + + taskhub_scoped_keys = n4js.get_taskhubs(network_scoped_keys) + + taskhub_network_map = { + taskhub_scoped_key: network_scoped_key + for taskhub_scoped_key, network_scoped_key in zip( + taskhub_scoped_keys, network_scoped_keys + ) + } + + restart_patterns = n4js.get_task_restart_patterns(taskhub_scoped_keys) + + network_patterns = { + str(taskhub_network_map[key]): value for key, value in restart_patterns.items() + } + + return network_patterns + + +@router.post("/networks/{network_scoped_key}/restartpatterns/maxretries") +def set_task_restart_patterns_max_retries( + network_scoped_key: str, + *, + patterns: list[str] = Body(embed=True), + max_retries: int = Body(embed=True), + n4js: Neo4jStore = Depends(get_n4js_depends), + token: TokenData = Depends(get_token_data_depends), +): + sk = ScopedKey.from_str(network_scoped_key) + validate_scopes(sk.scope, token) + + taskhub_scoped_key = n4js.get_taskhub(sk) + n4js.set_task_restart_patterns_max_retries( + taskhub_scoped_key, patterns, max_retries + ) + + ### results diff --git a/alchemiscale/interface/client.py b/alchemiscale/interface/client.py index 7bd1311f..72a93ef9 100644 --- a/alchemiscale/interface/client.py +++ b/alchemiscale/interface/client.py @@ -1602,7 +1602,6 @@ def get_transformation_results( visualize If ``True``, show retrieval progress indicators. - """ if not return_protocoldagresults: @@ -1739,3 +1738,112 @@ def get_task_failures( ) return pdrs + + def add_task_restart_patterns( + self, + network_scoped_key: ScopedKey, + patterns: list[str], + num_allowed_restarts: int, + ) -> ScopedKey: + """Add a list of `Task` restart patterns to an `AlchemicalNetwork`. + + Parameters + ---------- + network_scoped_key + The `ScopedKey` for the `AlchemicalNetwork` to add the patterns to. + patterns + The regular expression strings to compare to `ProtocolUnitFailure` + tracebacks. Matching patterns will set the `Task` status back to + 'waiting'. + num_allowed_restarts + The number of times each pattern will be able to restart each + `Task`. When this number is exceeded, the `Task` is canceled from + the `AlchemicalNetwork` and left with the `error` status. + + Returns + ------- + network_scoped_key + The `ScopedKey` of the `AlchemicalNetwork` the patterns were added to. + """ + data = {"patterns": patterns, "num_allowed_restarts": num_allowed_restarts} + self._post_resource(f"/networks/{network_scoped_key}/restartpatterns/add", data) + return network_scoped_key + + def get_task_restart_patterns( + self, network_scoped_key: ScopedKey + ) -> dict[str, int]: + """Get the `Task` restart patterns applied to an `AlchemicalNetwork` + along with the number of retries allowed for each pattern. + + Parameters + ---------- + network_scoped_key + The `ScopedKey` of the `AlchemicalNetwork` to query. + + Returns + ------- + patterns + A dictionary whose keys are all of the patterns applied to the + `AlchemicalNetwork` and whose values are the number of retries each + pattern will allow. + """ + data = {"networks": [str(network_scoped_key)]} + mapped_patterns = self._post_resource( + "/bulk/networks/restartpatterns/get", data=data + ) + network_patterns = mapped_patterns[str(network_scoped_key)] + patterns_with_retries = {pattern: retry for pattern, retry in network_patterns} + return patterns_with_retries + + def set_task_restart_patterns_allowed_restarts( + self, + network_scoped_key: ScopedKey, + patterns: list[str], + num_allowed_restarts: int, + ) -> None: + """Set the number of `Task` restarts that patterns are allowed to + perform for the given `AlchemicalNetwork`. + + Parameters + ---------- + network_scoped_key + The `ScopedKey` of the `AlchemicalNetwork` the `patterns` are + applied to. + patterns + The patterns to set the number of allowed restarts for. + num_allowed_restarts + The new number of allowed restarts. + """ + data = {"patterns": patterns, "max_retries": num_allowed_restarts} + self._post_resource( + f"/networks/{network_scoped_key}/restartpatterns/maxretries", data + ) + + def remove_task_restart_patterns( + self, network_scoped_key: ScopedKey, patterns: list[str] + ) -> None: + """Remove specific `Task` restart patterns from an `AlchemicalNetwork`. + + Parameters + ---------- + network_scoped_key + The `ScopedKey` of the `AlchemicalNetwork` the `patterns` are + applied to. + patterns + The patterns to remove from the `AlchemicalNetwork`. + """ + data = {"patterns": patterns} + self._post_resource( + f"/networks/{network_scoped_key}/restartpatterns/remove", data + ) + + def clear_task_restart_patterns(self, network_scoped_key: ScopedKey) -> None: + """Clear all restart patterns from an `AlchemicalNetwork`. + + Parameters + ---------- + network_scoped_key + The `ScopedKey` of the `AlchemicalNetwork` to be cleared of restart + patterns. + """ + self._query_resource(f"/networks/{network_scoped_key}/restartpatterns/clear") diff --git a/alchemiscale/storage/models.py b/alchemiscale/storage/models.py index c9b000b8..844e2ebf 100644 --- a/alchemiscale/storage/models.py +++ b/alchemiscale/storage/models.py @@ -8,12 +8,12 @@ from copy import copy from datetime import datetime from enum import Enum -from typing import Union, Dict, Optional +from typing import Union, Optional, List from uuid import uuid4 import hashlib -from pydantic import BaseModel, Field +from pydantic import BaseModel from gufe.tokenization import GufeTokenizable, GufeKey from ..models import ScopedKey, Scope @@ -143,6 +143,113 @@ def _defaults(cls): return super()._defaults() +class TaskRestartPattern(GufeTokenizable): + """A pattern to compare returned Task tracebacks to. + + Attributes + ---------- + pattern: str + A regular expression pattern that can match to returned tracebacks of errored Tasks. + max_retries: int + The number of times the pattern can trigger a restart for a Task. + taskhub_sk: str + The TaskHub the pattern is bound to. This is needed to properly set a unique Gufe key. + """ + + pattern: str + max_retries: int + taskhub_sk: str + + def __init__( + self, pattern: str, max_retries: int, taskhub_scoped_key: Union[str, ScopedKey] + ): + + if not isinstance(pattern, str) or pattern == "": + raise ValueError("`pattern` must be a non-empty string") + + self.pattern = pattern + + if not isinstance(max_retries, int) or max_retries <= 0: + raise ValueError("`max_retries` must have a positive integer value.") + self.max_retries = max_retries + + self.taskhub_scoped_key = str(taskhub_scoped_key) + + def _gufe_tokenize(self): + key_string = self.pattern + self.taskhub_scoped_key + return hashlib.md5(key_string.encode()).hexdigest() + + @classmethod + def _defaults(cls): + return super()._defaults() + + @classmethod + def _from_dict(cls, dct): + return cls(**dct) + + def _to_dict(self): + return { + "pattern": self.pattern, + "max_retries": self.max_retries, + "taskhub_scoped_key": self.taskhub_scoped_key, + } + + # TODO: should this also compare taskhub scoped keys? + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return self.pattern == other.pattern + + +class Tracebacks(GufeTokenizable): + """ + Attributes + ---------- + tracebacks: list[str] + The tracebacks returned with the ProtocolUnitFailures. + source_keys: list[GufeKey] + The GufeKeys of the ProtocolUnits that failed. + failure_keys: list[GufeKey] + The GufeKeys of the ProtocolUnitFailures. + """ + + def __init__( + self, + tracebacks: List[str], + source_keys: List[GufeKey], + failure_keys: List[GufeKey], + ): + value_error = ValueError( + "`tracebacks` must be a non-empty list of non-empty string values" + ) + if not isinstance(tracebacks, list) or tracebacks == []: + raise value_error + + all_string_values = all([isinstance(value, str) for value in tracebacks]) + if not all_string_values or "" in tracebacks: + raise value_error + + # TODO: validate + self.tracebacks = tracebacks + self.source_keys = source_keys + self.failure_keys = failure_keys + + @classmethod + def _defaults(cls): + return super()._defaults() + + @classmethod + def _from_dict(cls, dct): + return cls(**dct) + + def _to_dict(self): + return { + "tracebacks": self.tracebacks, + "source_keys": self.source_keys, + "failure_keys": self.failure_keys, + } + + class TaskHub(GufeTokenizable): """ diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 801b6600..f034c727 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -8,9 +8,11 @@ from datetime import datetime from contextlib import contextmanager import json -from functools import lru_cache -from operator import ne -from typing import Dict, List, Optional, Union, Tuple +import re +from functools import lru_cache, update_wrapper +from typing import Dict, List, Optional, Union, Tuple, Set +from collections import defaultdict +from collections.abc import Iterable import weakref import numpy as np @@ -23,6 +25,7 @@ ) from gufe.settings import SettingsBaseModel from gufe.tokenization import GufeTokenizable, GufeKey, JSON_HANDLER +from gufe.protocols import ProtocolUnitFailure from neo4j import Transaction, GraphDatabase, Driver @@ -31,10 +34,12 @@ ComputeServiceRegistration, NetworkMark, NetworkStateEnum, + ProtocolDAGResultRef, Task, TaskHub, + TaskRestartPattern, TaskStatusEnum, - ProtocolDAGResultRef, + Tracebacks, ) from ..strategies import Strategy from ..models import Scope, ScopedKey @@ -158,6 +163,19 @@ def transaction(self, ignore_exceptions=False) -> Transaction: else: tx.commit() + def chainable(func): + def inner(self, *args, **kwargs): + if kwargs.get("tx") is not None: + return func(self, *args, **kwargs) + + with self.transaction() as tx: + kwargs.update(tx=tx) + return func(self, *args, **kwargs) + + update_wrapper(inner, func) + + return inner + def execute_query(self, *args, **kwargs): kwargs.update({"database_": self.db_name}) return self.graph.execute_query(*args, **kwargs) @@ -1183,6 +1201,54 @@ def query_taskhubs( """ return self._query(qualname="TaskHub", scope=scope, return_gufe=return_gufe) + def get_taskhubs( + self, network_scoped_keys: list[ScopedKey], return_gufe: bool = False + ) -> list[Union[ScopedKey, TaskHub]]: + """Get the TaskHubs for the given AlchemicalNetworks. + + Parameters + ---------- + return_gufe + If True, return `TaskHub` instances. + Otherwise, return `ScopedKey`s. + + """ + + # TODO: this could fail better, report all instances rather than first + for network_scoped_key in network_scoped_keys: + if network_scoped_key.qualname != "AlchemicalNetwork": + raise ValueError( + "`network` ScopedKey does not correspond to an `AlchemicalNetwork`" + ) + + query = """ + UNWIND $network_scoped_keys AS network_scoped_key + MATCH (th:TaskHub {network: network_scoped_key})-[:PERFORMS]->(an:AlchemicalNetwork) + RETURN th, an + """ + + query_results = self.execute_query( + query, network_scoped_keys=list(map(str, network_scoped_keys)) + ) + + def _node_to_gufe(node): + return self._subgraph_to_gufe([node], node)[node] + + def _node_to_scoped_key(node): + return ScopedKey.from_str(node["_scoped_key"]) + + transform_function = _node_to_gufe if return_gufe else _node_to_scoped_key + transform_results = defaultdict(None) + for record in query_results.records: + node = record_data_to_node(record["th"]) + network_scoped_key = record["an"]["_scoped_key"] + transform_results[network_scoped_key] = transform_function(node) + + return [ + transform_results[str(network_scoped_key)] + for network_scoped_key in network_scoped_keys + ] + def get_taskhub( self, network: ScopedKey, return_gufe: bool = False ) -> Union[ScopedKey, TaskHub]: @@ -1195,27 +1261,8 @@ def get_taskhub( Otherwise, return a `ScopedKey`. """ - if network.qualname != "AlchemicalNetwork": - raise ValueError( - "`network` ScopedKey does not correspond to an `AlchemicalNetwork`" - ) - q = """ - MATCH (th:TaskHub {network: $network})-[:PERFORMS]->(an:AlchemicalNetwork) - RETURN th - """ - - try: - node = record_data_to_node( - self.execute_query(q, network=str(network)).records[0]["th"] - ) - except IndexError: - raise KeyError("No such object in database") - - if return_gufe: - return self._subgraph_to_gufe([node], node)[node] - else: - return ScopedKey.from_str(node["_scoped_key"]) + return self.get_taskhubs([network], return_gufe)[0] def delete_taskhub( self, @@ -1396,32 +1443,51 @@ def action_tasks( # so we can properly return `None` if needed task_map = {str(task): None for task in tasks} - tasks_scoped_keys = [str(task) for task in tasks if task is not None] + task_scoped_keys = [str(task) for task in tasks if task is not None] - q = f""" + q = """ // get our TaskHub - UNWIND $tasks as task_sk - MATCH (th:TaskHub {{_scoped_key: $taskhub}})-[:PERFORMS]->(an:AlchemicalNetwork) + UNWIND $task_scoped_keys as task_sk + MATCH (th:TaskHub {_scoped_key: $taskhub_scoped_key})-[:PERFORMS]->(an:AlchemicalNetwork) // get the task we want to add to the hub; check that it connects to same network - MATCH (task:Task {{_scoped_key: task_sk}})-[:PERFORMS]->(tf:Transformation|NonTransformation)<-[:DEPENDS_ON]-(an) + MATCH (task:Task {_scoped_key: task_sk})-[:PERFORMS]->(:Transformation|NonTransformation)<-[:DEPENDS_ON]-(an) // only proceed for cases where task is not already actioned on hub // and where the task is either in 'waiting', 'running', or 'error' status WITH th, an, task WHERE NOT (th)-[:ACTIONS]->(task) - AND task.status IN ['{TaskStatusEnum.waiting.value}', '{TaskStatusEnum.running.value}', '{TaskStatusEnum.error.value}'] + AND task.status IN [$waiting, $running, $error] // create the connection - CREATE (th)-[ar:ACTIONS {{weight: 0.5}}]->(task) + CREATE (th)-[ar:ACTIONS {weight: 0.5}]->(task) // set the task property to the scoped key of the Task // this is a convenience for when we have to loop over relationships in Python SET ar.task = task._scoped_key + // we want to preserve the list of tasks for the return, so we need to make a subquery + // since the subsequent WHERE clause could reduce the records in task + WITH task, th + CALL { + WITH task, th + MATCH (trp: TaskRestartPattern)-[:ENFORCES]->(th) + WHERE NOT (trp)-[:APPLIES]->(task) + + CREATE (trp)-[:APPLIES {num_retries: 0}]->(task) + } + RETURN task """ - results = self.execute_query(q, tasks=tasks_scoped_keys, taskhub=str(taskhub)) + + results = self.execute_query( + q, + task_scoped_keys=task_scoped_keys, + taskhub_scoped_key=str(taskhub), + waiting=TaskStatusEnum.waiting.value, + running=TaskStatusEnum.running.value, + error=TaskStatusEnum.error.value, + ) # update our map with the results, leaving None for tasks that aren't found for task_record in results.records: @@ -1566,10 +1632,12 @@ def get_task_weights( return weights + @chainable def cancel_tasks( self, tasks: List[ScopedKey], taskhub: ScopedKey, + tx=None, ) -> List[Union[ScopedKey, None]]: """Remove Tasks from the TaskHub for a given AlchemicalNetwork. @@ -1581,15 +1649,23 @@ def cancel_tasks( """ query = """ UNWIND $task_scoped_keys AS task_scoped_key - MATCH (:TaskHub {_scoped_key: $taskhub_scoped_key})-[ar:ACTIONS]->(task:Task {_scoped_key: task_scoped_key}) + MATCH (th:TaskHub {_scoped_key: $taskhub_scoped_key})-[ar:ACTIONS]->(task:Task {_scoped_key: task_scoped_key}) DELETE ar + + WITH task, th + CALL { + WITH task, th + MATCH (task)<-[applies:APPLIES]-(:TaskRestartPattern)-[:ENFORCES]->(th) + DELETE applies + } + RETURN task._scoped_key as task_scoped_key """ - results = self.execute_query( + results = tx.run( query, task_scoped_keys=list(map(str, tasks)), taskhub_scoped_key=str(taskhub), - ) + ).to_eager_result() returned_keys = {record["task_scoped_key"] for record in results.records} filtered_tasks = [ @@ -2426,6 +2502,59 @@ def get_task_failures(self, task: ScopedKey) -> List[ProtocolDAGResultRef]: """ return self._get_protocoldagresultrefs(q, task) + def add_protocol_dag_result_ref_tracebacks( + self, + protocol_unit_failures: List[ProtocolUnitFailure], + protocol_dag_result_ref_scoped_key: ScopedKey, + ): + subgraph = Subgraph() + + with self.transaction() as tx: + + query = """ + MATCH (pdrr:ProtocolDAGResultRef {`_scoped_key`: $protocol_dag_result_ref_scoped_key}) + RETURN pdrr + """ + + pdrr_result = tx.run( + query, + protocol_dag_result_ref_scoped_key=str( + protocol_dag_result_ref_scoped_key + ), + ).to_eager_result() + + try: + protocol_dag_result_ref_node = record_data_to_node( + pdrr_result.records[0]["pdrr"] + ) + except IndexError: + raise KeyError("Could not find ProtocolDAGResultRef in database.") + + failure_keys = [] + source_keys = [] + tracebacks = [] + + for puf in protocol_unit_failures: + failure_keys.append(puf.key) + source_keys.append(puf.source_key) + tracebacks.append(puf.traceback) + + tracebacks = Tracebacks(tracebacks, source_keys, failure_keys) + + _, tracebacks_node, _ = self._gufe_to_subgraph( + tracebacks.to_shallow_dict(), + labels=["GufeTokenizable", tracebacks.__class__.__name__], + gufe_key=tracebacks.key, + scope=protocol_dag_result_ref_scoped_key.scope, + ) + + subgraph |= Relationship.type("DETAILS")( + tracebacks_node, + protocol_dag_result_ref_node, + ) + + merge_subgraph(tx, subgraph, "GufeTokenizable", "_scoped_key") + def set_task_status( self, tasks: List[ScopedKey], status: TaskStatusEnum, raise_error: bool = False ) -> List[Optional[ScopedKey]]: @@ -2592,9 +2721,11 @@ def set_task_complete( WITH scoped_key, t, t_ // if we changed the status to complete, - // drop all ACTIONS relationships + // drop all taskhub ACTIONS and task restart APPLIES relationships OPTIONAL MATCH (t_)<-[ar:ACTIONS]-(th:TaskHub) + OPTIONAL MATCH (t_)<-[applies:APPLIES]-(:TaskRestartPattern) DELETE ar + DELETE applies WITH scoped_key, t, t_ @@ -2676,10 +2807,14 @@ def set_task_invalid( WITH scoped_key, t, t_, extends_task OPTIONAL MATCH (t_)<-[ar:ACTIONS]-(th:TaskHub) - OPTIONAL MATCH (extends_task)<-[are:ACTIONS]-(th:TaskHub) + OPTIONAL MATCH (extends_task)<-[ar_e:ACTIONS]-(th:TaskHub) + OPTIONAL MATCH (t_)<-[applies:APPLIES]-(:TaskRestartPattern) + OPTIONAL MATCH (extends_task)<-[applies_e:APPLIES]-(:TaskRestartPattern) DELETE ar - DELETE are + DELETE ar_e + DELETE applies + DELETE applies_e WITH scoped_key, t, t_ @@ -2726,10 +2861,14 @@ def set_task_deleted( WITH scoped_key, t, t_, extends_task OPTIONAL MATCH (t_)<-[ar:ACTIONS]-(th:TaskHub) - OPTIONAL MATCH (extends_task)<-[are:ACTIONS]-(th:TaskHub) + OPTIONAL MATCH (extends_task)<-[ar_e:ACTIONS]-(th:TaskHub) + OPTIONAL MATCH (t_)<-[applies:APPLIES]-(:TaskRestartPattern) + OPTIONAL MATCH (extends_task)<-[applies_e:APPLIES]-(:TaskRestartPattern) DELETE ar - DELETE are + DELETE ar_e + DELETE applies + DELETE applies_e WITH scoped_key, t, t_ @@ -2745,6 +2884,328 @@ def err_msg(t, status): return self._set_task_status(tasks, q, err_msg, raise_error=raise_error) + ## task restart policies + + def add_task_restart_patterns( + self, taskhub: ScopedKey, patterns: list[str], number_of_retries: int + ): + """Add a list of restart policy patterns to a `TaskHub` along with the number of retries allowed. + + Parameters + ---------- + taskhub : ScopedKey + TaskHub for the restart patterns to enforce. + patterns: list[str] + Regular expression patterns that will be compared to tracebacks returned by ProtocolUnitFailures. + number_of_retries: int + The number of times the given patterns will apply to a single Task, attempts to restart beyond + this value will result in a canceled Task with an error status. + + Raises + ------ + KeyError + Raised when the provided TaskHub ScopedKey cannot be associated with a TaskHub in the database. + """ + + # get taskhub node + q = """ + MATCH (th:TaskHub {`_scoped_key`: $taskhub}) + RETURN th + """ + results = self.execute_query(q, taskhub=str(taskhub)) + + # raise error if taskhub not found + if not results.records: + raise KeyError("No such TaskHub in the database") + + record_data = results.records[0]["th"] + taskhub_node = record_data_to_node(record_data) + scope = taskhub.scope + + with self.transaction() as tx: + actioned_tasks_query = """ + MATCH (taskhub: TaskHub {`_scoped_key`: $taskhub_scoped_key})-[:ACTIONS]->(task: Task) + RETURN task + """ + + actioned_task_records = ( + tx.run(actioned_tasks_query, taskhub_scoped_key=str(taskhub)) + .to_eager_result() + .records + ) + + subgraph = Subgraph() + + actioned_task_nodes = [] + + for actioned_tasks_record in actioned_task_records: + actioned_task_nodes.append( + record_data_to_node(actioned_tasks_record["task"]) + ) + + for pattern in patterns: + task_restart_pattern = TaskRestartPattern( + pattern, + max_retries=number_of_retries, + taskhub_scoped_key=str(taskhub), + ) + + _, task_restart_pattern_node, scoped_key = self._gufe_to_subgraph( + task_restart_pattern.to_shallow_dict(), + labels=["GufeTokenizable", task_restart_pattern.__class__.__name__], + gufe_key=task_restart_pattern.key, + scope=scope, + ) + + subgraph |= Relationship.type("ENFORCES")( + task_restart_pattern_node, + taskhub_node, + _org=scope.org, + _campaign=scope.campaign, + _project=scope.project, + ) + + for actioned_task_node in actioned_task_nodes: + subgraph |= Relationship.type("APPLIES")( + task_restart_pattern_node, + actioned_task_node, + num_retries=0, + ) + merge_subgraph(tx, subgraph, "GufeTokenizable", "_scoped_key") + + actioned_task_scoped_keys: List[ScopedKey] = [] + + for actioned_task_record in actioned_task_records: + actioned_task_scoped_keys.append( + ScopedKey.from_str(actioned_task_record["task"]["_scoped_key"]) + ) + + self.resolve_task_restarts(actioned_task_scoped_keys, tx=tx) + + def remove_task_restart_patterns(self, taskhub: ScopedKey, patterns: list[str]): + """Remove a list of restart patterns enforcing a TaskHub from the database. + + Parameters + ---------- + taskhub: ScopedKey + The ScopedKey of the TaskHub that the patterns enforce. + patterns: list[str] + The patterns to remove. Patterns not enforcing the TaskHub are ignored. + """ + q = """ + UNWIND $patterns AS pattern + + MATCH (trp: TaskRestartPattern {pattern: pattern, taskhub_scoped_key: $taskhub_scoped_key}) + + DETACH DELETE trp + """ + + self.execute_query(q, patterns=patterns, taskhub_scoped_key=str(taskhub)) + + def clear_task_restart_patterns(self, taskhub: ScopedKey): + """Clear all restart patterns from a TaskHub. + + Parameters + ---------- + taskhub: ScopedKey + The ScopedKey of the TaskHub to clear of restart patterns. + """ + q = """ + MATCH (trp: TaskRestartPattern {taskhub_scoped_key: $taskhub_scoped_key}) + DETACH DELETE trp + """ + self.execute_query(q, taskhub_scoped_key=str(taskhub)) + + def set_task_restart_patterns_max_retries( + self, + taskhub_scoped_key: ScopedKey, + patterns: list[str], + max_retries: int, + ): + """Set the maximum number of retries of a pattern enforcing a TaskHub. + + Parameters + ---------- + taskhub_scoped_key: ScopedKey + The ScopedKey of the TaskHub that the patterns enforce. + patterns: list[str] + The patterns to change the maximum retries value for. + max_retries: int + The new maximum retries value. + """ + query = """ + UNWIND $patterns AS pattern + MATCH (trp: TaskRestartPattern {pattern: pattern, taskhub_scoped_key: $taskhub_scoped_key}) + SET trp.max_retries = $max_retries + """ + + self.execute_query( + query, + patterns=patterns, + taskhub_scoped_key=str(taskhub_scoped_key), + max_retries=max_retries, + ) + + # TODO: validation of taskhubs variable, will fail in weird ways if not enforced + def get_task_restart_patterns( + self, taskhubs: list[ScopedKey] + ) -> dict[ScopedKey, set[tuple[str, int]]]: + """For a list of TaskHub ScopedKeys, get the associated restart + patterns along with the maximum number of retries for each pattern. + + Parameters + ---------- + taskhubs: list[ScopedKey] + The ScopedKeys of the TaskHubs to get the restart patterns of. + + Returns + ------- + dict[ScopedKey, set[tuple[str, int]]] + A dictionary with ScopedKeys of the TaskHubs provided as keys, and a + set of tuples containing the patterns enforcing each TaskHub along + with their associated maximum number of retries as values. + """ + + q = """ + UNWIND $taskhub_scoped_keys as taskhub_scoped_key + MATCH (trp: TaskRestartPattern)-[ENFORCES]->(th: TaskHub {`_scoped_key`: taskhub_scoped_key}) + RETURN th, trp + """ + + records = self.execute_query( + q, taskhub_scoped_keys=list(map(str, taskhubs)) + ).records + + data: dict[ScopedKey, set[tuple[str, int]]] = { + taskhub: set() for taskhub in taskhubs + } + + for record in records: + pattern = record["trp"]["pattern"] + max_retries = record["trp"]["max_retries"] + taskhub_sk = ScopedKey.from_str(record["th"]["_scoped_key"]) + data[taskhub_sk].add((pattern, max_retries)) + + return data + + @chainable + def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=None): + """Determine whether or not Tasks need to be restarted or canceled and perform that action. + + Parameters + ---------- + task_scoped_keys: Iterable[ScopedKey] + An iterable of Task ScopedKeys that need to be resolved. Tasks without the error status + are filtered out and ignored. + """ + + # Given the scoped keys of a list of Tasks, find all tasks that have an + # error status and have a TaskRestartPattern applied. A subquery is executed + # to optionally get the latest traceback associated with the task + query = """ + UNWIND $task_scoped_keys AS task_scoped_key + MATCH (task:Task {status: $error, `_scoped_key`: task_scoped_key})<-[app:APPLIES]-(trp:TaskRestartPattern)-[:ENFORCES]->(taskhub:TaskHub) + CALL { + WITH task + OPTIONAL MATCH (task:Task)-[:RESULTS_IN]->(pdrr:ProtocolDAGResultRef)<-[:DETAILS]-(tracebacks:Tracebacks) + RETURN tracebacks + ORDER BY pdrr.datetime_created DESCENDING + LIMIT 1 + } + WITH task, tracebacks, trp, app, taskhub + RETURN task, tracebacks, trp, app, taskhub + """ + + results = tx.run( + query, + task_scoped_keys=list(map(str, task_scoped_keys)), + error=TaskStatusEnum.error.value, + ).to_eager_result() + + if not results: + return + + # iterate over all of the results to determine if an applied pattern needs + # to be iterated or if the task needs to be cancelled outright + + # Keep track of which task/taskhub pairs would need to be canceled + # None => the pair never had a matching restart pattern + # True => at least one patterns max_retries was exceeded + # False => at least one regex matched, but no pattern max_retries were exceeded + cancel_map: dict[Tuple[str, str], Optional[bool]] = {} + to_increment: List[Tuple[str, str]] = [] + all_task_taskhub_pairs: set[Tuple[str, str]] = set() + for record in results.records: + task_restart_pattern = record["trp"] + applies_relationship = record["app"] + task = record["task"] + taskhub = record["taskhub"] + _tracebacks = record["tracebacks"] + + task_taskhub_tuple = (task["_scoped_key"], taskhub["_scoped_key"]) + + all_task_taskhub_pairs.add(task_taskhub_tuple) + + # TODO: remove in v1.0.0 + # tasks that errored, prior to the indtroduction of task restart policies will have no tracebacks in the database + if _tracebacks is None: + cancel_map[task_taskhub_tuple] = True + + # we have already determined that the task is to be canceled. + # this is only ever truthy when we say a task needs to be canceled. + if cancel_map.get(task_taskhub_tuple): + continue + + num_retries = applies_relationship["num_retries"] + max_retries = task_restart_pattern["max_retries"] + pattern = task_restart_pattern["pattern"] + tracebacks: List[str] = _tracebacks["tracebacks"] + + compiled_pattern = re.compile(pattern) + + if any([compiled_pattern.search(message) for message in tracebacks]): + if num_retries + 1 > max_retries: + cancel_map[task_taskhub_tuple] = True + else: + to_increment.append( + (task["_scoped_key"], task_restart_pattern["_scoped_key"]) + ) + cancel_map[task_taskhub_tuple] = False + + increment_query = """ + UNWIND $task_trp_pairs as pairs + WITH pairs[0] as task_scoped_key, pairs[1] as task_restart_pattern_scoped_key + MATCH (:Task {`_scoped_key`: task_scoped_key})<-[app:APPLIES]-(:TaskRestartPattern {`_scoped_key`: task_restart_pattern_scoped_key}) + SET app.num_retries = app.num_retries + 1 + """ + + tx.run(increment_query, task_trp_pairs=to_increment) + + # cancel all tasks (from a taskhub) that didn't trigger any restart patterns (None) + # or exceeded a pattern's max_retries value (True) + cancel_groups: defaultdict[str, list[str]] = defaultdict(list) + for task_taskhub_pair in all_task_taskhub_pairs: + cancel_result = cancel_map.get(task_taskhub_pair) + if cancel_result in (True, None): + cancel_groups[task_taskhub_pair[1]].append(task_taskhub_pair[0]) + + for taskhub, tasks in cancel_groups.items(): + self.cancel_tasks(tasks, taskhub, tx=tx) + + # any tasks that are still associated with a TaskHub and a TaskRestartPattern must then be okay to switch to waiting + renew_waiting_status_query = """ + UNWIND $task_scoped_keys AS task_scoped_key + MATCH (task:Task {status: $error, `_scoped_key`: task_scoped_key})<-[app:APPLIES]-(trp:TaskRestartPattern)-[:ENFORCES]->(taskhub:TaskHub) + SET task.status = $waiting + """ + + tx.run( + renew_waiting_status_query, + task_scoped_keys=list(map(str, task_scoped_keys)), + waiting=TaskStatusEnum.waiting.value, + error=TaskStatusEnum.error.value, + ) + ## authentication def create_credentialed_entity(self, entity: CredentialedEntity): diff --git a/alchemiscale/tests/integration/conftest.py b/alchemiscale/tests/integration/conftest.py index 054c7d48..2c3629e1 100644 --- a/alchemiscale/tests/integration/conftest.py +++ b/alchemiscale/tests/integration/conftest.py @@ -177,6 +177,49 @@ def n4js_fresh(graph): return n4js +@fixture +def n4js_task_restart_policy( + n4js_fresh: Neo4jStore, network_tyk2: AlchemicalNetwork, scope_test +): + + n4js = n4js_fresh + + _, taskhub_scoped_key_with_policy, _ = n4js.assemble_network( + network_tyk2, scope_test + ) + + _, taskhub_scoped_key_no_policy, _ = n4js.assemble_network( + network_tyk2.copy_with_replacements(name=network_tyk2.name + "_no_policy"), + scope_test, + ) + + transformation_1_scoped_key, transformation_2_scoped_key = map( + lambda transformation: n4js.get_scoped_key(transformation, scope_test), + list(network_tyk2.edges)[:2], + ) + + # create 4 tasks for each of the 2 selected transformations + task_scoped_keys = n4js.create_tasks( + [transformation_1_scoped_key] * 4 + [transformation_2_scoped_key] * 4 + ) + + # action the tasks for transformation 1 on the taskhub with no policy + # action the tasks for both transformations on the taskhub with a policy + assert all(n4js.action_tasks(task_scoped_keys[:4], taskhub_scoped_key_no_policy)) + assert all(n4js.action_tasks(task_scoped_keys, taskhub_scoped_key_with_policy)) + + patterns = [ + r"Error message \d, round \d", + "This is an example pattern that will be used as a restart string.", + ] + + n4js.add_task_restart_patterns( + taskhub_scoped_key_with_policy, patterns=patterns, number_of_retries=2 + ) + + return n4js + + @fixture(scope="module") def s3objectstore_settings(): os.environ["AWS_ACCESS_KEY_ID"] = "test-key-id" diff --git a/alchemiscale/tests/integration/interface/client/test_client.py b/alchemiscale/tests/integration/interface/client/test_client.py index 7146c50f..1b3d4eac 100644 --- a/alchemiscale/tests/integration/interface/client/test_client.py +++ b/alchemiscale/tests/integration/interface/client/test_client.py @@ -2151,3 +2151,143 @@ def test_get_task_failures( # TODO: can we mix in a success in here somewhere? # not possible with current BrokenProtocol, unfortunately + + +class TestTaskRestartPolicy: + + default_max_retries = 3 + default_patterns = ["Pattern 1", "Pattern 2", "Pattern 3"] + + def create_default_network(self, network, client, scope): + network_scoped_key = client.create_network(network, scope) + client.add_task_restart_patterns( + network_scoped_key, self.default_patterns, self.default_max_retries + ) + return network_scoped_key + + def test_add_task_restart_patterns( + self, user_client, network_tyk2, scope_test, n4js_preloaded + ): + + network_scoped_key = self.create_default_network( + network_tyk2, user_client, scope_test + ) + + query = """ + MATCH (trp: TaskRestartPattern)-[:ENFORCES]->(:TaskHub)-[:PERFORMS]->(:AlchemicalNetwork {`_scoped_key`: $network_scoped_key}) + RETURN trp + """ + + results = n4js_preloaded.execute_query( + query, network_scoped_key=str(network_scoped_key) + ) + + assert len(results.records) == 3 + + patterns_list = list(self.default_patterns) + for record in results.records: + trp = record["trp"] + assert trp["pattern"] in patterns_list + patterns_list.remove(trp["pattern"]) + + assert len(patterns_list) == 0 + + def test_get_task_restart_patterns( + self, + user_client: client.AlchemiscaleClient, + network_tyk2, + scope_test, + n4js_preloaded, + ): + network_scoped_key = self.create_default_network( + network_tyk2, user_client, scope_test + ) + taskrestartpatterns = user_client.get_task_restart_patterns(network_scoped_key) + expected = { + pattern: self.default_max_retries for pattern in self.default_patterns + } + assert taskrestartpatterns == expected + + def test_remove_task_restart_patterns( + self, + user_client: client.AlchemiscaleClient, + network_tyk2, + scope_test, + n4js_preloaded, + ): + network_scoped_key = self.create_default_network( + network_tyk2, user_client, scope_test + ) + expected = { + pattern: self.default_max_retries for pattern in self.default_patterns + } + + # check that we have the expected 3 restart patterns + assert user_client.get_task_restart_patterns(network_scoped_key) == expected + + pattern_to_remove = next(iter(expected)) + user_client.remove_task_restart_patterns( + network_scoped_key, [pattern_to_remove] + ) + del expected[pattern_to_remove] + + # check that one was removed + assert user_client.get_task_restart_patterns(network_scoped_key) == expected + + patterns_to_remove = [pattern for pattern in expected] + user_client.remove_task_restart_patterns(network_scoped_key, patterns_to_remove) + + # check the remaining patterns are removed + assert user_client.get_task_restart_patterns(network_scoped_key) == {} + + def test_clear_task_restart_patterns( + self, + user_client: client.AlchemiscaleClient, + network_tyk2, + scope_test, + n4js_preloaded, + ): + network_scoped_key = self.create_default_network( + network_tyk2, user_client, scope_test + ) + + query = """ + MATCH (trp:TaskRestartPattern)-[:ENFORCES]->(:TaskHub)-[:PERFORMS]->(:AlchemicalNetwork {`_scoped_key`: $network_scoped_key}) + RETURN trp + """ + + assert ( + len( + n4js_preloaded.execute_query( + query, network_scoped_key=str(network_scoped_key) + ).records + ) + == 3 + ) + user_client.clear_task_restart_patterns(network_scoped_key) + assert ( + len( + n4js_preloaded.execute_query( + query, network_scoped_key=str(network_scoped_key) + ).records + ) + == 0 + ) + + def test_set_task_restart_patterns_allowed_restarts( + self, + user_client: client.AlchemiscaleClient, + network_tyk2, + scope_test, + n4js_preloaded, + ): + network_scoped_key = self.create_default_network( + network_tyk2, user_client, scope_test + ) + user_client.set_task_restart_patterns_allowed_restarts( + network_scoped_key, self.default_patterns[:2], 1 + ) + + expected = {pattern: 1 for pattern in self.default_patterns[:2]} + expected[self.default_patterns[-1]] = self.default_max_retries + assert user_client.get_task_restart_patterns(network_scoped_key) == expected diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index f2f25ef5..6e82c8b4 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -2,12 +2,15 @@ import random from typing import List, Dict from pathlib import Path -from itertools import chain from functools import reduce +from itertools import chain +import operator +from collections import defaultdict import pytest from gufe import AlchemicalNetwork from gufe.tokenization import TOKENIZABLE_REGISTRY +from gufe.protocols import ProtocolUnitFailure from gufe.protocols.protocoldag import execute_DAG from alchemiscale.storage.statestore import Neo4jStore @@ -28,6 +31,13 @@ ) from alchemiscale.security.auth import hash_key +from alchemiscale.tests.integration.storage.utils import ( + complete_tasks, + fail_task, + tasks_are_errored, + tasks_are_not_actioned_on_taskhub, + tasks_are_waiting, +) from ..conftest import DummyProtocolA, DummyProtocolB, DummyProtocolC @@ -1151,6 +1161,65 @@ def test_action_task(self, n4js: Neo4jStore, network_tyk2, scope_test): task_sks_fail = n4js.action_tasks(task_sks, taskhub_sk2) assert all([i is None for i in task_sks_fail]) + # test for APPLIES relationship between an ACTIONED task and a TaskRestartPattern + + ## create a restart pattern, should already create APPLIES relationships with those + ## already actioned + n4js.add_task_restart_patterns(taskhub_sk, ["test_pattern"], 5) + + query = """ + MATCH (:TaskRestartPattern)-[applies:APPLIES]->(Task)<-[:ACTIONS]-(:TaskHub {`_scoped_key`: $taskhub_scoped_key}) + // change this so that later tests can show the value was not overwritten + SET applies.num_retries = 1 + RETURN count(applies) AS applies_count + """ + + ## sanity check that this number makes sense + applies_count = n4js.execute_query( + query, taskhub_scoped_key=str(taskhub_sk) + ).records[0]["applies_count"] + + assert applies_count == 10 + + # create 10 more tasks and action them + task_sks = n4js.create_tasks([transformation_sk] * 10) + n4js.action_tasks(task_sks, taskhub_sk) + + assert len(n4js.get_taskhub_actioned_tasks([taskhub_sk])[0]) == 20 + + # same as above query without the set num_retries = 1 + query = """ + MATCH (:TaskRestartPattern)-[applies:APPLIES]->(:Task)<-[:ACTIONS]-(:TaskHub {`_scoped_key`: $taskhub_scoped_key}) + RETURN count(applies) AS applies_count + """ + + applies_count = n4js.execute_query( + query, taskhub_scoped_key=str(taskhub_sk) + ).records[0]["applies_count"] + + query = """ + MATCH (:TaskRestartPattern)-[applies:APPLIES]->(:Task) + RETURN applies + """ + + results = n4js.execute_query(query) + + count_0, count_1 = 0, 0 + for count in map( + lambda record: record["applies"]["num_retries"], results.records + ): + match count: + case 0: + count_0 += 1 + case 1: + count_1 += 1 + case _: + raise AssertionError( + "Unexpected count value found in num_retries field" + ) + + assert count_0 == count_1 == 10 + def test_action_task_other_statuses( self, n4js: Neo4jStore, network_tyk2, scope_test ): @@ -1266,6 +1335,14 @@ def test_cancel_task(self, n4js, network_tyk2, scope_test): # cancel the second and third task we created canceled = n4js.cancel_tasks(task_sks[1:3], taskhub_sk) + # cancel a fake task + fake_canceled = n4js.cancel_tasks( + [ScopedKey.from_str("Task-FAKE-test_org-test_campaign-test_project")], + taskhub_sk, + ) + + assert fake_canceled[0] is None + # check that the hub has the contents we expect q = """ MATCH (:TaskHub {_scoped_key: $taskhub_scoped_key})-[:ACTIONS]->(task:Task) @@ -1280,9 +1357,30 @@ def test_cancel_task(self, n4js, network_tyk2, scope_test): assert len(tasks) == 8 assert set(tasks) == set(actioned) - set(canceled) - # cancel the remaining tasks and check for Nones - canceled = n4js.cancel_tasks(task_sks, taskhub_sk) - assert canceled == [task_sks[0]] + [None, None] + task_sks[3:] + # create a TaskRestartPattern + n4js.add_task_restart_patterns(taskhub_sk, ["Test pattern"], 1) + + query = """ + MATCH (:TaskHub {`_scoped_key`: $taskhub_scoped_key})<-[:ENFORCES]-(:TaskRestartPattern)-[applies:APPLIES]->(:Task) + RETURN count(applies) AS applies_count + """ + + assert ( + n4js.execute_query(query, taskhub_scoped_key=str(taskhub_sk)).records[0][ + "applies_count" + ] + == 8 + ) + + # cancel the fourth and fifth task we created + canceled = n4js.cancel_tasks(task_sks[3:5], taskhub_sk) + + assert ( + n4js.execute_query(query, taskhub_scoped_key=str(taskhub_sk)).records[0][ + "applies_count" + ] + == 6 + ) def test_get_taskhub_tasks(self, n4js, network_tyk2, scope_test): an = network_tyk2 @@ -1958,6 +2056,534 @@ def test_get_task_failures( assert pdr_ref_sk in failure_pdr_ref_sks assert pdr_ref2_sk in failure_pdr_ref_sks + @pytest.mark.parametrize("failure_count", (1, 2, 3, 4)) + def test_add_protocol_dag_result_ref_traceback( + self, + network_tyk2_failure, + n4js, + scope_test, + transformation_failure, + protocoldagresults_failure, + failure_count: int, + ): + + an = network_tyk2_failure.copy_with_replacements( + name=network_tyk2_failure.name + + "_test_add_protocol_dag_result_ref_traceback" + ) + n4js.assemble_network(an, scope_test) + transformation_scoped_key = n4js.get_scoped_key( + transformation_failure, scope_test + ) + + # create a task; pretend we computed it, submit reference for pre-baked + # result + task_scoped_key = n4js.create_task(transformation_scoped_key) + + protocol_unit_failure = protocoldagresults_failure[0].protocol_unit_failures[0] + + pdrr = ProtocolDAGResultRef( + scope=task_scoped_key.scope, + obj_key=protocoldagresults_failure[0].key, + ok=protocoldagresults_failure[0].ok(), + ) + + # push the result + pdrr_scoped_key = n4js.set_task_result(task_scoped_key, pdrr) + + # simulating many failures + protocol_unit_failures = [] + for failure_index in range(failure_count): + protocol_unit_failures.append( + protocol_unit_failure.copy_with_replacements( + traceback=protocol_unit_failure.traceback + "_" + str(failure_index) + ) + ) + + n4js.add_protocol_dag_result_ref_tracebacks( + protocol_unit_failures, pdrr_scoped_key + ) + + query = """ + MATCH (traceback:Tracebacks)-[:DETAILS]->(:ProtocolDAGResultRef {`_scoped_key`: $pdrr_scoped_key}) + RETURN traceback + """ + + results = n4js.execute_query(query, pdrr_scoped_key=str(pdrr_scoped_key)) + + returned_tracebacks = results.records[0]["traceback"]["tracebacks"] + + assert returned_tracebacks == [puf.traceback for puf in protocol_unit_failures] + + ### task restart policies + + class TestTaskRestartPolicy: + + @pytest.mark.parametrize("status", ("complete", "invalid", "deleted")) + def test_task_status_change(self, n4js, network_tyk2, scope_test, status): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + "_test_task_status_change" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + transformation = list(an.edges)[0] + transformation_scoped_key = n4js.get_scoped_key(transformation, scope_test) + task_scoped_keys = n4js.create_tasks([transformation_scoped_key]) + n4js.action_tasks(task_scoped_keys, taskhub_scoped_key) + + n4js.add_task_restart_patterns(taskhub_scoped_key, ["Test pattern"], 10) + + query = """ + MATCH (:TaskRestartPattern)-[:APPLIES]->(task:Task {`_scoped_key`: $task_scoped_key})<-[:ACTIONS]-(:TaskHub {`_scoped_key`: $taskhub_scoped_key}) + RETURN task + """ + + results = n4js.execute_query( + query, + task_scoped_key=str(task_scoped_keys[0]), + taskhub_scoped_key=str(taskhub_scoped_key), + ) + + assert len(results.records) == 1 + + if status == "complete": + n4js.set_task_running(task_scoped_keys) + + assert n4js.set_task_status(task_scoped_keys)[0] is not None + + query = """ + MATCH (:TaskRestartPattern)-[:APPLIES]->(task:Task) + RETURN task + """ + + results = n4js.execute_query( + query, + task_scoped_key=str(task_scoped_keys[0]), + taskhub_scoped_key=str(taskhub_scoped_key), + ) + + assert len(results.records) == 0 + + def test_add_task_restart_patterns(self, n4js, network_tyk2, scope_test): + # create three new alchemical networks (and taskhubs) + taskhub_sks = [] + for network_index in range(3): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + + f"_test_add_task_restart_patterns_{network_index}" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + + # don't action tasks on every network, take every other + if network_index % 2 == 0: + transformation = list(an.edges)[0] + transformation_sk = n4js.get_scoped_key(transformation, scope_test) + task_sks = n4js.create_tasks([transformation_sk] * 3) + n4js.action_tasks(task_sks, taskhub_scoped_key) + + taskhub_sks.append(taskhub_scoped_key) + + # test a shared pattern with and without shared number of restarts + # this will create 6 unique patterns + for network_index in range(3): + taskhub_scoped_key = taskhub_sks[network_index] + n4js.add_task_restart_patterns( + taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 + ) + n4js.add_task_restart_patterns( + taskhub_scoped_key, + ["shared_pattern_and_different_restarts.+"], + network_index + 1, + ) + + q = """UNWIND $taskhub_sks AS taskhub_sk + MATCH (trp: TaskRestartPattern)-[:ENFORCES]->(th: TaskHub {`_scoped_key`: taskhub_sk}) RETURN trp, th + """ + + taskhub_sks = list(map(str, taskhub_sks)) + records = n4js.execute_query(q, taskhub_sks=taskhub_sks).records + + assert len(records) == 6 + + taskhub_scoped_key_set = set() + taskrestartpattern_scoped_key_set = set() + + for record in records: + taskhub_scoped_key = ScopedKey.from_str(record["th"]["_scoped_key"]) + taskrestartpattern_scoped_key = ScopedKey.from_str( + record["trp"]["_scoped_key"] + ) + + taskhub_scoped_key_set.add(taskhub_scoped_key) + taskrestartpattern_scoped_key_set.add(taskrestartpattern_scoped_key) + + assert len(taskhub_scoped_key_set) == 3 + assert len(taskrestartpattern_scoped_key_set) == 6 + + # check that the applies relationships were correctly added + + ## first check that the number of applies relationships is correct and + ## that the number of retries is zero + applies_query = """ + MATCH (trp: TaskRestartPattern)-[app:APPLIES {num_retries: 0}]->(task: Task)<-[:ACTIONS]-(th: TaskHub) + RETURN th, count(app) AS num_applied + """ + + records = n4js.execute_query(applies_query).records + + ### one record per taskhub with tasks actioned, each with six num_applied + assert len(records) == 2 + assert records[0]["num_applied"] == records[1]["num_applied"] == 6 + + applies_nonzero_retries = """ + MATCH (trp: TaskRestartPattern)-[app:APPLIES]->(task: Task)<-[:ACTIONS]-(th: TaskHub) + WHERE app.num_retries <> 0 + RETURN th, count(app) AS num_applied + """ + assert len(n4js.execute_query(applies_nonzero_retries).records) == 0 + + def test_remove_task_restart_patterns(self, n4js, network_tyk2, scope_test): + + # collect what we expect `get_task_restart_patterns` to return + expected_results = defaultdict(set) + + # create three new alchemical networks (and taskhubs) + taskhub_sks = [] + for network_index in range(3): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + + f"_test_remove_task_restart_patterns_{network_index}" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + taskhub_sks.append(taskhub_scoped_key) + + # test a shared pattern with and without shared number of restarts + # this will create 6 unique patterns + for network_index in range(3): + taskhub_scoped_key = taskhub_sks[network_index] + n4js.add_task_restart_patterns( + taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_restarts.+", 5) + ) + + n4js.add_task_restart_patterns( + taskhub_scoped_key, + ["shared_pattern_and_different_restarts.+"], + network_index + 1, + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_different_restarts.+", network_index + 1) + ) + + # remove both patterns enforcing the first taskhub at the same time, two patterns + target_taskhub = taskhub_sks[0] + target_patterns = [] + + for pattern, _ in expected_results[target_taskhub]: + target_patterns.append(pattern) + + expected_results[target_taskhub].clear() + + n4js.remove_task_restart_patterns(target_taskhub, target_patterns) + assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + + # remove both patterns enforcing the second taskhub one at a time, two patterns + target_taskhub = taskhub_sks[1] + # pointer to underlying set, pops will update comparison data structure + target_patterns = expected_results[target_taskhub] + + pattern, _ = target_patterns.pop() + n4js.remove_task_restart_patterns(target_taskhub, [pattern]) + assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + + pattern, _ = target_patterns.pop() + n4js.remove_task_restart_patterns(target_taskhub, [pattern]) + assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + + def test_set_task_restart_patterns_max_retries( + self, n4js, network_tyk2, scope_test + ): + network_name = ( + network_tyk2.name + "_test_set_task_restart_patterns_max_retries" + ) + an = network_tyk2.copy_with_replacements(name=network_name) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + + pattern_data = [("pattern_1", 5), ("pattern_2", 5), ("pattern_3", 5)] + + n4js.add_task_restart_patterns( + taskhub_scoped_key, + patterns=[data[0] for data in pattern_data], + number_of_retries=5, + ) + + expected_results = {taskhub_scoped_key: set(pattern_data)} + + assert expected_results == n4js.get_task_restart_patterns( + [taskhub_scoped_key] + ) + + # reflect changing just one max_retry + new_pattern_1_tuple = ("pattern_1", 1) + + expected_results[taskhub_scoped_key].remove(pattern_data[0]) + expected_results[taskhub_scoped_key].add(new_pattern_1_tuple) + + n4js.set_task_restart_patterns_max_retries( + taskhub_scoped_key, new_pattern_1_tuple[0], new_pattern_1_tuple[1] + ) + + assert expected_results == n4js.get_task_restart_patterns( + [taskhub_scoped_key] + ) + + # reflect changing more than one at a time + new_pattern_2_tuple = ("pattern_2", 2) + new_pattern_3_tuple = ("pattern_3", 2) + + expected_results[taskhub_scoped_key].remove(pattern_data[1]) + expected_results[taskhub_scoped_key].add(new_pattern_2_tuple) + + expected_results[taskhub_scoped_key].remove(pattern_data[2]) + expected_results[taskhub_scoped_key].add(new_pattern_3_tuple) + + n4js.set_task_restart_patterns_max_retries( + taskhub_scoped_key, [new_pattern_2_tuple[0], new_pattern_3_tuple[0]], 2 + ) + + assert expected_results == n4js.get_task_restart_patterns( + [taskhub_scoped_key] + ) + + def test_get_task_restart_patterns(self, n4js, network_tyk2, scope_test): + # create three new alchemical networks (and taskhubs) + taskhub_sks = [] + for network_index in range(3): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + + f"_test_add_task_restart_patterns_{network_index}" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + taskhub_sks.append(taskhub_scoped_key) + + expected_results = defaultdict(set) + # test a shared pattern with and without shared number of restarts + # this will create 6 unique patterns + for network_index in range(3): + taskhub_scoped_key = taskhub_sks[network_index] + n4js.add_task_restart_patterns( + taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_restarts.+", 5) + ) + n4js.add_task_restart_patterns( + taskhub_scoped_key, + ["shared_pattern_and_different_restarts.+"], + network_index + 1, + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_different_restarts.+", network_index + 1) + ) + + taskhub_grouped_patterns = n4js.get_task_restart_patterns(taskhub_sks) + + assert taskhub_grouped_patterns == expected_results + + def test_resolve_task_restarts( + self, + n4js_task_restart_policy: Neo4jStore, + ): + n4js = n4js_task_restart_policy + + # get the actioned tasks for each taskhub + taskhub_actioned_tasks = {} + for taskhub_scoped_key in n4js.query_taskhubs(): + taskhub_actioned_tasks[taskhub_scoped_key] = set( + n4js.get_taskhub_actioned_tasks([taskhub_scoped_key])[0] + ) + + restart_patterns = n4js.get_task_restart_patterns( + list(taskhub_actioned_tasks.keys()) + ) + + # create a map of the transformations and all of the tasks that perform them + transformation_tasks: dict[ScopedKey, list[ScopedKey]] = defaultdict(list) + for task in n4js.query_tasks(status=TaskStatusEnum.waiting.value): + transformation_scoped_key, _ = n4js.get_task_transformation( + task, return_gufe=False + ) + transformation_tasks[transformation_scoped_key].append(task) + + # get a list of all tasks for more convient calls of the resolve method + all_tasks = [] + for task_group in transformation_tasks.values(): + all_tasks.extend(task_group) + + taskhub_scoped_key_no_policy = None + taskhub_scoped_key_with_policy = None + + # bind taskhub scoped keys to variables for convenience later + for taskhub_scoped_key, patterns in restart_patterns.items(): + if not patterns: + taskhub_scoped_key_no_policy = taskhub_scoped_key + continue + else: + taskhub_scoped_key_with_policy = taskhub_scoped_key + continue + + if patterns and taskhub_scoped_key_with_policy: + raise AssertionError("More than one TaskHub has restart patterns") + + assert ( + taskhub_scoped_key_no_policy + and taskhub_scoped_key_with_policy + and (taskhub_scoped_key_no_policy != taskhub_scoped_key_with_policy) + ) + + # we first check the behavior involving tasks that are actioned by both taskhubs + # this involves confirming: + # + # 1. Completed Tasks do not have an actions relationship with either TaskHub + # 2. A Task entering the error state is switched back to waiting if any restart patterns apply + # 3. A Task entering the error state is left in the error state if no patterns apply and only the TaskHub without + # an enforcing task restart policy actions the Task + # + # Tasks will be set to the error state with a spoofing method, which will create a fake ProtocolDAGResultRef + # and Tracebacks. This is done since making a protocol fail systematically in the testing environment is not + # obvious at this time. + + # reduce down all tasks until only the common elements between taskhubs exist + tasks_actioned_by_all_taskhubs: List[ScopedKey] = list( + reduce(operator.and_, taskhub_actioned_tasks.values()) + ) + + assert len(tasks_actioned_by_all_taskhubs) == 4 + + # we're going to just pass the first 2 and fail the second 2 + tasks_to_complete = tasks_actioned_by_all_taskhubs[:2] + tasks_to_fail = tasks_actioned_by_all_taskhubs[2:] + + complete_tasks(n4js, tasks_to_complete) + + records = n4js.execute_query( + """ + UNWIND $task_scoped_keys as task_scoped_key + MATCH (task:Task {_scoped_key: task_scoped_key})-[:RESULTS_IN]->(:ProtocolDAGResultRef) + RETURN count(task) as task_count + """, + task_scoped_keys=list(map(str, tasks_to_complete)), + ).records + + assert records[0]["task_count"] == 2 + + # test the behavior of the compute API + for i, task in enumerate(tasks_to_fail): + error_messages = [ + f"Error message {repeat}, round {i}" for repeat in range(3) + ] + + fail_task( + n4js, + task, + resolve=False, + error_messages=error_messages, + ) + + n4js.resolve_task_restarts(all_tasks) + + # both tasks should have the waiting status and the APPLIES + # relationship num_retries should have incremented by 1 + query = """ + UNWIND $task_scoped_keys as task_scoped_key + MATCH (task:Task {`_scoped_key`: task_scoped_key, status: $waiting})<-[:APPLIES {num_retries: 1}]-(:TaskRestartPattern {max_retries: 2}) + RETURN count(DISTINCT task) as renewed_waiting_tasks + """ + + renewed_waiting = n4js.execute_query( + query, + task_scoped_keys=list(map(str, tasks_to_fail)), + waiting=TaskStatusEnum.waiting.value, + ).records[0]["renewed_waiting_tasks"] + + assert renewed_waiting == 2 + + # we want the resolve restarts to cancel a task. + # deconstruct the tasks to fail, where the first + # one will be cancelled and the second will continue to wait + task_to_cancel, task_to_wait = tasks_to_fail + + # error out the first task + for _ in range(2): + error_messages = [ + f"Error message {repeat}, round {i}" for repeat in range(3) + ] + + fail_task( + n4js, + task_to_cancel, + resolve=False, + error_messages=error_messages, + ) + + n4js.resolve_task_restarts(tasks_to_fail) + + # check that it is no longer actioned on the enforced taskhub + assert tasks_are_not_actioned_on_taskhub( + n4js, + [task_to_cancel], + taskhub_scoped_key_with_policy, + ) + + # check that it is still actioned on the unenforced taskhub + assert not tasks_are_not_actioned_on_taskhub( + n4js, + [task_to_cancel], + taskhub_scoped_key_no_policy, + ) + + # it should still be errored though! + assert tasks_are_errored(n4js, [task_to_cancel]) + + # fail the second task one time + error_messages = [ + f"Error message {repeat}, round {i}" for repeat in range(3) + ] + + fail_task( + n4js, + task_to_wait, + resolve=False, + error_messages=error_messages, + ) + + n4js.resolve_task_restarts(tasks_to_fail) + + # check that the waiting task is actioned on both taskhubs + assert not tasks_are_not_actioned_on_taskhub( + n4js, + [task_to_wait], + taskhub_scoped_key_with_policy, + ) + + assert not tasks_are_not_actioned_on_taskhub( + n4js, + [task_to_wait], + taskhub_scoped_key_no_policy, + ) + + # it should be waiting + assert tasks_are_waiting(n4js, [task_to_wait]) + + @pytest.mark.xfail(raises=NotImplementedError) + def test_task_actioning_applies_relationship(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_task_deaction_applies_relationship(self): + raise NotImplementedError + ### authentication @pytest.mark.parametrize( diff --git a/alchemiscale/tests/integration/storage/utils.py b/alchemiscale/tests/integration/storage/utils.py new file mode 100644 index 00000000..520e25c0 --- /dev/null +++ b/alchemiscale/tests/integration/storage/utils.py @@ -0,0 +1,103 @@ +from datetime import datetime + +from gufe.protocols import ProtocolUnitFailure + +from alchemiscale.storage.statestore import Neo4jStore +from alchemiscale import ScopedKey +from alchemiscale.storage.models import TaskStatusEnum, ProtocolDAGResultRef + + +def tasks_are_not_actioned_on_taskhub( + n4js: Neo4jStore, + task_scoped_keys: list[ScopedKey], + taskhub_scoped_key: ScopedKey, +) -> bool: + + actioned_tasks = n4js.get_taskhub_actioned_tasks([taskhub_scoped_key]) + + return set(actioned_tasks[0].keys()).isdisjoint(set(task_scoped_keys)) + + +def tasks_are_errored(n4js: Neo4jStore, task_scoped_keys: list[ScopedKey]) -> bool: + query = """ + UNWIND $task_scoped_keys as task_scoped_key + MATCH (task:Task {_scoped_key: task_scoped_key, status: $error}) + RETURN task + """ + + results = n4js.execute_query( + query, + task_scoped_keys=list(map(str, task_scoped_keys)), + error=TaskStatusEnum.error.value, + ) + + return len(results.records) == len(task_scoped_keys) + + +def tasks_are_waiting(n4js: Neo4jStore, task_scoped_keys: list[ScopedKey]) -> bool: + query = """ + UNWIND $task_scoped_keys as task_scoped_key + MATCH (task:Task {_scoped_key: task_scoped_key, status: $waiting}) + RETURN task + """ + + results = n4js.execute_query( + query, + task_scoped_keys=list(map(str, task_scoped_keys)), + waiting=TaskStatusEnum.waiting.value, + ) + + return len(results.records) == len(task_scoped_keys) + + +def complete_tasks( + n4js: Neo4jStore, + tasks: list[ScopedKey], +): + n4js.set_task_running(tasks) + for task in tasks: + ok_pdrr = ProtocolDAGResultRef( + ok=True, + datetime_created=datetime.utcnow(), + obj_key=task.gufe_key, + scope=task.scope, + ) + + _ = n4js.set_task_result(task, ok_pdrr) + + n4js.set_task_complete(tasks) + + +def fail_task( + n4js: Neo4jStore, + task: ScopedKey, + resolve: bool = False, + error_messages: list[str] = [], +) -> None: + n4js.set_task_running([task]) + + not_ok_pdrr = ProtocolDAGResultRef( + ok=False, + datetime_created=datetime.utcnow(), + obj_key=task.gufe_key, + scope=task.scope, + ) + + protocol_unit_failures = [] + for j, message in enumerate(error_messages): + puf = ProtocolUnitFailure( + source_key=f"FakeProtocolUnitKey-123{j}", + inputs={}, + outputs={}, + exception=RuntimeError, + traceback=message, + ) + protocol_unit_failures.append(puf) + + pdrr_scoped_key = n4js.set_task_result(task, not_ok_pdrr) + + n4js.add_protocol_dag_result_ref_tracebacks(protocol_unit_failures, pdrr_scoped_key) + n4js.set_task_error([task]) + + if resolve: + n4js.resolve_task_restarts([task]) diff --git a/alchemiscale/tests/unit/test_storage_models.py b/alchemiscale/tests/unit/test_storage_models.py index 36678b9a..4d1bd720 100644 --- a/alchemiscale/tests/unit/test_storage_models.py +++ b/alchemiscale/tests/unit/test_storage_models.py @@ -1,6 +1,11 @@ import pytest -from alchemiscale.storage.models import NetworkStateEnum, NetworkMark +from alchemiscale.storage.models import ( + NetworkStateEnum, + NetworkMark, + TaskRestartPattern, + Tracebacks, +) from alchemiscale import ScopedKey @@ -38,3 +43,168 @@ def test_suggested_states_message(self): assert len(suggested_states) == len(NetworkStateEnum) for state in suggested_states: NetworkStateEnum(state) + + +class TestTaskRestartPattern(object): + + pattern_value_error = "`pattern` must be a non-empty string" + max_retries_value_error = "`max_retries` must have a positive integer value." + + def test_empty_pattern(self): + with pytest.raises(ValueError, match=self.pattern_value_error): + _ = TaskRestartPattern( + "", 3, "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) + + def test_non_string_pattern(self): + with pytest.raises(ValueError, match=self.pattern_value_error): + _ = TaskRestartPattern( + None, 3, "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) + + with pytest.raises(ValueError, match=self.pattern_value_error): + _ = TaskRestartPattern( + [], 3, "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) + + def test_non_positive_max_retries(self): + + with pytest.raises(ValueError, match=self.max_retries_value_error): + TaskRestartPattern( + "Example pattern", + 0, + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project", + ) + + with pytest.raises(ValueError, match=self.max_retries_value_error): + TaskRestartPattern( + "Example pattern", + -1, + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project", + ) + + def test_non_int_max_retries(self): + with pytest.raises(ValueError, match=self.max_retries_value_error): + TaskRestartPattern( + "Example pattern", + 4.0, + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project", + ) + + def test_to_dict(self): + trp = TaskRestartPattern( + "Example pattern", + 3, + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project", + ) + dict_trp = trp.to_dict() + + assert len(dict_trp.keys()) == 6 + + assert dict_trp.pop("__qualname__") == "TaskRestartPattern" + assert dict_trp.pop("__module__") == "alchemiscale.storage.models" + assert ( + dict_trp.pop("taskhub_scoped_key") + == "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) + + # light test of the version key + try: + dict_trp.pop(":version:") + except KeyError: + raise AssertionError("expected to find :version:") + + expected = {"pattern": "Example pattern", "max_retries": 3} + + assert expected == dict_trp + + def test_from_dict(self): + + original_pattern = "Example pattern" + original_max_retries = 3 + original_taskhub_scoped_key = ( + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) + + trp_orig = TaskRestartPattern( + original_pattern, original_max_retries, original_taskhub_scoped_key + ) + trp_dict = trp_orig.to_dict() + trp_reconstructed: TaskRestartPattern = TaskRestartPattern.from_dict(trp_dict) + + assert trp_reconstructed.pattern == original_pattern + assert trp_reconstructed.max_retries == original_max_retries + assert trp_reconstructed.taskhub_scoped_key == original_taskhub_scoped_key + + assert trp_orig is trp_reconstructed + + +class TestTracebacks(object): + + valid_entry = ["traceback1", "traceback2", "traceback3"] + source_keys = ["ProtocolUnit-ABC123", "ProtocolUnit-DEF456", "ProtocolUnit-GHI789"] + failure_keys = [ + "ProtocolUnitFailure-ABC123", + "ProtocolUnitFailure-DEF456", + "ProtocolUnitFailure-GHI789", + ] + tracebacks_value_error = ( + "`tracebacks` must be a non-empty list of non-empty string values" + ) + + def test_empty_string_element(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Tracebacks(self.valid_entry + [""], self.source_keys, self.failure_keys) + + def test_non_list_parameter(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Tracebacks(None, self.source_keys, self.failure_keys) + + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Tracebacks(100, self.source_keys, self.failure_keys) + + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Tracebacks( + "not a list, but still an iterable that yields strings", + self.source_keys, + self.failure_keys, + ) + + def test_list_non_string_elements(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Tracebacks(self.valid_entry + [None], self.source_keys, self.failure_keys) + + def test_empty_list(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Tracebacks([], self.source_keys, self.failure_keys) + + def test_to_dict(self): + tb = Tracebacks(self.valid_entry, self.source_keys, self.failure_keys) + tb_dict = tb.to_dict() + + assert len(tb_dict) == 6 + + assert tb_dict.pop("__qualname__") == "Tracebacks" + assert tb_dict.pop("__module__") == "alchemiscale.storage.models" + + # light test of the version key + try: + tb_dict.pop(":version:") + except KeyError: + raise AssertionError("expected to find :version:") + + expected = { + "tracebacks": self.valid_entry, + "source_keys": self.source_keys, + "failure_keys": self.failure_keys, + } + + assert expected == tb_dict + + def test_from_dict(self): + tb_orig = Tracebacks(self.valid_entry, self.source_keys, self.failure_keys) + tb_dict = tb_orig.to_dict() + tb_reconstructed: TaskRestartPattern = TaskRestartPattern.from_dict(tb_dict) + + assert tb_reconstructed.tracebacks == self.valid_entry + tb_orig is tb_reconstructed diff --git a/docs/user_guide.rst b/docs/user_guide.rst index dec60bf7..24b2642a 100644 --- a/docs/user_guide.rst +++ b/docs/user_guide.rst @@ -510,6 +510,48 @@ If you’re feeling confident, you could set all errored :py:class:`~alchemiscal , ...] +*************************************************** +Re-running Errored Tasks with Task Restart Patterns +*************************************************** + +Re-running errored :py:class:`~alchemiscale.storage.models.Task`\s manually for known failure modes (such as those described in the previous section) quickly becomes tedious, especially for large networks. +Alternatively, you can add `regular expression `_ strings as :py:class:`~alchemiscale.storage.models.Task` restart patterns to an :external+gufe:py:class:`~gufe.network.AlchemicalNetwork`. +:py:class:`~alchemiscale.storage.models.Task`\s actioned on that :external+gufe:py:class:`~gufe.network.AlchemicalNetwork` will be automatically restarted if the :py:class:`~alchemiscale.storage.models.Task` fails during any part of its execution, provided that an enforcing pattern matches a traceback within the :py:class:`~alchemiscale.storage.models.Task`\'s failed :external+gufe:py:class:`~gufe.protocols.ProtocolDAGResult`. +The number of restarts is controlled by the ``num_allowed_restarts`` parameter of the :py:meth:`~alchemiscale.interface.client.AlchemiscaleClient.add_task_restart_patterns` method. +If a :py:class:`~alchemiscale.storage.models.Task` is restarted more than ``num_allowed_restarts`` times, the :py:class:`~alchemiscale.storage.models.Task` is canceled on that :external+gufe:py:class:`~gufe.network.AlchemicalNetwork` and left in an ``error`` status. + +As an example, if you wanted to rerun any :py:class:`~alchemiscale.storage.models.Task` that failed with a ``RuntimeError`` or a ``MemoryError`` and attempt it at least 5 times, you could add the following patterns:: + + >>> asc.add_task_restart_patterns(an_sk, [r"RuntimeError: .+", r"MemoryError: Unable to allocate \d+ GiB"], 5) + +Providing too general a pattern, such as the example above, you may consume compute resources on failures that are unavoidable. +On the other hand, an overly strict pattern (such as specifying explicit ``gufe`` keys) will likely do nothing. +Therefore, it is best to find a balance in your patterns that matches your use case. + +Restart patterns enforcing an :external+gufe:py:class:`~gufe.network.AlchemicalNetwork` can be retrieved with:: + + >>> asc.get_task_restart_patterns(an_sk) + {"RuntimeError: .+": 5, "MemoryError: Unable to allocate \d+ GiB": 5} + +The number of allowed restarts can also be modified:: + + >>> asc.set_task_restart_patterns_allowed_restarts(an_sk, ["RuntimeError: .+"], 3) + >>> asc.set_task_restart_patterns_allowed_restarts(an_sk, ["MemoryError: Unable to allocate \d+ GiB"], 2) + >>> asc.get_task_restart_patterns(an_sk) + {"RuntimeError: .+": 3, "MemoryError: Unable to allocate \d+ GiB": 2} + +Patterns can be removed by specifying the patterns in a list:: + + >>> asc.remove_task_restart_patterns(an_sk, ["MemoryError: Unable to allocate \d+ GiB"]) + >>> asc.get_task_restart_patterns(an_sk) + {"RuntimeError: .+": 3} + +Or by clearing all enforcing patterns:: + + >>> asc.clear_task_restart_patterns(an_sk) + >>> asc.get_task_restart_patterns(an_sk) + {} + *********************************** Marking Tasks as deleted or invalid