Skip to content

Commit

Permalink
Linter Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinav-nain committed Jul 4, 2024
1 parent 6e1e392 commit 8a7c549
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 43 deletions.
8 changes: 5 additions & 3 deletions lib/artifacts/scw_artifact.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import requests
import utils.utils

BASE_SCW_URL = 'https://integration-api.securecodewarrior.com/api/v1/trial?id=bugcrowd&mappingList=vrt&mappingKey='
BASE_SCW_URL = 'https://integration-api.securecodewarrior.com\
/api/v1/trial?id=bugcrowd&mappingList=vrt&mappingKey='
OUTPUT_FILENAME = 'scw_links.json'


Expand All @@ -23,7 +23,9 @@ def scw_mapping(vrt_id):


def join_vrt_id(parent_id, child_id):
return '.'.join([parent_id, child_id]) if parent_id is not None else child_id
return '.'.join(
[parent_id, child_id]
) if parent_id is not None else child_id


def generate_urls(vrt, content, parent_id=None):
Expand Down
26 changes: 14 additions & 12 deletions lib/tests/test_artifact_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@
import os
import unittest


class TestArtifactFormat(unittest.TestCase):
def setUp(self):
print("\n`---{}---`".format(self._testMethodName))
self.scw_artifact_path = os.path.join(
utils.THIRD_PARTY_MAPPING_DIR,
utils.SCW_DIR,
utils.SCW_FILENAME
)
def setUp(self):
print("\n`---{}---`".format(self._testMethodName))
self.scw_artifact_path = os.path.join(
utils.THIRD_PARTY_MAPPING_DIR,
utils.SCW_DIR,
utils.SCW_FILENAME
)

def test_artifact_loads_valid_json(self):
self.assertTrue(
utils.get_json(self.scw_artifact_path),
self.scw_artifact_path + ' is not valid JSON.'
)

def test_artifact_loads_valid_json(self):
self.assertTrue(
utils.get_json(self.scw_artifact_path),
self.scw_artifact_path + ' is not valid JSON.'
)

if __name__ == "__main__":
unittest.main()
33 changes: 26 additions & 7 deletions lib/tests/test_deprecated_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ class TestDeprecatedMapping(unittest.TestCase):
def setUp(self):
print("\n`---{}---`".format(self._testMethodName))
self.vrt_versions = utils.all_versions(utils.VRT_FILENAME)
self.last_tagged_version = max([Version.coerce(x) for x in self.vrt_versions.keys() if x != 'current'])
self.deprecated_json = utils.get_json(utils.DEPRECATED_MAPPING_FILENAME)
self.last_tagged_version = max(
[
Version.coerce(x) for x in self.vrt_versions.keys()
if x != 'current'
]
)
self.deprecated_json = utils.get_json(
utils.DEPRECATED_MAPPING_FILENAME
)

def test_old_vrt_ids_have_current_node(self):
for version, vrt in self.vrt_versions.items():
Expand All @@ -17,17 +24,28 @@ def test_old_vrt_ids_have_current_node(self):
for id_list in utils.all_id_lists(vrt):
vrt_id = '.'.join(id_list)
if vrt_id in self.deprecated_json:
max_ver = sorted(self.deprecated_json[vrt_id].keys(), key=lambda s: map(int, s.split('.')))[-1]
max_ver = sorted(
self.deprecated_json[vrt_id].keys(),
key=lambda s: map(int, s.split('.'))
)[-1]
vrt_id = self.deprecated_json[vrt_id][max_ver]
id_list = vrt_id.split('.')
self.assertTrue(vrt_id == 'other' or self.check_mapping(id_list),
'%s from v%s has no mapping' % (vrt_id, version))
self.assertTrue(
vrt_id == 'other' or self.check_mapping(id_list),
'%s from v%s has no mapping' % (vrt_id, version)
)

def test_deprecated_nodes_map_valid_node(self):
for old_id, mapping in self.deprecated_json.items():
for new_version, new_id in mapping.items():
self.assertTrue(new_id == 'other' or utils.id_valid(self.vrt_version(new_version), new_id.split('.')),
new_id + ' is not valid')
self.assertTrue(
new_id == 'other' or utils.id_valid(
self.vrt_version(
new_version
), new_id.split('.')
),
new_id + ' is not valid'
)

def check_mapping(self, id_list):
if utils.id_valid(self.vrt_versions['current'], id_list):
Expand All @@ -45,5 +63,6 @@ def vrt_version(self, version):
else:
self.fail('Unknown version: %s' % version)


if __name__ == "__main__":
unittest.main()
44 changes: 34 additions & 10 deletions lib/tests/test_vrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
import glob
import os


class TestVrt(unittest.TestCase):
def setUp(self):
print("\n`---{}---`".format(self._testMethodName))
self.vrt = utils.get_json(utils.VRT_FILENAME)
self.mappings = [
{ 'filename': f, 'name': os.path.splitext(os.path.basename(f))[0] }
for f in glob.glob(utils.MAPPING_DIR + '/**/*.json', recursive=True) if 'schema' not in f
{'filename': f, 'name': os.path.splitext(os.path.basename(f))[0]}
for f in glob.glob(
utils.MAPPING_DIR + '/**/*.json', recursive=True
)
if 'schema' not in f
]

@unittest.skip('need to decide the best way to handle this')
Expand All @@ -20,15 +24,20 @@ def test_changelog_updated(self):
Checks if CHANGELOG.md is being updated with the current commit
and prompts the user if it isn't
"""
p = subprocess.Popen('git diff HEAD --stat --staged CHANGELOG.md | wc -l', shell=True, stdout=subprocess.PIPE)
p = subprocess.Popen(
'git diff HEAD --stat --staged CHANGELOG.md | wc -l',
shell=True, stdout=subprocess.PIPE
)
out, _err = p.communicate()
self.assertGreater(int(out), 0, 'CHANGELOG.md not updated')

def validate_schema(self, schema_file, data_file):
schema = utils.get_json(schema_file)
data = utils.get_json(data_file)
jsonschema.Draft4Validator.check_schema(schema)
error = jsonschema.exceptions.best_match(jsonschema.Draft4Validator(schema).iter_errors(data))
error = jsonschema.exceptions.best_match(
jsonschema.Draft4Validator(schema).iter_errors(data)
)
if error:
raise error

Expand All @@ -41,19 +50,30 @@ def test_mapping_schemas(self):
f'{utils.MAPPING_DIR}/**/{mapping["name"]}.schema.json',
recursive=True
)[0]
self.assertTrue(os.path.isfile(schema_file), 'Missing schema file for %s mapping' % mapping['name'])
self.assertTrue(
os.path.isfile(schema_file),
'Missing schema file for %s mapping' % mapping['name']
)
self.validate_schema(schema_file, mapping['filename'])

def all_vrt_ids_have_mapping(self, mappping_filename, key):
mapping = utils.get_json(mappping_filename)
keyed_mapping = utils.key_by_id(mapping['content'])
for vrt_id_list in utils.all_id_lists(self.vrt, include_internal=False):
for vrt_id_list in utils.all_id_lists(
self.vrt, include_internal=False
):
result = utils.has_mapping(keyed_mapping, vrt_id_list, key)
if key == 'cwe' and not result:
print('WARNING: no ' + key + ' mapping for ' + '.'.join(vrt_id_list))
print('WARNING: no ' + key + ' mapping for ' + '.'.join(
vrt_id_list
))
else:
self.assertTrue(utils.has_mapping(keyed_mapping, vrt_id_list, key),
'no ' + key + ' mapping for ' + '.'.join(vrt_id_list))
self.assertTrue(
utils.has_mapping(
keyed_mapping, vrt_id_list, key
),
'no ' + key + ' mapping for ' + '.'.join(vrt_id_list)
)

def test_all_vrt_ids_have_all_mappings(self):
for mapping in self.mappings:
Expand All @@ -63,7 +83,11 @@ def only_map_valid_ids(self, mapping_filename):
vrt_ids = utils.all_id_lists(self.vrt)
mapping_ids = utils.all_id_lists(utils.get_json(mapping_filename))
for id_list in mapping_ids:
self.assertIn(id_list, vrt_ids, 'invalid id in ' + mapping_filename + ' - ' + '.'.join(id_list))
self.assertIn(
id_list,
vrt_ids,
'invalid id in ' + mapping_filename + ' - ' + '.'.join(id_list)
)

def test_only_map_valid_ids(self):
for mapping in self.mappings:
Expand Down
36 changes: 27 additions & 9 deletions lib/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
SCW_DIR = 'remediation_training'
THIRD_PARTY_MAPPING_DIR = 'third-party-mappings'


def get_json(filename):
with open(filename) as f:
return json.loads(f.read())


def all_versions(filename):
"""
Find, open and parse all tagged versions of a json file, including the current version
Find, open and parse all tagged versions of a json file,
including the current version
:param filename: The filename to find
:return: a dictionary of all the versions, in the form
Expand All @@ -41,10 +44,12 @@ def id_valid(vrt, id_list):
Check if a vrt id is valid
:param vrt: The vrt object
:param id_list: The vrt id, split into components, eg ['category', 'subcategory', 'variant']
:param id_list: The vrt id, split into components,
eg ['category', 'subcategory', 'variant']
:return: True/False
"""
# this is not particularly efficient, but it's more readable than other options so until we need to care...
# this is not particularly efficient, but it's more readable than other
# options so until we need to care...
return id_list in all_id_lists(vrt)


Expand All @@ -53,7 +58,8 @@ def has_mapping(mapping, id_list, key):
Check if a vrt id has a mapping
:param mapping: The mapping object, keyed by id
:param id_list: The vrt id, split into components, eg ['category', 'subcategory', 'variant']
:param id_list: The vrt id, split into components,
eg ['category', 'subcategory', 'variant']
:param key: The mapping key to look for, eg 'cvss_v3'
:return: True/False
"""
Expand All @@ -72,9 +78,16 @@ def key_by_id(mapping):
Converts arrays to hashes keyed by the id attribute for easier lookup. So
[{'id': 'one', 'foo': 'bar'}, {'id': 'two', 'foo': 'baz'}]
becomes
{'one': {'id': 'one', 'foo': 'bar'}, 'two': {'id': 'two', 'foo': 'baz'}}
{
'one': {'id': 'one', 'foo': 'bar'},
'two': {'id': 'two', 'foo': 'baz'}
}
"""
if isinstance(mapping, list) and isinstance(mapping[0], dict) and 'id' in mapping[0]:
if isinstance(
mapping, list
) and isinstance(
mapping[0], dict
) and 'id' in mapping[0]:
return {x['id']: key_by_id(x) for x in mapping}
elif isinstance(mapping, dict):
return {k: key_by_id(v) for k, v in mapping.items()}
Expand All @@ -84,10 +97,12 @@ def key_by_id(mapping):

def all_id_lists(vrt, include_internal=True):
"""
Get all valid vrt ids for a given vrt object, including internal nodes by default
Get all valid vrt ids for a given vrt object, including internal nodes
by default
:param vrt: The vrt object
:param include_internal: Whether to include internal nodes or only leaf nodes
:param include_internal: Whether to include internal nodes or only
leaf nodes
:return: ids in the form
[
['category'],
Expand All @@ -98,7 +113,10 @@ def all_id_lists(vrt, include_internal=True):
"""
def _all_id_lists(sub_vrt, prefix):
if isinstance(sub_vrt, list):
return [vrt_id for entry in sub_vrt for vrt_id in _all_id_lists(entry, prefix)]
return [
vrt_id for entry in sub_vrt
for vrt_id in _all_id_lists(entry, prefix)
]
elif isinstance(sub_vrt, dict):
if 'children' in sub_vrt:
new_prefix = prefix + [sub_vrt['id']]
Expand Down
11 changes: 9 additions & 2 deletions lib/validate_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from artifacts import scw_artifact

artifact_json = utils.get_json(scw_artifact.OUTPUT_FILENAME)
repo_path = os.path.join(utils.THIRD_PARTY_MAPPING_DIR, utils.SCW_DIR, utils.SCW_FILENAME)
repo_path = os.path.join(
utils.THIRD_PARTY_MAPPING_DIR,
utils.SCW_DIR,
utils.SCW_FILENAME
)
print(os.path.abspath(repo_path))
repo_json = utils.get_json(repo_path)

Expand All @@ -16,5 +20,8 @@
print('SCW Document is valid!')
sys.exit(0)
else:
print('SCW Document is invalid, copy the artifact to the remediation training')
print(
'SCW Document is invalid, copy the artifact to the remediation\
training'
)
sys.exit(1)

0 comments on commit 8a7c549

Please sign in to comment.