Skip to content

Commit

Permalink
Wrap fd listening callbacks in a class
Browse files Browse the repository at this point in the history
  • Loading branch information
LasseBlaauwbroek committed Apr 3, 2023
1 parent c2e0765 commit 64e9141
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 69 deletions.
64 changes: 22 additions & 42 deletions capnp/helpers/asyncProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,15 @@ static constexpr uint NEW_FD_FLAGS =
class OwnedFileDescriptor {
public:
OwnedFileDescriptor(int fd, uint flags,
void (*ar)(int, void (*cb)(void* data), void* data),
void (*rr)(int),
void (*aw)(int, void (*cb)(void* data), void* data),
void (*rw)(int))
: fd(applyFlags(fd, flags)), flags(flags),
add_reader(ar), remove_reader(rr), add_writer(ar), remove_writer(rw) {
PyFdListener *fdListener)
: fd(applyFlags(fd, flags)), flags(flags), fdListener(fdListener) {
readRegistered = false;
writeRegistered = false;
}

~OwnedFileDescriptor() noexcept(false) {
remove_reader(fd);
remove_writer(fd);
fdListener->remove_reader(fd);
fdListener->remove_writer(fd);

// Don't use KJ_SYSCALL() here because close() should not be repeated on EINTR.
if ((flags & kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP) && close(fd) < 0) {
Expand All @@ -104,10 +100,7 @@ class OwnedFileDescriptor {
protected:
const int fd;
uint flags;
void (*add_reader)(int, void (*cb)(void* data), void* data);
void (*remove_reader)(int);
void (*add_writer)(int, void (*cb)(void* data), void* data);
void (*remove_writer)(int);
PyFdListener *fdListener;

private:
bool readRegistered;
Expand All @@ -117,12 +110,12 @@ class OwnedFileDescriptor {
public:
ReadPromiseAdapter(kj::PromiseFulfiller<void>& fulfiller, OwnedFileDescriptor& ofd)
: fulfiller(fulfiller), ofd(ofd) {
ofd.add_reader(ofd.fd, &readCallback, (void*)this);
ofd.fdListener->add_reader(ofd.fd, &readCallback, (void*)this);
ofd.readRegistered = true;
}

~ReadPromiseAdapter() {
ofd.remove_reader(ofd.fd);
ofd.fdListener->remove_reader(ofd.fd);
ofd.readRegistered = false;
}

Expand All @@ -145,12 +138,12 @@ class OwnedFileDescriptor {
public:
WritePromiseAdapter(kj::PromiseFulfiller<void>& fulfiller, OwnedFileDescriptor& ofd)
: fulfiller(fulfiller), ofd(ofd) {
ofd.add_writer(ofd.fd, &writeCallback, (void*)this);
ofd.fdListener->add_writer(ofd.fd, &writeCallback, (void*)this);
ofd.writeRegistered = true;
}

~WritePromiseAdapter() {
ofd.remove_writer(ofd.fd);
ofd.fdListener->remove_writer(ofd.fd);
ofd.writeRegistered = false;
}

Expand Down Expand Up @@ -178,12 +171,8 @@ class PyIoStream: public OwnedFileDescriptor, public kj::AsyncIoStream {
// TODO(cleanup): Allow better code sharing between the two.

public:
PyIoStream(int fd, uint flags,
void (*ar)(int, void (*cb)(void* data), void* data),
void (*rr)(int),
void (*aw)(int, void (*cb)(void* data), void* data),
void (*rw)(int))
: OwnedFileDescriptor(fd, flags, ar, rr, aw, rw) {}
PyIoStream(int fd, uint flags, PyFdListener *fdListener)
: OwnedFileDescriptor(fd, flags, fdListener) {}
virtual ~PyIoStream() noexcept(false) {}

kj::Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
Expand Down Expand Up @@ -340,12 +329,8 @@ class PyConnectionReceiver final: public kj::ConnectionReceiver, public OwnedFil
// Like PyIoStream but for ConnectionReceiver. This is also largely copied from kj/async-io.c++.

public:
PyConnectionReceiver(int fd, uint flags,
void (*ar)(int, void (*cb)(void* data), void* data),
void (*rr)(int),
void (*aw)(int, void (*cb)(void* data), void* data),
void (*rw)(int))
: OwnedFileDescriptor(fd, flags, ar, rr, aw, rw) {}
PyConnectionReceiver(int fd, uint flags, PyFdListener *fdListener)
: OwnedFileDescriptor(fd, flags, fdListener) {}

kj::Promise<kj::Own<kj::AsyncIoStream>> accept() override {
int newFd;
Expand All @@ -358,9 +343,7 @@ class PyConnectionReceiver final: public kj::ConnectionReceiver, public OwnedFil
#endif

if (newFd >= 0) {
return kj::Own<kj::AsyncIoStream>(kj::heap<PyIoStream>(newFd, NEW_FD_FLAGS,
add_reader, remove_reader,
add_writer, remove_writer));
return kj::Own<kj::AsyncIoStream>(kj::heap<PyIoStream>(newFd, NEW_FD_FLAGS, fdListener));
} else {
int error = errno;

Expand Down Expand Up @@ -412,21 +395,18 @@ class PyConnectionReceiver final: public kj::ConnectionReceiver, public OwnedFil
}
};

PyLowLevelAsyncIoProvider::PyLowLevelAsyncIoProvider(void (*ar)(int, void (*cb)(void* data), void* data),
void (*rr)(int),
void (*aw)(int, void (*cb)(void* data), void* data),
void (*rw)(int),
PyLowLevelAsyncIoProvider::PyLowLevelAsyncIoProvider(PyFdListener *fdListener,
kj::Timer* t) :
add_reader(ar), remove_reader(rr), add_writer(ar), remove_writer(rw), timer(t) {}
fdListener(fdListener), timer(t) {}

kj::Own<kj::AsyncInputStream> PyLowLevelAsyncIoProvider::wrapInputFd(int fd, uint flags) {
return kj::heap<PyIoStream>(fd, flags, add_reader, remove_reader, add_writer, remove_writer);
return kj::heap<PyIoStream>(fd, flags, fdListener);
}
kj::Own<kj::AsyncOutputStream> PyLowLevelAsyncIoProvider::wrapOutputFd(int fd, uint flags) {
return kj::heap<PyIoStream>(fd, flags, add_reader, remove_reader, add_writer, remove_writer);
return kj::heap<PyIoStream>(fd, flags, fdListener);
}
kj::Own<kj::AsyncIoStream> PyLowLevelAsyncIoProvider::wrapSocketFd(int fd, uint flags) {
return kj::heap<PyIoStream>(fd, flags, add_reader, remove_reader, add_writer, remove_writer);
return kj::heap<PyIoStream>(fd, flags, fdListener);
}
kj::Promise<kj::Own<kj::AsyncIoStream>> PyLowLevelAsyncIoProvider::wrapConnectingSocketFd(
int fd, const struct sockaddr* addr, uint addrlen, uint flags) {
Expand All @@ -448,7 +428,7 @@ kj::Promise<kj::Own<kj::AsyncIoStream>> PyLowLevelAsyncIoProvider::wrapConnectin
}
}

auto result = kj::heap<PyIoStream>(fd, flags, add_reader, remove_reader, add_writer, remove_writer);
auto result = kj::heap<PyIoStream>(fd, flags, fdListener);
auto connected = result->onWritable();
return connected.then(kj::mvCapture(result,
[fd](kj::Own<kj::AsyncIoStream>&& stream) {
Expand All @@ -464,14 +444,14 @@ kj::Promise<kj::Own<kj::AsyncIoStream>> PyLowLevelAsyncIoProvider::wrapConnectin

#if CAPNP_VERSION < 7000
kj::Own<kj::ConnectionReceiver> PyLowLevelAsyncIoProvider::wrapListenSocketFd(int fd, uint flags) {
return kj::heap<PyConnectionReceiver>(fd, flags, add_reader, remove_reader, add_writer, remove_writer);
return kj::heap<PyConnectionReceiver>(fd, flags, fdListener);
}
#else
kj::Own<kj::ConnectionReceiver> PyLowLevelAsyncIoProvider::wrapListenSocketFd(int fd,
kj::LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags) {
// TODO(soon): TODO(security): Actually use `filter`. Currently no API is exposed to set a
// filter so it's not important yet.
return kj::heap<PyConnectionReceiver>(fd, flags, add_reader, remove_reader, add_writer, remove_writer);
return kj::heap<PyConnectionReceiver>(fd, flags, fdListener);
}
#endif

Expand Down
18 changes: 10 additions & 8 deletions capnp/helpers/asyncProvider.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@

using namespace kj;

class PyFdListener {
public:
virtual void add_reader(int, void (*cb)(void* data), void* data) = 0;
virtual void remove_reader(int) = 0;
virtual void add_writer(int, void (*cb)(void* data), void* data) = 0;
virtual void remove_writer(int) = 0;
};

class PyLowLevelAsyncIoProvider final: public kj::LowLevelAsyncIoProvider {
public:
PyLowLevelAsyncIoProvider(void (*ar)(int, void (*cb)(void* data), void* data),
void (*rr)(int),
void (*aw)(int, void (*cb)(void* data), void* data),
void (*rw)(int),
PyLowLevelAsyncIoProvider(PyFdListener *fdListener,
kj::Timer* timer);

kj::Own<kj::AsyncInputStream> wrapInputFd(Fd fd, uint flags = 0);
Expand All @@ -20,9 +25,6 @@ class PyLowLevelAsyncIoProvider final: public kj::LowLevelAsyncIoProvider {
kj::Own<kj::ConnectionReceiver> wrapListenSocketFd(Fd fd, NetworkFilter& filter, uint flags = 0);

private:
void (*add_reader)(int, void (*cb)(void* data), void* data);
void (*remove_reader)(int);
void (*add_writer)(int, void (*cb)(void* data), void* data);
void (*remove_writer)(int);
PyFdListener *fdListener;
kj::Timer *timer;
};
10 changes: 6 additions & 4 deletions capnp/includes/capnp_cpp.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -547,12 +547,14 @@ cdef extern from "kj/async.h" namespace " ::kj":
PyPromiseArray joinPromises(Array[PyPromise]) nogil

cdef extern from "capnp/helpers/asyncProvider.h":
cdef cppclass PyFdListener:
void add_reader(int, void (*cb)(void* data), void* data) with gil
void remove_reader(int) with gil
void add_writer(int, void (*cb)(void* data), void* data) with gil
void remove_writer(int) with gil
cdef cppclass PyLowLevelAsyncIoProvider(LowLevelAsyncIoProvider):
pass

Own[LowLevelAsyncIoProvider] makePyLowLevelAsyncIoProvider" ::kj::heap<PyLowLevelAsyncIoProvider>"(
void (*ar)(int, void (*cb)(void* data), void* data),
void (*rr)(int),
void (*aw)(int, void (*cb)(void* data), void* data),
void (*rw)(int),
PyFdListener *fdListener,
Timer *timer)
40 changes: 25 additions & 15 deletions capnp/lib/capnp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
cimport cython # noqa: E402

from capnp.helpers.helpers cimport init_capnp_api
from capnp.includes.capnp_cpp cimport AsyncIoStream, WaitScope, PyPromise, VoidPromise, EventPort, EventLoop, WaitScope, makePyLowLevelAsyncIoProvider, LowLevelAsyncIoProvider, AsyncIoProvider, newAsyncIoProvider, MonotonicClock, Timer, TimerImpl, systemPreciseMonotonicClock, MILLISECONDS
from capnp.includes.capnp_cpp cimport AsyncIoStream, WaitScope, PyPromise, VoidPromise, EventPort, EventLoop, WaitScope, makePyLowLevelAsyncIoProvider, LowLevelAsyncIoProvider, AsyncIoProvider, newAsyncIoProvider, MonotonicClock, Timer, TimerImpl, systemPreciseMonotonicClock, MILLISECONDS, PyFdListener

from cpython cimport array, Py_buffer, PyObject_CheckBuffer
from cpython.buffer cimport PyBUF_SIMPLE, PyBUF_WRITABLE
Expand Down Expand Up @@ -1782,21 +1782,42 @@ cdef void kjloop_advance_callback(void* data) with gil:
assert port.runHandle is not None
port.timerImpl.advanceTo(systemPreciseMonotonicClock().now())

cdef cppclass AsyncIoPyFdListener(PyFdListener):
object loop

__init__(object loop):
this.loop = loop

void add_reader(int fd, void (*cb)(void* data), void* data) with gil:
this.loop.add_reader(fd, lambda: cb(data))

void remove_reader(int fd) with gil:
this.loop.remove_reader(fd)

void add_writer(int fd, void (*cb)(void* data), void* data) with gil:
this.loop.add_writer(fd, lambda: cb(data))

void remove_writer(int fd) with gil:
this.loop.remove_writer(fd)

cdef cppclass AsyncIoEventPort(EventPort):
EventLoop *kjLoop
TimerImpl *timerImpl;
AsyncIoPyFdListener *fdListener
object asyncioLoop;
object runHandle;

__init__(object asyncioLoop):
this.kjLoop = new EventLoop(deref(this))
this.timerImpl = new TimerImpl(systemPreciseMonotonicClock().now())
this.fdListener = new AsyncIoPyFdListener(asyncioLoop)
this.runHandle = None
this.asyncioLoop = asyncioLoop

__dealloc__():
del this.timerImpl
del this.kjLoop
del this.fdListener

cbool wait() with gil:
raise KjException("Currently you cannot wait for promises while pycapnp is running in asyncio mode. " +
Expand Down Expand Up @@ -1836,18 +1857,8 @@ cdef cppclass AsyncIoEventPort(EventPort):
Timer *getTimer():
return this.timerImpl;


cdef void add_reader(int fd, void (*cb)(void* data), void* data) with gil:
asyncio.get_running_loop().add_reader(fd, lambda: cb(data))

cdef void remove_reader(int fd) with gil:
asyncio.get_running_loop().remove_reader(fd)

cdef void add_writer(int fd, void (*cb)(void* data), void* data) with gil:
asyncio.get_running_loop().add_writer(fd, lambda: cb(data))

cdef void remove_writer(int fd) with gil:
asyncio.get_running_loop().remove_writer(fd)
PyFdListener *getFdListener():
return this.fdListener

from libcpp.utility cimport move

Expand All @@ -1866,8 +1877,7 @@ cdef class _EventLoop:
loop = asyncio.get_running_loop()
self.customPort = new AsyncIoEventPort(loop)
kjLoop = self.customPort.getKjLoop()
self.lowLevelProvider = makePyLowLevelAsyncIoProvider(&add_reader, &remove_reader,
&add_writer, &remove_writer,
self.lowLevelProvider = makePyLowLevelAsyncIoProvider(self.customPort.getFdListener(),
self.customPort.getTimer())
self.waitScope = new WaitScope(deref(kjLoop))
self.provider = newAsyncIoProvider(deref(self.lowLevelProvider))
Expand Down

0 comments on commit 64e9141

Please sign in to comment.