Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: PackageManifest utilities for working with Compilers list [APE-1497] #94

Merged
merged 12 commits into from
Nov 4, 2023
33 changes: 32 additions & 1 deletion ethpm_types/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def check_contract_source_ids(cls, values):

@validator("contract_types")
def add_name_to_contract_types(cls, values):
aliases = list(values.keys())
aliases = list((values or {}).keys())
# NOTE: Must manually inject names to types here
for alias in aliases:
if not values[alias]:
Expand Down Expand Up @@ -263,3 +263,34 @@ def unpack_sources(self, destination: Path):
source_path.parent.mkdir(parents=True, exist_ok=True)

source_path.write_text(content)

def get_contract_compiler(self, contract_type_name: str) -> Optional[Compiler]:
"""
Get the compiler used to compile the contract type, if it exists.

Args:
contract_type_name (str): The name of the compiled contract.

Returns:
Optional[`~ethpm_types.source.Compiler`]
"""
for compiler in self.compilers or []:
if contract_type_name in (compiler.contractTypes or []):
return compiler

return None

def add_compilers(self, *compilers: Compiler):
"""
Update compilers in the manifest. This method appends any
given compiler with a a different name, version, and settings
combination.

Args:
compilers (List[`~ethpm_types.source.Compiler]`): A list of
compilers.
"""

start = self.compilers or []
start.extend([c for c in compilers if c not in start])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the future I think .extend takes in a sequence so you can just do .extend(c for c in ...)

self.compilers = start
23 changes: 23 additions & 0 deletions ethpm_types/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ class Compiler(BaseModel):
that used this compiler to generate its outputs.
"""

def __eq__(self, other) -> bool:
if (
not hasattr(other, "name")
and not hasattr(other, "version")
and not hasattr(other, "settings")
):
return NotImplemented

return (
self.name == other.name
and self.version == other.version
and self.settings == other.settings
)


class Checksum(BaseModel):
"""Checksum information about the contents of a source file."""
Expand All @@ -59,6 +73,15 @@ class Checksum(BaseModel):
The hash of a source files contents generated with the corresponding algorithm.
"""

@classmethod
def from_file(cls, file: Union[Path, str], algorithm: Algorithm = Algorithm.MD5) -> "Checksum":
source_path = file if isinstance(file, Path) else Path(file)
return cls.from_bytes(source_path.read_bytes(), algorithm=algorithm)

@classmethod
def from_bytes(cls, data: bytes, algorithm: Algorithm = Algorithm.MD5) -> "Checksum":
return cls(algorithm=algorithm, hash=compute_checksum(data, algorithm=algorithm))


class Content(BaseModel):
"""
Expand Down
50 changes: 49 additions & 1 deletion tests/test_package_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import pytest
import requests

from ethpm_types import ContractType
from ethpm_types._pydantic_v1 import ValidationError
from ethpm_types.manifest import ALPHABET, NUMBERS, PackageManifest
from ethpm_types.source import Content, Source
from ethpm_types.source import Compiler, Content, Source

ETHPM_SPEC_REPO = github.Github(os.environ.get("GITHUB_ACCESS_TOKEN", None)).get_repo(
"ethpm/ethpm-spec"
Expand Down Expand Up @@ -83,6 +84,10 @@ def test_getattr(package_manifest, solidity_contract):
expected = solidity_contract
assert actual == expected

# Show when not an attribute or contract type.
with pytest.raises(AttributeError):
_ = package_manifest.contractTypes


def test_get_contract_type(package_manifest, solidity_contract):
actual = package_manifest.get_contract_type("SolidityContract")
Expand Down Expand Up @@ -117,3 +122,46 @@ def test_package_name_using_all_valid_characters():
name = "a" + "".join(list(ALPHABET.union(NUMBERS).union({"-"})))
manifest = PackageManifest(name=name, version="0.1.0")
assert manifest.name == name


def test_get_contract_compiler():
compiler = Compiler(name="vyper", version="0.3.7", settings={}, contractTypes=["foobar"])
manifest = PackageManifest(
compilers=[compiler], contractTypes={"foobar": ContractType(contractNam="foobar")}
)
assert manifest.get_contract_compiler("foobar") == compiler
assert manifest.get_contract_compiler("yoyoyo") is None


def test_add_compilers():
compiler = Compiler(name="vyper", version="0.3.7", settings={}, contractTypes=["foobar"])
manifest = PackageManifest(
compilers=[compiler],
contractTypes={
"foobar": ContractType(contractName="foobar", abi=[]),
"testtest": ContractType(contractName="testtest", abi=[]),
},
)
new_compilers = [
Compiler(name="vyper", version="0.3.7", settings={}, contractTypes=["foobar", "testtest"]),
Compiler(name="vyper", version="0.3.10", settings={}, contractTypes=["yoyo"]),
]
manifest.add_compilers(*new_compilers)
assert len(manifest.compilers) == 2


def test_contract_types():
"""
Tests against a bug where validators would fail because
they tried iterating None.
"""
manifest = PackageManifest(contractTypes=None)
assert manifest.contract_types is None

contract_types = {
"foobar": ContractType(contractName="foobar", abi=[]),
"testtest": ContractType(contractName="testtest", abi=[]),
}

manifest = PackageManifest(contractTypes=contract_types)
assert manifest.contract_types == contract_types
39 changes: 38 additions & 1 deletion tests/test_source.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import tempfile
from pathlib import Path

import pytest

from ethpm_types._pydantic_v1 import FileUrl
from ethpm_types.source import Content, ContractSource, Source
from ethpm_types.source import Checksum, Compiler, Content, ContractSource, Source
from ethpm_types.utils import Algorithm, compute_checksum

SOURCE_LOCATION = (
"https://github.com/OpenZeppelin/openzeppelin-contracts"
Expand Down Expand Up @@ -146,3 +150,36 @@ def test_contract_source_use_method_id(vyper_contract, source, source_base):
function = actual.lookup_function(location, method_id=method_id)
assert function.name == "getEmptyTupleOfDynArrayStructs"
assert function.full_name == "getEmptyTupleOfDynArrayStructs()"


def test_compiler_equality():
compiler_1 = Compiler(
name="yo", version="0.1.0", settings={"foo": "bar"}, contractType=["test1"]
)
compiler_2 = Compiler(
name="yo", version="0.1.0", settings={"foo": "bar"}, contractType=["test1", "test2"]
)
assert compiler_1 == compiler_2

compiler_1.name = "yo2"
assert compiler_1 != compiler_2
compiler_1.name = compiler_2.name

compiler_1.version = "0.100000.0"
assert compiler_1 != compiler_2
compiler_1.version = compiler_2.version

compiler_1.settings["test"] = "123"
assert compiler_1 != compiler_2
compiler_1.settings = compiler_2.settings


def test_checksum_from_file():
file = Path(tempfile.mktemp())
file.write_text("foobartest123")
actual = Checksum.from_file(file)
expected = Checksum(
algorithm=Algorithm.MD5,
hash=compute_checksum(file.read_bytes()),
)
assert actual == expected