From dd195a2eb960758fceef574f6ffa0a3d2e8a8a4f Mon Sep 17 00:00:00 2001 From: Yeison Vargas Date: Tue, 30 Apr 2024 17:34:17 -0500 Subject: [PATCH] feat: add SAFETY_DB_DIR env var to the scan command --- safety/auth/cli_utils.py | 2 +- safety/auth/models.py | 4 ++++ safety/cli.py | 11 ++++++++++- safety/safety.py | 27 ++++++++++++++++++++------- safety/scan/decorators.py | 13 +++++++++---- safety/scan/finder/handlers.py | 10 ++++++++-- safety/scan/validators.py | 4 ++++ tests/scan/test_file_handlers.py | 29 +++++++++++++++++++++++++++++ 8 files changed, 85 insertions(+), 15 deletions(-) create mode 100644 tests/scan/test_file_handlers.py diff --git a/safety/auth/cli_utils.py b/safety/auth/cli_utils.py index 4cffe1bd..527f3690 100644 --- a/safety/auth/cli_utils.py +++ b/safety/auth/cli_utils.py @@ -54,7 +54,7 @@ def update_token(tokens, **kwargs): try: openid_config = client_session.get(url=OPENID_CONFIG_URL, timeout=REQUEST_TIMEOUT).json() except Exception as e: - LOG.exception('Unable to load the openID config: %s', e) + LOG.debug('Unable to load the openID config: %s', e) openid_config = {} client_session.metadata["token_endpoint"] = openid_config.get("token_endpoint", diff --git a/safety/auth/models.py b/safety/auth/models.py index 5a14f11b..34801eab 100644 --- a/safety/auth/models.py +++ b/safety/auth/models.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import os from typing import Any, Optional from authlib.integrations.base_client import BaseOAuth @@ -26,6 +27,9 @@ class Auth: email_verified: bool = False def is_valid(self) -> bool: + if os.getenv("SAFETY_DB_DIR"): + return True + if not self.client: return False diff --git a/safety/cli.py b/safety/cli.py index 3e8bcad0..6edef4dc 100644 --- a/safety/cli.py +++ b/safety/cli.py @@ -49,11 +49,20 @@ LOG = logging.getLogger(__name__) + +def configure_logger(ctx, param, debug): + level = logging.CRITICAL + + if debug: + level = logging.DEBUG + + logging.basicConfig(format='%(asctime)s %(name)s => %(message)s', level=level) + @click.group(cls=SafetyCLILegacyGroup, help=CLI_MAIN_INTRODUCTION, epilog=DEFAULT_EPILOG) @auth_options() @proxy_options @click.option('--disable-optional-telemetry', default=False, is_flag=True, show_default=True, help=CLI_DISABLE_OPTIONAL_TELEMETRY_DATA_HELP) -@click.option('--debug', default=False, help=CLI_DEBUG_HELP) +@click.option('--debug', default=False, help=CLI_DEBUG_HELP, callback=configure_logger) @click.version_option(version=get_safety_version()) @click.pass_context @inject_session diff --git a/safety/safety.py b/safety/safety.py index 86aa172f..a368561b 100644 --- a/safety/safety.py +++ b/safety/safety.py @@ -199,12 +199,21 @@ def post_results(session, safety_json, policy_file): return {} -def fetch_database_file(path, db_name, ecosystem: Ecosystem = Ecosystem.PYTHON): - full_path = os.path.join(path, db_name) - if not os.path.exists(full_path): +def fetch_database_file(path: str, db_name: str, cached = 0, + ecosystem: Optional[Ecosystem] = None): + full_path = (Path(path) / (ecosystem.value if ecosystem else '') / db_name).expanduser().resolve() + + if not full_path.exists(): raise DatabaseFileNotFoundError(db=path) + with open(full_path) as f: - return json.loads(f.read()) + data = json.loads(f.read()) + + if cached: + LOG.info('Writing %s to cache because cached value was %s', db_name, cached) + write_to_cache(db_name, data) + + return data def is_valid_database(db) -> bool: @@ -218,7 +227,8 @@ def is_valid_database(db) -> bool: def fetch_database(session, full=False, db=False, cached=0, telemetry=True, - ecosystem: Ecosystem = Ecosystem.PYTHON, from_cache=True): + ecosystem: Optional[Ecosystem] = None, from_cache=True): + if session.is_using_auth_credentials(): mirrors = API_MIRRORS elif db: @@ -230,10 +240,13 @@ def fetch_database(session, full=False, db=False, cached=0, telemetry=True, for mirror in mirrors: # mirror can either be a local path or a URL if is_a_remote_mirror(mirror): + if ecosystem is None: + ecosystem = Ecosystem.PYTHON data = fetch_database_url(session, mirror, db_name=db_name, cached=cached, telemetry=telemetry, ecosystem=ecosystem, from_cache=from_cache) else: - data = fetch_database_file(mirror, db_name=db_name, ecosystem=ecosystem) + data = fetch_database_file(mirror, db_name=db_name, cached=cached, + ecosystem=ecosystem) if data: if is_valid_database(data): return data @@ -1000,7 +1013,7 @@ def get_licenses(*, session=None, db_mirror=False, cached=0, telemetry=True): licenses = fetch_database_url(session, mirror, db_name=db_name, cached=cached, telemetry=telemetry) else: - licenses = fetch_database_file(mirror, db_name=db_name) + licenses = fetch_database_file(mirror, db_name=db_name, ecosystem=None) if licenses: return licenses raise DatabaseFetchError() diff --git a/safety/scan/decorators.py b/safety/scan/decorators.py index 519c4023..2f41a7c8 100644 --- a/safety/scan/decorators.py +++ b/safety/scan/decorators.py @@ -1,5 +1,6 @@ from functools import wraps import logging +import os from pathlib import Path from random import randint import sys @@ -135,11 +136,15 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, if ctx.obj.auth.client.get_authentication_type() == "api_key": details = {"Account": f"API key used"} else: - content = ctx.obj.auth.email - if ctx.obj.auth.name != ctx.obj.auth.email: - content = f"{ctx.obj.auth.name}, {ctx.obj.auth.email}" - details = {"Account": f"{content} {render_email_note(ctx.obj.auth)}"} + if ctx.obj.auth.client.get_authentication_type() == "token": + content = ctx.obj.auth.email + if ctx.obj.auth.name != ctx.obj.auth.email: + content = f"{ctx.obj.auth.name}, {ctx.obj.auth.email}" + + details = {"Account": f"{content} {render_email_note(ctx.obj.auth)}"} + else: + details = {"Account": f"Offline - {os.getenv('SAFETY_DB_DIR')}"} if ctx.obj.project.id: details["Project"] = ctx.obj.project.id diff --git a/safety/scan/finder/handlers.py b/safety/scan/finder/handlers.py index 395d9e0b..4e2f6966 100644 --- a/safety/scan/finder/handlers.py +++ b/safety/scan/finder/handlers.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +import os from pathlib import Path from types import MappingProxyType from typing import Dict, List, Optional, Tuple @@ -49,12 +50,17 @@ def __init__(self) -> None: def download_required_assets(self, session): from safety.safety import fetch_database + + SAFETY_DB_DIR = os.getenv("SAFETY_DB_DIR") + + db = False if SAFETY_DB_DIR is None else SAFETY_DB_DIR + - fetch_database(session=session, full=False, db=False, cached=True, + fetch_database(session=session, full=False, db=db, cached=True, telemetry=True, ecosystem=Ecosystem.PYTHON, from_cache=False) - fetch_database(session=session, full=True, db=False, cached=True, + fetch_database(session=session, full=True, db=db, cached=True, telemetry=True, ecosystem=Ecosystem.PYTHON, from_cache=False) diff --git a/safety/scan/validators.py b/safety/scan/validators.py index b9e3aa18..12aa777f 100644 --- a/safety/scan/validators.py +++ b/safety/scan/validators.py @@ -1,4 +1,5 @@ +import os from pathlib import Path from typing import Optional, Tuple import typer @@ -42,6 +43,9 @@ def fail_if_not_allowed_stage(ctx: typer.Context): stage = ctx.obj.auth.stage auth_type: AuthenticationType = ctx.obj.auth.client.get_authentication_type() + if os.getenv("SAFETY_DB_DIR"): + return + if not auth_type.is_allowed_in(stage): raise typer.BadParameter(f"'{auth_type.value}' auth type isn't allowed with " \ f"the '{stage}' stage.") diff --git a/tests/scan/test_file_handlers.py b/tests/scan/test_file_handlers.py new file mode 100644 index 00000000..43b59be6 --- /dev/null +++ b/tests/scan/test_file_handlers.py @@ -0,0 +1,29 @@ +import os +import pytest +from unittest.mock import Mock, patch +from safety.scan.finder.handlers import PythonFileHandler + +@patch('safety.safety.fetch_database') +def test_download_required_assets(mock_fetch_database): + handler = PythonFileHandler() + session = Mock() + + os.environ["SAFETY_DB_DIR"] = "/path/to/db" + handler.download_required_assets(session) + + _, kwargs = mock_fetch_database.call_args + + assert kwargs['db'] == "/path/to/db" + +@patch('safety.safety.fetch_database') +def test_download_required_assets_no_db_dir(mock_fetch_database): + handler = PythonFileHandler() + session = Mock() + + if "SAFETY_DB_DIR" in os.environ: + del os.environ["SAFETY_DB_DIR"] + handler.download_required_assets(session) + + _, kwargs = mock_fetch_database.call_args + + assert kwargs['db'] == False