diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index 40ce637..f585e42 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -140,44 +140,43 @@ def split_streaming( secondary_backends = {b for b in backend_usage if b != primary_backend} _log.info(f"Backend split: {primary_backend=} {secondary_backends=}") - primary_id = main_subgraph_id - primary_pg = {} primary_has_load_collection = False - primary_dependencies = [] - + sub_graphs: List[Tuple[NodeId, Set[NodeId], BackendId]] = [] for node_id, node in process_graph.items(): if node["process_id"] == "load_collection": bid = backend_per_collection[node["arguments"]["id"]] if bid == primary_backend and (not self._always_split or not primary_has_load_collection): - # Add to primary pg - primary_pg[node_id] = node primary_has_load_collection = True else: - # New secondary pg - sub_id = f"{bid}:{node_id}" - sub_pg = { - node_id: node, - "sr1": { - # TODO: other/better choices for save_result format (e.g. based on backend support)? - "process_id": "save_result", - "arguments": { - "data": {"from_node": node_id}, - # TODO: particular format options? - # "format": "NetCDF", - "format": "GTiff", - }, - "result": True, - }, - } + sub_graphs.append((node_id, {node_id}, bid)) - yield (sub_id, SubJob(process_graph=sub_pg, backend_id=bid), []) + primary_graph_node_ids = set(process_graph.keys()).difference(n for _, ns, _ in sub_graphs for n in ns) + primary_pg = {k: process_graph[k] for k in primary_graph_node_ids} + primary_dependencies = [] - # Link secondary pg into primary pg - primary_pg.update(get_replacement(node_id=node_id, node=node, subgraph_id=sub_id)) - primary_dependencies.append(sub_id) - else: - primary_pg[node_id] = node + for node_id, subgraph_node_ids, backend_id in sub_graphs: + # New secondary pg + sub_id = f"{backend_id}:{node_id}" + sub_pg = {k: v for k, v in process_graph.items() if k in subgraph_node_ids} + # Add new `save_result` node to the subgraphs + sub_pg["_agg_crossbackend_save_result"] = { + # TODO: other/better choices for save_result format (e.g. based on backend support, cube type)? + "process_id": "save_result", + "arguments": { + "data": {"from_node": node_id}, + # TODO: particular format options? + # "format": "NetCDF", + "format": "GTiff", + }, + "result": True, + } + yield (sub_id, SubJob(process_graph=sub_pg, backend_id=backend_id), []) + # Link secondary pg into primary pg + primary_pg.update(get_replacement(node_id=node_id, node=process_graph[node_id], subgraph_id=sub_id)) + primary_dependencies.append(sub_id) + + primary_id = main_subgraph_id yield (primary_id, SubJob(process_graph=primary_pg, backend_id=primary_backend), primary_dependencies) def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob: @@ -427,6 +426,7 @@ class _FrozenGraph: """ # TODO: find better class name: e.g. SplitGraphView, GraphSplitUtility, GraphSplitter, ...? + # TODO: add more logging of what is happening under the hood def __init__(self, graph: dict[NodeId, _FrozenNode]): # Work with a read-only proxy to prevent accidental changes diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index f2c8c85..d513068 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -50,7 +50,7 @@ def test_split_basic(self): "cube2": {"from_node": "lc2"}, }, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, "result": True, @@ -77,7 +77,7 @@ def test_split_basic(self): "cube2": {"from_node": "lc2"}, }, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, "result": True, @@ -91,7 +91,7 @@ def test_split_basic(self): "process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"}, "result": True, @@ -113,7 +113,7 @@ def test_split_streaming_basic(self): "cube2": {"from_node": "lc2"}, }, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, "result": True, @@ -132,7 +132,7 @@ def test_split_streaming_basic(self): "process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"}, "result": True, @@ -152,7 +152,7 @@ def test_split_streaming_basic(self): "process_id": "merge_cubes", "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, "result": True, @@ -204,7 +204,7 @@ def get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) -> dict: SubJob( process_graph={ "lc2": {"process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}}, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"}, "result": True, @@ -219,7 +219,7 @@ def get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) -> dict: SubJob( process_graph={ "lc3": {"process_id": "load_collection", "arguments": {"id": "B3_SCL"}}, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "lc3"}, "format": "GTiff"}, "result": True, @@ -369,7 +369,7 @@ def test_basic(self, aggregator: _FakeAggregator): "cube2": {"from_node": "lc2"}, }, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, "result": True, @@ -404,7 +404,7 @@ def test_basic(self, aggregator: _FakeAggregator): "cube2": {"from_node": "lc2"}, }, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, "result": True, @@ -415,7 +415,7 @@ def test_basic(self, aggregator: _FakeAggregator): "process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"}, "result": True,