diff --git a/src/components/lockable/lockable.cairo b/src/components/lockable/lockable.cairo index 63020e8..d5bca72 100644 --- a/src/components/lockable/lockable.cairo +++ b/src/components/lockable/lockable.cairo @@ -46,7 +46,7 @@ pub mod LockableComponent { // ************************************************************************* pub mod Errors { pub const UNAUTHORIZED: felt252 = 'Account: unauthorized'; - pub const NOT_OWNER: felt252 = 'Not Account Owner'; + pub const NOT_OWNER: felt252 = 'Account: Not Account Owner'; pub const EXCEEDS_MAX_LOCK_TIME: felt252 = 'Lock time exceeded'; pub const LOCKED_ACCOUNT: felt252 = 'Account Locked'; } @@ -69,16 +69,19 @@ pub mod LockableComponent { let account_comp = get_dep_component!(@self, Account); - // get the token owner - let owner = account_comp.owner(); - - assert(get_caller_address() == owner, Errors::NOT_OWNER); + let is_valid = account_comp._is_valid_signer(get_caller_address()); + assert(is_valid, Errors::UNAUTHORIZED); assert(lock_until <= current_timestamp + 356, Errors::EXCEEDS_MAX_LOCK_TIME); let lock_status = self.is_lock(); assert(lock_status != true, Errors::LOCKED_ACCOUNT); + + // update account state + let mut account_comp_mut = get_dep_component_mut!(ref self, Account); + account_comp_mut._update_state(); + // set the lock_util which set the period the account is lock self.lock_until.write(lock_until); // emit event diff --git a/src/components/presets/account_preset.cairo b/src/components/presets/account_preset.cairo index dc050f2..2c15148 100644 --- a/src/components/presets/account_preset.cairo +++ b/src/components/presets/account_preset.cairo @@ -64,6 +64,9 @@ pub mod AccountPreset { #[abi(embed_v0)] impl Executable of IExecutable { fn execute(ref self: ContractState, mut calls: Array) -> Array> { + // cannot make this call when the account is lock + let is_lock = self.lockable.is_lock(); + assert(is_lock != true, 'Account locked'); self.account._execute(calls) } } @@ -74,6 +77,9 @@ pub mod AccountPreset { #[abi(embed_v0)] impl Upgradeable of IUpgradeable { fn upgrade(ref self: ContractState, new_class_hash: ClassHash) { + // cannot make this call when the account is lock + let is_lock = self.lockable.is_lock(); + assert(is_lock != true, 'Account locked'); self.upgradeable._upgrade(new_class_hash); } } diff --git a/tests/test_lockable_component.cairo b/tests/test_lockable_component.cairo index f69787f..4b3b440 100644 --- a/tests/test_lockable_component.cairo +++ b/tests/test_lockable_component.cairo @@ -28,29 +28,6 @@ use token_bound_accounts::test_helper::{ }; -const ACCOUNT: felt252 = 1234; -const ACCOUNT2: felt252 = 5729; -const SALT: felt252 = 123; - -#[derive(Drop)] -struct SignedTransactionData { - private_key: felt252, - public_key: felt252, - transaction_hash: felt252, - r: felt252, - s: felt252 -} - -fn SIGNED_TX_DATA() -> SignedTransactionData { - SignedTransactionData { - private_key: 1234, - public_key: 883045738439352841478194533192765345509759306772397516907181243450667673002, - transaction_hash: 2717105892474786771566982177444710571376803476229898722748888396642649184538, - r: 3068558690657879390136740086327753007413919701043650133111397282816679110801, - s: 3355728545224320878895493649495491771252432631648740019139167265522817576501 - } -} - // ************************************************************************* // SETUP // ************************************************************************* @@ -116,6 +93,55 @@ fn test_lockable() { assert(check_lock == true, 'Account Not Lock'); stop_cheat_caller_address(contract_address); } + +#[test] +#[should_panic(expected: ('Account locked',))] +fn test_should_fail_when_locked() { + let (contract_address, _) = __setup__(); + let acct_dispatcher = IAccountDispatcher { contract_address: contract_address }; + let safe_dispatcher = IExecutableDispatcher { contract_address }; + + let owner = acct_dispatcher.owner(); + let lock_duration = 30_u64; + + let lockable_dispatcher = ILockableDispatcher { contract_address }; + + start_cheat_caller_address(contract_address, owner); + lockable_dispatcher.lock(lock_duration); + stop_cheat_caller_address(contract_address); + + // deploy `HelloStarknet` contract for testing + let test_contract = declare("HelloStarknet").unwrap(); + let (test_address, _) = test_contract.deploy(@array![]).unwrap(); + + // craft calldata for call array + let mut calldata = array![100].span(); + let call = Call { + to: test_address, + selector: 1530486729947006463063166157847785599120665941190480211966374137237989315360, + calldata: calldata + }; + + start_cheat_caller_address(contract_address, owner); + safe_dispatcher.execute(array![call]); +} + +#[test] +#[should_panic(expected: ('Lock time exceeded',))] +fn test_should_fail_when_lock_until_exceed() { + let (contract_address, _) = __setup__(); + // let safe_acc_dispatcher = IAccountSafeDispatcher { contract_address }; + let acct_dispatcher = IAccountDispatcher { contract_address: contract_address }; + + let owner = acct_dispatcher.owner(); + let lock_duration = 3000_u64; + + let lockable_dispatcher = ILockableDispatcher { contract_address }; + + start_cheat_caller_address(contract_address, owner); + lockable_dispatcher.lock(lock_duration); +} + #[test] fn test_lockable_emits_event() { let (contract_address, _) = __setup__();