Skip to content

Commit

Permalink
add a few commands
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed May 27, 2024
1 parent 57a1b35 commit e33c19f
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 1 deletion.
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ classifiers = [
"Operating System :: OS Independent",
]

dependencies = ["anemoi-utils", "semantic-version"]
dependencies = [
"tomli", # Only needed before 3.11
"anemoi-utils",
"semantic-version",
"pyyaml",
]

[project.optional-dependencies]

Expand Down
7 changes: 7 additions & 0 deletions src/anemoi/inference/commands/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import json

import yaml

from ..checkpoint import Checkpoint
from . import Command

Expand All @@ -21,6 +23,7 @@ class CheckpointCmd(Command):

def add_arguments(self, command_parser):
command_parser.add_argument("--json", action="store_true", help="Output in JSON format")
command_parser.add_argument("--yaml", action="store_true", help="Output in YAML format")
command_parser.add_argument("path", help="Path to the checkpoint.")

def run(self, args):
Expand All @@ -30,6 +33,10 @@ def run(self, args):
print(json.dumps(c.to_dict(), indent=4, sort_keys=True))
return

if args.yaml:
print(yaml.dump(c.to_dict(), indent=4, sort_keys=True))
return

print("area:", c.area)
print("computed_constants:", c.computed_constants)
print("computed_constants_mask:", c.computed_constants_mask)
Expand Down
64 changes: 64 additions & 0 deletions src/anemoi/inference/commands/dump.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env python
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#


import json

import yaml

from ..checkpoint import Checkpoint
from . import Command


def _dump(x, path):
if isinstance(x, dict):
for k, v in x.items():

path.append(k)
_dump(v, path)
path.pop()
return

if isinstance(x, list):
for i, v in enumerate(x):
path.append(str(i))
_dump(v, path)
path.pop()
return

name = ".".join(path)
print(f"{name}: {repr(x)}")


class DumpCmd(Command):
"""Dump the content of a checkpoint file as JSON or YAML."""

need_logging = False

def add_arguments(self, command_parser):
command_parser.add_argument("--json", action="store_true", help="Output in JSON format")
command_parser.add_argument("--yaml", action="store_true", help="Output in YAML format")
command_parser.add_argument("path", help="Path to the checkpoint.")

def run(self, args):

c = Checkpoint(args.path)
if args.json:
print(json.dumps(c.to_dict(), indent=4, sort_keys=True))
return

if args.yaml:
print(yaml.dump(c.to_dict(), indent=4, sort_keys=True))
return

_dump(c.to_dict(), [])


command = DumpCmd
57 changes: 57 additions & 0 deletions src/anemoi/inference/commands/edit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#


from ..checkpoint import Checkpoint
from . import Command


def visit(x, path, name, value):
if isinstance(x, dict):
for k, v in x.items():
if k == name:
print(".".join(path), k, v)

if v == value:
print(".".join(path), k, v)

path.append(k)
visit(v, path, name, value)
path.pop()

if isinstance(x, list):
for i, v in enumerate(x):
path.append(str(i))
visit(v, path, name, value)
path.pop()


class EditCmd(Command):

need_logging = False

def add_arguments(self, command_parser):
command_parser.add_argument("path", help="Path to the checkpoint.")
command_parser.add_argument("--name", help="Search for a specific name.")
command_parser.add_argument("--value", help="Search for a specific value.")

def run(self, args):

checkpoint = Checkpoint(args.path)

visit(
checkpoint,
[],
args.name if args.name is not None else object(),
args.value if args.value is not None else object(),
)


command = EditCmd

0 comments on commit e33c19f

Please sign in to comment.