diff --git a/safety/safety.py b/safety/safety.py index a368561b..ae6d1fc5 100644 --- a/safety/safety.py +++ b/safety/safety.py @@ -12,7 +12,7 @@ import time from collections import defaultdict from datetime import datetime -from typing import Dict, Optional, List +from typing import Dict, Optional, List, Any import click import requests @@ -21,6 +21,7 @@ from packaging.utils import canonicalize_name from packaging.version import parse as parse_version, Version from pydantic.json import pydantic_encoder +from filelock import FileLock from safety_schemas.models import Ecosystem, FileType @@ -41,34 +42,38 @@ LOG = logging.getLogger(__name__) -def get_from_cache(db_name, cache_valid_seconds=0, skip_time_verification=False): - if os.path.exists(DB_CACHE_FILE): - with open(DB_CACHE_FILE) as f: - try: - data = json.loads(f.read()) - if db_name in data: +def get_from_cache(db_name: str, cache_valid_seconds: int = 0, skip_time_verification: bool = False) -> Optional[Dict[str, Any]]: + cache_file_lock = f"{DB_CACHE_FILE}.lock" + os.makedirs(os.path.dirname(cache_file_lock), exist_ok=True) + lock = FileLock(cache_file_lock, timeout=10) + with lock: + if os.path.exists(DB_CACHE_FILE): + with open(DB_CACHE_FILE) as f: + try: + data = json.loads(f.read()) + if db_name in data: - if "cached_at" in data[db_name]: - if data[db_name]["cached_at"] + cache_valid_seconds > time.time() or skip_time_verification: - LOG.debug('Getting the database from cache at %s, cache setting: %s', - data[db_name]["cached_at"], cache_valid_seconds) - - try: - data[db_name]["db"]["meta"]["base_domain"] = "https://data.safetycli.com" - except KeyError as e: - pass + if "cached_at" in data[db_name]: + if data[db_name]["cached_at"] + cache_valid_seconds > time.time() or skip_time_verification: + LOG.debug('Getting the database from cache at %s, cache setting: %s', + data[db_name]["cached_at"], cache_valid_seconds) - return data[db_name]["db"] + try: + data[db_name]["db"]["meta"]["base_domain"] = "https://data.safetycli.com" + except KeyError as e: + pass - LOG.debug('Cached file is too old, it was cached at %s', data[db_name]["cached_at"]) - else: - LOG.debug('There is not the cached_at key in %s database', data[db_name]) + return data[db_name]["db"] - except json.JSONDecodeError: - LOG.debug('JSONDecodeError trying to get the cached database.') - else: - LOG.debug("Cache file doesn't exist...") - return False + LOG.debug('Cached file is too old, it was cached at %s', data[db_name]["cached_at"]) + else: + LOG.debug('There is not the cached_at key in %s database', data[db_name]) + + except json.JSONDecodeError: + LOG.debug('JSONDecodeError trying to get the cached database.') + else: + LOG.debug("Cache file doesn't exist...") + return None def write_to_cache(db_name, data): @@ -95,25 +100,31 @@ def write_to_cache(db_name, data): if exc.errno != errno.EEXIST: raise - with open(DB_CACHE_FILE, "r") as f: - try: - cache = json.loads(f.read()) - except json.JSONDecodeError: - LOG.debug('JSONDecodeError in the local cache, dumping the full cache file.') + cache_file_lock = f"{DB_CACHE_FILE}.lock" + lock = FileLock(cache_file_lock, timeout=10) + with lock: + if os.path.exists(DB_CACHE_FILE): + with open(DB_CACHE_FILE, "r") as f: + try: + cache = json.loads(f.read()) + except json.JSONDecodeError: + LOG.debug('JSONDecodeError in the local cache, dumping the full cache file.') + cache = {} + else: cache = {} - with open(DB_CACHE_FILE, "w") as f: - cache[db_name] = { - "cached_at": time.time(), - "db": data - } - f.write(json.dumps(cache)) - LOG.debug('Safety updated the cache file for %s database.', db_name) + with open(DB_CACHE_FILE, "w") as f: + cache[db_name] = { + "cached_at": time.time(), + "db": data + } + f.write(json.dumps(cache)) + LOG.debug('Safety updated the cache file for %s database.', db_name) def fetch_database_url(session, mirror, db_name, cached, telemetry=True, ecosystem: Ecosystem = Ecosystem.PYTHON, from_cache=True): - headers = {'schema-version': JSON_SCHEMA_VERSION, 'ecosystem': ecosystem.value} + headers = {'schema-version': JSON_SCHEMA_VERSION, 'ecosystem': ecosystem.value} if cached and from_cache: cached_data = get_from_cache(db_name=db_name, cache_valid_seconds=cached) @@ -122,13 +133,13 @@ def fetch_database_url(session, mirror, db_name, cached, telemetry=True, return cached_data url = mirror + db_name - + telemetry_data = { - 'telemetry': json.dumps(build_telemetry_data(telemetry=telemetry), + 'telemetry': json.dumps(build_telemetry_data(telemetry=telemetry), default=pydantic_encoder)} try: - r = session.get(url=url, timeout=REQUEST_TIMEOUT, + r = session.get(url=url, timeout=REQUEST_TIMEOUT, headers=headers, params=telemetry_data) except requests.exceptions.ConnectionError: raise NetworkConnectionError() @@ -205,10 +216,10 @@ def fetch_database_file(path: str, db_name: str, cached = 0, if not full_path.exists(): raise DatabaseFileNotFoundError(db=path) - + with open(full_path) as f: data = json.loads(f.read()) - + if cached: LOG.info('Writing %s to cache because cached value was %s', db_name, cached) write_to_cache(db_name, data) @@ -226,7 +237,7 @@ def is_valid_database(db) -> bool: return False -def fetch_database(session, full=False, db=False, cached=0, telemetry=True, +def fetch_database(session, full=False, db=False, cached=0, telemetry=True, ecosystem: Optional[Ecosystem] = None, from_cache=True): if session.is_using_auth_credentials(): @@ -242,7 +253,7 @@ def fetch_database(session, full=False, db=False, cached=0, telemetry=True, if is_a_remote_mirror(mirror): if ecosystem is None: ecosystem = Ecosystem.PYTHON - data = fetch_database_url(session, mirror, db_name=db_name, cached=cached, + data = fetch_database_url(session, mirror, db_name=db_name, cached=cached, telemetry=telemetry, ecosystem=ecosystem, from_cache=from_cache) else: data = fetch_database_file(mirror, db_name=db_name, cached=cached, @@ -562,16 +573,16 @@ def compute_sec_ver(remediations, packages: Dict[str, Package], secure_vulns_by_ secure_v = compute_sec_ver_for_user(package=pkg, secure_vulns_by_user=secure_vulns_by_user, db_full=db_full) rem['closest_secure_version'] = get_closest_ver(secure_v, version, spec) - + upgrade = rem['closest_secure_version'].get('upper', None) downgrade = rem['closest_secure_version'].get('lower', None) recommended_version = None - + if upgrade: recommended_version = upgrade elif downgrade: recommended_version = downgrade - + rem['recommended_version'] = recommended_version rem['other_recommended_versions'] = [other_v for other_v in secure_v if other_v != str(recommended_version)] @@ -645,12 +656,12 @@ def process_fixes(files, remediations, auto_remediation_limit, output, no_output def process_fixes_scan(file_to_fix, to_fix_spec, auto_remediation_limit, output, no_output=True, prompt=False): to_fix_remediations = [] - + def get_remmediation_from(spec): upper = None lower = None recommended = None - + try: upper = Version(spec.remediation.closest_secure.upper) if spec.remediation.closest_secure.upper else None except Exception as e: @@ -664,7 +675,7 @@ def get_remmediation_from(spec): try: recommended = Version(spec.remediation.recommended) except Exception as e: - LOG.error(f'Error getting recommended version for remediation, ignoring', exc_info=True) + LOG.error(f'Error getting recommended version for remediation, ignoring', exc_info=True) return { "vulnerabilities_found": spec.remediation.vulnerabilities_found, @@ -672,7 +683,7 @@ def get_remmediation_from(spec): "requirement": spec, "more_info_url": spec.remediation.more_info_url, "closest_secure_version": { - 'upper': upper, + 'upper': upper, 'lower': lower }, "recommended_version": recommended, @@ -690,7 +701,7 @@ def get_remmediation_from(spec): 'files': {str(file_to_fix.location): {'content': None, 'fixes': {'TO_SKIP': [], 'TO_APPLY': [], 'TO_CONFIRM': []}, 'supported': False, 'filename': file_to_fix.location.name}}, 'dependencies': defaultdict(dict), } - + fixes = apply_fixes(requirements, output, no_output, prompt, scan_flow=True, auto_remediation_limit=auto_remediation_limit) return fixes @@ -822,7 +833,7 @@ def apply_fixes(requirements, out_type, no_output, prompt, scan_flow=False, auto for name, data in requirements['files'].items(): output = [('', {}), (f"Analyzing {name}... [{get_fix_opt_used_msg(auto_remediation_limit)} limit]", {'styling': {'bold': True}, 'start_line_decorator': '->', 'indent': ' '})] - + r_skip = data['fixes']['TO_SKIP'] r_apply = data['fixes']['TO_APPLY'] r_confirm = data['fixes']['TO_CONFIRM'] @@ -901,7 +912,7 @@ def apply_fixes(requirements, out_type, no_output, prompt, scan_flow=False, auto else: not_supported_filename = data.get('filename', name) output.append( - (f"{not_supported_filename} updates not supported: Please update these dependencies using your package manager.", + (f"{not_supported_filename} updates not supported: Please update these dependencies using your package manager.", {'start_line_decorator': ' -', 'indent': ' '})) output.append(('', {})) @@ -999,7 +1010,7 @@ def review(*, report=None, params=None): @sync_safety_context def get_licenses(*, session=None, db_mirror=False, cached=0, telemetry=True): - + if db_mirror: mirrors = [db_mirror] else: diff --git a/safety/scan/ecosystems/python/main.py b/safety/scan/ecosystems/python/main.py index 275b089d..bd9353bf 100644 --- a/safety/scan/ecosystems/python/main.py +++ b/safety/scan/ecosystems/python/main.py @@ -29,19 +29,19 @@ LOG = logging.getLogger(__name__) -def ignore_vuln_if_needed(dependency: PythonDependency, file_type: FileType, +def ignore_vuln_if_needed(dependency: PythonDependency, file_type: FileType, vuln_id: str, cve, ignore_vulns, ignore_unpinned: bool, ignore_environment: bool, specification: PythonSpecification, ignore_severity: List[VulnerabilitySeverityLabels] = []): - - vuln_ignored: bool = vuln_id in ignore_vulns + + vuln_ignored: bool = vuln_id in ignore_vulns if vuln_ignored and ignore_vulns[vuln_id].code is IgnoreCodes.manual: - if (not ignore_vulns[vuln_id].expires + if (not ignore_vulns[vuln_id].expires or ignore_vulns[vuln_id].expires > datetime.utcnow().date()): return - + del ignore_vulns[vuln_id] if ignore_environment and file_type is FileType.VIRTUAL_ENVIRONMENT: @@ -56,7 +56,7 @@ def ignore_vuln_if_needed(dependency: PythonDependency, file_type: FileType, if cve.cvssv3 and cve.cvssv3.get("base_severity", None): severity_label = VulnerabilitySeverityLabels( cve.cvssv3["base_severity"].lower()) - + if severity_label in ignore_severity: reason = f"{severity_label.value.capitalize()} severity ignored by rule in policy file." ignore_vulns[vuln_id] = IgnoredItemDetail( @@ -75,7 +75,7 @@ def ignore_vuln_if_needed(dependency: PythonDependency, file_type: FileType, specifications = set() specifications.add(str(specification.specifier)) ignore_vulns[vuln_id] = IgnoredItemDetail( - code=IgnoreCodes.unpinned_specification, reason=reason, + code=IgnoreCodes.unpinned_specification, reason=reason, specifications=specifications) @@ -84,7 +84,7 @@ def should_fail(config: ConfigModel, vulnerability: Vulnerability) -> bool: return False # If Severity is None type, it will be considered as UNKNOWN and NONE - # They are not the same, but we are handling like the same when a + # They are not the same, but we are handling like the same when a # vulnerability does not have a severity value. severities = [VulnerabilitySeverityLabels.NONE, VulnerabilitySeverityLabels.UNKNOWN] @@ -127,7 +127,7 @@ def get_vulnerability(vuln_id: str, cve, unpinned_ignored = ignore_vulns[vuln_id].specifications \ if vuln_id in ignore_vulns.keys() else None should_ignore = not unpinned_ignored or str(affected.specifier) in unpinned_ignored - ignored: bool = bool(ignore_vulns and + ignored: bool = bool(ignore_vulns and vuln_id in ignore_vulns and should_ignore) more_info_url = f"{base_domain}{data.get('more_info_path', '')}" @@ -175,13 +175,13 @@ def get_vulnerability(vuln_id: str, cve, ) class PythonFile(InspectableFile, Remediable): - + def __init__(self, file_type: FileType, file: FileTextWrite) -> None: super().__init__(file=file) self.ecosystem = file_type.ecosystem self.file_type = file_type - def __find_dependency_vulnerabilities__(self, dependencies: List[PythonDependency], + def __find_dependency_vulnerabilities__(self, dependencies: List[PythonDependency], config: ConfigModel): ignored_vulns_data = {} ignore_vulns = {} \ @@ -191,8 +191,11 @@ def __find_dependency_vulnerabilities__(self, dependencies: List[PythonDependenc ignore_severity = config.depedendency_vulnerability.ignore_cvss_severity ignore_unpinned = config.depedendency_vulnerability.python_ignore.unpinned_specifications ignore_environment = config.depedendency_vulnerability.python_ignore.environment_results - + db = get_from_cache(db_name="insecure.json", skip_time_verification=True) + if not db: + LOG.debug("Cache data for insecure.json is not available or is invalid.") + return db_full = None vulnerable_packages = frozenset(db.get('vulnerable_packages', [])) found_dependencies = {} @@ -214,8 +217,11 @@ def __find_dependency_vulnerabilities__(self, dependencies: List[PythonDependenc if not dependency.version: if not db_full: - db_full = get_from_cache(db_name="insecure_full.json", + db_full = get_from_cache(db_name="insecure_full.json", skip_time_verification=True) + if not db_full: + LOG.debug("Cache data for insecure_full.json is not available or is invalid.") + return dependency.refresh_from(db_full) if name in vulnerable_packages: @@ -225,8 +231,11 @@ def __find_dependency_vulnerabilities__(self, dependencies: List[PythonDependenc if spec.is_vulnerable(spec_set, dependency.insecure_versions): if not db_full: - db_full = get_from_cache(db_name="insecure_full.json", + db_full = get_from_cache(db_name="insecure_full.json", skip_time_verification=True) + if not db_full: + LOG.debug("Cache data for insecure_full.json is not available or is invalid.") + return if not dependency.latest_version: dependency.refresh_from(db_full) @@ -247,23 +256,23 @@ def __find_dependency_vulnerabilities__(self, dependencies: List[PythonDependenc vuln_id=vuln_id, cve=cve, ignore_vulns=ignore_vulns, ignore_severity=ignore_severity, - ignore_unpinned=ignore_unpinned, - ignore_environment=ignore_environment, + ignore_unpinned=ignore_unpinned, + ignore_environment=ignore_environment, specification=spec) include_ignored = True - vulnerability = get_vulnerability(vuln_id, cve, data, + vulnerability = get_vulnerability(vuln_id, cve, data, specifier, db_full, name, ignore_vulns, spec) - should_add_vuln = not (vulnerability.is_transitive and - dependency.found and + should_add_vuln = not (vulnerability.is_transitive and + dependency.found and dependency.found.parts[-1] == FileType.VIRTUAL_ENVIRONMENT.value) - + if vulnerability.ignored: ignored_vulns_data[ vulnerability.vulnerability_id] = vulnerability - + if not self.dependency_results.failed and not vulnerability.ignored: self.dependency_results.failed = should_fail(config, vulnerability) @@ -277,16 +286,16 @@ def __find_dependency_vulnerabilities__(self, dependencies: List[PythonDependenc self.dependency_results.dependencies = [dep for _, dep in found_dependencies.items()] self.dependency_results.ignored_vulns = ignore_vulns self.dependency_results.ignored_vulns_data = ignored_vulns_data - + def inspect(self, config: ConfigModel): - + # We only support vulnerability checking for now dependencies = get_dependencies(self) if not dependencies: self.results = [] - - self.__find_dependency_vulnerabilities__(dependencies=dependencies, + + self.__find_dependency_vulnerabilities__(dependencies=dependencies, config=config) def __get_secure_specifications_for_user__(self, dependency: PythonDependency, db_full, @@ -309,26 +318,26 @@ def __get_secure_specifications_for_user__(self, dependency: PythonDependency, d sec_ver_for_user = list(versions.difference(affected_v)) return sorted(sec_ver_for_user, key=lambda ver: parse_version(ver), reverse=True) - + def remediate(self): - db_full = get_from_cache(db_name="insecure_full.json", + db_full = get_from_cache(db_name="insecure_full.json", skip_time_verification=True) if not db_full: return for dependency in self.dependency_results.get_affected_dependencies(): secure_versions = dependency.secure_versions - + if not secure_versions: secure_versions = [] secure_vulns_by_user = set(self.dependency_results.ignored_vulns.keys()) if not secure_vulns_by_user: - secure_v = sorted(secure_versions, key=lambda ver: parse_version(ver), + secure_v = sorted(secure_versions, key=lambda ver: parse_version(ver), reverse=True) else: secure_v = self.__get_secure_specifications_for_user__( - dependency=dependency, db_full=db_full, + dependency=dependency, db_full=db_full, secure_vulns_by_user=secure_vulns_by_user) for specification in dependency.specifications: @@ -338,35 +347,35 @@ def remediate(self): version = None if is_pinned_requirement(specification.specifier): version = next(iter(specification.specifier)).version - closest_secure = {key: str(value) if value else None for key, value in - get_closest_ver(secure_v, - version, + closest_secure = {key: str(value) if value else None for key, value in + get_closest_ver(secure_v, + version, specification.specifier).items()} closest_secure = ClosestSecureVersion(**closest_secure) recommended = None - + if closest_secure.upper: recommended = closest_secure.upper elif closest_secure.lower: recommended = closest_secure.lower - + other_recommended = [other_v for other_v in secure_v if other_v != str(recommended)] remed_more_info_url = dependency.more_info_url if remed_more_info_url: remed_more_info_url = build_remediation_info_url( - base_url=remed_more_info_url, version=version, + base_url=remed_more_info_url, version=version, spec=str(specification.specifier), target_version=recommended) - + if not remed_more_info_url: remed_more_info_url = "-" vulns_found = sum(1 for vuln in specification.vulnerabilities if not vuln.ignored) - specification.remediation = RemediationModel(vulnerabilities_found=vulns_found, - more_info_url=remed_more_info_url, - closest_secure=closest_secure if recommended else None, - recommended=recommended, + specification.remediation = RemediationModel(vulnerabilities_found=vulns_found, + more_info_url=remed_more_info_url, + closest_secure=closest_secure if recommended else None, + recommended=recommended, other_recommended=other_recommended) diff --git a/setup.cfg b/setup.cfg index 97a72546..3d8ab9b3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,6 +52,7 @@ install_requires = pydantic>=1.10.12 safety_schemas>=0.0.2 typing-extensions>=4.7.1 + filelock~=3.12.2 [options.entry_points] console_scripts = diff --git a/test_requirements.txt b/test_requirements.txt index f1cc7af2..b465c4a9 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -19,3 +19,4 @@ typer pydantic>=1.10.12 safety_schemas>=0.0.2 typing-extensions>=4.7.1 +filelock~=3.12.2 \ No newline at end of file diff --git a/tests/test_safety.py b/tests/test_safety.py index d829d6b2..630704e6 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -171,6 +171,9 @@ def test_check_live(self): def test_check_live_cached(self): from safety.constants import DB_CACHE_FILE + # Ensure the cache directory and file exist + os.makedirs(os.path.dirname(DB_CACHE_FILE), exist_ok=True) + # lets clear the cache first try: with open(DB_CACHE_FILE, 'w') as f: