diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml index 0d77772cb8ea..1881c968b560 100644 --- a/.github/workflows/hpu-gaudi2.yml +++ b/.github/workflows/hpu-gaudi2.yml @@ -23,7 +23,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest + image: vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice diff --git a/.github/workflows/nv-accelerate-v100.yml b/.github/workflows/nv-accelerate-v100.yml index 93286b62610a..1fccbece2994 100644 --- a/.github/workflows/nv-accelerate-v100.yml +++ b/.github/workflows/nv-accelerate-v100.yml @@ -19,7 +19,7 @@ concurrency: jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] + runs-on: [self-hosted, nvidia, cu117, v100] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/nv-ds-chat.yml b/.github/workflows/nv-ds-chat.yml index 61011a85b92c..f61637be7e0e 100644 --- a/.github/workflows/nv-ds-chat.yml +++ b/.github/workflows/nv-ds-chat.yml @@ -21,7 +21,7 @@ permissions: jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] + runs-on: [self-hosted, nvidia, cu117, v100] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml index 2b74e7e155df..6b339f457802 100644 --- a/.github/workflows/nv-inference.yml +++ b/.github/workflows/nv-inference.yml @@ -22,7 +22,7 @@ concurrency: jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] + runs-on: [self-hosted, nvidia, cu117, v100] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/nv-mii.yml b/.github/workflows/nv-mii.yml index 0b3f128be5a4..31379f7e758b 100644 --- a/.github/workflows/nv-mii.yml +++ b/.github/workflows/nv-mii.yml @@ -27,7 +27,7 @@ concurrency: jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] + runs-on: [self-hosted, nvidia, cu117, v100] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/nv-nightly.yml b/.github/workflows/nv-nightly.yml index e540b5acaf33..ca091990cf4b 100644 --- a/.github/workflows/nv-nightly.yml +++ b/.github/workflows/nv-nightly.yml @@ -15,7 +15,7 @@ permissions: jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] + runs-on: [self-hosted, nvidia, cu117, v100] steps: - uses: actions/checkout@v3 @@ -25,7 +25,7 @@ jobs: - name: Install pytorch run: | - pip install -U --cache-dir $TORCH_CACHE torch==1.13.1 torchvision --index-url https://download.pytorch.org/whl/cu116 + pip install -U --cache-dir $TORCH_CACHE torch==1.13.1 torchvision --index-url https://download.pytorch.org/whl/cu117 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -55,7 +55,7 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - pytest $PYTEST_OPTS --forked -m 'nightly' unit/ --torch_ver="1.13" --cuda_ver="11.6" + pytest $PYTEST_OPTS --forked -m 'nightly' unit/ --torch_ver="1.13" --cuda_ver="11.7" - name: Open GitHub issue if nightly CI fails if: ${{ failure() && (github.event_name == 'schedule') }} diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml index 18db40380577..6e308242ecf0 100644 --- a/.github/workflows/nv-pre-compile-ops.yml +++ b/.github/workflows/nv-pre-compile-ops.yml @@ -36,7 +36,7 @@ jobs: #python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - name: Compile DeepSpeed Ops run: | - DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install . + DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install . - name: DS Report run: | ds_report diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index e2d0f172dcbf..14d33680521d 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -19,7 +19,7 @@ concurrency: jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] + runs-on: [self-hosted, nvidia, cu117, v100] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/nv-torch-nightly-v100.yml b/.github/workflows/nv-torch-nightly-v100.yml index f46c5089b241..bd13047f6078 100644 --- a/.github/workflows/nv-torch-nightly-v100.yml +++ b/.github/workflows/nv-torch-nightly-v100.yml @@ -15,7 +15,7 @@ permissions: jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] + runs-on: [self-hosted, nvidia, cu117, v100] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/nv-transformers-v100.yml b/.github/workflows/nv-transformers-v100.yml index 4fbc42abec5f..75f53c95c235 100644 --- a/.github/workflows/nv-transformers-v100.yml +++ b/.github/workflows/nv-transformers-v100.yml @@ -18,7 +18,7 @@ concurrency: jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu116, v100] + runs-on: [self-hosted, nvidia, cu117, v100] steps: - uses: actions/checkout@v3 diff --git a/blogs/deepspeed-fp6/03-05-2024/README.md b/blogs/deepspeed-fp6/03-05-2024/README.md index dbd6b2d081aa..0285dd79b87d 100755 --- a/blogs/deepspeed-fp6/03-05-2024/README.md +++ b/blogs/deepspeed-fp6/03-05-2024/README.md @@ -43,7 +43,7 @@ To cite DeepSpeed-FP6, please cite the following two arxiv reports - ZeroQuant(4 In the evolving landscape of Large Language Models (LLMs) like GPT, our research aims to boost computational efficiency and storage while preserving model quality. This focus brings us to tackle the complex challenges of 4-bit quantization, where optimizing performance, efficiency, and accuracy is crucial. -**Exploring the Challenges of 4-bit Quantization** In our recent research findings -- ZeroQuant (4+2)[1], we explore the capabilities of INT4 quantization techniques (like the GPTQ algorithm) for serving Large Language Models (LLMs). While these techniques reduce memory and computational requirements, they often perform poorly on a broad array of tasks, including generative tasks such as code generation and summarization, due to overfitting issues. This highlights the urgent need for new quantization approaches that simultanenously improve both the efficiency and effectiveness of LLMs. +**Exploring the Challenges of 4-bit Quantization** In our recent research findings -- ZeroQuant (4+2)[1], we explore the capabilities of INT4 quantization techniques (like the GPTQ algorithm) for serving Large Language Models (LLMs). While these techniques reduce memory and computational requirements, they often perform poorly on a broad array of tasks, including generative tasks such as code generation and summarization, due to overfitting issues. This highlights the urgent need for new quantization approaches that simultaneously improve both the efficiency and effectiveness of LLMs. **Breakthroughs with FP6 Precision** Our exploration of different quantization methods led us to the FP6 precision standard. Despite the challenges in integrating and accelerating FP6 with current AI hardware -- which we will address in the next section - this format excels in performance and flexibility across various tasks. Notably, we observe that for generative tasks, FP6 quantization can match the performance of the half-precision (FP16) format. For example, with FP6 quantization, StarCoder-15B achieves comparable code generation results to the FP16 variant, while a smaller model, such as BART-460M, achieves comparable summarization performance to the standard FP16 equivalent. In order to preserve these quality gains, while matching the system efficiency of INT4 quantization on AI hardware, we propose a novel 4+2 FP6 scheme. This innovation makes FP6 a promising direction for improving the efficiency of LLMs, marking a significant leap in AI technology advancement. For more details, please refer to our research paper - ZeroQuant (4+2)[1]. diff --git a/blogs/deepspeed-ulysses/README.md b/blogs/deepspeed-ulysses/README.md index aa4416521dd1..375eb1190325 100644 --- a/blogs/deepspeed-ulysses/README.md +++ b/blogs/deepspeed-ulysses/README.md @@ -233,7 +233,7 @@ at different sequence length and GPU count.* Next, we evaluate Ulysses on 7 billion (7B) and 30 billion (30B) parameter GPT dense attention models and compare against Megatron-LM's sequence -parallelism (Megatron LM) and Colosal AI sequence parallelism (ColAI-SP) on +parallelism (Megatron LM) and Colossal AI sequence parallelism (ColAI-SP) on 32 and 64 A100 GPUs respectively. The results of these evaluations are shown in Figures 3 and 4. diff --git a/csrc/cpu/comm/ccl.cpp b/csrc/cpu/comm/ccl.cpp index 6428ab5cbfa5..786906717f23 100644 --- a/csrc/cpu/comm/ccl.cpp +++ b/csrc/cpu/comm/ccl.cpp @@ -5,281 +5,24 @@ #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include +#include "shm.h" -// states for collectives -enum coll_state { - coll_begin = 0, - // coll states for naive allreduce - coll_allreduce_naive__copy_in_done, // this state is for rank != 0 - coll_allreduce_naive__reduce_done, // this state is for rank == 0 - coll_allreduce_naive__copy_out_done, // this state is for rank != 0 -}; - -// SHM building blocks -struct SharedData { - const char* name; - int descriptor; - void* bytes; - size_t nbytes; -}; - -void shared_open(SharedData* data, const char* name, size_t nbytes) -{ - int d = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR); - if (d != -1) { - void* bytes = mmap(NULL, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, d, 0); - data->name = name; - data->descriptor = d; - data->bytes = bytes; - data->nbytes = nbytes; - } else { - printf("shared_open %s failed\n", name); - data->descriptor = -1; - } -} - -void shared_create(SharedData* data, const char* name, void* bytes, size_t nbytes) -{ - int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR); - if (d != -1) { - if (nbytes = write(d, bytes, nbytes)) { shared_open(data, name, nbytes); } - } else { - printf("shared_create %s failed\n", name); - } -} - -void shared_close(SharedData* data) -{ - if (data->descriptor != -1) { - munmap(data->bytes, data->nbytes); - shm_unlink(data->name); - } -} - -// SHM based allreduce helper functions -// buffer that holds shm name -#define NAME_BUF_SIZE 1000 -#define MAX_BUF_SIZE 1048576 -#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer" -SharedData allreduce_buffer; -struct allreduce_workspace { - enum coll_state state; - char buffer[MAX_BUF_SIZE]; -}; -struct allreduce_workspace* workspace; - -void wait_buffer_state_until(int index, enum coll_state state) -{ - volatile enum coll_state* state_ptr = &(workspace[index].state); - - while (*state_ptr != state) - ; -} - -void wait_buffer_state_until_not(int index, enum coll_state state) -{ - volatile enum coll_state* state_ptr = &(workspace[index].state); - - while (*state_ptr == state) - ; -} - -__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); -inline __m512 cvt_bf16_to_fp32(const __m256i src) -{ - auto y = _mm512_cvtepu16_epi32(src); - return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2)); -} - -inline __m256i cvt_fp32_to_bf16(const __m512 src) __attribute__((target("avx512bw"))); -inline __m256i cvt_fp32_to_bf16(const __m512 src) -{ - __m512i value = _mm512_castps_si512(src); - __m512i nan = _mm512_set1_epi32(0xffff); - auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); - __m512i ones = _mm512_set1_epi32(0x1); - __m512i vec_bias = _mm512_set1_epi32(0x7fff); - // uint32_t lsb = (input >> 16) & 1; - auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); - // uint32_t rounding_bias = 0x7fff + lsb; - t_value = _mm512_add_epi32(t_value, vec_bias); - // input += rounding_bias; - t_value = _mm512_add_epi32(t_value, value); - // input = input >> 16; - t_value = _mm512_srli_epi32(t_value, 16); - // Check NaN before converting back to bf16 - t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); - return _mm512_cvtusepi32_epi16(t_value); -} - -void reduce_2_bf16_buffers(int num_elements, void* in_out, void* in) - __attribute__((target("avx512bw"))); - -void reduce_bf16_buffers(int num_elements, int num_buffers, struct allreduce_workspace* workspace) - __attribute__((target("avx512bw"))); - -void reduce_2_fp32_buffers(int num_elements, void* in_out, void* in) - __attribute__((target("avx512bw"))); - -void reduce_fp32_buffers(int num_elements, int num_buffers, struct allreduce_workspace* workspace) - __attribute__((target("avx512bw"))); - -// N_REDUCE_LIMIT is the number of buffers that can be reduced together in one shot. -// Compared with do N-1 2-reduces which needs 2*(N-1) read and N-1 write, -// N-reduce only needs N read and 1 write, this saves 2/3 memory bandwidth. -// When increase N_REDUCE_LIMIT to a bigger number, do the following steps -// 1. Extend REPEAT_ macros list down below -// 2. Extend switch cases which call "REPEAT(X, ...)" down below -#define N_REDUCE_LIMIT 8 - -void reduce_all_buffers(struct allreduce_workspace* workspace, - int num_elements, - c10::ScalarType scalar_type, - int num_buffers) -{ - switch (scalar_type) { - case c10::ScalarType::BFloat16: - if (num_buffers > 2 && num_buffers <= N_REDUCE_LIMIT) { - reduce_bf16_buffers(num_elements, num_buffers, workspace); - } else { - for (int i = 1; i < num_buffers; i++) { - reduce_2_bf16_buffers(num_elements, workspace[0].buffer, workspace[i].buffer); - } - } - break; - case c10::ScalarType::Float: - if (num_buffers > 2 && num_buffers <= N_REDUCE_LIMIT) { - reduce_fp32_buffers(num_elements, num_buffers, workspace); - } else { - for (int i = 1; i < num_buffers; i++) { - reduce_2_fp32_buffers(num_elements, workspace[0].buffer, workspace[i].buffer); - } - } - break; - default: assert(!"Should not get here"); - } -} +// #define DO_PROFILE +#ifdef DO_PROFILE +#include +#include +#endif -#define REPEAT(N, x) REPEAT_##N(x) -#define REPEAT_1(x) x(1) -#define REPEAT_2(x) \ - REPEAT_1(x); \ - x(2) -#define REPEAT_3(x) \ - REPEAT_2(x); \ - x(3) -#define REPEAT_4(x) \ - REPEAT_3(x); \ - x(4) -#define REPEAT_5(x) \ - REPEAT_4(x); \ - x(5) -#define REPEAT_6(x) \ - REPEAT_5(x); \ - x(6) -#define REPEAT_7(x) \ - REPEAT_6(x); \ - x(7) - -#define CVT_ADD_BF16(x) \ - do { \ - auto in##x##_val = \ - cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(workspace[x].buffer + i))); \ - inout_val = _mm512_add_ps(inout_val, in##x##_val); \ - } while (0) - -// Reduce functions down below use vectorized algorithm, the number of bytes processed each -// iteration depends on vector length. 256bit vector ==> 32 bytes, 512bit vector ==> 64 bytes -// If you change implementation of reduce_2_bf16_buffers or reduce_2_fp32_buffers, check -// whether this number needs to be changed -#define VECTOR_LENGTH_IN_BYTES 32 - -// num_elements must be divisible by 16 (caller check) -void reduce_bf16_buffers(int num_elements, int num_buffers, struct allreduce_workspace* workspace) -{ -#pragma omp parallel for - for (int i = 0; i < num_elements * 2; i += VECTOR_LENGTH_IN_BYTES) { - auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(workspace[0].buffer + i))); - switch (num_buffers) { - case 8: REPEAT(7, CVT_ADD_BF16); break; - case 7: REPEAT(6, CVT_ADD_BF16); break; - case 6: REPEAT(5, CVT_ADD_BF16); break; - case 5: REPEAT(4, CVT_ADD_BF16); break; - case 4: REPEAT(3, CVT_ADD_BF16); break; - case 3: REPEAT(2, CVT_ADD_BF16); break; - default: assert(!"Should not get here."); - } - _mm256_storeu_si256((__m256i*)(workspace[0].buffer + i), cvt_fp32_to_bf16(inout_val)); - } -} - -void reduce_2_bf16_buffers(int num_elements, void* in_out, void* in1) -{ -#pragma omp parallel for - for (int i = 0; i < num_elements * 2; i += VECTOR_LENGTH_IN_BYTES) { - auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)((char*)in_out + i))); - auto in1_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)((char*)in1 + i))); - inout_val = _mm512_add_ps(inout_val, in1_val); - _mm256_storeu_si256((__m256i*)((char*)in_out + i), cvt_fp32_to_bf16(inout_val)); - } -} - -#define CVT_ADD_F32(x) \ - do { \ - auto in##x##_val = _mm256_loadu_ps((float*)(workspace[x].buffer + i)); \ - inout_val = _mm256_add_ps(inout_val, in##x##_val); \ - } while (0) - -// num_elements must be divisible by 16 (caller check) -void reduce_fp32_buffers(int num_elements, int num_buffers, struct allreduce_workspace* workspace) -{ -#pragma omp parallel for - for (int i = 0; i < num_elements * 4; i += VECTOR_LENGTH_IN_BYTES) { - auto inout_val = _mm256_loadu_ps((float*)(workspace[0].buffer + i)); - switch (num_buffers) { - case 8: REPEAT(7, CVT_ADD_F32); break; - case 7: REPEAT(6, CVT_ADD_F32); break; - case 6: REPEAT(5, CVT_ADD_F32); break; - case 5: REPEAT(4, CVT_ADD_F32); break; - case 4: REPEAT(3, CVT_ADD_F32); break; - case 3: REPEAT(2, CVT_ADD_F32); break; - default: assert(!"Should not get here."); - } - _mm256_storeu_ps((float*)(workspace[0].buffer + i), inout_val); - } -} +// Communication settings +static int world_rank = -1; +static int world_size = -1; -void reduce_2_fp32_buffers(int num_elements, void* in_out, void* in1) -{ -#pragma omp parallel for - for (int i = 0; i < num_elements * 4; i += VECTOR_LENGTH_IN_BYTES) { - auto inout_val = _mm256_loadu_ps((float*)((char*)in_out + i)); - auto in1_val = _mm256_loadu_ps((float*)((char*)in1 + i)); - inout_val = _mm256_add_ps(inout_val, in1_val); - _mm256_storeu_ps((float*)((char*)in_out + i), inout_val); - } -} - -// Communicatiooon settings -int world_rank = -1; -int world_size = -1; - -std::set _comm_ids; -std::set _colors; -std::vector _ccl_comms; -ccl::shared_ptr_class sub_kvs; -std::map, int> group_to_comm_id; +static std::set _comm_ids; +static std::set _colors; +static std::vector _ccl_comms; +static ccl::shared_ptr_class sub_kvs; +static std::map, int> group_to_comm_id; ccl::communicator& _get_comm_from_group() { return _ccl_comms[0]; } ccl::communicator& _get_comm_from_group(py::object group) { return _ccl_comms[0]; } @@ -300,11 +43,11 @@ ccl::communicator& _get_comm_from_group(std::vector ranks) #define KVS_CREATE_SUCCESS 0 #define KVS_CREATE_FAILURE -1 -bool is_initialized = 0; +static bool is_initialized = 0; -ccl::shared_ptr_class kvs; +static ccl::shared_ptr_class kvs; -bool all_ranks_local_p = false; +static bool all_ranks_local_p = false; void initialize(int size, int rank, torch::Tensor& kvs_data) { @@ -336,30 +79,8 @@ void initialize(int size, int rank, torch::Tensor& kvs_data) if (addr_string == NULL) { addr_string = ""; } auto port_string = std::getenv("MASTER_PORT"); if (port_string == NULL) { port_string = ""; } - char shm_name[NAME_BUF_SIZE]; - snprintf(shm_name, - NAME_BUF_SIZE, - "%s_%d_%s_%s", - SHM_BUFFER_NAME, - getuid(), - addr_string, - port_string); - // create shared workspace for SHM based allreduce - if (all_ranks_local_p) { - if (rank == 0) { - workspace = - (struct allreduce_workspace*)malloc(size * sizeof(struct allreduce_workspace)); - shared_create( - &allreduce_buffer, shm_name, workspace, size * sizeof(struct allreduce_workspace)); - workspace = (struct allreduce_workspace*)allreduce_buffer.bytes; - for (int i = 0; i < size; i++) { workspace[i].state = coll_begin; } - } - CCLCHECK(ccl::barrier(_get_comm_from_group()).wait()); - if (rank != 0) { - shared_open(&allreduce_buffer, shm_name, size * sizeof(struct allreduce_workspace)); - } - workspace = (struct allreduce_workspace*)allreduce_buffer.bytes; - } + + if (all_ranks_local_p) { shm_initialize(size, rank, addr_string, port_string); } } /* @@ -526,19 +247,22 @@ void all_reduce_caching(torch::Tensor& data, .wait()); } -static void parallel_memcpy(void* to, void* from, size_t n_bytes) - __attribute__((target("avx512bw"))); -static void parallel_memcpy(void* to, void* from, size_t n_bytes) -{ -#pragma omp parallel for - for (int i = 0; i < n_bytes; i += VECTOR_LENGTH_IN_BYTES) { - auto val = _mm256_loadu_si256((__m256i*)((char*)from + i)); - _mm256_storeu_si256((__m256i*)((char*)to + i), val); - } -} - void inference_all_reduce(torch::Tensor& data, py::object op, bool async_op) { +#ifdef DO_PROFILE + static double total_time = 0.0; + static double total_time_sq = 0.0; + static int count = -16; // warmup + static double max_time = 0.0; + static double min_time = DBL_MAX; + // make sure all rank reach this point before measuring time + // turn on this if you suspect each rank didn't reach here at the same time (stragger) + // if (all_ranks_local_p) { + // barrier_wait(0, world_size); + //} + auto start = std::chrono::system_clock::now(); +#endif + static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp"); static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value")); @@ -555,7 +279,7 @@ void inference_all_reduce(torch::Tensor& data, py::object op, bool async_op) default: data_type_fallback = true; } - if (data_type_fallback || (data_size % VECTOR_LENGTH_IN_BYTES) != 0 || !all_ranks_local_p) { + if (data_type_fallback || !all_ranks_local_p) { // fallback to oneccl allreduce CCLCHECK(ccl::allreduce(data.data_ptr(), data.data_ptr(), @@ -564,50 +288,31 @@ void inference_all_reduce(torch::Tensor& data, py::object op, bool async_op) get_ccl_reduce_op(op, data), _get_comm_from_group()) .wait()); - return; + } else { + all_reduce_outer_loop(data, numel, data_size); } - for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) { - auto data_ptr = ((char*)(data.data_ptr()) + offset); - size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset; - size_t chunk_el = chunk_size / (data_size / numel); - - parallel_memcpy(workspace[world_rank].buffer, data_ptr, chunk_size); - std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank].state = coll_allreduce_naive__copy_in_done; - - if (world_rank == 0) { - // compute allreduce result on rank 0 - for (int i = 1; i < world_size; i++) { - // wait until the other rank copy the buffer - wait_buffer_state_until(i, coll_allreduce_naive__copy_in_done); - } - reduce_all_buffers(workspace, chunk_el, data.scalar_type(), world_size); - std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank].state = coll_allreduce_naive__reduce_done; - parallel_memcpy(data_ptr, workspace[0].buffer, chunk_size); - } - if (world_rank != 0) { - wait_buffer_state_until(0, coll_allreduce_naive__reduce_done); - parallel_memcpy(data_ptr, workspace[0].buffer, chunk_size); - std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank].state = coll_allreduce_naive__copy_out_done; - } - if (world_rank == 0) { - for (int i = 1; i < world_size; i++) { - wait_buffer_state_until(i, coll_allreduce_naive__copy_out_done); - } - std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank].state = coll_begin; - } - if (world_rank != 0) { - // if rank 0 spin too fast it could be in state 1 of next allreduce - // in this case wait_buffer_state_until(0, 0) may cause deadlock - // what we are certain is when rank 0 finishes the state won't be 2 - wait_buffer_state_until_not(0, coll_allreduce_naive__reduce_done); - workspace[world_rank].state = coll_begin; +#ifdef DO_PROFILE + auto end = std::chrono::system_clock::now(); + count++; + if (count > 0) { + double elapsed = std::chrono::duration_cast(end - start).count(); + if (elapsed > max_time) { max_time = elapsed; } + if (elapsed < min_time) { min_time = elapsed; } + total_time += elapsed; + total_time_sq += elapsed * elapsed; + if (world_rank == 0 && count == 1000) { + auto avg = total_time / count; + auto sd = + sqrt(total_time_sq / count - total_time * total_time / (count * count)) / avg * 100; + printf(" C++ kernel\t\t %.2f\t %.2f\t%.2f\t %.2f\n", + min_time, + max_time, + total_time / count, + sd); } } +#endif } void barrier(std::vector group, bool async_op) diff --git a/csrc/cpu/comm/shm.cpp b/csrc/cpu/comm/shm.cpp new file mode 100644 index 000000000000..859c2fec292d --- /dev/null +++ b/csrc/cpu/comm/shm.cpp @@ -0,0 +1,686 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +#include +#include +#include +#include +#include +#include "shm.h" + +// #define DO_PROFILE +#ifdef DO_PROFILE +#include +#include +#endif + +// states for collectives +enum coll_state { + coll_begin = 0, + coll_allreduce_naive__copy_in_done, // this state is for rank != 0 + coll_allreduce_naive__reduce_done, // this state is for rank == 0 + coll_allreduce_naive__copy_out_done, // this state is for rank != 0 +}; + +// SHM building blocks +struct SharedData { + const char* name; + int descriptor; + void* bytes; + size_t nbytes; +}; + +void shared_open(SharedData* data, const char* name, size_t nbytes) +{ + int d = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR); + if (d != -1) { + void* bytes = mmap(NULL, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, d, 0); + data->name = name; + data->descriptor = d; + data->bytes = bytes; + data->nbytes = nbytes; + } else { + if (errno != ENOENT) { + // don't print if shm can not be found because we want to loop over from + // caller again until the other ranks created the shm + printf("shared_open %s failed, errno=%d\n", name, errno); + } + data->descriptor = -1; + } +} + +void shared_create(SharedData* data, const char* name, void* bytes, size_t nbytes) +{ + int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR); + if (d != -1) { + if (nbytes = write(d, bytes, nbytes)) { shared_open(data, name, nbytes); } + } else { + printf("shared_create %s failed\n", name); + } +} + +void shared_close(SharedData* data) +{ + if (data->descriptor != -1) { + munmap(data->bytes, data->nbytes); + shm_unlink(data->name); + } +} + +// SHM based allreduce helper functions +// buffer that holds shm name +#define NAME_BUF_SIZE 1000 +#define MAX_BUF_SIZE 1048576 * 32 +#define NAIVE_ALLREDUCE_THRESHOLD 1048576 +#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer" +struct allreduce_workspace { + enum coll_state state; + sem_t mutex; + sem_t turnstile1; + sem_t turnstile2; + int counter; + char buffer[MAX_BUF_SIZE]; +}; +struct allreduce_workspace** workspace; + +void wait_buffer_state_until(int index, enum coll_state state) +{ + volatile enum coll_state* state_ptr = &(workspace[index]->state); + + while (*state_ptr != state) + ; +} + +void wait_buffer_state_until_range(int index, enum coll_state start, int size) +{ + volatile enum coll_state* state_ptr = &(workspace[index]->state); + enum coll_state end = (enum coll_state)(start + size); + + while (1) { + volatile enum coll_state cur_state = *state_ptr; + if (cur_state >= start and cur_state < end) break; + } +} + +void wait_buffer_state_until_not(int index, enum coll_state state) +{ + volatile enum coll_state* state_ptr = &(workspace[index]->state); + + while (*state_ptr == state) + ; +} + +void barrier_wait(int root_idx, int num_ranks) +{ + // Phase 1: Wait for all threads to enter the barrier + auto shared = workspace[root_idx]; + sem_wait(&shared->mutex); + shared->counter++; + if (shared->counter == num_ranks) { + for (int i = 0; i < num_ranks; ++i) { sem_post(&shared->turnstile1); } + } + sem_post(&shared->mutex); + sem_wait(&shared->turnstile1); + + // Phase 2: Wait for all threads to exit the barrier + sem_wait(&shared->mutex); + shared->counter--; + if (shared->counter == 0) { + for (int i = 0; i < num_ranks; ++i) { sem_post(&shared->turnstile2); } + } + sem_post(&shared->mutex); + sem_wait(&shared->turnstile2); +} + +__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); +inline __m512 cvt_bf16_to_fp32(const __m256i src) +{ + auto y = _mm512_cvtepu16_epi32(src); + return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2)); +} + +inline __m256i cvt_fp32_to_bf16(const __m512 src) __attribute__((target("avx512bw"))); +inline __m256i cvt_fp32_to_bf16(const __m512 src) +{ + __m512i value = _mm512_castps_si512(src); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm512_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm512_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm512_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); + return _mm512_cvtusepi32_epi16(t_value); +} + +void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out) + __attribute__((target("avx512bw"))); + +void reduce_bf16_buffers(int start_elements, + int num_elements, + int num_buffers, + int to_buffer_idx, + struct allreduce_workspace** workspace) + __attribute__((target("avx512bw"))); + +void reduce_2_fp32_buffers_iio(int num_elements, void* in0, void* in1, void* out) + __attribute__((target("avx512bw"))); + +void reduce_fp32_buffers(int start_elements, + int num_elements, + int num_buffers, + int to_buffer_idx, + struct allreduce_workspace** workspace) + __attribute__((target("avx512bw"))); + +// N_REDUCE_LIMIT is the number of buffers that can be reduced together in one shot. +// Compared with do N-1 2-reduces which needs 2*(N-1) read and N-1 write, +// N-reduce only needs N read and 1 write, this saves 2/3 memory bandwidth. +// When increase N_REDUCE_LIMIT to a bigger number, do the following steps +// 1. Extend REPEAT_ macros list down below +// 2. Extend switch cases which call "REPEAT(X, ...)" down below +#define N_REDUCE_LIMIT 16 + +void reduce_all_buffers(struct allreduce_workspace** workspace, + int start_elements, + int num_elements, + c10::ScalarType scalar_type, + int num_buffers, + int to_buffer_idx) +{ + switch (scalar_type) { + case c10::ScalarType::BFloat16: + if (num_buffers > 2 && num_buffers <= N_REDUCE_LIMIT) { + reduce_bf16_buffers( + start_elements, num_elements, num_buffers, to_buffer_idx, workspace); + } else { + for (int i = 0; i < num_buffers; i++) { + if (i == to_buffer_idx) continue; + reduce_2_bf16_buffers_iio( + num_elements, + workspace[i]->buffer + start_elements * 2, + workspace[to_buffer_idx]->buffer + start_elements * 2, + workspace[to_buffer_idx]->buffer + start_elements * 2); + } + } + break; + case c10::ScalarType::Float: + if (num_buffers > 2 && num_buffers <= N_REDUCE_LIMIT) { + reduce_fp32_buffers( + start_elements, num_elements, num_buffers, to_buffer_idx, workspace); + } else { + for (int i = 0; i < num_buffers; i++) { + if (i == to_buffer_idx) continue; + reduce_2_fp32_buffers_iio( + num_elements, + workspace[i]->buffer + start_elements * 4, + workspace[to_buffer_idx]->buffer + start_elements * 4, + workspace[to_buffer_idx]->buffer + start_elements * 4); + } + } + break; + default: assert(!"Should not get here"); + } +} + +#define REPEAT(N, x) REPEAT_##N(x) +#define REPEAT_1(x) x(1) +#define REPEAT_2(x) \ + REPEAT_1(x); \ + x(2) +#define REPEAT_3(x) \ + REPEAT_2(x); \ + x(3) +#define REPEAT_4(x) \ + REPEAT_3(x); \ + x(4) +#define REPEAT_5(x) \ + REPEAT_4(x); \ + x(5) +#define REPEAT_6(x) \ + REPEAT_5(x); \ + x(6) +#define REPEAT_7(x) \ + REPEAT_6(x); \ + x(7) +#define REPEAT_8(x) \ + REPEAT_7(x); \ + x(8) +#define REPEAT_9(x) \ + REPEAT_8(x); \ + x(9) +#define REPEAT_10(x) \ + REPEAT_9(x); \ + x(10) +#define REPEAT_11(x) \ + REPEAT_10(x); \ + x(11) +#define REPEAT_12(x) \ + REPEAT_11(x); \ + x(12) +#define REPEAT_13(x) \ + REPEAT_12(x); \ + x(13) +#define REPEAT_14(x) \ + REPEAT_13(x); \ + x(14) +#define REPEAT_15(x) \ + REPEAT_14(x); \ + x(15) + +#define CVT_ADD_BF16(x) \ + do { \ + auto in##x##_val = \ + cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(workspace[x]->buffer + i))); \ + inout_val = _mm512_add_ps(inout_val, in##x##_val); \ + } while (0) + +// Reduce functions down below use vectorized algorithm, the number of bytes processed each +// iteration depends on vector length. 256bit vector ==> 32 bytes, 512bit vector ==> 64 bytes +// If you change implementation of reduce_2_bf16_buffers_iio or reduce_2_fp32_buffers_iio, check +// whether this number needs to be changed +#define VECTOR_LENGTH_IN_BYTES 32 + +void reduce_bf16_buffers(int start_elements, + int num_elements, + int num_buffers, + int to_buffer_idx, + struct allreduce_workspace** workspace) +{ + const int element_size = 2; + const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; + int main_elements = num_elements - (num_elements % vector_length); + int remain_elements = num_elements % vector_length; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; + i += VECTOR_LENGTH_IN_BYTES) { + auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(workspace[0]->buffer + i))); + switch (num_buffers) { + case 16: REPEAT(15, CVT_ADD_BF16); break; + case 15: REPEAT(14, CVT_ADD_BF16); break; + case 14: REPEAT(13, CVT_ADD_BF16); break; + case 13: REPEAT(12, CVT_ADD_BF16); break; + case 12: REPEAT(11, CVT_ADD_BF16); break; + case 11: REPEAT(10, CVT_ADD_BF16); break; + case 10: REPEAT(9, CVT_ADD_BF16); break; + case 9: REPEAT(8, CVT_ADD_BF16); break; + case 8: REPEAT(7, CVT_ADD_BF16); break; + case 7: REPEAT(6, CVT_ADD_BF16); break; + case 6: REPEAT(5, CVT_ADD_BF16); break; + case 5: REPEAT(4, CVT_ADD_BF16); break; + case 4: REPEAT(3, CVT_ADD_BF16); break; + case 3: REPEAT(2, CVT_ADD_BF16); break; + default: assert(!"Should not get here."); + } + _mm256_storeu_si256((__m256i*)(workspace[to_buffer_idx]->buffer + i), + cvt_fp32_to_bf16(inout_val)); + } + + // process remaining part + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float val = 0.0f; + for (int j = 0; j < num_buffers; j++) { val += *(at::BFloat16*)(workspace[j]->buffer + i); } + *(at::BFloat16*)(workspace[to_buffer_idx]->buffer + i) = val; + remain_elements--; + i += element_size; + } +} + +void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out) +{ + const int element_size = 2; + const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; + int main_elements = num_elements - (num_elements % vector_length); + int remain_elements = num_elements % vector_length; + + // process aligned part +#pragma omp parallel for + for (int i = 0; i < main_elements * element_size; i += VECTOR_LENGTH_IN_BYTES) { + auto in0_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)((char*)in0 + i))); + auto in1_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)((char*)in1 + i))); + auto out_val = _mm512_add_ps(in0_val, in1_val); + _mm256_storeu_si256((__m256i*)((char*)out + i), cvt_fp32_to_bf16(out_val)); + } + + // process remaining part + int i = main_elements * element_size; + while (remain_elements > 0) { + float in0_val = *((at::BFloat16*)((char*)in0 + i)); + float in1_val = *((at::BFloat16*)((char*)in1 + i)); + *((at::BFloat16*)((char*)out + i)) = in0_val + in1_val; + remain_elements--; + i += element_size; + } +} + +#define CVT_ADD_F32(x) \ + do { \ + auto in##x##_val = _mm256_loadu_ps((float*)(workspace[x]->buffer + i)); \ + inout_val = _mm256_add_ps(inout_val, in##x##_val); \ + } while (0) + +void reduce_fp32_buffers(int start_elements, + int num_elements, + int num_buffers, + int to_buffer_idx, + struct allreduce_workspace** workspace) +{ + const int element_size = 4; + const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; + int main_elements = num_elements - (num_elements % vector_length); + int remain_elements = num_elements % vector_length; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; + i += VECTOR_LENGTH_IN_BYTES) { + auto inout_val = _mm256_loadu_ps((float*)(workspace[0]->buffer + i)); + switch (num_buffers) { + case 16: REPEAT(15, CVT_ADD_F32); break; + case 15: REPEAT(14, CVT_ADD_F32); break; + case 14: REPEAT(13, CVT_ADD_F32); break; + case 13: REPEAT(12, CVT_ADD_F32); break; + case 12: REPEAT(11, CVT_ADD_F32); break; + case 11: REPEAT(10, CVT_ADD_F32); break; + case 10: REPEAT(9, CVT_ADD_F32); break; + case 9: REPEAT(8, CVT_ADD_F32); break; + case 8: REPEAT(7, CVT_ADD_F32); break; + case 7: REPEAT(6, CVT_ADD_F32); break; + case 6: REPEAT(5, CVT_ADD_F32); break; + case 5: REPEAT(4, CVT_ADD_F32); break; + case 4: REPEAT(3, CVT_ADD_F32); break; + case 3: REPEAT(2, CVT_ADD_F32); break; + default: assert(!"Should not get here."); + } + _mm256_storeu_ps((float*)(workspace[to_buffer_idx]->buffer + i), inout_val); + } + + // process remaining part + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float val = 0.0f; + for (int j = 0; j < num_buffers; j++) { val += *(float*)(workspace[j]->buffer + i); } + *(float*)(workspace[to_buffer_idx]->buffer + i) = val; + remain_elements--; + i += element_size; + } +} + +void reduce_2_fp32_buffers_iio(int num_elements, void* in0, void* in1, void* out) +{ + const int element_size = 4; + const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; + int main_elements = num_elements - (num_elements % vector_length); + int remain_elements = num_elements % vector_length; + + // process aligned part +#pragma omp parallel for + for (int i = 0; i < main_elements * element_size; i += VECTOR_LENGTH_IN_BYTES) { + auto in0_val = _mm256_loadu_ps((float*)((char*)in0 + i)); + auto in1_val = _mm256_loadu_ps((float*)((char*)in1 + i)); + auto out_val = _mm256_add_ps(in0_val, in1_val); + _mm256_storeu_ps((float*)((char*)out + i), out_val); + } + + // process remaining part + int i = main_elements * element_size; + while (remain_elements > 0) { + float in0_val = *((float*)((char*)in0 + i)); + float in1_val = *((float*)((char*)in1 + i)); + *((float*)((char*)out + i)) = in0_val + in1_val; + remain_elements--; + i += element_size; + } +} + +static bool is_initialized = 0; +static int world_size; +static int world_rank; + +void shm_initialize(int size, int rank, char* addr_string, char* port_string) +{ + if (is_initialized) return; + is_initialized = 1; + + world_size = size; + world_rank = rank; + + char shm_name_prefix[NAME_BUF_SIZE]; + char shm_name[NAME_BUF_SIZE]; + snprintf(shm_name_prefix, + NAME_BUF_SIZE, + "%s_%d_%s_%s", + SHM_BUFFER_NAME, + getuid(), + addr_string, + port_string); + // create shared workspace for SHM based allreduce + SharedData allreduce_buffer; + // allocate workspace_buf for current rank + struct allreduce_workspace* workspace_buf; + struct allreduce_workspace* workspace_buf_other; + workspace_buf = (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace)); + snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank); + shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace)); + workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes; + workspace_buf->state = coll_begin; + + // create the workspace pointer list + workspace = (struct allreduce_workspace**)malloc(size * sizeof(struct allreduce_workspace*)); + + // map shm of all ranks + for (int i = 0; i < size; i++) { + if (i != rank) { + snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i); + // printf("open %s, %d\n", shm_name, rank); + do { + shared_open(&allreduce_buffer, shm_name, sizeof(struct allreduce_workspace)); + } while (allreduce_buffer.descriptor == -1 && errno == ENOENT); + workspace_buf_other = (struct allreduce_workspace*)allreduce_buffer.bytes; + workspace[i] = workspace_buf_other; + } else { + workspace[i] = workspace_buf; + workspace_buf->counter = 0; + sem_init(&workspace_buf->mutex, 1, 1); + sem_init(&workspace_buf->turnstile1, 1, 0); + sem_init(&workspace_buf->turnstile2, 1, 0); + } + } +} + +static void parallel_memcpy(void* to, void* from, size_t n_bytes) + __attribute__((target("avx512bw"))); +static void parallel_memcpy(void* to, void* from, size_t n_bytes) +{ + auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES); + // process aligned part +#pragma omp parallel for + for (int i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) { + auto val = _mm256_loadu_si256((__m256i*)((char*)from + i)); + _mm256_storeu_si256((__m256i*)((char*)to + i), val); + } + + // process remaining part + for (int i = aligned_bytes; i < n_bytes; i++) { *((char*)to + i) = *((char*)from + i); } +} + +#define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod)) +#define rank_mod(rank) positive_mod(rank, world_size) +size_t slice_size(size_t chunk_el, int slice_idx) +{ + size_t slice_size = chunk_el / world_size; + return slice_idx == world_size - 1 ? slice_size + (chunk_el % world_size) : slice_size; +} + +char* slice_data(char* data_ptr, size_t chunk_el, int el_size, int slice_idx) +{ + size_t slice_size = chunk_el / world_size; + size_t el_offset = slice_size * slice_idx; + return data_ptr + el_offset * el_size; +} + +size_t slice_el_start(size_t chunk_el, int slice_idx) +{ + size_t slice_size = chunk_el / world_size; + return slice_size * slice_idx; +} + +void naive_all_reduce(char* data_ptr, + c10::ScalarType scalar_type, + size_t chunk_size, + size_t chunk_el) +{ + parallel_memcpy(workspace[world_rank]->buffer, data_ptr, chunk_size); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->state = coll_allreduce_naive__copy_in_done; + + if (world_rank == 0) { + // compute allreduce result on rank 0 + for (int i = 1; i < world_size; i++) { + // wait until the other rank copy the buffer + wait_buffer_state_until(i, coll_allreduce_naive__copy_in_done); + } + reduce_all_buffers(workspace, 0, chunk_el, scalar_type, world_size, 0); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->state = coll_allreduce_naive__reduce_done; + parallel_memcpy(data_ptr, workspace[0]->buffer, chunk_size); + } + if (world_rank != 0) { + wait_buffer_state_until(0, coll_allreduce_naive__reduce_done); + parallel_memcpy(data_ptr, workspace[0]->buffer, chunk_size); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->state = coll_allreduce_naive__copy_out_done; + } + if (world_rank == 0) { + for (int i = 1; i < world_size; i++) { + wait_buffer_state_until(i, coll_allreduce_naive__copy_out_done); + } + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->state = coll_begin; + } + if (world_rank != 0) { + // if rank 0 spin too fast it could be in state 1 of next allreduce + // in this case wait_buffer_state_until(0, 0) may cause deadlock + // what we are certain is when rank 0 finishes the state won't be 2 + wait_buffer_state_until_not(0, coll_allreduce_naive__reduce_done); + workspace[world_rank]->state = coll_begin; + } +} + +// naive allreduce distributed, each rank do naive reduce on its slice +void distributed_naive_reduce(char* data_ptr, + c10::ScalarType scalar_type, + size_t chunk_size, + size_t chunk_el) +{ +#ifdef DO_PROFILE + static double total_t1_t0 = 0.0; + static double total_t2_t1 = 0.0; + static double total_t3_t2 = 0.0; + static double total_t4_t3 = 0.0; + static double total_t5_t4 = 0.0; + static int count = -16; // warmup + auto t0 = std::chrono::system_clock::now(); +#endif + + int data_size = chunk_size / chunk_el; + parallel_memcpy(workspace[world_rank]->buffer, data_ptr, chunk_size); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->state = coll_allreduce_naive__copy_in_done; + +#ifdef DO_PROFILE + auto t1 = std::chrono::system_clock::now(); +#endif + + for (int i = 0; i < world_size; i++) { + // wait until all the other ranks copy the buffer + wait_buffer_state_until_range(i, coll_allreduce_naive__copy_in_done, 2); + } + +#ifdef DO_PROFILE + auto t2 = std::chrono::system_clock::now(); +#endif + + // reduce scatter + reduce_all_buffers(workspace, + slice_el_start(chunk_el, world_rank), + slice_size(chunk_el, world_rank), + scalar_type, + world_size, + world_rank); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->state = coll_allreduce_naive__reduce_done; + +#ifdef DO_PROFILE + auto t3 = std::chrono::system_clock::now(); +#endif + + for (int i = 0; i < world_size; i++) { + int rank = (i + world_rank) % world_size; + // wait until the other rank reduce the buffer + wait_buffer_state_until_range(rank, coll_allreduce_naive__reduce_done, 2); + parallel_memcpy(slice_data(data_ptr, chunk_el, data_size, rank), + slice_data(workspace[rank]->buffer, chunk_el, chunk_size / chunk_el, rank), + slice_size(chunk_el, rank) * data_size); + } + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->state = coll_allreduce_naive__copy_out_done; + +#ifdef DO_PROFILE + auto t4 = std::chrono::system_clock::now(); +#endif + + for (int i = 0; i < world_size; i++) { + wait_buffer_state_until_not(i, coll_allreduce_naive__reduce_done); + } + + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->state = coll_begin; + +#ifdef DO_PROFILE + auto t5 = std::chrono::system_clock::now(); + count++; + if (count > 0) { + total_t1_t0 += std::chrono::duration_cast(t1 - t0).count(); + total_t2_t1 += std::chrono::duration_cast(t2 - t1).count(); + total_t3_t2 += std::chrono::duration_cast(t3 - t2).count(); + total_t4_t3 += std::chrono::duration_cast(t4 - t3).count(); + total_t5_t4 += std::chrono::duration_cast(t5 - t4).count(); + if (world_rank == 0 && count == 1000) { + printf("distributed_naive_reduce time breakdown:\n"); + printf("\tcopy input buffer: %.2f\n", total_t1_t0 / count); + printf("\twait for copy: %.2f\n", total_t2_t1 / count); + printf("\treduce: %.2f\n", total_t3_t2 / count); + printf("\tcopy buffer to output: %.2f\n", total_t4_t3 / count); + printf("\twait finish: %.2f\n", total_t5_t4 / count); + } + } +#endif +} + +void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size) +{ + for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) { + auto data_ptr = ((char*)(data.data_ptr()) + offset); + size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset; + size_t chunk_el = chunk_size / (data_size / numel); + if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) + naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el); + else + distributed_naive_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el); + } +} diff --git a/csrc/cpu/comm/shm.h b/csrc/cpu/comm/shm.h new file mode 100644 index 000000000000..7f73197a8caa --- /dev/null +++ b/csrc/cpu/comm/shm.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#ifndef __SHM_COLLECTIVES__ +#define __SHM_COLLECTIVES__ +#define VECTOR_LENGTH_IN_BYTES 32 +void shm_initialize(int size, int rank, char* addr_string, char* port_string); +void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size); +void barrier_wait(int root_idx, int num_ranks); +#endif diff --git a/csrc/fp_quantizer/includes/context.h b/csrc/fp_quantizer/includes/context.h new file mode 100644 index 000000000000..5bd9badbcb4f --- /dev/null +++ b/csrc/fp_quantizer/includes/context.h @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include +#include +#include +#include "cublas_v2.h" +#include "cuda.h" +#include "curand.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#define WARP_SIZE 32 + +class FPContext { +public: + FPContext() : _seed(42) + { + curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT); + curandSetPseudoRandomGeneratorSeed(_gen, 123); + } + + virtual ~FPContext() {} + + static FPContext& Instance() + { + static FPContext _ctx; + return _ctx; + } + + curandGenerator_t& GetRandGenerator() { return _gen; } + + cudaStream_t GetCurrentStream() + { + // get current pytorch stream. + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + return stream; + } + + std::pair IncrementOffset(uint64_t offset_inc) + { + uint64_t offset = _curr_offset; + _curr_offset += offset_inc; + return std::pair(_seed, offset); + } + + void SetSeed(uint64_t new_seed) { _seed = new_seed; } + +private: + curandGenerator_t _gen; + cublasHandle_t _cublasHandle; + uint64_t _seed; + uint64_t _curr_offset; +}; diff --git a/csrc/fp_quantizer/includes/quantize.h b/csrc/fp_quantizer/includes/quantize.h new file mode 100644 index 000000000000..2204c1ba74fc --- /dev/null +++ b/csrc/fp_quantizer/includes/quantize.h @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include + +#include + +#include +#include +#include + +#define QUANT_SWITCH(Q_BITS, ...) \ + [&] { \ + if (12 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 8; \ + constexpr int CONST_Q_MANTISA_BITS = 3; \ + __VA_ARGS__(); \ + } else if (13 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 8; \ + constexpr int CONST_Q_MANTISA_BITS = 3; \ + __VA_ARGS__(); \ + } else if (10 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 8; \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + __VA_ARGS__(); \ + } else if (11 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 8; \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + __VA_ARGS__(); \ + } else if (28 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 12; \ + constexpr int CONST_Q_MANTISA_BITS = 7; \ + __VA_ARGS__(); \ + } else if (29 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 12; \ + constexpr int CONST_Q_MANTISA_BITS = 7; \ + __VA_ARGS__(); \ + } else if (6 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 6; \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + __VA_ARGS__(); \ + } else if (7 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 6; \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + __VA_ARGS__(); \ + } else if (2 == Q_BITS) { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ + constexpr int CONST_Q_BITS = 4; \ + constexpr int CONST_Q_MANTISA_BITS = 1; \ + __VA_ARGS__(); \ + } else { \ + constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ + constexpr int CONST_Q_BITS = 4; \ + constexpr int CONST_Q_MANTISA_BITS = 1; \ + __VA_ARGS__(); \ + } \ + }() + +#define DEQUANT_SWITCH(Q_MANTISA_EXPONENT_BITS, ...) \ + [&] { \ + if (12 == Q_MANTISA_EXPONENT_BITS) { \ + constexpr int CONST_Q_MANTISA_BITS = 3; \ + constexpr int CONST_Q_EXPONENT_BITS = 4; \ + __VA_ARGS__(); \ + } else if (10 == Q_MANTISA_EXPONENT_BITS) { \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + constexpr int CONST_Q_EXPONENT_BITS = 5; \ + __VA_ARGS__(); \ + } else if (28 == Q_MANTISA_EXPONENT_BITS) { \ + constexpr int CONST_Q_MANTISA_BITS = 7; \ + constexpr int CONST_Q_EXPONENT_BITS = 4; \ + __VA_ARGS__(); \ + } else if (6 == Q_MANTISA_EXPONENT_BITS) { \ + constexpr int CONST_Q_MANTISA_BITS = 2; \ + constexpr int CONST_Q_EXPONENT_BITS = 3; \ + __VA_ARGS__(); \ + } else { \ + constexpr int CONST_Q_MANTISA_BITS = 1; \ + constexpr int CONST_Q_EXPONENT_BITS = 2; \ + __VA_ARGS__(); \ + } \ + }() + +template +void launch_quantization(T* val, + uint8_t* q_val, + int num_groups, + int group_size, + cudaStream_t stream, + float q_range, + int q_bits, + int q_mantisa_bits, + int stochastic_rounding); + +template +void launch_dequantization(uint8_t* val, + T* q_val, + int num_groups, + int group_size, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream); diff --git a/csrc/fp_quantizer/quantize.cpp b/csrc/fp_quantizer/quantize.cpp new file mode 100644 index 000000000000..4a88ff767636 --- /dev/null +++ b/csrc/fp_quantizer/quantize.cpp @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "quantize.h" + +#include +#include +#include + +#define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \ + if (val.options().dtype() == torch::T_TYPE) { \ + launch_quantization((C_TYPE*)val.data_ptr(), \ + (uint8_t*)out.data_ptr(), \ + num_groups, \ + group_size, \ + at::cuda::getCurrentCUDAStream(), \ + q_range, \ + q_bits, \ + q_mantisa_bits, \ + stochastic_rounding); \ + } + +at::Tensor quantize(torch::Tensor& val, + int group_size, + int stochastic_rounding, + int q_bits, + int q_mantisa_bits) +{ + int total_elems = at::numel(val); + auto options = at::TensorOptions() + .dtype(torch::kInt8) + .layout(val.layout()) + .device(val.device()) + .requires_grad(false); + float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges + (q_bits == 12 ? 510.0 : // fp12 range + (q_bits == 6 ? 28.0 : // fp6 range + 6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4 + // in case accuracy is not matching! + int num_groups = total_elems / group_size; + auto out = torch::empty({num_groups, group_size * q_bits / 8 + 4}, options); + + DISPATCH_QUANTIZE(kHalf, __half, 23, 8); +#ifdef BF16_AVAILABLE + DISPATCH_QUANTIZE(kBFloat16, __nv_bfloat16, 23, 8); +#endif + + return out; +} + +#define DISPATCH_DEQUANTIZE(T_TYPE, C_TYPE, mantisa) \ + if (val.options().dtype() == torch::T_TYPE) { \ + launch_dequantization((uint8_t*)val_q.data_ptr(), \ + (C_TYPE*)val.data_ptr(), \ + num_groups, \ + group_size, \ + q_mantisa_bits, \ + q_exponent_bits, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +void dequantize(torch::Tensor& val, + torch::Tensor& val_q, + int group_size, + int q_mantisa_bits, + int q_exponent_bits) +{ + int total_elems = at::numel(val); + + int num_groups = total_elems / group_size; + + DISPATCH_DEQUANTIZE(kHalf, __half, 10); +#ifdef BF16_AVAILABLE + DISPATCH_DEQUANTIZE(kBFloat16, __nv_bfloat16, 7); +#endif +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("quantize", &quantize, "quantize function"); + m.def("dequantize", &dequantize, "dequantize function"); +} diff --git a/csrc/fp_quantizer/quantize.cu b/csrc/fp_quantizer/quantize.cu new file mode 100644 index 000000000000..37be6cc0657c --- /dev/null +++ b/csrc/fp_quantizer/quantize.cu @@ -0,0 +1,427 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "context.h" +#include "memory_access_utils.h" +#include "quantize.h" +#include "reduction_utils.h" + +#include +#include + +#include +#include + +#include +#include + +using ROp = reduce::ROpType; + +namespace quantization { + +constexpr int access_granularity = 16; +constexpr int quanitzed_access_granularity = 4; +constexpr int quanitzed_access_granularity_6bits = 2; +constexpr int threads = 256; +constexpr int warps = threads / 32; + +} // namespace quantization + +template +__device__ void round(uint32_t& mantisa, uint32_t& dst_exponent, curandStatePhilox4_32_10_t* state) +{ + constexpr uint32_t mantisa_mask = (1 << (_mantisa_bits - q_mantisa_bits)) - 1; + uint32_t offset = stochastic_rounding ? (curand_poisson(state, 10) & mantisa_mask) + : 1 << (_mantisa_bits - q_mantisa_bits - 1); + mantisa += offset; + dst_exponent += (((mantisa & ~mantisa_mask) == (1 << _mantisa_bits)) ? 1 : 0); +} + +template +__device__ void clip(uint32_t& exponent, uint32_t& mantisa) +{ + constexpr uint32_t max_exponent = (1 << (q_exponent_bits - 1)) + (1 << (_exponent_bits - 1)); + constexpr uint32_t min_exponent = + (1 << (_exponent_bits - 1)) - ((1 << (q_exponent_bits - 1)) - 1); + if (exponent > max_exponent) { + exponent = max_exponent; + mantisa = (((uint32_t)-1) >> (32 - q_mantisa_bits)) << 1; //.11 .. 10 + } + if (exponent < min_exponent) { + exponent = min_exponent; + mantisa = 0; + } +} + +template +__global__ void apply_quantization(T* val, + uint8_t* q_val, + int group_size, + std::pair seed, + float q_range) +{ + int tidx = threadIdx.x; + int wid = tidx >> 5; + int lane = tidx & 0x1f; + int gid = blockIdx.x * quantization::warps + wid; + + constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; + constexpr uint32_t _mantisa_mask = (1 << _mantisa_bits) - 1; + constexpr uint32_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; + constexpr uint32_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits); + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); + constexpr uint32_t load_stride = vector_size * hw_warp_size; + constexpr uint32_t store_stride = (total_q_bits * vector_size / 8) * hw_warp_size; + const uint32_t thread_offset = lane * vector_size; + const uint32_t store_thread_offset = lane * (total_q_bits * vector_size / 8); + const uint32_t base_load_offset = gid * group_size + thread_offset; + const uint32_t base_store_offset = + gid * ((group_size * total_q_bits / 8) + 4) + + store_thread_offset; // 4-byte for saving the scale per group + const T* load_base_ptr = val + base_load_offset; + T tmp_buf[unroll * vector_size]; + T cur_max; + reduce::init(&cur_max); + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + +#pragma unroll + for (int i = 0; i < unroll; i++) { + if (i * load_stride + thread_offset < group_size) { + mem_access::load_global( + &tmp_buf[vector_size * i], load_base_ptr + i * load_stride); + for (int j = 0; j < vector_size; j++) + cur_max = reduce::element(cur_max, __habs(tmp_buf[i * vector_size + j])); + } + } + reduce::_block(tb, warp, &cur_max); + + int mantisa_mask = ((1 << q_mantisa_bits) - 1); + mantisa_mask <<= (_mantisa_bits - q_mantisa_bits); + + uint8_t* store_base_ptr = q_val + base_store_offset; + float scale = (float)q_range / conversion::to(cur_max); +#pragma unroll + for (int i = 0; i < unroll; i++) { + if (i * load_stride + thread_offset < group_size) { + uint64_t q_buf = 0; + uint64_t q_buf1 = 0; +#pragma unroll + for (int j = 0; j < vector_size; j++) { + float val_f = conversion::to(tmp_buf[i * vector_size + j]) * scale; + uint32_t* data = reinterpret_cast(&val_f); + uint32_t sign = (data[0] & _sign_mask) >> (_mantisa_bits + _exponent_bits); + uint32_t cur_exponent = (data[0] & _exponent_mask) >> _mantisa_bits; + uint32_t dst_mantisa = (data[0] & _mantisa_mask); + + uint32_t dst_exponent = cur_exponent; + + round<_mantisa_bits, q_mantisa_bits, stochastic_rounding>( + dst_mantisa, dst_exponent, &state); + if (cur_exponent != 0) + clip<_mantisa_bits, _exponent_bits, q_mantisa_bits, q_exponent_bits>( + dst_exponent, dst_mantisa); + + dst_mantisa = (dst_mantisa & mantisa_mask) >> (_mantisa_bits - q_mantisa_bits); + + if (dst_exponent != (1 << q_exponent_bits) - 1) + dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + + (1 << (q_exponent_bits - 1)) - 1; + if (total_q_bits == 8 || total_q_bits == 4 || total_q_bits == 6) + q_buf = q_buf | + ((uint64_t)((uint8_t)(sign << (q_exponent_bits + q_mantisa_bits) | + (dst_exponent << q_mantisa_bits) | dst_mantisa)) + << j * total_q_bits); + else if (total_q_bits == 12) { + if (j < 5) + q_buf = + q_buf | + ((uint64_t)((uint16_t)(sign << (q_exponent_bits + q_mantisa_bits) | + (dst_exponent << q_mantisa_bits) | dst_mantisa)) + << j * total_q_bits); + else + q_buf1 = + q_buf1 | + ((uint64_t)((uint16_t)(sign << (q_exponent_bits + q_mantisa_bits) | + (dst_exponent << q_mantisa_bits) | dst_mantisa)) + << (j - 5) * total_q_bits); + } + } + if (total_q_bits == 12) { + uint64_t last_nibble_mask = 0xf; + last_nibble_mask = q_buf1 & last_nibble_mask; + q_buf = (last_nibble_mask << 60) | q_buf; + q_buf1 >>= 4; + } + uint8_t* int8_data = reinterpret_cast(&q_buf); + uint8_t* int8_data1 = reinterpret_cast(&q_buf1); + if (total_q_bits == 6) { + mem_access::store_global( + store_base_ptr + i * store_stride, int8_data); + mem_access::store_global( + store_base_ptr + i * store_stride + + quantization::quanitzed_access_granularity_6bits, + int8_data + quantization::quanitzed_access_granularity_6bits); + mem_access::store_global( + store_base_ptr + i * store_stride + + quantization::quanitzed_access_granularity_6bits * 2, + int8_data + 2 * quantization::quanitzed_access_granularity_6bits); + } else { + mem_access::store_global( + store_base_ptr + i * store_stride, int8_data); + + if (total_q_bits > 4) { + mem_access::store_global( + store_base_ptr + i * store_stride + + quantization::quanitzed_access_granularity, + int8_data + quantization::quanitzed_access_granularity); + if (total_q_bits == 12) { + mem_access::store_global( + store_base_ptr + i * store_stride + + quantization::quanitzed_access_granularity * 2, + int8_data1); + } + } + } + } + } + if (lane == 0) { + float q_scale = conversion::to(cur_max) / (float)q_range; + uint8_t* scale_as_int8 = reinterpret_cast(&q_scale); + uint32_t scale_offset = + gid * ((group_size * total_q_bits / 8) + 4) + (group_size * total_q_bits / 8); + if (total_q_bits != 6) + mem_access::store_global( + q_val + scale_offset, scale_as_int8); + else { + mem_access::store_global( + q_val + scale_offset, scale_as_int8); + mem_access::store_global( + q_val + scale_offset + quantization::quanitzed_access_granularity_6bits, + scale_as_int8 + quantization::quanitzed_access_granularity_6bits); + } + } +} + +template +__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size) +{ + int tidx = threadIdx.x; + int wid = tidx >> 5; + int lane = tidx & 0x1f; + int gid = blockIdx.x * quantization::warps + wid; + constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; + constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; + constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; + constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; + constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits); + + constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); + constexpr uint32_t load_stride = vector_size * hw_warp_size; + const uint32_t thread_offset = lane * vector_size; + const uint32_t thread_load_offset = lane * vector_size * quantized_bits / 8; + const uint32_t base_load_offset = + gid * (group_size * quantized_bits / 8 + 4) + thread_load_offset; // 4-byte scale offset + const uint32_t base_store_offset = gid * group_size + thread_offset; + const uint8_t* load_base_ptr = val + base_load_offset; + + int mantisa_mask = ((1 << q_mantisa_bits) - 1); + mantisa_mask <<= (_mantisa_bits - q_mantisa_bits); + + T* store_base_ptr = q_val + base_store_offset; + float scale; //= q_scale[gid]; + + uint8_t* scale_as_int8 = reinterpret_cast(&scale); + if (quantized_bits == 6) { + mem_access::load_global( + scale_as_int8, + val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8)); + mem_access::load_global( + scale_as_int8 + quantization::quanitzed_access_granularity_6bits, + val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8) + + quantization::quanitzed_access_granularity_6bits); + } else + mem_access::load_global( + scale_as_int8, + val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8)); + +#pragma unroll + for (int i = 0; i < unroll; i++) { + if (i * load_stride + thread_offset < group_size) { + uint64_t q_buf_in; + uint64_t q_buf_in1; + uint8_t* int8_data = reinterpret_cast(&q_buf_in); + uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1); + uint32_t loading_offset = i * load_stride * quantized_bits / 8; + if (quantized_bits == 6) { + mem_access::load_global( + int8_data, load_base_ptr + loading_offset); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits, + load_base_ptr + loading_offset + + quantization::quanitzed_access_granularity_6bits); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits * 2, + load_base_ptr + loading_offset + + quantization::quanitzed_access_granularity_6bits * 2); + } else { + mem_access::load_global( + int8_data, load_base_ptr + loading_offset); + if (quantized_bits > 4) { + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity, + load_base_ptr + loading_offset + + quantization::quanitzed_access_granularity); + if (quantized_bits == 12) { + mem_access::load_global( + int8_data1, + load_base_ptr + loading_offset + + quantization::quanitzed_access_granularity * 2); + } + } + } + T store_buf[vector_size]; + uint16_t* q_buf = reinterpret_cast(store_buf); +#pragma unroll + for (int j = 0; j < vector_size; j++) { + uint16_t new_data; + if (j < 5 || quantized_bits != 12) { + new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); + } else { + if (j == 5) { + new_data = (uint16_t)(q_buf_in1); + new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); + } else + new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); + } + + uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); + uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; + uint16_t dst_mantisa = (new_data & _mantisa_mask); + + if (dst_exponent != (1 << q_exponent_bits) - 1) + dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + + (1 << (q_exponent_bits - 1)) - 1; + + q_buf[j] = ((sign << (q_exponent_bits + q_mantisa_bits)) | + (dst_exponent << q_mantisa_bits) | + (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); + float up_cast = conversion::to(store_buf[j]); + store_buf[j] = conversion::to(up_cast * scale); + } + mem_access::store_global( + store_base_ptr + i * load_stride, store_buf); + } + } +} + +#define LAUNCH_FOR_QUANTIZATION_UNROLL(COUNT) \ + case COUNT: \ + apply_quantization \ + <<>>(val, q_val, group_size, seed, q_range); \ + break; + +template +void launch_quantization(T* val, + uint8_t* q_val, + int num_groups, + int group_size, + cudaStream_t stream, + float q_range, + int q_bits, + int q_mantisa_bits, + int stochastic_rounding) +{ + const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps); + const dim3 block(quantization::threads); + + std::pair seed = FPContext::Instance().IncrementOffset(16); + + constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T); + + const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll; + QUANT_SWITCH((q_bits - q_mantisa_bits - 1) * q_mantisa_bits + stochastic_rounding, [&] { + switch (copy_unroll) { + LAUNCH_FOR_QUANTIZATION_UNROLL(1) + LAUNCH_FOR_QUANTIZATION_UNROLL(2) + LAUNCH_FOR_QUANTIZATION_UNROLL(3) + LAUNCH_FOR_QUANTIZATION_UNROLL(4) + LAUNCH_FOR_QUANTIZATION_UNROLL(5) + LAUNCH_FOR_QUANTIZATION_UNROLL(6) + } + }); +} +#define INSTANTIATE_LAUNCH_QUANTIZATION(T, mantisa, exponent) \ + template void launch_quantization( \ + T*, uint8_t*, int, int, cudaStream_t, float q_range, int, int, int); +// fp8(E4M3), nearest-rounding +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8); +#endif +INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8); + +#define LAUNCH_FOR_DEQUANTIZATION_UNROLL(COUNT) \ + case COUNT: \ + apply_dequantization \ + <<>>(val, q_val, group_size); \ + break; + +template +void launch_dequantization(uint8_t* val, + T* q_val, + int num_groups, + int group_size, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream) +{ + const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps); + const dim3 block(quantization::threads); + + constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T); + const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll; + + DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { + switch (copy_unroll) { + LAUNCH_FOR_DEQUANTIZATION_UNROLL(1) + LAUNCH_FOR_DEQUANTIZATION_UNROLL(2) + LAUNCH_FOR_DEQUANTIZATION_UNROLL(3) + LAUNCH_FOR_DEQUANTIZATION_UNROLL(4) + LAUNCH_FOR_DEQUANTIZATION_UNROLL(5) + LAUNCH_FOR_DEQUANTIZATION_UNROLL(6) + } + }); +} +#define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \ + template void launch_dequantization(uint8_t*, T*, int, int, int, int, cudaStream_t); +// fp8(E4M3) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7); +#endif +INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10); diff --git a/csrc/includes/memory_access_utils.h b/csrc/includes/memory_access_utils.h index 6789714d27c7..cb990b75bbe8 100644 --- a/csrc/includes/memory_access_utils.h +++ b/csrc/includes/memory_access_utils.h @@ -868,6 +868,35 @@ __device__ __forceinline__ void store_global<4, StorePolicy::CacheStreaming>(voi #endif } +template <> +__device__ __forceinline__ void store_global<2>(void* dst, const void* src) +{ + const int16_t* data = reinterpret_cast(src); + + int16_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +} + +template <> +__device__ __forceinline__ void store_global<2, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const int16_t* data = reinterpret_cast(src); + + int16_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +} + +template <> +__device__ __forceinline__ void store_global<2, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const int16_t* data = reinterpret_cast(src); + + int16_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +} + /////////// Store Shared /////////// template <> diff --git a/csrc/includes/reduction_utils.h b/csrc/includes/reduction_utils.h index eb8efab77ac1..eb9afb66a894 100644 --- a/csrc/includes/reduction_utils.h +++ b/csrc/includes/reduction_utils.h @@ -159,6 +159,12 @@ DS_D_INLINE float element(const float lhs, const float rhs) return lhs + rhs; } +template <> +DS_D_INLINE double element(const double lhs, const double rhs) +{ + return lhs + rhs; +} + template <> DS_D_INLINE float element(const float lhs, const float rhs) { @@ -189,6 +195,19 @@ DS_D_INLINE __half element(const __half lhs, const __half rhs) #endif } +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 element(const __nv_bfloat16 lhs, const __nv_bfloat16 rhs) +{ +#if __CUDA_ARCH__ >= 800 + // Intrinsic limited to Ampere + newer + return __hmax(lhs, rhs); +#else + return (lhs > rhs) ? lhs : rhs; +#endif +} +#endif + template <> DS_D_INLINE __half element(const __half lhs, const __half rhs) { @@ -220,6 +239,21 @@ DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) #endif } +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat162 element(const __nv_bfloat162 lhs, const __nv_bfloat162 rhs) +{ +#if __CUDA_ARCH__ >= 800 + return __hmax2(lhs, rhs); +#else + __nv_bfloat162 ret_val; + ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x; + ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y; + return ret_val; +#endif +} +#endif + template <> DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) { @@ -295,6 +329,11 @@ DS_D_INLINE float init() { return 0.0f; } +template <> +DS_D_INLINE double init() +{ + return (double)0.0f; +} template <> DS_D_INLINE float init() @@ -331,6 +370,15 @@ DS_D_INLINE __half init() return __half(neg_inf); } +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 init() +{ + constexpr __nv_bfloat16_raw neg_inf = {0xFF80}; + return __nv_bfloat16(neg_inf); +} +#endif + template <> DS_D_INLINE __half2 init() { diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index f1d99e1b0e43..fe0043547860 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -26,6 +26,7 @@ from . import module_inject from .accelerator import get_accelerator +from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER from .runtime.hybrid_engine import DeepSpeedHybridEngine @@ -71,6 +72,7 @@ def initialize(args=None, model_parameters: Optional[torch.nn.Module] = None, training_data: Optional[torch.utils.data.Dataset] = None, lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None, + distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT, mpu=None, dist_init_required: Optional[bool] = None, collate_fn=None, @@ -95,6 +97,8 @@ def initialize(args=None, lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object. The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods + distributed_port: Optional: Master node (rank 0)'s free port that needs to be used for communication during distributed training + mpu: Optional: A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}() @@ -136,7 +140,9 @@ def initialize(args=None, global dist from deepspeed import comm as dist dist_backend = get_accelerator().communication_backend_name() - dist.init_distributed(dist_backend=dist_backend, dist_init_required=dist_init_required) + dist.init_distributed(dist_backend=dist_backend, + distributed_port=distributed_port, + dist_init_required=dist_init_required) # Set config using config_params for backwards compat if config is None and config_params is not None: diff --git a/deepspeed/checkpoint/reshape_utils.py b/deepspeed/checkpoint/reshape_utils.py index 15b6ce28b2fd..137607721ebf 100644 --- a/deepspeed/checkpoint/reshape_utils.py +++ b/deepspeed/checkpoint/reshape_utils.py @@ -4,9 +4,10 @@ # DeepSpeed Team import os +import re import torch from collections import OrderedDict -from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX) +from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX, MODEL_FILE_PREFIX) def basic_folder_validation(dir): @@ -38,12 +39,28 @@ def get_files(dir): return file_list +def sort_zero_files(files, prefix): + pattern = f"{prefix}([0-9]+)_{MODEL_FILE_PREFIX}([0-9]+)" + rank_pairs = [] + for f in files: + m = re.search(pattern, f) + if m: + dp_rank = int(m.group(1)) + mp_rank = int(m.group(2)) + rank_pairs.append((dp_rank, mp_rank, f)) + else: + raise ValueError(f"Cannot parse dp_rank and mp_rank from {f}") + + sorted_files = sorted(rank_pairs, key=lambda x: (x[0], x[1])) + return [f for _, _, f in sorted_files] + + def get_zero_files(dir): file_list = get_files(dir) for prefix in [ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX]: zero_files = get_files_with_prefix(file_list, prefix) if len(zero_files) > 0: - return zero_files + return sort_zero_files(zero_files, prefix) return [] diff --git a/deepspeed/elasticity/elastic_agent.py b/deepspeed/elasticity/elastic_agent.py index c1e8932ecaba..6051529bfa24 100644 --- a/deepspeed/elasticity/elastic_agent.py +++ b/deepspeed/elasticity/elastic_agent.py @@ -6,7 +6,7 @@ from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent from typing import Any, Dict, Optional, Tuple from datetime import datetime -from torch.distributed.elastic.agent.server.api import log, _get_socket_with_port +from torch.distributed.elastic.agent.server.api import _get_socket_with_port from torch.distributed.elastic.metrics import put_metric from torch.distributed.elastic.agent.server.api import ( RunResult, @@ -24,6 +24,10 @@ from contextlib import closing import subprocess +from torch.distributed.elastic.utils.logging import get_logger + +log = get_logger(__name__) + class DSElasticAgent(LocalElasticAgent): diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index d6c023ec11d3..e685a0f574f3 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -209,13 +209,15 @@ def top1gating(logits: Tensor, # if we don't want to drop any tokens if not drop_tokens: new_capacity = torch.max(exp_counts).to(logits.device) + # Communicate across all processes to pick the maximum capacity. dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) if groups._get_expert_model_parallel_world_size() == 1: # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. # This is since we are going to activate drop_tokens() to drop duplicate tokens. tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size() new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) - capacity = new_capacity + # Make sure the capacity value does not exceed the number of tokens. + capacity = min(new_capacity, torch.tensor(mask1.size(0))) # Compute l_aux me = torch.mean(gates, dim=0) diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py new file mode 100644 index 000000000000..5575f3567185 --- /dev/null +++ b/deepspeed/ops/fp_quantizer/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .quantize import FP_Quantize diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py new file mode 100644 index 000000000000..5dc3c190ae5d --- /dev/null +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from deepspeed.ops.op_builder import FPQuantizerBuilder + +fp_quant_module = None + + +class FP_Quantize: + + def __init__(self, group_size=512) -> None: + global fp_quant_module + if fp_quant_module is None: + fp_quant_module = FPQuantizerBuilder().load() + + self.group_size = group_size + self.orig_dtype = None + + def quantize(self, + input, + q_bits=8, + q_mantisa_bits=3, + stochastic_mode=False, + return_meta_tensor=False) -> torch.Tensor: + assert input.dtype == torch.bfloat16, "only support bf16 for now" + if return_meta_tensor: + assert q_bits == 8, "meta tensor is only supported with q_bit=8" + + self.orig_dtype = input.dtype + self.orig_shape = input.shape + + if q_bits == 8: + pass + elif q_bits == 12: + q_mantisa_bits = 4 + elif q_bits == 6: + q_mantisa_bits = 2 + elif q_bits == 4: + q_mantisa_bits = 1 + else: + assert (0), \ + f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!" + + out = fp_quant_module.quantize(input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits) + + if return_meta_tensor: + data, scale = out.split(self.group_size, dim=-1) + return data.contiguous().reshape(input.shape), scale.contiguous() + + return out + + def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor: + assert (self.orig_dtype is not None), \ + "[De-quantization Error]: you need to call quantize before dequantizing!" + fp_out = torch.empty(self.orig_shape, dtype=self.orig_dtype, + device=input_q.device) if fp_out is None else fp_out + if q_bits == 8: + pass + elif q_bits == 12: + q_mantisa_bits = 4 + elif q_bits == 6: + q_mantisa_bits = 2 + elif q_bits == 4: + q_mantisa_bits = 1 + else: + assert (0), \ + f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" + + if scale is not None: + assert input_q.numel() == fp_out.numel(), \ + f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' + input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() + + fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) + return fp_out diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 7b98216c1cba..d076035604e3 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -517,7 +517,7 @@ def param_groups(self): def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx): assert self.immediate_grad_update - self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=False) + self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=True) def create_grad_acc_hooks(self): self.grad_accs = [] diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index e068f4a48b4a..c1c2b6c61cfd 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -380,8 +380,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): for p in parameters: all_norms.append(p.grad.data.abs().max().float()) total_norm = torch.stack(all_norms).max() - origin_device = total_norm.device.type - total_norm = total_norm.to(get_accelerator().device_name()) + total_norm = total_norm.to(get_accelerator().current_device_name()) # Take max across all GPUs. if mpu is not None: dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) @@ -398,9 +397,8 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): if len(all_norms) > 0: total_norm = torch.stack(all_norms).square().sum().float() else: - total_norm = torch.FloatTensor([0.0]).to(parameters[0].device) - origin_device = total_norm.device.type - total_norm = total_norm.to(get_accelerator().device_name()) + total_norm = get_accelerator().FloatTensor([0.0]) + total_norm = total_norm.to(get_accelerator().current_device_name()) # Sum across all model parallel GPUs. if mpu is not None: dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) @@ -413,11 +411,11 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): dist.all_reduce(scaled_norm_tensor, group=pg) total_norm = scaled_norm_tensor - total_norm = total_norm.to(origin_device) + total_norm = total_norm.to(parameters[0].device) - max_norm = torch.tensor([float(max_norm)], device=parameters[0].device) + max_norm = torch.tensor([float(max_norm)], device=total_norm.device) clip_coef = max_norm / (total_norm + 1e-6) - tmp_tensor = torch.tensor([1.0], device=parameters[0].device) + tmp_tensor = torch.tensor([1.0], device=clip_coef.device) clip_coef = torch.min(tmp_tensor, clip_coef) for p in parameters: p.grad.data.mul_(clip_coef) @@ -890,42 +888,48 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors' norm_type = float(norm_type) + all_norms = [] if norm_type == inf: - total_norm = max(t.data.abs().max() for t in input_tensors) - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + for t in input_tensors: + all_norms.append(t.data.abs().max().float()) + total_norm = torch.stack(all_norms).max() + device_total_norm = total_norm.to(get_accelerator().current_device_name()) if mpu is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) + dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) if moe_ep_group is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=moe_ep_group) - total_norm = total_norm_cuda[0].item() + dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=moe_ep_group) + total_norm = device_total_norm.to(input_tensors[0].device) else: - if use_graph: - if 'norm_tensors_compute_buffer' not in graph_cache: - graph_cache['norm_tensors_compute_buffer'] = [t.data.float().norm(norm_type) for t in input_tensors] - compute_buffer = graph_cache['norm_tensors_compute_buffer'] - def _norm_tensors(tensor_list, _compute_buffer, _norm_type): - for i, t in enumerate(tensor_list): - _compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type) - if i != 0: - _compute_buffer[0].data.add_(_compute_buffer[i].data) + if 'norm_tensors_compute_buffer' not in graph_cache or len( + graph_cache['norm_tensors_compute_buffer']) != len(input_tensors): + graph_cache['norm_tensors_compute_buffer'] = [ + torch.empty([], dtype=torch.float, device=get_accelerator().current_device_name()) + for t in input_tensors + ] + compute_buffer = graph_cache['norm_tensors_compute_buffer'] - graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type) + def _norm_tensors(tensor_list, _compute_buffer, _norm_type): + for i, t in enumerate(tensor_list): + _compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type) + if i != 0: + _compute_buffer[0].data.add_(_compute_buffer[i].data) - total_norm = compute_buffer[0] + if use_graph: + graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type) else: - total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors]) + _norm_tensors(input_tensors, compute_buffer, norm_type) + + device_total_norm = compute_buffer[0].float().detach() - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach() if mpu is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) + dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) if moe_ep_group is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=moe_ep_group) + dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=moe_ep_group) + total_norm = device_total_norm.to(input_tensors[0].device).pow(1. / norm_type) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) - - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + inf_or_nan = total_norm.isinf().logical_or(total_norm.isnan()) + total_norm.masked_fill_(inf_or_nan, -1) return total_norm diff --git a/op_builder/cpu/comm.py b/op_builder/cpu/comm.py index b26328341081..38e965530f43 100644 --- a/op_builder/cpu/comm.py +++ b/op_builder/cpu/comm.py @@ -19,7 +19,7 @@ def absolute_name(self): return f'deepspeed.ops.comm.{self.NAME}_op' def sources(self): - return ['csrc/cpu/comm/ccl.cpp'] + return ['csrc/cpu/comm/ccl.cpp', 'csrc/cpu/comm/shm.cpp'] def include_paths(self): includes = ['csrc/cpu/includes'] diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py new file mode 100644 index 000000000000..bafd3e0c33f6 --- /dev/null +++ b/op_builder/fp_quantizer.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CUDAOpBuilder, installed_cuda_version + + +class FPQuantizerBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_FP_QUANTIZER" + NAME = "fp_quantizer" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.fp_quantizer.{self.NAME}_op' + + def is_compatible(self, verbose=True): + try: + import torch + except ImportError: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + + cuda_okay = True + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 8: + self.warning("NVIDIA Inference is only supported on Ampere and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in ccs: + if int(cc[0]) >= 8: + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def sources(self): + return [ + "csrc/fp_quantizer/quantize.cu", + "csrc/fp_quantizer/quantize.cpp", + ] + + def extra_ldflags(self): + return ['-lcurand'] + + def include_paths(self): + return ['csrc/fp_quantizer/includes', 'csrc/includes'] diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index f28c1ecb165c..dd13ac163517 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -10,6 +10,7 @@ pytest<=8.0.0 pytest-forked pytest-randomly pytest-xdist +qtorch==0.3.0 recommonmark sphinx sphinx-rtd-theme diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index 310a0df16381..dd340b1117c4 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -9,6 +9,9 @@ import gc from unit.common import DistributedTest from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader +import deepspeed.comm as dist +from deepspeed import get_accelerator +from deepspeed.moe.sharded_moe import top1gating from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param from deepspeed.runtime.utils import required_torch_version @@ -132,3 +135,23 @@ def test(self, ep_size, use_residual): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + + +class TestTopk(DistributedTest): + world_size = 2 + + def test(self): + device = get_accelerator().current_device() + if dist.get_rank() == 0: + logits = torch.rand(2, 2, device=device) + elif dist.get_rank() == 1: + logits = torch.rand(10, 2, device=device) + + output = top1gating(logits=logits, + capacity_factor=1, + min_capacity=0, + used_token=None, + noisy_gate_policy=None, + drop_tokens=False, + use_rts=True, + use_tutel=False) diff --git a/tests/unit/ops/fp_quantizer/test_fp_quant.py b/tests/unit/ops/fp_quantizer/test_fp_quant.py new file mode 100644 index 000000000000..101f4cd69811 --- /dev/null +++ b/tests/unit/ops/fp_quantizer/test_fp_quant.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed + +from deepspeed.ops.fp_quantizer import FP_Quantize +from deepspeed.ops.op_builder import FPQuantizerBuilder + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + +# warning: this import silently JIT builds a set of kernels and may take a minute +from qtorch.quant import float_quantize + + +def qtorch_quantize(input, exp_bits=4, man_bits=3, rounding="nearest", group_size=1024): + ori_dt = input.dtype + ori_shape = input.shape + last_dim = group_size + input = input.view(-1, last_dim) + + q_bits = exp_bits + man_bits + 1 + input_to_float = input.float() + if q_bits == 8: + q_range = 480. + elif q_bits == 6: + q_range = 28. + elif q_bits == 12: + q_range = 510. + else: + assert (0), \ + "Please specify the right quantization range for the selected precision!" + input_max = input_to_float.abs().amax(dim=-1, keepdim=True) + return ((float_quantize(input_to_float / input_max * q_range, exp_bits, man_bits, rounding=rounding) * \ + input_max / q_range).to(ori_dt)).reshape(ori_shape) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +def test_fp_quant_meta(dtype): + group_size = 128 + q_bits = 8 + exp_bits = 4 + man_bits = 3 + + fpq = FP_Quantize(group_size=group_size) + for i in range(10): + x = torch.rand(4, 1024, dtype=dtype, device='cuda') + + ds_x = x.clone() + x_quantized, meta_tensor = fpq.quantize(ds_x, q_bits=q_bits, return_meta_tensor=True) + x_dequantized = fpq.dequantize(x_quantized, q_bits=q_bits, scale=meta_tensor) + + qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size) + qtorch_error = (qtorch_out - x).abs().sum() / x.numel() + ds_error = (x_dequantized - x).abs().sum() / x.numel() + + assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" + + +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +@pytest.mark.parametrize("q_bits", [8, 6, 12], ids=["qbits8", "qbits6", "qbits12"]) +def test_fp_quant(dtype, q_bits): + group_size = 128 + fpq = FP_Quantize(group_size=group_size) + + for i in range(10): + x = torch.rand(4, 1024, dtype=dtype, device='cuda') + + ds_x = x.clone() + x_quantized = fpq.quantize(ds_x, q_bits=q_bits) + x_dequantized = fpq.dequantize(x_quantized, q_bits=q_bits) + + if q_bits == 8: + exp_bits = 4 + man_bits = 3 + elif q_bits == 6: + exp_bits = 3 + man_bits = 2 + elif q_bits == 12: + exp_bits = 4 + man_bits = 7 + else: + raise ValueError(f"unknown {q_bits=}") + + qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size) + + qtorch_error = (qtorch_out - x).abs().sum() / x.numel() + ds_error = (x_dequantized - x).abs().sum() / x.numel() + + assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}"