diff --git a/internal/migrate/staging_validator.go b/internal/migrate/staging_validator.go index 41ddfada3..8e90110ef 100644 --- a/internal/migrate/staging_validator.go +++ b/internal/migrate/staging_validator.go @@ -25,6 +25,9 @@ import ( "fmt" "strings" + "github.com/rs/zerolog" + "golang.org/x/exp/slices" + "github.com/onflow/cadence" "github.com/onflow/cadence/runtime" "github.com/onflow/cadence/runtime/ast" @@ -35,12 +38,14 @@ import ( "github.com/onflow/cadence/runtime/pretty" "github.com/onflow/cadence/runtime/sema" "github.com/onflow/cadence/runtime/stdlib" - "github.com/onflow/contract-updater/lib/go/templates" - flowsdk "github.com/onflow/flow-go-sdk" + "github.com/onflow/flow-go/cmd/util/ledger/migrations" + "github.com/onflow/flow-go/cmd/util/ledger/reporters" "github.com/onflow/flow-go/model/flow" + + "github.com/onflow/contract-updater/lib/go/templates" + flowsdk "github.com/onflow/flow-go-sdk" "github.com/onflow/flowkit/v2" - "golang.org/x/exp/slices" "github.com/onflow/flow-cli/internal/util" ) @@ -58,9 +63,13 @@ type stagingValidatorImpl struct { // Cache for account contract names so we don't have to fetch them multiple times accountContractNames map[common.Address][]string + // All resolved contract code contracts map[common.Location][]byte + // Contract codes that are not updated/staged + oldCodes map[common.Location][]byte + // Dependency graph for staged contracts // This root level map holds all nodes graph map[common.Location]node @@ -172,6 +181,7 @@ func newStagingValidator(flow flowkit.Services) *stagingValidatorImpl { checkingCache: make(map[common.Location]*cachedCheckingResult), accountContractNames: make(map[common.Address][]string), graph: make(map[common.Location]node), + oldCodes: make(map[common.Location][]byte), } } @@ -186,10 +196,12 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) v.stagedContracts = make(map[common.AddressLocation]stagedContractUpdate) for _, stagedContract := range stagedContracts { - v.stagedContracts[stagedContract.DeployLocation] = stagedContract + stagedContractLocation := stagedContract.DeployLocation + + v.stagedContracts[stagedContractLocation] = stagedContract // Add the contract code to the contracts map for pretty printing - v.contracts[stagedContract.SourceLocation] = stagedContract.Code + v.contracts[stagedContractLocation] = stagedContract.Code } // Load system contracts @@ -198,24 +210,72 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) // Parse and check all staged contracts errs := v.checkAllStaged() + typeRequirements := &migrations.LegacyTypeRequirements{} + + // Extract type requirements from the old codes for all staged contracts. + for _, contract := range v.stagedContracts { + location := contract.DeployLocation + + // Don't validate contracts with existing errors + if errs[location] != nil { + continue + } + + // Get the account for the contract + address := flowsdk.Address(location.Address) + + var account *flowsdk.Account + var err error + + err = withRetry(func() error { + account, err = v.flow.GetAccount(context.Background(), address) + return err + }) + if err != nil { + return fmt.Errorf("failed to get account: %w", err) + } + + // Get the target contract old code + contractName := location.Name + oldCode, ok := account.Contracts[contractName] + if !ok { + return fmt.Errorf("old contract code not found for contract: %s", contractName) + } + v.oldCodes[location] = oldCode + + migrations.ExtractTypeRequirements( + migrations.AddressContract{ + Location: location, + Code: oldCode, + }, + zerolog.Nop(), + reporters.ReportNilWriter{}, + typeRequirements, + ) + } + // Validate all contract updates for _, contract := range v.stagedContracts { + location := contract.DeployLocation + // Don't validate contracts with existing errors - if errs[contract.SourceLocation] != nil { + if errs[location] != nil { continue } // Validate the contract update - checker := v.checkingCache[contract.SourceLocation].checker - err := v.validateContractUpdate(contract, checker) + checker := v.checkingCache[location].checker + err := v.validateContractUpdate(contract, checker, typeRequirements) if err != nil { - errs[contract.SourceLocation] = err + errs[location] = err } } // Check for any upstream contract update failures for _, contract := range v.stagedContracts { - err := errs[contract.SourceLocation] + location := contract.DeployLocation + + err := errs[location] // We will override any errors other than those related // to missing dependencies, since they are more specific @@ -233,19 +293,14 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) badDeps := make([]common.Location, 0) v.forEachDependency(contract, func(dependency common.Location) { - strLocation, ok := dependency.(common.StringLocation) - if !ok { - return - } - - if errs[strLocation] != nil { + if errs[dependency] != nil { badDeps = append(badDeps, dependency) } }) if len(badDeps) > 0 { - errs[contract.SourceLocation] = &upstreamValidationError{ - Location: contract.SourceLocation, + errs[location] = &upstreamValidationError{ + Location: location, BadDependencies: badDeps, } } @@ -256,7 +311,7 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) // Map errors to address locations errsByAddress := make(map[common.AddressLocation]error) for _, contract := range v.stagedContracts { - err := errs[contract.SourceLocation] + err := errs[contract.DeployLocation] if err != nil { errsByAddress[contract.DeployLocation] = err } @@ -266,12 +321,13 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) return nil } -func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error { - errs := make(map[common.StringLocation]error) +func (v *stagingValidatorImpl) checkAllStaged() map[common.Location]error { + errs := make(map[common.Location]error) for _, contract := range v.stagedContracts { - _, err := v.checkContract(contract.SourceLocation) + location := contract.DeployLocation + _, err := v.checkContract(location) if err != nil { - errs[contract.SourceLocation] = err + errs[location] = err } } @@ -280,6 +336,8 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error // Note: nodes are not visited more than once so cyclic imports are not an issue // They will be reported, however, by the checker, if they do exist for _, contract := range v.stagedContracts { + location := contract.DeployLocation + // Create a set of all dependencies missingDependencies := make([]common.AddressLocation, 0) v.forEachDependency(contract, func(dependency common.Location) { @@ -293,7 +351,7 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error if len(missingDependencies) > 0 { // If an error exists, only overwrite if it is a checking error - existingErr, ok := errs[contract.SourceLocation] + existingErr, ok := errs[location] if ok { var existingCheckingErr *sema.CheckerError if !errors.As(existingErr, &existingCheckingErr) { @@ -301,7 +359,7 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error } } - errs[contract.SourceLocation] = &missingDependenciesError{ + errs[location] = &missingDependenciesError{ MissingContracts: missingDependencies, } } @@ -310,7 +368,11 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error return errs } -func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpdate, checker *sema.Checker) (err error) { +func (v *stagingValidatorImpl) validateContractUpdate( + contract stagedContractUpdate, + checker *sema.Checker, + typeRequirements *migrations.LegacyTypeRequirements, +) (err error) { // Gracefully recover from panics defer func() { if r := recover(); r != nil { @@ -318,21 +380,11 @@ func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpd } }() - // Get the account for the contract - address := flowsdk.Address(contract.DeployLocation.Address) - - var account *flowsdk.Account - err = withRetry(func() error { - account, err = v.flow.GetAccount(context.Background(), address) - return err - }) - if err != nil { - return fmt.Errorf("failed to get account: %w", err) - } + location := contract.DeployLocation + contractName := location.Name // Get the target contract old code - contractName := contract.DeployLocation.Name - contractCode, ok := account.Contracts[contractName] + contractCode, ok := v.oldCodes[location] if !ok { return fmt.Errorf("old contract code not found for contract: %s", contractName) } @@ -348,7 +400,7 @@ func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpd // Check if contract code is valid according to Cadence V1 Update Checker validator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( - contract.SourceLocation, + location, contractName, &accountContractNamesProviderImpl{ resolverFunc: func(address common.Address) ([]string, error) { @@ -371,9 +423,6 @@ func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpd return fmt.Errorf("unsupported network: %s", v.flow.Network().Name) } - // TODO: extract type requirements from the old contracts - typeRequirements := &migrations.LegacyTypeRequirements{} - validator.WithUserDefinedTypeChangeChecker( migrations.NewUserDefinedTypeChangeCheckerFunc(chainId, typeRequirements), ) @@ -583,19 +632,11 @@ func (v *stagingValidatorImpl) resolveLocation( for i := range resolvedLocations { identifier := identifiers[i] - var resolvedLocation common.Location - resovledAddrLocation := common.AddressLocation{ + resolvedLocation := common.AddressLocation{ Address: addressLocation.Address, Name: identifier.Identifier, } - // If the contract one of our staged contract updates, use the source location - if stagedUpdate, ok := v.stagedContracts[resovledAddrLocation]; ok { - resolvedLocation = stagedUpdate.SourceLocation - } else { - resolvedLocation = resovledAddrLocation - } - resolvedLocations[i] = runtime.ResolvedLocation{ Location: resolvedLocation, Identifiers: []runtime.Identifier{identifier}, @@ -765,7 +806,7 @@ func (v *stagingValidatorImpl) forEachDependency( } } } - traverse(contract.SourceLocation) + traverse(contract.DeployLocation) } // Helper for pretty printing errors diff --git a/internal/migrate/staging_validator_test.go b/internal/migrate/staging_validator_test.go index f91c57cc0..a1e15dc09 100644 --- a/internal/migrate/staging_validator_test.go +++ b/internal/migrate/staging_validator_test.go @@ -496,7 +496,7 @@ func Test_StagingValidator(t *testing.T) { // check that error exists & ensure that the local contract names are used (not the deploy locations) fooErr := validatorErr.errors[simpleAddressLocation("0x01.Foo")] require.ErrorContains(t, fooErr, "mismatched types") - require.ErrorContains(t, fooErr, "Foo.cdc") + require.ErrorContains(t, fooErr, "0000000000000001.Foo") // Bar should have an error related to var upstreamErr *upstreamValidationError @@ -822,4 +822,115 @@ func Test_StagingValidator(t *testing.T) { err := validator.Validate([]stagedContractUpdate{{location, sourceCodeLocation, []byte(newContract)}}) require.NoError(t, err) }) + + t.Run("with type requirements", func(t *testing.T) { + // setup mocks + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Foo": []byte(` + pub contract interface Foo { + pub let bar: @Bar? + pub resource Bar {} + }`)}, + }, + { + address: flow.HexToAddress("02"), + contracts: map[string][]byte{"Bar": []byte(` + import Foo from 0x01 + pub contract FooImpl: Foo { + pub let bar: @Foo.Bar? + pub resource BarImpl {} + init() { + self.bar <- nil + } + }`)}, + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation("./Foo.cdc"), + Code: []byte(` + access(all) contract interface Foo { + access(all) let bar: @{Bar}? + access(all) resource interface Bar {} + }`), + }, + { + DeployLocation: simpleAddressLocation("0x02.Bar"), + SourceLocation: common.StringLocation("./Bar.cdc"), + Code: []byte(` + import Foo from 0x01 + access(all) contract FooImpl: Foo { + access(all) let bar: @{Foo.Bar}? + access(all) resource BarImpl: Foo.Bar {} + init() { + self.bar <- nil + } + }`), + }, + }) + + require.NoError(t, err) + }) + + t.Run("contract update with entitlements", func(t *testing.T) { + // setup mocks + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Foo": []byte(` + pub contract Foo { + pub resource Bar {} + }`)}, + }, + { + address: flow.HexToAddress("02"), + contracts: map[string][]byte{"Test": []byte(` + import Foo from 0x01 + pub contract Test { + pub resource R { + pub var bar: auth &Foo.Bar? + init() { + self.bar = nil + } + } + }`)}, + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation("./Foo.cdc"), + Code: []byte(` + access(all) contract Foo { + access(all) resource Bar { + access(E) fun foo(){} + } + access(all) entitlement E + }`), + }, + { + DeployLocation: simpleAddressLocation("0x02.Test"), + SourceLocation: common.StringLocation("./Test.cdc"), + Code: []byte(` + import Foo from 0x01 + access(all) contract Test { + access(all) resource R { + access(all) var bar: auth(Foo.E) &Foo.Bar? + init() { + self.bar = nil + } + } + }`), + }, + }) + + require.NoError(t, err) + }) }