Skip to content

Commit

Permalink
Turn on python 3.9 QA and fix typing (#104)
Browse files Browse the repository at this point in the history
* ci: add python 3.9 to qa matrix

* fix: typing
  • Loading branch information
gmertes authored Jan 13, 2025
1 parent 90728d5 commit ff2bd59
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-pull-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
checks:
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-pytest-pyproject.yml@v2
with:
python-version: ${{ matrix.python-version }}
3 changes: 2 additions & 1 deletion src/anemoi/inference/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections import defaultdict
from functools import cached_property
from pathlib import Path
from typing import Optional

from anemoi.utils.checkpoints import load_metadata
from earthkit.data.utils.dates import to_datetime
Expand Down Expand Up @@ -179,7 +180,7 @@ def validate_environment(
*,
all_packages: bool = False,
on_difference: str = "warn",
exempt_packages: list[str] | None = None,
exempt_packages: Optional[list[str]] = None,
) -> bool:
return self._metadata.validate_environment(
all_packages=all_packages, on_difference=on_difference, exempt_packages=exempt_packages
Expand Down
18 changes: 10 additions & 8 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from typing import Any
from typing import Dict
from typing import Literal
from typing import Optional
from typing import Union

import yaml
from pydantic import BaseModel
Expand All @@ -27,12 +29,12 @@ class Configuration(BaseModel):
class Config:
extra = "forbid"

description: str | None = None
description: Optional[str] = None

checkpoint: str | Dict[Literal["huggingface"], Dict[str, Any] | str]
"""A path to an Anemoi checkpoint file."""

date: str | int | datetime.datetime | None = None
date: Union[str, int, datetime.datetime, None] = None
"""The starting date for the forecast. If not provided, the date will depend on the selected Input object. If a string, it is parsed by :func:`anemoi.utils.dates.as_datetime`.
"""

Expand All @@ -41,7 +43,7 @@ class Config:
If an integer, it represents a number of hours. Otherwise, it is parsed by :func:`anemoi.utils.dates.as_timedelta`.
"""

name: str | None = None
name: Optional[str] = None
"""Used by prepml."""

verbosity: int = 0
Expand All @@ -50,19 +52,19 @@ class Config:
report_error: bool = False
"""If True, the runner list the training versions of the packages in case of error."""

input: str | Dict | None = "test"
output: str | Dict | None = "printer"
input: Union[str, Dict, None] = "test"
output: Union[str, Dict, None] = "printer"

forcings: Dict[str, Dict] | None = None
forcings: Union[Dict[str, Dict], None] = None
"""Where to find the forcings."""

device: str = "cuda"
"""The device on which the model should run. This can be "cpu", "cuda" or any other value supported by PyTorch."""

precision: str | None = None
precision: Optional[str] = None
"""The precision in which the model should run. If not provided, the model will use the precision used during training."""

allow_nans: bool | None = None
allow_nans: Optional[bool] = None
"""
- If None (default), the model will check for NaNs in the input. If NaNs are found, the model issue a warning and `allow_nans` to True.
- If False, the model will raise an exception if NaNs are found in the input and output.
Expand Down
3 changes: 2 additions & 1 deletion src/anemoi/inference/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from functools import cached_property
from types import MappingProxyType as frozendict
from typing import Literal
from typing import Optional

import numpy as np
from anemoi.transform.variables import Variable
Expand Down Expand Up @@ -461,7 +462,7 @@ def validate_environment(
*,
all_packages: bool = False,
on_difference: Literal["warn", "error", "ignore"] = "warn",
exempt_packages: list[str] | None = None,
exempt_packages: Optional[list[str]] = None,
) -> bool:
"""
Validate environment of the checkpoint against the current environment.
Expand Down
3 changes: 2 additions & 1 deletion src/anemoi/inference/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
from typing import TYPE_CHECKING
from typing import Literal
from typing import Optional

from anemoi.utils.provenance import gather_provenance_info
from packaging.version import Version
Expand Down Expand Up @@ -43,7 +44,7 @@ def validate_environment(
*,
all_packages: bool = False,
on_difference: Literal["warn", "error", "ignore"] = "warn",
exempt_packages: list[str] | None = None,
exempt_packages: Optional[list[str]] = None,
) -> bool:
"""
Validate environment of the checkpoint against the current environment.
Expand Down
3 changes: 2 additions & 1 deletion src/anemoi/inference/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Optional
from typing import TypeVar

import numpy as np
Expand Down Expand Up @@ -163,7 +164,7 @@ def from_xarray(
self,
data: xr.Dataset | xr.DataArray,
*,
flatten: str | None = None,
flatten: Optional[str] = None,
variable_dim: str = "variable",
private_info: Any = None,
) -> State:
Expand Down

0 comments on commit ff2bd59

Please sign in to comment.