From 68d9ad5f3756bd316ac682bd2e8677cf9a817123 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Fri, 23 Aug 2024 15:38:35 -0700 Subject: [PATCH] Extract type requirements from old code --- internal/migrate/staging_validator.go | 91 +++++++++++++++++----- internal/migrate/staging_validator_test.go | 54 +++++++++++++ 2 files changed, 124 insertions(+), 21 deletions(-) diff --git a/internal/migrate/staging_validator.go b/internal/migrate/staging_validator.go index 5f9c95458..ddfb9b895 100644 --- a/internal/migrate/staging_validator.go +++ b/internal/migrate/staging_validator.go @@ -25,6 +25,8 @@ import ( "fmt" "strings" + "github.com/rs/zerolog" + "golang.org/x/exp/slices" "github.com/onflow/cadence" @@ -40,6 +42,7 @@ import ( "github.com/onflow/contract-updater/lib/go/templates" flowsdk "github.com/onflow/flow-go-sdk" "github.com/onflow/flow-go/cmd/util/ledger/migrations" + "github.com/onflow/flow-go/cmd/util/ledger/reporters" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flowkit/v2" @@ -187,10 +190,12 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) v.stagedContracts = make(map[common.AddressLocation]stagedContractUpdate) for _, stagedContract := range stagedContracts { - v.stagedContracts[stagedContract.DeployLocation] = stagedContract + stagedContractLocation := stagedContract.DeployLocation + + v.stagedContracts[stagedContractLocation] = stagedContract // Add the contract code to the contracts map for pretty printing - v.contracts[stagedContract.SourceLocation] = stagedContract.Code + v.contracts[stagedContractLocation] = stagedContract.Code } // Load system contracts @@ -199,24 +204,64 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) // Parse and check all staged contracts errs := v.checkAllStaged() + typeRequirements := &migrations.LegacyTypeRequirements{} + + // Extract type requirements from the old codes for all staged contracts. + for _, contract := range v.stagedContracts { + location := contract.DeployLocation + + // Don't validate contracts with existing errors + if errs[location] != nil { + continue + } + + // Get the account for the contract + address := flowsdk.Address(location.Address) + account, err := v.flow.GetAccount(context.Background(), address) + if err != nil { + return fmt.Errorf("failed to get account: %w", err) + } + + // Get the target contract old code + contractName := location.Name + oldCode, ok := account.Contracts[contractName] + if !ok { + return fmt.Errorf("old contract code not found for contract: %s", contractName) + } + + migrations.ExtractTypeRequirements( + migrations.AddressContract{ + Location: location, + Code: oldCode, + }, + zerolog.Nop(), + reporters.ReportNilWriter{}, + typeRequirements, + ) + } + // Validate all contract updates for _, contract := range v.stagedContracts { + location := contract.DeployLocation + // Don't validate contracts with existing errors - if errs[contract.SourceLocation] != nil { + if errs[location] != nil { continue } // Validate the contract update - checker := v.checkingCache[contract.SourceLocation].checker - err := v.validateContractUpdate(contract, checker) + checker := v.checkingCache[location].checker + err := v.validateContractUpdate(contract, checker, typeRequirements) if err != nil { - errs[contract.SourceLocation] = err + errs[location] = err } } // Check for any upstream contract update failures for _, contract := range v.stagedContracts { - err := errs[contract.SourceLocation] + location := contract.DeployLocation + + err := errs[location] // We will override any errors other than those related // to missing dependencies, since they are more specific @@ -245,8 +290,8 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) }) if len(badDeps) > 0 { - errs[contract.SourceLocation] = &upstreamValidationError{ - Location: contract.SourceLocation, + errs[location] = &upstreamValidationError{ + Location: location, BadDependencies: badDeps, } } @@ -257,7 +302,7 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) // Map errors to address locations errsByAddress := make(map[common.AddressLocation]error) for _, contract := range v.stagedContracts { - err := errs[contract.SourceLocation] + err := errs[contract.DeployLocation] if err != nil { errsByAddress[contract.DeployLocation] = err } @@ -267,12 +312,13 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) return nil } -func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error { - errors := make(map[common.StringLocation]error) +func (v *stagingValidatorImpl) checkAllStaged() map[common.Location]error { + errors := make(map[common.Location]error) for _, contract := range v.stagedContracts { - _, err := v.checkContract(contract.SourceLocation) + location := contract.DeployLocation + _, err := v.checkContract(location) if err != nil { - errors[contract.SourceLocation] = err + errors[location] = err } } @@ -281,6 +327,8 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error // Note: nodes are not visited more than once so cyclic imports are not an issue // They will be reported, however, by the checker, if they do exist for _, contract := range v.stagedContracts { + location := contract.DeployLocation + // Create a set of all dependencies missingDependencies := make([]common.AddressLocation, 0) v.forEachDependency(contract, func(dependency common.Location) { @@ -292,7 +340,7 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error }) if len(missingDependencies) > 0 { - errors[contract.SourceLocation] = &missingDependenciesError{ + errors[location] = &missingDependenciesError{ MissingContracts: missingDependencies, } } @@ -300,7 +348,11 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error return errors } -func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpdate, checker *sema.Checker) (err error) { +func (v *stagingValidatorImpl) validateContractUpdate( + contract stagedContractUpdate, + checker *sema.Checker, + typeRequirements *migrations.LegacyTypeRequirements, +) (err error) { // Gracefully recover from panics defer func() { if r := recover(); r != nil { @@ -333,7 +385,7 @@ func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpd // Check if contract code is valid according to Cadence V1 Update Checker validator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( - contract.SourceLocation, + contract.DeployLocation, contractName, &accountContractNamesProviderImpl{ resolverFunc: v.resolveAddressContractNames, @@ -349,9 +401,6 @@ func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpd return fmt.Errorf("unsupported network: %s", v.flow.Network().Name) } - // TODO: extract type requirements from the old contracts - typeRequirements := &migrations.LegacyTypeRequirements{} - validator.WithUserDefinedTypeChangeChecker( migrations.NewUserDefinedTypeChangeCheckerFunc(chainId, typeRequirements), ) @@ -560,7 +609,7 @@ func (v *stagingValidatorImpl) resolveLocation( // If the contract one of our staged contract updates, use the source location if stagedUpdate, ok := v.stagedContracts[resovledAddrLocation]; ok { - resolvedLocation = stagedUpdate.SourceLocation + resolvedLocation = stagedUpdate.DeployLocation } else { resolvedLocation = resovledAddrLocation } diff --git a/internal/migrate/staging_validator_test.go b/internal/migrate/staging_validator_test.go index b6b01b9ce..af8a7aeaf 100644 --- a/internal/migrate/staging_validator_test.go +++ b/internal/migrate/staging_validator_test.go @@ -822,4 +822,58 @@ func Test_StagingValidator(t *testing.T) { err := validator.Validate([]stagedContractUpdate{{location, sourceCodeLocation, []byte(newContract)}}) require.NoError(t, err) }) + + t.Run("with type requirements", func(t *testing.T) { + // setup mocks + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Foo": []byte(` + pub contract interface Foo { + pub let bar: @Bar? + pub resource Bar {} + }`)}, + }, + { + address: flow.HexToAddress("02"), + contracts: map[string][]byte{"Bar": []byte(` + import Foo from 0x01 + pub contract FooImpl: Foo { + pub let bar: @Foo.Bar? + pub resource BarImpl {} + init() { + self.bar <- nil + } + }`)}, + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation("./Foo.cdc"), + Code: []byte(` + access(all) contract interface Foo { + access(all) let bar: @{Bar}? + access(all) resource interface Bar {} + }`), + }, + { + DeployLocation: simpleAddressLocation("0x02.Bar"), + SourceLocation: common.StringLocation("./Bar.cdc"), + Code: []byte(` + import Foo from 0x01 + access(all) contract FooImpl: Foo { + access(all) let bar: @{Foo.Bar}? + access(all) resource BarImpl: Foo.Bar {} + init() { + self.bar <- nil + } + }`), + }, + }) + + require.NoError(t, err) + }) }