Skip to content

Commit

Permalink
Extract type requirements from old code
Browse files Browse the repository at this point in the history
  • Loading branch information
SupunS committed Aug 23, 2024
1 parent c813439 commit 68d9ad5
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 21 deletions.
91 changes: 70 additions & 21 deletions internal/migrate/staging_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"fmt"
"strings"

"github.com/rs/zerolog"

"golang.org/x/exp/slices"

"github.com/onflow/cadence"
Expand All @@ -40,6 +42,7 @@ import (
"github.com/onflow/contract-updater/lib/go/templates"
flowsdk "github.com/onflow/flow-go-sdk"
"github.com/onflow/flow-go/cmd/util/ledger/migrations"
"github.com/onflow/flow-go/cmd/util/ledger/reporters"
"github.com/onflow/flow-go/model/flow"
"github.com/onflow/flowkit/v2"

Expand Down Expand Up @@ -187,10 +190,12 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)

v.stagedContracts = make(map[common.AddressLocation]stagedContractUpdate)
for _, stagedContract := range stagedContracts {
v.stagedContracts[stagedContract.DeployLocation] = stagedContract
stagedContractLocation := stagedContract.DeployLocation

v.stagedContracts[stagedContractLocation] = stagedContract

// Add the contract code to the contracts map for pretty printing
v.contracts[stagedContract.SourceLocation] = stagedContract.Code
v.contracts[stagedContractLocation] = stagedContract.Code
}

// Load system contracts
Expand All @@ -199,24 +204,64 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)
// Parse and check all staged contracts
errs := v.checkAllStaged()

typeRequirements := &migrations.LegacyTypeRequirements{}

// Extract type requirements from the old codes for all staged contracts.
for _, contract := range v.stagedContracts {
location := contract.DeployLocation

// Don't validate contracts with existing errors
if errs[location] != nil {
continue
}

// Get the account for the contract
address := flowsdk.Address(location.Address)
account, err := v.flow.GetAccount(context.Background(), address)
if err != nil {
return fmt.Errorf("failed to get account: %w", err)
}

// Get the target contract old code
contractName := location.Name
oldCode, ok := account.Contracts[contractName]
if !ok {
return fmt.Errorf("old contract code not found for contract: %s", contractName)
}

migrations.ExtractTypeRequirements(
migrations.AddressContract{
Location: location,
Code: oldCode,
},
zerolog.Nop(),
reporters.ReportNilWriter{},
typeRequirements,
)
}

// Validate all contract updates
for _, contract := range v.stagedContracts {
location := contract.DeployLocation

// Don't validate contracts with existing errors
if errs[contract.SourceLocation] != nil {
if errs[location] != nil {
continue
}

// Validate the contract update
checker := v.checkingCache[contract.SourceLocation].checker
err := v.validateContractUpdate(contract, checker)
checker := v.checkingCache[location].checker
err := v.validateContractUpdate(contract, checker, typeRequirements)
if err != nil {
errs[contract.SourceLocation] = err
errs[location] = err
}
}

// Check for any upstream contract update failures
for _, contract := range v.stagedContracts {
err := errs[contract.SourceLocation]
location := contract.DeployLocation

err := errs[location]

// We will override any errors other than those related
// to missing dependencies, since they are more specific
Expand Down Expand Up @@ -245,8 +290,8 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)
})

if len(badDeps) > 0 {
errs[contract.SourceLocation] = &upstreamValidationError{
Location: contract.SourceLocation,
errs[location] = &upstreamValidationError{
Location: location,
BadDependencies: badDeps,
}
}
Expand All @@ -257,7 +302,7 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)
// Map errors to address locations
errsByAddress := make(map[common.AddressLocation]error)
for _, contract := range v.stagedContracts {
err := errs[contract.SourceLocation]
err := errs[contract.DeployLocation]
if err != nil {
errsByAddress[contract.DeployLocation] = err
}
Expand All @@ -267,12 +312,13 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)
return nil
}

func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error {
errors := make(map[common.StringLocation]error)
func (v *stagingValidatorImpl) checkAllStaged() map[common.Location]error {
errors := make(map[common.Location]error)
for _, contract := range v.stagedContracts {
_, err := v.checkContract(contract.SourceLocation)
location := contract.DeployLocation
_, err := v.checkContract(location)
if err != nil {
errors[contract.SourceLocation] = err
errors[location] = err
}
}

Expand All @@ -281,6 +327,8 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error
// Note: nodes are not visited more than once so cyclic imports are not an issue
// They will be reported, however, by the checker, if they do exist
for _, contract := range v.stagedContracts {
location := contract.DeployLocation

// Create a set of all dependencies
missingDependencies := make([]common.AddressLocation, 0)
v.forEachDependency(contract, func(dependency common.Location) {
Expand All @@ -292,15 +340,19 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error
})

if len(missingDependencies) > 0 {
errors[contract.SourceLocation] = &missingDependenciesError{
errors[location] = &missingDependenciesError{
MissingContracts: missingDependencies,
}
}
}
return errors
}

func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpdate, checker *sema.Checker) (err error) {
func (v *stagingValidatorImpl) validateContractUpdate(
contract stagedContractUpdate,
checker *sema.Checker,
typeRequirements *migrations.LegacyTypeRequirements,
) (err error) {
// Gracefully recover from panics
defer func() {
if r := recover(); r != nil {
Expand Down Expand Up @@ -333,7 +385,7 @@ func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpd

// Check if contract code is valid according to Cadence V1 Update Checker
validator := stdlib.NewCadenceV042ToV1ContractUpdateValidator(
contract.SourceLocation,
contract.DeployLocation,
contractName,
&accountContractNamesProviderImpl{
resolverFunc: v.resolveAddressContractNames,
Expand All @@ -349,9 +401,6 @@ func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpd
return fmt.Errorf("unsupported network: %s", v.flow.Network().Name)
}

// TODO: extract type requirements from the old contracts
typeRequirements := &migrations.LegacyTypeRequirements{}

validator.WithUserDefinedTypeChangeChecker(
migrations.NewUserDefinedTypeChangeCheckerFunc(chainId, typeRequirements),
)
Expand Down Expand Up @@ -560,7 +609,7 @@ func (v *stagingValidatorImpl) resolveLocation(

// If the contract one of our staged contract updates, use the source location
if stagedUpdate, ok := v.stagedContracts[resovledAddrLocation]; ok {
resolvedLocation = stagedUpdate.SourceLocation
resolvedLocation = stagedUpdate.DeployLocation
} else {
resolvedLocation = resovledAddrLocation
}
Expand Down
54 changes: 54 additions & 0 deletions internal/migrate/staging_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -822,4 +822,58 @@ func Test_StagingValidator(t *testing.T) {
err := validator.Validate([]stagedContractUpdate{{location, sourceCodeLocation, []byte(newContract)}})
require.NoError(t, err)
})

t.Run("with type requirements", func(t *testing.T) {
// setup mocks
srv := setupValidatorMocks(t, []mockNetworkAccount{
{
address: flow.HexToAddress("01"),
contracts: map[string][]byte{"Foo": []byte(`
pub contract interface Foo {
pub let bar: @Bar?
pub resource Bar {}
}`)},
},
{
address: flow.HexToAddress("02"),
contracts: map[string][]byte{"Bar": []byte(`
import Foo from 0x01
pub contract FooImpl: Foo {
pub let bar: @Foo.Bar?
pub resource BarImpl {}
init() {
self.bar <- nil
}
}`)},
},
})

validator := newStagingValidator(srv)
err := validator.Validate([]stagedContractUpdate{
{
DeployLocation: simpleAddressLocation("0x01.Foo"),
SourceLocation: common.StringLocation("./Foo.cdc"),
Code: []byte(`
access(all) contract interface Foo {
access(all) let bar: @{Bar}?
access(all) resource interface Bar {}
}`),
},
{
DeployLocation: simpleAddressLocation("0x02.Bar"),
SourceLocation: common.StringLocation("./Bar.cdc"),
Code: []byte(`
import Foo from 0x01
access(all) contract FooImpl: Foo {
access(all) let bar: @{Foo.Bar}?
access(all) resource BarImpl: Foo.Bar {}
init() {
self.bar <- nil
}
}`),
},
})

require.NoError(t, err)
})
}

0 comments on commit 68d9ad5

Please sign in to comment.