diff --git a/cyclopts/__init__.py b/cyclopts/__init__.py index 8f899d35..d36d535c 100644 --- a/cyclopts/__init__.py +++ b/cyclopts/__init__.py @@ -5,6 +5,7 @@ "App", "CommandCollisionError", "CycloptsError", + "MissingArgumentError", "MissingTypeError", "Parameter", "RepeatKeywordError", @@ -17,6 +18,7 @@ from cyclopts.exceptions import ( CommandCollisionError, CycloptsError, + MissingArgumentError, MissingTypeError, RepeatKeywordError, UnknownKeywordError, diff --git a/cyclopts/bind.py b/cyclopts/bind.py index 85e5642f..4a44376f 100644 --- a/cyclopts/bind.py +++ b/cyclopts/bind.py @@ -4,7 +4,12 @@ from typing import Callable, Dict, Iterable, Tuple from cyclopts.coercion import default_coercion_lookup -from cyclopts.exceptions import MissingTypeError, UnknownKeywordError, UnreachableError +from cyclopts.exceptions import ( + MissingArgumentError, + MissingTypeError, + UnknownKeywordError, + UnreachableError, +) from cyclopts.parameter import get_hint_param @@ -65,7 +70,10 @@ def _parse_kw_and_flags(f, tokens): token = token[2:] # remove the leading "--" if "=" in token: cli_key, cli_value = token.split("=", 1) - parameter = cli2kw[cli_key] + try: + parameter = cli2kw[cli_key] + except KeyError as e: + raise UnknownKeywordError(cli_key) from e elif token in cli2flag: parameter, cli_value = cli2flag[token] elif token in cli2kw: @@ -73,7 +81,7 @@ def _parse_kw_and_flags(f, tokens): try: cli_value = tokens[i + 1] except IndexError as e: - raise ValueError("No value supplied for keyword --{cli_key}") from e + raise MissingArgumentError(f"Unknown CLI keyword --{cli_key}") from e parameter = cli2kw[cli_key] skip_next_iteration = True else: @@ -120,7 +128,11 @@ def create_bound_arguments(f, tokens) -> Tuple[inspect.BoundArguments, Iterable[ f_pos, remaining_tokens = _parse_pos(f, remaining_tokens, f_kwargs) signature = inspect.signature(f) - bound = signature.bind(*f_pos, **f_kwargs) + try: + bound = signature.bind(*f_pos, **f_kwargs) + except TypeError as e: + raise MissingArgumentError from e + bound.apply_defaults() return bound, remaining_tokens diff --git a/cyclopts/exceptions.py b/cyclopts/exceptions.py index 81c74b3f..4c7b9f4e 100644 --- a/cyclopts/exceptions.py +++ b/cyclopts/exceptions.py @@ -33,6 +33,10 @@ def __init__(self, value, message="Unknown keyword or flag: "): super().__init__(self.message) +class MissingArgumentError(CycloptsError): + pass + + class RepeatKeywordError(CycloptsError): def __init__(self, message="Cannot specify Parameter multiple times per annotation."): super().__init__(message) diff --git a/tests/bind/test_basic.py b/tests/bind/test_basic.py deleted file mode 100644 index 02e4ae6c..00000000 --- a/tests/bind/test_basic.py +++ /dev/null @@ -1,43 +0,0 @@ -import inspect - -import pytest - -import cyclopts - - -@pytest.fixture -def app(): - return cyclopts.App() - - -def test_missing_positional_type(app): - with pytest.raises(cyclopts.MissingTypeError): - - @app.command - def foo(a, b, c): - pass - - -@pytest.mark.parametrize( - "cmd_str", - [ - "foo 1 2 3", - "foo --a 1 --b 2 --c 3", - "foo --c 3 1 2", - "foo --c 3 --b=2 1", - "foo --c 3 --b=2 --a 1", - "foo 1 --b=2 3", - ], -) -def test_basic_1(app, cmd_str): - @app.command - def foo(a: int, b: int, c: int): - pass - - signature = inspect.signature(foo) - expected_bind = signature.bind(1, 2, 3) - - actual_command, actual_bind, unused_args = app.parse_known_args(cmd_str) - assert actual_command == foo - assert actual_bind == expected_bind - assert unused_args == [] diff --git a/tests/bind/test_kw_only.py b/tests/bind/test_kw_only.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/bind/test_pos_only.py b/tests/bind/test_pos_only.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/test_bind.py b/tests/test_bind.py new file mode 100644 index 00000000..c77929a5 --- /dev/null +++ b/tests/test_bind.py @@ -0,0 +1,143 @@ +import inspect + +import pytest + +import cyclopts +from cyclopts import MissingArgumentError +from cyclopts.exceptions import UnknownKeywordError + + +@pytest.fixture +def app(): + return cyclopts.App() + + +def test_missing_positional_type(app): + with pytest.raises(cyclopts.MissingTypeError): + + @app.command + def foo(a, b, c): + pass + + +@pytest.mark.parametrize( + "cmd_str", + [ + "foo 1 2 3", + "foo --a 1 --b 2 --c 3", + "foo --c 3 1 2", + "foo --c 3 --b=2 1", + "foo --c 3 --b=2 --a 1", + "foo 1 --b=2 3", + ], +) +def test_basic_1(app, cmd_str): + @app.command + def foo(a: int, b: int, c: int): + pass + + signature = inspect.signature(foo) + expected_bind = signature.bind(1, 2, 3) + + actual_command, actual_bind, unused_args = app.parse_known_args(cmd_str) + assert actual_command == foo + assert actual_bind == expected_bind + assert unused_args == [] + + +@pytest.mark.parametrize( + "cmd_str", + [ + "foo 1 2 3 --d 10 --some-flag", + "foo --some-flag 1 --b=2 3 --d 10", + "foo 1 2 --some-flag 3 --d 10", + ], +) +def test_basic_2(app, cmd_str): + @app.command + def foo(a: int, b: int, c: int, d: int = 5, some_flag: bool = False): + pass + + signature = inspect.signature(foo) + expected_bind = signature.bind(1, 2, 3, d=10, some_flag=True) + + actual_command, actual_bind, unused_args = app.parse_known_args(cmd_str) + assert actual_command == foo + assert actual_bind == expected_bind + assert unused_args == [] + + +@pytest.mark.parametrize( + "cmd_str", + [ + "foo 1 2 3", + ], +) +def test_basic_pos_only(app, cmd_str): + @app.command + def foo(a: int, b: int, c: int, /): + pass + + signature = inspect.signature(foo) + expected_bind = signature.bind(1, 2, 3) + + actual_command, actual_bind, unused_args = app.parse_known_args(cmd_str) + assert actual_command == foo + assert actual_bind == expected_bind + assert unused_args == [] + + +@pytest.mark.parametrize( + "cmd_str_e", + [ + ("foo 1 2 --c=3", UnknownKeywordError), + ], +) +def test_basic_pos_only_exceptions(app, cmd_str_e): + cmd_str, e = cmd_str_e + + @app.command + def foo(a: int, b: int, c: int, /): + pass + + with pytest.raises(e): + app.parse_known_args(cmd_str) + + +@pytest.mark.parametrize( + "cmd_str", + [ + "foo 1 2 3 4", + "foo 1 2 3 --d 4", + "foo 1 2 --d=4 3", + ], +) +def test_basic_pos_only_extended(app, cmd_str): + @app.command + def foo(a: int, b: int, c: int, /, d: int): + pass + + signature = inspect.signature(foo) + expected_bind = signature.bind(1, 2, 3, 4) + + actual_command, actual_bind, unused_args = app.parse_known_args(cmd_str) + assert actual_command == foo + assert actual_bind == expected_bind + assert unused_args == [] + + +@pytest.mark.parametrize( + "cmd_str_e", + [ + ("foo 1 2 3", MissingArgumentError), + ], +) +def test_basic_pos_only_extended_exceptions(app, cmd_str_e): + cmd_str, e = cmd_str_e + + @app.command + def foo(a: int, b: int, c: int, /, d: int): + pass + + with pytest.raises(e): + app.parse_known_args(cmd_str)