diff --git a/src/uct/ib/mlx5/dv/ib_mlx5dv_md.c b/src/uct/ib/mlx5/dv/ib_mlx5dv_md.c index e5626040f79..d979935d560 100644 --- a/src/uct/ib/mlx5/dv/ib_mlx5dv_md.c +++ b/src/uct/ib/mlx5/dv/ib_mlx5dv_md.c @@ -2167,9 +2167,10 @@ static void uct_ib_mlx5dv_check_dm_ksm_reg(uct_ib_mlx5_md_t *md) #endif } -ucs_status_t uct_ib_mlx5_devx_md_open(struct ibv_device *ibv_device, - const uct_ib_md_config_t *md_config, - uct_ib_md_t **p_md) +ucs_status_t uct_ib_mlx5_devx_md_open_common(const char *name, size_t size, + struct ibv_device *ibv_device, + const uct_ib_md_config_t *md_config, + uct_ib_md_t **p_md) { uint8_t lag_state = 0; size_t out_len = UCT_IB_MLX5DV_ST_SZ_BYTES(query_hca_cap_out); @@ -2217,8 +2218,7 @@ ucs_status_t uct_ib_mlx5_devx_md_open(struct ibv_device *ibv_device, goto err_free_context; } - md = ucs_derived_of(uct_ib_md_alloc(sizeof(*md), "ib_mlx5_devx_md", ctx), - uct_ib_mlx5_md_t); + md = ucs_derived_of(uct_ib_md_alloc(size, name, ctx), uct_ib_mlx5_md_t); if (md == NULL) { status = UCS_ERR_NO_MEMORY; goto err_free_context; @@ -3099,6 +3099,16 @@ uct_ib_mlx5_devx_md_query(uct_md_h uct_md, uct_md_attr_v2_t *md_attr) return UCS_OK; } +static ucs_status_t +uct_ib_mlx5_devx_md_open(struct ibv_device *ibv_device, + const uct_ib_md_config_t *md_config, + uct_ib_md_t **p_md) +{ + return uct_ib_mlx5_devx_md_open_common("ib_mlx5_devx_md", + sizeof(uct_ib_mlx5_md_t), + ibv_device, md_config, p_md); +} + static uct_ib_md_ops_t uct_ib_mlx5_devx_md_ops = { .super = { .close = uct_ib_mlx5_devx_md_close, diff --git a/src/uct/ib/mlx5/gga/gga_mlx5.c b/src/uct/ib/mlx5/gga/gga_mlx5.c index f120ac728c1..28a59054a31 100644 --- a/src/uct/ib/mlx5/gga/gga_mlx5.c +++ b/src/uct/ib/mlx5/gga/gga_mlx5.c @@ -23,10 +23,15 @@ UCT_IB_MLX5_MD_FLAG_MMO_DMA) #endif /* ENABLE_ASSERT */ +typedef struct { + uct_ib_mlx5_md_t super; + pthread_mutex_t mem_attach_lock; +} uct_gga_mlx5_md_t; + typedef struct { uct_ib_md_packed_mkey_t packed_mkey; uct_ib_mlx5_devx_mem_t *memh; - uct_ib_mlx5_md_t *md; + uct_gga_mlx5_md_t *md; uct_rkey_bundle_t rkey_ob; } uct_gga_mlx5_rkey_handle_t; @@ -93,6 +98,21 @@ uct_ib_mlx5_gga_md_query(uct_md_h uct_md, uct_md_attr_v2_t *md_attr) return UCS_OK; } +static ucs_status_t +uct_ib_mlx5_gga_mem_attach(uct_md_h uct_md, const void *mkey_buffer, + uct_md_mem_attach_params_t *params, + uct_mem_h *memh_p) +{ + uct_gga_mlx5_md_t *gga_md = ucs_derived_of(uct_md, uct_gga_mlx5_md_t); + ucs_status_t status; + + pthread_mutex_lock(&gga_md->mem_attach_lock); + status = uct_ib_mlx5_devx_mem_attach(uct_md, mkey_buffer, params, memh_p); + pthread_mutex_unlock(&gga_md->mem_attach_lock); + return status; +} + + static ucs_status_t uct_ib_mlx5_gga_mkey_pack(uct_md_h uct_md, uct_mem_h uct_memh, void *address, size_t length, @@ -151,7 +171,8 @@ uct_gga_mlx5_rkey_handle_dereg(uct_gga_mlx5_rkey_handle_t *rkey_handle) return; } - status = uct_ib_mlx5_devx_mem_dereg(&rkey_handle->md->super.super, ¶ms); + status = uct_ib_mlx5_devx_mem_dereg(&rkey_handle->md->super.super.super, + ¶ms); if (status != UCS_OK) { ucs_warn("md %p: failed to deregister GGA memh %p", rkey_handle->md, rkey_handle->memh); @@ -247,10 +268,10 @@ static uct_component_t uct_gga_component = { .md_vfs_init = (uct_component_md_vfs_init_func_t)ucs_empty_function }; -static UCS_F_ALWAYS_INLINE -void uct_gga_mlx5_rkey_trace(uct_ib_mlx5_md_t *md, - uct_gga_mlx5_rkey_handle_t *rkey_handle, - const char *prefix) +static UCS_F_ALWAYS_INLINE void + uct_gga_mlx5_rkey_trace(uct_gga_mlx5_md_t *md, + uct_gga_mlx5_rkey_handle_t *rkey_handle, + const char *prefix) { ucs_trace("md %p: %s resolved rkey %p: rkey_ob %"PRIx64"/%p", md, prefix, rkey_handle, rkey_handle->rkey_ob.rkey, @@ -258,13 +279,12 @@ void uct_gga_mlx5_rkey_trace(uct_ib_mlx5_md_t *md, } static UCS_F_ALWAYS_INLINE ucs_status_t -uct_gga_mlx5_rkey_resolve(uct_ib_mlx5_md_t *md, +uct_gga_mlx5_rkey_resolve(uct_gga_mlx5_md_t *md, uct_gga_mlx5_rkey_handle_t *rkey_handle) { - static pthread_mutex_t mem_attach_lock = PTHREAD_MUTEX_INITIALIZER; - uct_md_h uct_md = &md->super.super; - uct_md_mem_attach_params_t atach_params = { 0 }; - uct_md_mkey_pack_params_t repack_params = { 0 }; + uct_md_h uct_md = &md->super.super.super; + uct_md_mem_attach_params_t attach_params = { 0 }; + uct_md_mkey_pack_params_t repack_params = { 0 }; uint64_t repack_mkey; ucs_status_t status; @@ -273,14 +293,9 @@ uct_gga_mlx5_rkey_resolve(uct_ib_mlx5_md_t *md, return UCS_OK; } - /* TODO: this is a temporary solution to protect - @ref uct_ib_mlx5_md_t::umr::mkey_hash, - it should be reworked in PR #10236 */ - pthread_mutex_lock(&mem_attach_lock); - status = uct_ib_mlx5_devx_mem_attach(uct_md, &rkey_handle->packed_mkey, - &atach_params, - (uct_mem_h *)&rkey_handle->memh); - pthread_mutex_unlock(&mem_attach_lock); + status = uct_ib_mlx5_gga_mem_attach(uct_md, &rkey_handle->packed_mkey, + &attach_params, + (uct_mem_h *)&rkey_handle->memh); if (status != UCS_OK) { goto err_out; } @@ -536,8 +551,8 @@ uct_gga_mlx5_ep_put_zcopy(uct_ep_h tl_ep, const uct_iov_t *iov, size_t iovcnt, { UCT_RC_MLX5_BASE_EP_DECL(tl_ep, iface, ep); uct_gga_mlx5_ep_t *gga_ep = ucs_derived_of(ep, uct_gga_mlx5_ep_t); - uct_ib_mlx5_md_t *md = ucs_derived_of(iface->super.super.super.md, - uct_ib_mlx5_md_t); + uct_gga_mlx5_md_t *md = ucs_derived_of(iface->super.super.super.md, + uct_gga_mlx5_md_t); uct_gga_mlx5_rkey_handle_t *rkey_handle = (uct_gga_mlx5_rkey_handle_t*)rkey; uint8_t fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; uct_rkey_t rkey_copy; @@ -584,8 +599,8 @@ uct_gga_mlx5_ep_get_zcopy(uct_ep_h tl_ep, const uct_iov_t *iov, size_t iovcnt, { UCT_RC_MLX5_BASE_EP_DECL(tl_ep, iface, ep); uct_gga_mlx5_ep_t *gga_ep = ucs_derived_of(ep, uct_gga_mlx5_ep_t); - uct_ib_mlx5_md_t *md = ucs_derived_of(iface->super.super.super.md, - uct_ib_mlx5_md_t); + uct_gga_mlx5_md_t *md = ucs_derived_of(iface->super.super.super.md, + uct_gga_mlx5_md_t); uct_gga_mlx5_rkey_handle_t *rkey_handle = (uct_gga_mlx5_rkey_handle_t*)rkey; size_t total_length = uct_iov_total_length(iov, iovcnt); uint8_t fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; @@ -818,15 +833,23 @@ UCT_TL_DEFINE_ENTRY(&uct_gga_component, gga_mlx5, uct_gga_mlx5_query_tl_devices, UCT_SINGLE_TL_INIT(&uct_gga_component, gga_mlx5, ctor,,) +static void uct_ib_mlx5_gga_md_close(uct_md_h md) +{ + uct_gga_mlx5_md_t *gga_md = ucs_derived_of(md, uct_gga_mlx5_md_t); + + pthread_mutex_destroy(&gga_md->mem_attach_lock); + uct_ib_mlx5_devx_md_close(md); +} + /* TODO: separate memh since atomic_mr is not relevant for GGA */ static uct_md_ops_t uct_mlx5_gga_md_ops = { - .close = uct_ib_mlx5_devx_md_close, + .close = uct_ib_mlx5_gga_md_close, .query = uct_ib_mlx5_gga_md_query, .mem_alloc = uct_ib_mlx5_devx_device_mem_alloc, .mem_free = uct_ib_mlx5_devx_device_mem_free, .mem_reg = uct_ib_mlx5_devx_mem_reg, .mem_dereg = uct_ib_mlx5_devx_mem_dereg, - .mem_attach = uct_ib_mlx5_devx_mem_attach, + .mem_attach = uct_ib_mlx5_gga_mem_attach, .mem_advise = uct_ib_mem_advise, .mkey_pack = uct_ib_mlx5_gga_mkey_pack, .detect_memory_type = ucs_empty_function_return_unsupported, @@ -841,7 +864,8 @@ uct_ib_mlx5_gga_md_open(uct_component_t *component, const char *md_name, struct ibv_device **ib_device_list, *ib_device; int num_devices, fork_init; ucs_status_t status; - uct_ib_md_t *md; + uct_gga_mlx5_md_t *md; + int ret; ucs_trace("opening GGA device %s", md_name); @@ -867,19 +891,30 @@ uct_ib_mlx5_gga_md_open(uct_component_t *component, const char *md_name, goto out_free_dev_list; } - status = uct_ib_mlx5_devx_md_open(ib_device, md_config, &md); + status = uct_ib_mlx5_devx_md_open_common("gga_mlx5_md_t", + sizeof(uct_gga_mlx5_md_t), + ib_device, md_config, + (uct_ib_md_t**)&md); if (status != UCS_OK) { goto out_free_dev_list; } - md->super.component = &uct_gga_component; - md->super.ops = &uct_mlx5_gga_md_ops; - md->name = UCT_IB_MD_NAME(gga); - md->fork_init = fork_init; - *md_p = &md->super; + md->super.super.super.component = &uct_gga_component; + md->super.super.super.ops = &uct_mlx5_gga_md_ops; + md->super.super.name = UCT_IB_MD_NAME(gga); + md->super.super.fork_init = fork_init; + + ret = pthread_mutex_init(&md->mem_attach_lock, NULL); + if (ret != 0) { + ucs_error("pthread_mutex_init failed with error: %s", strerror(ret)); + status = UCS_ERR_IO_ERROR; + uct_ib_mlx5_devx_md_close(&md->super.super.super); + goto out_free_dev_list; + } + + *md_p = &md->super.super.super; out_free_dev_list: ibv_free_device_list(ib_device_list); -out: return status; } diff --git a/src/uct/ib/mlx5/ib_mlx5.h b/src/uct/ib/mlx5/ib_mlx5.h index 2e082ef5e50..79fef59cba1 100644 --- a/src/uct/ib/mlx5/ib_mlx5.h +++ b/src/uct/ib/mlx5/ib_mlx5.h @@ -1196,9 +1196,10 @@ uct_ib_mlx5_devx_mkey_pack(uct_md_h uct_md, uct_mem_h uct_memh, struct ibv_context* uct_ib_mlx5_devx_open_device(struct ibv_device *ibv_device); -ucs_status_t uct_ib_mlx5_devx_md_open(struct ibv_device *ibv_device, - const uct_ib_md_config_t *md_config, - uct_ib_md_t **p_md); +ucs_status_t uct_ib_mlx5_devx_md_open_common(const char* name, size_t size, + struct ibv_device *ibv_device, + const uct_ib_md_config_t *md_config, + uct_ib_md_t **p_md); ucs_status_t uct_ib_mlx5_devx_reg_exported_key(uct_ib_mlx5_md_t *md, uct_ib_mlx5_devx_mem_t *memh);