diff --git a/cyclopts/_convert.py b/cyclopts/_convert.py index cca2c483..cd73e9d3 100644 --- a/cyclopts/_convert.py +++ b/cyclopts/_convert.py @@ -216,7 +216,8 @@ def _convert( out = convert_tuple(type_, token, converter=converter) else: out = convert_tuple(type_, *token, converter=converter) - elif origin_type in ITERABLE_TYPES: # NOT including tuple + elif origin_type in ITERABLE_TYPES: + # NOT including tuple; handled in ``origin_type is tuple`` body above. count, _ = token_count(inner_types[0]) if not isinstance(token, Sequence): raise ValueError diff --git a/cyclopts/argument.py b/cyclopts/argument.py index 996ea890..c9ad7985 100644 --- a/cyclopts/argument.py +++ b/cyclopts/argument.py @@ -8,6 +8,7 @@ import cyclopts.utils from cyclopts._convert import ( + ITERABLE_TYPES, convert, token_count, ) @@ -112,6 +113,16 @@ def _identity_converter(type_, token): return token +def _get_parameters(hint: Any) -> tuple[Any, list[Parameter]]: + """At root level, checks for cyclopts.Parameter annotations.""" + if is_annotated(hint): + inner = get_args(hint) + hint = inner[0] + return hint, [x for x in inner[1:] if isinstance(x, Parameter)] + else: + return hint, [] + + class ArgumentCollection(list["Argument"]): """A list-like container for :class:`Argument`.""" @@ -206,10 +217,23 @@ def _from_type( cyclopts_parameters_no_group = [] hint = field_info.hint - if is_annotated(hint): - annotations = hint.__metadata__ # pyright: ignore - hint = get_args(hint)[0] - cyclopts_parameters_no_group.extend(x for x in annotations if isinstance(x, Parameter)) + hint, hint_parameters = _get_parameters(hint) + cyclopts_parameters_no_group.extend(hint_parameters) + + # Handle annotations where ``Annotated`` is not at the root level; e.g. ``list[Annotated[...]]``. + # Multiple inner Parameter Annotations only make sense if providing specific converter/validators. + origin = get_origin(hint) + if origin is tuple: + # handled in _convert.py + pass + elif origin in ITERABLE_TYPES: + inner_hints = get_args(hint) + if len(inner_hints) > 1: + raise NotImplementedError(f"Did not expect multiple inner type arguments: {inner_hints}.") + elif len(inner_hints) == 1: + inner_hint = inner_hints[0] + _, hint_parameters = _get_parameters(inner_hint) + cyclopts_parameters_no_group.extend(hint_parameters) if not keys: # root hint annotation if field_info.kind is field_info.VAR_KEYWORD: @@ -245,6 +269,7 @@ def _from_type( # if not immediate_parameter.parse: # return out + # resolve/derive the parameter name if keys: cparam = Parameter.combine( upstream_parameter, diff --git a/pyproject.toml b/pyproject.toml index 8019cd7e..ad10fcc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,6 +184,7 @@ convention = "numpy" "D106", "D107", "D205", + "D400", "D404", "S102", # use of "exec" "S106", # possible hardcoded password. diff --git a/tests/types/test_types_number.py b/tests/types/test_types_number.py index 3c8c3bda..09457540 100644 --- a/tests/types/test_types_number.py +++ b/tests/types/test_types_number.py @@ -1,3 +1,5 @@ +from typing import Optional + import pytest from cyclopts.exceptions import ValidationError @@ -18,3 +20,23 @@ def default(color: tuple[UInt8, UInt8, UInt8] = (0x00, 0x00, 0x00)): with pytest.raises(ValidationError) as e: app.parse_args("--color 100 200 300", exit_on_error=False) assert str(e.value) == 'Invalid value "300" for "--color". Must be <= 255.' + + +def test_nested_list_annotated_validator(app, assert_parse_args): + @app.default + def default(color: Optional[list[tuple[UInt8, UInt8, UInt8]]] = None): + pass + + assert_parse_args( + default, + "0x12 0x34 0x56 0x78 0x90 0xAB", + [(0x12, 0x34, 0x56), (0x78, 0x90, 0xAB)], + ) + + with pytest.raises(ValidationError) as e: + app.parse_args("100 200 300", exit_on_error=False) + assert str(e.value) == 'Invalid value "300" for "COLOR". Must be <= 255.' + + with pytest.raises(ValidationError) as e: + app.parse_args("--color 100 200 300", exit_on_error=False) + assert str(e.value) == 'Invalid value "300" for "--color". Must be <= 255.' diff --git a/tests/types/test_types_path.py b/tests/types/test_types_path.py index 6ac34e77..a75f8340 100644 --- a/tests/types/test_types_path.py +++ b/tests/types/test_types_path.py @@ -28,6 +28,28 @@ def test_types_existing_file(convert, tmp_file): assert tmp_file == convert(ct.ExistingFile, tmp_file) +def test_types_existing_file_app(app): + """https://github.com/BrianPugh/cyclopts/issues/287""" + + @app.default + def main(f: ct.ExistingFile): + pass + + with pytest.raises(ValidationError): + app(["this-file-does-not-exist"], exit_on_error=False) + + +def test_types_existing_file_app_list(app): + """https://github.com/BrianPugh/cyclopts/issues/287""" + + @app.default + def main(f: list[ct.ExistingFile]): + pass + + with pytest.raises(ValidationError): + app(["this-file-does-not-exist"], exit_on_error=False) + + def test_types_existing_file_validation_error(convert, tmp_path): with pytest.raises(ValidationError): convert(ct.ExistingFile, tmp_path)