diff --git a/migrations/20240923163452_charges-fix.sql b/migrations/20240923163452_charges-fix.sql new file mode 100644 index 00000000..378bd488 --- /dev/null +++ b/migrations/20240923163452_charges-fix.sql @@ -0,0 +1,12 @@ +CREATE TABLE charges ( + id bigint PRIMARY KEY, + user_id bigint REFERENCES users NOT NULL, + price_id bigint REFERENCES products_prices NOT NULL, + amount bigint NOT NULL, + currency_code text NOT NULL, + subscription_id bigint REFERENCES users_subscriptions NULL, + interval text NULL, + status varchar(255) NOT NULL, + due timestamptz DEFAULT CURRENT_TIMESTAMP NOT NULL, + last_attempt timestamptz NOT NULL +); \ No newline at end of file diff --git a/src/database/models/charge_item.rs b/src/database/models/charge_item.rs new file mode 100644 index 00000000..af535a7a --- /dev/null +++ b/src/database/models/charge_item.rs @@ -0,0 +1,131 @@ +use crate::database::models::{ + ChargeId, DatabaseError, ProductPriceId, UserId, UserSubscriptionId, +}; +use crate::models::billing::{ChargeStatus, PriceDuration}; +use chrono::{DateTime, Utc}; +use std::convert::TryFrom; + +pub struct ChargeItem { + pub id: ChargeId, + pub user_id: UserId, + pub price_id: ProductPriceId, + pub amount: i64, + pub currency_code: String, + pub subscription_id: Option, + pub interval: Option, + pub status: ChargeStatus, + pub due: DateTime, + pub last_attempt: Option>, +} + +struct ChargeResult { + id: i64, + user_id: i64, + price_id: i64, + amount: i64, + currency_code: String, + subscription_id: Option, + interval: Option, + status: String, + due: DateTime, + last_attempt: Option>, +} + +impl TryFrom for ChargeItem { + type Error = serde_json::Error; + + fn try_from(r: ChargeResult) -> Result { + Ok(ChargeItem { + id: ChargeId(r.id), + user_id: UserId(r.user_id), + price_id: ProductPriceId(r.price_id), + amount: r.amount, + currency_code: r.currency_code, + subscription_id: r.subscription_id.map(UserSubscriptionId), + interval: r.interval.map(|x| serde_json::from_str(&x)).transpose()?, + status: serde_json::from_str(&r.status)?, + due: r.due, + last_attempt: r.last_attempt, + }) + } +} + +macro_rules! select_charges_with_predicate { + ($predicate:tt, $param:ident) => { + sqlx::query_as!( + ChargeResult, + r#" + SELECT id, user_id, price_id, amount, currency_code, subscription_id, interval, status, due, last_attempt + FROM charges + "# + + $predicate, + $param + ) + }; +} + +impl ChargeItem { + pub async fn insert( + &self, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result { + sqlx::query!( + r#" + INSERT INTO charges (id, user_id, price_id, amount, currency_code, subscription_id, interval, status, due, last_attempt) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + "#, + self.id.0, + self.user_id.0, + self.price_id.0, + self.amount, + self.currency_code, + self.subscription_id.map(|x| x.0), + self.interval.map(|x| x.as_str()), + self.status.as_str(), + self.due, + self.last_attempt, + ) + .execute(exec) + .await?; + + Ok(self.id) + } + + pub async fn get( + id: ChargeId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let res = select_charges_with_predicate!("WHERE id = $1", id) + .fetch_optional(exec) + .await?; + + Ok(res.map(|r| r.try_into())) + } + + pub async fn get_from_user( + user_id: UserId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let res = select_charges_with_predicate!("WHERE user_id = $1", user_id) + .fetch_all(exec) + .await?; + + Ok(res + .into_iter() + .map(|r| r.try_into()) + .collect::, serde_json::Error>>()?) + } + + pub async fn get_chargeable( + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let res = select_charges_with_predicate!("WHERE (status = 'open' AND due < NOW()) OR (status = 'failed' AND last_attempt < NOW() - INTERVAL '2 days')") + .fetch_all(exec) + .await?; + + Ok(res + .into_iter() + .map(|r| r.try_into()) + .collect::, serde_json::Error>>()?) + } +} diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index fd85a64c..be380924 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -256,6 +256,14 @@ generate_ids!( UserSubscriptionId ); +generate_ids!( + pub generate_charge_id, + ChargeId, + 8, + "SELECT EXISTS(SELECT 1 FROM charges WHERE id=$1)", + ChargeId +); + #[derive(Copy, Clone, Debug, PartialEq, Eq, Type, Hash, Serialize, Deserialize)] #[sqlx(transparent)] pub struct UserId(pub i64); @@ -386,6 +394,10 @@ pub struct ProductPriceId(pub i64); #[sqlx(transparent)] pub struct UserSubscriptionId(pub i64); +#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[sqlx(transparent)] +pub struct ChargeId(pub i64); + use crate::models::ids; impl From for ProjectId { @@ -571,3 +583,14 @@ impl From for ids::UserSubscriptionId { ids::UserSubscriptionId(id.0 as u64) } } + +impl From for ChargeId { + fn from(id: ids::ChargeId) -> Self { + ChargeId(id.0 as i64) + } +} +impl From for ids::ChargeId { + fn from(id: ChargeId) -> Self { + ids::ChargeId(id.0 as u64) + } +} diff --git a/src/database/models/mod.rs b/src/database/models/mod.rs index 51dafed6..dabcfdda 100644 --- a/src/database/models/mod.rs +++ b/src/database/models/mod.rs @@ -1,6 +1,7 @@ use thiserror::Error; pub mod categories; +pub mod charge_item; pub mod collection_item; pub mod flow_item; pub mod ids; diff --git a/src/database/models/user_subscription_item.rs b/src/database/models/user_subscription_item.rs index b13d319f..fb5af314 100644 --- a/src/database/models/user_subscription_item.rs +++ b/src/database/models/user_subscription_item.rs @@ -89,17 +89,6 @@ impl UserSubscriptionItem { Ok(results.into_iter().map(|r| r.into()).collect()) } - pub async fn get_all_expired( - exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, - ) -> Result, DatabaseError> { - let now = Utc::now(); - let results = select_user_subscriptions_with_predicate!("WHERE expires < $1", now) - .fetch_all(exec) - .await?; - - Ok(results.into_iter().map(|r| r.into()).collect()) - } - pub async fn upsert( &self, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, diff --git a/src/models/v3/billing.rs b/src/models/v3/billing.rs index c78583da..c1e3d703 100644 --- a/src/models/v3/billing.rs +++ b/src/models/v3/billing.rs @@ -134,3 +134,53 @@ impl SubscriptionStatus { } } } + +#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] +#[serde(from = "Base62Id")] +#[serde(into = "Base62Id")] +pub struct ChargeId(pub u64); + +#[derive(Serialize, Deserialize)] +pub struct Charge { + pub id: ChargeId, + pub user_id: UserId, + pub price_id: ProductPriceId, + pub amount: i64, + pub currency_code: String, + pub subscription_id: Option, + pub interval: Option, + pub status: ChargeStatus, + pub due: DateTime, + pub last_attempt: Option>, +} + +#[derive(Serialize, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "kebab-case")] +pub enum ChargeStatus { + // Open charges are for the next billing interval + Open, + Processing, + Succeeded, + Failed, +} + +impl ChargeStatus { + pub fn from_string(string: &str) -> ChargeStatus { + match string { + "processing" => ChargeStatus::Processing, + "succeeded" => ChargeStatus::Succeeded, + "failed" => ChargeStatus::Failed, + "open" => ChargeStatus::Open, + _ => ChargeStatus::Failed, + } + } + + pub fn as_str(&self) -> &'static str { + match self { + ChargeStatus::Processing => "processing", + ChargeStatus::Succeeded => "succeeded", + ChargeStatus::Failed => "failed", + ChargeStatus::Open => "open", + } + } +} diff --git a/src/models/v3/ids.rs b/src/models/v3/ids.rs index 839d6587..3af87437 100644 --- a/src/models/v3/ids.rs +++ b/src/models/v3/ids.rs @@ -13,8 +13,7 @@ pub use super::teams::TeamId; pub use super::threads::ThreadId; pub use super::threads::ThreadMessageId; pub use super::users::UserId; -pub use crate::models::billing::UserSubscriptionId; -pub use crate::models::v3::billing::{ProductId, ProductPriceId}; +pub use crate::models::billing::{ChargeId, ProductId, ProductPriceId, UserSubscriptionId}; use thiserror::Error; /// Generates a random 64 bit integer that is exactly `n` characters @@ -137,6 +136,7 @@ base62_id_impl!(PayoutId, PayoutId); base62_id_impl!(ProductId, ProductId); base62_id_impl!(ProductPriceId, ProductPriceId); base62_id_impl!(UserSubscriptionId, UserSubscriptionId); +base62_id_impl!(ChargeId, ChargeId); pub mod base62_impl { use serde::de::{self, Deserializer, Visitor}; diff --git a/src/routes/internal/billing.rs b/src/routes/internal/billing.rs index a63442dc..424d0a1f 100644 --- a/src/routes/internal/billing.rs +++ b/src/routes/internal/billing.rs @@ -1,11 +1,11 @@ use crate::auth::{get_user_from_headers, send_email}; use crate::database::models::{ - generate_user_subscription_id, product_item, user_subscription_item, + generate_charge_id, generate_user_subscription_id, product_item, user_subscription_item, }; use crate::database::redis::RedisPool; use crate::models::billing::{ - Price, PriceDuration, Product, ProductMetadata, ProductPrice, SubscriptionStatus, - UserSubscription, + Charge, ChargeStatus, Price, PriceDuration, Product, ProductMetadata, ProductPrice, + SubscriptionStatus, UserSubscription, }; use crate::models::ids::base62_impl::{parse_base62, to_base62}; use crate::models::pats::Scopes; @@ -140,6 +140,8 @@ pub async fn cancel_subscription( ) .execute(&mut *transaction) .await?; + + // TODO: delete open charges for this subscription } else { subscription.status = SubscriptionStatus::Cancelled; subscription.upsert(&mut transaction).await?; @@ -191,7 +193,6 @@ pub async fn charges( pool: web::Data, redis: web::Data, session_queue: web::Data, - stripe_client: web::Data, ) -> Result { let user = get_user_from_headers( &req, @@ -203,25 +204,27 @@ pub async fn charges( .await? .1; - if let Some(customer_id) = user - .stripe_customer_id - .as_ref() - .and_then(|x| stripe::CustomerId::from_str(x).ok()) - { - let charges = stripe::Charge::list( - &stripe_client, - &ListCharges { - customer: Some(customer_id), - limit: Some(100), - ..Default::default() - }, - ) - .await?; + let charges = + crate::database::models::charge_item::ChargeItem::get_from_user(user.id.into(), &**pool) + .await?; - Ok(HttpResponse::Ok().json(charges.data)) - } else { - Ok(HttpResponse::NoContent().finish()) - } + Ok(HttpResponse::Ok().json( + charges + .into_iter() + .map(|x| Charge { + id: x.id.into(), + user_id: x.user_id.into(), + price_id: x.price_id.into(), + amount: x.amount, + currency_code: x.currency_code, + subscription_id: x.subscription_id.map(|x| x.into()), + interval: x.interval, + status: x.status, + due: x.due, + last_attempt: x.last_attempt, + }) + .collect::>(), + )) } #[post("payment_method")] @@ -466,15 +469,85 @@ pub enum PaymentRequestType { ConfirmationToken { token: String }, } +#[derive(Deserialize)] +pub enum ChargeRequestType { + Existing { + id: crate::models::ids::ChargeId, + }, + New { + product_id: crate::models::ids::ProductId, + interval: Option, + }, +} + #[derive(Deserialize)] pub struct PaymentRequest { - pub product_id: crate::models::ids::ProductId, - pub interval: Option, #[serde(flatten)] pub type_: PaymentRequestType, + pub charge: ChargeRequestType, pub existing_payment_intent: Option, } +fn infer_currency_code(country: &str) -> String { + match country { + "US" => "USD", + "GB" => "GBP", + "EU" => "EUR", + "AT" => "EUR", + "BE" => "EUR", + "CY" => "EUR", + "EE" => "EUR", + "FI" => "EUR", + "FR" => "EUR", + "DE" => "EUR", + "GR" => "EUR", + "IE" => "EUR", + "IT" => "EUR", + "LV" => "EUR", + "LT" => "EUR", + "LU" => "EUR", + "MT" => "EUR", + "NL" => "EUR", + "PT" => "EUR", + "SK" => "EUR", + "SI" => "EUR", + "RU" => "RUB", + "BR" => "BRL", + "JP" => "JPY", + "ID" => "IDR", + "MY" => "MYR", + "PH" => "PHP", + "TH" => "THB", + "VN" => "VND", + "KR" => "KRW", + "TR" => "TRY", + "UA" => "UAH", + "MX" => "MXN", + "CA" => "CAD", + "NZ" => "NZD", + "NO" => "NOK", + "PL" => "PLN", + "CH" => "CHF", + "LI" => "CHF", + "IN" => "INR", + "CL" => "CLP", + "PE" => "PEN", + "CO" => "COP", + "ZA" => "ZAR", + "HK" => "HKD", + "AR" => "ARS", + "KZ" => "KZT", + "UY" => "UYU", + "CN" => "CNY", + "AU" => "AUD", + "TW" => "TWD", + "SA" => "SAR", + "QA" => "QAR", + _ => "USD", + } + .to_string() +} + #[post("payment")] pub async fn initiate_payment( req: HttpRequest, @@ -494,12 +567,6 @@ pub async fn initiate_payment( .await? .1; - let product = product_item::ProductItem::get(payment_request.product_id.into(), &**pool) - .await? - .ok_or_else(|| { - ApiError::InvalidInput("Specified product could not be found!".to_string()) - })?; - let (user_country, payment_method) = match &payment_request.type_ { PaymentRequestType::PaymentMethod { id } => { let payment_method_id = stripe::PaymentMethodId::from_str(id) @@ -551,93 +618,112 @@ pub async fn initiate_payment( }; let country = user_country.as_deref().unwrap_or("US"); - let recommended_currency_code = match country { - "US" => "USD", - "GB" => "GBP", - "EU" => "EUR", - "AT" => "EUR", - "BE" => "EUR", - "CY" => "EUR", - "EE" => "EUR", - "FI" => "EUR", - "FR" => "EUR", - "DE" => "EUR", - "GR" => "EUR", - "IE" => "EUR", - "IT" => "EUR", - "LV" => "EUR", - "LT" => "EUR", - "LU" => "EUR", - "MT" => "EUR", - "NL" => "EUR", - "PT" => "EUR", - "SK" => "EUR", - "SI" => "EUR", - "RU" => "RUB", - "BR" => "BRL", - "JP" => "JPY", - "ID" => "IDR", - "MY" => "MYR", - "PH" => "PHP", - "TH" => "THB", - "VN" => "VND", - "KR" => "KRW", - "TR" => "TRY", - "UA" => "UAH", - "MX" => "MXN", - "CA" => "CAD", - "NZ" => "NZD", - "NO" => "NOK", - "PL" => "PLN", - "CH" => "CHF", - "LI" => "CHF", - "IN" => "INR", - "CL" => "CLP", - "PE" => "PEN", - "CO" => "COP", - "ZA" => "ZAR", - "HK" => "HKD", - "AR" => "ARS", - "KZ" => "KZT", - "UY" => "UYU", - "CN" => "CNY", - "AU" => "AUD", - "TW" => "TWD", - "SA" => "SAR", - "QA" => "QAR", - _ => "USD", - }; - - let mut product_prices = - product_item::ProductPriceItem::get_all_product_prices(product.id, &**pool).await?; - - let price_item = if let Some(pos) = product_prices - .iter() - .position(|x| x.currency_code == recommended_currency_code) - { - product_prices.remove(pos) - } else if let Some(pos) = product_prices.iter().position(|x| x.currency_code == "USD") { - product_prices.remove(pos) - } else { - return Err(ApiError::InvalidInput( - "Could not find a valid price for the user's country".to_string(), - )); - }; + let recommended_currency_code = infer_currency_code(country); + + let (price, currency_code, interval, price_id, charge_id) = match payment_request.charge { + ChargeRequestType::Existing { id } => { + let charge = crate::database::models::charge_item::ChargeItem::get(id.into(), &**pool) + .await? + .ok_or_else(|| { + ApiError::InvalidInput("Specified charge could not be found!".to_string()) + })?; + + ( + charge.amount, + charge.currency_code, + charge.interval, + charge.price_id, + Some(id), + ) + } + ChargeRequestType::New { + product_id, + interval, + } => { + let product = product_item::ProductItem::get(product_id.into(), &**pool) + .await? + .ok_or_else(|| { + ApiError::InvalidInput("Specified product could not be found!".to_string()) + })?; + + let mut product_prices = + product_item::ProductPriceItem::get_all_product_prices(product.id, &**pool).await?; + + let price_item = if let Some(pos) = product_prices + .iter() + .position(|x| x.currency_code == recommended_currency_code) + { + product_prices.remove(pos) + } else if let Some(pos) = product_prices.iter().position(|x| x.currency_code == "USD") { + product_prices.remove(pos) + } else { + return Err(ApiError::InvalidInput( + "Could not find a valid price for the user's country".to_string(), + )); + }; + + let price = match price_item.prices { + Price::OneTime { price } => price, + Price::Recurring { ref intervals } => { + let interval = interval.ok_or_else(|| { + ApiError::InvalidInput( + "Could not find a valid interval for the user's country".to_string(), + ) + })?; + + *intervals.get(&interval).ok_or_else(|| { + ApiError::InvalidInput( + "Could not find a valid price for the user's country".to_string(), + ) + })? + } + }; + + if let Price::Recurring { .. } = price_item.prices { + if product.unitary { + let user_subscriptions = + user_subscription_item::UserSubscriptionItem::get_all_user( + user.id.into(), + &**pool, + ) + .await?; + + let user_products = product_item::ProductPriceItem::get_many( + &user_subscriptions + .iter() + .map(|x| x.price_id) + .collect::>(), + &**pool, + ) + .await?; - let price = match price_item.prices { - Price::OneTime { price } => price, - Price::Recurring { ref intervals } => { - let interval = payment_request.interval.ok_or_else(|| { - ApiError::InvalidInput( - "Could not find a valid interval for the user's country".to_string(), - ) - })?; + if let Some(product) = user_products + .into_iter() + .find(|x| x.product_id == product.id) + { + if let Some(subscription) = user_subscriptions + .into_iter() + .find(|x| x.price_id == product.id) + { + return Err(ApiError::InvalidInput(if !(subscription.status == SubscriptionStatus::Cancelled + || subscription.status == SubscriptionStatus::PaymentFailed) + { + "You are already subscribed to this product!" + } else { + "You are already subscribed to this product, but the payment failed!" + }.to_string())); + } + } + } + } - *intervals.get(&interval).ok_or_else(|| { - ApiError::InvalidInput( - "Could not find a valid price for the user's country".to_string(), - ) - })? + ( + price as i64, + price_item.currency_code, + interval, + price_item.id, + None, + ) } }; @@ -650,31 +736,17 @@ pub async fn initiate_payment( &redis, ) .await?; - let stripe_currency = Currency::from_str(&price_item.currency_code.to_lowercase()) + let stripe_currency = Currency::from_str(¤cy_code.to_lowercase()) .map_err(|_| ApiError::InvalidInput("Invalid currency code".to_string()))?; if let Some(payment_intent_id) = &payment_request.existing_payment_intent { let mut update_payment_intent = stripe::UpdatePaymentIntent { - amount: Some(price as i64), + amount: Some(price), currency: Some(stripe_currency), customer: Some(customer), ..Default::default() }; - let mut metadata = HashMap::new(); - metadata.insert("modrinth_user_id".to_string(), to_base62(user.id.0)); - metadata.insert( - "modrinth_price_id".to_string(), - to_base62(price_item.id.0 as u64), - ); - if let Some(interval) = payment_request.interval { - metadata.insert( - "modrinth_subscription_interval".to_string(), - interval.as_str().to_string(), - ); - } - update_payment_intent.metadata = Some(metadata); - if let PaymentRequestType::PaymentMethod { .. } = payment_request.type_ { update_payment_intent.payment_method = Some(payment_method.id.clone()); } @@ -689,68 +761,20 @@ pub async fn initiate_payment( "payment_method": payment_method, }))) } else { - let mut intent = CreatePaymentIntent::new(price as i64, stripe_currency); + let mut intent = CreatePaymentIntent::new(price, stripe_currency); - let mut transaction = pool.begin().await?; let mut metadata = HashMap::new(); metadata.insert("modrinth_user_id".to_string(), to_base62(user.id.0)); - metadata.insert( - "modrinth_price_id".to_string(), - to_base62(price_item.id.0 as u64), - ); - - if let Price::Recurring { .. } = price_item.prices { - if product.unitary { - let user_subscriptions = - user_subscription_item::UserSubscriptionItem::get_all_user( - user.id.into(), - &**pool, - ) - .await?; - - let user_products = product_item::ProductPriceItem::get_many( - &user_subscriptions - .iter() - .map(|x| x.price_id) - .collect::>(), - &**pool, - ) - .await?; - - if let Some(product) = user_products - .into_iter() - .find(|x| x.product_id == product.id) - { - if let Some(subscription) = user_subscriptions - .into_iter() - .find(|x| x.price_id == product.id) - { - if subscription.status == SubscriptionStatus::Cancelled - || subscription.status == SubscriptionStatus::PaymentFailed - { - metadata.insert( - "modrinth_subscription_id".to_string(), - to_base62(subscription.id.0 as u64), - ); - } else { - return Err(ApiError::InvalidInput( - "You are already subscribed to this product!".to_string(), - )); - } - } - } - } - - if !metadata.contains_key("modrinth_subscription_id") { - let user_subscription_id = generate_user_subscription_id(&mut transaction).await?; - metadata.insert( - "modrinth_subscription_id".to_string(), - to_base62(user_subscription_id.0 as u64), - ); - } + if let Some(charge_id) = charge_id { + metadata.insert("modrinth_charge_id".to_string(), to_base62(charge_id.0)); + } else { + metadata.insert( + "modrinth_price_id".to_string(), + to_base62(price_id.0 as u64), + ); - if let Some(interval) = payment_request.interval { + if let Some(interval) = interval { metadata.insert( "modrinth_subscription_interval".to_string(), interval.as_str().to_string(), @@ -766,14 +790,12 @@ pub async fn initiate_payment( }); intent.receipt_email = user.email.as_deref(); intent.setup_future_usage = Some(PaymentIntentSetupFutureUsage::OffSession); - intent.payment_method_types = Some(vec!["card".to_string(), "cashapp".to_string()]); if let PaymentRequestType::PaymentMethod { .. } = payment_request.type_ { intent.payment_method = Some(payment_method.id.clone()); } let payment_intent = stripe::PaymentIntent::create(&stripe_client, intent).await?; - transaction.commit().await?; Ok(HttpResponse::Ok().json(serde_json::json!({ "payment_intent_id": payment_intent.id, @@ -814,6 +836,7 @@ pub async fn stripe_webhook( user_subscription: Option, product: product_item::ProductItem, product_price: product_item::ProductPriceItem, + charge_item: Option, } async fn get_payment_intent_metadata( @@ -1115,16 +1138,29 @@ async fn get_or_create_customer( } pub async fn task(stripe_client: stripe::Client, pool: PgPool, redis: RedisPool) { + // Check for open charges which are open AND last charge hasn't already been attempted + // CHeck for open charges which are failed ANd last attempt > 2 days ago (and unprovision) + // if subscription is cancelled and expired, unprovision and remove // if subscription is payment failed and last attempt is > 2 days ago, try again to charge and unprovision // if subscription is active and expired, attempt to charge and set as processing loop { info!("Indexing billing queue"); let res = async { - let expired = - user_subscription_item::UserSubscriptionItem::get_all_expired(&pool).await?; + let charges_to_do = crate::database::models::charge_item::ChargeItem::get_chargeable(&pool).await?; + + let subscription_items = user_subscription_item::UserSubscriptionItem::get_many( + &charges_to_do + .iter() + .flat_map(|x| x.subscription_id) + .collect::>() + .into_iter() + .collect::>(), + &pool, + ) + .await?; let subscription_prices = product_item::ProductPriceItem::get_many( - &expired + &charges_to_do .iter() .map(|x| x.price_id) .collect::>() @@ -1144,7 +1180,7 @@ pub async fn task(stripe_client: stripe::Client, pool: PgPool, redis: RedisPool) ) .await?; let users = crate::database::models::User::get_many_ids( - &expired + &charges_to_do .iter() .map(|x| x.user_id) .collect::>() @@ -1158,13 +1194,13 @@ pub async fn task(stripe_client: stripe::Client, pool: PgPool, redis: RedisPool) let mut transaction = pool.begin().await?; let mut clear_cache_users = Vec::new(); - for mut subscription in expired { - let user = users.iter().find(|x| x.id == subscription.user_id); + for mut charge in charges_to_do { + let user = users.iter().find(|x| x.id == charge.user_id); if let Some(user) = user { let product_price = subscription_prices .iter() - .find(|x| x.id == subscription.price_id); + .find(|x| x.id == charge.price_id); if let Some(product_price) = product_price { let product = subscription_products @@ -1238,11 +1274,21 @@ pub async fn task(stripe_client: stripe::Client, pool: PgPool, redis: RedisPool) ) .await?; - let mut intent = CreatePaymentIntent::new( - *price as i64, - Currency::from_str(&product_price.currency_code) - .unwrap_or(Currency::USD), - ); + let currency = match Currency::from_str( + &product_price.currency_code.to_lowercase(), + ) { + Ok(x) => x, + Err(_) => { + warn!( + "Could not find currency for {}", + product_price.currency_code + ); + continue; + } + }; + + let mut intent = + CreatePaymentIntent::new(*price as i64, currency); let mut metadata = HashMap::new(); metadata.insert(