Skip to content

Commit

Permalink
Work on checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed May 27, 2024
1 parent 9ca3d1e commit 1654d46
Showing 1 changed file with 26 additions and 29 deletions.
55 changes: 26 additions & 29 deletions src/anemoi/inference/commands/edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,13 @@
#


from ..checkpoint import Checkpoint
from . import Command


def visit(x, path, name, value):
if isinstance(x, dict):
for k, v in x.items():
if k == name:
print(".".join(path), k, v)
import os
import subprocess
from tempfile import TemporaryDirectory

if v == value:
print(".".join(path), k, v)
import yaml

path.append(k)
visit(v, path, name, value)
path.pop()

if isinstance(x, list):
for i, v in enumerate(x):
path.append(str(i))
visit(v, path, name, value)
path.pop()
from . import Command


class EditCmd(Command):
Expand All @@ -39,19 +24,31 @@ class EditCmd(Command):

def add_arguments(self, command_parser):
command_parser.add_argument("path", help="Path to the checkpoint.")
command_parser.add_argument("--name", help="Search for a specific name.")
command_parser.add_argument("--value", help="Search for a specific value.")
command_parser.add_argument("--editor", help="Editor to use.", default=os.environ.get("EDITOR", "vi"))

def run(self, args):

checkpoint = Checkpoint(args.path)
from anemoi.utils.checkpoints import DEFAULT_NAME
from anemoi.utils.checkpoints import load_metadata
from anemoi.utils.checkpoints import metadata_files
from anemoi.utils.checkpoints import replace_metadata

OLD_NAME = "ai-models.json"

names = metadata_files(args.path)
name = OLD_NAME if OLD_NAME in names else DEFAULT_NAME

with TemporaryDirectory() as temp_dir:
path = os.path.join(temp_dir, "checkpoint.yaml")
with open(path, "w") as f:
yaml.dump(load_metadata(args.path, name), f)

subprocess.check_call([args.editor, path])

with open(path) as f:
replace_metadata(args.path, yaml.safe_load(f), OLD_NAME)

visit(
checkpoint,
[],
args.name if args.name is not None else object(),
args.value if args.value is not None else object(),
)
# checkpoint.pack(temp_dir, args.path)


command = EditCmd

0 comments on commit 1654d46

Please sign in to comment.