Skip to content

Commit

Permalink
feat: PackageManifest utilities for working with Compilers list [APE-…
Browse files Browse the repository at this point in the history
…1497] (#94)
  • Loading branch information
antazoey authored Nov 4, 2023
1 parent 9d826dc commit 36fa9b6
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 3 deletions.
33 changes: 32 additions & 1 deletion ethpm_types/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def check_contract_source_ids(cls, values):

@validator("contract_types")
def add_name_to_contract_types(cls, values):
aliases = list(values.keys())
aliases = list((values or {}).keys())
# NOTE: Must manually inject names to types here
for alias in aliases:
if not values[alias]:
Expand Down Expand Up @@ -263,3 +263,34 @@ def unpack_sources(self, destination: Path):
source_path.parent.mkdir(parents=True, exist_ok=True)

source_path.write_text(content)

def get_contract_compiler(self, contract_type_name: str) -> Optional[Compiler]:
"""
Get the compiler used to compile the contract type, if it exists.
Args:
contract_type_name (str): The name of the compiled contract.
Returns:
Optional[`~ethpm_types.source.Compiler`]
"""
for compiler in self.compilers or []:
if contract_type_name in (compiler.contractTypes or []):
return compiler

return None

def add_compilers(self, *compilers: Compiler):
"""
Update compilers in the manifest. This method appends any
given compiler with a a different name, version, and settings
combination.
Args:
compilers (List[`~ethpm_types.source.Compiler]`): A list of
compilers.
"""

start = self.compilers or []
start.extend([c for c in compilers if c not in start])
self.compilers = start
23 changes: 23 additions & 0 deletions ethpm_types/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ class Compiler(BaseModel):
that used this compiler to generate its outputs.
"""

def __eq__(self, other) -> bool:
if (
not hasattr(other, "name")
and not hasattr(other, "version")
and not hasattr(other, "settings")
):
return NotImplemented

return (
self.name == other.name
and self.version == other.version
and self.settings == other.settings
)


class Checksum(BaseModel):
"""Checksum information about the contents of a source file."""
Expand All @@ -59,6 +73,15 @@ class Checksum(BaseModel):
The hash of a source files contents generated with the corresponding algorithm.
"""

@classmethod
def from_file(cls, file: Union[Path, str], algorithm: Algorithm = Algorithm.MD5) -> "Checksum":
source_path = file if isinstance(file, Path) else Path(file)
return cls.from_bytes(source_path.read_bytes(), algorithm=algorithm)

@classmethod
def from_bytes(cls, data: bytes, algorithm: Algorithm = Algorithm.MD5) -> "Checksum":
return cls(algorithm=algorithm, hash=compute_checksum(data, algorithm=algorithm))


class Content(BaseModel):
"""
Expand Down
50 changes: 49 additions & 1 deletion tests/test_package_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import pytest
import requests

from ethpm_types import ContractType
from ethpm_types._pydantic_v1 import ValidationError
from ethpm_types.manifest import ALPHABET, NUMBERS, PackageManifest
from ethpm_types.source import Content, Source
from ethpm_types.source import Compiler, Content, Source

ETHPM_SPEC_REPO = github.Github(os.environ.get("GITHUB_ACCESS_TOKEN", None)).get_repo(
"ethpm/ethpm-spec"
Expand Down Expand Up @@ -83,6 +84,10 @@ def test_getattr(package_manifest, solidity_contract):
expected = solidity_contract
assert actual == expected

# Show when not an attribute or contract type.
with pytest.raises(AttributeError):
_ = package_manifest.contractTypes


def test_get_contract_type(package_manifest, solidity_contract):
actual = package_manifest.get_contract_type("SolidityContract")
Expand Down Expand Up @@ -117,3 +122,46 @@ def test_package_name_using_all_valid_characters():
name = "a" + "".join(list(ALPHABET.union(NUMBERS).union({"-"})))
manifest = PackageManifest(name=name, version="0.1.0")
assert manifest.name == name


def test_get_contract_compiler():
compiler = Compiler(name="vyper", version="0.3.7", settings={}, contractTypes=["foobar"])
manifest = PackageManifest(
compilers=[compiler], contractTypes={"foobar": ContractType(contractNam="foobar")}
)
assert manifest.get_contract_compiler("foobar") == compiler
assert manifest.get_contract_compiler("yoyoyo") is None


def test_add_compilers():
compiler = Compiler(name="vyper", version="0.3.7", settings={}, contractTypes=["foobar"])
manifest = PackageManifest(
compilers=[compiler],
contractTypes={
"foobar": ContractType(contractName="foobar", abi=[]),
"testtest": ContractType(contractName="testtest", abi=[]),
},
)
new_compilers = [
Compiler(name="vyper", version="0.3.7", settings={}, contractTypes=["foobar", "testtest"]),
Compiler(name="vyper", version="0.3.10", settings={}, contractTypes=["yoyo"]),
]
manifest.add_compilers(*new_compilers)
assert len(manifest.compilers) == 2


def test_contract_types():
"""
Tests against a bug where validators would fail because
they tried iterating None.
"""
manifest = PackageManifest(contractTypes=None)
assert manifest.contract_types is None

contract_types = {
"foobar": ContractType(contractName="foobar", abi=[]),
"testtest": ContractType(contractName="testtest", abi=[]),
}

manifest = PackageManifest(contractTypes=contract_types)
assert manifest.contract_types == contract_types
39 changes: 38 additions & 1 deletion tests/test_source.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import tempfile
from pathlib import Path

import pytest

from ethpm_types._pydantic_v1 import FileUrl
from ethpm_types.source import Content, ContractSource, Source
from ethpm_types.source import Checksum, Compiler, Content, ContractSource, Source
from ethpm_types.utils import Algorithm, compute_checksum

SOURCE_LOCATION = (
"https://github.com/OpenZeppelin/openzeppelin-contracts"
Expand Down Expand Up @@ -146,3 +150,36 @@ def test_contract_source_use_method_id(vyper_contract, source, source_base):
function = actual.lookup_function(location, method_id=method_id)
assert function.name == "getEmptyTupleOfDynArrayStructs"
assert function.full_name == "getEmptyTupleOfDynArrayStructs()"


def test_compiler_equality():
compiler_1 = Compiler(
name="yo", version="0.1.0", settings={"foo": "bar"}, contractType=["test1"]
)
compiler_2 = Compiler(
name="yo", version="0.1.0", settings={"foo": "bar"}, contractType=["test1", "test2"]
)
assert compiler_1 == compiler_2

compiler_1.name = "yo2"
assert compiler_1 != compiler_2
compiler_1.name = compiler_2.name

compiler_1.version = "0.100000.0"
assert compiler_1 != compiler_2
compiler_1.version = compiler_2.version

compiler_1.settings["test"] = "123"
assert compiler_1 != compiler_2
compiler_1.settings = compiler_2.settings


def test_checksum_from_file():
file = Path(tempfile.mktemp())
file.write_text("foobartest123")
actual = Checksum.from_file(file)
expected = Checksum(
algorithm=Algorithm.MD5,
hash=compute_checksum(file.read_bytes()),
)
assert actual == expected

0 comments on commit 36fa9b6

Please sign in to comment.