diff --git a/src/MPI/Comm.pyx b/src/MPI/Comm.pyx index 78446ff..b04830d 100644 --- a/src/MPI/Comm.pyx +++ b/src/MPI/Comm.pyx @@ -638,15 +638,11 @@ cdef class Comm: Generalized All-to-All communication allowing different counts, displacements and datatypes for each partner """ - sendbuf = recvbuf = None - raise NotImplementedError # XXX implement! - cdef void *sbuf = NULL, *rbuf = NULL - cdef int *scounts = NULL, *rcounts = NULL - cdef int *sdispls = NULL, *rdispls = NULL - cdef MPI_Datatype *stypes = NULL, *rtypes = NULL + cdef _p_msg_ccow m = message_ccow() + m.for_alltoallw(sendbuf, recvbuf, self.ob_mpi) with nogil: CHKERR( MPI_Alltoallw( - sbuf, scounts, sdispls, stypes, - rbuf, rcounts, rdispls, rtypes, + m.sbuf, m.scounts, m.sdispls, m.stypes, + m.rbuf, m.rcounts, m.rdispls, m.rtypes, self.ob_mpi) ) @@ -835,18 +831,14 @@ cdef class Comm: """ Nonblocking Generalized All-to-All """ - sendbuf = recvbuf = None - raise NotImplementedError # XXX implement! - cdef void *sbuf = NULL, *rbuf = NULL - cdef int *scounts = NULL, *rcounts = NULL - cdef int *sdispls = NULL, *rdispls = NULL - cdef MPI_Datatype *stypes = NULL, *rtypes = NULL + cdef _p_msg_ccow m = message_ccow() + m.for_alltoallw(sendbuf, recvbuf, self.ob_mpi) cdef Request request = Request.__new__(Request) with nogil: CHKERR( MPI_Ialltoallw( - sbuf, scounts, sdispls, stypes, - rbuf, rcounts, rdispls, rtypes, + m.sbuf, m.scounts, m.sdispls, m.stypes, + m.rbuf, m.rcounts, m.rdispls, m.rtypes, self.ob_mpi, &request.ob_mpi) ) - request.ob_buf = None + request.ob_buf = m return request def Ireduce(self, sendbuf, recvbuf, Op op not None=SUM, int root=0): @@ -1481,15 +1473,11 @@ cdef class Intracomm(Comm): """ Neighbor All-to-All Generalized """ - sendbuf = recvbuf = None - raise NotImplementedError # XXX implement! - cdef void *sbuf = NULL, *rbuf = NULL - cdef int *scounts = NULL, *rcounts = NULL - cdef int *sdispls = NULL, *rdispls = NULL - cdef MPI_Datatype *stypes = NULL, *rtypes = NULL + cdef _p_msg_ccow m = message_ccow() + m.for_neighbor_alltoallw(sendbuf, recvbuf, self.ob_mpi) with nogil: CHKERR( MPI_Neighbor_alltoallw( - sbuf, scounts, sdispls, stypes, - rbuf, rcounts, rdispls, rtypes, + m.sbuf, m.scounts, m.sdisplsA, m.stypes, + m.rbuf, m.rcounts, m.rdisplsA, m.rtypes, self.ob_mpi) ) # Nonblocking Neighborhood Collectives @@ -1554,18 +1542,14 @@ cdef class Intracomm(Comm): """ Nonblocking Neighbor All-to-All Generalized """ - sendbuf = recvbuf = None - raise NotImplementedError # XXX implement! - cdef void *sbuf = NULL, *rbuf = NULL - cdef int *scounts = NULL, *rcounts = NULL - cdef int *sdispls = NULL, *rdispls = NULL - cdef MPI_Datatype *stypes = NULL, *rtypes = NULL + cdef _p_msg_ccow m = message_ccow() + m.for_neighbor_alltoallw(sendbuf, recvbuf, self.ob_mpi) cdef Request request = Request.__new__(Request) with nogil: CHKERR( MPI_Ineighbor_alltoallw( - sbuf, scounts, sdispls, stypes, - rbuf, rcounts, rdispls, rtypes, + m.sbuf, m.scounts, m.sdisplsA, m.stypes, + m.rbuf, m.rcounts, m.rdisplsA, m.rtypes, self.ob_mpi, &request.ob_mpi) ) - request.ob_buf = None + request.ob_buf = m return request # Python Communication diff --git a/src/MPI/msgbuffer.pxi b/src/MPI/msgbuffer.pxi index c874c47..f9b8833 100644 --- a/src/MPI/msgbuffer.pxi +++ b/src/MPI/msgbuffer.pxi @@ -287,6 +287,56 @@ cdef _p_message message_vector(object msg, _type[0] = btype return m +cdef tuple message_vecw_I(object msg, + int readonly, + int blocks, + # + void **_addr, + int **_counts, + int **_displs, + MPI_Datatype **_types, + ): + cdef Py_ssize_t nargs = len(msg) + if nargs == 3: + o_buffer, (o_counts, o_displs), o_types = msg + elif nargs == 4: + o_buffer, o_counts, o_displs, o_types = msg + else: + raise ValueError("message: expecting 3 to 4 items") + if readonly: + o_buffer = getbuffer_r(o_buffer, _addr, NULL) + else: + o_buffer = getbuffer_w(o_buffer, _addr, NULL) + o_counts = asarray_int(o_counts, blocks, _counts) + o_displs = asarray_int(o_displs, blocks, _displs) + o_types = asarray_Datatype(o_types, blocks, _types) + return (o_buffer, o_counts, o_displs, o_types) + +cdef tuple message_vecw_A(object msg, + int readonly, + int blocks, + # + void **_addr, + int **_counts, + MPI_Aint **_displs, + MPI_Datatype **_types, + ): + cdef Py_ssize_t nargs = len(msg) + if nargs == 3: + o_buffer, (o_counts, o_displs), o_types = msg + elif nargs == 4: + o_buffer, o_counts, o_displs, o_types = msg + else: + raise ValueError("message: expecting 3 to 4 items") + if readonly: + o_buffer = getbuffer_r(o_buffer, _addr, NULL) + else: + o_buffer = getbuffer_w(o_buffer, _addr, NULL) + o_counts = asarray_int(o_counts, blocks, _counts) + o_displs = asarray_Aint(o_displs, blocks, _displs) + o_types = asarray_Datatype(o_types, blocks, _types) + return (o_buffer, o_counts, o_displs, o_types) + #------------------------------------------------------------------------------ #@cython.final @@ -403,6 +453,7 @@ cdef class _p_msg_cco: sending = 1 else: self.for_cco_recv(0, msg, root, 0) + sending = 0 else: # inter-communication if ((root == MPI_ROOT) or (root == MPI_PROC_NULL)): @@ -410,6 +461,7 @@ cdef class _p_msg_cco: sending = 1 else: self.for_cco_recv(0, msg, root, 0) + sending = 0 if sending: self.rbuf = self.sbuf self.rcount = self.scount @@ -756,6 +808,79 @@ cdef inline _p_msg_cco message_cco(): #------------------------------------------------------------------------------ +#@cython.final +#@cython.internal +cdef class _p_msg_ccow: + + # raw C-side arguments + cdef void *sbuf, *rbuf + cdef int *scounts, *rcounts + cdef int *sdispls, *rdispls + cdef MPI_Aint *sdisplsA, *rdisplsA + cdef MPI_Datatype *stypes, *rtypes + # python-side arguments + cdef object _smsg, _rmsg + + def __cinit__(self): + self.sbuf = self.rbuf = NULL + self.scounts = self.rcounts = NULL + self.sdispls = self.rdispls = NULL + self.sdisplsA = self.rdisplsA = NULL + self.stypes = self.rtypes = NULL + + # alltoallw + cdef int for_alltoallw(self, + object smsg, object rmsg, + MPI_Comm comm) except -1: + if comm == MPI_COMM_NULL: return 0 + cdef int inter=0, size=0 + CHKERR( MPI_Comm_test_inter(comm, &inter) ) + if not inter: # intra-communication + CHKERR( MPI_Comm_size(comm, &size) ) + else: # inter-communication + CHKERR( MPI_Comm_remote_size(comm, &size) ) + # + self._rmsg = message_vecw_I( + rmsg, 0, size, + &self.rbuf, &self.rcounts, + &self.rdispls, &self.rtypes) + if not inter and smsg is __IN_PLACE__: + self.sbuf = MPI_IN_PLACE + self.scount = self.rcount + self.scounts = self.rcounts + self.sdispls = self.rdispls + self.stypes = self.rtypes + return 0 + self._smsg = message_vecw_I( + smsg, 1, size, + &self.sbuf, &self.scounts, + &self.sdispls, &self.stypes) + return 0 + + # neighbor alltoallw + cdef int for_neighbor_alltoallw(self, + object smsg, object rmsg, + MPI_Comm comm) except -1: + if comm == MPI_COMM_NULL: return 0 + cdef int sendsize=0, recvsize=0 + comm_neighbors_count(comm, &recvsize, &sendsize) + self._rmsg = message_vecw_A( + rmsg, 0, recvsize, + &self.rbuf, &self.rcounts, + &self.rdisplsA, &self.rtypes) + self._smsg = message_vecw_A( + smsg, 1, sendsize, + &self.sbuf, &self.scounts, + &self.sdisplsA, &self.stypes) + return 0 + + +cdef inline _p_msg_ccow message_ccow(): + cdef _p_msg_ccow msg = <_p_msg_ccow>_p_msg_ccow.__new__(_p_msg_ccow) + return msg + +#------------------------------------------------------------------------------ + #@cython.final #@cython.internal cdef class _p_msg_rma: diff --git a/test/test_cco_buf.py b/test/test_cco_buf.py index b1ae2c4..c154306 100644 --- a/test/test_cco_buf.py +++ b/test/test_cco_buf.py @@ -99,6 +99,26 @@ def testAlltoall(self): for value in rbuf.flat: self.assertEqual(value, root) + def testAlltoallw(self): + size = self.COMM.Get_size() + rank = self.COMM.Get_rank() + for array in arrayimpl.ArrayTypes: + for typecode in arrayimpl.TypeMap: + for n in range(1,size+1): + sbuf = array( n, typecode, (size, n)) + rbuf = array(-1, typecode, (size, n)) + sdt, rdt = sbuf.mpidtype, rbuf.mpidtype + sdsp = list(range(0, size*n*sdt.extent, n*sdt.extent)) + rdsp = list(range(0, size*n*rdt.extent, n*rdt.extent)) + smsg = (sbuf.as_raw(), ([n]*size, sdsp), [sdt]*size) + rmsg = (rbuf.as_raw(), ([n]*size, rdsp), [rdt]*size) + try: + self.COMM.Alltoallw(smsg, rmsg) + except NotImplementedError: + return + for value in rbuf.flat: + self.assertEqual(value, n) + def assertAlmostEqual(self, first, second): num = float(float(second-first)) den = float(second+first)/2 or 1.0 diff --git a/test/test_cco_nb_vec.py b/test/test_cco_nb_vec.py index e43a4b8..ddca2b0 100644 --- a/test/test_cco_nb_vec.py +++ b/test/test_cco_nb_vec.py @@ -285,6 +285,26 @@ def testAlltoallv3(self): for v in rbuf: self.assertEqual(v, root) + def testAlltoallw(self): + size = self.COMM.Get_size() + rank = self.COMM.Get_rank() + for array in arrayimpl.ArrayTypes: + for typecode in arrayimpl.TypeMap: + for n in range(1, size+1): + sbuf = array( n, typecode, (size, n)) + rbuf = array(-1, typecode, (size, n)) + sdt, rdt = sbuf.mpidtype, rbuf.mpidtype + sdsp = list(range(0, size*n*sdt.extent, n*sdt.extent)) + rdsp = list(range(0, size*n*rdt.extent, n*rdt.extent)) + smsg = (sbuf.as_raw(), ([n]*size, sdsp), [sdt]*size) + rmsg = (rbuf.as_raw(), ([n]*size, rdsp), [rdt]*size) + try: + self.COMM.Ialltoallw(smsg, rmsg).Wait() + except NotImplementedError: + return + for v in rbuf.flat: + self.assertEqual(v, n) + class TestCCOVecSelf(BaseTestCCOVec, unittest.TestCase): COMM = MPI.COMM_SELF