Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/hop by hop async #3

Merged
merged 6 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 144 additions & 28 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,59 +1,175 @@
use crate::diameter::DiameterMessage;
use crate::error::Error;
use std::collections::HashMap;
use std::io::Cursor;
use std::sync::{Arc, Mutex};
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio::sync::oneshot::Receiver;
use tokio::sync::oneshot::Sender;

pub struct DiameterClient {
stream: Option<TcpStream>,
writer: Option<Arc<Mutex<OwnedWriteHalf>>>,
msg_caches: Arc<Mutex<HashMap<u32, Sender<DiameterMessage>>>>,
}

impl DiameterClient {
pub fn new() -> DiameterClient {
DiameterClient { stream: None }
DiameterClient {
writer: None,
msg_caches: Arc::new(Mutex::new(HashMap::new())),
}
}

pub async fn connect(&mut self, addr: &str) -> Result<(), Error> {
let stream = TcpStream::connect(addr).await?;
self.stream = Some(stream);

let (mut reader, writer) = stream.into_split();
let writer = Arc::new(Mutex::new(writer));
self.writer = Some(writer);

let msg_caches = Arc::clone(&self.msg_caches);
tokio::spawn(async move {
loop {
match Self::read(&mut reader).await {
Ok(res) => {
if let Err(e) = Self::process_response(msg_caches.clone(), res).await {
log::error!("Failed to process response; error: {:?}", e);
return;
}
}
Err(e) => {
log::error!("Failed to read message from socket; error: {:?}", e);
return;
}
}
}
});

Ok(())
}

pub async fn send(&mut self, req: DiameterMessage) -> Result<DiameterMessage, Error> {
if let Some(stream) = self.stream.as_mut() {
// Encode Request
let mut encoded = Vec::new();
req.encode_to(&mut encoded)?;
async fn process_response(
msg_caches: Arc<Mutex<HashMap<u32, Sender<DiameterMessage>>>>,
res: DiameterMessage,
) -> Result<(), Error> {
let hop_by_hop = res.get_hop_by_hop_id();

// Send Request
stream.write_all(&encoded).await?;
let sender_opt = {
let mut msg_caches = msg_caches.lock()?;

// Read first 4 bytes to determine the length
let mut b = [0; 4];
stream.read_exact(&mut b).await?;
let length = u32::from_be_bytes([0, b[1], b[2], b[3]]);

// Limit to 1MB
if length as usize > 1024 * 1024 {
return Err(Error::ClientError("Message too large ".into()));
msg_caches.remove(&hop_by_hop)
};
match sender_opt {
Some(sender) => {
sender.send(res).map_err(|e| {
Error::ClientError(format!("Failed to send response; error: {:?}", e))
})?;
}
None => {
Err(Error::ClientError(format!(
"No request found for hop_by_hop_id {}",
hop_by_hop
)))?;
}
};
Ok(())
}

// Read the rest of the message
let mut buffer = Vec::with_capacity(length as usize);
buffer.extend_from_slice(&b);
buffer.resize(length as usize, 0);
stream.read_exact(&mut buffer[4..]).await?;
async fn read(reader: &mut OwnedReadHalf) -> Result<DiameterMessage, Error> {
let mut b = [0; 4];
reader.read_exact(&mut b).await?;
let length = u32::from_be_bytes([0, b[1], b[2], b[3]]);

// Decode Response
let mut cursor = Cursor::new(buffer);
let res = DiameterMessage::decode_from(&mut cursor)?;
// Limit to 1MB
if length as usize > 1024 * 1024 {
return Err(Error::ClientError("Message too large to read".into()));
}

// Read the rest of the message
let mut buffer = Vec::with_capacity(length as usize);
buffer.extend_from_slice(&b);
buffer.resize(length as usize, 0);
reader.read_exact(&mut buffer[4..]).await?;

// Decode Response
let mut cursor = Cursor::new(buffer);
let res = DiameterMessage::decode_from(&mut cursor)?;
Ok(res)
}

Ok(res)
pub fn request(&mut self, req: DiameterMessage) -> Result<DiameterRequest, Error> {
if let Some(writer) = &self.writer {
let (tx, rx) = oneshot::channel();
let hop_by_hop = req.get_hop_by_hop_id();
{
let mut msg_caches = self.msg_caches.lock()?;
msg_caches.insert(hop_by_hop, tx);
}

Ok(DiameterRequest::new(req, rx, Arc::clone(&writer)))
} else {
Err(Error::ClientError("Not connected".into()))
}
}

pub async fn send_message(&mut self, req: DiameterMessage) -> Result<DiameterMessage, Error> {
let mut request = self.request(req)?;
let _ = request.send().await?;
let response = request.response().await?;
Ok(response)
}
}

pub struct DiameterRequest {
request: DiameterMessage,
receiver: Arc<Mutex<Option<Receiver<DiameterMessage>>>>,
writer: Arc<Mutex<OwnedWriteHalf>>,
}

impl DiameterRequest {
pub fn new(
request: DiameterMessage,
receiver: Receiver<DiameterMessage>,
writer: Arc<Mutex<OwnedWriteHalf>>,
) -> Self {
DiameterRequest {
request,
receiver: Arc::new(Mutex::new(Some(receiver))),
writer,
}
}

pub fn get_request(&self) -> &DiameterMessage {
&self.request
}

pub async fn send(&mut self) -> Result<(), Error> {
let mut encoded = Vec::new();
self.request.encode_to(&mut encoded)?;

let mut writer = self.writer.lock()?;
writer.write_all(&encoded).await?;

Ok(())
}

pub async fn response(&self) -> Result<DiameterMessage, Error> {
let rx = self
.receiver
.lock()?
.take()
.ok_or_else(|| Error::ClientError("Response already taken".into()))?;

let res = rx.await.map_err(|e| {
Error::ClientError(format!("Failed to receive response; error: {:?}", e))
})?;

Ok(res)
}
}

#[cfg(test)]
Expand Down Expand Up @@ -85,7 +201,7 @@ mod tests {

let mut client = DiameterClient::new();
let _ = client.connect("localhost:3868").await;
let response = client.send(ccr).await.unwrap();
let response = client.send_message(ccr).await.unwrap();
println!("Response: {}", response);
}
}
12 changes: 12 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::fmt;
use std::sync::{MutexGuard, PoisonError};

#[derive(Debug)]
pub enum CLientError {}
Expand All @@ -12,6 +13,7 @@ pub enum Error {
ServerError(String),
IoError(std::io::Error),
TryFromSliceError(std::array::TryFromSliceError),
LockError(String),
}

impl fmt::Display for Error {
Expand All @@ -24,20 +26,30 @@ impl fmt::Display for Error {
Error::ServerError(msg) => write!(f, "{}", msg),
Error::IoError(e) => write!(f, "{}", e),
Error::TryFromSliceError(e) => write!(f, "{}", e),
Error::LockError(msg) => write!(f, "{}", msg),
}
}
}

impl std::error::Error for Error {}

// io error
impl From<std::io::Error> for Error {
fn from(err: std::io::Error) -> Self {
Error::IoError(err)
}
}

// try from slice error
impl From<std::array::TryFromSliceError> for Error {
fn from(err: std::array::TryFromSliceError) -> Error {
Error::TryFromSliceError(err)
}
}

// lock error
impl<'a, T> From<PoisonError<MutexGuard<'a, T>>> for Error {
fn from(err: PoisonError<MutexGuard<'a, T>>) -> Self {
Error::LockError(format!("Lock error: {}", err))
}
}
5 changes: 3 additions & 2 deletions src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ mod tests {
#[tokio::test]
async fn test_diameter_transport() {
// Diameter Server
let mut server = DiameterServer::new("0.0.0.0:3868").await.unwrap();

tokio::spawn(async move {
let mut server = DiameterServer::new("0.0.0.0:3868").await.unwrap();
server
.handle(|req| -> Result<DiameterMessage, Error> {
println!("Request : {}", req);
Expand Down Expand Up @@ -55,7 +56,7 @@ mod tests {
ccr.add_avp(avp!(263, None, UTF8StringAvp::new("ses;12345888"), true));
ccr.add_avp(avp!(416, None, EnumeratedAvp::new(1), true));
ccr.add_avp(avp!(415, None, Unsigned32Avp::new(1000), true));
let cca = client.send(ccr).await.unwrap();
let cca = client.send_message(ccr).await.unwrap();

println!("Response: {}", cca);

Expand Down