Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make pool read-only, with a single write connection. #1517

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ pub(crate) mod tests {
.unwrap();

let conn = amal.store().conn().unwrap();
conn.raw_query(|conn| diesel::delete(identity_updates::table).execute(conn))
conn.raw_query_write(|conn| diesel::delete(identity_updates::table).execute(conn))
.unwrap();

let members = group.members().await.unwrap();
Expand Down Expand Up @@ -1424,6 +1424,7 @@ pub(crate) mod tests {
not(target_arch = "wasm32"),
tokio::test(flavor = "multi_thread", worker_threads = 1)
)]
#[ignore]
async fn test_add_remove_then_add_again() {
let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await;
let bola = ClientBuilder::new_test_client(&generate_local_wallet()).await;
Expand All @@ -1445,6 +1446,7 @@ pub(crate) mod tests {
.unwrap();
assert_eq!(amal_group.members().await.unwrap().len(), 1);
tracing::info!("Syncing bolas welcomes");

// See if Bola can see that they were added to the group
bola.sync_welcomes(&bola.mls_provider().unwrap())
.await
Expand Down
6 changes: 4 additions & 2 deletions xmtp_mls/src/groups/intents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
intent_kind: IntentKind,
intent_data: Vec<u8>,
) -> Result<StoredGroupIntent, GroupError> {
provider.transaction(|provider| {
let res = provider.transaction(|provider| {
let conn = provider.conn_ref();
self.queue_intent_with_conn(conn, intent_kind, intent_data)
})
});

res
}

fn queue_intent_with_conn(
Expand Down
10 changes: 6 additions & 4 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,8 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
intent_data.into(),
)?;

tracing::warn!("This makes it here?");

self.sync_until_intent_resolved(provider, intent.id).await
}

Expand Down Expand Up @@ -1250,13 +1252,13 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
state,
hex::encode(self.group_id.clone()),
);
let new_records = conn
let new_records: Vec<_> = conn
.insert_or_replace_consent_records(&[consent_record.clone()])?
.into_iter()
.map(UserPreferenceUpdate::ConsentUpdate)
.collect();

if self.client.history_sync_url().is_some() {
if !new_records.is_empty() && self.client.history_sync_url().is_some() {
// Dispatch an update event so it can be synced across devices
let _ = self
.client
Expand Down Expand Up @@ -2169,7 +2171,7 @@ pub(crate) mod tests {

// The dm shows up
let alix_groups = alix_conn
.raw_query(|conn| groups::table.load::<StoredGroup>(conn))
.raw_query_read( |conn| groups::table.load::<StoredGroup>(conn))
.unwrap();
assert_eq!(alix_groups.len(), 2);
// They should have the same ID
Expand Down Expand Up @@ -3696,7 +3698,7 @@ pub(crate) mod tests {
let conn_1: XmtpOpenMlsProvider = bo.store().conn().unwrap().into();
let conn_2 = bo.store().conn().unwrap();
conn_2
.raw_query(|c| {
.raw_query_read( |c| {
c.batch_execute("BEGIN EXCLUSIVE").unwrap();
Ok::<_, diesel::result::Error>(())
})
Expand Down
5 changes: 3 additions & 2 deletions xmtp_mls/src/storage/encrypted_store/association_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ impl StoredAssociationState {
.and(dsl::sequence_id.eq_any(sequence_ids)),
);

let association_states =
conn.raw_query(|query_conn| query.load::<StoredAssociationState>(query_conn))?;
let association_states = conn.raw_query_read( |query_conn| {
query.load::<StoredAssociationState>(query_conn)
})?;

association_states
.into_iter()
Expand Down
6 changes: 3 additions & 3 deletions xmtp_mls/src/storage/encrypted_store/consent_record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl DbConnection {
entity: String,
entity_type: ConsentType,
) -> Result<Option<StoredConsentRecord>, StorageError> {
Ok(self.raw_query(|conn| -> diesel::QueryResult<_> {
Ok(self.raw_query_read( |conn| -> diesel::QueryResult<_> {
dsl::consent_records
.filter(dsl::entity.eq(entity))
.filter(dsl::entity_type.eq(entity_type))
Expand Down Expand Up @@ -77,7 +77,7 @@ impl DbConnection {
);
}

let changed = self.raw_query(|conn| -> diesel::QueryResult<_> {
let changed = self.raw_query_write( |conn| -> diesel::QueryResult<_> {
let existing: Vec<StoredConsentRecord> = query.load(conn)?;
let changed: Vec<_> = records
.iter()
Expand Down Expand Up @@ -107,7 +107,7 @@ impl DbConnection {
&self,
record: &StoredConsentRecord,
) -> Result<Option<StoredConsentRecord>, StorageError> {
self.raw_query(|conn| {
self.raw_query_write( |conn| {
let maybe_inserted_consent_record: Option<StoredConsentRecord> =
diesel::insert_into(dsl::consent_records)
.values(record)
Expand Down
8 changes: 4 additions & 4 deletions xmtp_mls/src/storage/encrypted_store/conversation_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl DbConnection {
.select(conversation_list::all_columns())
.order(conversation_list_dsl::created_at_ns.asc());

self.raw_query(|conn| query.load::<ConversationListItem>(conn))?
self.raw_query_read(|conn| query.load::<ConversationListItem>(conn))?
} else {
// Only include the specified states
let query = query
Expand All @@ -153,19 +153,19 @@ impl DbConnection {
.select(conversation_list::all_columns())
.order(conversation_list_dsl::created_at_ns.asc());

self.raw_query(|conn| query.load::<ConversationListItem>(conn))?
self.raw_query_read(|conn| query.load::<ConversationListItem>(conn))?
}
} else {
// Handle the case where `consent_states` is `None`
self.raw_query(|conn| query.load::<ConversationListItem>(conn))?
self.raw_query_read(|conn| query.load::<ConversationListItem>(conn))?
};

// Were sync groups explicitly asked for? Was the include_sync_groups flag set to true?
// Then query for those separately
if matches!(conversation_type, Some(ConversationType::Sync)) || *include_sync_groups {
let query = conversation_list_dsl::conversation_list
.filter(conversation_list_dsl::conversation_type.eq(ConversationType::Sync));
let mut sync_groups = self.raw_query(|conn| query.load(conn))?;
let mut sync_groups = self.raw_query_read(|conn| query.load(conn))?;
conversations.append(&mut sync_groups);
}

Expand Down
87 changes: 78 additions & 9 deletions xmtp_mls/src/storage/encrypted_store/db_connection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use parking_lot::Mutex;
use std::fmt;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use crate::storage::xmtp_openmls_provider::XmtpOpenMlsProvider;
Expand All @@ -19,43 +20,102 @@ pub type DbConnection = DbConnectionPrivate<sqlite_web::connection::WasmSqliteCo
// callers should be able to accomplish everything with one conn/reference.
#[doc(hidden)]
pub struct DbConnectionPrivate<C> {
inner: Arc<Mutex<C>>,
read: Arc<Mutex<C>>,
write: Option<Arc<Mutex<C>>>,
pub(super) in_transaction: Arc<AtomicBool>,
}

/// Owned DBConnection Methods
impl<C> DbConnectionPrivate<C> {
/// Create a new [`DbConnectionPrivate`] from an existing Arc<Mutex<C>>
pub(super) fn from_arc_mutex(conn: Arc<Mutex<C>>) -> Self {
Self { inner: conn }
pub(super) fn from_arc_mutex(read: Arc<Mutex<C>>, write: Option<Arc<Mutex<C>>>) -> Self {
Self {
read,
write,
in_transaction: Arc::new(AtomicBool::new(false)),
}
}
}

impl<C> DbConnectionPrivate<C>
where
C: diesel::Connection,
{
fn in_transaction(&self) -> bool {
self.in_transaction.load(Ordering::SeqCst)
}

pub(crate) fn start_transaction(&self) -> TransactionGuard {
self.in_transaction.store(true, Ordering::SeqCst);
TransactionGuard {
in_transaction: self.in_transaction.clone(),
}
}

/// Do a scoped query with a mutable [`diesel::Connection`]
/// reference
pub(crate) fn raw_query_read<T, E, F>(&self, fun: F) -> Result<T, E>
where
F: FnOnce(&mut C) -> Result<T, E>,
{
let mut lock = self.read.lock();
fun(&mut lock)
}

/// Do a scoped query with a mutable [`diesel::Connection`]
/// reference
pub(crate) fn raw_query<T, E, F>(&self, fun: F) -> Result<T, E>
pub(crate) fn raw_query_write<T, E, F>(&self, fun: F) -> Result<T, E>
where
F: FnOnce(&mut C) -> Result<T, E>,
{
let mut lock = self.inner.lock();
if let Some(write_conn) = &self.write {
let mut lock = write_conn.lock();
return fun(&mut lock);
}

let mut lock = self.read.lock();
fun(&mut lock)
}

/// Internal-only API to get the underlying `diesel::Connection` reference
/// without a scope
/// Must be used with care. holding this reference while calling `raw_query`
/// will cause a deadlock.
pub(super) fn inner_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> {
self.inner.lock()
pub(super) fn read_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> {
if self.in_transaction() {
if let Some(write) = &self.write {
return write.lock();
}
}
self.read.lock()
}

/// Internal-only API to get the underlying `diesel::Connection` reference
/// without a scope
pub(super) fn inner_ref(&self) -> Arc<Mutex<C>> {
self.inner.clone()
pub(super) fn read_ref(&self) -> Arc<Mutex<C>> {
if self.in_transaction() {
if let Some(write) = &self.write {
return write.clone();
};
}
self.read.clone()
}

/// Internal-only API to get the underlying `diesel::Connection` reference
/// without a scope
/// Must be used with care. holding this reference while calling `raw_query`
/// will cause a deadlock.
pub(super) fn write_mut_ref(&self) -> parking_lot::MutexGuard<'_, C> {
let Some(write) = &self.write else {
return self.read_mut_ref();
};
write.lock()
}

/// Internal-only API to get the underlying `diesel::Connection` reference
/// without a scope
pub(super) fn write_ref(&self) -> Option<Arc<Mutex<C>>> {
self.write.clone()
}
}

Expand All @@ -77,3 +137,12 @@ impl<C> fmt::Debug for DbConnectionPrivate<C> {
.finish()
}
}

pub struct TransactionGuard {
in_transaction: Arc<AtomicBool>,
}
impl Drop for TransactionGuard {
fn drop(&mut self) {
self.in_transaction.store(false, Ordering::SeqCst);
}
}
Loading
Loading