Skip to content

Commit

Permalink
9k parallel inference works
Browse files Browse the repository at this point in the history
  • Loading branch information
cathalobrien committed Jan 15, 2025
2 parents 71fdf0e + b95e167 commit 9264754
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 8 deletions.
52 changes: 52 additions & 0 deletions src/anemoi/inference/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
import subprocess
import socket

import logging

LOG = logging.getLogger(__name__)

def getParallelLogger():
global_rank = int(os.environ.get("SLURM_PROCID", 0))
logger = logging.getLogger(__name__)

if global_rank != 0:
logger.setLevel(logging.NOTSET)

return logger

def init_network():
# Get the master address from the SLURM_NODELIST environment variable
slurm_nodelist = os.environ.get("SLURM_NODELIST")
if not slurm_nodelist:
raise ValueError("SLURM_NODELIST environment variable is not set.")

# Use subprocess to execute scontrol and get the first hostname
result = subprocess.run(
["scontrol", "show", "hostname", slurm_nodelist],
stdout=subprocess.PIPE,
text=True,
check=True
)
master_addr = result.stdout.splitlines()[0]

# Resolve the master address using nslookup
try:
resolved_addr = socket.gethostbyname(master_addr)
except socket.gaierror:
raise ValueError(f"Could not resolve hostname: {master_addr}")

# Set the resolved address as MASTER_ADDR
os.environ["MASTER_ADDR"] = resolved_addr

# Calculate the MASTER_PORT using SLURM_JOBID
slurm_jobid = os.environ.get("SLURM_JOBID")
if not slurm_jobid:
raise ValueError("SLURM_JOBID environment variable is not set.")

master_port = 10000 + int(slurm_jobid[-4:])
os.environ["MASTER_PORT"] = str(master_port)

# Print the results for confirmation
#print(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
#print(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
25 changes: 17 additions & 8 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import os
import warnings
import random
from functools import cached_property

import numpy as np
Expand All @@ -26,6 +27,7 @@
from .postprocess import Accumulator
from .postprocess import Noop
from .precisions import PRECISIONS
from .parallel import init_network

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -252,25 +254,32 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
os.environ.get("SLURM_PROCID", 0)
) # Get rank of the current process, equivalent to dist.get_rank()
world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes
if local_rank == 0:
LOG.info("World size: %d", world_size)

# Create pytorch input tensor
input_tensor_torch = torch.from_numpy(np.swapaxes(input_tensor_numpy, -2, -1)[np.newaxis, ...]).to(self.device)

LOG.info("Using autocast %s", self.autocast)

lead_time = to_timedelta(lead_time)
steps = lead_time // self.checkpoint.timestep

LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps)
if global_rank == 0:
LOG.info("World size: %d", world_size)
LOG.info("Using autocast %s", self.autocast)
LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps)

result = input_state.copy() # We should not modify the input state
result["fields"] = dict()

start = input_state["date"]

if world_size > 1:

#only rank 0 logs
if (local_rank != 0):
LOG.handlers.clear()

init_network()


dist.init_process_group(
backend="nccl",
init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}',
Expand Down Expand Up @@ -304,7 +313,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
for s in range(steps):
step = (s + 1) * self.checkpoint.timestep
date = start + step
if local_rank == 0:
if global_rank == 0:
LOG.info("Forecasting step %s (%s)", step, date)

result["date"] = date
Expand All @@ -315,7 +324,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
# y_pred = self.model.forward(input_tensor_torch, model_comm_group)
y_pred = self.model.predict_step(input_tensor_torch, model_comm_group)

if local_rank == 0:
if global_rank == 0:
# Detach tensor and squeeze (should we detach here?)
output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables)

Expand Down Expand Up @@ -361,7 +370,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_input_tensor("Next input tensor", input_tensor_torch)

dist.destroy_process_group()
#dist.destroy_process_group()

def copy_prognostic_fields_to_input_tensor(self, input_tensor_torch, y_pred, check):

Expand Down

0 comments on commit 9264754

Please sign in to comment.