From fd1d004f155ad72ce9dd1c541a6cb85ccd6b16d8 Mon Sep 17 00:00:00 2001 From: Arn0d Date: Fri, 1 Mar 2024 00:28:53 +0100 Subject: [PATCH 1/6] rework account as components and add preset account --- src/account.cairo | 2 + src/account/account.cairo | 97 +++++++++++++++++++---------------- src/interfaces/IAccount.cairo | 3 ++ src/lib.cairo | 1 + src/presets.cairo | 3 ++ src/presets/account.cairo | 34 ++++++++++++ src/registry/registry.cairo | 8 +-- tests/test_account.cairo | 2 +- tests/test_registry.cairo | 2 +- 9 files changed, 101 insertions(+), 51 deletions(-) create mode 100644 src/presets.cairo create mode 100644 src/presets/account.cairo diff --git a/src/account.cairo b/src/account.cairo index cfc38a2..41277d0 100644 --- a/src/account.cairo +++ b/src/account.cairo @@ -1 +1,3 @@ mod account; + +use account::AccountComponent; \ No newline at end of file diff --git a/src/account/account.cairo b/src/account/account.cairo index 8e2a0f5..5e96303 100644 --- a/src/account/account.cairo +++ b/src/account/account.cairo @@ -1,8 +1,8 @@ //////////////////////////////// // Account contract //////////////////////////////// -#[starknet::contract] -mod Account { +#[starknet::component] +mod AccountComponent { use starknet::{ get_tx_info, get_caller_address, get_contract_address, get_block_timestamp, ContractAddress, account::Call, call_contract_syscall, replace_class_syscall, ClassHash, SyscallResultTrait @@ -14,16 +14,13 @@ mod Account { use zeroable::Zeroable; use token_bound_accounts::interfaces::IERC721::{IERC721DispatcherTrait, IERC721Dispatcher}; use token_bound_accounts::interfaces::IAccount::IAccount; - - // SRC5 interface for token bound accounts - const TBA_INTERFACE_ID: felt252 = - 0x539036932a2ab9c4734fbfd9872a1f7791a3f577e45477336ae0fd0a00c9ff; + use token_bound_accounts::interfaces::IAccount::{TBA_INTERFACE_ID}; #[storage] struct Storage { - _token_contract: ContractAddress, // contract address of NFT - _token_id: u256, // token ID of NFT - _unlock_timestamp: u64, // time to unlock account when locked + Account_token_contract: ContractAddress, // contract address of NFT + Account_token_id: u256, // token ID of NFT + Account_unlock_timestamp: u64, // time to unlock account when locked } #[event] @@ -74,45 +71,40 @@ mod Account { duration: u64, } - #[constructor] - fn constructor(ref self: ContractState, token_contract: ContractAddress, token_id: u256) { - self._token_contract.write(token_contract); - self._token_id.write(token_id); - - let owner = self._get_owner(token_contract, token_id); - self.emit(AccountCreated { owner }); - } - - #[external(v0)] - impl IAccountImpl of IAccount { + #[embeddable_as(AccountImpl)] + impl Account< + TContractState, + +HasComponent, + +Drop + > of IAccount> { /// @notice used for signature validation /// @param hash The message hash /// @param signature The signature to be validated fn is_valid_signature( - self: @ContractState, hash: felt252, signature: Span + self: @ComponentState, hash: felt252, signature: Span ) -> felt252 { self._is_valid_signature(hash, signature) } fn __validate_deploy__( - self: @ContractState, class_hash: felt252, contract_address_salt: felt252, + self: @ComponentState, class_hash: felt252, contract_address_salt: felt252, ) -> felt252 { self._validate_transaction() } - fn __validate_declare__(self: @ContractState, class_hash: felt252) -> felt252 { + fn __validate_declare__(self: @ComponentState, class_hash: felt252) -> felt252 { self._validate_transaction() } /// @notice validate an account transaction /// @param calls an array of transactions to be executed - fn __validate__(ref self: ContractState, mut calls: Array) -> felt252 { + fn __validate__(ref self: ComponentState, mut calls: Array) -> felt252 { self._validate_transaction() } /// @notice executes a transaction /// @param calls an array of transactions to be executed - fn __execute__(ref self: ContractState, mut calls: Array) -> Array> { + fn __execute__(ref self: ComponentState, mut calls: Array) -> Array> { self._assert_only_owner(); let (lock_status, _) = self._is_locked(); assert(!lock_status, 'Account: account is locked!'); @@ -131,19 +123,19 @@ mod Account { /// @param token_contract the contract address of the NFT /// @param token_id the token ID of the NFT fn owner( - self: @ContractState, token_contract: ContractAddress, token_id: u256 + self: @ComponentState, token_contract: ContractAddress, token_id: u256 ) -> ContractAddress { self._get_owner(token_contract, token_id) } /// @notice returns the contract address and token ID of the NFT - fn token(self: @ContractState) -> (ContractAddress, u256) { + fn token(self: @ComponentState) -> (ContractAddress, u256) { self._get_token() } /// @notice ugprades an account implementation /// @param implementation the new class_hash - fn upgrade(ref self: ContractState, implementation: ClassHash) { + fn upgrade(ref self: ComponentState, implementation: ClassHash) { self._assert_only_owner(); let (lock_status, _) = self._is_locked(); assert(!lock_status, 'Account: account is locked!'); @@ -154,13 +146,13 @@ mod Account { // @notice protection mechanism for selling token bound accounts. can't execute when account is locked // @param duration for which to lock account - fn lock(ref self: ContractState, duration: u64) { + fn lock(ref self: ComponentState, duration: u64) { self._assert_only_owner(); let (lock_status, _) = self._is_locked(); assert(!lock_status, 'Account: account already locked'); let current_timestamp = get_block_timestamp(); let unlock_time = current_timestamp + duration; - self._unlock_timestamp.write(unlock_time); + self.Account_unlock_timestamp.write(unlock_time); self .emit( AccountLocked { @@ -170,13 +162,13 @@ mod Account { } // @notice returns account lock status and time left until account unlocks - fn is_locked(self: @ContractState) -> (bool, u64) { + fn is_locked(self: @ComponentState) -> (bool, u64) { return self._is_locked(); } // @notice check that account supports TBA interface // @param interface_id interface to be checked against - fn supports_interface(self: @ContractState, interface_id: felt252) -> bool { + fn supports_interface(self: @ComponentState, interface_id: felt252) -> bool { if (interface_id == TBA_INTERFACE_ID) { return true; } else { @@ -186,11 +178,26 @@ mod Account { } #[generate_trait] - impl internalImpl of InternalTrait { + impl InternalImpl< + TContractState, + +HasComponent, + +Drop + > of InternalTrait { + + /// @notice initializes the account by setting the initial token conrtact and token id + fn initializer(ref self: ComponentState, token_contract: ContractAddress, token_id: u256) { + + self.Account_token_contract.write(token_contract); + self.Account_token_id.write(token_id); + + let owner = self._get_owner(token_contract, token_id); + self.emit(AccountCreated { owner }); + } + /// @notice check that caller is the token bound account - fn _assert_only_owner(ref self: ContractState) { + fn _assert_only_owner(ref self: ComponentState) { let caller = get_caller_address(); - let owner = self._get_owner(self._token_contract.read(), self._token_id.read()); + let owner = self._get_owner(self.Account_token_contract.read(), self.Account_token_id.read()); assert(caller == owner, 'Account: unathorized'); } @@ -199,7 +206,7 @@ mod Account { // @param token_id token ID of NFT // NB: This function aims for compatibility with all contracts (snake or camel case) but do not work as expected on mainnet as low level calls do not return err at the moment. Should work for contracts which implements CamelCase but not snake_case until starknet v0.15. fn _get_owner( - self: @ContractState, token_contract: ContractAddress, token_id: u256 + self: @ComponentState, token_contract: ContractAddress, token_id: u256 ) -> ContractAddress { let mut calldata: Array = ArrayTrait::new(); Serde::serialize(@token_id, ref calldata); @@ -214,15 +221,15 @@ mod Account { } /// @notice internal transaction for returning the contract address and token ID of the NFT - fn _get_token(self: @ContractState) -> (ContractAddress, u256) { - let contract = self._token_contract.read(); - let tokenId = self._token_id.read(); + fn _get_token(self: @ComponentState) -> (ContractAddress, u256) { + let contract = self.Account_token_contract.read(); + let tokenId = self.Account_token_id.read(); (contract, tokenId) } // @notice protection mechanism for TBA trading. Returns the lock-status (true or false), and the remaning time till account unlocks. - fn _is_locked(self: @ContractState) -> (bool, u64) { - let unlock_timestamp = self._unlock_timestamp.read(); + fn _is_locked(self: @ComponentState) -> (bool, u64) { + let unlock_timestamp = self.Account_unlock_timestamp.read(); let current_time = get_block_timestamp(); if (current_time < unlock_timestamp) { let time_until_unlocks = unlock_timestamp - current_time; @@ -233,7 +240,7 @@ mod Account { } /// @notice internal function for tx validation - fn _validate_transaction(self: @ContractState) -> felt252 { + fn _validate_transaction(self: @ComponentState) -> felt252 { let tx_info = get_tx_info().unbox(); let tx_hash = tx_info.transaction_hash; let signature = tx_info.signature; @@ -246,13 +253,13 @@ mod Account { /// @notice internal function for signature validation fn _is_valid_signature( - self: @ContractState, hash: felt252, signature: Span + self: @ComponentState, hash: felt252, signature: Span ) -> felt252 { let signature_length = signature.len(); assert(signature_length == 2_u32, 'Account: invalid sig length'); let caller = get_caller_address(); - let owner = self._get_owner(self._token_contract.read(), self._token_id.read()); + let owner = self._get_owner(self.Account_token_contract.read(), self.Account_token_id.read()); if (caller == owner) { return starknet::VALIDATED; } else { @@ -262,7 +269,7 @@ mod Account { /// @notice internal function for executing transactions /// @param calls An array of transactions to be executed - fn _execute_calls(ref self: ContractState, mut calls: Span) -> Array> { + fn _execute_calls(ref self: ComponentState, mut calls: Span) -> Array> { let mut result: Array> = ArrayTrait::new(); let mut calls = calls; diff --git a/src/interfaces/IAccount.cairo b/src/interfaces/IAccount.cairo index 58a21cc..7f79809 100644 --- a/src/interfaces/IAccount.cairo +++ b/src/interfaces/IAccount.cairo @@ -2,6 +2,9 @@ use starknet::ContractAddress; use starknet::ClassHash; use starknet::account::Call; +// SRC5 interface for token bound accounts +const TBA_INTERFACE_ID: felt252 = 0x539036932a2ab9c4734fbfd9872a1f7791a3f577e45477336ae0fd0a00c9ff; + #[starknet::interface] trait IAccount { fn is_valid_signature( diff --git a/src/lib.cairo b/src/lib.cairo index e35e00b..4ad41ba 100644 --- a/src/lib.cairo +++ b/src/lib.cairo @@ -1,4 +1,5 @@ mod registry; mod account; mod interfaces; +mod presets; mod test_helper; diff --git a/src/presets.cairo b/src/presets.cairo new file mode 100644 index 0000000..74885b0 --- /dev/null +++ b/src/presets.cairo @@ -0,0 +1,3 @@ +mod account; + +use account::Account; \ No newline at end of file diff --git a/src/presets/account.cairo b/src/presets/account.cairo new file mode 100644 index 0000000..0430b15 --- /dev/null +++ b/src/presets/account.cairo @@ -0,0 +1,34 @@ +//////////////////////////////// +// Account contract +//////////////////////////////// +#[starknet::contract] +mod Account { + use token_bound_accounts::account::AccountComponent; + use starknet::ContractAddress; + + component!(path: AccountComponent, storage: account, event: AccountEvent); + + // Account + #[abi(embed_v0)] + impl AccountImpl = AccountComponent::AccountImpl; + impl AccountInternalImpl = AccountComponent::InternalImpl; + + #[storage] + struct Storage { + #[substorage(v0)] + account: AccountComponent::Storage + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + #[flat] + AccountEvent: AccountComponent::Event + } + + #[constructor] + fn constructor(ref self: ContractState, token_contract: ContractAddress, token_id: u256) { + self.account.initializer(token_contract, token_id); + } + +} diff --git a/src/registry/registry.cairo b/src/registry/registry.cairo index 8b9b442..f2beef7 100644 --- a/src/registry/registry.cairo +++ b/src/registry/registry.cairo @@ -20,7 +20,7 @@ mod Registry { #[storage] struct Storage { - registry_deployed_accounts: LegacyMap< + Registry_deployed_accounts: LegacyMap< (ContractAddress, u256), u8 >, // tracks no. of deployed accounts by registry for an NFT } @@ -68,10 +68,10 @@ mod Registry { let (account_address, _) = result.unwrap_syscall(); let new_deployment_index: u8 = self - .registry_deployed_accounts + .Registry_deployed_accounts .read((token_contract, token_id)) + 1_u8; - self.registry_deployed_accounts.write((token_contract, token_id), new_deployment_index); + self.Registry_deployed_accounts.write((token_contract, token_id), new_deployment_index); self.emit(AccountCreated { account_address, token_contract, token_id, }); @@ -116,7 +116,7 @@ mod Registry { fn total_deployed_accounts( self: @ContractState, token_contract: ContractAddress, token_id: u256 ) -> u8 { - self.registry_deployed_accounts.read((token_contract, token_id)) + self.Registry_deployed_accounts.read((token_contract, token_id)) } } diff --git a/tests/test_account.cairo b/tests/test_account.cairo index f51a274..b3e36da 100644 --- a/tests/test_account.cairo +++ b/tests/test_account.cairo @@ -13,7 +13,7 @@ use token_bound_accounts::interfaces::IAccount::IAccountDispatcher; use token_bound_accounts::interfaces::IAccount::IAccountDispatcherTrait; use token_bound_accounts::interfaces::IAccount::IAccountSafeDispatcher; use token_bound_accounts::interfaces::IAccount::IAccountSafeDispatcherTrait; -use token_bound_accounts::account::account::Account; +use token_bound_accounts::presets::account::Account; use token_bound_accounts::test_helper::hello_starknet::IHelloStarknetDispatcher; use token_bound_accounts::test_helper::hello_starknet::IHelloStarknetDispatcherTrait; diff --git a/tests/test_registry.cairo b/tests/test_registry.cairo index 38ecf25..7e1f50a 100644 --- a/tests/test_registry.cairo +++ b/tests/test_registry.cairo @@ -14,7 +14,7 @@ use token_bound_accounts::registry::registry::Registry; use token_bound_accounts::interfaces::IAccount::IAccountDispatcher; use token_bound_accounts::interfaces::IAccount::IAccountDispatcherTrait; -use token_bound_accounts::account::account::Account; +use token_bound_accounts::presets::account::Account; use token_bound_accounts::test_helper::erc721_helper::IERC721Dispatcher; use token_bound_accounts::test_helper::erc721_helper::IERC721DispatcherTrait; From 8f760a1debe6fc0a3163e6ef3a48153b82d951c4 Mon Sep 17 00:00:00 2001 From: Arn0d Date: Fri, 1 Mar 2024 00:47:49 +0100 Subject: [PATCH 2/6] rework registry as components and add preset registry --- src/presets.cairo | 4 +++- src/presets/registry.cairo | 28 ++++++++++++++++++++++++++++ src/registry.cairo | 2 ++ src/registry/registry.cairo | 26 +++++++++++++++++--------- tests/test_registry.cairo | 2 +- 5 files changed, 51 insertions(+), 11 deletions(-) create mode 100644 src/presets/registry.cairo diff --git a/src/presets.cairo b/src/presets.cairo index 74885b0..e373209 100644 --- a/src/presets.cairo +++ b/src/presets.cairo @@ -1,3 +1,5 @@ mod account; +mod registry; -use account::Account; \ No newline at end of file +use account::Account; +use registry::Registry; \ No newline at end of file diff --git a/src/presets/registry.cairo b/src/presets/registry.cairo new file mode 100644 index 0000000..adc3fce --- /dev/null +++ b/src/presets/registry.cairo @@ -0,0 +1,28 @@ +//////////////////////////////// +// Registry contract +//////////////////////////////// +#[starknet::contract] +mod Registry { + use token_bound_accounts::registry::RegistryComponent; + + component!(path: RegistryComponent, storage: registry, event: RegistryEvent); + + // Account + #[abi(embed_v0)] + impl RegistryImpl = RegistryComponent::RegistryImpl; + impl AccountInternalImpl = RegistryComponent::InternalImpl; + + #[storage] + struct Storage { + #[substorage(v0)] + registry: RegistryComponent::Storage + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + #[flat] + RegistryEvent: RegistryComponent::Event + } + +} diff --git a/src/registry.cairo b/src/registry.cairo index 516f5b2..f199435 100644 --- a/src/registry.cairo +++ b/src/registry.cairo @@ -1 +1,3 @@ mod registry; + +use registry::RegistryComponent; \ No newline at end of file diff --git a/src/registry/registry.cairo b/src/registry/registry.cairo index f2beef7..ddf5ee6 100644 --- a/src/registry/registry.cairo +++ b/src/registry/registry.cairo @@ -1,8 +1,8 @@ //////////////////////////////// // Registry contract //////////////////////////////// -#[starknet::contract] -mod Registry { +#[starknet::component] +mod RegistryComponent { use core::result::ResultTrait; use core::hash::HashStateTrait; use starknet::{ @@ -42,15 +42,19 @@ mod Registry { token_id: u256, } - #[external(v0)] - impl IRegistryImpl of IRegistry { + #[embeddable_as(RegistryImpl)] + impl Registry< + TContractState, + +HasComponent, + +Drop + > of IRegistry> { /// @notice deploys a new tokenbound account for an NFT /// @param implementation_hash the class hash of the reference account /// @param token_contract the contract address of the NFT /// @param token_id the ID of the NFT /// @param salt random salt for deployment fn create_account( - ref self: ContractState, + ref self: ComponentState, implementation_hash: felt252, token_contract: ContractAddress, token_id: u256, @@ -84,7 +88,7 @@ mod Registry { /// @param token_id the ID of the NFT /// @param salt random salt for deployment fn get_account( - self: @ContractState, + self: @ComponentState, implementation_hash: felt252, token_contract: ContractAddress, token_id: u256, @@ -114,20 +118,24 @@ mod Registry { /// @param token_contract the contract address of the NFT /// @param token_id the ID of the NFT fn total_deployed_accounts( - self: @ContractState, token_contract: ContractAddress, token_id: u256 + self: @ComponentState, token_contract: ContractAddress, token_id: u256 ) -> u8 { self.Registry_deployed_accounts.read((token_contract, token_id)) } } #[generate_trait] - impl internalImpl of InternalTrait { + impl InternalImpl< + TContractState, + +HasComponent, + +Drop + > of InternalTrait { /// @notice internal function for getting NFT owner /// @param token_contract contract address of NFT // @param token_id token ID of NFT // NB: This function aims for compatibility with all contracts (snake or camel case) but do not work as expected on mainnet as low level calls do not return err at the moment. Should work for contracts which implements CamelCase but not snake_case until starknet v0.15. fn _get_owner( - self: @ContractState, token_contract: ContractAddress, token_id: u256 + self: @ComponentState, token_contract: ContractAddress, token_id: u256 ) -> ContractAddress { let mut calldata: Array = ArrayTrait::new(); Serde::serialize(@token_id, ref calldata); diff --git a/tests/test_registry.cairo b/tests/test_registry.cairo index 7e1f50a..899e5a5 100644 --- a/tests/test_registry.cairo +++ b/tests/test_registry.cairo @@ -10,7 +10,7 @@ use snforge_std::{ use token_bound_accounts::interfaces::IRegistry::IRegistryDispatcherTrait; use token_bound_accounts::interfaces::IRegistry::IRegistryDispatcher; -use token_bound_accounts::registry::registry::Registry; +use token_bound_accounts::presets::registry::Registry; use token_bound_accounts::interfaces::IAccount::IAccountDispatcher; use token_bound_accounts::interfaces::IAccount::IAccountDispatcherTrait; From c6c43afe33e272c63ee77f552d5b46568df07d23 Mon Sep 17 00:00:00 2001 From: Arn0d Date: Fri, 1 Mar 2024 02:38:18 +0100 Subject: [PATCH 3/6] Add upgradeable components and tests --- src/account/account.cairo | 43 +++++++++++-------------------- src/interfaces.cairo | 1 + src/interfaces/IAccount.cairo | 1 - src/interfaces/IUpgradeable.cairo | 6 +++++ src/lib.cairo | 1 + src/presets/account.cairo | 26 ++++++++++++++++--- src/registry/registry.cairo | 2 +- src/upgradeable.cairo | 3 +++ src/upgradeable/upgradeable.cairo | 43 +++++++++++++++++++++++++++++++ tests/test_account.cairo | 38 ++++++++++++++------------- 10 files changed, 113 insertions(+), 51 deletions(-) create mode 100644 src/interfaces/IUpgradeable.cairo create mode 100644 src/upgradeable.cairo create mode 100644 src/upgradeable/upgradeable.cairo diff --git a/src/account/account.cairo b/src/account/account.cairo index 5e96303..bf6f4f1 100644 --- a/src/account/account.cairo +++ b/src/account/account.cairo @@ -1,5 +1,5 @@ //////////////////////////////// -// Account contract +// Account Component //////////////////////////////// #[starknet::component] mod AccountComponent { @@ -27,7 +27,6 @@ mod AccountComponent { #[derive(Drop, starknet::Event)] enum Event { AccountCreated: AccountCreated, - AccountUpgraded: AccountUpgraded, AccountLocked: AccountLocked, TransactionExecuted: TransactionExecuted } @@ -50,15 +49,6 @@ mod AccountComponent { response: Span> } - /// @notice Emitted when the account upgrades to a new implementation - /// @param account tokenbound account to be upgraded - /// @param implementation the upgraded account class hash - #[derive(Drop, starknet::Event)] - struct AccountUpgraded { - account: ContractAddress, - implementation: ClassHash - } - /// @notice Emitted when the account is locked /// @param account tokenbound account who's lock function was triggered /// @param locked_at timestamp at which the lock function was triggered @@ -71,6 +61,14 @@ mod AccountComponent { duration: u64, } + mod Errors { + const LOCKED_ACCOUNT: felt252 = 'Account: account is locked!'; + const INV_TX_VERSION: felt252 = 'Account: invalid tx version'; + const UNAUTHORIZED: felt252 = 'Account: unauthorized'; + const INV_SIG_LEN: felt252 = 'Account: invalid sig length'; + const INV_SIGNATURE: felt252 = 'Account: invalid signature'; + } + #[embeddable_as(AccountImpl)] impl Account< TContractState, @@ -107,10 +105,10 @@ mod AccountComponent { fn __execute__(ref self: ComponentState, mut calls: Array) -> Array> { self._assert_only_owner(); let (lock_status, _) = self._is_locked(); - assert(!lock_status, 'Account: account is locked!'); + assert(!lock_status, Errors::LOCKED_ACCOUNT); let tx_info = get_tx_info().unbox(); - assert(tx_info.version != 0, 'invalid tx version'); + assert(tx_info.version != 0, Errors::INV_TX_VERSION); let retdata = self._execute_calls(calls.span()); let hash = tx_info.transaction_hash; @@ -133,23 +131,12 @@ mod AccountComponent { self._get_token() } - /// @notice ugprades an account implementation - /// @param implementation the new class_hash - fn upgrade(ref self: ComponentState, implementation: ClassHash) { - self._assert_only_owner(); - let (lock_status, _) = self._is_locked(); - assert(!lock_status, 'Account: account is locked!'); - assert(!implementation.is_zero(), 'Invalid class hash'); - replace_class_syscall(implementation).unwrap_syscall(); - self.emit(AccountUpgraded { account: get_contract_address(), implementation, }); - } - // @notice protection mechanism for selling token bound accounts. can't execute when account is locked // @param duration for which to lock account fn lock(ref self: ComponentState, duration: u64) { self._assert_only_owner(); let (lock_status, _) = self._is_locked(); - assert(!lock_status, 'Account: account already locked'); + assert(!lock_status, Errors::LOCKED_ACCOUNT); let current_timestamp = get_block_timestamp(); let unlock_time = current_timestamp + duration; self.Account_unlock_timestamp.write(unlock_time); @@ -198,7 +185,7 @@ mod AccountComponent { fn _assert_only_owner(ref self: ComponentState) { let caller = get_caller_address(); let owner = self._get_owner(self.Account_token_contract.read(), self.Account_token_id.read()); - assert(caller == owner, 'Account: unathorized'); + assert(caller == owner, Errors::UNAUTHORIZED); } /// @notice internal function for getting NFT owner @@ -246,7 +233,7 @@ mod AccountComponent { let signature = tx_info.signature; assert( self._is_valid_signature(tx_hash, signature) == starknet::VALIDATED, - 'Account: invalid signature' + Errors::INV_SIGNATURE ); starknet::VALIDATED } @@ -256,7 +243,7 @@ mod AccountComponent { self: @ComponentState, hash: felt252, signature: Span ) -> felt252 { let signature_length = signature.len(); - assert(signature_length == 2_u32, 'Account: invalid sig length'); + assert(signature_length == 2_u32, Errors::INV_SIG_LEN); let caller = get_caller_address(); let owner = self._get_owner(self.Account_token_contract.read(), self.Account_token_id.read()); diff --git a/src/interfaces.cairo b/src/interfaces.cairo index 86a419e..5721060 100644 --- a/src/interfaces.cairo +++ b/src/interfaces.cairo @@ -1,3 +1,4 @@ mod IAccount; mod IERC721; mod IRegistry; +mod IUpgradeable; \ No newline at end of file diff --git a/src/interfaces/IAccount.cairo b/src/interfaces/IAccount.cairo index 7f79809..0e8645e 100644 --- a/src/interfaces/IAccount.cairo +++ b/src/interfaces/IAccount.cairo @@ -20,7 +20,6 @@ trait IAccount { fn owner( self: @TContractState, token_contract: ContractAddress, token_id: u256 ) -> ContractAddress; - fn upgrade(ref self: TContractState, implementation: ClassHash); fn lock(ref self: TContractState, duration: u64); fn is_locked(self: @TContractState) -> (bool, u64); fn supports_interface(self: @TContractState, interface_id: felt252) -> bool; diff --git a/src/interfaces/IUpgradeable.cairo b/src/interfaces/IUpgradeable.cairo new file mode 100644 index 0000000..0999f73 --- /dev/null +++ b/src/interfaces/IUpgradeable.cairo @@ -0,0 +1,6 @@ +use starknet::ClassHash; + +#[starknet::interface] +trait IUpgradeable { + fn upgrade(ref self: TContractState, new_class_hash: ClassHash); +} diff --git a/src/lib.cairo b/src/lib.cairo index 4ad41ba..e652203 100644 --- a/src/lib.cairo +++ b/src/lib.cairo @@ -3,3 +3,4 @@ mod account; mod interfaces; mod presets; mod test_helper; +mod upgradeable; diff --git a/src/presets/account.cairo b/src/presets/account.cairo index 0430b15..94cfe07 100644 --- a/src/presets/account.cairo +++ b/src/presets/account.cairo @@ -3,27 +3,38 @@ //////////////////////////////// #[starknet::contract] mod Account { - use token_bound_accounts::account::AccountComponent; use starknet::ContractAddress; + use starknet::ClassHash; + use token_bound_accounts::account::AccountComponent; + use token_bound_accounts::upgradeable::UpgradeableComponent; + use token_bound_accounts::interfaces::IUpgradeable::IUpgradeable; component!(path: AccountComponent, storage: account, event: AccountEvent); + component!(path: UpgradeableComponent, storage: upgradeable, event: UpgradeableEvent); // Account #[abi(embed_v0)] impl AccountImpl = AccountComponent::AccountImpl; impl AccountInternalImpl = AccountComponent::InternalImpl; + // Upgradeable + impl UpgradeableInternalImpl = UpgradeableComponent::InternalImpl; + #[storage] struct Storage { #[substorage(v0)] - account: AccountComponent::Storage + account: AccountComponent::Storage, + #[substorage(v0)] + upgradeable: UpgradeableComponent::Storage } #[event] #[derive(Drop, starknet::Event)] enum Event { #[flat] - AccountEvent: AccountComponent::Event + AccountEvent: AccountComponent::Event, + #[flat] + UpgradeableEvent: UpgradeableComponent::Event } #[constructor] @@ -31,4 +42,13 @@ mod Account { self.account.initializer(token_contract, token_id); } + #[external(v0)] + impl UpgradeableImpl of IUpgradeable { + fn upgrade(ref self: ContractState, new_class_hash: ClassHash) { + self.account._assert_only_owner(); + let (lock_status, _) = self.account._is_locked(); + assert(!lock_status, AccountComponent::Errors::LOCKED_ACCOUNT); + self.upgradeable._upgrade(new_class_hash); + } + } } diff --git a/src/registry/registry.cairo b/src/registry/registry.cairo index ddf5ee6..57a5f47 100644 --- a/src/registry/registry.cairo +++ b/src/registry/registry.cairo @@ -1,5 +1,5 @@ //////////////////////////////// -// Registry contract +// Registry Component //////////////////////////////// #[starknet::component] mod RegistryComponent { diff --git a/src/upgradeable.cairo b/src/upgradeable.cairo new file mode 100644 index 0000000..3d4ca2e --- /dev/null +++ b/src/upgradeable.cairo @@ -0,0 +1,3 @@ +mod upgradeable; + +use upgradeable::UpgradeableComponent; \ No newline at end of file diff --git a/src/upgradeable/upgradeable.cairo b/src/upgradeable/upgradeable.cairo new file mode 100644 index 0000000..de83cfe --- /dev/null +++ b/src/upgradeable/upgradeable.cairo @@ -0,0 +1,43 @@ +//////////////////////////////// +// Upgradeable Component +//////////////////////////////// +#[starknet::component] +mod UpgradeableComponent { + use starknet::ClassHash; + use starknet::SyscallResultTrait; + use token_bound_accounts::interfaces::IUpgradeable; + + #[storage] + struct Storage {} + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + Upgraded: Upgraded + } + + /// Emitted when the contract is upgraded. + #[derive(Drop, starknet::Event)] + struct Upgraded { + class_hash: ClassHash + } + + mod Errors { + const INVALID_CLASS: felt252 = 'Class hash cannot be zero'; + } + + #[generate_trait] + impl InternalImpl< + TContractState, + +HasComponent, + +Drop + > of InternalTrait { + /// @notice eplaces the contract's class hash with `new_class_hash`. + /// Emits an `Upgraded` event. + fn _upgrade(ref self: ComponentState, new_class_hash: ClassHash) { + assert(!new_class_hash.is_zero(), Errors::INVALID_CLASS); + starknet::replace_class_syscall(new_class_hash).unwrap_syscall(); + self.emit(Upgraded { class_hash: new_class_hash }); + } + } +} diff --git a/tests/test_account.cairo b/tests/test_account.cairo index b3e36da..8e4fa73 100644 --- a/tests/test_account.cairo +++ b/tests/test_account.cairo @@ -15,6 +15,10 @@ use token_bound_accounts::interfaces::IAccount::IAccountSafeDispatcher; use token_bound_accounts::interfaces::IAccount::IAccountSafeDispatcherTrait; use token_bound_accounts::presets::account::Account; +use token_bound_accounts::interfaces::IUpgradeable::IUpgradeableDispatcher; +use token_bound_accounts::interfaces::IUpgradeable::IUpgradeableDispatcherTrait; + + use token_bound_accounts::test_helper::hello_starknet::IHelloStarknetDispatcher; use token_bound_accounts::test_helper::hello_starknet::IHelloStarknetDispatcherTrait; use token_bound_accounts::test_helper::hello_starknet::HelloStarknet; @@ -206,7 +210,7 @@ fn test_owner() { #[test] fn test_upgrade() { let (contract_address, erc721_contract_address) = __setup__(); - let dispatcher = IAccountDispatcher { contract_address }; + // let dispatcher = IAccountDispatcher { contract_address }; let new_class_hash = declare('UpgradedAccount').class_hash; @@ -215,6 +219,7 @@ fn test_upgrade() { let token_owner = token_dispatcher.ownerOf(u256_from_felt252(1)); // call the upgrade function + let dispatcher = IUpgradeableDispatcher { contract_address }; start_prank(CheatTarget::One(contract_address), token_owner); dispatcher.upgrade(new_class_hash); @@ -223,18 +228,20 @@ fn test_upgrade() { let version = upgraded_dispatcher.version(); assert(version == 1_u8, 'upgrade unsuccessful'); stop_prank(CheatTarget::One(contract_address)); +} + +#[test] +#[should_panic(expected: ( 'Account: unauthorized',))] +fn test_upgrade_with_unauthorized() { + let (contract_address, erc721_contract_address) = __setup__(); + + let new_class_hash = declare('UpgradedAccount').class_hash; // call upgrade function with an unauthorized address start_prank(CheatTarget::One(contract_address), ACCOUNT2.try_into().unwrap()); - let safe_upgrade_dispatcher = IAccountSafeDispatcher { contract_address }; - match safe_upgrade_dispatcher.upgrade(new_class_hash) { - Result::Ok(_) => panic_with_felt252('expected to panic'), - Result::Err(panic_data) => { - stop_prank(CheatTarget::One(contract_address)); - panic_data.print(); - return (); - } - } + // let safe_upgrade_dispatcher = IAccountSafeDispatcher { contract_address }; + let safe_upgrade_dispatcher = IUpgradeableDispatcher { contract_address }; + safe_upgrade_dispatcher.upgrade(new_class_hash); } #[test] @@ -302,6 +309,7 @@ fn test_should_not_execute_when_locked() { } #[test] +#[should_panic(expected: ('Account: account is locked!',))] fn test_should_not_upgrade_when_locked() { let (contract_address, erc721_contract_address) = __setup__(); let dispatcher = IAccountSafeDispatcher { contract_address }; @@ -320,14 +328,8 @@ fn test_should_not_upgrade_when_locked() { let new_class_hash = declare('UpgradedAccount').class_hash; // call the upgrade function - match dispatcher.upgrade(new_class_hash) { - Result::Ok(_) => panic_with_felt252('should have panicked'), - Result::Err(panic_data) => { - stop_prank(CheatTarget::One(contract_address)); - panic_data.print(); - return (); - } - } + let dispatcher_upgradable = IUpgradeableDispatcher { contract_address }; + dispatcher_upgradable.upgrade(new_class_hash); } #[test] From 413eb4a197a3c7494b5b7a0458ff5b6fafcd5801 Mon Sep 17 00:00:00 2001 From: Arn0d Date: Fri, 1 Mar 2024 03:11:45 +0100 Subject: [PATCH 4/6] Add upgradeablefor Registry components and test --- src/presets/registry.cairo | 22 +++- src/test_helper.cairo | 1 + src/test_helper/registry_upgrade.cairo | 138 +++++++++++++++++++++++++ tests/test_account.cairo | 1 - tests/test_registry.cairo | 59 +++++++++++ 5 files changed, 218 insertions(+), 3 deletions(-) create mode 100644 src/test_helper/registry_upgrade.cairo diff --git a/src/presets/registry.cairo b/src/presets/registry.cairo index adc3fce..aefd2c1 100644 --- a/src/presets/registry.cairo +++ b/src/presets/registry.cairo @@ -3,26 +3,44 @@ //////////////////////////////// #[starknet::contract] mod Registry { + use starknet::ClassHash; use token_bound_accounts::registry::RegistryComponent; + use token_bound_accounts::upgradeable::UpgradeableComponent; + use token_bound_accounts::interfaces::IUpgradeable::IUpgradeable; component!(path: RegistryComponent, storage: registry, event: RegistryEvent); + component!(path: UpgradeableComponent, storage: upgradeable, event: UpgradeableEvent); // Account #[abi(embed_v0)] impl RegistryImpl = RegistryComponent::RegistryImpl; impl AccountInternalImpl = RegistryComponent::InternalImpl; + // Upgradeable + impl UpgradeableInternalImpl = UpgradeableComponent::InternalImpl; + #[storage] struct Storage { #[substorage(v0)] - registry: RegistryComponent::Storage + registry: RegistryComponent::Storage, + #[substorage(v0)] + upgradeable: UpgradeableComponent::Storage } #[event] #[derive(Drop, starknet::Event)] enum Event { #[flat] - RegistryEvent: RegistryComponent::Event + RegistryEvent: RegistryComponent::Event, + #[flat] + UpgradeableEvent: UpgradeableComponent::Event } + #[external(v0)] + impl UpgradeableImpl of IUpgradeable { + fn upgrade(ref self: ContractState, new_class_hash: ClassHash) { + self.upgradeable._upgrade(new_class_hash); + } + } + } diff --git a/src/test_helper.cairo b/src/test_helper.cairo index 57f17fb..53d5b9b 100644 --- a/src/test_helper.cairo +++ b/src/test_helper.cairo @@ -1,3 +1,4 @@ mod hello_starknet; mod account_upgrade; mod erc721_helper; +mod registry_upgrade; \ No newline at end of file diff --git a/src/test_helper/registry_upgrade.cairo b/src/test_helper/registry_upgrade.cairo new file mode 100644 index 0000000..29a3da4 --- /dev/null +++ b/src/test_helper/registry_upgrade.cairo @@ -0,0 +1,138 @@ +use array::{ArrayTrait, SpanTrait}; +use starknet::{account::Call, ContractAddress, ClassHash}; + +#[starknet::interface] +trait IUpgradedRegistry { + fn create_account( + ref self: TContractState, + implementation_hash: felt252, + token_contract: ContractAddress, + token_id: u256, + salt: felt252 + ) -> ContractAddress; + fn get_account( + self: @TContractState, + implementation_hash: felt252, + token_contract: ContractAddress, + token_id: u256, + salt: felt252 + ) -> ContractAddress; + fn total_deployed_accounts( + self: @TContractState, token_contract: ContractAddress, token_id: u256 + ) -> u8; + fn upgrade(ref self: TContractState, new_class_hash: ClassHash); + fn version(self: @TContractState) -> u8; +} + +#[starknet::contract] +mod UpgradedRegistry { + use core::hash::HashStateTrait; + use starknet::{ + ContractAddress, SyscallResultTrait, syscalls::call_contract_syscall, syscalls::deploy_syscall, get_caller_address, ClassHash + }; + use pedersen::PedersenTrait; + + + #[storage] + struct Storage { + Registry_deployed_accounts: LegacyMap< + (ContractAddress, u256), u8 + >, // tracks no. of deployed accounts by registry for an NFT + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + AccountCreated: AccountCreated, + Upgraded: Upgraded + } + + #[derive(Drop, starknet::Event)] + struct AccountCreated { + account_address: ContractAddress, + token_contract: ContractAddress, + token_id: u256, + } + + #[derive(Drop, starknet::Event)] + struct Upgraded { + class_hash: ClassHash + } + + + #[external(v0)] + impl RegistryImpl of super::IUpgradedRegistry { + + fn create_account( + ref self: ContractState, + implementation_hash: felt252, + token_contract: ContractAddress, + token_id: u256, + salt: felt252 + ) -> ContractAddress { + let mut constructor_calldata: Array = array![ + token_contract.into(), token_id.low.into(), token_id.high.into() + ]; + + let class_hash: ClassHash = implementation_hash.try_into().unwrap(); + let result = deploy_syscall(class_hash, salt, constructor_calldata.span(), true); + let (account_address, _) = result.unwrap_syscall(); + + let new_deployment_index: u8 = self + .Registry_deployed_accounts + .read((token_contract, token_id)) + + 1_u8; + self.Registry_deployed_accounts.write((token_contract, token_id), new_deployment_index); + + self.emit(AccountCreated { account_address, token_contract, token_id, }); + + account_address + } + + fn get_account( + self: @ContractState, + implementation_hash: felt252, + token_contract: ContractAddress, + token_id: u256, + salt: felt252 + ) -> ContractAddress { + let constructor_calldata_hash = PedersenTrait::new(0) + .update(token_contract.into()) + .update(token_id.low.into()) + .update(token_id.high.into()) + .update(3) + .finalize(); + + let prefix: felt252 = 'STARKNET_CONTRACT_ADDRESS'; + let account_address = PedersenTrait::new(0) + .update(prefix) + .update(0) + .update(salt) + .update(implementation_hash) + .update(constructor_calldata_hash) + .update(5) + .finalize(); + + account_address.try_into().unwrap() + } + + /// @notice returns the total no. of deployed tokenbound accounts for an NFT by the registry + /// @param token_contract the contract address of the NFT + /// @param token_id the ID of the NFT + fn total_deployed_accounts( + self: @ContractState, token_contract: ContractAddress, token_id: u256 + ) -> u8 { + self.Registry_deployed_accounts.read((token_contract, token_id)) + } + + fn upgrade(ref self: ContractState, new_class_hash: ClassHash) { + assert(!new_class_hash.is_zero(), 'Class hash cannot be zero'); + starknet::replace_class_syscall(new_class_hash).unwrap_syscall(); + self.emit(Upgraded { class_hash: new_class_hash }); + } + + fn version(self: @ContractState) -> u8 { + 1_u8 + } + } +} diff --git a/tests/test_account.cairo b/tests/test_account.cairo index 8e4fa73..3cc958b 100644 --- a/tests/test_account.cairo +++ b/tests/test_account.cairo @@ -210,7 +210,6 @@ fn test_owner() { #[test] fn test_upgrade() { let (contract_address, erc721_contract_address) = __setup__(); - // let dispatcher = IAccountDispatcher { contract_address }; let new_class_hash = declare('UpgradedAccount').class_hash; diff --git a/tests/test_registry.cairo b/tests/test_registry.cairo index 899e5a5..a2805c3 100644 --- a/tests/test_registry.cairo +++ b/tests/test_registry.cairo @@ -12,6 +12,13 @@ use token_bound_accounts::interfaces::IRegistry::IRegistryDispatcherTrait; use token_bound_accounts::interfaces::IRegistry::IRegistryDispatcher; use token_bound_accounts::presets::registry::Registry; +use token_bound_accounts::test_helper::registry_upgrade::IUpgradedRegistryDispatcher; +use token_bound_accounts::test_helper::registry_upgrade::IUpgradedRegistryDispatcherTrait; +use token_bound_accounts::test_helper::registry_upgrade::UpgradedRegistry; + +use token_bound_accounts::interfaces::IUpgradeable::IUpgradeableDispatcher; +use token_bound_accounts::interfaces::IUpgradeable::IUpgradeableDispatcherTrait; + use token_bound_accounts::interfaces::IAccount::IAccountDispatcher; use token_bound_accounts::interfaces::IAccount::IAccountDispatcherTrait; use token_bound_accounts::presets::account::Account; @@ -143,3 +150,55 @@ fn test_get_account() { // compare both addresses assert(account == account_address, 'get_account computes wrongly'); } + +// Upgradeable test cases + +#[test] +fn test_upgrade() { + let (registry_contract_address, erc721_contract_address) = __setup__(); + let registry_dispatcher = IRegistryDispatcher { contract_address: registry_contract_address }; + + // prank contract as token owner + let token_dispatcher = IERC721Dispatcher { contract_address: erc721_contract_address }; + let token_owner = token_dispatcher.ownerOf(u256_from_felt252(1)); + start_prank(CheatTarget::One(registry_contract_address), token_owner); + + // create account + let acct_class_hash = declare('Account').class_hash; + let account_address = registry_dispatcher + .create_account( + class_hash_to_felt252(acct_class_hash), + erc721_contract_address, + u256_from_felt252(1), + 245828 + ); + + // check total_deployed_accounts + let total_deployed_accounts = registry_dispatcher + .total_deployed_accounts(erc721_contract_address, u256_from_felt252(1)); + assert(total_deployed_accounts == 1_u8, 'invalid deployed TBA count'); + + // confirm account deployment by checking the account owner + let acct_dispatcher = IAccountDispatcher { contract_address: account_address }; + let TBA_owner = acct_dispatcher.owner(erc721_contract_address, u256_from_felt252(1)); + assert(TBA_owner == token_owner, 'acct deployed wrongly'); + + /////////////////////////// upgrade account /////////////////////////// + + let new_class_hash = declare('UpgradedRegistry').class_hash; + + // get token owner + let token_dispatcher = IERC721Dispatcher { contract_address: erc721_contract_address }; + let token_owner = token_dispatcher.ownerOf(u256_from_felt252(1)); + + // call the upgrade function + let dispatcher = IUpgradeableDispatcher { contract_address: registry_contract_address }; + start_prank(CheatTarget::One(registry_contract_address), token_owner); + dispatcher.upgrade(new_class_hash); + + // try to call the version function + let upgraded_dispatcher = IUpgradedRegistryDispatcher { contract_address: registry_contract_address }; + let version = upgraded_dispatcher.version(); + assert(version == 1_u8, 'upgrade unsuccessful'); + stop_prank(CheatTarget::One(registry_contract_address)); +} \ No newline at end of file From 5cb5345050fa3babbdf7a4d774afe749e0d820b6 Mon Sep 17 00:00:00 2001 From: Arn0d Date: Fri, 1 Mar 2024 03:30:23 +0100 Subject: [PATCH 5/6] Add error message for caller not being owner --- src/registry/registry.cairo | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/registry/registry.cairo b/src/registry/registry.cairo index 57a5f47..44c9180 100644 --- a/src/registry/registry.cairo +++ b/src/registry/registry.cairo @@ -42,6 +42,10 @@ mod RegistryComponent { token_id: u256, } + mod Errors { + const CALLER_IS_NOT_OWNER: felt252 = 'Registry: caller is not onwer'; + } + #[embeddable_as(RegistryImpl)] impl Registry< TContractState, @@ -61,7 +65,7 @@ mod RegistryComponent { salt: felt252 ) -> ContractAddress { let owner = self._get_owner(token_contract, token_id); - assert(owner == get_caller_address(), 'CALLER_IS_NOT_OWNER'); + assert(owner == get_caller_address(), Errors::CALLER_IS_NOT_OWNER); let mut constructor_calldata: Array = array![ token_contract.into(), token_id.low.into(), token_id.high.into() From bc029979459bdadda87a2fd2bf06af64f46ac34f Mon Sep 17 00:00:00 2001 From: Arn0d Date: Fri, 1 Mar 2024 13:59:51 +0100 Subject: [PATCH 6/6] Revert Registry Components rework and Upgradable --- src/presets.cairo | 4 +-- src/presets/registry.cairo | 46 ------------------------------- src/registry.cairo | 2 -- src/registry/registry.cairo | 26 +++++++----------- tests/test_registry.cairo | 54 +------------------------------------ 5 files changed, 11 insertions(+), 121 deletions(-) delete mode 100644 src/presets/registry.cairo diff --git a/src/presets.cairo b/src/presets.cairo index e373209..74885b0 100644 --- a/src/presets.cairo +++ b/src/presets.cairo @@ -1,5 +1,3 @@ mod account; -mod registry; -use account::Account; -use registry::Registry; \ No newline at end of file +use account::Account; \ No newline at end of file diff --git a/src/presets/registry.cairo b/src/presets/registry.cairo deleted file mode 100644 index aefd2c1..0000000 --- a/src/presets/registry.cairo +++ /dev/null @@ -1,46 +0,0 @@ -//////////////////////////////// -// Registry contract -//////////////////////////////// -#[starknet::contract] -mod Registry { - use starknet::ClassHash; - use token_bound_accounts::registry::RegistryComponent; - use token_bound_accounts::upgradeable::UpgradeableComponent; - use token_bound_accounts::interfaces::IUpgradeable::IUpgradeable; - - component!(path: RegistryComponent, storage: registry, event: RegistryEvent); - component!(path: UpgradeableComponent, storage: upgradeable, event: UpgradeableEvent); - - // Account - #[abi(embed_v0)] - impl RegistryImpl = RegistryComponent::RegistryImpl; - impl AccountInternalImpl = RegistryComponent::InternalImpl; - - // Upgradeable - impl UpgradeableInternalImpl = UpgradeableComponent::InternalImpl; - - #[storage] - struct Storage { - #[substorage(v0)] - registry: RegistryComponent::Storage, - #[substorage(v0)] - upgradeable: UpgradeableComponent::Storage - } - - #[event] - #[derive(Drop, starknet::Event)] - enum Event { - #[flat] - RegistryEvent: RegistryComponent::Event, - #[flat] - UpgradeableEvent: UpgradeableComponent::Event - } - - #[external(v0)] - impl UpgradeableImpl of IUpgradeable { - fn upgrade(ref self: ContractState, new_class_hash: ClassHash) { - self.upgradeable._upgrade(new_class_hash); - } - } - -} diff --git a/src/registry.cairo b/src/registry.cairo index f199435..516f5b2 100644 --- a/src/registry.cairo +++ b/src/registry.cairo @@ -1,3 +1 @@ mod registry; - -use registry::RegistryComponent; \ No newline at end of file diff --git a/src/registry/registry.cairo b/src/registry/registry.cairo index 44c9180..4b445f5 100644 --- a/src/registry/registry.cairo +++ b/src/registry/registry.cairo @@ -1,8 +1,8 @@ //////////////////////////////// // Registry Component //////////////////////////////// -#[starknet::component] -mod RegistryComponent { +#[starknet::contract] +mod Registry { use core::result::ResultTrait; use core::hash::HashStateTrait; use starknet::{ @@ -46,19 +46,15 @@ mod RegistryComponent { const CALLER_IS_NOT_OWNER: felt252 = 'Registry: caller is not onwer'; } - #[embeddable_as(RegistryImpl)] - impl Registry< - TContractState, - +HasComponent, - +Drop - > of IRegistry> { + #[external(v0)] + impl IRegistryImpl of IRegistry { /// @notice deploys a new tokenbound account for an NFT /// @param implementation_hash the class hash of the reference account /// @param token_contract the contract address of the NFT /// @param token_id the ID of the NFT /// @param salt random salt for deployment fn create_account( - ref self: ComponentState, + ref self: ContractState, implementation_hash: felt252, token_contract: ContractAddress, token_id: u256, @@ -92,7 +88,7 @@ mod RegistryComponent { /// @param token_id the ID of the NFT /// @param salt random salt for deployment fn get_account( - self: @ComponentState, + self: @ContractState, implementation_hash: felt252, token_contract: ContractAddress, token_id: u256, @@ -122,24 +118,20 @@ mod RegistryComponent { /// @param token_contract the contract address of the NFT /// @param token_id the ID of the NFT fn total_deployed_accounts( - self: @ComponentState, token_contract: ContractAddress, token_id: u256 + self: @ContractState, token_contract: ContractAddress, token_id: u256 ) -> u8 { self.Registry_deployed_accounts.read((token_contract, token_id)) } } #[generate_trait] - impl InternalImpl< - TContractState, - +HasComponent, - +Drop - > of InternalTrait { + impl internalImpl of InternalTrait { /// @notice internal function for getting NFT owner /// @param token_contract contract address of NFT // @param token_id token ID of NFT // NB: This function aims for compatibility with all contracts (snake or camel case) but do not work as expected on mainnet as low level calls do not return err at the moment. Should work for contracts which implements CamelCase but not snake_case until starknet v0.15. fn _get_owner( - self: @ComponentState, token_contract: ContractAddress, token_id: u256 + self: @ContractState, token_contract: ContractAddress, token_id: u256 ) -> ContractAddress { let mut calldata: Array = ArrayTrait::new(); Serde::serialize(@token_id, ref calldata); diff --git a/tests/test_registry.cairo b/tests/test_registry.cairo index a2805c3..4ce7223 100644 --- a/tests/test_registry.cairo +++ b/tests/test_registry.cairo @@ -10,7 +10,7 @@ use snforge_std::{ use token_bound_accounts::interfaces::IRegistry::IRegistryDispatcherTrait; use token_bound_accounts::interfaces::IRegistry::IRegistryDispatcher; -use token_bound_accounts::presets::registry::Registry; +use token_bound_accounts::registry::registry::Registry; use token_bound_accounts::test_helper::registry_upgrade::IUpgradedRegistryDispatcher; use token_bound_accounts::test_helper::registry_upgrade::IUpgradedRegistryDispatcherTrait; @@ -149,56 +149,4 @@ fn test_get_account() { // compare both addresses assert(account == account_address, 'get_account computes wrongly'); -} - -// Upgradeable test cases - -#[test] -fn test_upgrade() { - let (registry_contract_address, erc721_contract_address) = __setup__(); - let registry_dispatcher = IRegistryDispatcher { contract_address: registry_contract_address }; - - // prank contract as token owner - let token_dispatcher = IERC721Dispatcher { contract_address: erc721_contract_address }; - let token_owner = token_dispatcher.ownerOf(u256_from_felt252(1)); - start_prank(CheatTarget::One(registry_contract_address), token_owner); - - // create account - let acct_class_hash = declare('Account').class_hash; - let account_address = registry_dispatcher - .create_account( - class_hash_to_felt252(acct_class_hash), - erc721_contract_address, - u256_from_felt252(1), - 245828 - ); - - // check total_deployed_accounts - let total_deployed_accounts = registry_dispatcher - .total_deployed_accounts(erc721_contract_address, u256_from_felt252(1)); - assert(total_deployed_accounts == 1_u8, 'invalid deployed TBA count'); - - // confirm account deployment by checking the account owner - let acct_dispatcher = IAccountDispatcher { contract_address: account_address }; - let TBA_owner = acct_dispatcher.owner(erc721_contract_address, u256_from_felt252(1)); - assert(TBA_owner == token_owner, 'acct deployed wrongly'); - - /////////////////////////// upgrade account /////////////////////////// - - let new_class_hash = declare('UpgradedRegistry').class_hash; - - // get token owner - let token_dispatcher = IERC721Dispatcher { contract_address: erc721_contract_address }; - let token_owner = token_dispatcher.ownerOf(u256_from_felt252(1)); - - // call the upgrade function - let dispatcher = IUpgradeableDispatcher { contract_address: registry_contract_address }; - start_prank(CheatTarget::One(registry_contract_address), token_owner); - dispatcher.upgrade(new_class_hash); - - // try to call the version function - let upgraded_dispatcher = IUpgradedRegistryDispatcher { contract_address: registry_contract_address }; - let version = upgraded_dispatcher.version(); - assert(version == 1_u8, 'upgrade unsuccessful'); - stop_prank(CheatTarget::One(registry_contract_address)); } \ No newline at end of file