Skip to content

Commit

Permalink
fix(oauth): arc_swap
Browse files Browse the repository at this point in the history
  • Loading branch information
sigaloid committed Jun 28, 2024
1 parent 4dc7ff8 commit 3b2ad21
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ fastrand = "2.0.1"
log = "0.4.20"
pretty_env_logger = "0.5.0"
dotenvy = "0.15.7"
arc-swap = "1.7.1"

[dev-dependencies]
lipsum = "0.9.0"
Expand Down
8 changes: 4 additions & 4 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use arc_swap::ArcSwap;
use cached::proc_macro::cached;
use futures_lite::future::block_on;
use futures_lite::{future::Boxed, FutureExt};
Expand All @@ -13,7 +14,6 @@ use serde_json::Value;
use std::sync::atomic::Ordering;
use std::sync::atomic::{AtomicBool, AtomicU16};
use std::{io, result::Result};
use tokio::sync::RwLock;

use crate::dbg_msg;
use crate::oauth::{force_refresh_token, token_daemon, Oauth};
Expand All @@ -32,10 +32,10 @@ pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
client::Client::builder().build(https)
});

pub static OAUTH_CLIENT: Lazy<RwLock<Oauth>> = Lazy::new(|| {
pub static OAUTH_CLIENT: Lazy<ArcSwap<Oauth>> = Lazy::new(|| {
let client = block_on(Oauth::new());
tokio::spawn(token_daemon());
RwLock::new(client)
ArcSwap::new(client.into())
});

pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
Expand Down Expand Up @@ -177,7 +177,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
let client: Client<_, Body> = CLIENT.clone();

let (token, vendor_id, device_id, user_agent, loid) = {
let client = block_on(OAUTH_CLIENT.read());
let client = OAUTH_CLIENT.load_full();
(
client.token.clone(),
client.headers_map.get("Client-Vendor-Id").cloned().unwrap_or_default(),
Expand Down
23 changes: 8 additions & 15 deletions src/oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,13 @@ impl Oauth {

Some(())
}

async fn refresh(&mut self) -> Option<()> {
// Refresh is actually just a subsequent login with the same headers (without the old token
// or anything). This logic is handled in login, so we just call login again.
let refresh = self.login().await;
info!("Refreshing OAuth token... {}", if refresh.is_some() { "success" } else { "failed" });
refresh
}
}

pub async fn token_daemon() {
// Monitor for refreshing token
loop {
// Get expiry time - be sure to not hold the read lock
let expires_in = { OAUTH_CLIENT.read().await.expires_in };
let expires_in = { OAUTH_CLIENT.load_full().expires_in };

// sleep for the expiry time minus 2 minutes
let duration = Duration::from_secs(expires_in - 120);
Expand All @@ -125,7 +117,7 @@ pub async fn token_daemon() {

// Refresh token - in its own scope
{
OAUTH_CLIENT.write().await.refresh().await;
force_refresh_token().await;
}
}
}
Expand All @@ -137,7 +129,8 @@ pub async fn force_refresh_token() {
}

trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst));
OAUTH_CLIENT.write().await.refresh().await;
let new_client = Oauth::new().await;
OAUTH_CLIENT.swap(new_client.into());
OAUTH_RATELIMIT_REMAINING.store(99, Ordering::SeqCst);
OAUTH_IS_ROLLING_OVER.store(false, Ordering::SeqCst);
}
Expand Down Expand Up @@ -187,21 +180,21 @@ fn choose<T: Copy>(list: &[T]) -> T {

#[tokio::test(flavor = "multi_thread")]
async fn test_oauth_client() {
assert!(!OAUTH_CLIENT.read().await.token.is_empty());
assert!(!OAUTH_CLIENT.load_full().token.is_empty());
}

#[tokio::test(flavor = "multi_thread")]
async fn test_oauth_client_refresh() {
OAUTH_CLIENT.write().await.refresh().await.unwrap();
force_refresh_token().await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_oauth_token_exists() {
assert!(!OAUTH_CLIENT.read().await.token.is_empty());
assert!(!OAUTH_CLIENT.load_full().token.is_empty());
}

#[tokio::test(flavor = "multi_thread")]
async fn test_oauth_headers_len() {
assert!(OAUTH_CLIENT.read().await.headers_map.len() >= 3);
assert!(OAUTH_CLIENT.load_full().headers_map.len() >= 3);
}

#[test]
Expand Down

0 comments on commit 3b2ad21

Please sign in to comment.