Skip to content

Commit

Permalink
Better prepml support (#89)
Browse files Browse the repository at this point in the history
* Better prepml support

* add support for hindcasts

* tmpfix for missing "number" in hindcasts

* Add more keys to maybe missing mars keys

* address mars namespace bug

---------

Co-authored-by: Gert Mertes <[email protected]>
  • Loading branch information
b8raoult and gmertes authored Jan 13, 2025
1 parent 92eb869 commit 05d5e6f
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,4 @@ _version.py
*.gif
*.zarr/
/*-plots/
/definitions*/
13 changes: 13 additions & 0 deletions src/anemoi/inference/commands/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def add_arguments(self, command_parser):
command_parser.add_argument("--output", type=str, help="Output file")
command_parser.add_argument("--staging-dates", type=str, help="Path to a file with staging dates")
command_parser.add_argument("--extra", action="append", help="Additional request values. Can be repeated")
command_parser.add_argument("--retrieve-fields-type", type=str, help="Type of fields to retrieve")
command_parser.add_argument("overrides", nargs="*", help="Overrides.")

def run(self, args):
Expand All @@ -45,6 +46,18 @@ def run(self, args):
area = runner.checkpoint.area
grid = runner.checkpoint.grid

if args.retrieve_fields_type is not None:
selected = set()

for name, kinds in runner.checkpoint.variable_categories().items():
if "computed" in kinds:
continue
for kind in kinds:
if args.retrieve_fields_type.startswith(kind): # PrepML adds an 's' to the type
selected.add(name)

variables = sorted(selected)

extra = postproc(grid, area)

for r in args.extra or []:
Expand Down
3 changes: 3 additions & 0 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class Config:
development_hacks: dict = {}
"""A dictionary of development hacks to apply to the runner. This is used to test new features or to work around"""

debugging_info: dict = {}
"""A dictionary to store debug information. This is ignored."""


def load_config(path, overrides, defaults=None, Configuration=Configuration):

Expand Down
5 changes: 2 additions & 3 deletions src/anemoi/inference/grib/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,10 @@ def grib_keys(
result.update(grib2_keys.get(param, {}))

result.setdefault("type", "fc")
type = result.get("type")

if type is not None:
if result.get("type") in ("an", "fc"):
# For organisations that do not use type
result.setdefault("dataType", type)
result.setdefault("dataType", result.pop("type"))

# if stream is not None:
# result.setdefault("stream", stream)
Expand Down
61 changes: 56 additions & 5 deletions src/anemoi/inference/outputs/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,58 @@
LOG = logging.getLogger(__name__)


class HindcastOutput:

def __init__(self, reference_year):
self.reference_year = reference_year

def __call__(self, values, template, keys):

if "date" not in keys:
assert template.metadata("hdate", default=None) is None, template
date = template.metadata("date")
else:
date = keys.pop("date")

for k in ("date", "hdate"):
keys.pop(k, None)

keys["edition"] = 1
keys["localDefinitionNumber"] = 30
keys["dataDate"] = int(to_datetime(date).strftime("%Y%m%d"))
keys["referenceDate"] = int(to_datetime(date).replace(year=self.reference_year).strftime("%Y%m%d"))

return values, template, keys


MODIFIERS = dict(hindcast=HindcastOutput)


def modifier_factory(modifiers):

if modifiers is None:
return []

if not isinstance(modifiers, list):
modifiers = [modifiers]

result = []
for modifier in modifiers:
assert isinstance(modifier, dict), modifier
assert len(modifier) == 1, modifier

klass = list(modifier.keys())[0]
result.append(MODIFIERS[klass](**modifier[klass]))

return result


class GribOutput(Output):
"""
Handles grib
"""

def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None):
def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None, modifiers=None):
super().__init__(context)
self._first = True
self.typed_variables = self.checkpoint.typed_variables
Expand All @@ -40,6 +86,7 @@ def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, g
self._template_date = None
self._template_reuse = None
self.use_closest_template = False # Off for now
self.modifiers = modifier_factory(modifiers)

def write_initial_state(self, state):
# We trust the GribInput class to provide the templates
Expand Down Expand Up @@ -76,7 +123,8 @@ def write_initial_state(self, state):
quiet=self.quiet,
)

# LOG.info("Step 0 GRIB %s\n%s", template, json.dumps(keys, indent=4))
for modifier in self.modifiers:
values, template, keys = modifier(values, template, keys)

self.write_message(values, template=template, **keys)

Expand All @@ -95,7 +143,7 @@ def write_state(self, state):
self.quiet.add("_grib_templates_for_output")
LOG.warning("Input is not GRIB.")

for name, value in state["fields"].items():
for name, values in state["fields"].items():
keys = {}

variable = self.typed_variables[name]
Expand All @@ -118,7 +166,7 @@ def write_state(self, state):
keys.update(self.encoding)

keys = grib_keys(
values=value,
values=values,
template=template,
date=reference_date.strftime("%Y-%m-%d"),
time=reference_date.hour,
Expand All @@ -131,11 +179,14 @@ def write_state(self, state):
quiet=self.quiet,
)

for modifier in self.modifiers:
values, template, keys = modifier(values, template, keys)

if LOG.isEnabledFor(logging.DEBUG):
LOG.info("Encoding GRIB %s\n%s", template, json.dumps(keys, indent=4))

try:
self.write_message(value, template=template, **keys)
self.write_message(values, template=template, **keys)
except Exception:
LOG.error("Error writing field %s", name)
LOG.error("Template: %s", template)
Expand Down
74 changes: 72 additions & 2 deletions src/anemoi/inference/outputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,34 @@

LOG = logging.getLogger(__name__)

# There is a bug with hindcasts, where these keys are not added to the 'mars' namespace
MARS_MAYBE_MISSING_KEYS = (
"number",
"step",
"time",
"date",
"hdate",
"type",
"stream",
"expver",
"class",
"levtype",
"levelist",
"param",
)


def _is_valid(mars, keys):
if "number" in keys and "number" not in mars:
LOG.warning("`number` is missing from mars namespace")
return False

if "referenceDate" in keys and "hdate" not in mars:
LOG.warning("`hdate` is missing from mars namespace")
return False

return True


class ArchiveCollector:
"""Collects archive requests"""
Expand Down Expand Up @@ -61,19 +89,42 @@ def __init__(
templates=None,
grib1_keys=None,
grib2_keys=None,
modifiers=None,
**kwargs,
):
super().__init__(context, encoding=encoding, templates=templates, grib1_keys=grib1_keys, grib2_keys=grib2_keys)
super().__init__(
context,
encoding=encoding,
templates=templates,
grib1_keys=grib1_keys,
grib2_keys=grib2_keys,
modifiers=modifiers,
)
self.path = path
self.output = ekd.new_grib_output(self.path, split_output=True, **kwargs)
self.archiving = defaultdict(ArchiveCollector)
self.archive_requests = archive_requests
self.check_encoding = check_encoding
self._namespace_bug_fix = False

def __repr__(self):
return f"GribFileOutput({self.path})"

def write_message(self, message, template, **keys):
# Make sure `name` is not in the keys, otherwise grib_encoding will fail
if template is not None and template.metadata("name", default=None) is not None:
# We cannot clear the metadata...
class Dummy:
def __init__(self, template):
self.template = template
self.handle = template.handle

def __repr__(self):
return f"Dummy({self.template})"

template = Dummy(template)

# LOG.info("Writing message to %s %s", template, keys)
try:
self.collect_archive_requests(
self.output.write(
Expand All @@ -90,6 +141,7 @@ def write_message(self, message, template, **keys):

LOG.error("Error writing message to %s", self.path)
LOG.error("eccodes: %s", eccodes.__version__)
LOG.error("Template: %s, Keys: %s", template, keys)
LOG.error("Exception: %s", e)
if message is not None and np.isnan(message.data).any():
LOG.error("Message contains NaNs (%s, %s) (allow_nans=%s)", keys, template, self.context.allow_nans)
Expand All @@ -102,7 +154,25 @@ def collect_archive_requests(self, written, template, **keys):

handle, path = written

mars = handle.as_namespace("mars")
while True:

if self._namespace_bug_fix:
import eccodes
from earthkit.data.readers.grib.codes import GribCodesHandle

handle = GribCodesHandle(eccodes.codes_clone(handle._handle), None, None)

mars = {k: v for k, v in handle.items("mars")}

if _is_valid(mars, keys):
break

if self._namespace_bug_fix:
raise ValueError("Namespace bug: %s" % mars)

# Try again with the namespace bug
LOG.warning("Namespace bug detected, trying again")
self._namespace_bug_fix = True

if self.check_encoding:
check_encoding(handle, keys)
Expand Down

0 comments on commit 05d5e6f

Please sign in to comment.