Skip to content

Commit

Permalink
async
Browse files Browse the repository at this point in the history
  • Loading branch information
codabrink committed Jan 17, 2025
1 parent a79f083 commit f5ff61c
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 74 deletions.
6 changes: 3 additions & 3 deletions xmtp_mls/src/groups/device_sync/backup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ impl From<BackupOptions> for BackupMetadata {
}

impl BackupOptions {
pub fn export_to_file(
pub async fn export_to_file(
self,
provider: XmtpOpenMlsProvider,
path: impl AsRef<Path>,
) -> Result<(), DeviceSyncError> {
let provider = Arc::new(provider);
let mut exporter = BackupExporter::new(self, &provider);
exporter.write_to_file(path)?;
exporter.write_to_file(path).await?;

Ok(())
}
Expand Down Expand Up @@ -89,7 +89,7 @@ mod tests {
let mut exporter = BackupExporter::new(opts, &alix_provider);
let path = Path::new("archive.zstd");
let _ = std::fs::remove_file(path);
exporter.write_to_file(path).unwrap();
exporter.write_to_file(path).await.unwrap();

let alix2_wallet = generate_local_wallet();
let alix2 = ClientBuilder::new_test_client(&alix2_wallet).await;
Expand Down
108 changes: 66 additions & 42 deletions xmtp_mls/src/groups/device_sync/backup/backup_exporter.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use super::{export_stream::BatchExportStream, BackupOptions};
use crate::{groups::device_sync::DeviceSyncError, XmtpOpenMlsProvider};
use async_compression::tokio::write::ZstdEncoder;
use futures::{pin_mut, StreamExt};
use prost::Message;
use std::{
io::{Read, Write},
path::Path,
sync::Arc,
use std::{future::Future, path::Path, pin::Pin, sync::Arc, task::Poll};
use tokio::{
fs::File,
io::{AsyncRead, AsyncReadExt, AsyncWriteExt},
};
use xmtp_proto::xmtp::device_sync::{backup_element::Element, BackupElement, BackupMetadata};
use zstd::stream::Encoder;

pub(super) struct BackupExporter<'a> {
pub(super) struct BackupExporter {
stage: Stage,
metadata: BackupMetadata,
stream: BatchExportStream,
position: usize,
encoder: Encoder<'a, Vec<u8>>,
zstd_encoder: ZstdEncoder<Vec<u8>>,
encoder_finished: bool,
}

Expand All @@ -25,83 +26,106 @@ pub(super) enum Stage {
Elements,
}

impl<'a> BackupExporter<'a> {
impl BackupExporter {
pub(super) fn new(opts: BackupOptions, provider: &Arc<XmtpOpenMlsProvider>) -> Self {
Self {
position: 0,
stage: Stage::default(),
stream: BatchExportStream::new(&opts, provider),
metadata: opts.into(),
encoder: Encoder::new(Vec::new(), 0).unwrap(),
zstd_encoder: ZstdEncoder::new(Vec::new()),
encoder_finished: false,
}
}

pub fn write_to_file(&mut self, path: impl AsRef<Path>) -> Result<(), DeviceSyncError> {
let mut file = std::fs::File::create(path.as_ref())?;
pub async fn write_to_file(&mut self, path: impl AsRef<Path>) -> Result<(), DeviceSyncError> {
let mut file = File::create(path.as_ref()).await?;
let mut buffer = [0u8; 1024];

let mut amount = self.read(&mut buffer)?;
let mut amount = self.read(&mut buffer).await?;
while amount != 0 {
file.write(&buffer[..amount])?;
amount = self.read(&mut buffer)?;
file.write(&buffer[..amount]).await?;
amount = self.read(&mut buffer).await?;
}

file.flush()?;
file.flush().await?;

Ok(())
}
}

impl<'a> Read for BackupExporter<'a> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
impl AsyncRead for BackupExporter {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.get_mut();

{
// Read from the buffer while there is data
let buffer_inner = self.encoder.get_ref();
if self.position < buffer_inner.len() {
let available = &buffer_inner[self.position..];
let amount = available.len().min(buf.len());
let buffer_inner = this.zstd_encoder.get_ref();
if this.position < buffer_inner.len() {
let available = &buffer_inner[this.position..];
let amount = available.len().min(buf.remaining());
buf.put_slice(&available[..amount]);
this.position += amount;

buf[..amount].clone_from_slice(&available[..amount]);
self.position += amount;
return Ok(amount);
return Poll::Ready(Ok(()));
}
}

// The buffer is consumed. Reset.
self.position = 0;
self.encoder.get_mut().clear();
this.position = 0;
this.zstd_encoder.get_mut().clear();

// Time to fill the buffer with more data 8kb at a time.
while self.encoder.get_ref().len() < 8_000 {
let bytes = match self.stage {
while this.zstd_encoder.get_ref().len() < 8_000 {
let mut element = match this.stage {
Stage::Metadata => {
self.stage = Stage::Elements;
this.stage = Stage::Elements;
BackupElement {
element: Some(Element::Metadata(self.metadata.clone())),
element: Some(Element::Metadata(this.metadata.clone())),
}
.encode_to_vec()
}
Stage::Elements => match self.stream.next() {
Some(element) => element.encode_to_vec(),
None => {
if !self.encoder_finished {
self.encoder_finished = true;
self.encoder.do_finish()?;
Stage::Elements => match this.stream.poll_next_unpin(cx) {
Poll::Ready(Some(element)) => element.encode_to_vec(),
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => {
if !this.encoder_finished {
this.encoder_finished = true;
let fut = this.zstd_encoder.shutdown();
pin_mut!(fut);
let _ = fut.poll(cx)?;
}
break;
}
},
};
self.encoder.write(&(bytes.len() as u32).to_le_bytes())?;
self.encoder.write(&bytes)?;

let mut bytes = (element.len() as u32).to_le_bytes().to_vec();
bytes.append(&mut element);
let fut = this.zstd_encoder.write(&bytes);
pin_mut!(fut);
match fut.poll(cx) {
Poll::Ready(Ok(_amt)) => {}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
}

// Flush the encoder
if !this.encoder_finished {
let fut = this.zstd_encoder.flush();
pin_mut!(fut);
let _ = fut.poll(cx)?;
}
self.encoder.flush()?;

if self.encoder.get_ref().is_empty() {
Ok(0)
if this.zstd_encoder.get_ref().is_empty() {
Poll::Ready(Ok(()))
} else {
self.read(buf)
Pin::new(&mut *this).poll_read(cx, buf)
}
}
}
69 changes: 40 additions & 29 deletions xmtp_mls/src/groups/device_sync/backup/export_stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::BackupOptions;
use crate::XmtpOpenMlsProvider;
use std::{marker::PhantomData, sync::Arc};
use futures::{Stream, StreamExt};
use std::{marker::PhantomData, pin::Pin, sync::Arc, task::Poll};
use xmtp_proto::xmtp::device_sync::{
consent_backup::ConsentSave, group_backup::GroupSave, message_backup::GroupMessageSave,
BackupElement, BackupElementSelection,
Expand All @@ -10,10 +11,7 @@ pub(crate) mod consent_save;
pub(crate) mod group_save;
pub(crate) mod message_save;

pub(super) trait ExportStream {
fn next(&mut self) -> Option<Vec<BackupElement>>;
}
type BackupInputStream = Box<dyn ExportStream>;
type BackupInputStream = Pin<Box<dyn Stream<Item = Vec<BackupElement>>>>;

/// A stream that curates a collection of streams for backup.
pub(super) struct BatchExportStream {
Expand Down Expand Up @@ -45,26 +43,38 @@ impl BatchExportStream {
}
}

impl BatchExportStream {
pub(super) fn next(&mut self) -> Option<BackupElement> {
if let Some(element) = self.buffer.pop() {
return Some(element);
impl Stream for BatchExportStream {
type Item = BackupElement;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.get_mut();

if let Some(element) = this.buffer.pop() {
return Poll::Ready(Some(element));
}

loop {
let Some(last) = self.input_streams.last_mut() else {
let Some(last) = this.input_streams.last_mut() else {
// No streams left, we're done.
return None;
return Poll::Ready(None);
};

if let Some(buffer) = last.next() {
self.buffer = buffer;
if let Some(element) = self.buffer.pop() {
return Some(element);
match last.poll_next_unpin(cx) {
Poll::Ready(Some(buffer)) => {
this.buffer = buffer;
if let Some(element) = this.buffer.pop() {
return Poll::Ready(Some(element));
}
}
Poll::Ready(None) => {
// It's ended - pop the stream off below and continue
}
Poll::Pending => return Poll::Pending,
}

self.input_streams.pop();
this.input_streams.pop();
}
}
}
Expand All @@ -87,7 +97,7 @@ pub(crate) struct BackupRecordStreamer<R> {

impl<R> BackupRecordStreamer<R>
where
R: BackupRecordProvider + 'static,
R: BackupRecordProvider + Unpin + 'static,
{
pub(super) fn new(
provider: &Arc<XmtpOpenMlsProvider>,
Expand All @@ -101,26 +111,27 @@ where
_phantom: PhantomData,
};

Box::new(stream)
Box::pin(stream)
}
}

impl<R> ExportStream for BackupRecordStreamer<R>
impl<R> Stream for BackupRecordStreamer<R>
where
R: BackupRecordProvider,
R: BackupRecordProvider + Unpin,
{
fn next(&mut self) -> Option<Vec<BackupElement>> {
let batch = R::backup_records(self);
type Item = Vec<BackupElement>;
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let batch = R::backup_records(this);

// If no records found, we've reached the end of the stream
if batch.is_empty() {
return None;
return Poll::Ready(None);
}

// Update offset for next batch
self.offset += R::BATCH_SIZE;

// Return the current batch
Some(batch)
this.offset += R::BATCH_SIZE;
Poll::Ready(Some(batch))
}
}

0 comments on commit f5ff61c

Please sign in to comment.