Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some memory debugging info #101

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,4 @@ _version.py
*.gif
*.zarr/
/*-plots/
/definitions*/
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you!
- Add CONTRIBUTORS.md file (#36)
- Add sanetise command
- Add support for huggingface
- Add memory debugging code

### Changed
- Change `write_initial_state` default value to `true`
Expand Down
13 changes: 13 additions & 0 deletions src/anemoi/inference/commands/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def add_arguments(self, command_parser):
command_parser.add_argument("--output", type=str, help="Output file")
command_parser.add_argument("--staging-dates", type=str, help="Path to a file with staging dates")
command_parser.add_argument("--extra", action="append", help="Additional request values. Can be repeated")
command_parser.add_argument("--retrieve-fields-type", type=str, help="Type of fields to retrieve")
command_parser.add_argument("overrides", nargs="*", help="Overrides.")

def run(self, args):
Expand All @@ -45,6 +46,18 @@ def run(self, args):
area = runner.checkpoint.area
grid = runner.checkpoint.grid

if args.retrieve_fields_type is not None:
selected = set()

for name, kinds in runner.checkpoint.variable_categories().items():
if "computed" in kinds:
continue
for kind in kinds:
if args.retrieve_fields_type.startswith(kind): # PrepML adds an 's' to the type
selected.add(name)

variables = sorted(selected)

extra = postproc(grid, area)

for r in args.extra or []:
Expand Down
6 changes: 6 additions & 0 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ class Config:
development_hacks: dict = {}
"""A dictionary of development hacks to apply to the runner. This is used to test new features or to work around"""

debugging_info: dict = {}
"""A dictionary to store debug information. This is ignored."""

memory_debugging: bool = False
"""If True, the runner will print memory debugging information."""


def load_config(path, overrides, defaults=None, Configuration=Configuration):

Expand Down
5 changes: 2 additions & 3 deletions src/anemoi/inference/grib/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,10 @@ def grib_keys(
result.update(grib2_keys.get(param, {}))

result.setdefault("type", "fc")
type = result.get("type")

if type is not None:
if result.get("type") in ("an", "fc"):
# For organisations that do not use type
result.setdefault("dataType", type)
result.setdefault("dataType", result.pop("type"))

# if stream is not None:
# result.setdefault("stream", stream)
Expand Down
86 changes: 86 additions & 0 deletions src/anemoi/inference/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# (C) Copyright 2025 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import gc
import logging
import os

import torch
from anemoi.utils.humanize import bytes_to_human

LOG = logging.getLogger(__name__)


class Collector:
"""Collects memory usage information."""

def __init__(self):
self.known = set()
self.last_total = 0
self.last_title = "START"
self.n = 0
self.plot = int(os.environ.get("MEMORY_PLOT", "0"))

def __call__(self, title):
gc.collect()
total = 0
tensors = set()
newobj = []

for obj in gc.get_objects():
if obj.__class__.__name__ == "SourceMaker":
# Remove some warnings from earthkit-data
continue
try:
if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):
if id(obj) not in self.known:
newobj.append(obj)
tensors.add(id(obj))
total += obj.element_size() * obj.nelement()

except Exception:
pass

added = tensors - self.known
removed = self.known - tensors

if newobj and self.plot:
import objgraph

gc.collect()
objgraph.show_backrefs(newobj, filename=f"backref-{self.n:04d}.pdf")
self.n += 1

delta = total - self.last_total
if delta < 0:
what = "decrease"
delta = -delta
else:
what = "increase"

LOG.info(
"[%s] => [%s] memory %s %s (memory=%s, tensors=%s, added=%s, removed=%s), allocated=%s, reserved=%s",
self.last_title,
title,
what,
bytes_to_human(delta),
bytes_to_human(total),
len(tensors),
len(added),
len(removed),
bytes_to_human(torch.cuda.memory_allocated(0)),
bytes_to_human(torch.cuda.memory_reserved(0)),
)

self.last_total = total
self.last_title = title
self.known = tensors

free, total = torch.cuda.mem_get_info("cuda:0")
LOG.info(f"Pytorch mem_get_info says 'cuda:0' is using {bytes_to_human(total - free)} of memory.")
61 changes: 56 additions & 5 deletions src/anemoi/inference/outputs/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,58 @@
LOG = logging.getLogger(__name__)


class HindcastOutput:

def __init__(self, reference_year):
self.reference_year = reference_year

def __call__(self, values, template, keys):

if "date" not in keys:
assert template.metadata("hdate", default=None) is None, template
date = template.metadata("date")
else:
date = keys.pop("date")

for k in ("date", "hdate"):
keys.pop(k, None)

keys["edition"] = 1
keys["localDefinitionNumber"] = 30
keys["dataDate"] = int(to_datetime(date).strftime("%Y%m%d"))
keys["referenceDate"] = int(to_datetime(date).replace(year=self.reference_year).strftime("%Y%m%d"))

return values, template, keys


MODIFIERS = dict(hindcast=HindcastOutput)


def modifier_factory(modifiers):

if modifiers is None:
return []

if not isinstance(modifiers, list):
modifiers = [modifiers]

result = []
for modifier in modifiers:
assert isinstance(modifier, dict), modifier
assert len(modifier) == 1, modifier

klass = list(modifier.keys())[0]
result.append(MODIFIERS[klass](**modifier[klass]))

return result


class GribOutput(Output):
"""
Handles grib
"""

def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None):
def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None, modifiers=None):
super().__init__(context)
self._first = True
self.typed_variables = self.checkpoint.typed_variables
Expand All @@ -40,6 +86,7 @@ def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, g
self._template_date = None
self._template_reuse = None
self.use_closest_template = False # Off for now
self.modifiers = modifier_factory(modifiers)

def write_initial_state(self, state):
# We trust the GribInput class to provide the templates
Expand Down Expand Up @@ -76,7 +123,8 @@ def write_initial_state(self, state):
quiet=self.quiet,
)

# LOG.info("Step 0 GRIB %s\n%s", template, json.dumps(keys, indent=4))
for modifier in self.modifiers:
values, template, keys = modifier(values, template, keys)

self.write_message(values, template=template, **keys)

Expand All @@ -95,7 +143,7 @@ def write_state(self, state):
self.quiet.add("_grib_templates_for_output")
LOG.warning("Input is not GRIB.")

for name, value in state["fields"].items():
for name, values in state["fields"].items():
keys = {}

variable = self.typed_variables[name]
Expand All @@ -118,7 +166,7 @@ def write_state(self, state):
keys.update(self.encoding)

keys = grib_keys(
values=value,
values=values,
template=template,
date=reference_date.strftime("%Y-%m-%d"),
time=reference_date.hour,
Expand All @@ -131,11 +179,14 @@ def write_state(self, state):
quiet=self.quiet,
)

for modifier in self.modifiers:
values, template, keys = modifier(values, template, keys)

if LOG.isEnabledFor(logging.DEBUG):
LOG.info("Encoding GRIB %s\n%s", template, json.dumps(keys, indent=4))

try:
self.write_message(value, template=template, **keys)
self.write_message(values, template=template, **keys)
except Exception:
LOG.error("Error writing field %s", name)
LOG.error("Template: %s", template)
Expand Down
74 changes: 72 additions & 2 deletions src/anemoi/inference/outputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,34 @@

LOG = logging.getLogger(__name__)

# There is a bug with hindcasts, where these keys are not added to the 'mars' namespace
MARS_MAYBE_MISSING_KEYS = (
"number",
"step",
"time",
"date",
"hdate",
"type",
"stream",
"expver",
"class",
"levtype",
"levelist",
"param",
)


def _is_valid(mars, keys):
if "number" in keys and "number" not in mars:
LOG.warning("`number` is missing from mars namespace")
return False

if "referenceDate" in keys and "hdate" not in mars:
LOG.warning("`hdate` is missing from mars namespace")
return False

return True


class ArchiveCollector:
"""Collects archive requests"""
Expand Down Expand Up @@ -61,19 +89,42 @@ def __init__(
templates=None,
grib1_keys=None,
grib2_keys=None,
modifiers=None,
**kwargs,
):
super().__init__(context, encoding=encoding, templates=templates, grib1_keys=grib1_keys, grib2_keys=grib2_keys)
super().__init__(
context,
encoding=encoding,
templates=templates,
grib1_keys=grib1_keys,
grib2_keys=grib2_keys,
modifiers=modifiers,
)
self.path = path
self.output = ekd.new_grib_output(self.path, split_output=True, **kwargs)
self.archiving = defaultdict(ArchiveCollector)
self.archive_requests = archive_requests
self.check_encoding = check_encoding
self._namespace_bug_fix = False

def __repr__(self):
return f"GribFileOutput({self.path})"

def write_message(self, message, template, **keys):
# Make sure `name` is not in the keys, otherwise grib_encoding will fail
if template is not None and template.metadata("name", default=None) is not None:
# We cannot clear the metadata...
class Dummy:
def __init__(self, template):
self.template = template
self.handle = template.handle

def __repr__(self):
return f"Dummy({self.template})"

template = Dummy(template)

# LOG.info("Writing message to %s %s", template, keys)
try:
self.collect_archive_requests(
self.output.write(
Expand All @@ -90,6 +141,7 @@ def write_message(self, message, template, **keys):

LOG.error("Error writing message to %s", self.path)
LOG.error("eccodes: %s", eccodes.__version__)
LOG.error("Template: %s, Keys: %s", template, keys)
LOG.error("Exception: %s", e)
if message is not None and np.isnan(message.data).any():
LOG.error("Message contains NaNs (%s, %s) (allow_nans=%s)", keys, template, self.context.allow_nans)
Expand All @@ -102,7 +154,25 @@ def collect_archive_requests(self, written, template, **keys):

handle, path = written

mars = handle.as_namespace("mars")
while True:

if self._namespace_bug_fix:
import eccodes
from earthkit.data.readers.grib.codes import GribCodesHandle

handle = GribCodesHandle(eccodes.codes_clone(handle._handle), None, None)

mars = {k: v for k, v in handle.items("mars")}

if _is_valid(mars, keys):
break

if self._namespace_bug_fix:
raise ValueError("Namespace bug: %s" % mars)

# Try again with the namespace bug
LOG.warning("Namespace bug detected, trying again")
self._namespace_bug_fix = True

if self.check_encoding:
check_encoding(handle, keys)
Expand Down
Loading
Loading