Skip to content

non-blocking collectives: retain MPI_op and MPI_Datatype(s) #2154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 187 additions & 2 deletions ompi/mca/coll/base/coll_base_util.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
* University of Stuttgart. All rights reserved.
* Copyright (c) 2004-2005 The Regents of the University of California.
* All rights reserved.
* Copyright (c) 2014-2017 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2014-2019 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
Expand All @@ -26,6 +26,7 @@
#include "ompi/communicator/communicator.h"
#include "ompi/mca/coll/base/coll_tags.h"
#include "ompi/mca/coll/base/coll_base_functions.h"
#include "ompi/mca/topo/base/base.h"
#include "ompi/mca/pml/pml.h"
#include "coll_base_util.h"

Expand Down Expand Up @@ -103,3 +104,187 @@ int ompi_rounddown(int num, int factor)
num /= factor;
return num * factor; /* floor(num / factor) * factor */
}

static void release_objs_callback(struct ompi_coll_base_nbc_request_t *request) {
if (NULL != request->data.objs.objs[0]) {
OBJ_RELEASE(request->data.objs.objs[0]);
}
if (NULL != request->data.objs.objs[1]) {
OBJ_RELEASE(request->data.objs.objs[1]);
}
}

static int complete_objs_callback(struct ompi_request_t *req) {
struct ompi_coll_base_nbc_request_t *request = (ompi_coll_base_nbc_request_t *)req;
int rc = OMPI_SUCCESS;
assert (NULL != request);
if (NULL != request->cb.req_complete_cb) {
rc = request->cb.req_complete_cb(request->req_complete_cb_data);
}
release_objs_callback(request);
return rc;
}

static int free_objs_callback(struct ompi_request_t **rptr) {
struct ompi_coll_base_nbc_request_t *request = *(ompi_coll_base_nbc_request_t **)rptr;
int rc = OMPI_SUCCESS;
if (NULL != request->cb.req_free) {
rc = request->cb.req_free(rptr);
}
release_objs_callback(request);
return rc;
}

int ompi_coll_base_retain_op( ompi_request_t *req, ompi_op_t *op,
ompi_datatype_t *type) {
ompi_coll_base_nbc_request_t *request = (ompi_coll_base_nbc_request_t *)req;
bool retain = false;
if (!ompi_op_is_intrinsic(op)) {
OBJ_RETAIN(op);
request->data.op.op = op;
retain = true;
}
if (!ompi_datatype_is_predefined(type)) {
OBJ_RETAIN(type);
request->data.op.datatype = type;
retain = true;
}
if (OPAL_UNLIKELY(retain)) {
/* We need to consider two cases :
* - non blocking collectives:
* the objects can be released when MPI_Wait() completes
* and we use the req_complete_cb callback
* - persistent non blocking collectives:
* the objects can only be released when the request is freed
* (e.g. MPI_Request_free() completes) and we use req_free callback
*/
if (req->req_persistent) {
request->cb.req_free = req->req_free;
req->req_free = free_objs_callback;
} else {
request->cb.req_complete_cb = req->req_complete_cb;
request->req_complete_cb_data = req->req_complete_cb_data;
req->req_complete_cb = complete_objs_callback;
req->req_complete_cb_data = request;
}
}
return OMPI_SUCCESS;
}

int ompi_coll_base_retain_datatypes( ompi_request_t *req, ompi_datatype_t *stype,
ompi_datatype_t *rtype) {
ompi_coll_base_nbc_request_t *request = (ompi_coll_base_nbc_request_t *)req;
bool retain = false;
if (NULL != stype && !ompi_datatype_is_predefined(stype)) {
OBJ_RETAIN(stype);
request->data.types.stype = stype;
retain = true;
}
if (NULL != rtype && !ompi_datatype_is_predefined(rtype)) {
OBJ_RETAIN(rtype);
request->data.types.rtype = rtype;
retain = true;
}
if (OPAL_UNLIKELY(retain)) {
if (req->req_persistent) {
request->cb.req_free = req->req_free;
req->req_free = free_objs_callback;
} else {
request->cb.req_complete_cb = req->req_complete_cb;
request->req_complete_cb_data = req->req_complete_cb_data;
req->req_complete_cb = complete_objs_callback;
req->req_complete_cb_data = request;
}
}
return OMPI_SUCCESS;
}

static void release_vecs_callback(ompi_coll_base_nbc_request_t *request) {
ompi_communicator_t *comm = request->super.req_mpi_object.comm;
int scount, rcount;
if (OMPI_COMM_IS_TOPO(comm)) {
(void)mca_topo_base_neighbor_count (comm, &rcount, &scount);
} else {
scount = rcount = OMPI_COMM_IS_INTER(comm)?ompi_comm_remote_size(comm):ompi_comm_size(comm);
}
for (int i=0; i<scount; i++) {
if (NULL != request->data.vecs.stypes && NULL != request->data.vecs.stypes[i]) {
OMPI_DATATYPE_RELEASE(request->data.vecs.stypes[i]);
}
}
for (int i=0; i<rcount; i++) {
if (NULL != request->data.vecs.rtypes && NULL != request->data.vecs.rtypes[i]) {
OMPI_DATATYPE_RELEASE(request->data.vecs.rtypes[i]);
}
}
}

static int complete_vecs_callback(struct ompi_request_t *req) {
ompi_coll_base_nbc_request_t *request = (ompi_coll_base_nbc_request_t *)req;
int rc = OMPI_SUCCESS;
assert (NULL != request);
if (NULL != request->cb.req_complete_cb) {
rc = request->cb.req_complete_cb(request->req_complete_cb_data);
}
release_vecs_callback(request);
return rc;
}

static int free_vecs_callback(struct ompi_request_t **rptr) {
struct ompi_coll_base_nbc_request_t *request = *(ompi_coll_base_nbc_request_t **)rptr;
int rc = OMPI_SUCCESS;
if (NULL != request->cb.req_free) {
rc = request->cb.req_free(rptr);
}
release_vecs_callback(request);
return rc;
}

int ompi_coll_base_retain_datatypes_w( ompi_request_t *req,
ompi_datatype_t *stypes[], ompi_datatype_t *rtypes[]) {
ompi_coll_base_nbc_request_t *request = (ompi_coll_base_nbc_request_t *)req;
bool retain = false;
ompi_communicator_t *comm = request->super.req_mpi_object.comm;
int scount, rcount;
if (OMPI_COMM_IS_TOPO(comm)) {
(void)mca_topo_base_neighbor_count (comm, &rcount, &scount);
} else {
scount = rcount = OMPI_COMM_IS_INTER(comm)?ompi_comm_remote_size(comm):ompi_comm_size(comm);
}

for (int i=0; i<scount; i++) {
if (NULL != stypes && NULL != stypes[i] && !ompi_datatype_is_predefined(stypes[i])) {
OBJ_RETAIN(stypes[i]);
retain = true;
}
}
for (int i=0; i<rcount; i++) {
if (NULL != rtypes && NULL != rtypes[i] && !ompi_datatype_is_predefined(rtypes[i])) {
OBJ_RETAIN(rtypes[i]);
retain = true;
}
}
if (OPAL_UNLIKELY(retain)) {
request->data.vecs.stypes = stypes;
request->data.vecs.rtypes = rtypes;
if (req->req_persistent) {
request->cb.req_free = req->req_free;
req->req_free = free_vecs_callback;
} else {
request->cb.req_complete_cb = req->req_complete_cb;
request->req_complete_cb_data = req->req_complete_cb_data;
req->req_complete_cb = complete_vecs_callback;
req->req_complete_cb_data = request;
}
}
return OMPI_SUCCESS;
}

static void nbc_req_cons(ompi_coll_base_nbc_request_t *req) {
req->cb.req_complete_cb = NULL;
req->req_complete_cb_data = NULL;
req->data.objs.objs[0] = NULL;
req->data.objs.objs[1] = NULL;
}

OBJ_CLASS_INSTANCE(ompi_coll_base_nbc_request_t, ompi_request_t, nbc_req_cons, NULL);
47 changes: 45 additions & 2 deletions ompi/mca/coll/base/coll_base_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
* University of Stuttgart. All rights reserved.
* Copyright (c) 2004-2005 The Regents of the University of California.
* All rights reserved.
* Copyright (c) 2014-2017 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2014-2019 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
Expand All @@ -27,10 +27,41 @@
#include "ompi/mca/mca.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/request/request.h"
#include "ompi/op/op.h"
#include "ompi/mca/pml/pml.h"

BEGIN_C_DECLS

struct ompi_coll_base_nbc_request_t {
ompi_request_t super;
union {
ompi_request_complete_fn_t req_complete_cb;
ompi_request_free_fn_t req_free;
} cb;
void *req_complete_cb_data;
union {
struct {
ompi_op_t *op;
ompi_datatype_t *datatype;
} op;
struct {
ompi_datatype_t *stype;
ompi_datatype_t *rtype;
} types;
struct {
opal_object_t *objs[2];
} objs;
struct {
ompi_datatype_t **stypes;
ompi_datatype_t **rtypes;
} vecs;
} data;
};

OMPI_DECLSPEC OBJ_CLASS_DECLARATION(ompi_coll_base_nbc_request_t);

typedef struct ompi_coll_base_nbc_request_t ompi_coll_base_nbc_request_t;

/**
* A MPI_like function doing a send and a receive simultaneously.
* If one of the communications results in a zero-byte message the
Expand Down Expand Up @@ -84,5 +115,17 @@ unsigned int ompi_mirror_perm(unsigned int x, int nbits);
*/
int ompi_rounddown(int num, int factor);

int ompi_coll_base_retain_op( ompi_request_t *request,
ompi_op_t *op,
ompi_datatype_t *type);

int ompi_coll_base_retain_datatypes( ompi_request_t *request,
ompi_datatype_t *stype,
ompi_datatype_t *rtype);

int ompi_coll_base_retain_datatypes_w( ompi_request_t *request,
ompi_datatype_t *stypes[],
ompi_datatype_t *rtypes[]);

END_C_DECLS
#endif /* MCA_COLL_BASE_UTIL_EXPORT_H */
14 changes: 7 additions & 7 deletions ompi/mca/coll/libnbc/coll_libnbc.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
* Copyright (c) 2008 Cisco Systems, Inc. All rights reserved.
* Copyright (c) 2013-2015 Los Alamos National Security, LLC. All rights
* reserved.
* Copyright (c) 2014-2017 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2014-2019 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2016-2017 IBM Corporation. All rights reserved.
* Copyright (c) 2018 FUJITSU LIMITED. All rights reserved.
* $COPYRIGHT$
Expand All @@ -28,7 +28,7 @@
#define MCA_COLL_LIBNBC_EXPORT_H

#include "ompi/mca/coll/coll.h"
#include "ompi/request/request.h"
#include "ompi/mca/coll/base/coll_base_util.h"
#include "opal/sys/atomic.h"

BEGIN_C_DECLS
Expand Down Expand Up @@ -121,7 +121,7 @@ typedef struct NBC_Schedule NBC_Schedule;
OBJ_CLASS_DECLARATION(NBC_Schedule);

struct ompi_coll_libnbc_request_t {
ompi_request_t super;
ompi_coll_base_nbc_request_t super;
MPI_Comm comm;
long row_offset;
bool nbc_complete; /* status in libnbc level */
Expand All @@ -145,13 +145,13 @@ typedef ompi_coll_libnbc_request_t NBC_Handle;
opal_free_list_item_t *item; \
item = opal_free_list_wait (&mca_coll_libnbc_component.requests); \
req = (ompi_coll_libnbc_request_t*) item; \
OMPI_REQUEST_INIT(&req->super, persistent); \
req->super.req_mpi_object.comm = comm; \
OMPI_REQUEST_INIT(&req->super.super, persistent); \
req->super.super.req_mpi_object.comm = comm; \
} while (0)

#define OMPI_COLL_LIBNBC_REQUEST_RETURN(req) \
do { \
OMPI_REQUEST_FINI(&(req)->super); \
OMPI_REQUEST_FINI(&(req)->super.super); \
opal_free_list_return (&mca_coll_libnbc_component.requests, \
(opal_free_list_item_t*) (req)); \
} while (0)
Expand Down
30 changes: 15 additions & 15 deletions ompi/mca/coll/libnbc/coll_libnbc_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
* Copyright (c) 2008 Cisco Systems, Inc. All rights reserved.
* Copyright (c) 2013-2015 Los Alamos National Security, LLC. All rights
* reserved.
* Copyright (c) 2016-2017 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2016-2019 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2016 IBM Corporation. All rights reserved.
* Copyright (c) 2017 Ian Bradley Morgan and Anthony Skjellum. All
* rights reserved.
Expand Down Expand Up @@ -448,21 +448,21 @@ ompi_coll_libnbc_progress(void)
/* done, remove and complete */
OPAL_THREAD_LOCK(&mca_coll_libnbc_component.lock);
opal_list_remove_item(&mca_coll_libnbc_component.active_requests,
&request->super.super.super);
&request->super.super.super.super);
OPAL_THREAD_UNLOCK(&mca_coll_libnbc_component.lock);

if( OMPI_SUCCESS == res || NBC_OK == res || NBC_SUCCESS == res ) {
request->super.req_status.MPI_ERROR = OMPI_SUCCESS;
request->super.super.req_status.MPI_ERROR = OMPI_SUCCESS;
}
else {
request->super.req_status.MPI_ERROR = res;
request->super.super.req_status.MPI_ERROR = res;
}
if(request->super.req_persistent) {
if(request->super.super.req_persistent) {
/* reset for the next communication */
request->row_offset = 0;
}
if(!request->super.req_persistent || !REQUEST_COMPLETE(&request->super)) {
ompi_request_complete(&request->super, true);
if(!request->super.super.req_persistent || !REQUEST_COMPLETE(&request->super.super)) {
ompi_request_complete(&request->super.super, true);
}
}
OPAL_THREAD_LOCK(&mca_coll_libnbc_component.lock);
Expand Down Expand Up @@ -527,7 +527,7 @@ request_start(size_t count, ompi_request_t ** requests)
NBC_DEBUG(5, "tmpbuf address=%p size=%u\n", handle->tmpbuf, sizeof(handle->tmpbuf));
NBC_DEBUG(5, "--------------------------------\n");

handle->super.req_complete = REQUEST_PENDING;
handle->super.super.req_complete = REQUEST_PENDING;
handle->nbc_complete = false;

res = NBC_Start(handle);
Expand Down Expand Up @@ -557,7 +557,7 @@ request_free(struct ompi_request_t **ompi_req)
ompi_coll_libnbc_request_t *request =
(ompi_coll_libnbc_request_t*) *ompi_req;

if( !REQUEST_COMPLETE(&request->super) ) {
if( !REQUEST_COMPLETE(&request->super.super) ) {
return MPI_ERR_REQUEST;
}

Expand All @@ -571,11 +571,11 @@ request_free(struct ompi_request_t **ompi_req)
static void
request_construct(ompi_coll_libnbc_request_t *request)
{
request->super.req_type = OMPI_REQUEST_COLL;
request->super.req_status._cancelled = 0;
request->super.req_start = request_start;
request->super.req_free = request_free;
request->super.req_cancel = request_cancel;
request->super.super.req_type = OMPI_REQUEST_COLL;
request->super.super.req_status._cancelled = 0;
request->super.super.req_start = request_start;
request->super.super.req_free = request_free;
request->super.super.req_cancel = request_cancel;
}


Expand Down
Loading