Skip to content

Commit

Permalink
Add generalized all-to-all collectives
Browse files Browse the repository at this point in the history
  • Loading branch information
dalcinl committed May 21, 2013
1 parent 6fa6075 commit 0339f42
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 34 deletions.
52 changes: 18 additions & 34 deletions src/MPI/Comm.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -638,15 +638,11 @@ cdef class Comm:
Generalized All-to-All communication allowing different
counts, displacements and datatypes for each partner
"""
sendbuf = recvbuf = None
raise NotImplementedError # XXX implement!
cdef void *sbuf = NULL, *rbuf = NULL
cdef int *scounts = NULL, *rcounts = NULL
cdef int *sdispls = NULL, *rdispls = NULL
cdef MPI_Datatype *stypes = NULL, *rtypes = NULL
cdef _p_msg_ccow m = message_ccow()
m.for_alltoallw(sendbuf, recvbuf, self.ob_mpi)
with nogil: CHKERR( MPI_Alltoallw(
sbuf, scounts, sdispls, stypes,
rbuf, rcounts, rdispls, rtypes,
m.sbuf, m.scounts, m.sdispls, m.stypes,
m.rbuf, m.rcounts, m.rdispls, m.rtypes,
self.ob_mpi) )


Expand Down Expand Up @@ -835,18 +831,14 @@ cdef class Comm:
"""
Nonblocking Generalized All-to-All
"""
sendbuf = recvbuf = None
raise NotImplementedError # XXX implement!
cdef void *sbuf = NULL, *rbuf = NULL
cdef int *scounts = NULL, *rcounts = NULL
cdef int *sdispls = NULL, *rdispls = NULL
cdef MPI_Datatype *stypes = NULL, *rtypes = NULL
cdef _p_msg_ccow m = message_ccow()
m.for_alltoallw(sendbuf, recvbuf, self.ob_mpi)
cdef Request request = <Request>Request.__new__(Request)
with nogil: CHKERR( MPI_Ialltoallw(
sbuf, scounts, sdispls, stypes,
rbuf, rcounts, rdispls, rtypes,
m.sbuf, m.scounts, m.sdispls, m.stypes,
m.rbuf, m.rcounts, m.rdispls, m.rtypes,
self.ob_mpi, &request.ob_mpi) )
request.ob_buf = None
request.ob_buf = m
return request

def Ireduce(self, sendbuf, recvbuf, Op op not None=SUM, int root=0):
Expand Down Expand Up @@ -1481,15 +1473,11 @@ cdef class Intracomm(Comm):
"""
Neighbor All-to-All Generalized
"""
sendbuf = recvbuf = None
raise NotImplementedError # XXX implement!
cdef void *sbuf = NULL, *rbuf = NULL
cdef int *scounts = NULL, *rcounts = NULL
cdef int *sdispls = NULL, *rdispls = NULL
cdef MPI_Datatype *stypes = NULL, *rtypes = NULL
cdef _p_msg_ccow m = message_ccow()
m.for_neighbor_alltoallw(sendbuf, recvbuf, self.ob_mpi)
with nogil: CHKERR( MPI_Neighbor_alltoallw(
sbuf, scounts, sdispls, stypes,
rbuf, rcounts, rdispls, rtypes,
m.sbuf, m.scounts, m.sdisplsA, m.stypes,
m.rbuf, m.rcounts, m.rdisplsA, m.rtypes,
self.ob_mpi) )

# Nonblocking Neighborhood Collectives
Expand Down Expand Up @@ -1554,18 +1542,14 @@ cdef class Intracomm(Comm):
"""
Nonblocking Neighbor All-to-All Generalized
"""
sendbuf = recvbuf = None
raise NotImplementedError # XXX implement!
cdef void *sbuf = NULL, *rbuf = NULL
cdef int *scounts = NULL, *rcounts = NULL
cdef int *sdispls = NULL, *rdispls = NULL
cdef MPI_Datatype *stypes = NULL, *rtypes = NULL
cdef _p_msg_ccow m = message_ccow()
m.for_neighbor_alltoallw(sendbuf, recvbuf, self.ob_mpi)
cdef Request request = <Request>Request.__new__(Request)
with nogil: CHKERR( MPI_Ineighbor_alltoallw(
sbuf, scounts, sdispls, stypes,
rbuf, rcounts, rdispls, rtypes,
m.sbuf, m.scounts, m.sdisplsA, m.stypes,
m.rbuf, m.rcounts, m.rdisplsA, m.rtypes,
self.ob_mpi, &request.ob_mpi) )
request.ob_buf = None
request.ob_buf = m
return request

# Python Communication
Expand Down
125 changes: 125 additions & 0 deletions src/MPI/msgbuffer.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,56 @@ cdef _p_message message_vector(object msg,
_type[0] = btype
return m

cdef tuple message_vecw_I(object msg,
int readonly,
int blocks,
#
void **_addr,
int **_counts,
int **_displs,
MPI_Datatype **_types,
):
cdef Py_ssize_t nargs = len(msg)
if nargs == 3:
o_buffer, (o_counts, o_displs), o_types = msg
elif nargs == 4:
o_buffer, o_counts, o_displs, o_types = msg
else:
raise ValueError("message: expecting 3 to 4 items")
if readonly:
o_buffer = getbuffer_r(o_buffer, _addr, NULL)
else:
o_buffer = getbuffer_w(o_buffer, _addr, NULL)
o_counts = asarray_int(o_counts, blocks, _counts)
o_displs = asarray_int(o_displs, blocks, _displs)
o_types = asarray_Datatype(o_types, blocks, _types)
return (o_buffer, o_counts, o_displs, o_types)

cdef tuple message_vecw_A(object msg,
int readonly,
int blocks,
#
void **_addr,
int **_counts,
MPI_Aint **_displs,
MPI_Datatype **_types,
):
cdef Py_ssize_t nargs = len(msg)
if nargs == 3:
o_buffer, (o_counts, o_displs), o_types = msg
elif nargs == 4:
o_buffer, o_counts, o_displs, o_types = msg
else:
raise ValueError("message: expecting 3 to 4 items")
if readonly:
o_buffer = getbuffer_r(o_buffer, _addr, NULL)
else:
o_buffer = getbuffer_w(o_buffer, _addr, NULL)
o_counts = asarray_int(o_counts, blocks, _counts)
o_displs = asarray_Aint(o_displs, blocks, _displs)
o_types = asarray_Datatype(o_types, blocks, _types)
return (o_buffer, o_counts, o_displs, o_types)

#------------------------------------------------------------------------------

#@cython.final
Expand Down Expand Up @@ -403,13 +453,15 @@ cdef class _p_msg_cco:
sending = 1
else:
self.for_cco_recv(0, msg, root, 0)
sending = 0
else: # inter-communication
if ((root == <int>MPI_ROOT) or
(root == <int>MPI_PROC_NULL)):
self.for_cco_send(0, msg, root, 0)
sending = 1
else:
self.for_cco_recv(0, msg, root, 0)
sending = 0
if sending:
self.rbuf = self.sbuf
self.rcount = self.scount
Expand Down Expand Up @@ -756,6 +808,79 @@ cdef inline _p_msg_cco message_cco():

#------------------------------------------------------------------------------

#@cython.final
#@cython.internal
cdef class _p_msg_ccow:

# raw C-side arguments
cdef void *sbuf, *rbuf
cdef int *scounts, *rcounts
cdef int *sdispls, *rdispls
cdef MPI_Aint *sdisplsA, *rdisplsA
cdef MPI_Datatype *stypes, *rtypes
# python-side arguments
cdef object _smsg, _rmsg

def __cinit__(self):
self.sbuf = self.rbuf = NULL
self.scounts = self.rcounts = NULL
self.sdispls = self.rdispls = NULL
self.sdisplsA = self.rdisplsA = NULL
self.stypes = self.rtypes = NULL

# alltoallw
cdef int for_alltoallw(self,
object smsg, object rmsg,
MPI_Comm comm) except -1:
if comm == MPI_COMM_NULL: return 0
cdef int inter=0, size=0
CHKERR( MPI_Comm_test_inter(comm, &inter) )
if not inter: # intra-communication
CHKERR( MPI_Comm_size(comm, &size) )
else: # inter-communication
CHKERR( MPI_Comm_remote_size(comm, &size) )
#
self._rmsg = message_vecw_I(
rmsg, 0, size,
&self.rbuf, &self.rcounts,
&self.rdispls, &self.rtypes)
if not inter and smsg is __IN_PLACE__:
self.sbuf = MPI_IN_PLACE
self.scount = self.rcount
self.scounts = self.rcounts
self.sdispls = self.rdispls
self.stypes = self.rtypes
return 0
self._smsg = message_vecw_I(
smsg, 1, size,
&self.sbuf, &self.scounts,
&self.sdispls, &self.stypes)
return 0

# neighbor alltoallw
cdef int for_neighbor_alltoallw(self,
object smsg, object rmsg,
MPI_Comm comm) except -1:
if comm == MPI_COMM_NULL: return 0
cdef int sendsize=0, recvsize=0
comm_neighbors_count(comm, &recvsize, &sendsize)
self._rmsg = message_vecw_A(
rmsg, 0, recvsize,
&self.rbuf, &self.rcounts,
&self.rdisplsA, &self.rtypes)
self._smsg = message_vecw_A(
smsg, 1, sendsize,
&self.sbuf, &self.scounts,
&self.sdisplsA, &self.stypes)
return 0


cdef inline _p_msg_ccow message_ccow():
cdef _p_msg_ccow msg = <_p_msg_ccow>_p_msg_ccow.__new__(_p_msg_ccow)
return msg

#------------------------------------------------------------------------------

#@cython.final
#@cython.internal
cdef class _p_msg_rma:
Expand Down
20 changes: 20 additions & 0 deletions test/test_cco_buf.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ def testAlltoall(self):
for value in rbuf.flat:
self.assertEqual(value, root)

def testAlltoallw(self):
size = self.COMM.Get_size()
rank = self.COMM.Get_rank()
for array in arrayimpl.ArrayTypes:
for typecode in arrayimpl.TypeMap:
for n in range(1,size+1):
sbuf = array( n, typecode, (size, n))
rbuf = array(-1, typecode, (size, n))
sdt, rdt = sbuf.mpidtype, rbuf.mpidtype
sdsp = list(range(0, size*n*sdt.extent, n*sdt.extent))
rdsp = list(range(0, size*n*rdt.extent, n*rdt.extent))
smsg = (sbuf.as_raw(), ([n]*size, sdsp), [sdt]*size)
rmsg = (rbuf.as_raw(), ([n]*size, rdsp), [rdt]*size)
try:
self.COMM.Alltoallw(smsg, rmsg)
except NotImplementedError:
return
for value in rbuf.flat:
self.assertEqual(value, n)

def assertAlmostEqual(self, first, second):
num = float(float(second-first))
den = float(second+first)/2 or 1.0
Expand Down
20 changes: 20 additions & 0 deletions test/test_cco_nb_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,26 @@ def testAlltoallv3(self):
for v in rbuf:
self.assertEqual(v, root)

def testAlltoallw(self):
size = self.COMM.Get_size()
rank = self.COMM.Get_rank()
for array in arrayimpl.ArrayTypes:
for typecode in arrayimpl.TypeMap:
for n in range(1, size+1):
sbuf = array( n, typecode, (size, n))
rbuf = array(-1, typecode, (size, n))
sdt, rdt = sbuf.mpidtype, rbuf.mpidtype
sdsp = list(range(0, size*n*sdt.extent, n*sdt.extent))
rdsp = list(range(0, size*n*rdt.extent, n*rdt.extent))
smsg = (sbuf.as_raw(), ([n]*size, sdsp), [sdt]*size)
rmsg = (rbuf.as_raw(), ([n]*size, rdsp), [rdt]*size)
try:
self.COMM.Ialltoallw(smsg, rmsg).Wait()
except NotImplementedError:
return
for v in rbuf.flat:
self.assertEqual(v, n)


class TestCCOVecSelf(BaseTestCCOVec, unittest.TestCase):
COMM = MPI.COMM_SELF
Expand Down

0 comments on commit 0339f42

Please sign in to comment.