Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(upgrader): assume that target canister is always set in production #479

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions core/upgrader/impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use ic_stable_structures::{
};
use lazy_static::lazy_static;
use orbit_essentials::storable;
use std::{cell::RefCell, sync::Arc, thread::LocalKey};
use std::{cell::RefCell, sync::Arc};
use upgrade::{UpgradeError, UpgradeParams};
use upgrader_api::{InitArg, TriggerUpgradeError};

Expand All @@ -29,7 +29,6 @@ pub mod utils;
type Memory = VirtualMemory<DefaultMemoryImpl>;
type StableMap<K, V> = StableBTreeMap<K, V, Memory>;
type StableValue<T> = StableMap<(), T>;
type LocalRef<T> = &'static LocalKey<RefCell<T>>;

const MEMORY_ID_TARGET_CANISTER_ID: u8 = 0;
const MEMORY_ID_DISASTER_RECOVERY: u8 = 1;
Expand All @@ -51,6 +50,15 @@ thread_local! {
);
}

pub fn get_target_canister() -> Principal {
TARGET_CANISTER_ID.with(|id| {
id.borrow()
.get(&())
.map(|id| id.0)
.unwrap_or(Principal::anonymous())
})
}

#[init]
fn init_fn(InitArg { target_canister }: InitArg) {
TARGET_CANISTER_ID.with(|id| {
Expand All @@ -61,13 +69,13 @@ fn init_fn(InitArg { target_canister }: InitArg) {

lazy_static! {
static ref UPGRADER: Box<dyn Upgrade> = {
let u = Upgrader::new(&TARGET_CANISTER_ID);
let u = WithStop(u, &TARGET_CANISTER_ID);
let u = WithStart(u, &TARGET_CANISTER_ID);
let u = Upgrader {};
let u = WithStop(u);
let u = WithStart(u);
let u = WithLogs(u, "upgrade".to_string());
let u = WithBackground(Arc::new(u), &TARGET_CANISTER_ID);
let u = CheckController(u, &TARGET_CANISTER_ID);
let u = WithAuthorization(u, &TARGET_CANISTER_ID);
let u = WithBackground(Arc::new(u));
let u = CheckController(u);
let u = WithAuthorization(u);
let u = WithLogs(u, "trigger_upgrade".to_string());
Box::new(u)
};
Expand Down
91 changes: 14 additions & 77 deletions core/upgrader/impl/src/services/disaster_recovery.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,29 @@
use std::{
cell::RefCell,
collections::{HashMap, HashSet},
sync::Arc,
};

use super::{InstallCanister, LoggerService, INSTALL_CANISTER};
use crate::{
errors::UpgraderApiError,
get_target_canister,
model::{
Asset, DisasterRecoveryInProgressLog, DisasterRecoveryResultLog, DisasterRecoveryStartLog,
LogEntryType, MultiAssetAccount, RequestDisasterRecoveryLog, SetAccountsAndAssetsLog,
SetAccountsLog, SetCommitteeLog,
Account, AdminUser, Asset, DisasterRecovery, DisasterRecoveryCommittee,
DisasterRecoveryInProgressLog, DisasterRecoveryResultLog, DisasterRecoveryStartLog,
InstallMode, LogEntryType, MultiAssetAccount, RecoveryEvaluationResult, RecoveryFailure,
RecoveryResult, RecoveryStatus, RequestDisasterRecoveryLog, SetAccountsAndAssetsLog,
SetAccountsLog, SetCommitteeLog, StationRecoveryRequest,
},
services::LOGGER_SERVICE,
upgrader_ic_cdk::{api::time, spawn},
StableValue, MEMORY_ID_DISASTER_RECOVERY, MEMORY_MANAGER,
};

use candid::Principal;
use ic_stable_structures::memory_manager::MemoryId;
use lazy_static::lazy_static;
use orbit_essentials::{api::ServiceResult, utils::sha256_hash};

use crate::{
model::{
Account, AdminUser, DisasterRecovery, DisasterRecoveryCommittee, InstallMode,
RecoveryEvaluationResult, RecoveryFailure, RecoveryResult, RecoveryStatus,
StationRecoveryRequest,
},
StableValue, MEMORY_ID_DISASTER_RECOVERY, MEMORY_MANAGER, TARGET_CANISTER_ID,
use std::{
cell::RefCell,
collections::{HashMap, HashSet},
sync::Arc,
};

use super::{InstallCanister, LoggerService, INSTALL_CANISTER};

pub const DISASTER_RECOVERY_REQUEST_EXPIRATION_NS: u64 = 60 * 60 * 24 * 7 * 1_000_000_000; // 1 week
pub const DISASTER_RECOVERY_IN_PROGESS_EXPIRATION_NS: u64 = 60 * 60 * 1_000_000_000; // 1 hour

Expand Down Expand Up @@ -299,15 +293,7 @@ impl DisasterRecoveryService {
return;
}

let Some(station_canister_id) =
TARGET_CANISTER_ID.with(|id| id.borrow().get(&()).map(|id| id.0))
else {
value.last_recovery_result = Some(RecoveryResult::Failure(RecoveryFailure {
reason: "Station canister ID not set".to_string(),
}));
storage.set(value);
return;
};
let station_canister_id = get_target_canister();

value.recovery_status = RecoveryStatus::InProgress { since: time() };
storage.set(value);
Expand Down Expand Up @@ -432,7 +418,6 @@ mod tests {
services::{
DisasterRecoveryService, DisasterRecoveryStorage, InstallCanister, LoggerService,
},
StorablePrincipal, TARGET_CANISTER_ID,
};

#[derive(Default)]
Expand Down Expand Up @@ -592,11 +577,6 @@ mod tests {

#[tokio::test]
async fn test_do_recovery() {
TARGET_CANISTER_ID.with(|id| {
id.borrow_mut()
.insert((), StorablePrincipal(Principal::anonymous()));
});

let storage: DisasterRecoveryStorage = Default::default();
let logger = Arc::new(LoggerService::default());
let recovery_request = StationRecoveryRequest {
Expand Down Expand Up @@ -672,51 +652,8 @@ mod tests {
);
}

#[tokio::test]
async fn test_failing_do_recovery_with_no_target_canister_id() {
// setup: TARGET_CANISTER_ID is not set, so recovery should fail

let storage: DisasterRecoveryStorage = Default::default();
let logger = Arc::new(LoggerService::default());
let recovery_request = StationRecoveryRequest {
user_id: [1; 16],
wasm_module: vec![1, 2, 3],
wasm_module_extra_chunks: None,
wasm_sha256: vec![4, 5, 6],
install_mode: InstallMode::Reinstall,
arg: vec![7, 8, 9],
arg_sha256: vec![10, 11, 12],
submitted_at: 0,
};

let installer = Arc::new(TestInstaller::default());

DisasterRecoveryService::do_recovery(
storage.clone(),
installer.clone(),
logger.clone(),
recovery_request.clone(),
)
.await;

assert!(matches!(
storage.get().last_recovery_result,
Some(RecoveryResult::Failure(_))
));

assert!(matches!(
storage.get().recovery_status,
RecoveryStatus::Idle
));
}

#[tokio::test]
async fn test_failing_do_recovery_with_panicking_install() {
TARGET_CANISTER_ID.with(|id| {
id.borrow_mut()
.insert((), StorablePrincipal(Principal::anonymous()));
});

let storage: DisasterRecoveryStorage = Default::default();
let logger = Arc::new(LoggerService::default());
let recovery_request = StationRecoveryRequest {
Expand Down
103 changes: 40 additions & 63 deletions core/upgrader/impl/src/upgrade.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use crate::{
get_target_canister,
model::{LogEntryType, UpgradeResultLog},
services::LOGGER_SERVICE,
LocalRef, StableValue, StorablePrincipal,
};
use anyhow::{anyhow, Context};
use anyhow::anyhow;
use async_trait::async_trait;
use candid::Principal;
use ic_cdk::api::management_canister::main::{
self as mgmt, CanisterIdRecord, CanisterInfoRequest, CanisterInstallMode,
};
Expand Down Expand Up @@ -41,23 +40,12 @@ pub trait Upgrade: 'static + Sync + Send {
}

#[derive(Clone)]
pub struct Upgrader {
target: LocalRef<StableValue<StorablePrincipal>>,
}

impl Upgrader {
pub fn new(target: LocalRef<StableValue<StorablePrincipal>>) -> Self {
Self { target }
}
}
pub struct Upgrader {}

#[async_trait]
impl Upgrade for Upgrader {
async fn upgrade(&self, ps: UpgradeParams) -> Result<(), UpgradeError> {
let target_canister = self
.target
.with(|id| id.borrow().get(&()).context("canister id not set"))?
.0;
let target_canister = get_target_canister();

install_chunked_code(
target_canister,
Expand All @@ -71,25 +59,23 @@ impl Upgrade for Upgrader {
}
}

pub struct WithStop<T>(pub T, pub LocalRef<StableValue<StorablePrincipal>>);
pub struct WithStop<T>(pub T);

#[async_trait]
impl<T: Upgrade> Upgrade for WithStop<T> {
/// Perform an upgrade but ensure that the target canister is stopped first
async fn upgrade(&self, ps: UpgradeParams) -> Result<(), UpgradeError> {
let id = self
.1
.with(|id| id.borrow().get(&()).context("canister id not set"))?;
let id = get_target_canister();

mgmt::stop_canister(CanisterIdRecord { canister_id: id.0 })
mgmt::stop_canister(CanisterIdRecord { canister_id: id })
.await
.map_err(|(_, err)| anyhow!("failed to stop canister: {err}"))?;

self.0.upgrade(ps).await
}
}

pub struct WithStart<T>(pub T, pub LocalRef<StableValue<StorablePrincipal>>);
pub struct WithStart<T>(pub T);

#[async_trait]
impl<T: Upgrade> Upgrade for WithStart<T> {
Expand All @@ -98,57 +84,52 @@ impl<T: Upgrade> Upgrade for WithStart<T> {
async fn upgrade(&self, ps: UpgradeParams) -> Result<(), UpgradeError> {
let out = self.0.upgrade(ps).await;

let id = self
.1
.with(|id| id.borrow().get(&()).context("canister id not set"))?;
let id = get_target_canister();

mgmt::start_canister(CanisterIdRecord { canister_id: id.0 })
mgmt::start_canister(CanisterIdRecord { canister_id: id })
.await
.map_err(|(_, err)| anyhow!("failed to start canister: {err}"))?;

out
}
}

pub struct WithBackground<T>(pub Arc<T>, pub LocalRef<StableValue<StorablePrincipal>>);
pub struct WithBackground<T>(pub Arc<T>);

#[async_trait]
impl<T: Upgrade> Upgrade for WithBackground<T> {
/// Spawn a background task performing the upgrade
/// so that it is performed in a non-blocking manner
async fn upgrade(&self, ps: UpgradeParams) -> Result<(), UpgradeError> {
let u = self.0.clone();
let target_canister_id: Option<Principal> =
self.1.with(|p| p.borrow().get(&()).map(|sp| sp.0));
let target_canister_id = get_target_canister();

ic_cdk::spawn(async move {
let res = u.upgrade(ps).await;
// Notify the target canister about a failed upgrade unless the call is unauthorized
// (we don't want to spam the target canister with such errors).
if let Some(target_canister_id) = target_canister_id {
if let Err(ref err) = res {
let err = match err {
UpgradeError::UnexpectedError(err) => Some(err.to_string()),
UpgradeError::NotController => Some(
"The upgrader canister is not a controller of the target canister"
.to_string(),
),
UpgradeError::Unauthorized => None,
};
if let Some(err) = err {
let notify_failed_station_upgrade_input =
NotifyFailedStationUpgradeInput { reason: err };
let notify_res = call::<_, (ApiResult<()>,)>(
target_canister_id,
"notify_failed_station_upgrade",
(notify_failed_station_upgrade_input,),
)
.await
.map(|r| r.0);
// Log an error if the notification can't be made.
if let Err(e) = notify_res {
print(format!("notify_failed_station_upgrade failed: {:?}", e));
}
if let Err(ref err) = res {
let err = match err {
UpgradeError::UnexpectedError(err) => Some(err.to_string()),
UpgradeError::NotController => Some(
"The upgrader canister is not a controller of the target canister"
.to_string(),
),
UpgradeError::Unauthorized => None,
};
if let Some(err) = err {
let notify_failed_station_upgrade_input =
NotifyFailedStationUpgradeInput { reason: err };
let notify_res = call::<_, (ApiResult<()>,)>(
target_canister_id,
"notify_failed_station_upgrade",
(notify_failed_station_upgrade_input,),
)
.await
.map(|r| r.0);
// Log an error if the notification can't be made.
if let Err(e) = notify_res {
print(format!("notify_failed_station_upgrade failed: {:?}", e));
}
}
}
Expand All @@ -158,34 +139,30 @@ impl<T: Upgrade> Upgrade for WithBackground<T> {
}
}

pub struct WithAuthorization<T>(pub T, pub LocalRef<StableValue<StorablePrincipal>>);
pub struct WithAuthorization<T>(pub T);

#[async_trait]
impl<T: Upgrade> Upgrade for WithAuthorization<T> {
async fn upgrade(&self, ps: UpgradeParams) -> Result<(), UpgradeError> {
let id = self
.1
.with(|id| id.borrow().get(&()).context("canister id not set"))?;
let id = get_target_canister();

if !ic_cdk::caller().eq(&id.0) {
if !ic_cdk::caller().eq(&id) {
return Err(UpgradeError::Unauthorized);
}

self.0.upgrade(ps).await
}
}

pub struct CheckController<T>(pub T, pub LocalRef<StableValue<StorablePrincipal>>);
pub struct CheckController<T>(pub T);

#[async_trait]
impl<T: Upgrade> Upgrade for CheckController<T> {
async fn upgrade(&self, ps: UpgradeParams) -> Result<(), UpgradeError> {
let id = self
.1
.with(|id| id.borrow().get(&()).context("canister id not set"))?;
let id = get_target_canister();

let (resp,) = mgmt::canister_info(CanisterInfoRequest {
canister_id: id.0,
canister_id: id,
num_requested_changes: None,
})
.await
Expand Down
Loading