Skip to content

Commit

Permalink
fix: infinite reading loop when encountering eofs
Browse files Browse the repository at this point in the history
  • Loading branch information
threadexio committed May 11, 2024
1 parent d6e7c0e commit 6df031d
Show file tree
Hide file tree
Showing 12 changed files with 199 additions and 134 deletions.
56 changes: 46 additions & 10 deletions channels-io/src/async_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use core::future::Future;
use core::pin::Pin;
use core::task::{ready, Context, Poll};

use crate::ReadBuf;
use crate::{ReadBuf, ReadError};

/// This trait is the asynchronous version of [`Read`].
///
Expand All @@ -11,7 +11,7 @@ pub trait AsyncRead: Unpin {
/// Error type for [`read()`].
///
/// [`read()`]: AsyncRead::read
type Error;
type Error: ReadError;

/// Asynchronously read some bytes into `buf`.
///
Expand All @@ -32,7 +32,39 @@ pub trait AsyncRead: Unpin {
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<Result<(), Self::Error>>;
) -> Poll<Result<(), Self::Error>> {
default_poll_read(self, cx, buf)
}

/// Poll the reader once and read some bytes into the slice `buf`.
///
/// This method reads bytes directly into `buf` and reports how many bytes it
/// read.
fn poll_read_slice(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<Result<usize, Self::Error>>;
}

fn default_poll_read<T: AsyncRead + ?Sized>(
mut reader: Pin<&mut T>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<Result<(), T::Error>> {
while !buf.unfilled().is_empty() {
match ready!(reader
.as_mut()
.poll_read_slice(cx, buf.unfilled_mut()))
{
Ok(0) => return Poll::Ready(Err(T::Error::eof())),
Ok(n) => buf.advance(n),
Err(e) if e.should_retry() => continue,
Err(e) => return Poll::Ready(Err(e)),
}
}

Poll::Ready(Ok(()))
}

#[allow(missing_debug_implementations)]
Expand All @@ -58,13 +90,8 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Self::Output> {
let Self { ref mut reader, ref mut buf, .. } = *self;

while !buf.unfilled().is_empty() {
ready!(Pin::new(&mut **reader).poll_read(cx, buf))?;
}

Poll::Ready(Ok(()))
let Self { ref mut reader, ref mut buf } = *self;
Pin::new(&mut **reader).poll_read(cx, buf)
}
}

Expand All @@ -80,6 +107,15 @@ macro_rules! forward_impl_async_read {
let this = Pin::new(&mut **self);
<$to>::poll_read(this, cx, buf)
}

fn poll_read_slice(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<Result<usize, Self::Error>> {
let this = Pin::new(&mut **self);
<$to>::poll_read_slice(this, cx, buf)
}
};
}

Expand Down
11 changes: 11 additions & 0 deletions channels-io/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/// A trait that describes an error returned by [`Read`] and [`AsyncRead`].
///
/// [`Read`]: trait@crate::Read
/// [`AsyncRead`]: trait@crate::AsyncRead
pub trait ReadError {
/// Create a new End-Of-File error.
fn eof() -> Self;

/// Checks whether the given error indicates that the operation should be retried.
fn should_retry(&self) -> bool;
}
31 changes: 16 additions & 15 deletions channels-io/src/impls/core2.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
use super::prelude::*;

use ::core2::io::ErrorKind as E;

impl ReadError for ::core2::io::Error {
fn eof() -> Self {
Self::from(E::UnexpectedEof)
}

fn should_retry(&self) -> bool {
self.kind() == E::UnexpectedEof
}
}

newtype! {
/// Wrapper IO type for [`core2::io::Read`] and [`core2::io::Write`].
Core2
Expand All @@ -13,20 +25,11 @@ where
{
type Error = ::core2::io::Error;

fn read(
fn read_slice(
&mut self,
mut buf: &mut [u8],
) -> Result<(), Self::Error> {
while !buf.is_empty() {
use ::core2::io::ErrorKind as E;
match self.0.read(buf) {
Ok(i) => buf = &mut buf[i..],
Err(e) if e.kind() == E::Interrupted => continue,
Err(e) => return Err(e),
}
}

Ok(())
buf: &mut [u8],
) -> Result<usize, Self::Error> {
self.0.read(buf)
}
}

Expand All @@ -40,7 +43,6 @@ where

fn write(&mut self, mut buf: &[u8]) -> Result<(), Self::Error> {
while !buf.is_empty() {
use ::core2::io::ErrorKind as E;
match self.0.write(buf) {
Ok(i) => buf = &buf[i..],
Err(e) if e.kind() == E::Interrupted => continue,
Expand All @@ -53,7 +55,6 @@ where

fn flush(&mut self) -> Result<(), Self::Error> {
loop {
use ::core2::io::ErrorKind as E;
match self.0.flush() {
Ok(()) => break Ok(()),
Err(e) if e.kind() == E::Interrupted => continue,
Expand Down
27 changes: 6 additions & 21 deletions channels-io/src/impls/futures.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use super::prelude::*;

use ::std::io::ErrorKind as E;

newtype! {
/// Wrapper IO type for [`futures::AsyncRead`] and [`futures::AsyncWrite`].
Futures
Expand All @@ -13,25 +15,12 @@ where
{
type Error = ::futures::io::Error;

fn poll_read(
fn poll_read_slice(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<Result<(), Self::Error>> {
use ::std::io::ErrorKind as E;

while !buf.unfilled().is_empty() {
match ready!(Pin::new(&mut self.0)
.poll_read(cx, buf.unfilled_mut()))
{
Ok(0) => break,
Ok(n) => buf.advance(n),
Err(e) if e.kind() == E::Interrupted => continue,
Err(e) => return Ready(Err(e)),
}
}

Ready(Ok(()))
buf: &mut [u8],
) -> Poll<Result<usize, Self::Error>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

Expand All @@ -48,8 +37,6 @@ where
cx: &mut Context,
buf: &mut WriteBuf,
) -> Poll<Result<(), Self::Error>> {
use ::std::io::ErrorKind as E;

while !buf.remaining().is_empty() {
match ready!(
Pin::new(&mut self.0).poll_write(cx, buf.remaining())
Expand All @@ -68,8 +55,6 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Result<(), Self::Error>> {
use ::std::io::ErrorKind as E;

loop {
match ready!(Pin::new(&mut self.0).poll_flush(cx)) {
Ok(()) => return Ready(Ok(())),
Expand Down
3 changes: 2 additions & 1 deletion channels-io/src/impls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ use newtype;
#[allow(unused_imports)]
mod prelude {
pub(super) use crate::{
AsyncRead, AsyncWrite, Read, ReadBuf, Write, WriteBuf,
AsyncRead, AsyncWrite, Read, ReadBuf, ReadError, Write,
WriteBuf,
};

pub(super) use super::{
Expand Down
15 changes: 11 additions & 4 deletions channels-io/src/impls/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ where
fn read(&mut self, buf: &mut [u8]) -> Result<(), Self::Error> {
self.0.read(buf)
}

fn read_slice(
&mut self,
buf: &mut [u8],
) -> Result<usize, Self::Error> {
self.0.read_slice(buf)
}
}

impl_newtype_write! { Native: Write }
Expand Down Expand Up @@ -48,12 +55,12 @@ where
{
type Error = T::Error;

fn poll_read(
fn poll_read_slice(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0).poll_read(cx, buf)
buf: &mut [u8],
) -> Poll<Result<usize, Self::Error>> {
Pin::new(&mut self.0).poll_read_slice(cx, buf)
}
}

Expand Down
39 changes: 18 additions & 21 deletions channels-io/src/impls/smol.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
use super::prelude::*;

#[allow(unused_imports)]
use ::smol::io::ErrorKind as E;

#[cfg(not(feature = "std"))]
impl ReadError for ::smol::io::Error {
fn eof() -> Self {
Self::from(E::UnexpectedEof)
}

fn should_retry(&self) -> bool {
self.kind() == E::Interrupted
}
}

newtype! {
/// Wrapper IO type for [`smol::io::AsyncRead`] and [`smol::io::AsyncWrite`].
Smol
Expand All @@ -13,25 +27,12 @@ where
{
type Error = ::smol::io::Error;

fn poll_read(
fn poll_read_slice(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<Result<(), Self::Error>> {
use ::smol::io::ErrorKind as E;

while !buf.unfilled().is_empty() {
match ready!(Pin::new(&mut self.0)
.poll_read(cx, buf.unfilled_mut()))
{
Ok(0) => break,
Ok(n) => buf.advance(n),
Err(e) if e.kind() == E::Interrupted => continue,
Err(e) => return Ready(Err(e)),
}
}

Ready(Ok(()))
buf: &mut [u8],
) -> Poll<Result<usize, Self::Error>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

Expand All @@ -48,8 +49,6 @@ where
cx: &mut Context,
buf: &mut WriteBuf,
) -> Poll<Result<(), Self::Error>> {
use ::smol::io::ErrorKind as E;

while !buf.remaining().is_empty() {
match ready!(
Pin::new(&mut self.0).poll_write(cx, buf.remaining())
Expand All @@ -68,8 +67,6 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Result<(), Self::Error>> {
use ::smol::io::ErrorKind as E;

loop {
match ready!(Pin::new(&mut self.0).poll_flush(cx)) {
Ok(()) => return Ready(Ok(())),
Expand Down
31 changes: 16 additions & 15 deletions channels-io/src/impls/std.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
use super::prelude::*;

use ::std::io::ErrorKind as E;

impl ReadError for ::std::io::Error {
fn eof() -> Self {
Self::from(E::UnexpectedEof)
}

fn should_retry(&self) -> bool {
self.kind() == E::Interrupted
}
}

newtype! {
/// Wrapper IO type for [`std::io::Read`] and [`std::io::Write`].
Std
Expand All @@ -13,20 +25,11 @@ where
{
type Error = ::std::io::Error;

fn read(
fn read_slice(
&mut self,
mut buf: &mut [u8],
) -> Result<(), Self::Error> {
while !buf.is_empty() {
use ::std::io::ErrorKind as E;
match self.0.read(buf) {
Ok(i) => buf = &mut buf[i..],
Err(e) if e.kind() == E::Interrupted => continue,
Err(e) => return Err(e),
}
}

Ok(())
buf: &mut [u8],
) -> Result<usize, Self::Error> {
self.0.read(buf)
}
}

Expand All @@ -40,7 +43,6 @@ where

fn write(&mut self, mut buf: &[u8]) -> Result<(), Self::Error> {
while !buf.is_empty() {
use ::std::io::ErrorKind as E;
match self.0.write(buf) {
Ok(i) => buf = &buf[i..],
Err(e) if e.kind() == E::Interrupted => continue,
Expand All @@ -53,7 +55,6 @@ where

fn flush(&mut self) -> Result<(), Self::Error> {
loop {
use ::std::io::ErrorKind as E;
match self.0.flush() {
Ok(()) => break Ok(()),
Err(e) if e.kind() == E::Interrupted => continue,
Expand Down
Loading

0 comments on commit 6df031d

Please sign in to comment.