Skip to content

Commit

Permalink
Mranged params (openucx#681)
Browse files Browse the repository at this point in the history
* UTIL: mranged unsigned parameter

* TL/UCP: range_uint for sra radix

* BUILD: fix linter warns

* TL/UCP: range_uint for kn_radix and cleanup

* UTIL: fix additional auto case

* REVIEW: code review fixes and gtest mrange parse

Co-authored-by: Valentin Petrov <[email protected]>
  • Loading branch information
shimmybalsam and Valentin Petrov authored Dec 19, 2022
1 parent 8ac022e commit a7a641c
Show file tree
Hide file tree
Showing 18 changed files with 489 additions and 74 deletions.
1 change: 1 addition & 0 deletions src/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ libucc_la_SOURCES = \
coll_score/ucc_coll_score_map.c \
utils/ini.c \
utils/ucc_component.c \
utils/ucc_datastruct.c \
utils/ucc_status.c \
utils/ucc_mpool.c \
utils/ucc_math.c \
Expand Down
67 changes: 17 additions & 50 deletions src/coll_score/ucc_coll_score.c
Original file line number Diff line number Diff line change
Expand Up @@ -475,44 +475,6 @@ static ucc_status_t str_to_coll_type(const char *str, unsigned *ct_n,
return status;
}

static ucc_status_t str_to_mem_type(const char *str, unsigned *mt_n,
ucc_memory_type_t **mt)
{
ucc_status_t status = UCC_OK;
char ** tokens;
unsigned i, n_tokens;
ucc_memory_type_t t;
tokens = ucc_str_split(str, ",");
if (!tokens) {
status = UCC_ERR_INVALID_PARAM;
goto out;
}
n_tokens = ucc_str_split_count(tokens);
*mt = ucc_malloc(n_tokens * sizeof(ucc_memory_type_t), "ucc_mem_types");
if (!(*mt)) {
ucc_error("failed to allocate %zd bytes for ucc_mem_types",
sizeof(ucc_memory_type_t) * n_tokens);
status = UCC_ERR_NO_MEMORY;
goto out;
}
*mt_n = 0;
for (i = 0; i < n_tokens; i++) {
t = ucc_mem_type_from_str(tokens[i]);
if (t == UCC_MEMORY_TYPE_LAST) {
/* entry does not match any memory type name */
ucc_free(*mt);
*mt = NULL;
status = UCC_ERR_NOT_FOUND;
goto out;
}
(*mt)[*mt_n] = t;
(*mt_n)++;
}
out:
ucc_str_split_free(tokens);
return status;
}

static ucc_status_t str_to_score(const char *str, ucc_score_t *score)
{
if (0 == strcasecmp("inf", str)) {
Expand Down Expand Up @@ -676,17 +638,17 @@ static ucc_status_t ucc_coll_score_parse_str(const char *str,
{
ucc_status_t status = UCC_OK;
ucc_coll_type_t *ct = NULL;
ucc_memory_type_t *mt = NULL;
size_t *msg = NULL;
ucc_rank_t *tsizes = NULL;
ucc_base_coll_init_fn_t alg_init = NULL;
const char* alg_id = NULL;
ucc_score_t score_v = UCC_SCORE_INVALID;
int ts_skip = 0;
uint32_t mtypes = 0;
char **tokens;
unsigned i, n_tokens, ct_n, mt_n, c, m, n_ranges, r, n_tsizes;
unsigned i, n_tokens, ct_n, c, m, n_ranges, r, n_tsizes;

mt_n = ct_n = n_ranges = n_tsizes = 0;
ct_n = n_ranges = n_tsizes = 0;
tokens = ucc_str_split(str, ":");
if (!tokens) {
status = UCC_ERR_INVALID_PARAM;
Expand All @@ -697,7 +659,8 @@ static ucc_status_t ucc_coll_score_parse_str(const char *str,
if (!ct && UCC_OK == str_to_coll_type(tokens[i], &ct_n, &ct)) {
continue;
}
if (!mt && UCC_OK == str_to_mem_type(tokens[i], &mt_n, &mt)) {
if (!mtypes && UCC_OK == ucc_str_to_mtype_map(tokens[i], ",",
&mtypes)) {
continue;
}
if ((UCC_SCORE_INVALID == score_v) &&
Expand Down Expand Up @@ -733,18 +696,23 @@ static ucc_status_t ucc_coll_score_parse_str(const char *str,
if (!ts_skip && (UCC_SCORE_INVALID != score_v || NULL != alg_id)) {
/* Score provided but not coll_types/mem_types.
This means: apply score to ALL coll_types/mem_types */
if (!ct)
if (!ct) {
ct_n = UCC_COLL_TYPE_NUM;
if (!mt)
mt_n = UCC_MEMORY_TYPE_LAST;
if (!msg)
}
if (!mtypes) {
mtypes = UCC_MEM_TYPE_MASK_FULL;
}
if (!msg) {
n_ranges = 1;
}
for (c = 0; c < ct_n; c++) {
for (m = 0; m < mt_n; m++) {
for (m = 0; m < UCC_MEMORY_TYPE_LAST; m++) {
if (!(UCC_BIT(m) & mtypes)) {
continue;
}
ucc_coll_type_t coll_type = ct ? ct[c] :
(ucc_coll_type_t)UCC_BIT(c);
ucc_memory_type_t mem_type = mt ? mt[m] :
(ucc_memory_type_t)m;
ucc_memory_type_t mem_type = (ucc_memory_type_t)m;
if (alg_id) {
if (!alg_fn) {
status = UCC_ERR_NOT_SUPPORTED;
Expand Down Expand Up @@ -799,7 +767,6 @@ static ucc_status_t ucc_coll_score_parse_str(const char *str,
}
out:
ucc_free(ct);
ucc_free(mt);
ucc_free(msg);
ucc_free(tsizes);
ucc_str_split_free(tokens);
Expand Down
3 changes: 2 additions & 1 deletion src/components/cl/hier/allreduce/allreduce_split_rail.c
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,9 @@ static ucc_status_t ucc_cl_hier_allreduce_split_rail_frag_init(
rs_args.args.dst.info_v.counts = counts;
rs_args.args.dst.info_v.mem_type = coll_args->args.dst.info.mem_type;
rs_args.args.dst.info_v.datatype = coll_args->args.dst.info.datatype;
/* linter thinks node_size can be 0 - false positive */
rs_args.max_frag_count = ucc_buffer_block_count(
ucc_buffer_block_count(total_count, n_frags, 0), node_size, 0);
ucc_buffer_block_count(total_count, n_frags, 0), node_size, 0); //NOLINT
rs_args.mask |= UCC_BASE_CARGS_MAX_FRAG_COUNT;


Expand Down
23 changes: 18 additions & 5 deletions src/components/tl/ucp/allreduce/allreduce_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,22 @@ ucc_status_t ucc_tl_ucp_allreduce_knomial_start(ucc_coll_task_t *coll_task)
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t size = (ucc_rank_t)task->subset.map.ep_num;
ucc_rank_t rank = task->subset.myrank;
ucc_memory_type_t mem_type = TASK_ARGS(task).dst.info.mem_type;
size_t count = TASK_ARGS(task).dst.info.count;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
size_t data_size = count * ucc_dt_size(dt);
ucc_mrange_uint_t *p =
&UCC_TL_UCP_TEAM_LIB(team)->cfg.allreduce_kn_radix;
ucc_kn_radix_t cfg_radix;
ucc_status_t status;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_kn_start", 0);
task->allreduce_kn.phase = UCC_KN_PHASE_INIT;
ucc_assert(UCC_IS_INPLACE(TASK_ARGS(task)) ||
(TASK_ARGS(task).src.info.mem_type ==
TASK_ARGS(task).dst.info.mem_type));
ucc_knomial_pattern_init(size, rank,
ucc_min(team->cfg.allreduce_kn_radix, size),
(TASK_ARGS(task).src.info.mem_type == mem_type));
cfg_radix = ucc_tl_ucp_get_radix_from_range(team, data_size,
mem_type, p);
ucc_knomial_pattern_init(size, rank, ucc_min(cfg_radix, size),
&task->allreduce_kn.p);
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
status =
Expand All @@ -209,17 +216,23 @@ ucc_status_t ucc_tl_ucp_allreduce_knomial_start(ucc_coll_task_t *coll_task)
ucc_status_t ucc_tl_ucp_allreduce_knomial_init_common(ucc_tl_ucp_task_t *task)
{
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_memory_type_t mem_type = TASK_ARGS(task).dst.info.mem_type;
size_t count = TASK_ARGS(task).dst.info.count;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
size_t data_size = count * ucc_dt_size(dt);
ucc_rank_t size = (ucc_rank_t)task->subset.map.ep_num;
ucc_kn_radix_t radix = ucc_min(team->cfg.allreduce_kn_radix, size);
ucc_mrange_uint_t *p =
&UCC_TL_UCP_TEAM_LIB(team)->cfg.allreduce_kn_radix;
ucc_kn_radix_t radix, cfg_radix;
ucc_status_t status;

task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
task->super.post = ucc_tl_ucp_allreduce_knomial_start;
task->super.progress = ucc_tl_ucp_allreduce_knomial_progress;
task->super.finalize = ucc_tl_ucp_allreduce_knomial_finalize;
cfg_radix = ucc_tl_ucp_get_radix_from_range(team, data_size,
mem_type, p);
radix = ucc_min(cfg_radix, size);
status = ucc_mc_alloc(&task->allreduce_kn.scratch_mc_header,
(radix - 1) * data_size,
TASK_ARGS(task).dst.info.mem_type);
Expand Down
22 changes: 18 additions & 4 deletions src/components/tl/ucp/allreduce/allreduce_sra_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,35 @@ static ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_frag_init(
ucc_base_team_t *team, ucc_schedule_t **frag_p)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
size_t count = coll_args->args.dst.info.count;
ucc_datatype_t dtype = coll_args->args.dst.info.datatype;
ucc_memory_type_t mem_type = coll_args->args.dst.info.mem_type;
ucc_base_coll_args_t args = *coll_args;
ucc_mrange_uint_t *p =
&UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.allreduce_sra_kn_radix;
ucc_schedule_t *schedule;
ucc_coll_task_t *task, *rs_task;
ucc_status_t status;
ucc_kn_radix_t radix, cfg_radix;
size_t count;

status = ucc_tl_ucp_get_schedule(tl_team, coll_args,
(ucc_tl_ucp_schedule_t **)&schedule);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}
cfg_radix = tl_team->cfg.allreduce_sra_kn_radix;
radix = ucc_knomial_pattern_get_min_radix(cfg_radix,
UCC_TL_TEAM_SIZE(tl_team), count);

if (coll_args->mask & UCC_BASE_CARGS_MAX_FRAG_COUNT) {
count = coll_args->max_frag_count;
} else {
count = coll_args->args.dst.info.count;
}

cfg_radix = ucc_tl_ucp_get_radix_from_range(tl_team,
count * ucc_dt_size(dtype),
mem_type, p);
radix = ucc_knomial_pattern_get_min_radix(cfg_radix,
UCC_TL_TEAM_SIZE(tl_team),
count);

/* 1st step of allreduce: knomial reduce_scatter */
UCC_CHECK_GOTO(
Expand Down
8 changes: 4 additions & 4 deletions src/components/tl/ucp/tl_ucp.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ ucc_config_field_t ucc_tl_ucp_lib_config_table[] = {
ucc_offsetof(ucc_tl_ucp_lib_config_t, fanout_kn_radix),
UCC_CONFIG_TYPE_UINT},

{"ALLREDUCE_KN_RADIX", "4",
{"ALLREDUCE_KN_RADIX", "auto",
"Radix of the recursive-knomial allreduce algorithm",
ucc_offsetof(ucc_tl_ucp_lib_config_t, allreduce_kn_radix),
UCC_CONFIG_TYPE_UINT},
UCC_CONFIG_TYPE_UINT_RANGED},

{"ALLREDUCE_SRA_KN_RADIX", "4",
{"ALLREDUCE_SRA_KN_RADIX", "auto",
"Radix of the scatter-reduce-allgather (SRA) knomial allreduce algorithm",
ucc_offsetof(ucc_tl_ucp_lib_config_t, allreduce_sra_kn_radix),
UCC_CONFIG_TYPE_UINT},
UCC_CONFIG_TYPE_UINT_RANGED},

{"ALLREDUCE_SRA_KN_PIPELINE", "n",
"Pipelining settings for SRA Knomial allreduce algorithm",
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/ucp/tl_ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ typedef struct ucc_tl_ucp_lib_config {
uint32_t fanin_kn_radix;
uint32_t fanout_kn_radix;
uint32_t barrier_kn_radix;
uint32_t allreduce_kn_radix;
uint32_t allreduce_sra_kn_radix;
ucc_mrange_uint_t allreduce_kn_radix;
ucc_mrange_uint_t allreduce_sra_kn_radix;
uint32_t reduce_scatter_kn_radix;
uint32_t allgather_kn_radix;
uint32_t bcast_kn_radix;
Expand Down
16 changes: 16 additions & 0 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "components/ec/ucc_ec.h"
#include "tl_ucp_tag.h"

#define UCC_UUNITS_AUTO_RADIX 4
#define UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR 5
extern const char
*ucc_tl_ucp_default_alg_select_str[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR];
Expand Down Expand Up @@ -307,5 +308,20 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str,
ucc_memory_type_t mem_type,
ucc_base_coll_init_fn_t *init);

static inline unsigned
ucc_tl_ucp_get_radix_from_range(ucc_tl_ucp_team_t *team,
size_t msgsize,
ucc_memory_type_t mem_type,
ucc_mrange_uint_t *p)
{
unsigned radix;

radix = ucc_mrange_uint_get(p, msgsize, mem_type);

if (UCC_UUNITS_AUTO == radix) {
/* auto selection based on team configuration */
return UCC_UUNITS_AUTO_RADIX;
}
return radix;
}
#endif
9 changes: 6 additions & 3 deletions src/components/tl/ucp/tl_ucp_lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_lib_t, const ucc_base_lib_params_t *params,

UCC_CLASS_CALL_SUPER_INIT(ucc_tl_lib_t, &ucc_tl_ucp.super,
&tl_ucp_config->super);
memcpy(&self->cfg, tl_ucp_config, sizeof(*tl_ucp_config));
status = ucc_config_clone_table(tl_ucp_config, &self->cfg,
ucc_tl_ucp_lib_config_table);
if (UCC_OK != status) {
return status;
}

if (tl_ucp_config->kn_radix > 0) {
self->cfg.barrier_kn_radix = tl_ucp_config->kn_radix;
self->cfg.allreduce_kn_radix = tl_ucp_config->kn_radix;
self->cfg.allreduce_sra_kn_radix = tl_ucp_config->kn_radix;
self->cfg.reduce_scatter_kn_radix = tl_ucp_config->kn_radix;
self->cfg.allgather_kn_radix = tl_ucp_config->kn_radix;
self->cfg.bcast_kn_radix = tl_ucp_config->kn_radix;
Expand Down
2 changes: 2 additions & 0 deletions src/utils/ucc_coll_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
#define UCC_COLL_ARGS_ACTIVE_SET(_args) \
((_args)->mask & UCC_COLL_ARGS_FIELD_ACTIVE_SET)

#define UCC_MEM_TYPE_MASK_FULL -1

static inline size_t
ucc_coll_args_get_count(const ucc_coll_args_t *args, const ucc_count_t *counts,
ucc_rank_t idx)
Expand Down
46 changes: 46 additions & 0 deletions src/utils/ucc_datastruct.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/**
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* See file LICENSE for terms.
*/

#include "ucc_datastruct.h"
#include "ucc_malloc.h"
#include "ucc_compiler_def.h"
#include "ucc_log.h"

ucc_status_t ucc_mrange_uint_copy(ucc_mrange_uint_t *dst,
const ucc_mrange_uint_t *src)
{
ucc_mrange_t *r, *r_dup;

dst->default_value = src->default_value;
ucc_list_head_init(&dst->ranges);
ucc_list_for_each(r, &src->ranges, list_elem) {
r_dup = ucc_malloc(sizeof(*r_dup), "range_dup");
if (ucc_unlikely(!r_dup)) {
ucc_error("failed to allocate %zd bytes for mrange",
sizeof(*r_dup));
goto err;
}
r_dup->start = r->start;
r_dup->end = r->end;
r_dup->value = r->value;
r_dup->mtypes = r->mtypes;
ucc_list_add_tail(&dst->ranges, &r_dup->list_elem);
}

return UCC_OK;
err:
ucc_mrange_uint_destroy(dst);
return UCC_ERR_NO_MEMORY;
}

void ucc_mrange_uint_destroy(ucc_mrange_uint_t *param)
{
ucc_mrange_t *r, *r_tmp;

ucc_list_for_each_safe(r, r_tmp, &param->ranges, list_elem) {
ucc_list_del(&r->list_elem);
ucc_free(r);
}
}
Loading

0 comments on commit a7a641c

Please sign in to comment.