Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract type requirements from old code for staged contracts #1705

Merged
merged 6 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 81 additions & 44 deletions internal/migrate/staging_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"fmt"
"strings"

"github.com/rs/zerolog"

"golang.org/x/exp/slices"

"github.com/onflow/cadence"
Expand All @@ -40,6 +42,7 @@ import (
"github.com/onflow/contract-updater/lib/go/templates"
flowsdk "github.com/onflow/flow-go-sdk"
"github.com/onflow/flow-go/cmd/util/ledger/migrations"
"github.com/onflow/flow-go/cmd/util/ledger/reporters"
"github.com/onflow/flow-go/model/flow"
"github.com/onflow/flowkit/v2"

Expand All @@ -59,9 +62,13 @@ type stagingValidatorImpl struct {

// Cache for account contract names so we don't have to fetch them multiple times
accountContractNames map[common.Address][]string

// All resolved contract code
contracts map[common.Location][]byte

// Contract codes that are not updated/staged
oldCodes map[common.Location][]byte

// Dependency graph for staged contracts
// This root level map holds all nodes
graph map[common.Location]node
Expand Down Expand Up @@ -173,6 +180,7 @@ func newStagingValidator(flow flowkit.Services) *stagingValidatorImpl {
checkingCache: make(map[common.Location]*cachedCheckingResult),
accountContractNames: make(map[common.Address][]string),
graph: make(map[common.Location]node),
oldCodes: make(map[common.Location][]byte),
}
}

Expand All @@ -187,10 +195,12 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)

v.stagedContracts = make(map[common.AddressLocation]stagedContractUpdate)
for _, stagedContract := range stagedContracts {
v.stagedContracts[stagedContract.DeployLocation] = stagedContract
stagedContractLocation := stagedContract.DeployLocation
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contract update validator relies on typeIDs. So had to use the DeployLocation in all places, instead of SourceLocation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a dedicated test for ^above: bd87b9f


v.stagedContracts[stagedContractLocation] = stagedContract

// Add the contract code to the contracts map for pretty printing
v.contracts[stagedContract.SourceLocation] = stagedContract.Code
v.contracts[stagedContractLocation] = stagedContract.Code
}

// Load system contracts
Expand All @@ -199,24 +209,65 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)
// Parse and check all staged contracts
errs := v.checkAllStaged()

typeRequirements := &migrations.LegacyTypeRequirements{}

// Extract type requirements from the old codes for all staged contracts.
for _, contract := range v.stagedContracts {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actual improvement is this block: extract type requirements from old code

location := contract.DeployLocation

// Don't validate contracts with existing errors
if errs[location] != nil {
continue
}

// Get the account for the contract
address := flowsdk.Address(location.Address)
account, err := v.flow.GetAccount(context.Background(), address)
if err != nil {
return fmt.Errorf("failed to get account: %w", err)
}

// Get the target contract old code
contractName := location.Name
oldCode, ok := account.Contracts[contractName]
if !ok {
return fmt.Errorf("old contract code not found for contract: %s", contractName)
}
v.oldCodes[location] = oldCode

migrations.ExtractTypeRequirements(
migrations.AddressContract{
Location: location,
Code: oldCode,
},
zerolog.Nop(),
reporters.ReportNilWriter{},
typeRequirements,
)
}

// Validate all contract updates
for _, contract := range v.stagedContracts {
location := contract.DeployLocation

// Don't validate contracts with existing errors
if errs[contract.SourceLocation] != nil {
if errs[location] != nil {
continue
}

// Validate the contract update
checker := v.checkingCache[contract.SourceLocation].checker
err := v.validateContractUpdate(contract, checker)
checker := v.checkingCache[location].checker
err := v.validateContractUpdate(contract, checker, typeRequirements)
if err != nil {
errs[contract.SourceLocation] = err
errs[location] = err
}
}

// Check for any upstream contract update failures
for _, contract := range v.stagedContracts {
err := errs[contract.SourceLocation]
location := contract.DeployLocation

err := errs[location]

// We will override any errors other than those related
// to missing dependencies, since they are more specific
Expand All @@ -234,19 +285,14 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)

badDeps := make([]common.Location, 0)
v.forEachDependency(contract, func(dependency common.Location) {
strLocation, ok := dependency.(common.StringLocation)
if !ok {
return
}

if errs[strLocation] != nil {
if errs[dependency] != nil {
badDeps = append(badDeps, dependency)
}
})

if len(badDeps) > 0 {
errs[contract.SourceLocation] = &upstreamValidationError{
Location: contract.SourceLocation,
errs[location] = &upstreamValidationError{
Location: location,
BadDependencies: badDeps,
}
}
Expand All @@ -257,7 +303,7 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)
// Map errors to address locations
errsByAddress := make(map[common.AddressLocation]error)
for _, contract := range v.stagedContracts {
err := errs[contract.SourceLocation]
err := errs[contract.DeployLocation]
if err != nil {
errsByAddress[contract.DeployLocation] = err
}
Expand All @@ -267,12 +313,13 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)
return nil
}

func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error {
errors := make(map[common.StringLocation]error)
func (v *stagingValidatorImpl) checkAllStaged() map[common.Location]error {
errors := make(map[common.Location]error)
for _, contract := range v.stagedContracts {
_, err := v.checkContract(contract.SourceLocation)
location := contract.DeployLocation
_, err := v.checkContract(location)
if err != nil {
errors[contract.SourceLocation] = err
errors[location] = err
}
}

Expand All @@ -281,6 +328,8 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error
// Note: nodes are not visited more than once so cyclic imports are not an issue
// They will be reported, however, by the checker, if they do exist
for _, contract := range v.stagedContracts {
location := contract.DeployLocation

// Create a set of all dependencies
missingDependencies := make([]common.AddressLocation, 0)
v.forEachDependency(contract, func(dependency common.Location) {
Expand All @@ -292,32 +341,31 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error
})

if len(missingDependencies) > 0 {
errors[contract.SourceLocation] = &missingDependenciesError{
errors[location] = &missingDependenciesError{
MissingContracts: missingDependencies,
}
}
}
return errors
}

func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpdate, checker *sema.Checker) (err error) {
func (v *stagingValidatorImpl) validateContractUpdate(
contract stagedContractUpdate,
checker *sema.Checker,
typeRequirements *migrations.LegacyTypeRequirements,
) (err error) {
// Gracefully recover from panics
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during contract update validation: %v", r)
}
}()

// Get the account for the contract
address := flowsdk.Address(contract.DeployLocation.Address)
account, err := v.flow.GetAccount(context.Background(), address)
if err != nil {
return fmt.Errorf("failed to get account: %w", err)
}
location := contract.DeployLocation
contractName := location.Name

// Get the target contract old code
contractName := contract.DeployLocation.Name
contractCode, ok := account.Contracts[contractName]
contractCode, ok := v.oldCodes[location]
if !ok {
return fmt.Errorf("old contract code not found for contract: %s", contractName)
}
Expand All @@ -333,7 +381,7 @@ func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpd

// Check if contract code is valid according to Cadence V1 Update Checker
validator := stdlib.NewCadenceV042ToV1ContractUpdateValidator(
contract.SourceLocation,
location,
contractName,
&accountContractNamesProviderImpl{
resolverFunc: v.resolveAddressContractNames,
Expand All @@ -349,9 +397,6 @@ func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpd
return fmt.Errorf("unsupported network: %s", v.flow.Network().Name)
}

// TODO: extract type requirements from the old contracts
typeRequirements := &migrations.LegacyTypeRequirements{}

validator.WithUserDefinedTypeChangeChecker(
migrations.NewUserDefinedTypeChangeCheckerFunc(chainId, typeRequirements),
)
Expand Down Expand Up @@ -552,19 +597,11 @@ func (v *stagingValidatorImpl) resolveLocation(
for i := range resolvedLocations {
identifier := identifiers[i]

var resolvedLocation common.Location
resovledAddrLocation := common.AddressLocation{
resolvedLocation := common.AddressLocation{
Address: addressLocation.Address,
Name: identifier.Identifier,
}

// If the contract one of our staged contract updates, use the source location
if stagedUpdate, ok := v.stagedContracts[resovledAddrLocation]; ok {
resolvedLocation = stagedUpdate.SourceLocation
} else {
resolvedLocation = resovledAddrLocation
}

Comment on lines -592 to -598
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need to separate the staged contracts vs other contracts now, given staged contracts also uses DeployLocation, which is an address location.

resolvedLocations[i] = runtime.ResolvedLocation{
Location: resolvedLocation,
Identifiers: []runtime.Identifier{identifier},
Expand Down Expand Up @@ -725,7 +762,7 @@ func (v *stagingValidatorImpl) forEachDependency(
}
}
}
traverse(contract.SourceLocation)
traverse(contract.DeployLocation)
}

// Helper for pretty printing errors
Expand Down
56 changes: 55 additions & 1 deletion internal/migrate/staging_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ func Test_StagingValidator(t *testing.T) {
// check that error exists & ensure that the local contract names are used (not the deploy locations)
fooErr := validatorErr.errors[simpleAddressLocation("0x01.Foo")]
require.ErrorContains(t, fooErr, "mismatched types")
require.ErrorContains(t, fooErr, "Foo.cdc")
require.ErrorContains(t, fooErr, "0000000000000001.Foo")

// Bar should have an error related to
var upstreamErr *upstreamValidationError
Expand Down Expand Up @@ -822,4 +822,58 @@ func Test_StagingValidator(t *testing.T) {
err := validator.Validate([]stagedContractUpdate{{location, sourceCodeLocation, []byte(newContract)}})
require.NoError(t, err)
})

t.Run("with type requirements", func(t *testing.T) {
// setup mocks
srv := setupValidatorMocks(t, []mockNetworkAccount{
{
address: flow.HexToAddress("01"),
contracts: map[string][]byte{"Foo": []byte(`
pub contract interface Foo {
pub let bar: @Bar?
pub resource Bar {}
}`)},
},
{
address: flow.HexToAddress("02"),
contracts: map[string][]byte{"Bar": []byte(`
import Foo from 0x01
pub contract FooImpl: Foo {
pub let bar: @Foo.Bar?
pub resource BarImpl {}
init() {
self.bar <- nil
}
}`)},
},
})

validator := newStagingValidator(srv)
err := validator.Validate([]stagedContractUpdate{
{
DeployLocation: simpleAddressLocation("0x01.Foo"),
SourceLocation: common.StringLocation("./Foo.cdc"),
Code: []byte(`
access(all) contract interface Foo {
access(all) let bar: @{Bar}?
access(all) resource interface Bar {}
}`),
},
{
DeployLocation: simpleAddressLocation("0x02.Bar"),
SourceLocation: common.StringLocation("./Bar.cdc"),
Code: []byte(`
import Foo from 0x01
access(all) contract FooImpl: Foo {
access(all) let bar: @{Foo.Bar}?
access(all) resource BarImpl: Foo.Bar {}
init() {
self.bar <- nil
}
}`),
},
})

require.NoError(t, err)
})
}
Loading