Skip to content

Commit

Permalink
Add in the dragon shared memory hack. Make it so we don't deadlock
Browse files Browse the repository at this point in the history
  • Loading branch information
gwm17 committed Dec 17, 2024
1 parent bd9af4a commit 81ff0e2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
10 changes: 9 additions & 1 deletion src/spyral/core/dragon_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ def _run_shared_memory_manager(
and indicate that all shared memory should be cleaned up.
"""
handles = {}
pipeline.create_shared_data(handles)
try:
pipeline.create_shared_data(handles)
except Exception:
# Hack to make sure we don't hang when creation fails
ready.set()
return
ready.set()
shutdown.wait()
for handle in handles.values():
Expand All @@ -50,6 +55,9 @@ def start_pipeline_dragon(
worker_cpus = total_cpus - n_nodes - 1

print(SPLASH)
print(f"Total cpus available: {total_cpus}")
print(f"Number of nodes: {n_nodes}")
print(f"Calculated worker cpus: {worker_cpus}")
print(f"Creating workspace: {pipeline.workspace} ...", end=" ")
pipeline.create_workspace()
print("Done.")
Expand Down
18 changes: 15 additions & 3 deletions src/spyral/phases/interp_solver_phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,21 @@ def create_shared_data(
# Create a block of shared memory of the same total size as the mesh
# Note that we don't have a lock on the shared memory as the mesh is
# used read-only
handle = SharedMemory(
name=self.shared_mesh_name, create=True, size=mesh_data.nbytes
)
# Dragon hack: Dragon does not currently implement SharedMemory, so for now
# we fall back to the base multiprocessing impl when we receive a
# NotImplementedError
handle = None
try:
handle = SharedMemory(
name=self.shared_mesh_name, create=True, size=mesh_data.nbytes
)
except NotImplementedError:
# Hacky
MultiprocessingSharedMemory = SharedMemory.__bases__[0]
handle = MultiprocessingSharedMemory(
name=self.shared_mesh_name, create=True, size=mesh_data.nbytes
)

handles[handle.name] = handle
spyral_info(
__name__,
Expand Down

0 comments on commit 81ff0e2

Please sign in to comment.