diff --git a/binaries.py b/binaries.py index 0b817f63..d54220f1 100644 --- a/binaries.py +++ b/binaries.py @@ -8,15 +8,21 @@ import subprocess import sys from collections import OrderedDict +from typing import Generator, Tuple - -class environment: +class Environment: + """ + Environment class to handle the build and distribution process for different operating systems. + """ WIN = "win" LINUX = "linux" MACOS = "macos" - def __init__(self): + def __init__(self) -> None: + """ + Initialize the environment based on the BINARY_OS environment variable. + """ os_mapping = { "windows-latest": self.WIN, "ubuntu-20.04": self.LINUX, @@ -25,7 +31,13 @@ def __init__(self): self.os = os_mapping[os.getenv("BINARY_OS")] @property - def python(self): + def python(self) -> Generator[Tuple[int, str], None, None]: + """ + Generator to yield the architecture and corresponding Python executable path. + + Yields: + Generator[Tuple[int, str], None, None]: Architecture and Python executable path. + """ for arch, python in self.PYTHON_BINARIES[self.os].items(): yield arch, python @@ -49,11 +61,15 @@ def python(self): } } - def run(self, command): - """Runs the given command via subprocess.check_output. + def run(self, command: str) -> None: + """ + Runs the given command via subprocess.run. - Exits with -1 if the command wasn't successfull. + Args: + command (str): The command to run. + Exits: + Exits with -1 if the command wasn't successful. """ try: print(f"RUNNING: {command}") @@ -68,7 +84,7 @@ def run(self, command): print(e.output and e.output.decode('utf-8')) sys.exit(-1) - def install(self): + def install(self) -> None: """ Install required dependencies """ @@ -76,8 +92,10 @@ def install(self): self.run(f"{python} -m pip install pyinstaller") self.run(f"{python} -m pip install -r test_requirements.txt") - def dist(self): - """Runs Pyinstaller producing a binary for every platform arch.""" + def dist(self) -> None: + """ + Runs PyInstaller to produce a binary for every platform architecture. + """ for arch, python in self.python: # Build the binary @@ -102,9 +120,9 @@ def dist(self): else: self.run(f"cp {binary_path} {artifact_path}") - def test(self): + def test(self) -> None: """ - Runs tests for every available arch on the current platform. + Runs tests for every available architecture on the current platform. """ for arch, python in self.python: self.run(f"{python} -m pytest --log-level=DEBUG") @@ -116,7 +134,7 @@ def test(self): print("usage: binaries.py [install|test|dist]") sys.exit(-1) - env = environment() + env = Environment() # Runs the command in sys.argv[1] (install|test|dist) getattr(env, sys.argv[1])() diff --git a/safety/util.py b/safety/util.py index feef3747..420eb13a 100644 --- a/safety/util.py +++ b/safety/util.py @@ -7,7 +7,7 @@ from datetime import datetime from difflib import SequenceMatcher from threading import Lock -from typing import List, Optional +from typing import List, Optional, Dict, Generator, Tuple, Union, Any import click from click import BadParameter @@ -27,17 +27,45 @@ LOG = logging.getLogger(__name__) -def is_a_remote_mirror(mirror): +def is_a_remote_mirror(mirror: str) -> bool: + """ + Check if a mirror URL is remote. + + Args: + mirror (str): The mirror URL. + + Returns: + bool: True if the mirror URL is remote, False otherwise. + """ return mirror.startswith("http://") or mirror.startswith("https://") -def is_supported_by_parser(path): +def is_supported_by_parser(path: str) -> bool: + """ + Check if the file path is supported by the parser. + + Args: + path (str): The file path. + + Returns: + bool: True if the file path is supported, False otherwise. + """ supported_types = (".txt", ".in", ".yml", ".ini", "Pipfile", "Pipfile.lock", "setup.cfg", "poetry.lock") return path.endswith(supported_types) -def parse_requirement(dep, found): +def parse_requirement(dep: Any, found: str) -> SafetyRequirement: + """ + Parse a requirement. + + Args: + dep (Any): The dependency. + found (str): The location where the dependency was found. + + Returns: + SafetyRequirement: The parsed requirement. + """ req = SafetyRequirement(dep) req.found = found @@ -47,7 +75,16 @@ def parse_requirement(dep, found): return req -def find_version(requirements): +def find_version(requirements: List[SafetyRequirement]) -> Optional[str]: + """ + Find the version of a requirement. + + Args: + requirements (List[SafetyRequirement]): The list of requirements. + + Returns: + Optional[str]: The version if found, None otherwise. + """ ver = None if len(requirements) != 1: @@ -61,12 +98,16 @@ def find_version(requirements): return ver -def read_requirements(fh, resolve=True): +def read_requirements(fh: Any, resolve: bool = True) -> Generator[Package, None, None]: """ - Reads requirements from a file like object and (optionally) from referenced files. - :param fh: file like object to read from - :param resolve: boolean. resolves referenced files. - :return: generator + Reads requirements from a file-like object and (optionally) from referenced files. + + Args: + fh (Any): The file-like object to read from. + resolve (bool): Resolves referenced files. + + Returns: + Generator: Yields Package objects. """ is_temp_file = not hasattr(fh, 'name') path = None @@ -111,14 +152,35 @@ def read_requirements(fh, resolve=True): more_info_url=None) -def get_proxy_dict(proxy_protocol, proxy_host, proxy_port): +def get_proxy_dict(proxy_protocol: str, proxy_host: str, proxy_port: int) -> Optional[Dict[str, str]]: + """ + Get the proxy dictionary for requests. + + Args: + proxy_protocol (str): The proxy protocol. + proxy_host (str): The proxy host. + proxy_port (int): The proxy port. + + Returns: + Optional[Dict[str, str]]: The proxy dictionary if all parameters are provided, None otherwise. + """ if proxy_protocol and proxy_host and proxy_port: # Safety only uses https request, so only https dict will be passed to requests return {'https': f"{proxy_protocol}://{proxy_host}:{str(proxy_port)}"} return None -def get_license_name_by_id(license_id, db): +def get_license_name_by_id(license_id: int, db: Dict[str, Any]) -> Optional[str]: + """ + Get the license name by its ID. + + Args: + license_id (int): The license ID. + db (Dict[str, Any]): The database containing license information. + + Returns: + Optional[str]: The license name if found, None otherwise. + """ licenses = db.get('licenses', []) for name, id in licenses.items(): if id == license_id: @@ -126,7 +188,13 @@ def get_license_name_by_id(license_id, db): return None -def get_flags_from_context(): +def get_flags_from_context() -> Dict[str, str]: + """ + Get the flags from the current click context. + + Returns: + Dict[str, str]: A dictionary of flags and their corresponding option names. + """ flags = {} context = click.get_current_context(silent=True) @@ -139,7 +207,13 @@ def get_flags_from_context(): return flags -def get_used_options(): +def get_used_options() -> Dict[str, Dict[str, int]]: + """ + Get the used options from the command-line arguments. + + Returns: + Dict[str, Dict[str, int]]: A dictionary of used options and their counts. + """ flags = get_flags_from_context() used_options = {} @@ -156,12 +230,27 @@ def get_used_options(): return used_options -def get_safety_version(): +def get_safety_version() -> str: + """ + Get the version of Safety. + + Returns: + str: The Safety version. + """ from safety import VERSION return VERSION -def get_primary_announcement(announcements): +def get_primary_announcement(announcements: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """ + Get the primary announcement from a list of announcements. + + Args: + announcements (List[Dict[str, Any]]): The list of announcements. + + Returns: + Optional[Dict[str, Any]]: The primary announcement if found, None otherwise. + """ for announcement in announcements: if announcement.get('type', '').lower() == 'primary_announcement': try: @@ -176,20 +265,50 @@ def get_primary_announcement(announcements): return None -def get_basic_announcements(announcements, include_local: bool = True): +def get_basic_announcements(announcements: List[Dict[str, Any]], include_local: bool = True) -> List[Dict[str, Any]]: + """ + Get the basic announcements from a list of announcements. + + Args: + announcements (List[Dict[str, Any]]): The list of announcements. + include_local (bool): Whether to include local announcements. + + Returns: + List[Dict[str, Any]]: The list of basic announcements. + """ return [announcement for announcement in announcements if announcement.get('type', '').lower() != 'primary_announcement' and not announcement.get('local', False) or (announcement.get('local', False) and include_local)] -def filter_announcements(announcements, by_type='error'): +def filter_announcements(announcements: List[Dict[str, Any]], by_type: str = 'error') -> List[Dict[str, Any]]: + """ + Filter announcements by type. + + Args: + announcements (List[Dict[str, Any]]): The list of announcements. + by_type (str): The type of announcements to filter by. + + Returns: + List[Dict[str, Any]]: The filtered announcements. + """ return [announcement for announcement in announcements if announcement.get('type', '').lower() == by_type] -def build_telemetry_data(telemetry = True, - command: Optional[str] = None, +def build_telemetry_data(telemetry: bool = True, + command: Optional[str] = None, subcommand: Optional[str] = None) -> TelemetryModel: + """Build telemetry data for the Safety context. + + Args: + telemetry (bool): Whether telemetry is enabled. + command (Optional[str]): The command. + subcommand (Optional[str]): The subcommand. + + Returns: + TelemetryModel: The telemetry data model. + """ context = SafetyContext() body = { @@ -212,10 +331,15 @@ def build_telemetry_data(telemetry = True, return TelemetryModel(**body) -def build_git_data(): +def build_git_data() -> Dict[str, Any]: + """Build git data for the repository. + + Returns: + Dict[str, str]: The git data. + """ import subprocess - def git_command(commandline): + def git_command(commandline: List[str]) -> str: return subprocess.run(commandline, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL).stdout.decode('utf-8').strip() try: @@ -251,7 +375,17 @@ def git_command(commandline): } -def output_exception(exception, exit_code_output=True): +def output_exception(exception: Exception, exit_code_output: bool = True) -> None: + """ + Output an exception message to the console and exit. + + Args: + exception (Exception): The exception to output. + exit_code_output (bool): Whether to output the exit code. + + Exits: + Exits the program with the appropriate exit code. + """ click.secho(str(exception), fg="red", file=sys.stderr) if exit_code_output: @@ -264,7 +398,19 @@ def output_exception(exception, exit_code_output=True): sys.exit(exit_code) def build_remediation_info_url(base_url: str, version: Optional[str], spec: str, - target_version: Optional[str] = ''): + target_version: Optional[str] = '') -> str: + """ + Build the remediation info URL. + + Args: + base_url (str): The base URL. + version (Optional[str]): The current version. + spec (str): The specification. + target_version (Optional[str]): The target version. + + Returns: + str: The remediation info URL. + """ params = {'from': version, 'to': target_version} @@ -277,8 +423,23 @@ def build_remediation_info_url(base_url: str, version: Optional[str], spec: str, return req.url -def get_processed_options(policy_file, ignore, ignore_severity_rules, exit_code, ignore_unpinned_requirements=None, - project=None): +def get_processed_options(policy_file: Dict[str, Any], ignore: Dict[str, Any], ignore_severity_rules: Dict[str, Any], + exit_code: bool, ignore_unpinned_requirements: Optional[bool] = None, + project: Optional[str] = None) -> Tuple[Dict[str, Any], Dict[str, Any], bool, Optional[bool], Optional[str]]: + """ + Get processed options from the policy file. + + Args: + policy_file (Dict[str, Any]): The policy file. + ignore (Dict[str, Any]): The ignore settings. + ignore_severity_rules (Dict[str, Any]): The ignore severity rules. + exit_code (bool): The exit code setting. + ignore_unpinned_requirements (Optional[bool]): The ignore unpinned requirements setting. + project (Optional[str]): The project setting. + + Returns: + Tuple[Dict[str, Any], Dict[str, Any], bool, Optional[bool], Optional[str]]: The processed options. + """ if policy_file: project_config = policy_file.get('project', {}) security = policy_file.get('security', {}) @@ -306,7 +467,17 @@ def get_processed_options(policy_file, ignore, ignore_severity_rules, exit_code, return ignore, ignore_severity_rules, exit_code, ignore_unpinned_requirements, project -def get_fix_options(policy_file, auto_remediation_limit): +def get_fix_options(policy_file: Dict[str, Any], auto_remediation_limit: int) -> int: + """ + Get fix options from the policy file. + + Args: + policy_file (Dict[str, Any]): The policy file. + auto_remediation_limit (int): The auto remediation limit. + + Returns: + int: The auto remediation limit. + """ auto_fix = [] source = click.get_current_context().get_parameter_source("auto_remediation_limit") @@ -323,6 +494,10 @@ def get_fix_options(policy_file, auto_remediation_limit): class MutuallyExclusiveOption(click.Option): + """ + A click option that is mutually exclusive with other options. + """ + def __init__(self, *args, **kwargs): self.mutually_exclusive = set(kwargs.pop('mutually_exclusive', [])) self.with_values = kwargs.pop('with_values', {}) @@ -335,7 +510,18 @@ def __init__(self, *args, **kwargs): ) super(MutuallyExclusiveOption, self).__init__(*args, **kwargs) - def handle_parse_result(self, ctx, opts, args): + def handle_parse_result(self, ctx: click.Context, opts: Dict[str, Any], args: List[str]) -> Tuple[Any, List[str]]: + """ + Handle the parse result for mutually exclusive options. + + Args: + ctx (click.Context): The click context. + opts (Dict[str, Any]): The options dictionary. + args (List[str]): The arguments list. + + Returns: + Tuple[Any, List[str]]: The result and remaining arguments. + """ m_exclusive_used = self.mutually_exclusive.intersection(opts) option_used = m_exclusive_used and self.name in opts @@ -363,6 +549,9 @@ def handle_parse_result(self, ctx, opts, args): class DependentOption(click.Option): + """ + A click option that depends on other options. + """ def __init__(self, *args, **kwargs): self.required_options = set(kwargs.pop('required_options', [])) help = kwargs.get('help', '') @@ -373,7 +562,18 @@ def __init__(self, *args, **kwargs): ) super(DependentOption, self).__init__(*args, **kwargs) - def handle_parse_result(self, ctx, opts, args): + def handle_parse_result(self, ctx: click.Context, opts: Dict[str, Any], args: List[str]) -> Tuple[Any, List[str]]: + """ + Handle the parse result for dependent options. + + Args: + ctx (click.Context): The click context. + opts (Dict[str, Any]): The options dictionary. + args (List[str]): The arguments list. + + Returns: + Tuple[Any, List[str]]: The result and remaining arguments. + """ missing_required_arguments = None if self.name in opts: @@ -395,7 +595,18 @@ def handle_parse_result(self, ctx, opts, args): ) -def transform_ignore(ctx, param, value): +def transform_ignore(ctx: click.Context, param: click.Parameter, value: Tuple[str]) -> Dict[str, Dict[str, Optional[str]]]: + """ + Transform ignore parameters into a dictionary. + + Args: + ctx (click.Context): The click context. + param (click.Parameter): The click parameter. + value (Tuple[str]): The parameter value. + + Returns: + Dict[str, Dict[str, Optional[str]]]: The transformed ignore parameters. + """ ignored_default_dict = {'reason': '', 'expires': None} if isinstance(value, tuple) and any(value): # Following code is required to support the 2 ways of providing 'ignore' @@ -409,7 +620,18 @@ def transform_ignore(ctx, param, value): return {} -def active_color_if_needed(ctx, param, value): +def active_color_if_needed(ctx: click.Context, param: click.Parameter, value: str) -> str: + """ + Activate color if needed based on the context and environment variables. + + Args: + ctx (click.Context): The click context. + param (click.Parameter): The click parameter. + value (str): The parameter value. + + Returns: + str: The parameter value. + """ if value == 'screen': ctx.color = True @@ -426,24 +648,63 @@ def active_color_if_needed(ctx, param, value): return value -def json_alias(ctx, param, value): +def json_alias(ctx: click.Context, param: click.Parameter, value: bool) -> Optional[bool]: + """ + Set the SAFETY_OUTPUT environment variable to 'json' if the parameter is used. + + Args: + ctx (click.Context): The click context. + param (click.Parameter): The click parameter. + value (bool): The parameter value. + + Returns: + bool: The parameter value. + """ if value: os.environ['SAFETY_OUTPUT'] = 'json' return value -def html_alias(ctx, param, value): +def html_alias(ctx: click.Context, param: click.Parameter, value: bool) -> Optional[bool]: + """ + Set the SAFETY_OUTPUT environment variable to 'html' if the parameter is used. + + Args: + ctx (click.Context): The click context. + param (click.Parameter): The click parameter. + value (bool): The parameter value. + + Returns: + bool: The parameter value. + """ if value: os.environ['SAFETY_OUTPUT'] = 'html' return value -def bare_alias(ctx, param, value): +def bare_alias(ctx: click.Context, param: click.Parameter, value: bool) -> Optional[bool]: + """ + Set the SAFETY_OUTPUT environment variable to 'bare' if the parameter is used. + + Args: + ctx (click.Context): The click context. + param (click.Parameter): The click parameter. + value (bool): The parameter value. + + Returns: + bool: The parameter value. + """ if value: os.environ['SAFETY_OUTPUT'] = 'bare' return value -def get_terminal_size(): +def get_terminal_size() -> os.terminal_size: + """ + Get the terminal size. + + Returns: + os.terminal_size: The terminal size. + """ from shutil import get_terminal_size as t_size # get_terminal_size can report 0, 0 if run from pseudo-terminal prior Python 3.11 versions @@ -453,7 +714,16 @@ def get_terminal_size(): return os.terminal_size((columns, lines)) -def clean_project_id(input_string): +def clean_project_id(input_string: str) -> str: + """ + Clean a project ID by removing non-alphanumeric characters and normalizing the string. + + Args: + input_string (str): The input string. + + Returns: + str: The cleaned project ID. + """ input_string = re.sub(r'[^a-zA-Z0-9]+', '-', input_string) input_string = input_string.strip('-') input_string = input_string.lower() @@ -461,7 +731,16 @@ def clean_project_id(input_string): return input_string -def validate_expiration_date(expiration_date): +def validate_expiration_date(expiration_date: str) -> Optional[datetime]: + """ + Validate an expiration date string. + + Args: + expiration_date (str): The expiration date string. + + Returns: + Optional[datetime]: The validated expiration date if valid, None otherwise. + """ d = None if expiration_date: @@ -480,7 +759,7 @@ def validate_expiration_date(expiration_date): class SafetyPolicyFile(click.ParamType): """ - Custom Safety Policy file to hold validations + Custom Safety Policy file to hold validations. """ name = "filename" @@ -489,7 +768,7 @@ class SafetyPolicyFile(click.ParamType): def __init__( self, mode: str = "r", - encoding: str = None, + encoding: Optional[str] = None, errors: str = "strict", pure: bool = os.environ.get('SAFETY_PURE_YAML', 'false').lower() == 'true' ) -> None: @@ -499,12 +778,33 @@ def __init__( self.basic_msg = '\n' + click.style('Unable to load the Safety Policy file "{name}".', fg='red') self.pure = pure - def to_info_dict(self): + def to_info_dict(self) -> Dict[str, Any]: + """ + Convert the object to an info dictionary. + + Returns: + Dict[str, Any]: The info dictionary. + """ info_dict = super().to_info_dict() info_dict.update(mode=self.mode, encoding=self.encoding) return info_dict - def fail_if_unrecognized_keys(self, used_keys, valid_keys, param=None, ctx=None, msg='{hint}', context_hint=''): + def fail_if_unrecognized_keys(self, used_keys: List[str], valid_keys: List[str], param: Optional[click.Parameter] = None, + ctx: Optional[click.Context] = None, msg: str = '{hint}', context_hint: str = '') -> None: + """ + Fail if unrecognized keys are found in the policy file. + + Args: + used_keys (List[str]): The used keys. + valid_keys (List[str]): The valid keys. + param (Optional[click.Parameter]): The click parameter. + ctx (Optional[click.Context]): The click context. + msg (str): The error message template. + context_hint (str): The context hint for the error message. + + Raises: + click.UsageError: If unrecognized keys are found. + """ for keyword in used_keys: if keyword not in valid_keys: match = None @@ -521,31 +821,60 @@ def fail_if_unrecognized_keys(self, used_keys, valid_keys, param=None, ctx=None, self.fail(msg.format(hint=f'{context_hint}"{keyword}" is not a valid keyword.{maybe_msg}'), param, ctx) - def fail_if_wrong_bool_value(self, keyword, value, msg='{hint}'): + def fail_if_wrong_bool_value(self, keyword: str, value: Any, msg: str = '{hint}') -> None: + """ + Fail if a boolean value is invalid. + + Args: + keyword (str): The keyword. + value (Any): The value. + msg (str): The error message template. + + Raises: + click.UsageError: If the boolean value is invalid. + """ if value is not None and not isinstance(value, bool): self.fail(msg.format(hint=f"'{keyword}' value needs to be a boolean. " "You can use True, False, TRUE, FALSE, true or false")) - def convert(self, value, param, ctx): - try: + def convert(self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context]) -> Any: + """ + Convert the parameter value to a Safety policy file. + + Args: + value (Any): The parameter value. + param (Optional[click.Parameter]): The click parameter. + ctx (Optional[click.Context]): The click context. + + Returns: + Any: The converted policy file. + Raises: + click.UsageError: If the policy file is invalid. + """ + try: + # Check if the value is already a file-like object if hasattr(value, "read") or hasattr(value, "write"): return value + # Prepare the error message template msg = self.basic_msg.format(name=value) + '\n' + click.style('HINT:', fg='yellow') + ' {hint}' + # Open the file stream f, _ = click.types.open_stream( value, self.mode, self.encoding, self.errors, atomic=False ) filename = '' try: + # Read the content of the file raw = f.read() yaml = YAML(typ='safe', pure=self.pure) safety_policy = yaml.load(raw) filename = f.name f.close() except Exception as e: + # Handle YAML parsing errors show_parsed_hint = isinstance(e, MarkedYAMLError) hint = str(e) if show_parsed_hint: @@ -553,6 +882,7 @@ def convert(self, value, param, ctx): self.fail(msg.format(name=value, hint=hint), param, ctx) + # Validate the structure of the safety policy if not safety_policy or not isinstance(safety_policy, dict) or not safety_policy.get('security', None): hint = "you are missing the security root tag" try: @@ -566,33 +896,34 @@ def convert(self, value, param, ctx): self.fail( msg.format(hint=hint), param, ctx) + # Validate 'security' section keys security_config = safety_policy.get('security', {}) security_keys = ['ignore-cvss-severity-below', 'ignore-cvss-unknown-severity', 'ignore-vulnerabilities', 'continue-on-vulnerability-error', 'ignore-unpinned-requirements'] self.fail_if_unrecognized_keys(security_config.keys(), security_keys, param=param, ctx=ctx, msg=msg, context_hint='"security" -> ') + # Validate 'ignore-cvss-severity-below' value ignore_cvss_security_below = security_config.get('ignore-cvss-severity-below', None) - if ignore_cvss_security_below: limit = 0.0 - try: limit = float(ignore_cvss_security_below) except ValueError as e: self.fail(msg.format(hint="'ignore-cvss-severity-below' value needs to be an integer or float.")) - if limit < 0 or limit > 10: self.fail(msg.format(hint="'ignore-cvss-severity-below' needs to be a value between 0 and 10")) + # Validate 'continue-on-vulnerability-error' value continue_on_vulnerability_error = security_config.get('continue-on-vulnerability-error', None) self.fail_if_wrong_bool_value('continue-on-vulnerability-error', continue_on_vulnerability_error, msg) + # Validate 'ignore-cvss-unknown-severity' value ignore_cvss_unknown_severity = security_config.get('ignore-cvss-unknown-severity', None) self.fail_if_wrong_bool_value('ignore-cvss-unknown-severity', ignore_cvss_unknown_severity, msg) + # Validate 'ignore-vulnerabilities' section ignore_vulns = safety_policy.get('security', {}).get('ignore-vulnerabilities', {}) - if ignore_vulns: if not isinstance(ignore_vulns, dict): self.fail(msg.format(hint="Vulnerability IDs under the 'ignore-vulnerabilities' key, need to " @@ -626,7 +957,7 @@ def convert(self, value, param, ctx): f"be a positive integer") ) - # Validate expires + # Validate expires date d = validate_expiration_date(expires) if expires and not d: @@ -644,9 +975,9 @@ def convert(self, value, param, ctx): else: safety_policy['security']['ignore-vulnerabilities'] = {} + # Validate 'fix' section keys fix_config = safety_policy.get('fix', {}) - self.fail_if_unrecognized_keys(fix_config.keys(), ['auto-security-updates-limit'], param=param, ctx=ctx, msg=msg, - context_hint='"fix" -> ') + self.fail_if_unrecognized_keys(fix_config.keys(), ['auto-security-updates-limit'], param=param, ctx=ctx, msg=msg, context_hint='"fix" -> ') auto_remediation_limit = fix_config.get('auto-security-updates-limit', None) if auto_remediation_limit: @@ -658,7 +989,7 @@ def convert(self, value, param, ctx): except BadParameter as expected_e: raise expected_e except Exception as e: - # Don't fail in the default case + # Handle file not found errors gracefully, don't fail in the default case if ctx and isinstance(e, OSError): default = ctx.get_parameter_source source = default("policy_file") if default("policy_file") else default("policy_file_path") @@ -670,14 +1001,19 @@ def convert(self, value, param, ctx): self.fail(f"{problem}\n{hint}", param, ctx) def shell_complete( - self, ctx: "Context", param: "Parameter", incomplete: str + self, ctx: click.Context, param: click.Parameter, incomplete: str ): - """Return a special completion marker that tells the completion + """ + Return a special completion marker that tells the completion system to use the shell to provide file path completions. - :param ctx: Invocation context for this command. - :param param: The parameter that is requesting completion. - :param incomplete: Value being completed. May be empty. + Args: + ctx (click.Context): The click context. + param (click.Parameter): The click parameter. + incomplete (str): The value being completed. May be empty. + + Returns: + List[click.shell_completion.CompletionItem]: The completion items. .. versionadded:: 8.0 """ @@ -687,12 +1023,15 @@ def shell_complete( class SingletonMeta(type): + """ + A metaclass for singleton classes. + """ - _instances = {} + _instances: Dict[type, Any] = {} - _lock = Lock() + _lock: Lock = Lock() - def __call__(cls, *args, **kwargs): + def __call__(cls, *args: Any, **kwargs: Any) -> Any: with cls._lock: if cls not in cls._instances: instance = super().__call__(*args, **kwargs) @@ -701,6 +1040,9 @@ def __call__(cls, *args, **kwargs): class SafetyContext(metaclass=SingletonMeta): + """ + A singleton class to hold the Safety context. + """ packages = [] key = False db_mirror = False @@ -725,6 +1067,9 @@ class SafetyContext(metaclass=SingletonMeta): def sync_safety_context(f): + """ + Decorator to sync the Safety context with the function arguments. + """ def new_func(*args, **kwargs): ctx = SafetyContext() @@ -746,12 +1091,16 @@ def new_func(*args, **kwargs): @sync_safety_context -def get_packages_licenses(*, packages=None, licenses_db=None): - """Get the licenses for the specified packages based on their version. +def get_packages_licenses(*, packages: Optional[List[Package]] = None, licenses_db: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: + """ + Get the licenses for the specified packages based on their version. + + Args: + packages (Optional[List[Package]]): The list of packages. + licenses_db (Optional[Dict[str, Any]]): The licenses database. - :param packages: packages list - :param licenses_db: the licenses db in the raw form. - :return: list of objects with the packages and their respectives licenses. + Returns: + List[Dict[str, Any]]: The list of packages and their licenses. """ SafetyContext().command = 'license' @@ -776,7 +1125,7 @@ def get_packages_licenses(*, packages=None, licenses_db=None): if is_pinned_requirement(req.specifier): pkg.version = next(iter(req.specifier)).version break - + if not pkg.version: continue version_requested = parse_version(pkg.version) @@ -806,7 +1155,19 @@ def get_packages_licenses(*, packages=None, licenses_db=None): return filtered_packages_licenses -def get_requirements_content(files): +def get_requirements_content(files: List[click.File]) -> Dict[str, str]: + """ + Get the content of the requirements files. + + Args: + files (List[click.File]): The list of requirement files. + + Returns: + Dict[str, str]: The content of the requirement files. + + Raises: + InvalidProvidedReportError: If a file cannot be read. + """ requirements_files = {} for f in files: @@ -820,16 +1181,43 @@ def get_requirements_content(files): return requirements_files -def is_ignore_unpinned_mode(version): +def is_ignore_unpinned_mode(version: str) -> bool: + """ + Check if unpinned mode is enabled based on the version. + + Args: + version (str): The version string. + + Returns: + bool: True if unpinned mode is enabled, False otherwise. + """ ignore = SafetyContext().params.get('ignore_unpinned_requirements') return (ignore is None or ignore) and not version -def get_remediations_count(remediations): +def get_remediations_count(remediations: Dict[str, Any]) -> int: + """ + Get the count of remediations. + + Args: + remediations (Dict[str, Any]): The remediations dictionary. + + Returns: + int: The count of remediations. + """ return sum((len(rem.keys()) for pkg, rem in remediations.items())) -def get_hashes(dependency): +def get_hashes(dependency: Any) -> List[Dict[str, str]]: + """ + Get the hashes for a dependency. + + Args: + dependency (Any): The dependency. + + Returns: + List[Dict[str, str]]: The list of hashes. + """ pattern = re.compile(HASH_REGEX_GROUPS) return [{'method': method, 'hash': hsh} for method, hsh in @@ -837,9 +1225,19 @@ def get_hashes(dependency): def pluralize(word: str, count: int = 0) -> str: + """ + Pluralize a word based on the count. + + Args: + word (str): The word to pluralize. + count (int): The count. + + Returns: + str: The pluralized word. + """ if count == 1: return word - + default = {"was": "were", "this": "these", "has": "have"} if word in default: @@ -858,7 +1256,10 @@ def pluralize(word: str, count: int = 0) -> str: return word + "s" -def initializate_config_dirs(): +def initializate_config_dirs() -> None: + """ + Initialize the configuration directories. + """ USER_CONFIG_DIR.mkdir(parents=True, exist_ok=True) try: