From c30db8935541d67c30758d4a1d9fe6ee44d4e628 Mon Sep 17 00:00:00 2001 From: Elise Amber Katze Date: Tue, 5 Nov 2024 15:58:14 +0100 Subject: [PATCH] fix: re-registering interest on Windows (#274) --- monoio/src/driver/legacy/iocp/mod.rs | 41 +++++++++++--------------- monoio/src/driver/legacy/iocp/state.rs | 17 +++++++---- monoio/src/driver/legacy/mod.rs | 17 ++++++++++- monoio/src/driver/scheduled_io.rs | 9 +++++- 4 files changed, 54 insertions(+), 30 deletions(-) diff --git a/monoio/src/driver/legacy/iocp/mod.rs b/monoio/src/driver/legacy/iocp/mod.rs index 3aad2057..afec11ca 100644 --- a/monoio/src/driver/legacy/iocp/mod.rs +++ b/monoio/src/driver/legacy/iocp/mod.rs @@ -128,7 +128,8 @@ impl Poller { token: mio::Token, interests: mio::Interest, ) -> std::io::Result<()> { - if state.inner.is_none() { + let mut state_inner = state.inner.lock().unwrap(); + if state_inner.inner.is_none() { let flags = interests_to_afd_flags(interests); let inner = { @@ -143,9 +144,9 @@ impl Poller { self.queue_state(inner.clone()); unsafe { self.update_sockets_events_if_polling()? }; - state.inner = Some(inner); - state.token = token; - state.interest = interests; + state_inner.inner = Some(inner); + state_inner.token = token; + state_inner.interest = interests; Ok(()) } else { @@ -155,37 +156,31 @@ impl Poller { pub fn reregister( &self, - state: &mut SocketState, + state: Pin>>, token: mio::Token, interests: mio::Interest, ) -> std::io::Result<()> { - if let Some(inner) = state.inner.as_mut() { - { - let event = Event { - flags: interests_to_afd_flags(interests), - data: token.0 as u64, - }; - - inner.lock().unwrap().set_event(event); - } - - state.token = token; - state.interest = interests; + { + let event = Event { + flags: interests_to_afd_flags(interests), + data: token.0 as u64, + }; - self.queue_state(inner.clone()); - unsafe { self.update_sockets_events_if_polling() } - } else { - Err(std::io::ErrorKind::NotFound.into()) + state.lock().unwrap().set_event(event); } + + self.queue_state(state.clone()); + unsafe { self.update_sockets_events_if_polling() } } pub fn deregister(&mut self, state: &mut SocketState) -> std::io::Result<()> { - if let Some(inner) = state.inner.as_mut() { + let mut state_inner = state.inner.lock().unwrap(); + if let Some(inner) = state_inner.inner.as_mut() { { let mut sock_state = inner.lock().unwrap(); sock_state.mark_delete(); } - state.inner = None; + state_inner.inner = None; Ok(()) } else { Err(std::io::ErrorKind::NotFound.into()) diff --git a/monoio/src/driver/legacy/iocp/state.rs b/monoio/src/driver/legacy/iocp/state.rs index a550eb6e..7271a153 100644 --- a/monoio/src/driver/legacy/iocp/state.rs +++ b/monoio/src/driver/legacy/iocp/state.rs @@ -25,20 +25,27 @@ pub enum SockPollStatus { } #[derive(Debug)] -pub struct SocketState { - pub socket: RawSocket, +pub struct SocketStateInner { pub inner: Option>>>, pub token: mio::Token, pub interest: mio::Interest, } +#[derive(Debug)] +pub struct SocketState { + pub socket: RawSocket, + pub inner: Arc>, +} + impl SocketState { pub fn new(socket: RawSocket) -> Self { Self { socket, - inner: None, - token: mio::Token(0), - interest: mio::Interest::READABLE, + inner: Arc::new(Mutex::new(SocketStateInner { + inner: None, + token: mio::Token(0), + interest: mio::Interest::READABLE, + })) } } } diff --git a/monoio/src/driver/legacy/mod.rs b/monoio/src/driver/legacy/mod.rs index 9d0f796e..8f6f7160 100644 --- a/monoio/src/driver/legacy/mod.rs +++ b/monoio/src/driver/legacy/mod.rs @@ -182,7 +182,7 @@ impl LegacyDriver { interest: mio::Interest, ) -> io::Result { let inner = unsafe { &mut *this.get() }; - let io = ScheduledIo::default(); + let io = ScheduledIo::new(state.inner.clone()); let token = inner.io_dispatch.insert(io); match inner.poll.register(state, mio::Token(token), interest) { @@ -303,6 +303,21 @@ impl LegacyInner { flags: 0, }), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + #[cfg(windows)] + { + if let Some((sock_state, token, interest)) = { + let socket_state_lock = ref_mut.state.lock().unwrap(); + socket_state_lock.inner.clone().map(|inner| (inner, socket_state_lock.token, socket_state_lock.interest)) + } { + if let Err(e) = inner.poll.reregister(sock_state, token, interest) { + return Poll::Ready(CompletionMeta { + result: Err(e), + flags: 0, + }); + } + } + } + ref_mut.clear_readiness(direction.mask()); ref_mut.set_waker(cx, direction); Poll::Pending diff --git a/monoio/src/driver/scheduled_io.rs b/monoio/src/driver/scheduled_io.rs index d164a1d3..5b5cd578 100644 --- a/monoio/src/driver/scheduled_io.rs +++ b/monoio/src/driver/scheduled_io.rs @@ -9,8 +9,13 @@ pub(crate) struct ScheduledIo { reader: Option, /// Waker used for AsyncWrite. writer: Option, + + #[cfg(windows)] + pub state: std::sync::Arc>, } + +#[cfg(not(windows))] impl Default for ScheduledIo { #[inline] fn default() -> Self { @@ -19,11 +24,13 @@ impl Default for ScheduledIo { } impl ScheduledIo { - pub(crate) const fn new() -> Self { + pub(crate) const fn new(#[cfg(windows)] state: std::sync::Arc>) -> Self { Self { readiness: Ready::EMPTY, reader: None, writer: None, + #[cfg(windows)] + state, } }