From 81ff0e2cc634274e72a2befbdc4c054d677759b9 Mon Sep 17 00:00:00 2001 From: gwm17 Date: Tue, 17 Dec 2024 11:48:42 -0500 Subject: [PATCH] Add in the dragon shared memory hack. Make it so we don't deadlock --- src/spyral/core/dragon_start.py | 10 +++++++++- src/spyral/phases/interp_solver_phase.py | 18 +++++++++++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/spyral/core/dragon_start.py b/src/spyral/core/dragon_start.py index 73389f5..3bbd313 100644 --- a/src/spyral/core/dragon_start.py +++ b/src/spyral/core/dragon_start.py @@ -25,7 +25,12 @@ def _run_shared_memory_manager( and indicate that all shared memory should be cleaned up. """ handles = {} - pipeline.create_shared_data(handles) + try: + pipeline.create_shared_data(handles) + except Exception: + # Hack to make sure we don't hang when creation fails + ready.set() + return ready.set() shutdown.wait() for handle in handles.values(): @@ -50,6 +55,9 @@ def start_pipeline_dragon( worker_cpus = total_cpus - n_nodes - 1 print(SPLASH) + print(f"Total cpus available: {total_cpus}") + print(f"Number of nodes: {n_nodes}") + print(f"Calculated worker cpus: {worker_cpus}") print(f"Creating workspace: {pipeline.workspace} ...", end=" ") pipeline.create_workspace() print("Done.") diff --git a/src/spyral/phases/interp_solver_phase.py b/src/spyral/phases/interp_solver_phase.py index 8a1fd7f..c68ac75 100644 --- a/src/spyral/phases/interp_solver_phase.py +++ b/src/spyral/phases/interp_solver_phase.py @@ -173,9 +173,21 @@ def create_shared_data( # Create a block of shared memory of the same total size as the mesh # Note that we don't have a lock on the shared memory as the mesh is # used read-only - handle = SharedMemory( - name=self.shared_mesh_name, create=True, size=mesh_data.nbytes - ) + # Dragon hack: Dragon does not currently implement SharedMemory, so for now + # we fall back to the base multiprocessing impl when we receive a + # NotImplementedError + handle = None + try: + handle = SharedMemory( + name=self.shared_mesh_name, create=True, size=mesh_data.nbytes + ) + except NotImplementedError: + # Hacky + MultiprocessingSharedMemory = SharedMemory.__bases__[0] + handle = MultiprocessingSharedMemory( + name=self.shared_mesh_name, create=True, size=mesh_data.nbytes + ) + handles[handle.name] = handle spyral_info( __name__,