From 4b6c099c6f515da9a1b33b59c6ccc9a9034465d8 Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Tue, 22 Oct 2024 06:30:16 -0500 Subject: [PATCH 01/55] WIP mgx llama2-7b example --- .../transformers/mgx_llama2/CMakeLists.txt | 33 ++++ .../mgx_llama2/harness/buffer.hpp | 159 ++++++++++++++++++ .../mgx_llama2/harness/common.hpp | 99 +++++++++++ .../mgx_llama2/harness/logging.hpp | 49 ++++++ .../transformers/mgx_llama2/harness/numa.hpp | 158 +++++++++++++++++ .../transformers/mgx_llama2/harness/numpy.hpp | 111 ++++++++++++ .../transformers/mgx_llama2/harness/timer.hpp | 69 ++++++++ examples/transformers/mgx_llama2/mgxllama2.cc | 116 +++++++++++++ 8 files changed, 794 insertions(+) create mode 100644 examples/transformers/mgx_llama2/CMakeLists.txt create mode 100644 examples/transformers/mgx_llama2/harness/buffer.hpp create mode 100644 examples/transformers/mgx_llama2/harness/common.hpp create mode 100644 examples/transformers/mgx_llama2/harness/logging.hpp create mode 100644 examples/transformers/mgx_llama2/harness/numa.hpp create mode 100644 examples/transformers/mgx_llama2/harness/numpy.hpp create mode 100644 examples/transformers/mgx_llama2/harness/timer.hpp create mode 100644 examples/transformers/mgx_llama2/mgxllama2.cc diff --git a/examples/transformers/mgx_llama2/CMakeLists.txt b/examples/transformers/mgx_llama2/CMakeLists.txt new file mode 100644 index 00000000000..ec27b8ff203 --- /dev/null +++ b/examples/transformers/mgx_llama2/CMakeLists.txt @@ -0,0 +1,33 @@ +project(MGXLlama2) +cmake_minimum_required(VERSION 3.22) + +set(TARGET_NAME mgxllama2) + +set(HARNESS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/harness) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CXX /opt/rocm/llvm/bin/clang++) + +list (APPEND CMAKE_PREFIX_PATH /opt/rocm ${HIP_PATH}) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O3 -W -Wall -pthread -D__HIP_PLATFORM_HCC__=1") + +find_package(migraphx REQUIRED) +find_package(hip REQUIRED) + +include_directories(${HARNESS_DIR}) + +add_executable(${TARGET_NAME} + mgxllama2.cc +) + +target_include_directories(${TARGET_NAME} + PUBLIC ${HARNESS_DIR} +) + +target_link_libraries(${TARGET_NAME} + migraphx::c + hip::device + pthread +) + diff --git a/examples/transformers/mgx_llama2/harness/buffer.hpp b/examples/transformers/mgx_llama2/harness/buffer.hpp new file mode 100644 index 00000000000..3cdb65f9176 --- /dev/null +++ b/examples/transformers/mgx_llama2/harness/buffer.hpp @@ -0,0 +1,159 @@ +#pragma once + +#include "common.hpp" + +namespace mlinfer +{ + template + struct IBuffer : public INoCopy + { + AllocFunc alloc_fn; + FreeFunc free_fn; + }; + + template + struct GenericBuffer : public IBuffer + { + GenericBuffer() + : size_in_bytes{0}, stride_in_bytes{0}, tensor_ptr{nullptr} + { + } + + explicit GenericBuffer(size_t size_in_bytes_, size_t stride_in_bytes_ = 0) + : size_in_bytes{size_in_bytes_}, stride_in_bytes{stride_in_bytes_} + { + if (stride_in_bytes == 0) + { + stride_in_bytes = size_in_bytes; + } + this->alloc_fn(&tensor_ptr, size_in_bytes); + } + + GenericBuffer(GenericBuffer &&buf) + : size_in_bytes{buf.size_in_bytes}, stride_in_bytes{buf.stride_in_bytes}, tensor_ptr{buf.tensor_ptr} + { + buf.size_in_bytes = 0; + buf.stride_in_bytes = 0; + buf.tensor_ptr = nullptr; + } + + GenericBuffer &operator=(GenericBuffer &&buf) + { + if (this != &buf) + { + this->free_fn(tensor_ptr); + size_in_bytes = buf.size_in_bytes; + stride_in_bytes = buf.stride_in_bytes; + tensor_ptr = buf.tensor_ptr; + buf.size_in_bytes = 0; + buf.stride_in_bytes = 0; + buf.tensor_ptr = nullptr; + } + return *this; + } + + GenericBuffer(const GenericBuffer &buf) = delete; + GenericBuffer &operator=(const GenericBuffer &buf) = delete; + + ~GenericBuffer() + { + this->free_fn(tensor_ptr); + } + + size_t size_in_bytes; + size_t stride_in_bytes; + void *tensor_ptr; + }; + + struct DeviceAllocator + { + void operator()(void **ptr, size_t size) const + { + LOG_INFO("Malloc " << size << " bytes on device"); + TIMED(hipMalloc, check_hip_status(hipMalloc(ptr, size))); + TIMED(hipMemset, check_hip_status(hipMemset(*ptr, 0, size))); + } + }; + + struct DeviceFree + { + void operator()(void *ptr) const + { + TIMED(hipFree, check_hip_status_non_throwing(hipFree(ptr))); + ptr = nullptr; + } + }; + + struct HostAllocator + { + void operator()(void **ptr, size_t size) const + { + LOG_INFO("Malloc " << size << " bytes on host"); + TIMED(hipHostMalloc, check_hip_status(hipHostMalloc(ptr, size))); + TIMED(hipMemset, check_hip_status(hipMemset(*ptr, 0, size))); + } + }; + + struct HostFree + { + void operator()(void *ptr) const + { + TIMED(hipHostFree, check_hip_status_non_throwing(hipHostFree(ptr))); + ptr = nullptr; + } + }; + + using DeviceBuffer = GenericBuffer; + using HostBuffer = GenericBuffer; + + struct ManagedBuffer + { + + explicit ManagedBuffer(size_t size_in_bytes_, size_t stride_in_bytes_ = 0) + { + dbuff = DeviceBuffer(size_in_bytes_, stride_in_bytes_); + hbuff = HostBuffer(size_in_bytes_, stride_in_bytes_); + } + + template + T get_host_ptr() + { + return static_cast(hbuff.tensor_ptr); + } + + template + T get_device_ptr() + { + return static_cast(dbuff.tensor_ptr); + } + + void upload_to_device(void* data, size_t size_in_bytes) + { + memcpy(get_host_ptr(), data, size_in_bytes); + check_hip_status(hipMemcpy(get_device_ptr(), get_host_ptr(), size_in_bytes, hipMemcpyKind::hipMemcpyHostToDevice)); + } + + template + std::vector download_from_device(size_t size_in_bytes) + { + check_hip_status(hipMemcpy(get_host_ptr(), get_device_ptr(), size_in_bytes, hipMemcpyKind::hipMemcpyDeviceToHost)); + return std::vector(get_host_ptr(), get_host_ptr() + (size_in_bytes / sizeof(T))); + } + + template + void update_device_data(T data, size_t position) + { + T* host_data = get_host_ptr(); + host_data[position] = data; + // TODO: don't copy over the entire buffer just the changed range + check_hip_status(hipMemcpy(get_device_ptr(), get_host_ptr(), dbuff.size_in_bytes, hipMemcpyKind::hipMemcpyHostToDevice)); + } + + ManagedBuffer() = delete; + ManagedBuffer(const ManagedBuffer &buf) = delete; + ManagedBuffer &operator=(const ManagedBuffer &buf) = delete; + + DeviceBuffer dbuff; + HostBuffer hbuff; + }; +} diff --git a/examples/transformers/mgx_llama2/harness/common.hpp b/examples/transformers/mgx_llama2/harness/common.hpp new file mode 100644 index 00000000000..db155af820a --- /dev/null +++ b/examples/transformers/mgx_llama2/harness/common.hpp @@ -0,0 +1,99 @@ +#pragma once + +#include + +#include "logging.hpp" +#include "timer.hpp" + +#include +#include +#include +#include + +#define TIMER_ON 0 +#define TRACE_ON 0 + +#define assertm(exp, msg) assert(((void)msg, exp)) + +namespace mlinfer +{ + struct INoCopy + { + INoCopy() = default; + virtual ~INoCopy() = default; + INoCopy(const INoCopy &) = delete; + INoCopy &operator=(const INoCopy &) = delete; + }; + + /* Helper function to split a string based on a delimiting character */ + inline std::vector + splitString(const std::string &input, const std::string &delimiter) + { + std::vector result; + size_t start = 0; + size_t next = 0; + while (next != std::string::npos) + { + next = input.find(delimiter, start); + result.emplace_back(input, start, next - start); + start = next + 1; + } + return result; + } + +#define check_hip_status(hip_call) \ + do \ + { \ + int status = (hip_call); \ + if (status != hipSuccess) \ + { \ + throw std::runtime_error("hip error (" + std::to_string(status) + "): " + std::string(hipGetErrorString(static_cast(status)))); \ + } \ + } while (0); + +#define check_hip_status_non_throwing(hip_call) \ + do \ + { \ + int status = (hip_call); \ + if (status != hipSuccess) \ + { \ + LOG_INFO("hip error (" + std::to_string(status) + "): " + std::string(hipGetErrorString(static_cast(status)))); \ + } \ + } while (0); + + +#define CHECK(condition, error) \ + do \ + { \ + if (!(condition)) \ + { \ + std::cerr << error << std::endl; \ + } \ + } while (0); + +#if TIMER_ON +#define TIMER_STARTV(s) \ + static Timer timer##s(#s, true); \ + auto start##s = std::chrono::high_resolution_clock::now(); +#define TIMER_START(s) \ + static Timer timer##s(#s); \ + auto start##s = std::chrono::high_resolution_clock::now(); +#define TIMER_END(s) timer##s.add(std::chrono::high_resolution_clock::now() - start##s); +#else +#define TIMER_START(s) +#define TIMER_STARTV(s) +#define TIMER_END(s) +#endif + +#define TIMED(s, call) \ + do \ + { \ + TIMER_START(s); \ + { \ + call; \ + } \ + TIMER_END(s); \ + } while (0); + +} // namespace mlinfer + diff --git a/examples/transformers/mgx_llama2/harness/logging.hpp b/examples/transformers/mgx_llama2/harness/logging.hpp new file mode 100644 index 00000000000..a25e73498a2 --- /dev/null +++ b/examples/transformers/mgx_llama2/harness/logging.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include + +namespace mlinfer +{ + +#define LOGGING_OFF 0 +#define ENABLE_TIMED_LOGGING 0 +#define ENABLE_DEBUG_LOGGING 0 + +#if (!LOGGING_OFF) +#define LOG_INFO(...) \ + do \ + { \ + std::cout << __VA_ARGS__ << std::endl; \ + } while (0) +#define LOG_ERROR(...) \ + do \ + { \ + std::cerr << __VA_ARGS__ << std::endl; \ + } while (0) +#define LOG_STATE(...) \ + do \ + { \ + std::cout << "================================================" << std::endl; \ + std::cout << __VA_ARGS__ << std::endl; \ + std::cout << "================================================" << std::endl; \ + } while (0) +#else +#define LOG_INFO(...) (void)0 +#define LOG_ERROR(...) (void)0 +#define LOG_STATE(...) (void)0 +#endif + +#if (ENABLE_TIMED_LOGGING || ENABLE_DEBUG_LOGGING) +#define LOG_TIMED(...) LOG_INFO(__VA_ARGS__) +#else +#define LOG_TIMED(...) (void)0 +#endif + +#if ENABLE_DEBUG_LOGGING +#define LOG_DEBUG(...) LOG_INFO(__VA_ARGS__) +#else +#define LOG_DEBUG(...) (void)0 +#endif + +} // namespace mlinfer + diff --git a/examples/transformers/mgx_llama2/harness/numa.hpp b/examples/transformers/mgx_llama2/harness/numa.hpp new file mode 100644 index 00000000000..5c0ee561efe --- /dev/null +++ b/examples/transformers/mgx_llama2/harness/numa.hpp @@ -0,0 +1,158 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.hpp" + +namespace mlinfer +{ + // NUMA config. Each NUMA node contains a pair of GPU indices and CPU indices. + using NumaConfig = std::vector, std::vector>>; + + // The NUMA node idx for each GPU. + using GpuToNumaMap = std::vector; + + struct NumaSettings + { + NumaConfig numa_config; + GpuToNumaMap gpu_to_numa_map; + }; + + struct Numa final + { + NumaSettings numa_settings; + + explicit Numa(const NumaSettings &numa_settings) : numa_settings{numa_settings} {} + + inline bool UseNuma() const + { + return not numa_settings.numa_config.empty(); + } + + inline size_t GetNumaCount() const + { + return numa_settings.numa_config.size(); + }; + + inline int GetNumaIdx(const int deviceId) const + { + return UseNuma() ? numa_settings.gpu_to_numa_map.at(deviceId) : 0; + } + + inline std::vector GetClosestCpus(const int deviceId) const + { + assertm(UseNuma(), "GetClosestCpus only available for NUMA"); + return numa_settings.numa_config.at(GetNumaIdx(deviceId)).second; + } + }; + + // Restrict mem allocation to specific NUMA node. + inline void + bindNumaMemPolicy(const int32_t numaIdx, const int32_t nbNumas) + { + unsigned long nodeMask = 1UL << numaIdx; + long ret = set_mempolicy(MPOL_BIND, &nodeMask, nbNumas + 1); + CHECK(ret >= 0, std::strerror(errno)); + } + + // Reset mem allocation setting. + inline void resetNumaMemPolicy() + { + long ret = set_mempolicy(MPOL_DEFAULT, nullptr, 0); + CHECK(ret >= 0, std::strerror(errno)); + } + + // Limit a thread to be on specific cpus. + inline void bindThreadToCpus(std::thread &th, const std::vector &cpus, const bool ignore_esrch = false) + { + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + for (int cpu : cpus) + { + CPU_SET(cpu, &cpuset); + } + int ret = pthread_setaffinity_np(th.native_handle(), sizeof(cpu_set_t), &cpuset); + bool noerr = ignore_esrch ? ret == 0 || ret == ESRCH : ret == 0; + CHECK(noerr, std::strerror(ret)); + } + + // Helper to converts the range string (like "0,2-5,13-17") to a vector of ints. + inline std::vector parseRange(const std::string &s) + { + std::vector results; + auto ranges = splitString(s, ","); + for (const auto &range : ranges) + { + auto startEnd = splitString(range, "-"); + CHECK((startEnd.size() <= 2), "Invalid numa_config setting. Expects zero or one '-'."); + if (startEnd.size() == 1) + { + results.push_back(std::stoi(startEnd[0])); + } + else + { + size_t start = std::stoi(startEnd[0]); + size_t last = std::stoi(startEnd[1]); + for (size_t i = start; i <= last; ++i) + { + results.push_back(i); + } + } + } + return results; + } + + // Example of the format: "0,2:0-63&1,3:64-127" for 4 GPUs, 128 CPU, 2 NUMA node system. + inline NumaConfig parseNumaConfig(const std::string &numa_file) + { + std::string numa_str; + std::ifstream file(numa_file.c_str()); + if (file.is_open()) + { + getline(file, numa_str); + file.close(); + } + + NumaConfig config; + if (!numa_str.empty()) + { + auto nodes = splitString(numa_str, "&"); + for (const auto &node : nodes) + { + auto pair = splitString(node, ":"); + CHECK((pair.size() == 2), "Invalid numa_config setting. Expects one ':'."); + auto gpus = parseRange(pair[0]); + auto cpus = parseRange(pair[1]); + config.emplace_back(std::make_pair(gpus, cpus)); + } + } + return config; + } + + // Convert NumaConfig to GpuToNumaMap for easier look-up. + inline GpuToNumaMap getGpuToNumaMap(const NumaConfig &config) + { + std::vector map; + for (size_t numaIdx = 0; numaIdx < config.size(); numaIdx++) + { + for (const auto gpuIdx : config[numaIdx].first) + { + if (gpuIdx >= map.size()) + { + map.resize(gpuIdx + 1); + } + map[gpuIdx] = numaIdx; + } + } + return map; + } +} // namespace mlinfer + diff --git a/examples/transformers/mgx_llama2/harness/numpy.hpp b/examples/transformers/mgx_llama2/harness/numpy.hpp new file mode 100644 index 00000000000..8cc5a15db67 --- /dev/null +++ b/examples/transformers/mgx_llama2/harness/numpy.hpp @@ -0,0 +1,111 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "common.hpp" +#include "logging.hpp" + +namespace mlinfer +{ + namespace npy + { + class NpyFile + { + private: + std::string m_Path; + std::ifstream m_FStream; + size_t m_HeaderSize; + std::string m_Header; + size_t m_TensorSize; + size_t m_ElementSize; + std::vector m_TensorDims; + + public: + explicit NpyFile(const std::string &path) + : m_Path(path), m_FStream(m_Path) + { + LOG_INFO("Npy file from " << path); + // magic and fixed header + char b[256]; + m_FStream.read(b, 10); + CHECK(m_FStream, "Unable to parse: " << m_Path); + + // check magic + CHECK(static_cast(b[0]) == 0x93 && b[1] == 'N' && b[2] == 'U' && b[3] == 'M' && b[4] == 'P' && b[5] == 'Y', "Bad magic: " << m_Path); + + // get header + auto major = static_cast(b[6]); + // auto minor = static_cast(b[7]); + CHECK(major == 1, "Only npy version 1 is supported: " << m_Path); + m_HeaderSize = static_cast(b[8]); + m_Header.resize(m_HeaderSize); + m_FStream.read(static_cast(m_Header.data()), m_HeaderSize); + + // get file size + auto cur = m_FStream.tellg(); + m_FStream.seekg(0, std::ios::end); + auto size = m_FStream.tellg(); + m_TensorSize = size - cur; + + // parse header + std::regex re(R"re(\{'descr': '[<|][fi]([\d])', 'fortran_order': False, 'shape': \(([\d, ]*)\), \} +\n)re"); + std::smatch matches; + CHECK(std::regex_match(m_Header, matches, re), "Cannot parse numpy header: " << m_Path); + CHECK(matches.size() == 3, "Cannot parse numpy header: " << m_Path); + m_ElementSize = std::stoi(matches[1]); + std::vector dims = splitString(matches[2], ", "); + m_TensorDims.resize(dims.size()); + std::transform( + dims.begin(), dims.end(), m_TensorDims.begin(), [](const std::string &s) + { return std::stoi(s); }); + + // check header sanity + size_t tensorSize = std::accumulate(m_TensorDims.begin(), m_TensorDims.end(), m_ElementSize, std::multiplies()); + CHECK(tensorSize == m_TensorSize, "Header description does not match file size: " << m_Path); + LOG_DEBUG(" Input num=" << m_TensorDims[0] << " | Sample size=" << (tensorSize / m_TensorDims[0]) << " | Full size=" << m_TensorSize); + } + ~NpyFile() + { + m_FStream.close(); + }; + std::string GetPath() const + { + return m_Path; + } + std::vector GetDims() const + { + return m_TensorDims; + } + size_t GetTensorSize() const + { + return m_TensorSize; + } + // load the entire tensor + void LoadAll(void *dst) + { + m_FStream.seekg(10 + m_HeaderSize, std::ios::beg); + m_FStream.read(static_cast(dst), m_TensorSize); + CHECK(m_FStream, "Unable to parse: " << m_Path); + CHECK(m_FStream.peek() == EOF, "Did not consume full file: " << m_Path); + } + + // load only selected indices from the Tensor, assuming that the first dim is batch dim. + void LoadSamples(void *dst, const std::vector &indices) + { + size_t sampleSize = std::accumulate(m_TensorDims.begin() + 1, m_TensorDims.end(), m_ElementSize, std::multiplies()); + for (size_t i = 0; i < indices.size(); i++) + { + m_FStream.seekg(10 + m_HeaderSize + indices[i] * sampleSize, std::ios::beg); + m_FStream.read(static_cast(dst) + i * sampleSize, sampleSize); + } + } + }; + } // namespace npy +} // namespace mlinfer + diff --git a/examples/transformers/mgx_llama2/harness/timer.hpp b/examples/transformers/mgx_llama2/harness/timer.hpp new file mode 100644 index 00000000000..517cffb148e --- /dev/null +++ b/examples/transformers/mgx_llama2/harness/timer.hpp @@ -0,0 +1,69 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +// For debugging the timing of each part +class Timer +{ +public: + explicit Timer(const std::string &tag_, bool verbose_ = false) + : tag(tag_), verbose(verbose_) + { + std::cout << "Timer " << tag << " created." << std::endl; + } + void add(const std::chrono::duration &in) + { + std::thread::id id = std::this_thread::get_id(); + count[id] += 1; + total[id] += in; + if (verbose) + measurements[id].emplace_back(in); + } + ~Timer() + { + auto total_accum = std::accumulate( + std::begin(total), + std::end(total), + 0, + [](int64_t value, std::pair> p) + { return value + p.second.count(); }); + + auto count_accum = std::accumulate( + std::begin(count), + std::end(count), + 0, + [](size_t value, std::pair p) + { return value + p.second; }); + + std::cout << "Timer " << tag << " reports " << (double)total_accum / count_accum << " ms per call for " << count_accum + << " times." << std::endl; + if (verbose) + { + std::cout << " Measurements=["; + for (const auto &m : measurements) + { + std::cout << " Thread " << m.first << ": {"; + for (const auto &d : m.second) + { + std::cout << d.count() << ","; + } + + std::cout << "},"; + } + std::cout << "]" << std::endl; + } + } + +private: + std::string tag; + bool verbose; + std::unordered_map> total; + std::unordered_map>> measurements; + std::unordered_map count; +}; diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc new file mode 100644 index 00000000000..720c5669366 --- /dev/null +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -0,0 +1,116 @@ +#include "buffer.hpp" +#include "numpy.hpp" +#include + +#include +#include +#include +#include +#include + +using namespace mlinfer; + +const std::string MODEL_FILE_PATH = "/code/AMDMIGraphX/examples/transformers/python_llama2/models/llama-2-7b-chat-hf/model-256.mxr"; +std::vector SAMPLE_IDS = {1,6804,5207,387,287,29973}; +// sequence length from model config +const size_t SEQ_SIZE = 256; +const size_t VOCAB_SIZE = 32000; +// EOS token from model config +const size_t EOS = 2; + +static migraphx::program loadProgram() +{ + std::filesystem::path compiled_path(MODEL_FILE_PATH); + + migraphx::file_options file_options; + file_options.set_file_format("msgpack"); + + migraphx::program prog; + std::ifstream f(compiled_path.c_str()); + if (f.good()) + { + prog = migraphx::load(compiled_path.c_str(), file_options); + } + else + { + std::cout << "model is not good.\n"; + } + return prog; +} + +int main() { + std::cout << "Loading model ..." << std::endl; + migraphx::program prog = loadProgram(); + std::cout << "Model loaded" << std::endl; + + prog.print(); + + auto output_tokens = SAMPLE_IDS; + SAMPLE_IDS.resize(SEQ_SIZE, 0); + std::vector attention_mask = SAMPLE_IDS; + std::transform(std::begin(attention_mask), std::end(attention_mask), std::begin(attention_mask), [](auto i){ + return (i != 0) ? 1 : 0; + }); + + std::vector position_ids; + for (int64_t i=0; i < SEQ_SIZE; ++i) + { + position_ids.emplace_back(i); + } + + migraphx::program_parameters prog_args; + auto param_shapes = prog.get_parameter_shapes(); + + size_t alloc_size = SEQ_SIZE * sizeof(int64_t); + + std::cout << "Uploading input ids to the GPU" << std::endl; + auto name = "input_ids"; + auto input_ids_buffer = ManagedBuffer(alloc_size); + input_ids_buffer.upload_to_device(static_cast(SAMPLE_IDS.data()), alloc_size); + prog_args.add(name, migraphx::argument(param_shapes[name], input_ids_buffer.get_device_ptr())); + + std::cout << "Uploading attention mask to the GPU" << std::endl; + name = "attention_mask"; + auto attention_mask_buffer = ManagedBuffer(alloc_size); + attention_mask_buffer.upload_to_device(static_cast(attention_mask.data()), alloc_size); + prog_args.add(name, migraphx::argument(param_shapes[name], attention_mask_buffer.get_device_ptr())); + + std::cout << "Uploading position ids to the GPU" << std::endl; + name = "position_ids"; + auto position_ids_buffer = ManagedBuffer(alloc_size); + position_ids_buffer.upload_to_device(static_cast(position_ids.data()), alloc_size); + prog_args.add(name, migraphx::argument(param_shapes[name], position_ids_buffer.get_device_ptr())); + + // Handle output tensors + std::cout << "Creating output buffer" << std::endl; + const size_t output_size = SEQ_SIZE * VOCAB_SIZE * sizeof(float); + name = "@return"; + auto output_buffer = ManagedBuffer(output_size); + migraphx::shape outShape{migraphx_shape_float_type, {1, 256, 32000}}; + prog_args.add(name, migraphx::argument(outShape, output_buffer.get_device_ptr())); + + std::cout << "Starting evaluation" << std::endl; + for (int i = 5; i < SEQ_SIZE; ++i) + { + prog.eval(prog_args); + // TODO: Only download the relevant data range + std::vector logits = output_buffer.download_from_device(output_size); + std::vector::iterator max = std::max_element(std::begin(logits) + (i * VOCAB_SIZE), std::begin(logits) + ((i + 1) * VOCAB_SIZE)); + int64_t new_token = std::distance(std::begin(logits) + (i * VOCAB_SIZE), max); + output_tokens.push_back(new_token); + if (new_token == EOS) + { + break; + } + input_ids_buffer.update_device_data(new_token, i + 1); + attention_mask_buffer.update_device_data(1, i + 1); + } + + std::cout << "######### Output token ids #########" << std::endl; + // print output tokens + for (auto tok: output_tokens){ + std::cout << tok << ", "; + } + std::cout << std::endl; + return 0; +} From 1edf198ce52d4f6822e938e8d6fea7afe6b45973 Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Tue, 22 Oct 2024 07:21:30 -0500 Subject: [PATCH 02/55] Code works with offload copy --- examples/transformers/mgx_llama2/mgxllama2.cc | 51 +++++++++++-------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 720c5669366..2021770e757 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -64,37 +64,40 @@ int main() { size_t alloc_size = SEQ_SIZE * sizeof(int64_t); std::cout << "Uploading input ids to the GPU" << std::endl; - auto name = "input_ids"; - auto input_ids_buffer = ManagedBuffer(alloc_size); - input_ids_buffer.upload_to_device(static_cast(SAMPLE_IDS.data()), alloc_size); - prog_args.add(name, migraphx::argument(param_shapes[name], input_ids_buffer.get_device_ptr())); + auto input_ids_str = "input_ids"; + // auto input_ids_buffer = ManagedBuffer(alloc_size); + // input_ids_buffer.upload_to_device(static_cast(SAMPLE_IDS.data()), alloc_size); + prog_args.add(input_ids_str, migraphx::argument(param_shapes[input_ids_str], SAMPLE_IDS.data())); std::cout << "Uploading attention mask to the GPU" << std::endl; - name = "attention_mask"; - auto attention_mask_buffer = ManagedBuffer(alloc_size); - attention_mask_buffer.upload_to_device(static_cast(attention_mask.data()), alloc_size); - prog_args.add(name, migraphx::argument(param_shapes[name], attention_mask_buffer.get_device_ptr())); + auto attention_mask_str = "attention_mask"; + // auto attention_mask_buffer = ManagedBuffer(alloc_size); + // attention_mask_buffer.upload_to_device(static_cast(attention_mask.data()), alloc_size); + prog_args.add(attention_mask_str, migraphx::argument(param_shapes[attention_mask_str], attention_mask.data())); std::cout << "Uploading position ids to the GPU" << std::endl; - name = "position_ids"; - auto position_ids_buffer = ManagedBuffer(alloc_size); - position_ids_buffer.upload_to_device(static_cast(position_ids.data()), alloc_size); - prog_args.add(name, migraphx::argument(param_shapes[name], position_ids_buffer.get_device_ptr())); + auto position_ids_str = "position_ids"; + // auto position_ids_buffer = ManagedBuffer(alloc_size); + // position_ids_buffer.upload_to_device(static_cast(position_ids.data()), alloc_size); + prog_args.add(position_ids_str, migraphx::argument(param_shapes[position_ids_str], position_ids.data())); // Handle output tensors - std::cout << "Creating output buffer" << std::endl; - const size_t output_size = SEQ_SIZE * VOCAB_SIZE * sizeof(float); - name = "@return"; - auto output_buffer = ManagedBuffer(output_size); - migraphx::shape outShape{migraphx_shape_float_type, {1, 256, 32000}}; - prog_args.add(name, migraphx::argument(outShape, output_buffer.get_device_ptr())); + // std::cout << "Creating output buffer" << std::endl; + const size_t output_size = SEQ_SIZE * VOCAB_SIZE; + // name = "@return"; + // auto output_buffer = ManagedBuffer(output_size); + // migraphx::shape outShape{migraphx_shape_float_type, {1, 256, 32000}}; + // prog_args.add(name, migraphx::argument(outShape, output_buffer.get_device_ptr())); std::cout << "Starting evaluation" << std::endl; for (int i = 5; i < SEQ_SIZE; ++i) { - prog.eval(prog_args); + std::cout << "# iter: " << i << std::endl; + auto outputs = prog.eval(prog_args); // TODO: Only download the relevant data range - std::vector logits = output_buffer.download_from_device(output_size); + float* results = reinterpret_cast(outputs[0].data()); + std::vector logits(results, results + output_size); + // std::cout << "## logits size: " << logits.size() << std::endl; std::vector::iterator max = std::max_element(std::begin(logits) + (i * VOCAB_SIZE), std::begin(logits) + ((i + 1) * VOCAB_SIZE)); int64_t new_token = std::distance(std::begin(logits) + (i * VOCAB_SIZE), max); output_tokens.push_back(new_token); @@ -102,8 +105,12 @@ int main() { { break; } - input_ids_buffer.update_device_data(new_token, i + 1); - attention_mask_buffer.update_device_data(1, i + 1); + SAMPLE_IDS[i + 1] = new_token; + prog_args.add(input_ids_str, migraphx::argument(param_shapes[input_ids_str], SAMPLE_IDS.data())); + attention_mask[i + 1] = 1; + prog_args.add(attention_mask_str, migraphx::argument(param_shapes[attention_mask_str], attention_mask.data())); + // input_ids_buffer.update_device_data(new_token, i + 1); + // attention_mask_buffer.update_device_data(1, i + 1); } std::cout << "######### Output token ids #########" << std::endl; From 7cd6a0cb4eecc4513e5a7db57ed0c05727504498 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Thu, 24 Oct 2024 03:47:49 -0500 Subject: [PATCH 03/55] Add support to load onnx file --- examples/transformers/mgx_llama2/mgxllama2.cc | 44 +++++++++++++++++-- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 2021770e757..f009e2941f2 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -10,17 +10,21 @@ using namespace mlinfer; -const std::string MODEL_FILE_PATH = "/code/AMDMIGraphX/examples/transformers/python_llama2/models/llama-2-7b-chat-hf/model-256.mxr"; +const std::string MODEL_PATH = "/code/AMDMIGraphX/examples/transformers/python_llama2/models/llama-2-7b-chat-hf/"; +const std::string MODEL_FILE_NAME = "model-256.mxr"; +const std::string ONNX_FILE_NAME = "model.onnx"; std::vector SAMPLE_IDS = {1,6804,5207,387,287,29973}; // sequence length from model config const size_t SEQ_SIZE = 256; const size_t VOCAB_SIZE = 32000; // EOS token from model config const size_t EOS = 2; +// Onnx config +const bool USE_ONNX = true; static migraphx::program loadProgram() { - std::filesystem::path compiled_path(MODEL_FILE_PATH); + std::filesystem::path compiled_path(MODEL_PATH + MODEL_FILE_NAME); migraphx::file_options file_options; file_options.set_file_format("msgpack"); @@ -38,9 +42,43 @@ static migraphx::program loadProgram() return prog; } +static migraphx::program loadOnnx() +{ + std::filesystem::path onnx_path(MODEL_PATH + ONNX_FILE_NAME); + + migraphx::program prog; + std::ifstream f(onnx_path.c_str()); + if (f.good()) + { + migraphx::onnx_options onnx_opts; + std::vector dims = {1, SEQ_SIZE}; + onnx_opts.set_input_parameter_shape("input_ids", dims); + onnx_opts.set_input_parameter_shape("attention_mask", dims); + onnx_opts.set_input_parameter_shape("position_ids", dims); + std::cout << "Parsing onnx file ..." << std::endl; + prog = parse_onnx(onnx_path.c_str(), onnx_opts); + + std::string target_str = "gpu"; + migraphx::target targ = migraphx::target(target_str.c_str()); + + std::cout << "Quantize FP16 ..." << std::endl; + migraphx::quantize_fp16(prog); + + migraphx::compile_options comp_opts; + comp_opts.set_offload_copy(); + std::cout << "Compile to target..." << std::endl; + prog.compile(targ, comp_opts); + } + else + { + std::cout << "Onnx file is not available.\n"; + } + return prog; +} + int main() { std::cout << "Loading model ..." << std::endl; - migraphx::program prog = loadProgram(); + migraphx::program prog = USE_ONNX ? loadOnnx() : loadProgram(); std::cout << "Model loaded" << std::endl; prog.print(); From 595771f8b795a6e412593ba4cabc9abd77eda636 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Thu, 24 Oct 2024 04:07:35 -0500 Subject: [PATCH 04/55] Add dockerization for mgx_llama2 example --- examples/transformers/mgx_llama2/Dockerfile | 20 +++++++++++++++++++ .../transformers/mgx_llama2/build_docker.sh | 3 +++ examples/transformers/mgx_llama2/mgxllama2.cc | 2 +- .../transformers/mgx_llama2/run_docker.sh | 14 +++++++++++++ 4 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 examples/transformers/mgx_llama2/Dockerfile create mode 100755 examples/transformers/mgx_llama2/build_docker.sh create mode 100755 examples/transformers/mgx_llama2/run_docker.sh diff --git a/examples/transformers/mgx_llama2/Dockerfile b/examples/transformers/mgx_llama2/Dockerfile new file mode 100644 index 00000000000..387ac4ac80a --- /dev/null +++ b/examples/transformers/mgx_llama2/Dockerfile @@ -0,0 +1,20 @@ +FROM rocm/dev-ubuntu-22.04:6.2 + +ENV DEBIAN_FRONTEND=noninteractive + +SHELL ["/bin/bash", "-c"] + +RUN apt-get update && apt-get install -y --allow-unauthenticated \ + apt-utils \ + cmake \ + migraphx && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN mkdir /mgx_llama2 + +COPY . /mgx_llama2 + +RUN rm -rf /mgx_llama2/build && mkdir /mgx_llama2/build + +RUN cd /mgx_llama2/build && CXX=/opt/rocm/llvm/bin/clang++ cmake .. && make \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/build_docker.sh b/examples/transformers/mgx_llama2/build_docker.sh new file mode 100755 index 00000000000..c6eb64304ec --- /dev/null +++ b/examples/transformers/mgx_llama2/build_docker.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +docker build --platform linux/amd64 --tag mgx_llama2:v0.1 --file Dockerfile . diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index f009e2941f2..5e7ff8e2071 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -10,7 +10,7 @@ using namespace mlinfer; -const std::string MODEL_PATH = "/code/AMDMIGraphX/examples/transformers/python_llama2/models/llama-2-7b-chat-hf/"; +const std::string MODEL_PATH = "/model/"; const std::string MODEL_FILE_NAME = "model-256.mxr"; const std::string ONNX_FILE_NAME = "model.onnx"; std::vector SAMPLE_IDS = {1,6804,5207,387,287,29973}; diff --git a/examples/transformers/mgx_llama2/run_docker.sh b/examples/transformers/mgx_llama2/run_docker.sh new file mode 100755 index 00000000000..73d41181c26 --- /dev/null +++ b/examples/transformers/mgx_llama2/run_docker.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +if [[ -z "${MODEL_DIR_PATH}" ]]; then + echo "MODEL_DIR_PATH is not set, please provide the path to model before running docker." + exit 1 +else + MODEL_DIR="${MODEL_DIR_PATH}" +fi + +docker run --device='/dev/kfd' --device='/dev/dri' --group-add video \ +-v $MODEL_DIR:/model \ +-w /mgx_llama2/build \ +-it mgx_llama2:v0.1 + From fdd16edd8ffa23913251b65a3a00d5bd6e9126c0 Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Thu, 24 Oct 2024 06:09:07 -0500 Subject: [PATCH 05/55] Rework buffer allocation so offload_copy can be turned on/off --- .../mgx_llama2/harness/buffer.hpp | 66 +++--- examples/transformers/mgx_llama2/mgxllama2.cc | 193 ++++++++++-------- 2 files changed, 147 insertions(+), 112 deletions(-) diff --git a/examples/transformers/mgx_llama2/harness/buffer.hpp b/examples/transformers/mgx_llama2/harness/buffer.hpp index 3cdb65f9176..17391091a70 100644 --- a/examples/transformers/mgx_llama2/harness/buffer.hpp +++ b/examples/transformers/mgx_llama2/harness/buffer.hpp @@ -106,54 +106,60 @@ namespace mlinfer using DeviceBuffer = GenericBuffer; using HostBuffer = GenericBuffer; - struct ManagedBuffer + template + struct ManagedBuffer_v2 { - explicit ManagedBuffer(size_t size_in_bytes_, size_t stride_in_bytes_ = 0) + explicit ManagedBuffer_v2(std::vector&& host_data, bool with_offload_copy=true): with_offload_copy(with_offload_copy) { - dbuff = DeviceBuffer(size_in_bytes_, stride_in_bytes_); - hbuff = HostBuffer(size_in_bytes_, stride_in_bytes_); - } - - template - T get_host_ptr() - { - return static_cast(hbuff.tensor_ptr); + size_in_bytes = host_data.size() * sizeof(T); + hbuff = std::move(host_data); + if (not with_offload_copy) + { + dbuff = DeviceBuffer(size_in_bytes, 0); + } } - template - T get_device_ptr() + void* data() { - return static_cast(dbuff.tensor_ptr); + return with_offload_copy ? static_cast(hbuff.data()) : dbuff.tensor_ptr; } - void upload_to_device(void* data, size_t size_in_bytes) + void upload_to_device() { - memcpy(get_host_ptr(), data, size_in_bytes); - check_hip_status(hipMemcpy(get_device_ptr(), get_host_ptr(), size_in_bytes, hipMemcpyKind::hipMemcpyHostToDevice)); + assert(not with_offload_copy); + check_hip_status(hipMemcpyHtoD(dbuff.tensor_ptr, static_cast(hbuff.data()), size_in_bytes)); } - template - std::vector download_from_device(size_t size_in_bytes) + void download_from_device() { - check_hip_status(hipMemcpy(get_host_ptr(), get_device_ptr(), size_in_bytes, hipMemcpyKind::hipMemcpyDeviceToHost)); - return std::vector(get_host_ptr(), get_host_ptr() + (size_in_bytes / sizeof(T))); + assert(not with_offload_copy); + // TODO: use a separate stream for eval and upload download, so we don't have to sync here + check_hip_status(hipDeviceSynchronize()); + check_hip_status(hipMemcpyDtoH(static_cast(hbuff.data()), dbuff.tensor_ptr, size_in_bytes)); } - template - void update_device_data(T data, size_t position) + void update_data(T data, size_t position) { - T* host_data = get_host_ptr(); - host_data[position] = data; - // TODO: don't copy over the entire buffer just the changed range - check_hip_status(hipMemcpy(get_device_ptr(), get_host_ptr(), dbuff.size_in_bytes, hipMemcpyKind::hipMemcpyHostToDevice)); + hbuff.at(position) = data; + if (not with_offload_copy) + { + // TODO: don't copy over the entire buffer just the changed range + // check_hip_status(hipMemcpy(get_device_ptr(), get_host_ptr(), dbuff.size_in_bytes, hipMemcpyKind::hipMemcpyHostToDevice)); + upload_to_device(); + } } - ManagedBuffer() = delete; - ManagedBuffer(const ManagedBuffer &buf) = delete; - ManagedBuffer &operator=(const ManagedBuffer &buf) = delete; + ManagedBuffer_v2() = delete; + ManagedBuffer_v2(const ManagedBuffer_v2 &buf) = delete; + ManagedBuffer_v2 &operator=(const ManagedBuffer_v2 &buf) = delete; DeviceBuffer dbuff; - HostBuffer hbuff; + std::vector hbuff; + size_t size_in_bytes; + bool with_offload_copy; }; + + using LLama2InputBuffer = ManagedBuffer_v2; + using LLama2OutputBuffer = ManagedBuffer_v2; } diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 5e7ff8e2071..d0f4b4103bc 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -10,41 +10,23 @@ using namespace mlinfer; -const std::string MODEL_PATH = "/model/"; -const std::string MODEL_FILE_NAME = "model-256.mxr"; -const std::string ONNX_FILE_NAME = "model.onnx"; +// TODO: fix paths +const std::string MODEL_PATH = "/code/AMDMIGraphX/examples/transformers/python_llama2/models/llama-2-7b-chat-hf/"; +const std::string MXR_FILE = "model-256.mxr"; +const std::string MXR_FILE_NO_OFFLOAD = "model-256_fp32_nooffload.mxr"; +const std::string ONNX_FILE = "model.onnx"; std::vector SAMPLE_IDS = {1,6804,5207,387,287,29973}; // sequence length from model config const size_t SEQ_SIZE = 256; +// vocab size from model config const size_t VOCAB_SIZE = 32000; // EOS token from model config const size_t EOS = 2; -// Onnx config -const bool USE_ONNX = true; -static migraphx::program loadProgram() +// TODO: enable fp16 quant and fast math +static migraphx::program loadOnnx(std::string& model_path, bool offload_copy, bool quantize_fp16 = false) { - std::filesystem::path compiled_path(MODEL_PATH + MODEL_FILE_NAME); - - migraphx::file_options file_options; - file_options.set_file_format("msgpack"); - - migraphx::program prog; - std::ifstream f(compiled_path.c_str()); - if (f.good()) - { - prog = migraphx::load(compiled_path.c_str(), file_options); - } - else - { - std::cout << "model is not good.\n"; - } - return prog; -} - -static migraphx::program loadOnnx() -{ - std::filesystem::path onnx_path(MODEL_PATH + ONNX_FILE_NAME); + std::filesystem::path onnx_path(model_path); migraphx::program prog; std::ifstream f(onnx_path.c_str()); @@ -62,93 +44,140 @@ static migraphx::program loadOnnx() migraphx::target targ = migraphx::target(target_str.c_str()); std::cout << "Quantize FP16 ..." << std::endl; - migraphx::quantize_fp16(prog); + if (quantize_fp16) + migraphx::quantize_fp16(prog); migraphx::compile_options comp_opts; comp_opts.set_offload_copy(); std::cout << "Compile to target..." << std::endl; prog.compile(targ, comp_opts); + // TODO: save model to mxr } else { - std::cout << "Onnx file is not available.\n"; + std::cerr << "Onnx file is not available on path: " << model_path << std::endl; + exit(1); } return prog; -} +}; + +static migraphx::program loadProgram(std::string& mxr_path, std::string& model_path, bool offload_copy) +{ + std::filesystem::path compiled_path(mxr_path); + + migraphx::file_options file_options; + file_options.set_file_format("msgpack"); + + migraphx::program prog; + std::ifstream f(compiled_path.c_str()); + if (f.good()) + { + std::cout << "Loadind model from MXR ...\n"; + prog = migraphx::load(compiled_path.c_str(), file_options); + } + else + { + std::cout << "MXR file can't be loaded try to load ONNX\n"; + prog = loadOnnx(model_path, offload_copy); + } + return prog; +}; + +struct LLama2Inputs +{ + LLama2Inputs(migraphx::program& prog, migraphx::program_parameters& prog_args, bool offload_copy=true): offload_copy(offload_copy) + { + auto input_ids = SAMPLE_IDS; + input_ids.resize(SEQ_SIZE, 0); + std::vector attention_mask = input_ids; + std::transform(std::begin(attention_mask), std::end(attention_mask), std::begin(attention_mask), [](auto i){ + return (i != 0) ? 1 : 0; + }); + + std::vector position_ids; + for (int64_t i=0; i < SEQ_SIZE; ++i) + { + position_ids.emplace_back(i); + } + + auto param_shapes = prog.get_parameter_shapes(); + auto input_ids_str = "input_ids"; + input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); + prog_args.add(input_ids_str, migraphx::argument(param_shapes[input_ids_str], input_ids_buffer->data())); + + auto attention_mask_str = "attention_mask"; + attention_mask_buffer = std::make_unique(std::move(attention_mask), offload_copy); + prog_args.add(attention_mask_str, migraphx::argument(param_shapes[attention_mask_str], attention_mask_buffer->data())); + + auto position_ids_str = "position_ids"; + position_ids_buffer = std::make_unique(std::move(position_ids), offload_copy); + prog_args.add(position_ids_str, migraphx::argument(param_shapes[position_ids_str], position_ids_buffer->data())); + }; + + void upload_to_device() + { + assert(not offload_copy); + input_ids_buffer->upload_to_device(); + attention_mask_buffer->upload_to_device(); + position_ids_buffer->upload_to_device(); + } + + LLama2Inputs() = delete; + LLama2Inputs(const LLama2Inputs &buf) = delete; + LLama2Inputs &operator=(const LLama2Inputs &buf) = delete; + + std::unique_ptr input_ids_buffer; + std::unique_ptr attention_mask_buffer; + std::unique_ptr position_ids_buffer; + bool offload_copy; +}; int main() { - std::cout << "Loading model ..." << std::endl; - migraphx::program prog = USE_ONNX ? loadOnnx() : loadProgram(); + bool offload_copy = false; + std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; + auto mxr_path = offload_copy ? MODEL_PATH + MXR_FILE : MODEL_PATH + MXR_FILE_NO_OFFLOAD; + auto onnx_path = MODEL_PATH + ONNX_FILE; + migraphx::program prog = loadProgram(mxr_path, onnx_path, offload_copy); std::cout << "Model loaded" << std::endl; - prog.print(); - + // Setup model inputs auto output_tokens = SAMPLE_IDS; - SAMPLE_IDS.resize(SEQ_SIZE, 0); - std::vector attention_mask = SAMPLE_IDS; - std::transform(std::begin(attention_mask), std::end(attention_mask), std::begin(attention_mask), [](auto i){ - return (i != 0) ? 1 : 0; - }); - - std::vector position_ids; - for (int64_t i=0; i < SEQ_SIZE; ++i) + migraphx::program_parameters prog_args; + auto model_inputs = LLama2Inputs(prog, prog_args, offload_copy); + if (not offload_copy) { - position_ids.emplace_back(i); + model_inputs.upload_to_device(); } - migraphx::program_parameters prog_args; - auto param_shapes = prog.get_parameter_shapes(); - - size_t alloc_size = SEQ_SIZE * sizeof(int64_t); - - std::cout << "Uploading input ids to the GPU" << std::endl; - auto input_ids_str = "input_ids"; - // auto input_ids_buffer = ManagedBuffer(alloc_size); - // input_ids_buffer.upload_to_device(static_cast(SAMPLE_IDS.data()), alloc_size); - prog_args.add(input_ids_str, migraphx::argument(param_shapes[input_ids_str], SAMPLE_IDS.data())); - - std::cout << "Uploading attention mask to the GPU" << std::endl; - auto attention_mask_str = "attention_mask"; - // auto attention_mask_buffer = ManagedBuffer(alloc_size); - // attention_mask_buffer.upload_to_device(static_cast(attention_mask.data()), alloc_size); - prog_args.add(attention_mask_str, migraphx::argument(param_shapes[attention_mask_str], attention_mask.data())); - - std::cout << "Uploading position ids to the GPU" << std::endl; - auto position_ids_str = "position_ids"; - // auto position_ids_buffer = ManagedBuffer(alloc_size); - // position_ids_buffer.upload_to_device(static_cast(position_ids.data()), alloc_size); - prog_args.add(position_ids_str, migraphx::argument(param_shapes[position_ids_str], position_ids.data())); - - // Handle output tensors - // std::cout << "Creating output buffer" << std::endl; + // Setup model output for non-offload copy const size_t output_size = SEQ_SIZE * VOCAB_SIZE; - // name = "@return"; - // auto output_buffer = ManagedBuffer(output_size); - // migraphx::shape outShape{migraphx_shape_float_type, {1, 256, 32000}}; - // prog_args.add(name, migraphx::argument(outShape, output_buffer.get_device_ptr())); + auto output_name = "main:#output_0"; + auto output_buffer = LLama2OutputBuffer(std::vector(output_size), offload_copy); + migraphx::shape out_shape{migraphx_shape_float_type, {1, SEQ_SIZE, VOCAB_SIZE}}; + prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); std::cout << "Starting evaluation" << std::endl; - for (int i = 5; i < SEQ_SIZE; ++i) + for (int i = SAMPLE_IDS.size() - 1; i < SEQ_SIZE - 1; ++i) { - std::cout << "# iter: " << i << std::endl; + // std::cout << "# iter: " << i << std::endl; auto outputs = prog.eval(prog_args); // TODO: Only download the relevant data range - float* results = reinterpret_cast(outputs[0].data()); + if (not offload_copy) + { + output_buffer.download_from_device(); + } + float* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast(output_buffer.hbuff.data()); std::vector logits(results, results + output_size); - // std::cout << "## logits size: " << logits.size() << std::endl; std::vector::iterator max = std::max_element(std::begin(logits) + (i * VOCAB_SIZE), std::begin(logits) + ((i + 1) * VOCAB_SIZE)); int64_t new_token = std::distance(std::begin(logits) + (i * VOCAB_SIZE), max); + // std::cout << "New token: " << new_token << std::endl; output_tokens.push_back(new_token); if (new_token == EOS) { break; } - SAMPLE_IDS[i + 1] = new_token; - prog_args.add(input_ids_str, migraphx::argument(param_shapes[input_ids_str], SAMPLE_IDS.data())); - attention_mask[i + 1] = 1; - prog_args.add(attention_mask_str, migraphx::argument(param_shapes[attention_mask_str], attention_mask.data())); - // input_ids_buffer.update_device_data(new_token, i + 1); - // attention_mask_buffer.update_device_data(1, i + 1); + model_inputs.input_ids_buffer->update_data(new_token, i +1); + model_inputs.attention_mask_buffer->update_data(1, i +1); } std::cout << "######### Output token ids #########" << std::endl; From 1b78a1c1f5c33d38c7cdd555888f6dcce89a167a Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Thu, 24 Oct 2024 11:55:26 -0500 Subject: [PATCH 06/55] Use dedicated hipStream for synchronization --- .../mgx_llama2/harness/buffer.hpp | 14 ++++---- examples/transformers/mgx_llama2/mgxllama2.cc | 32 ++++++++++++------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/examples/transformers/mgx_llama2/harness/buffer.hpp b/examples/transformers/mgx_llama2/harness/buffer.hpp index 17391091a70..f29ce054585 100644 --- a/examples/transformers/mgx_llama2/harness/buffer.hpp +++ b/examples/transformers/mgx_llama2/harness/buffer.hpp @@ -125,28 +125,26 @@ namespace mlinfer return with_offload_copy ? static_cast(hbuff.data()) : dbuff.tensor_ptr; } - void upload_to_device() + void upload_to_device(hipStream_t stream) { assert(not with_offload_copy); - check_hip_status(hipMemcpyHtoD(dbuff.tensor_ptr, static_cast(hbuff.data()), size_in_bytes)); + check_hip_status(hipMemcpyHtoDAsync(dbuff.tensor_ptr, static_cast(hbuff.data()), size_in_bytes, stream)); } - void download_from_device() + void download_from_device(hipStream_t stream) { assert(not with_offload_copy); - // TODO: use a separate stream for eval and upload download, so we don't have to sync here - check_hip_status(hipDeviceSynchronize()); - check_hip_status(hipMemcpyDtoH(static_cast(hbuff.data()), dbuff.tensor_ptr, size_in_bytes)); + check_hip_status(hipMemcpyDtoHAsync(static_cast(hbuff.data()), dbuff.tensor_ptr, size_in_bytes, stream)); } - void update_data(T data, size_t position) + void update_data(T data, size_t position, hipStream_t stream) { hbuff.at(position) = data; if (not with_offload_copy) { // TODO: don't copy over the entire buffer just the changed range // check_hip_status(hipMemcpy(get_device_ptr(), get_host_ptr(), dbuff.size_in_bytes, hipMemcpyKind::hipMemcpyHostToDevice)); - upload_to_device(); + upload_to_device(stream); } } diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index d0f4b4103bc..ff85a447ef0 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -72,7 +72,7 @@ static migraphx::program loadProgram(std::string& mxr_path, std::string& model_p std::ifstream f(compiled_path.c_str()); if (f.good()) { - std::cout << "Loadind model from MXR ...\n"; + std::cout << "Loading model from MXR ...\n"; prog = migraphx::load(compiled_path.c_str(), file_options); } else @@ -85,7 +85,11 @@ static migraphx::program loadProgram(std::string& mxr_path, std::string& model_p struct LLama2Inputs { - LLama2Inputs(migraphx::program& prog, migraphx::program_parameters& prog_args, bool offload_copy=true): offload_copy(offload_copy) + LLama2Inputs( + migraphx::program& prog, + migraphx::program_parameters& prog_args, + bool offload_copy) + : offload_copy(offload_copy) { auto input_ids = SAMPLE_IDS; input_ids.resize(SEQ_SIZE, 0); @@ -114,12 +118,12 @@ struct LLama2Inputs prog_args.add(position_ids_str, migraphx::argument(param_shapes[position_ids_str], position_ids_buffer->data())); }; - void upload_to_device() + void upload_to_device(hipStream_t stream) { assert(not offload_copy); - input_ids_buffer->upload_to_device(); - attention_mask_buffer->upload_to_device(); - position_ids_buffer->upload_to_device(); + input_ids_buffer->upload_to_device(stream); + attention_mask_buffer->upload_to_device(stream); + position_ids_buffer->upload_to_device(stream); } LLama2Inputs() = delete; @@ -143,10 +147,12 @@ int main() { // Setup model inputs auto output_tokens = SAMPLE_IDS; migraphx::program_parameters prog_args; + hipStream_t stream; + check_hip_status(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); auto model_inputs = LLama2Inputs(prog, prog_args, offload_copy); if (not offload_copy) { - model_inputs.upload_to_device(); + model_inputs.upload_to_device(stream); } // Setup model output for non-offload copy @@ -157,15 +163,17 @@ int main() { prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); std::cout << "Starting evaluation" << std::endl; - for (int i = SAMPLE_IDS.size() - 1; i < SEQ_SIZE - 1; ++i) + for (int i = SAMPLE_IDS.size() - 1; i < SEQ_SIZE; ++i) { // std::cout << "# iter: " << i << std::endl; - auto outputs = prog.eval(prog_args); + auto outputs = prog.run_async(prog_args, stream); // TODO: Only download the relevant data range if (not offload_copy) { - output_buffer.download_from_device(); + output_buffer.download_from_device(stream); } + + check_hip_status(hipStreamSynchronize(stream)); float* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast(output_buffer.hbuff.data()); std::vector logits(results, results + output_size); std::vector::iterator max = std::max_element(std::begin(logits) + (i * VOCAB_SIZE), std::begin(logits) + ((i + 1) * VOCAB_SIZE)); @@ -176,8 +184,8 @@ int main() { { break; } - model_inputs.input_ids_buffer->update_data(new_token, i +1); - model_inputs.attention_mask_buffer->update_data(1, i +1); + model_inputs.input_ids_buffer->update_data(new_token, i +1, stream); + model_inputs.attention_mask_buffer->update_data(1, i +1, stream); } std::cout << "######### Output token ids #########" << std::endl; From 0731e262f00ef94b1955cd8715312cf5959b03ef Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Fri, 25 Oct 2024 06:32:34 -0500 Subject: [PATCH 07/55] Save onnx model to mxr file --- examples/transformers/mgx_llama2/mgxllama2.cc | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index ff85a447ef0..1993a944546 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -12,7 +12,7 @@ using namespace mlinfer; // TODO: fix paths const std::string MODEL_PATH = "/code/AMDMIGraphX/examples/transformers/python_llama2/models/llama-2-7b-chat-hf/"; -const std::string MXR_FILE = "model-256.mxr"; +const std::string MXR_FILE = "model-256_fp32_offload.mxr"; const std::string MXR_FILE_NO_OFFLOAD = "model-256_fp32_nooffload.mxr"; const std::string ONNX_FILE = "model.onnx"; std::vector SAMPLE_IDS = {1,6804,5207,387,287,29973}; @@ -23,6 +23,22 @@ const size_t VOCAB_SIZE = 32000; // EOS token from model config const size_t EOS = 2; +static std::string getModelPath(bool offload_copy, bool quantize_fp16) +{ + std::string path{MODEL_PATH + "model-" + std::to_string(SEQ_SIZE)}; + + path += "_fp"; + path += quantize_fp16 ? "16" : "32"; + + path += "_"; + if (!offload_copy) + { + path += "no"; + } + path += "offload"; + return path; +} + // TODO: enable fp16 quant and fast math static migraphx::program loadOnnx(std::string& model_path, bool offload_copy, bool quantize_fp16 = false) { @@ -51,7 +67,12 @@ static migraphx::program loadOnnx(std::string& model_path, bool offload_copy, bo comp_opts.set_offload_copy(); std::cout << "Compile to target..." << std::endl; prog.compile(targ, comp_opts); - // TODO: save model to mxr + + std::string modelPath = getModelPath(offload_copy, quantize_fp16) + ".mxr"; + migraphx::file_options file_options; + file_options.set_file_format("msgpack"); + std::cout << "Saving mxr file to: " << modelPath << "\n"; + migraphx::save(prog, modelPath.c_str(), file_options); } else { From 2ab665907673ca9568563f1614285ad94cf8a533 Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Fri, 25 Oct 2024 06:12:54 -0500 Subject: [PATCH 08/55] Only copy changed data --- .../mgx_llama2/harness/buffer.hpp | 34 ++++++++++++++++--- examples/transformers/mgx_llama2/mgxllama2.cc | 4 +-- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/examples/transformers/mgx_llama2/harness/buffer.hpp b/examples/transformers/mgx_llama2/harness/buffer.hpp index f29ce054585..c5d3694133e 100644 --- a/examples/transformers/mgx_llama2/harness/buffer.hpp +++ b/examples/transformers/mgx_llama2/harness/buffer.hpp @@ -125,16 +125,40 @@ namespace mlinfer return with_offload_copy ? static_cast(hbuff.data()) : dbuff.tensor_ptr; } - void upload_to_device(hipStream_t stream) + void upload_to_device(hipStream_t stream, size_t start_idx=0, size_t end_idx=0) { assert(not with_offload_copy); - check_hip_status(hipMemcpyHtoDAsync(dbuff.tensor_ptr, static_cast(hbuff.data()), size_in_bytes, stream)); + char* src_addr = reinterpret_cast(hbuff.data()); + char* dst_addr = static_cast(dbuff.tensor_ptr); + size_t copy_size_in_bytes = size_in_bytes; + + size_t range_size_in_bytes = (end_idx - start_idx) * sizeof(T); + if (range_size_in_bytes > 0) + { + size_t offset = start_idx * sizeof(T); + src_addr += offset; + dst_addr += offset; + copy_size_in_bytes = range_size_in_bytes; + } + check_hip_status(hipMemcpyHtoDAsync(dst_addr, src_addr, copy_size_in_bytes, stream)); } - void download_from_device(hipStream_t stream) + void download_from_device(hipStream_t stream, size_t start_idx=0, size_t end_idx=0) { assert(not with_offload_copy); - check_hip_status(hipMemcpyDtoHAsync(static_cast(hbuff.data()), dbuff.tensor_ptr, size_in_bytes, stream)); + char* src_addr = static_cast(dbuff.tensor_ptr); + char* dst_addr = reinterpret_cast(hbuff.data()); + size_t copy_size_in_bytes = size_in_bytes; + + size_t range_size_in_bytes = (end_idx - start_idx) * sizeof(T); + if (range_size_in_bytes > 0) + { + size_t offset = start_idx * sizeof(T); + src_addr += offset; + dst_addr += offset; + copy_size_in_bytes = range_size_in_bytes; + } + check_hip_status(hipMemcpyDtoHAsync(dst_addr, src_addr, copy_size_in_bytes, stream)); } void update_data(T data, size_t position, hipStream_t stream) @@ -144,7 +168,7 @@ namespace mlinfer { // TODO: don't copy over the entire buffer just the changed range // check_hip_status(hipMemcpy(get_device_ptr(), get_host_ptr(), dbuff.size_in_bytes, hipMemcpyKind::hipMemcpyHostToDevice)); - upload_to_device(stream); + upload_to_device(stream, position, position + 1); } } diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 1993a944546..e3bd63d7187 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -186,12 +186,10 @@ int main() { std::cout << "Starting evaluation" << std::endl; for (int i = SAMPLE_IDS.size() - 1; i < SEQ_SIZE; ++i) { - // std::cout << "# iter: " << i << std::endl; auto outputs = prog.run_async(prog_args, stream); - // TODO: Only download the relevant data range if (not offload_copy) { - output_buffer.download_from_device(stream); + output_buffer.download_from_device(stream, i * VOCAB_SIZE, (i + 1) * VOCAB_SIZE); } check_hip_status(hipStreamSynchronize(stream)); From e7f84c2b94ca186dbe33948c1758cc62879a79b5 Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Fri, 25 Oct 2024 08:20:20 -0500 Subject: [PATCH 09/55] Extend model loading options with fast_math --- examples/transformers/mgx_llama2/mgxllama2.cc | 71 +++++++++++-------- 1 file changed, 42 insertions(+), 29 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index e3bd63d7187..0d218a63b05 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -11,9 +11,7 @@ using namespace mlinfer; // TODO: fix paths -const std::string MODEL_PATH = "/code/AMDMIGraphX/examples/transformers/python_llama2/models/llama-2-7b-chat-hf/"; -const std::string MXR_FILE = "model-256_fp32_offload.mxr"; -const std::string MXR_FILE_NO_OFFLOAD = "model-256_fp32_nooffload.mxr"; +const std::string MODEL_FOLDER = "/code/AMDMIGraphX/examples/transformers/python_llama2/models/llama-2-7b-chat-hf/"; const std::string ONNX_FILE = "model.onnx"; std::vector SAMPLE_IDS = {1,6804,5207,387,287,29973}; // sequence length from model config @@ -23,26 +21,34 @@ const size_t VOCAB_SIZE = 32000; // EOS token from model config const size_t EOS = 2; -static std::string getModelPath(bool offload_copy, bool quantize_fp16) +struct ModelLoadSettings { - std::string path{MODEL_PATH + "model-" + std::to_string(SEQ_SIZE)}; - - path += "_fp"; - path += quantize_fp16 ? "16" : "32"; + size_t sequnce_length; + bool quantize_fp16; + bool offload_copy; + bool fast_math; +}; - path += "_"; - if (!offload_copy) +static std::string getModelPath(ModelLoadSettings& s) +{ + std::stringstream path; + path << MODEL_FOLDER << "model-" << std::to_string(s.sequnce_length) << "_fp" << (s.quantize_fp16 ? "16" : "32") << "_"; + if (!s.offload_copy) + { + path << "no"; + } + path << "offload_"; + if (!s.fast_math) { - path += "no"; + path << "no"; } - path += "offload"; - return path; + path << "fastmath.mxr"; + return path.str(); } -// TODO: enable fp16 quant and fast math -static migraphx::program loadOnnx(std::string& model_path, bool offload_copy, bool quantize_fp16 = false) +static migraphx::program loadOnnx(ModelLoadSettings& settings) { - std::filesystem::path onnx_path(model_path); + std::filesystem::path onnx_path(MODEL_FOLDER + ONNX_FILE); migraphx::program prog; std::ifstream f(onnx_path.c_str()); @@ -60,15 +66,23 @@ static migraphx::program loadOnnx(std::string& model_path, bool offload_copy, bo migraphx::target targ = migraphx::target(target_str.c_str()); std::cout << "Quantize FP16 ..." << std::endl; - if (quantize_fp16) + if (settings.quantize_fp16) migraphx::quantize_fp16(prog); migraphx::compile_options comp_opts; - comp_opts.set_offload_copy(); - std::cout << "Compile to target..." << std::endl; + + if (settings.offload_copy) + comp_opts.set_offload_copy(); + + if (settings.fast_math) + comp_opts.set_fast_math(); + + comp_opts.set_exhaustive_tune_flag(); + + std::cout << "Compile to target ..." << std::endl; prog.compile(targ, comp_opts); - std::string modelPath = getModelPath(offload_copy, quantize_fp16) + ".mxr"; + std::string modelPath = getModelPath(settings); migraphx::file_options file_options; file_options.set_file_format("msgpack"); std::cout << "Saving mxr file to: " << modelPath << "\n"; @@ -76,15 +90,15 @@ static migraphx::program loadOnnx(std::string& model_path, bool offload_copy, bo } else { - std::cerr << "Onnx file is not available on path: " << model_path << std::endl; + std::cerr << "Onnx file is not available on path: " << onnx_path << std::endl; exit(1); } return prog; }; -static migraphx::program loadProgram(std::string& mxr_path, std::string& model_path, bool offload_copy) +static migraphx::program loadProgram(ModelLoadSettings& settings) { - std::filesystem::path compiled_path(mxr_path); + std::filesystem::path compiled_path(getModelPath(settings)); migraphx::file_options file_options; file_options.set_file_format("msgpack"); @@ -93,13 +107,13 @@ static migraphx::program loadProgram(std::string& mxr_path, std::string& model_p std::ifstream f(compiled_path.c_str()); if (f.good()) { - std::cout << "Loading model from MXR ...\n"; + std::cout << "Loading model from " << compiled_path << " ...\n"; prog = migraphx::load(compiled_path.c_str(), file_options); } else { std::cout << "MXR file can't be loaded try to load ONNX\n"; - prog = loadOnnx(model_path, offload_copy); + prog = loadOnnx(settings); } return prog; }; @@ -112,7 +126,7 @@ struct LLama2Inputs bool offload_copy) : offload_copy(offload_copy) { - auto input_ids = SAMPLE_IDS; + auto input_ids = SAMPLE_IDS; input_ids.resize(SEQ_SIZE, 0); std::vector attention_mask = input_ids; std::transform(std::begin(attention_mask), std::end(attention_mask), std::begin(attention_mask), [](auto i){ @@ -160,9 +174,8 @@ struct LLama2Inputs int main() { bool offload_copy = false; std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; - auto mxr_path = offload_copy ? MODEL_PATH + MXR_FILE : MODEL_PATH + MXR_FILE_NO_OFFLOAD; - auto onnx_path = MODEL_PATH + ONNX_FILE; - migraphx::program prog = loadProgram(mxr_path, onnx_path, offload_copy); + ModelLoadSettings settings = {SEQ_SIZE, true /*quantize_fp16*/, false /*offload_copy*/, true /*fast_math*/}; + migraphx::program prog = loadProgram(settings); std::cout << "Model loaded" << std::endl; // Setup model inputs From 0fe79245a75966b7c325032a9f9b085377205d8d Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Fri, 25 Oct 2024 08:35:52 -0500 Subject: [PATCH 10/55] Fix quant message --- examples/transformers/mgx_llama2/mgxllama2.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 0d218a63b05..d334b2df969 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -65,9 +65,11 @@ static migraphx::program loadOnnx(ModelLoadSettings& settings) std::string target_str = "gpu"; migraphx::target targ = migraphx::target(target_str.c_str()); - std::cout << "Quantize FP16 ..." << std::endl; if (settings.quantize_fp16) + { + std::cout << "Quantize FP16 ..." << std::endl; migraphx::quantize_fp16(prog); + } migraphx::compile_options comp_opts; From 984e0dcbbd55a5de844841715f4ba22d65b6be1f Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Mon, 28 Oct 2024 03:12:17 -0500 Subject: [PATCH 11/55] Basic tokens/sec counting --- examples/transformers/mgx_llama2/mgxllama2.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index d334b2df969..3f6af07d02f 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -199,7 +199,9 @@ int main() { prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); std::cout << "Starting evaluation" << std::endl; - for (int i = SAMPLE_IDS.size() - 1; i < SEQ_SIZE; ++i) + size_t token_count = 0; + auto start = std::chrono::steady_clock::now(); + for (int i = SAMPLE_IDS.size() - 1; i < SEQ_SIZE - 1; ++i) { auto outputs = prog.run_async(prog_args, stream); if (not offload_copy) @@ -212,6 +214,7 @@ int main() { std::vector logits(results, results + output_size); std::vector::iterator max = std::max_element(std::begin(logits) + (i * VOCAB_SIZE), std::begin(logits) + ((i + 1) * VOCAB_SIZE)); int64_t new_token = std::distance(std::begin(logits) + (i * VOCAB_SIZE), max); + token_count++; // std::cout << "New token: " << new_token << std::endl; output_tokens.push_back(new_token); if (new_token == EOS) @@ -221,6 +224,10 @@ int main() { model_inputs.input_ids_buffer->update_data(new_token, i +1, stream); model_inputs.attention_mask_buffer->update_data(1, i +1, stream); } + float dur = std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count() / 1000.f; + std::cout << "Duration: " << dur << " seconds." << std::endl; + std::cout << "Completed " << token_count << " tokens." << std::endl; + std::cout << "Tokens/sec: " << token_count / dur << std::endl; std::cout << "######### Output token ids #########" << std::endl; // print output tokens From 4134505dd10069af5b738bf21768ff3161b265d0 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Mon, 28 Oct 2024 04:38:48 -0500 Subject: [PATCH 12/55] Add preprocess dataset script --- .../mgx_llama2/preprocess_dataset.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 examples/transformers/mgx_llama2/preprocess_dataset.py diff --git a/examples/transformers/mgx_llama2/preprocess_dataset.py b/examples/transformers/mgx_llama2/preprocess_dataset.py new file mode 100644 index 00000000000..56411ccae23 --- /dev/null +++ b/examples/transformers/mgx_llama2/preprocess_dataset.py @@ -0,0 +1,36 @@ +import numpy as np +import pickle +from pathlib import Path +import os + +G_MAX_TOK_LEN = 1024 +G_LLAMA2_EOS = 2 + +DATASET_PATH = "/dataset/open_orca_gpt4_tokenized_llama.sampled_24576.pkl" +OUTPUT_PATH = "/dataset/" + +_p = Path(DATASET_PATH) +if _p.exists(): + with _p.open(mode="rb") as f: + d = pickle.load(f) + +toks = d['tok_input'].to_list() +#toks = [toks[0]] + +toks_np = np.ones((len(toks), G_MAX_TOK_LEN), dtype=np.int64) * G_LLAMA2_EOS +mask_np = np.zeros((len(toks), G_MAX_TOK_LEN), dtype=np.int64) +position_nps = [np.arange(0, G_MAX_TOK_LEN, dtype=np.int64) for _ in range(len(toks))] + + +for i, q in enumerate(toks): + toks_np[i, :len(q)] = q + mask_np[i, :len(q)] = np.ones_like(q) + + +token_size = len(toks) + +np.save(f"{OUTPUT_PATH}input_ids_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", toks_np) +np.save(f"{OUTPUT_PATH}attention_mask_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", mask_np) +np.save(f"{OUTPUT_PATH}position_ids_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", position_nps) + +print("Npy filed are created") From 6735534e350d0f9d8460ff7cc9c11f34f31f9480 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Mon, 28 Oct 2024 07:39:06 -0500 Subject: [PATCH 13/55] Use dataset from numpy files if available --- examples/transformers/mgx_llama2/mgxllama2.cc | 132 +++++++++++++++--- .../transformers/mgx_llama2/run_docker.sh | 8 ++ 2 files changed, 123 insertions(+), 17 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 3f6af07d02f..e1b5e4a6030 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -11,8 +11,9 @@ using namespace mlinfer; // TODO: fix paths -const std::string MODEL_FOLDER = "/code/AMDMIGraphX/examples/transformers/python_llama2/models/llama-2-7b-chat-hf/"; +const std::string MODEL_FOLDER = "/model/"; const std::string ONNX_FILE = "model.onnx"; +const std::string DATASET_FOLDER = "/dataset/"; std::vector SAMPLE_IDS = {1,6804,5207,387,287,29973}; // sequence length from model config const size_t SEQ_SIZE = 256; @@ -120,6 +121,108 @@ static migraphx::program loadProgram(ModelLoadSettings& settings) return prog; }; +using NumpyVector = std::vector>; + +struct Dataset +{ + Dataset() = default; + + void initialize() + { + std::string input_file_path = DATASET_FOLDER + "input_ids_size_1_seq_256.npy"; + std::string attention_mask_file_path = DATASET_FOLDER + "attention_mask_size_1_seq_256.npy"; + std::string position_ids_file_path = DATASET_FOLDER + "position_ids_size_1_seq_256.npy"; + std::ifstream input_file(input_file_path.c_str()); + std::ifstream attention_mask_file(attention_mask_file_path.c_str()); + std::ifstream position_ids_file(position_ids_file_path.c_str()); + if (input_file.good() && attention_mask_file.good() && position_ids_file.good()) + { + npy::NpyFile input_ids_npy{input_file_path}; + npy::NpyFile attention_mask_npy{attention_mask_file_path}; + npy::NpyFile position_ids_npy{position_ids_file_path}; + input_ids = loadNumpy(input_ids_npy); + attention_mask = loadNumpy(attention_mask_npy); + position_ids = loadNumpy(position_ids_npy); + + if (input_ids.size() == attention_mask.size() == position_ids.size()) + { + std::cout << "Loaded numpy files\n"; + npy_files_loaded = true; + } + else + { + std::cout << "Numpy files do not have the same size\n"; + input_ids.clear(); + attention_mask.clear(); + position_ids.clear(); + } + } + + if (!npy_files_loaded) + { + std::cout << "Numpy files are not loaded, using dummy data\n"; + auto input_ids_sample = SAMPLE_IDS; + input_ids_sample.resize(SEQ_SIZE, EOS); + input_ids.emplace_back(input_ids_sample); + std::vector attention_mask_sample = input_ids_sample; + std::transform(std::begin(attention_mask_sample), std::end(attention_mask_sample), std::begin(attention_mask_sample), [](auto i){ + return (i != EOS) ? 1 : 0; + }); + attention_mask.emplace_back(attention_mask_sample); + + std::vector position_ids_sample; + for (int64_t i=0; i < SEQ_SIZE; ++i) + { + position_ids_sample.emplace_back(i); + } + position_ids.emplace_back(std::move(position_ids_sample)); + } + + } + + NumpyVector loadNumpy(npy::NpyFile& file) + { + NumpyVector numpyData; + auto load_size = file.GetTensorSize()/sizeof(int64_t); + numpyData.push_back(std::vector(load_size)); + file.LoadAll(numpyData.back().data()); + + #ifdef TRACE + for (auto& vec: numpyData) + { + for (auto val: vec) + { + std::cout << val << " "; + } + std::cout << "\n"; + } + #endif + return numpyData; + } + + size_t getLastIdx() const + { + auto res = std::find_if(std::rbegin(attention_mask[current_idx]), std::rend(attention_mask[current_idx]), [](uint64_t val) { return 1 == val;}); + size_t last_idx = std::distance(res, std::rend(attention_mask[current_idx])); + //std::cout << "Last input idx: " << last_idx << std::endl; + return last_idx; + } + + std::vector getInputIds() { return input_ids[current_idx]; } + std::vector getAttentionMask() { return attention_mask[current_idx]; } + std::vector getPositionIds() { return position_ids[current_idx]; } + + Dataset(const Dataset &buf) = delete; + Dataset &operator=(const Dataset &buf) = delete; + + NumpyVector input_ids; + NumpyVector attention_mask; + NumpyVector position_ids; + + size_t current_idx = 0; + bool npy_files_loaded = false; +}; + struct LLama2Inputs { LLama2Inputs( @@ -128,28 +231,20 @@ struct LLama2Inputs bool offload_copy) : offload_copy(offload_copy) { - auto input_ids = SAMPLE_IDS; - input_ids.resize(SEQ_SIZE, 0); - std::vector attention_mask = input_ids; - std::transform(std::begin(attention_mask), std::end(attention_mask), std::begin(attention_mask), [](auto i){ - return (i != 0) ? 1 : 0; - }); - - std::vector position_ids; - for (int64_t i=0; i < SEQ_SIZE; ++i) - { - position_ids.emplace_back(i); - } + data.initialize(); + auto input_ids = data.getInputIds(); auto param_shapes = prog.get_parameter_shapes(); auto input_ids_str = "input_ids"; input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); prog_args.add(input_ids_str, migraphx::argument(param_shapes[input_ids_str], input_ids_buffer->data())); + auto attention_mask = data.getAttentionMask(); auto attention_mask_str = "attention_mask"; attention_mask_buffer = std::make_unique(std::move(attention_mask), offload_copy); prog_args.add(attention_mask_str, migraphx::argument(param_shapes[attention_mask_str], attention_mask_buffer->data())); + auto position_ids = data.getPositionIds(); auto position_ids_str = "position_ids"; position_ids_buffer = std::make_unique(std::move(position_ids), offload_copy); prog_args.add(position_ids_str, migraphx::argument(param_shapes[position_ids_str], position_ids_buffer->data())); @@ -163,6 +258,8 @@ struct LLama2Inputs position_ids_buffer->upload_to_device(stream); } + size_t getLastInputIndex() { return data.getLastIdx(); } + LLama2Inputs() = delete; LLama2Inputs(const LLama2Inputs &buf) = delete; LLama2Inputs &operator=(const LLama2Inputs &buf) = delete; @@ -170,18 +267,19 @@ struct LLama2Inputs std::unique_ptr input_ids_buffer; std::unique_ptr attention_mask_buffer; std::unique_ptr position_ids_buffer; + Dataset data; bool offload_copy; }; int main() { - bool offload_copy = false; + bool offload_copy = true; std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; - ModelLoadSettings settings = {SEQ_SIZE, true /*quantize_fp16*/, false /*offload_copy*/, true /*fast_math*/}; + ModelLoadSettings settings = {SEQ_SIZE, true /*quantize_fp16*/, offload_copy /*offload_copy*/, true /*fast_math*/}; migraphx::program prog = loadProgram(settings); std::cout << "Model loaded" << std::endl; // Setup model inputs - auto output_tokens = SAMPLE_IDS; + std::vector output_tokens; migraphx::program_parameters prog_args; hipStream_t stream; check_hip_status(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -201,7 +299,7 @@ int main() { std::cout << "Starting evaluation" << std::endl; size_t token_count = 0; auto start = std::chrono::steady_clock::now(); - for (int i = SAMPLE_IDS.size() - 1; i < SEQ_SIZE - 1; ++i) + for (int i = model_inputs.getLastInputIndex(); i < SEQ_SIZE - 1; ++i) { auto outputs = prog.run_async(prog_args, stream); if (not offload_copy) diff --git a/examples/transformers/mgx_llama2/run_docker.sh b/examples/transformers/mgx_llama2/run_docker.sh index 73d41181c26..a0c7e2da7ed 100755 --- a/examples/transformers/mgx_llama2/run_docker.sh +++ b/examples/transformers/mgx_llama2/run_docker.sh @@ -7,8 +7,16 @@ else MODEL_DIR="${MODEL_DIR_PATH}" fi +if [[ -z "${DATA_DIR_PATH}" ]]; then + echo "DATA_DIR_PATH is not set, please provide the path to dataset before running docker." + exit 1 +else + DATA_DIR="${DATA_DIR_PATH}" +fi + docker run --device='/dev/kfd' --device='/dev/dri' --group-add video \ -v $MODEL_DIR:/model \ +-v $DATA_DIR:/dataset \ -w /mgx_llama2/build \ -it mgx_llama2:v0.1 From f1960f7e6b85535b83532b10fcbeff345b84861b Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Tue, 29 Oct 2024 07:06:53 -0500 Subject: [PATCH 14/55] Support npy dataset with multiple samples --- .../mgx_llama2/harness/buffer.hpp | 5 + examples/transformers/mgx_llama2/mgxllama2.cc | 277 ++++++++++++------ 2 files changed, 191 insertions(+), 91 deletions(-) diff --git a/examples/transformers/mgx_llama2/harness/buffer.hpp b/examples/transformers/mgx_llama2/harness/buffer.hpp index c5d3694133e..fc7fa064b6e 100644 --- a/examples/transformers/mgx_llama2/harness/buffer.hpp +++ b/examples/transformers/mgx_llama2/harness/buffer.hpp @@ -125,6 +125,11 @@ namespace mlinfer return with_offload_copy ? static_cast(hbuff.data()) : dbuff.tensor_ptr; } + void update(std::vector&& host_data) + { + hbuff = std::move(host_data); + } + void upload_to_device(hipStream_t stream, size_t start_idx=0, size_t end_idx=0) { assert(not with_offload_copy); diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index e1b5e4a6030..47048c7991a 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -14,7 +14,7 @@ using namespace mlinfer; const std::string MODEL_FOLDER = "/model/"; const std::string ONNX_FILE = "model.onnx"; const std::string DATASET_FOLDER = "/dataset/"; -std::vector SAMPLE_IDS = {1,6804,5207,387,287,29973}; +const size_t DATASET_SIZE = 3; // sequence length from model config const size_t SEQ_SIZE = 256; // vocab size from model config @@ -129,9 +129,88 @@ struct Dataset void initialize() { - std::string input_file_path = DATASET_FOLDER + "input_ids_size_1_seq_256.npy"; - std::string attention_mask_file_path = DATASET_FOLDER + "attention_mask_size_1_seq_256.npy"; - std::string position_ids_file_path = DATASET_FOLDER + "position_ids_size_1_seq_256.npy"; + loadDataset(); + if (!_npy_files_loaded) + { + prepareSampleDataset(); + } + } + + NumpyVector loadNumpy(npy::NpyFile& file) + { + NumpyVector numpyDataAll; + auto load_size = file.GetTensorSize()/sizeof(int64_t); + numpyDataAll.push_back(std::vector(load_size)); + file.LoadAll(numpyDataAll.back().data()); + + NumpyVector numpyData; + for(size_t i = 0; i < numpyDataAll.back().size(); i += SEQ_SIZE) + { + auto last = std::min(numpyDataAll.back().size(), i + SEQ_SIZE); + numpyData.emplace_back(numpyDataAll.back().begin() + i, numpyDataAll.back().begin() + last); + } + +#ifdef TRACE + for (auto& vec: numpyData) + { + std::cout << "Vector size: " << vec.size() << std::endl; + for (auto val: vec) + { + std::cout << val << " "; + } + std::cout << "\n"; + } +#endif + return numpyData; + } + + size_t getLastIdx() const + { + auto res = std::find_if(std::rbegin(attention_mask[_current_idx]), std::rend(attention_mask[_current_idx]), [](uint64_t val) { return 1 == val;}); + size_t last_idx = std::distance(res, std::rend(attention_mask[_current_idx])); + #ifdef TRACE + std::cout << "Last input idx: " << last_idx << std::endl; + #endif + return last_idx; + } + + std::vector getInputIds() { return input_ids[_current_idx]; } + std::vector getAttentionMask() { return attention_mask[_current_idx]; } + std::vector getPositionIds() { return position_ids[_current_idx]; } + + size_t size() const { return _size; } + size_t currentIdx() const { return _current_idx; } + size_t getNext() + { + if (_current_idx < size() - 1) + { + ++_current_idx; + } + #ifdef TRACE + std::cout << "Current idx: " << _current_idx << std::endl; + #endif + return _current_idx; + } + + Dataset(const Dataset &buf) = delete; + Dataset &operator=(const Dataset &buf) = delete; +private: + + // e.g.: /dataset/input_ids_size_3_seq_256.npy + std::string getDatasetPath(const std::string& datasetName) + { + std::stringstream path; + path << DATASET_FOLDER << datasetName << "_size_" << std::to_string(DATASET_SIZE) << "_seq_" << std::to_string(SEQ_SIZE) << ".npy"; + return path.str(); + } + + void loadDataset() + { + std::string input_file_path = getDatasetPath("input_ids"); + std::string attention_mask_file_path = getDatasetPath("attention_mask"); + std::string position_ids_file_path = getDatasetPath("position_ids"); + + std::cout << "Input ids file: " << input_file_path << std::endl; std::ifstream input_file(input_file_path.c_str()); std::ifstream attention_mask_file(attention_mask_file_path.c_str()); std::ifstream position_ids_file(position_ids_file_path.c_str()); @@ -144,10 +223,12 @@ struct Dataset attention_mask = loadNumpy(attention_mask_npy); position_ids = loadNumpy(position_ids_npy); - if (input_ids.size() == attention_mask.size() == position_ids.size()) + _size = input_ids.size(); + + if ((input_ids.size() == attention_mask.size()) && (attention_mask.size() == position_ids.size())) { std::cout << "Loaded numpy files\n"; - npy_files_loaded = true; + _npy_files_loaded = true; } else { @@ -157,70 +238,37 @@ struct Dataset position_ids.clear(); } } - - if (!npy_files_loaded) - { - std::cout << "Numpy files are not loaded, using dummy data\n"; - auto input_ids_sample = SAMPLE_IDS; - input_ids_sample.resize(SEQ_SIZE, EOS); - input_ids.emplace_back(input_ids_sample); - std::vector attention_mask_sample = input_ids_sample; - std::transform(std::begin(attention_mask_sample), std::end(attention_mask_sample), std::begin(attention_mask_sample), [](auto i){ - return (i != EOS) ? 1 : 0; - }); - attention_mask.emplace_back(attention_mask_sample); - - std::vector position_ids_sample; - for (int64_t i=0; i < SEQ_SIZE; ++i) - { - position_ids_sample.emplace_back(i); - } - position_ids.emplace_back(std::move(position_ids_sample)); - } - } - NumpyVector loadNumpy(npy::NpyFile& file) + void prepareSampleDataset() { - NumpyVector numpyData; - auto load_size = file.GetTensorSize()/sizeof(int64_t); - numpyData.push_back(std::vector(load_size)); - file.LoadAll(numpyData.back().data()); - - #ifdef TRACE - for (auto& vec: numpyData) + std::cout << "Numpy files are not loaded, using dummy data\n"; + std::vector input_ids_sample = {1,6804,5207,387,287,29973}; + input_ids_sample.resize(SEQ_SIZE, EOS); + std::vector attention_mask_sample = input_ids_sample; + input_ids.emplace_back(std::move(input_ids_sample)); + std::transform(std::begin(attention_mask_sample), std::end(attention_mask_sample), std::begin(attention_mask_sample), [](auto i){ + return (i != EOS) ? 1 : 0; + }); + attention_mask.emplace_back(std::move(attention_mask_sample)); + + std::vector position_ids_sample; + for (int64_t i=0; i < SEQ_SIZE; ++i) { - for (auto val: vec) - { - std::cout << val << " "; - } - std::cout << "\n"; + position_ids_sample.emplace_back(i); } - #endif - return numpyData; - } + position_ids.emplace_back(std::move(position_ids_sample)); - size_t getLastIdx() const - { - auto res = std::find_if(std::rbegin(attention_mask[current_idx]), std::rend(attention_mask[current_idx]), [](uint64_t val) { return 1 == val;}); - size_t last_idx = std::distance(res, std::rend(attention_mask[current_idx])); - //std::cout << "Last input idx: " << last_idx << std::endl; - return last_idx; + _size = 1; } - std::vector getInputIds() { return input_ids[current_idx]; } - std::vector getAttentionMask() { return attention_mask[current_idx]; } - std::vector getPositionIds() { return position_ids[current_idx]; } - - Dataset(const Dataset &buf) = delete; - Dataset &operator=(const Dataset &buf) = delete; - NumpyVector input_ids; NumpyVector attention_mask; NumpyVector position_ids; - size_t current_idx = 0; - bool npy_files_loaded = false; + size_t _size = 0; + size_t _current_idx = 0; + bool _npy_files_loaded = false; }; struct LLama2Inputs @@ -233,21 +281,19 @@ struct LLama2Inputs { data.initialize(); - auto input_ids = data.getInputIds(); auto param_shapes = prog.get_parameter_shapes(); - auto input_ids_str = "input_ids"; + + auto input_ids = data.getInputIds(); input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); - prog_args.add(input_ids_str, migraphx::argument(param_shapes[input_ids_str], input_ids_buffer->data())); + prog_args.add(INPUTS_ID_STR, migraphx::argument(param_shapes[INPUTS_ID_STR], input_ids_buffer->data())); auto attention_mask = data.getAttentionMask(); - auto attention_mask_str = "attention_mask"; attention_mask_buffer = std::make_unique(std::move(attention_mask), offload_copy); - prog_args.add(attention_mask_str, migraphx::argument(param_shapes[attention_mask_str], attention_mask_buffer->data())); + prog_args.add(ATTENTION_MASK_STR, migraphx::argument(param_shapes[ATTENTION_MASK_STR], attention_mask_buffer->data())); auto position_ids = data.getPositionIds(); - auto position_ids_str = "position_ids"; position_ids_buffer = std::make_unique(std::move(position_ids), offload_copy); - prog_args.add(position_ids_str, migraphx::argument(param_shapes[position_ids_str], position_ids_buffer->data())); + prog_args.add(POSITION_IDS_STR, migraphx::argument(param_shapes[POSITION_IDS_STR], position_ids_buffer->data())); }; void upload_to_device(hipStream_t stream) @@ -258,7 +304,38 @@ struct LLama2Inputs position_ids_buffer->upload_to_device(stream); } - size_t getLastInputIndex() { return data.getLastIdx(); } + void updateData(migraphx::program& prog, migraphx::program_parameters& prog_args) + { + auto currentIdx = data.currentIdx(); + if (currentIdx != data.getNext()) + { + auto param_shapes = prog.get_parameter_shapes(); + + auto input_ids = data.getInputIds(); + input_ids_buffer->update(std::move(input_ids)); + if (offload_copy) + { + prog_args.add(INPUTS_ID_STR, migraphx::argument(param_shapes[INPUTS_ID_STR], input_ids_buffer->data())); + } + + auto attention_mask = data.getAttentionMask(); + attention_mask_buffer->update(std::move(attention_mask)); + if (offload_copy) + { + prog_args.add(ATTENTION_MASK_STR, migraphx::argument(param_shapes[ATTENTION_MASK_STR], attention_mask_buffer->data())); + } + + auto position_ids = data.getPositionIds(); + position_ids_buffer->update(std::move(position_ids)); + if (offload_copy) + { + prog_args.add(POSITION_IDS_STR, migraphx::argument(param_shapes[POSITION_IDS_STR], position_ids_buffer->data())); + } + } + } + + size_t getLastInputIndex() const { return data.getLastIdx(); } + size_t dataSize() const { return data.size(); } LLama2Inputs() = delete; LLama2Inputs(const LLama2Inputs &buf) = delete; @@ -269,10 +346,16 @@ struct LLama2Inputs std::unique_ptr position_ids_buffer; Dataset data; bool offload_copy; + + const char* INPUTS_ID_STR = "input_ids"; + const char* ATTENTION_MASK_STR = "attention_mask"; + const char* POSITION_IDS_STR = "position_ids"; }; + + int main() { - bool offload_copy = true; + bool offload_copy = false; std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; ModelLoadSettings settings = {SEQ_SIZE, true /*quantize_fp16*/, offload_copy /*offload_copy*/, true /*fast_math*/}; migraphx::program prog = loadProgram(settings); @@ -299,39 +382,51 @@ int main() { std::cout << "Starting evaluation" << std::endl; size_t token_count = 0; auto start = std::chrono::steady_clock::now(); - for (int i = model_inputs.getLastInputIndex(); i < SEQ_SIZE - 1; ++i) + std::cout << "Dataset size: " << model_inputs.dataSize() << std::endl; + for (size_t i = 0; i < model_inputs.dataSize(); ++i) { - auto outputs = prog.run_async(prog_args, stream); - if (not offload_copy) + #ifdef TRACE + std::cout << "Iter #" << i << std::endl; + #endif + for (size_t i = model_inputs.getLastInputIndex(); i < SEQ_SIZE - 1; ++i) { - output_buffer.download_from_device(stream, i * VOCAB_SIZE, (i + 1) * VOCAB_SIZE); + auto outputs = prog.run_async(prog_args, stream); + if (not offload_copy) + { + output_buffer.download_from_device(stream, i * VOCAB_SIZE, (i + 1) * VOCAB_SIZE); + } + + check_hip_status(hipStreamSynchronize(stream)); + float* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast(output_buffer.hbuff.data()); + std::vector logits(results, results + output_size); + std::vector::iterator max = std::max_element(std::begin(logits) + (i * VOCAB_SIZE), std::begin(logits) + ((i + 1) * VOCAB_SIZE)); + int64_t new_token = std::distance(std::begin(logits) + (i * VOCAB_SIZE), max); + token_count++; + // std::cout << "New token: " << new_token << std::endl; + output_tokens.push_back(new_token); + if (new_token == EOS) + { + break; + } + model_inputs.input_ids_buffer->update_data(new_token, i +1, stream); + model_inputs.attention_mask_buffer->update_data(1, i +1, stream); } - check_hip_status(hipStreamSynchronize(stream)); - float* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast(output_buffer.hbuff.data()); - std::vector logits(results, results + output_size); - std::vector::iterator max = std::max_element(std::begin(logits) + (i * VOCAB_SIZE), std::begin(logits) + ((i + 1) * VOCAB_SIZE)); - int64_t new_token = std::distance(std::begin(logits) + (i * VOCAB_SIZE), max); - token_count++; - // std::cout << "New token: " << new_token << std::endl; - output_tokens.push_back(new_token); - if (new_token == EOS) - { - break; +#ifdef TRACE + std::cout << "######### Output token ids for #" << i << " #########" << std::endl; + // print output tokens + for (auto tok: output_tokens){ + std::cout << tok << ", "; } - model_inputs.input_ids_buffer->update_data(new_token, i +1, stream); - model_inputs.attention_mask_buffer->update_data(1, i +1, stream); + std::cout << std::endl; +#endif + model_inputs.updateData(prog, prog_args); + + output_tokens.clear(); } float dur = std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count() / 1000.f; std::cout << "Duration: " << dur << " seconds." << std::endl; std::cout << "Completed " << token_count << " tokens." << std::endl; std::cout << "Tokens/sec: " << token_count / dur << std::endl; - - std::cout << "######### Output token ids #########" << std::endl; - // print output tokens - for (auto tok: output_tokens){ - std::cout << tok << ", "; - } - std::cout << std::endl; return 0; } From 55e41f4bcecffa8dd28a3ae0c1383919bd3d78d3 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Tue, 29 Oct 2024 07:29:56 -0500 Subject: [PATCH 15/55] Add missing upload to device for multiple samples --- examples/transformers/mgx_llama2/mgxllama2.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 47048c7991a..de73e212380 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -304,7 +304,7 @@ struct LLama2Inputs position_ids_buffer->upload_to_device(stream); } - void updateData(migraphx::program& prog, migraphx::program_parameters& prog_args) + bool updateData(migraphx::program& prog, migraphx::program_parameters& prog_args) { auto currentIdx = data.currentIdx(); if (currentIdx != data.getNext()) @@ -331,7 +331,9 @@ struct LLama2Inputs { prog_args.add(POSITION_IDS_STR, migraphx::argument(param_shapes[POSITION_IDS_STR], position_ids_buffer->data())); } + return true; } + return false; } size_t getLastInputIndex() const { return data.getLastIdx(); } @@ -420,7 +422,12 @@ int main() { } std::cout << std::endl; #endif - model_inputs.updateData(prog, prog_args); + auto updated = model_inputs.updateData(prog, prog_args); + + if (updated && not offload_copy) + { + model_inputs.upload_to_device(stream); + } output_tokens.clear(); } From 4d25c4932c0e185539d805c6df3ad43d63838f38 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Mon, 4 Nov 2024 06:03:00 -0600 Subject: [PATCH 16/55] Add accuracy calculation for mgx_llama2 example --- .../transformers/mgx_llama2/eval_accuracy.py | 118 ++++++++++++++++++ examples/transformers/mgx_llama2/mgxllama2.cc | 33 ++++- 2 files changed, 147 insertions(+), 4 deletions(-) create mode 100644 examples/transformers/mgx_llama2/eval_accuracy.py diff --git a/examples/transformers/mgx_llama2/eval_accuracy.py b/examples/transformers/mgx_llama2/eval_accuracy.py new file mode 100644 index 00000000000..041ce0ea500 --- /dev/null +++ b/examples/transformers/mgx_llama2/eval_accuracy.py @@ -0,0 +1,118 @@ +from argparse import ArgumentParser +import numpy as np +import pickle +from pathlib import Path +import os +import evaluate +import nltk +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM + + +MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" + +G_MAX_TOK_LEN = 1024 +G_LLAMA2_EOS = 2 +SAMPLE_SIZE = 10 + +DATASET_PATH = "/dataset/open_orca_gpt4_tokenized_llama.sampled_24576.pkl" +RESULT_PATH = "build/results.txt" + +def main(dataset_path, result_path, sample_size, sequence_size): + tokenizer = AutoTokenizer.from_pretrained( + MODEL_NAME, + model_max_length=sequence_size, + padding_side="left", + use_fast=False,) + + metric = evaluate.load("rouge") + nltk.download("punkt") + + _p = Path(DATASET_PATH) + if _p.exists(): + with _p.open(mode="rb") as f: + d = pickle.load(f) + + + target = d['output'].to_list() + targets = target[0:sample_size] + results, gen_tok_len = readResult(result_path) + + preds = tokenizer.batch_decode( + results, skip_special_tokens=True + ) + + postprocess_text(preds, target) + + result = metric.compute( + predictions=preds, references=targets, use_stemmer=True, use_aggregator=False + ) + + result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} + prediction_lens = [len(pred) for pred in preds] + gen_num = len(preds) + + result = { + **result, + "gen_len": np.sum(prediction_lens), + "gen_num": gen_num, + "gen_tok_len": gen_tok_len, + "tokens_per_sample": round(gen_tok_len / gen_num, 1), + } + + print("\nResults\n") + print(result) + +def readResult(path): + results = [] + tok_len = 0 + f = open(path, "r") + for res in f: + result = res.split(",") + result = [int(num_res) for num_res in result] + results.append(result) + tok_len += len(result) + return results, tok_len + +def postprocess_text(preds, targets): + preds = [pred.strip() for pred in preds] + targets = [target.strip() for target in targets] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] + + return preds, targets + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument( + "-d", + "--dataset-path", + help="Path to the dataset pickle file", + default=DATASET_PATH + ) + + parser.add_argument( + "-r", + "--result_path", + help="Path to output tokens result file", + default=RESULT_PATH + ) + + parser.add_argument( + "-size", + "--sample-size", + help="Sample size of dataset", + type=int, + default=SAMPLE_SIZE + ) + + parser.add_argument( + "-seq_size", + "--sequence_size", + help="Size of sequence", + type=int, + default=G_MAX_TOK_LEN + ) + + main(**vars(parser.parse_args())) \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index de73e212380..06640860ffe 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -14,13 +14,15 @@ using namespace mlinfer; const std::string MODEL_FOLDER = "/model/"; const std::string ONNX_FILE = "model.onnx"; const std::string DATASET_FOLDER = "/dataset/"; -const size_t DATASET_SIZE = 3; +const size_t DATASET_SIZE = 10; // sequence length from model config -const size_t SEQ_SIZE = 256; +const size_t SEQ_SIZE = 1024; // vocab size from model config const size_t VOCAB_SIZE = 32000; // EOS token from model config const size_t EOS = 2; +// Write output tokens to file +const bool WRITE_RESULT_FILE = true; struct ModelLoadSettings { @@ -354,7 +356,23 @@ struct LLama2Inputs const char* POSITION_IDS_STR = "position_ids"; }; - +void writeResults(const std::vector>& results) +{ + std::string RESULT_FILE = "result.txt"; + std::ofstream outFile(RESULT_FILE); + for (auto& resVec : results) + { + for (auto& res : resVec) + { + outFile << res; + if (&res != &resVec.back()) + { + outFile << ", "; + } + } + outFile << "\n"; + } +} int main() { bool offload_copy = false; @@ -364,6 +382,7 @@ int main() { std::cout << "Model loaded" << std::endl; // Setup model inputs + std::vector> results; std::vector output_tokens; migraphx::program_parameters prog_args; hipStream_t stream; @@ -428,12 +447,18 @@ int main() { { model_inputs.upload_to_device(stream); } - + results.emplace_back(output_tokens); output_tokens.clear(); } + float dur = std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count() / 1000.f; std::cout << "Duration: " << dur << " seconds." << std::endl; std::cout << "Completed " << token_count << " tokens." << std::endl; std::cout << "Tokens/sec: " << token_count / dur << std::endl; + + if (WRITE_RESULT_FILE) + { + writeResults(results); + } return 0; } From 3dd55b5e4ac0627ada2fba92ad23519ed9d5b051 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Tue, 5 Nov 2024 10:21:04 -0600 Subject: [PATCH 17/55] Use MIGraphX from develop branch in Dockerfile --- examples/transformers/mgx_llama2/Dockerfile | 20 +++++++++++++++++-- .../transformers/mgx_llama2/build_docker.sh | 2 +- .../transformers/mgx_llama2/run_docker.sh | 3 ++- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/examples/transformers/mgx_llama2/Dockerfile b/examples/transformers/mgx_llama2/Dockerfile index 387ac4ac80a..f120cfbce96 100644 --- a/examples/transformers/mgx_llama2/Dockerfile +++ b/examples/transformers/mgx_llama2/Dockerfile @@ -7,14 +7,30 @@ SHELL ["/bin/bash", "-c"] RUN apt-get update && apt-get install -y --allow-unauthenticated \ apt-utils \ cmake \ - migraphx && \ + git && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* +RUN mkdir /migraphx && cd /migraphx && git clone --branch develop https://github.com/ROCm/AMDMIGraphX.git + +WORKDIR /migraphx/AMDMIGraphX + +RUN ./tools/install_prereqs.sh + +RUN export test=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') + +#TODO: use $(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') for GPU_TARGETS +RUN mkdir build && cd build && \ + CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DGPU_TARGETS='gfx942' && \ + make -j$(nproc) && \ + make install + RUN mkdir /mgx_llama2 COPY . /mgx_llama2 RUN rm -rf /mgx_llama2/build && mkdir /mgx_llama2/build -RUN cd /mgx_llama2/build && CXX=/opt/rocm/llvm/bin/clang++ cmake .. && make \ No newline at end of file +WORKDIR /mgx_llama2/build + +RUN CXX=/opt/rocm/llvm/bin/clang++ cmake .. && make \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/build_docker.sh b/examples/transformers/mgx_llama2/build_docker.sh index c6eb64304ec..4aafe5bea88 100755 --- a/examples/transformers/mgx_llama2/build_docker.sh +++ b/examples/transformers/mgx_llama2/build_docker.sh @@ -1,3 +1,3 @@ #!/bin/bash -docker build --platform linux/amd64 --tag mgx_llama2:v0.1 --file Dockerfile . +docker build --platform linux/amd64 --tag mgx_llama2:v0.2 --file Dockerfile . diff --git a/examples/transformers/mgx_llama2/run_docker.sh b/examples/transformers/mgx_llama2/run_docker.sh index a0c7e2da7ed..329ab350873 100755 --- a/examples/transformers/mgx_llama2/run_docker.sh +++ b/examples/transformers/mgx_llama2/run_docker.sh @@ -15,8 +15,9 @@ else fi docker run --device='/dev/kfd' --device='/dev/dri' --group-add video \ +-v $(pwd):/mgx_llama2 \ -v $MODEL_DIR:/model \ -v $DATA_DIR:/dataset \ -w /mgx_llama2/build \ --it mgx_llama2:v0.1 +-it mgx_llama2:v0.2 From 24167860c7e3b80c4e666c5b701d6702d5ef337a Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Thu, 7 Nov 2024 08:38:33 -0600 Subject: [PATCH 18/55] Fix dataset loading --- examples/transformers/mgx_llama2/Dockerfile | 3 ++- examples/transformers/mgx_llama2/preprocess_dataset.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/transformers/mgx_llama2/Dockerfile b/examples/transformers/mgx_llama2/Dockerfile index f120cfbce96..fc8fb5fc2e8 100644 --- a/examples/transformers/mgx_llama2/Dockerfile +++ b/examples/transformers/mgx_llama2/Dockerfile @@ -33,4 +33,5 @@ RUN rm -rf /mgx_llama2/build && mkdir /mgx_llama2/build WORKDIR /mgx_llama2/build -RUN CXX=/opt/rocm/llvm/bin/clang++ cmake .. && make \ No newline at end of file +RUN CXX=/opt/rocm/llvm/bin/clang++ cmake .. && make +RUN pip install pandas diff --git a/examples/transformers/mgx_llama2/preprocess_dataset.py b/examples/transformers/mgx_llama2/preprocess_dataset.py index 56411ccae23..0d43f890102 100644 --- a/examples/transformers/mgx_llama2/preprocess_dataset.py +++ b/examples/transformers/mgx_llama2/preprocess_dataset.py @@ -13,6 +13,8 @@ if _p.exists(): with _p.open(mode="rb") as f: d = pickle.load(f) +else: + raise RuntimeError(f"Missing dataset from {DATASET_PATH}") toks = d['tok_input'].to_list() #toks = [toks[0]] From 0877f720ffaaf8367a9575160adccc8f5408a40e Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Thu, 7 Nov 2024 08:39:26 -0600 Subject: [PATCH 19/55] Add README to C++ LLama2 example --- examples/transformers/mgx_llama2/README.md | 92 ++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 examples/transformers/mgx_llama2/README.md diff --git a/examples/transformers/mgx_llama2/README.md b/examples/transformers/mgx_llama2/README.md new file mode 100644 index 00000000000..de3bcd7d50b --- /dev/null +++ b/examples/transformers/mgx_llama2/README.md @@ -0,0 +1,92 @@ +## Getting the model + +### Getting the pre-quantized model from HuggingFace +```bash +pip install -U "huggingface_hub[cli]" +huggingface-cli login YOUR_HF_TOKEN +hugginggface-cli download https://huggingface.co/amd/Llama-2-7b-chat-hf-awq-int4-asym-gs128-onnx +``` +Alternatively you can quantize the model yourself. + +### Quantizing the model + +**If you are using the pre-quantized model you can skip this section.** + +Get the latest quark quantizer version from https://xcoartifactory/ui/native/uai-pip-local/com/amd/quark/main/nightly/ . Donwloading the zip is recommended because it contains the required scripts. The quark version used when this was created: quark-1.0.0.dev20241028+eb46b7438 (28-10-24). + +Also we will need to install the onnxruntime-genai (OGA) tool to convert the quark_safetensors format to onnx format properly. + +#### Installing quark and it's dependencies: +```bash +# install OGA tool +pip install onnxruntime-genai + +# Quark dependencies according to https://quark.docs.amd.com/latest/install.html, we assume pytorch is already installed. You can use the following base docker image which has torch installed: rocm/pytorch:rocm6.2.2_ubuntu20.04_py3.9_pytorch_release_2.2.1 +pip install onnxruntime onnx + +# Install the whl +unzip quark-1.0.0.dev20241028+eb46b7438.zip -d quark +cd quark +RUN pip install quark-1.0.0.dev20241028+eb46b7438-py3-none-any.whl +``` + +#### Quantizing the model and converting to ONNX +```bash +cd examples/torch/language_modeling/llm_ptq + +export MODEL_DIR = [local model checkpoint folder] or meta-llama/Llama-2-7b-chat-hf or meta-llama/Llama-2-70b-chat-hf +export QUANTIZED_MODEL_DIR = [output model checkpoint folder] + +python3 quantize_quark.py --model_dir $MODEL_DIR \ + --data_type float16 \ + --quant_scheme w_uint4_per_group_asym \ + --num_calib_data 128 \ + --quant_algo awq \ + --dataset pileval_for_awq_benchmark \ + --seq_len 1024 \ + --output_dir $MODEL_DIR-awq-uint4-asym-g128-f16 \ + --model_export quark_safetensors \ + --custom_mode awq + +python3 -m onnxruntime_genai.models.builder \ + -i "$QUANTIZED_MODEL_DIR" \ + -o "$QUANTIZED_MODEL_DIR-onnx" \ + -p int4 \ + -e cpu +``` + +## Getting the dataset + +Download the preprocessed open-orca dataset files using the instructions in https://github.com/mlcommons/inference/tree/master/language/llama2-70b#preprocessed + +### Running the example + +#### Starting migraphx docker + +```bash + +./build_docker.sh + +export MODEL_DIR_PATH=path/to/quantized/llama2-7[0]b-model +export DATA_DIR_PATH=path/to/open_orca_dataset +./run_docker.sh +``` + +#### Building and running the example + +```bash +# Convert dataset to numpy format +./prepocess_dataset.py + +# Builidng the example +cd mgx_llama2 +mkdir build && cd build +CXX=/opt/rocm/llvm/bin/clang++ cmake .. +make -j + +# Running the example +./mgxllama2 + +# Test the accuracy of the output +python3 eval_accuracy.py +``` From 513cc0d176bb71ab2d2aa8e3426b14c0e449fd78 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Thu, 7 Nov 2024 09:31:21 -0600 Subject: [PATCH 20/55] Add buffers for llama2 7b quantized models --- examples/transformers/mgx_llama2/Dockerfile | 1 + .../mgx_llama2/harness/buffer.hpp | 1 + .../mgx_llama2/harness/common.hpp | 4 ++ examples/transformers/mgx_llama2/mgxllama2.cc | 53 +++++++++++++++++-- 4 files changed, 55 insertions(+), 4 deletions(-) diff --git a/examples/transformers/mgx_llama2/Dockerfile b/examples/transformers/mgx_llama2/Dockerfile index fc8fb5fc2e8..ed252385bdf 100644 --- a/examples/transformers/mgx_llama2/Dockerfile +++ b/examples/transformers/mgx_llama2/Dockerfile @@ -7,6 +7,7 @@ SHELL ["/bin/bash", "-c"] RUN apt-get update && apt-get install -y --allow-unauthenticated \ apt-utils \ cmake \ + half \ git && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* diff --git a/examples/transformers/mgx_llama2/harness/buffer.hpp b/examples/transformers/mgx_llama2/harness/buffer.hpp index fc7fa064b6e..fc1f91daa59 100644 --- a/examples/transformers/mgx_llama2/harness/buffer.hpp +++ b/examples/transformers/mgx_llama2/harness/buffer.hpp @@ -189,4 +189,5 @@ namespace mlinfer using LLama2InputBuffer = ManagedBuffer_v2; using LLama2OutputBuffer = ManagedBuffer_v2; + using LLama2PastKeyValueBuffer = ManagedBuffer_v2; } diff --git a/examples/transformers/mgx_llama2/harness/common.hpp b/examples/transformers/mgx_llama2/harness/common.hpp index db155af820a..dbf65af459c 100644 --- a/examples/transformers/mgx_llama2/harness/common.hpp +++ b/examples/transformers/mgx_llama2/harness/common.hpp @@ -9,12 +9,16 @@ #include #include #include +#include #define TIMER_ON 0 #define TRACE_ON 0 #define assertm(exp, msg) assert(((void)msg, exp)) +using half = half_float::half; +using namespace half_float::literal; + namespace mlinfer { struct INoCopy diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 06640860ffe..e1a47afcd2a 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -1,4 +1,5 @@ #include "buffer.hpp" +#include "common.hpp" #include "numpy.hpp" #include @@ -22,7 +23,9 @@ const size_t VOCAB_SIZE = 32000; // EOS token from model config const size_t EOS = 2; // Write output tokens to file -const bool WRITE_RESULT_FILE = true; +const bool WRITE_RESULT_FILE = false; + +const size_t HIDDEN_LAYERS_NUM = 32; struct ModelLoadSettings { @@ -53,6 +56,10 @@ static migraphx::program loadOnnx(ModelLoadSettings& settings) { std::filesystem::path onnx_path(MODEL_FOLDER + ONNX_FILE); + #ifdef TRACE + std::cout << "Using model: " << MODEL_FOLDER + ONNX_FILE << std::endl; + #endif + migraphx::program prog; std::ifstream f(onnx_path.c_str()); if (f.good()) @@ -285,17 +292,47 @@ struct LLama2Inputs auto param_shapes = prog.get_parameter_shapes(); + auto inputShape = param_shapes[INPUTS_ID_STR]; auto input_ids = data.getInputIds(); input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); - prog_args.add(INPUTS_ID_STR, migraphx::argument(param_shapes[INPUTS_ID_STR], input_ids_buffer->data())); + prog_args.add(INPUTS_ID_STR, migraphx::argument(inputShape, input_ids_buffer->data())); + + auto attShape = param_shapes[ATTENTION_MASK_STR]; auto attention_mask = data.getAttentionMask(); attention_mask_buffer = std::make_unique(std::move(attention_mask), offload_copy); - prog_args.add(ATTENTION_MASK_STR, migraphx::argument(param_shapes[ATTENTION_MASK_STR], attention_mask_buffer->data())); + prog_args.add(ATTENTION_MASK_STR, migraphx::argument(attShape, attention_mask_buffer->data())); + //auto positionShape = param_shapes[POSITION_IDS_STR]; + auto positionShape = inputShape; auto position_ids = data.getPositionIds(); position_ids_buffer = std::make_unique(std::move(position_ids), offload_copy); - prog_args.add(POSITION_IDS_STR, migraphx::argument(param_shapes[POSITION_IDS_STR], position_ids_buffer->data())); + prog_args.add(POSITION_IDS_STR, migraphx::argument(inputShape, position_ids_buffer->data())); + + const size_t HEAD_SIZE = 128; + + const size_t PAST_KEY_VAL_SIZE = HIDDEN_LAYERS_NUM*HEAD_SIZE; + + // past_key_values.0.key = @param:past_key_values.0.key -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} + // past_key_values.0.value = @param:past_key_values.0.value -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} + for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + { + std::stringstream past_key; + past_key << "past_key_values." << std::to_string(i) << ".key"; + auto past_keyStr = past_key.str(); + auto past_keyString = past_keyStr.c_str(); + past_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + auto pastKeyShape = param_shapes[past_keyString]; + prog_args.add(past_keyString, migraphx::argument(pastKeyShape, past_key_buffers[i]->data())); + + std::stringstream past_val; + past_val << "past_key_values." << std::to_string(i) << ".value"; + auto past_valueStr = past_val.str(); + auto past_valueString = past_valueStr.c_str(); + past_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + auto pastValueShape = param_shapes[past_valueString]; + prog_args.add(past_valueString, migraphx::argument(pastValueShape, past_value_buffers[i]->data())); + } }; void upload_to_device(hipStream_t stream) @@ -348,6 +385,8 @@ struct LLama2Inputs std::unique_ptr input_ids_buffer; std::unique_ptr attention_mask_buffer; std::unique_ptr position_ids_buffer; + std::vector> past_key_buffers; + std::vector> past_value_buffers; Dataset data; bool offload_copy; @@ -400,6 +439,12 @@ int main() { migraphx::shape out_shape{migraphx_shape_float_type, {1, SEQ_SIZE, VOCAB_SIZE}}; prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); + const size_t logit_size = 4096 * VOCAB_SIZE; + auto logit_name = "logits"; + auto logit_buffer = LLama2PastKeyValueBuffer(std::vector(logit_size, 0.0_h), offload_copy); + migraphx::shape logit_shape{migraphx_shape_float_type, {1, 4096, VOCAB_SIZE}}; + prog_args.add(logit_name, migraphx::argument(logit_shape, output_buffer.data())); + std::cout << "Starting evaluation" << std::endl; size_t token_count = 0; auto start = std::chrono::steady_clock::now(); From 048e58797e789f2fa20c58f9f2a5cb438d2fbb77 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Fri, 8 Nov 2024 09:29:02 -0600 Subject: [PATCH 21/55] Fix Llama2-7b model file parse and input buffers --- examples/transformers/mgx_llama2/mgxllama2.cc | 65 ++++++++++++------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index e1a47afcd2a..12786906aa7 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -17,7 +17,7 @@ const std::string ONNX_FILE = "model.onnx"; const std::string DATASET_FOLDER = "/dataset/"; const size_t DATASET_SIZE = 10; // sequence length from model config -const size_t SEQ_SIZE = 1024; +const size_t SEQ_SIZE = 4096; // vocab size from model config const size_t VOCAB_SIZE = 32000; // EOS token from model config @@ -26,6 +26,7 @@ const size_t EOS = 2; const bool WRITE_RESULT_FILE = false; const size_t HIDDEN_LAYERS_NUM = 32; +const size_t HEAD_SIZE = 128; struct ModelLoadSettings { @@ -52,6 +53,20 @@ static std::string getModelPath(ModelLoadSettings& s) return path.str(); } +std::string getPastKeyString(size_t i) +{ + std::stringstream past_key; + past_key << "past_key_values." << std::to_string(i) << ".key"; + return past_key.str(); +} + +std::string getPastValueStr(size_t i) +{ + std::stringstream past_val; + past_val << "past_key_values." << std::to_string(i) << ".value"; + return past_val.str(); +} + static migraphx::program loadOnnx(ModelLoadSettings& settings) { std::filesystem::path onnx_path(MODEL_FOLDER + ONNX_FILE); @@ -66,9 +81,15 @@ static migraphx::program loadOnnx(ModelLoadSettings& settings) { migraphx::onnx_options onnx_opts; std::vector dims = {1, SEQ_SIZE}; + std::vector dimsPastKey = {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}; onnx_opts.set_input_parameter_shape("input_ids", dims); onnx_opts.set_input_parameter_shape("attention_mask", dims); onnx_opts.set_input_parameter_shape("position_ids", dims); + for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + { + onnx_opts.set_input_parameter_shape(getPastKeyString(i), dimsPastKey); + onnx_opts.set_input_parameter_shape(getPastValueStr(i), dimsPastKey); + } std::cout << "Parsing onnx file ..." << std::endl; prog = parse_onnx(onnx_path.c_str(), onnx_opts); @@ -252,12 +273,12 @@ struct Dataset void prepareSampleDataset() { std::cout << "Numpy files are not loaded, using dummy data\n"; - std::vector input_ids_sample = {1,6804,5207,387,287,29973}; - input_ids_sample.resize(SEQ_SIZE, EOS); + std::vector input_ids_sample = {1,6804,338,5207,387,287,29973}; + input_ids_sample.resize(SEQ_SIZE, 0); std::vector attention_mask_sample = input_ids_sample; input_ids.emplace_back(std::move(input_ids_sample)); std::transform(std::begin(attention_mask_sample), std::end(attention_mask_sample), std::begin(attention_mask_sample), [](auto i){ - return (i != EOS) ? 1 : 0; + return (i != 0) ? 1 : 0; }); attention_mask.emplace_back(std::move(attention_mask_sample)); @@ -307,27 +328,21 @@ struct LLama2Inputs auto positionShape = inputShape; auto position_ids = data.getPositionIds(); position_ids_buffer = std::make_unique(std::move(position_ids), offload_copy); - prog_args.add(POSITION_IDS_STR, migraphx::argument(inputShape, position_ids_buffer->data())); + //prog_args.add(POSITION_IDS_STR, migraphx::argument(inputShape, position_ids_buffer->data())); - const size_t HEAD_SIZE = 128; - - const size_t PAST_KEY_VAL_SIZE = HIDDEN_LAYERS_NUM*HEAD_SIZE; + const size_t PAST_KEY_VAL_SIZE = HIDDEN_LAYERS_NUM*HEAD_SIZE*SEQ_SIZE; // past_key_values.0.key = @param:past_key_values.0.key -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} // past_key_values.0.value = @param:past_key_values.0.value -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) { - std::stringstream past_key; - past_key << "past_key_values." << std::to_string(i) << ".key"; - auto past_keyStr = past_key.str(); + auto past_keyStr = getPastKeyString(i); auto past_keyString = past_keyStr.c_str(); past_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); auto pastKeyShape = param_shapes[past_keyString]; prog_args.add(past_keyString, migraphx::argument(pastKeyShape, past_key_buffers[i]->data())); - std::stringstream past_val; - past_val << "past_key_values." << std::to_string(i) << ".value"; - auto past_valueStr = past_val.str(); + auto past_valueStr = getPastValueStr(i); auto past_valueString = past_valueStr.c_str(); past_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); auto pastValueShape = param_shapes[past_valueString]; @@ -433,18 +448,18 @@ int main() { } // Setup model output for non-offload copy + // const size_t output_size = SEQ_SIZE * VOCAB_SIZE; + // auto output_name = "main:#output_0"; + // auto output_buffer = LLama2OutputBuffer(std::vector(output_size), offload_copy); + // migraphx::shape out_shape{migraphx_shape_float_type, {1, SEQ_SIZE, VOCAB_SIZE}}; + // prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); + const size_t output_size = SEQ_SIZE * VOCAB_SIZE; - auto output_name = "main:#output_0"; - auto output_buffer = LLama2OutputBuffer(std::vector(output_size), offload_copy); + auto output_name = "logits"; + auto output_buffer = LLama2PastKeyValueBuffer(std::vector(output_size), offload_copy); migraphx::shape out_shape{migraphx_shape_float_type, {1, SEQ_SIZE, VOCAB_SIZE}}; prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); - const size_t logit_size = 4096 * VOCAB_SIZE; - auto logit_name = "logits"; - auto logit_buffer = LLama2PastKeyValueBuffer(std::vector(logit_size, 0.0_h), offload_copy); - migraphx::shape logit_shape{migraphx_shape_float_type, {1, 4096, VOCAB_SIZE}}; - prog_args.add(logit_name, migraphx::argument(logit_shape, output_buffer.data())); - std::cout << "Starting evaluation" << std::endl; size_t token_count = 0; auto start = std::chrono::steady_clock::now(); @@ -463,9 +478,9 @@ int main() { } check_hip_status(hipStreamSynchronize(stream)); - float* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast(output_buffer.hbuff.data()); - std::vector logits(results, results + output_size); - std::vector::iterator max = std::max_element(std::begin(logits) + (i * VOCAB_SIZE), std::begin(logits) + ((i + 1) * VOCAB_SIZE)); + half* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast(output_buffer.hbuff.data()); + std::vector logits(results, results + output_size); + std::vector::iterator max = std::max_element(std::begin(logits) + (i * VOCAB_SIZE), std::begin(logits) + ((i + 1) * VOCAB_SIZE)); int64_t new_token = std::distance(std::begin(logits) + (i * VOCAB_SIZE), max); token_count++; // std::cout << "New token: " << new_token << std::endl; From 71b75224fc8e68e91d00ffb233dd8cc28470c5ca Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Wed, 13 Nov 2024 07:57:48 -0600 Subject: [PATCH 22/55] Fix llama 7b quantized model evaluation step, use 2 models --- examples/transformers/mgx_llama2/mgxllama2.cc | 120 ++++++++++++++++-- 1 file changed, 106 insertions(+), 14 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 12786906aa7..57dd5a12391 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -27,6 +27,7 @@ const bool WRITE_RESULT_FILE = false; const size_t HIDDEN_LAYERS_NUM = 32; const size_t HEAD_SIZE = 128; +const size_t PAST_KEY_VAL_SIZE = HIDDEN_LAYERS_NUM*HEAD_SIZE*SEQ_SIZE; struct ModelLoadSettings { @@ -34,6 +35,7 @@ struct ModelLoadSettings bool quantize_fp16; bool offload_copy; bool fast_math; + bool input_one_dim; }; static std::string getModelPath(ModelLoadSettings& s) @@ -49,7 +51,12 @@ static std::string getModelPath(ModelLoadSettings& s) { path << "no"; } - path << "fastmath.mxr"; + path << "fastmath"; + if (s.input_one_dim) + { + path << "_inputonedim"; + } + path << ".mxr"; return path.str(); } @@ -67,6 +74,20 @@ std::string getPastValueStr(size_t i) return past_val.str(); } +std::string getPresentKeyString(size_t i) +{ + std::stringstream past_key; + past_key << "present." << std::to_string(i) << ".key"; + return past_key.str(); +} + +std::string getPresentValueStr(size_t i) +{ + std::stringstream past_val; + past_val << "present." << std::to_string(i) << ".value"; + return past_val.str(); +} + static migraphx::program loadOnnx(ModelLoadSettings& settings) { std::filesystem::path onnx_path(MODEL_FOLDER + ONNX_FILE); @@ -82,7 +103,16 @@ static migraphx::program loadOnnx(ModelLoadSettings& settings) migraphx::onnx_options onnx_opts; std::vector dims = {1, SEQ_SIZE}; std::vector dimsPastKey = {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}; - onnx_opts.set_input_parameter_shape("input_ids", dims); + std::vector inputDim; + if (settings.input_one_dim) + { + inputDim = {1,1}; + } + else + { + inputDim = dims; + } + onnx_opts.set_input_parameter_shape("input_ids", inputDim); onnx_opts.set_input_parameter_shape("attention_mask", dims); onnx_opts.set_input_parameter_shape("position_ids", dims); for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) @@ -328,9 +358,7 @@ struct LLama2Inputs auto positionShape = inputShape; auto position_ids = data.getPositionIds(); position_ids_buffer = std::make_unique(std::move(position_ids), offload_copy); - //prog_args.add(POSITION_IDS_STR, migraphx::argument(inputShape, position_ids_buffer->data())); - - const size_t PAST_KEY_VAL_SIZE = HIDDEN_LAYERS_NUM*HEAD_SIZE*SEQ_SIZE; + prog_args.add(POSITION_IDS_STR, migraphx::argument(inputShape, position_ids_buffer->data())); // past_key_values.0.key = @param:past_key_values.0.key -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} // past_key_values.0.value = @param:past_key_values.0.value -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} @@ -429,12 +457,17 @@ void writeResults(const std::vector>& results) } int main() { - bool offload_copy = false; + bool offload_copy = true; std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; - ModelLoadSettings settings = {SEQ_SIZE, true /*quantize_fp16*/, offload_copy /*offload_copy*/, true /*fast_math*/}; + ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, offload_copy /*offload_copy*/, false /*fast_math*/, false /*input_one_dim*/}; migraphx::program prog = loadProgram(settings); std::cout << "Model loaded" << std::endl; + // Load {1,1} input_ids model + settings.input_one_dim = true; + migraphx::program progSimpleInput = loadProgram(settings); + std::cout << "Model 1 dim input loaded" << std::endl; + // Setup model inputs std::vector> results; std::vector output_tokens; @@ -454,12 +487,28 @@ int main() { // migraphx::shape out_shape{migraphx_shape_float_type, {1, SEQ_SIZE, VOCAB_SIZE}}; // prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); - const size_t output_size = SEQ_SIZE * VOCAB_SIZE; + size_t output_size = SEQ_SIZE * VOCAB_SIZE; auto output_name = "logits"; auto output_buffer = LLama2PastKeyValueBuffer(std::vector(output_size), offload_copy); migraphx::shape out_shape{migraphx_shape_float_type, {1, SEQ_SIZE, VOCAB_SIZE}}; prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); + // std::vector> present_key_buffers; + // std::vector> present_value_buffers; + // for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + // { + // migraphx::shape present_shape{migraphx_shape_half_type, {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}}; + // auto present_keyStr = getPresentKeyString(i); + // auto present_keyString = present_keyStr.c_str(); + // present_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + // prog_args.add(present_keyString, migraphx::argument(present_shape, present_key_buffers[i]->data())); + + // auto present_valueStr = getPresentValueStr(i); + // auto present_valueString = present_valueStr.c_str(); + // present_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + // prog_args.add(present_valueString, migraphx::argument(present_shape, present_value_buffers[i]->data())); + // } + std::cout << "Starting evaluation" << std::endl; size_t token_count = 0; auto start = std::chrono::steady_clock::now(); @@ -469,7 +518,8 @@ int main() { #ifdef TRACE std::cout << "Iter #" << i << std::endl; #endif - for (size_t i = model_inputs.getLastInputIndex(); i < SEQ_SIZE - 1; ++i) + auto lastInputIdx = model_inputs.getLastInputIndex(); + for (size_t i = lastInputIdx; i < SEQ_SIZE - 1; ++i) { auto outputs = prog.run_async(prog_args, stream); if (not offload_copy) @@ -480,17 +530,59 @@ int main() { check_hip_status(hipStreamSynchronize(stream)); half* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast(output_buffer.hbuff.data()); std::vector logits(results, results + output_size); - std::vector::iterator max = std::max_element(std::begin(logits) + (i * VOCAB_SIZE), std::begin(logits) + ((i + 1) * VOCAB_SIZE)); - int64_t new_token = std::distance(std::begin(logits) + (i * VOCAB_SIZE), max); + + bool firstIter = (i == lastInputIdx); + auto logits_begin = firstIter ? std::begin(logits) + (i * VOCAB_SIZE) : std::begin(logits); + auto logits_end = firstIter ? std::begin(logits) + ((i + 1) * VOCAB_SIZE) : std::end(logits); + std::vector::iterator max = std::max_element(logits_begin, logits_end); + int64_t new_token = std::distance(logits_begin, max); + token_count++; - // std::cout << "New token: " << new_token << std::endl; + #ifdef TRACE + std::cout << "New token: " << new_token << std::endl; + #endif output_tokens.push_back(new_token); + + for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + { + migraphx::shape past_shape{migraphx_shape_half_type, {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}}; + half* res = reinterpret_cast(outputs[2*i+1].data()); + std::vector present_key(res, res + PAST_KEY_VAL_SIZE); + + auto past_keyStr = getPastKeyString(i); + model_inputs.past_key_buffers[i]->update(std::move(present_key)); + prog_args.add(past_keyStr.c_str(), migraphx::argument(past_shape, model_inputs.past_key_buffers[i]->data())); + + res = reinterpret_cast(outputs[2*i+2].data()); + std::vector present_value(res, res + PAST_KEY_VAL_SIZE); + + auto past_valueStr = getPastValueStr(i); + model_inputs.past_value_buffers[i]->update(std::move(present_value)); + prog_args.add(past_valueStr.c_str(), migraphx::argument(past_shape, model_inputs.past_value_buffers[i]->data())); + } + if (new_token == EOS) { break; } - model_inputs.input_ids_buffer->update_data(new_token, i +1, stream); - model_inputs.attention_mask_buffer->update_data(1, i +1, stream); + + model_inputs.attention_mask_buffer->update_data(1, i + 1, stream); + + if (firstIter) + { + prog = progSimpleInput; + output_size = VOCAB_SIZE; + + auto param_shapes = prog.get_parameter_shapes(); + auto inputShape = param_shapes[model_inputs.INPUTS_ID_STR]; + std::vector input_ids = {new_token}; + model_inputs.input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); + prog_args.add(model_inputs.INPUTS_ID_STR, migraphx::argument(inputShape, model_inputs.input_ids_buffer->data())); + } + else + { + model_inputs.input_ids_buffer->update_data(new_token, 0, stream); + } } #ifdef TRACE From 9220cf13c24be103610ee18c860b1c684010c5e8 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Wed, 13 Nov 2024 09:41:56 -0600 Subject: [PATCH 23/55] Support llama 7b quantized model without offload copy --- examples/transformers/mgx_llama2/mgxllama2.cc | 74 +++++++++++++------ 1 file changed, 51 insertions(+), 23 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 57dd5a12391..d3eaa320465 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -457,7 +457,7 @@ void writeResults(const std::vector>& results) } int main() { - bool offload_copy = true; + bool offload_copy = false; std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, offload_copy /*offload_copy*/, false /*fast_math*/, false /*input_one_dim*/}; migraphx::program prog = loadProgram(settings); @@ -488,26 +488,32 @@ int main() { // prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); size_t output_size = SEQ_SIZE * VOCAB_SIZE; - auto output_name = "logits"; + auto output_name = "main:#output_0"; auto output_buffer = LLama2PastKeyValueBuffer(std::vector(output_size), offload_copy); - migraphx::shape out_shape{migraphx_shape_float_type, {1, SEQ_SIZE, VOCAB_SIZE}}; + auto output_buffer_oneDim = LLama2PastKeyValueBuffer(std::vector(VOCAB_SIZE), offload_copy); + migraphx::shape out_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); - // std::vector> present_key_buffers; - // std::vector> present_value_buffers; - // for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) - // { - // migraphx::shape present_shape{migraphx_shape_half_type, {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}}; - // auto present_keyStr = getPresentKeyString(i); - // auto present_keyString = present_keyStr.c_str(); - // present_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); - // prog_args.add(present_keyString, migraphx::argument(present_shape, present_key_buffers[i]->data())); - - // auto present_valueStr = getPresentValueStr(i); - // auto present_valueString = present_valueStr.c_str(); - // present_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); - // prog_args.add(present_valueString, migraphx::argument(present_shape, present_value_buffers[i]->data())); - // } + std::vector> present_key_buffers; + std::vector> present_value_buffers; + for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + { + migraphx::shape present_shape{migraphx_shape_half_type, {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}}; + auto present_keyStr = getPresentKeyString(i); + auto present_keyString = present_keyStr.c_str(); + present_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + prog_args.add(present_keyString, migraphx::argument(present_shape, present_key_buffers[i]->data())); + + auto present_valueStr = getPresentValueStr(i); + auto present_valueString = present_valueStr.c_str(); + present_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + prog_args.add(present_valueString, migraphx::argument(present_shape, present_value_buffers[i]->data())); + if (offload_copy) + { + model_inputs.past_key_buffers[i]->upload_to_device(stream); + model_inputs.past_value_buffers[i]->upload_to_device(stream); + } + } std::cout << "Starting evaluation" << std::endl; size_t token_count = 0; @@ -521,17 +527,30 @@ int main() { auto lastInputIdx = model_inputs.getLastInputIndex(); for (size_t i = lastInputIdx; i < SEQ_SIZE - 1; ++i) { + bool firstIter = (i == lastInputIdx); auto outputs = prog.run_async(prog_args, stream); if (not offload_copy) { - output_buffer.download_from_device(stream, i * VOCAB_SIZE, (i + 1) * VOCAB_SIZE); + if (firstIter) + { + output_buffer.download_from_device(stream, i * VOCAB_SIZE, (i + 1) * VOCAB_SIZE); + } + else + { + output_buffer_oneDim.download_from_device(stream); + } + + for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + { + present_key_buffers[i]->download_from_device(stream); + present_value_buffers[i]->download_from_device(stream); + } } check_hip_status(hipStreamSynchronize(stream)); - half* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast(output_buffer.hbuff.data()); + half* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast( firstIter ? output_buffer.hbuff.data() : output_buffer_oneDim.hbuff.data()); std::vector logits(results, results + output_size); - bool firstIter = (i == lastInputIdx); auto logits_begin = firstIter ? std::begin(logits) + (i * VOCAB_SIZE) : std::begin(logits); auto logits_end = firstIter ? std::begin(logits) + ((i + 1) * VOCAB_SIZE) : std::end(logits); std::vector::iterator max = std::max_element(logits_begin, logits_end); @@ -546,19 +565,25 @@ int main() { for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) { migraphx::shape past_shape{migraphx_shape_half_type, {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}}; - half* res = reinterpret_cast(outputs[2*i+1].data()); + half* res = offload_copy ? reinterpret_cast(outputs[2*i+1].data()) : reinterpret_cast(present_key_buffers[i]->hbuff.data()); std::vector present_key(res, res + PAST_KEY_VAL_SIZE); auto past_keyStr = getPastKeyString(i); model_inputs.past_key_buffers[i]->update(std::move(present_key)); prog_args.add(past_keyStr.c_str(), migraphx::argument(past_shape, model_inputs.past_key_buffers[i]->data())); - res = reinterpret_cast(outputs[2*i+2].data()); + res = offload_copy ? reinterpret_cast(outputs[2*i+2].data()) : reinterpret_cast(present_value_buffers[i]->hbuff.data()); std::vector present_value(res, res + PAST_KEY_VAL_SIZE); auto past_valueStr = getPastValueStr(i); model_inputs.past_value_buffers[i]->update(std::move(present_value)); prog_args.add(past_valueStr.c_str(), migraphx::argument(past_shape, model_inputs.past_value_buffers[i]->data())); + + if (offload_copy) + { + model_inputs.past_key_buffers[i]->upload_to_device(stream); + model_inputs.past_value_buffers[i]->upload_to_device(stream); + } } if (new_token == EOS) @@ -572,12 +597,15 @@ int main() { { prog = progSimpleInput; output_size = VOCAB_SIZE; + migraphx::shape out_shape{migraphx_shape_half_type, {1, 1, VOCAB_SIZE}}; + prog_args.add(output_name, migraphx::argument(out_shape, output_buffer_oneDim.data())); auto param_shapes = prog.get_parameter_shapes(); auto inputShape = param_shapes[model_inputs.INPUTS_ID_STR]; std::vector input_ids = {new_token}; model_inputs.input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); prog_args.add(model_inputs.INPUTS_ID_STR, migraphx::argument(inputShape, model_inputs.input_ids_buffer->data())); + model_inputs.input_ids_buffer->upload_to_device(stream); } else { From 3842b54bc80c31afefd89b4bdc40237597c3f17c Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Thu, 14 Nov 2024 05:41:49 -0600 Subject: [PATCH 24/55] Connect dataset to llama 7b quantized model --- examples/transformers/mgx_llama2/mgxllama2.cc | 86 ++++++++++--------- 1 file changed, 46 insertions(+), 40 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index d3eaa320465..799eff678a7 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -17,7 +17,7 @@ const std::string ONNX_FILE = "model.onnx"; const std::string DATASET_FOLDER = "/dataset/"; const size_t DATASET_SIZE = 10; // sequence length from model config -const size_t SEQ_SIZE = 4096; +const size_t SEQ_SIZE = 1024; // vocab size from model config const size_t VOCAB_SIZE = 32000; // EOS token from model config @@ -354,12 +354,6 @@ struct LLama2Inputs attention_mask_buffer = std::make_unique(std::move(attention_mask), offload_copy); prog_args.add(ATTENTION_MASK_STR, migraphx::argument(attShape, attention_mask_buffer->data())); - //auto positionShape = param_shapes[POSITION_IDS_STR]; - auto positionShape = inputShape; - auto position_ids = data.getPositionIds(); - position_ids_buffer = std::make_unique(std::move(position_ids), offload_copy); - prog_args.add(POSITION_IDS_STR, migraphx::argument(inputShape, position_ids_buffer->data())); - // past_key_values.0.key = @param:past_key_values.0.key -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} // past_key_values.0.value = @param:past_key_values.0.value -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) @@ -383,7 +377,7 @@ struct LLama2Inputs assert(not offload_copy); input_ids_buffer->upload_to_device(stream); attention_mask_buffer->upload_to_device(stream); - position_ids_buffer->upload_to_device(stream); + //position_ids_buffer->upload_to_device(stream); } bool updateData(migraphx::program& prog, migraphx::program_parameters& prog_args) @@ -393,12 +387,9 @@ struct LLama2Inputs { auto param_shapes = prog.get_parameter_shapes(); - auto input_ids = data.getInputIds(); - input_ids_buffer->update(std::move(input_ids)); - if (offload_copy) - { - prog_args.add(INPUTS_ID_STR, migraphx::argument(param_shapes[INPUTS_ID_STR], input_ids_buffer->data())); - } + std::vector input_ids = data.getInputIds(); + input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); + prog_args.add(INPUTS_ID_STR, migraphx::argument(param_shapes[INPUTS_ID_STR], input_ids_buffer->data())); auto attention_mask = data.getAttentionMask(); attention_mask_buffer->update(std::move(attention_mask)); @@ -407,17 +398,31 @@ struct LLama2Inputs prog_args.add(ATTENTION_MASK_STR, migraphx::argument(param_shapes[ATTENTION_MASK_STR], attention_mask_buffer->data())); } - auto position_ids = data.getPositionIds(); - position_ids_buffer->update(std::move(position_ids)); - if (offload_copy) - { - prog_args.add(POSITION_IDS_STR, migraphx::argument(param_shapes[POSITION_IDS_STR], position_ids_buffer->data())); - } + // auto position_ids = data.getPositionIds(); + // position_ids_buffer->update(std::move(position_ids)); + // if (offload_copy) + // { + // prog_args.add(POSITION_IDS_STR, migraphx::argument(param_shapes[POSITION_IDS_STR], position_ids_buffer->data())); + // } return true; } return false; } + void resetPastKeyValueBuffers(hipStream_t stream) + { + for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + { + past_key_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); + past_value_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); + if (not offload_copy) + { + past_key_buffers[i]->upload_to_device(stream); + past_value_buffers[i]->upload_to_device(stream); + } + } + } + size_t getLastInputIndex() const { return data.getLastIdx(); } size_t dataSize() const { return data.size(); } @@ -427,7 +432,7 @@ struct LLama2Inputs std::unique_ptr input_ids_buffer; std::unique_ptr attention_mask_buffer; - std::unique_ptr position_ids_buffer; + //std::unique_ptr position_ids_buffer; std::vector> past_key_buffers; std::vector> past_value_buffers; Dataset data; @@ -460,7 +465,7 @@ int main() { bool offload_copy = false; std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, offload_copy /*offload_copy*/, false /*fast_math*/, false /*input_one_dim*/}; - migraphx::program prog = loadProgram(settings); + migraphx::program progMultipleInputDim = loadProgram(settings); std::cout << "Model loaded" << std::endl; // Load {1,1} input_ids model @@ -468,13 +473,15 @@ int main() { migraphx::program progSimpleInput = loadProgram(settings); std::cout << "Model 1 dim input loaded" << std::endl; + migraphx::program *prog = &progMultipleInputDim; + // Setup model inputs std::vector> results; std::vector output_tokens; migraphx::program_parameters prog_args; hipStream_t stream; check_hip_status(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); - auto model_inputs = LLama2Inputs(prog, prog_args, offload_copy); + auto model_inputs = LLama2Inputs(*prog, prog_args, offload_copy); if (not offload_copy) { model_inputs.upload_to_device(stream); @@ -508,17 +515,12 @@ int main() { auto present_valueString = present_valueStr.c_str(); present_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); prog_args.add(present_valueString, migraphx::argument(present_shape, present_value_buffers[i]->data())); - if (offload_copy) - { - model_inputs.past_key_buffers[i]->upload_to_device(stream); - model_inputs.past_value_buffers[i]->upload_to_device(stream); - } } + std::cout << "Dataset size: " << model_inputs.dataSize() << std::endl; std::cout << "Starting evaluation" << std::endl; size_t token_count = 0; auto start = std::chrono::steady_clock::now(); - std::cout << "Dataset size: " << model_inputs.dataSize() << std::endl; for (size_t i = 0; i < model_inputs.dataSize(); ++i) { #ifdef TRACE @@ -528,7 +530,7 @@ int main() { for (size_t i = lastInputIdx; i < SEQ_SIZE - 1; ++i) { bool firstIter = (i == lastInputIdx); - auto outputs = prog.run_async(prog_args, stream); + auto outputs = prog->run_async(prog_args, stream); if (not offload_copy) { if (firstIter) @@ -578,12 +580,6 @@ int main() { auto past_valueStr = getPastValueStr(i); model_inputs.past_value_buffers[i]->update(std::move(present_value)); prog_args.add(past_valueStr.c_str(), migraphx::argument(past_shape, model_inputs.past_value_buffers[i]->data())); - - if (offload_copy) - { - model_inputs.past_key_buffers[i]->upload_to_device(stream); - model_inputs.past_value_buffers[i]->upload_to_device(stream); - } } if (new_token == EOS) @@ -595,17 +591,20 @@ int main() { if (firstIter) { - prog = progSimpleInput; + prog = &progSimpleInput; output_size = VOCAB_SIZE; migraphx::shape out_shape{migraphx_shape_half_type, {1, 1, VOCAB_SIZE}}; prog_args.add(output_name, migraphx::argument(out_shape, output_buffer_oneDim.data())); - auto param_shapes = prog.get_parameter_shapes(); + auto param_shapes = prog->get_parameter_shapes(); auto inputShape = param_shapes[model_inputs.INPUTS_ID_STR]; std::vector input_ids = {new_token}; model_inputs.input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); prog_args.add(model_inputs.INPUTS_ID_STR, migraphx::argument(inputShape, model_inputs.input_ids_buffer->data())); - model_inputs.input_ids_buffer->upload_to_device(stream); + if (not offload_copy) + { + model_inputs.input_ids_buffer->upload_to_device(stream); + } } else { @@ -621,7 +620,14 @@ int main() { } std::cout << std::endl; #endif - auto updated = model_inputs.updateData(prog, prog_args); + prog = &progMultipleInputDim; + output_size = SEQ_SIZE * VOCAB_SIZE; + migraphx::shape out_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; + prog_args.add(output_name, migraphx::argument(out_shape, output_buffer_oneDim.data())); + + model_inputs.resetPastKeyValueBuffers(stream); + + auto updated = model_inputs.updateData(*prog, prog_args); if (updated && not offload_copy) { From d8a85cdd0789c92aff9b4c60b044eb3063c8ec1b Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Thu, 14 Nov 2024 07:03:14 -0600 Subject: [PATCH 25/55] Fix output buffer usage for new sample --- examples/transformers/mgx_llama2/mgxllama2.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 799eff678a7..bb8cec03285 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -623,7 +623,7 @@ int main() { prog = &progMultipleInputDim; output_size = SEQ_SIZE * VOCAB_SIZE; migraphx::shape out_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; - prog_args.add(output_name, migraphx::argument(out_shape, output_buffer_oneDim.data())); + prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); model_inputs.resetPastKeyValueBuffers(stream); From 37ed15b5ddb2dd8a59ce12a1499f0f57d189c327 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Thu, 14 Nov 2024 10:57:58 -0600 Subject: [PATCH 26/55] Comment out past/present_key_value binding --- examples/transformers/mgx_llama2/mgxllama2.cc | 70 +++++++++---------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index bb8cec03285..779b4d45c95 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -501,21 +501,21 @@ int main() { migraphx::shape out_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); - std::vector> present_key_buffers; - std::vector> present_value_buffers; - for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) - { - migraphx::shape present_shape{migraphx_shape_half_type, {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}}; - auto present_keyStr = getPresentKeyString(i); - auto present_keyString = present_keyStr.c_str(); - present_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); - prog_args.add(present_keyString, migraphx::argument(present_shape, present_key_buffers[i]->data())); - - auto present_valueStr = getPresentValueStr(i); - auto present_valueString = present_valueStr.c_str(); - present_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); - prog_args.add(present_valueString, migraphx::argument(present_shape, present_value_buffers[i]->data())); - } + // std::vector> present_key_buffers; + // std::vector> present_value_buffers; + // for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + // { + // migraphx::shape present_shape{migraphx_shape_half_type, {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}}; + // auto present_keyStr = getPresentKeyString(i); + // auto present_keyString = present_keyStr.c_str(); + // present_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + // prog_args.add(present_keyString, migraphx::argument(present_shape, present_key_buffers[i]->data())); + + // auto present_valueStr = getPresentValueStr(i); + // auto present_valueString = present_valueStr.c_str(); + // present_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + // prog_args.add(present_valueString, migraphx::argument(present_shape, present_value_buffers[i]->data())); + // } std::cout << "Dataset size: " << model_inputs.dataSize() << std::endl; std::cout << "Starting evaluation" << std::endl; @@ -542,11 +542,11 @@ int main() { output_buffer_oneDim.download_from_device(stream); } - for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) - { - present_key_buffers[i]->download_from_device(stream); - present_value_buffers[i]->download_from_device(stream); - } + // for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + // { + // present_key_buffers[i]->download_from_device(stream); + // present_value_buffers[i]->download_from_device(stream); + // } } check_hip_status(hipStreamSynchronize(stream)); @@ -564,23 +564,23 @@ int main() { #endif output_tokens.push_back(new_token); - for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) - { - migraphx::shape past_shape{migraphx_shape_half_type, {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}}; - half* res = offload_copy ? reinterpret_cast(outputs[2*i+1].data()) : reinterpret_cast(present_key_buffers[i]->hbuff.data()); - std::vector present_key(res, res + PAST_KEY_VAL_SIZE); + // for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + // { + // migraphx::shape past_shape{migraphx_shape_half_type, {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}}; + // half* res = offload_copy ? reinterpret_cast(outputs[2*i+1].data()) : reinterpret_cast(present_key_buffers[i]->hbuff.data()); + // std::vector present_key(res, res + PAST_KEY_VAL_SIZE); - auto past_keyStr = getPastKeyString(i); - model_inputs.past_key_buffers[i]->update(std::move(present_key)); - prog_args.add(past_keyStr.c_str(), migraphx::argument(past_shape, model_inputs.past_key_buffers[i]->data())); + // auto past_keyStr = getPastKeyString(i); + // model_inputs.past_key_buffers[i]->update(std::move(present_key)); + // prog_args.add(past_keyStr.c_str(), migraphx::argument(past_shape, present_key_buffers[i]->data())); - res = offload_copy ? reinterpret_cast(outputs[2*i+2].data()) : reinterpret_cast(present_value_buffers[i]->hbuff.data()); - std::vector present_value(res, res + PAST_KEY_VAL_SIZE); + // res = offload_copy ? reinterpret_cast(outputs[2*i+2].data()) : reinterpret_cast(present_value_buffers[i]->hbuff.data()); + // std::vector present_value(res, res + PAST_KEY_VAL_SIZE); - auto past_valueStr = getPastValueStr(i); - model_inputs.past_value_buffers[i]->update(std::move(present_value)); - prog_args.add(past_valueStr.c_str(), migraphx::argument(past_shape, model_inputs.past_value_buffers[i]->data())); - } + // auto past_valueStr = getPastValueStr(i); + // model_inputs.past_value_buffers[i]->update(std::move(present_value)); + // prog_args.add(past_valueStr.c_str(), migraphx::argument(past_shape, present_value_buffers[i]->data())); + // } if (new_token == EOS) { @@ -625,7 +625,7 @@ int main() { migraphx::shape out_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); - model_inputs.resetPastKeyValueBuffers(stream); + //model_inputs.resetPastKeyValueBuffers(stream); auto updated = model_inputs.updateData(*prog, prog_args); From d970b304d7642758ceb10d4c2bef729a1c6e717e Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Fri, 15 Nov 2024 01:00:00 -0600 Subject: [PATCH 27/55] Add new migraphx public API functions: replace_return and get_last_instruction --- src/api/api.cpp | 29 +++++++++++++++++++++++++++ src/api/include/migraphx/migraphx.h | 7 +++++++ src/api/include/migraphx/migraphx.hpp | 17 ++++++++++++++++ src/include/migraphx/module.hpp | 2 ++ src/module.cpp | 10 +++++++++ 5 files changed, 65 insertions(+) diff --git a/src/api/api.cpp b/src/api/api.cpp index 4ecd0763225..811fafd3f5a 100644 --- a/src/api/api.cpp +++ b/src/api/api.cpp @@ -1522,6 +1522,21 @@ extern "C" migraphx_status migraphx_module_add_instruction(migraphx_instruction_ return api_error_result; } +extern "C" migraphx_status migraphx_module_get_last_instruction(migraphx_instruction_t* out, + migraphx_module_t module) +{ + auto api_error_result = migraphx::try_([&] { + if(module == nullptr) + { + std::cout << "# migraphx_module_get_last_instruction nullptr" << std::endl; + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer"); + } + *out = allocate( + (module->object).get_last_instruction()); + }); + return api_error_result; +} + extern "C" migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out, migraphx_module_t module, @@ -1590,6 +1605,20 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou return api_error_result; } +extern "C" migraphx_status migraphx_module_replace_return(migraphx_instruction_t* out, + migraphx_module_t module, + migraphx_instructions_t args) +{ + auto api_error_result = migraphx::try_([&] { + if(module == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer"); + if(args == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer"); + *out = allocate((module->object).replace_return((args->object))); + }); + return api_error_result; +} + extern "C" migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out, migraphx_module_t module, const_migraphx_shape_t s) diff --git a/src/api/include/migraphx/migraphx.h b/src/api/include/migraphx/migraphx.h index 90ba7c3e017..ea829424aff 100644 --- a/src/api/include/migraphx/migraphx.h +++ b/src/api/include/migraphx/migraphx.h @@ -401,6 +401,9 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_instruction(migraphx_instr migraphx_operation_t op, migraphx_instructions_t args); +MIGRAPHX_C_EXPORT migraphx_status migraphx_module_get_last_instruction(migraphx_instruction_t* out, + migraphx_module_t module); + MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out, migraphx_module_t module, @@ -422,6 +425,10 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_return(migraphx_instructio migraphx_module_t module, migraphx_instructions_t args); +MIGRAPHX_C_EXPORT migraphx_status migraphx_module_replace_return(migraphx_instruction_t* out, + migraphx_module_t module, + migraphx_instructions_t args); + MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out, migraphx_module_t module, const_migraphx_shape_t s); diff --git a/src/api/include/migraphx/migraphx.hpp b/src/api/include/migraphx/migraphx.hpp index fa6339b4389..1870992e999 100644 --- a/src/api/include/migraphx/migraphx.hpp +++ b/src/api/include/migraphx/migraphx.hpp @@ -1049,6 +1049,16 @@ struct module return instruction(op_ins, own{}); } + instruction get_last_instruction() + { + std::cout << "# get_last_instruction called" << std::endl; + migraphx_instruction_t op_ins; + call(&migraphx_module_get_last_instruction, + &op_ins, + mm.get()); + return instruction(op_ins, own{}); + } + instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args, const migraphx::modules& module_args) @@ -1087,6 +1097,13 @@ struct module return instruction(ret_ins, own{}); } + instruction replace_return(const migraphx::instructions& args) + { + migraphx_instruction_t ret_ins; + call(&migraphx_module_replace_return, &ret_ins, mm.get(), args.get_handle_ptr()); + return instruction(ret_ins, own{}); + } + instruction add_allocation(const migraphx::shape& s) { migraphx_instruction_t ret_ins; diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 68734c82baa..e4a0529fcd8 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -86,6 +86,8 @@ struct MIGRAPHX_EXPORT module return add_instruction(op, {args...}); } + instruction_ref get_last_instruction(); + instruction_ref add_instruction(const operation& op, std::vector args); instruction_ref add_instruction(const operation& op, diff --git a/src/module.cpp b/src/module.cpp index 7e02478b385..2f1739a8578 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -288,6 +288,16 @@ instruction_ref module::add_instruction(const operation& op, std::vectorinstructions.end(), op, std::move(args)); } + +instruction_ref module::get_last_instruction() +{ + auto last_instr = std::prev(this->end()); + if (last_instr->name() == "@return") + last_instr = std::prev(last_instr); + return last_instr; +} + + instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, std::vector args) From a7801b045d99e67c52c7433739fd5d649aad641c Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Fri, 15 Nov 2024 01:05:06 -0600 Subject: [PATCH 28/55] Add romxProfileData to the example container --- examples/transformers/mgx_llama2/Dockerfile | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/transformers/mgx_llama2/Dockerfile b/examples/transformers/mgx_llama2/Dockerfile index ed252385bdf..db6bb1ac46c 100644 --- a/examples/transformers/mgx_llama2/Dockerfile +++ b/examples/transformers/mgx_llama2/Dockerfile @@ -8,12 +8,18 @@ RUN apt-get update && apt-get install -y --allow-unauthenticated \ apt-utils \ cmake \ half \ + sqlite3 \ + libsqlite3-dev \ + libfmt-dev \ git && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* -RUN mkdir /migraphx && cd /migraphx && git clone --branch develop https://github.com/ROCm/AMDMIGraphX.git +RUN mkdir /app && cd /app && git clone https://github.com/ROCm/rocmProfileData +WORKDIR /app/rocmProfileData +RUN make; make install +RUN mkdir /migraphx && cd /migraphx && git clone --branch develop https://github.com/ROCm/AMDMIGraphX.git WORKDIR /migraphx/AMDMIGraphX RUN ./tools/install_prereqs.sh From 9152cd15b0f33bd62145935a5a983bce2beb4321 Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Fri, 15 Nov 2024 01:06:29 -0600 Subject: [PATCH 29/55] Update MGX branch and en variables in the example docker --- examples/transformers/mgx_llama2/Dockerfile | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/transformers/mgx_llama2/Dockerfile b/examples/transformers/mgx_llama2/Dockerfile index db6bb1ac46c..d6e5bc53aba 100644 --- a/examples/transformers/mgx_llama2/Dockerfile +++ b/examples/transformers/mgx_llama2/Dockerfile @@ -19,12 +19,13 @@ RUN mkdir /app && cd /app && git clone https://github.com/ROCm/rocmProfileData WORKDIR /app/rocmProfileData RUN make; make install -RUN mkdir /migraphx && cd /migraphx && git clone --branch develop https://github.com/ROCm/AMDMIGraphX.git +RUN mkdir /migraphx && cd /migraphx && git clone --branch htec/mgx-llama2-7b-example https://github.com/ROCm/AMDMIGraphX.git WORKDIR /migraphx/AMDMIGraphX RUN ./tools/install_prereqs.sh - -RUN export test=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') +RUN export MIGRAPHX_ENABLE_HIPBLASLT_GEMM=1 +RUN export MIGRAPHX_USE_HIPBLASLT=1 +RUN export MIGRAPHX_USE_MIOPEN=1 #TODO: use $(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') for GPU_TARGETS RUN mkdir build && cd build && \ From 14df5db561a9cdf6c5916f06371f6ed46ac43c37 Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Fri, 15 Nov 2024 01:07:36 -0600 Subject: [PATCH 30/55] Fix example readme --- examples/transformers/mgx_llama2/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/transformers/mgx_llama2/README.md b/examples/transformers/mgx_llama2/README.md index de3bcd7d50b..a832a236a03 100644 --- a/examples/transformers/mgx_llama2/README.md +++ b/examples/transformers/mgx_llama2/README.md @@ -12,7 +12,7 @@ Alternatively you can quantize the model yourself. **If you are using the pre-quantized model you can skip this section.** -Get the latest quark quantizer version from https://xcoartifactory/ui/native/uai-pip-local/com/amd/quark/main/nightly/ . Donwloading the zip is recommended because it contains the required scripts. The quark version used when this was created: quark-1.0.0.dev20241028+eb46b7438 (28-10-24). +Get the latest quark quantizer version from https://xcoartifactory/ui/native/uai-pip-local/com/amd/quark/main/nightly/ . Downloading the zip is recommended because it contains the required scripts. The quark version used when this was created: quark-1.0.0.dev20241028+eb46b7438 (28-10-24). Also we will need to install the onnxruntime-genai (OGA) tool to convert the quark_safetensors format to onnx format properly. @@ -32,7 +32,7 @@ RUN pip install quark-1.0.0.dev20241028+eb46b7438-py3-none-any.whl #### Quantizing the model and converting to ONNX ```bash -cd examples/torch/language_modeling/llm_ptq +cd quark/examples/torch/language_modeling/llm_ptq export MODEL_DIR = [local model checkpoint folder] or meta-llama/Llama-2-7b-chat-hf or meta-llama/Llama-2-70b-chat-hf export QUANTIZED_MODEL_DIR = [output model checkpoint folder] From f7d3fbf7816fc31915f6912a11363c4f545406f7 Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Fri, 15 Nov 2024 01:09:48 -0600 Subject: [PATCH 31/55] Fix typo in example preproc script --- examples/transformers/mgx_llama2/preprocess_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/transformers/mgx_llama2/preprocess_dataset.py b/examples/transformers/mgx_llama2/preprocess_dataset.py index 0d43f890102..f3d096f0592 100644 --- a/examples/transformers/mgx_llama2/preprocess_dataset.py +++ b/examples/transformers/mgx_llama2/preprocess_dataset.py @@ -35,4 +35,4 @@ np.save(f"{OUTPUT_PATH}attention_mask_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", mask_np) np.save(f"{OUTPUT_PATH}position_ids_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", position_nps) -print("Npy filed are created") +print("Npy files are created") From 58bc2c2bc7816f693ce7df670509b64e652db382 Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Fri, 15 Nov 2024 01:11:16 -0600 Subject: [PATCH 32/55] Update LLama2 example to target specific device and enable fast_math by default --- examples/transformers/mgx_llama2/mgxllama2.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 779b4d45c95..121b2ab4b54 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -29,6 +29,8 @@ const size_t HIDDEN_LAYERS_NUM = 32; const size_t HEAD_SIZE = 128; const size_t PAST_KEY_VAL_SIZE = HIDDEN_LAYERS_NUM*HEAD_SIZE*SEQ_SIZE; +const int DEVICE_ID = 4; + struct ModelLoadSettings { size_t sequnce_length; @@ -298,6 +300,10 @@ struct Dataset position_ids.clear(); } } + else + { + std::cout << "Unable to open numpy files\n"; + } } void prepareSampleDataset() @@ -463,8 +469,9 @@ void writeResults(const std::vector>& results) int main() { bool offload_copy = false; + check_hip_status(hipSetDevice(DEVICE_ID)); std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; - ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, offload_copy /*offload_copy*/, false /*fast_math*/, false /*input_one_dim*/}; + ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, offload_copy /*offload_copy*/, true /*fast_math*/, false /*input_one_dim*/}; migraphx::program progMultipleInputDim = loadProgram(settings); std::cout << "Model loaded" << std::endl; From e2856a7bfa37b288f56b722043d851ed11958d6e Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Fri, 15 Nov 2024 03:28:57 -0600 Subject: [PATCH 33/55] Add eval_accuracy script dependencies to Dockerfile --- examples/transformers/mgx_llama2/Dockerfile | 4 +++- examples/transformers/mgx_llama2/eval_accuracy.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/transformers/mgx_llama2/Dockerfile b/examples/transformers/mgx_llama2/Dockerfile index d6e5bc53aba..f5feb7f7569 100644 --- a/examples/transformers/mgx_llama2/Dockerfile +++ b/examples/transformers/mgx_llama2/Dockerfile @@ -42,4 +42,6 @@ RUN rm -rf /mgx_llama2/build && mkdir /mgx_llama2/build WORKDIR /mgx_llama2/build RUN CXX=/opt/rocm/llvm/bin/clang++ cmake .. && make -RUN pip install pandas + +RUN pip install pandas evaluate nltk transformers sentencepiece rouge_score + diff --git a/examples/transformers/mgx_llama2/eval_accuracy.py b/examples/transformers/mgx_llama2/eval_accuracy.py index 041ce0ea500..803c55dd2f0 100644 --- a/examples/transformers/mgx_llama2/eval_accuracy.py +++ b/examples/transformers/mgx_llama2/eval_accuracy.py @@ -15,7 +15,7 @@ SAMPLE_SIZE = 10 DATASET_PATH = "/dataset/open_orca_gpt4_tokenized_llama.sampled_24576.pkl" -RESULT_PATH = "build/results.txt" +RESULT_PATH = "build/result.txt" def main(dataset_path, result_path, sample_size, sequence_size): tokenizer = AutoTokenizer.from_pretrained( @@ -25,7 +25,7 @@ def main(dataset_path, result_path, sample_size, sequence_size): use_fast=False,) metric = evaluate.load("rouge") - nltk.download("punkt") + nltk.download("punkt_tab") _p = Path(DATASET_PATH) if _p.exists(): From 632c02699cc73ebf25438e8d57acecc4d4c9503c Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Fri, 15 Nov 2024 05:49:46 -0600 Subject: [PATCH 34/55] Add argmax program to lessen DToH copy + CPU computation overhead --- .../mgx_llama2/harness/buffer.hpp | 1 + examples/transformers/mgx_llama2/mgxllama2.cc | 135 +++++++++--------- 2 files changed, 70 insertions(+), 66 deletions(-) diff --git a/examples/transformers/mgx_llama2/harness/buffer.hpp b/examples/transformers/mgx_llama2/harness/buffer.hpp index fc1f91daa59..fc13d73aa04 100644 --- a/examples/transformers/mgx_llama2/harness/buffer.hpp +++ b/examples/transformers/mgx_llama2/harness/buffer.hpp @@ -190,4 +190,5 @@ namespace mlinfer using LLama2InputBuffer = ManagedBuffer_v2; using LLama2OutputBuffer = ManagedBuffer_v2; using LLama2PastKeyValueBuffer = ManagedBuffer_v2; + using ArgMaxOutputBuffer = ManagedBuffer_v2; } diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 121b2ab4b54..ee97f1c72be 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -183,6 +183,47 @@ static migraphx::program loadProgram(ModelLoadSettings& settings) return prog; }; +static migraphx::program create_argmax_program(ModelLoadSettings& settings) +{ + migraphx::program prog; + std::vector dims {1, SEQ_SIZE, VOCAB_SIZE}; + if (settings.input_one_dim) + { + dims[1] = 1; + } + migraphx::shape s{migraphx_shape_half_type, dims}; + migraphx::module m = prog.get_main_module(); + auto x = m.add_parameter("x", s); + auto argmax_ins = m.add_instruction(migraphx::operation("argmax", "{axis: 2}"), {x}); + m.add_return({argmax_ins}); + + std::cout << "Creating ArgMax program ..." << std::endl; + + std::string target_str = "gpu"; + migraphx::target targ = migraphx::target(target_str.c_str()); + + if (settings.quantize_fp16) + { + std::cout << "Quantize FP16 ..." << std::endl; + migraphx::quantize_fp16(prog); + } + + migraphx::compile_options comp_opts; + + if (settings.offload_copy) + comp_opts.set_offload_copy(); + + if (settings.fast_math) + comp_opts.set_fast_math(); + + comp_opts.set_exhaustive_tune_flag(); + + std::cout << "Compile to target ..." << std::endl; + prog.compile(targ, comp_opts); + + return prog; +} + using NumpyVector = std::vector>; struct Dataset @@ -474,13 +515,18 @@ int main() { ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, offload_copy /*offload_copy*/, true /*fast_math*/, false /*input_one_dim*/}; migraphx::program progMultipleInputDim = loadProgram(settings); std::cout << "Model loaded" << std::endl; + migraphx::program progArgMaxMultipleInputDim = create_argmax_program(settings); + std::cout << "ArgMax model created" << std::endl; // Load {1,1} input_ids model settings.input_one_dim = true; migraphx::program progSimpleInput = loadProgram(settings); std::cout << "Model 1 dim input loaded" << std::endl; + migraphx::program progArgMaxSimpleInput = create_argmax_program(settings); + std::cout << "ArgMax model for 1 dim model created" << std::endl; migraphx::program *prog = &progMultipleInputDim; + migraphx::program *progArgMax = &progArgMaxMultipleInputDim; // Setup model inputs std::vector> results; @@ -494,35 +540,28 @@ int main() { model_inputs.upload_to_device(stream); } - // Setup model output for non-offload copy - // const size_t output_size = SEQ_SIZE * VOCAB_SIZE; - // auto output_name = "main:#output_0"; - // auto output_buffer = LLama2OutputBuffer(std::vector(output_size), offload_copy); - // migraphx::shape out_shape{migraphx_shape_float_type, {1, SEQ_SIZE, VOCAB_SIZE}}; - // prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); + auto output_name = "main:#output_0"; size_t output_size = SEQ_SIZE * VOCAB_SIZE; - auto output_name = "main:#output_0"; auto output_buffer = LLama2PastKeyValueBuffer(std::vector(output_size), offload_copy); auto output_buffer_oneDim = LLama2PastKeyValueBuffer(std::vector(VOCAB_SIZE), offload_copy); migraphx::shape out_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); - // std::vector> present_key_buffers; - // std::vector> present_value_buffers; - // for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) - // { - // migraphx::shape present_shape{migraphx_shape_half_type, {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}}; - // auto present_keyStr = getPresentKeyString(i); - // auto present_keyString = present_keyStr.c_str(); - // present_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); - // prog_args.add(present_keyString, migraphx::argument(present_shape, present_key_buffers[i]->data())); - - // auto present_valueStr = getPresentValueStr(i); - // auto present_valueString = present_valueStr.c_str(); - // present_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); - // prog_args.add(present_valueString, migraphx::argument(present_shape, present_value_buffers[i]->data())); - // } + // setting up argmax arguments + migraphx::program_parameters prog_args_argmax; + migraphx::shape x_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; + prog_args_argmax.add("x", migraphx::argument(x_shape, output_buffer.data())); + auto argm_output_buffer = ArgMaxOutputBuffer(std::vector(VOCAB_SIZE), offload_copy); + migraphx::shape argm_out_shape{migraphx_shape_int64_type, {1, SEQ_SIZE, 1}}; + prog_args_argmax.add(output_name, migraphx::argument(argm_out_shape, argm_output_buffer.data())); + + migraphx::program_parameters prog_args_argmax_one_dim; + migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {1, 1, VOCAB_SIZE}}; + prog_args_argmax_one_dim.add("x", migraphx::argument(x_shape_one_dim, output_buffer_oneDim.data())); + auto argm_output_buffer_one_dim = ArgMaxOutputBuffer(std::vector(1), offload_copy); + migraphx::shape argm_out_shape_one_dim{migraphx_shape_int64_type, {1, 1, 1}}; + prog_args_argmax_one_dim.add(output_name, migraphx::argument(argm_out_shape_one_dim, argm_output_buffer_one_dim.data())); std::cout << "Dataset size: " << model_inputs.dataSize() << std::endl; std::cout << "Starting evaluation" << std::endl; @@ -537,33 +576,17 @@ int main() { for (size_t i = lastInputIdx; i < SEQ_SIZE - 1; ++i) { bool firstIter = (i == lastInputIdx); - auto outputs = prog->run_async(prog_args, stream); + prog->run_async(prog_args, stream); + auto outputs = progArgMax->run_async(firstIter ? prog_args_argmax : prog_args_argmax_one_dim, stream); if (not offload_copy) { - if (firstIter) - { - output_buffer.download_from_device(stream, i * VOCAB_SIZE, (i + 1) * VOCAB_SIZE); - } - else - { - output_buffer_oneDim.download_from_device(stream); - } - - // for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) - // { - // present_key_buffers[i]->download_from_device(stream); - // present_value_buffers[i]->download_from_device(stream); - // } + firstIter ? argm_output_buffer.download_from_device(stream, i, i + 1) : argm_output_buffer_one_dim.download_from_device(stream); } check_hip_status(hipStreamSynchronize(stream)); - half* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast( firstIter ? output_buffer.hbuff.data() : output_buffer_oneDim.hbuff.data()); - std::vector logits(results, results + output_size); - - auto logits_begin = firstIter ? std::begin(logits) + (i * VOCAB_SIZE) : std::begin(logits); - auto logits_end = firstIter ? std::begin(logits) + ((i + 1) * VOCAB_SIZE) : std::end(logits); - std::vector::iterator max = std::max_element(logits_begin, logits_end); - int64_t new_token = std::distance(logits_begin, max); + int64_t* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast( firstIter ? argm_output_buffer.hbuff.data() : argm_output_buffer_one_dim.hbuff.data()); + auto new_token_idx = firstIter ? i : 0; + int64_t new_token = results[new_token_idx]; token_count++; #ifdef TRACE @@ -571,24 +594,6 @@ int main() { #endif output_tokens.push_back(new_token); - // for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) - // { - // migraphx::shape past_shape{migraphx_shape_half_type, {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}}; - // half* res = offload_copy ? reinterpret_cast(outputs[2*i+1].data()) : reinterpret_cast(present_key_buffers[i]->hbuff.data()); - // std::vector present_key(res, res + PAST_KEY_VAL_SIZE); - - // auto past_keyStr = getPastKeyString(i); - // model_inputs.past_key_buffers[i]->update(std::move(present_key)); - // prog_args.add(past_keyStr.c_str(), migraphx::argument(past_shape, present_key_buffers[i]->data())); - - // res = offload_copy ? reinterpret_cast(outputs[2*i+2].data()) : reinterpret_cast(present_value_buffers[i]->hbuff.data()); - // std::vector present_value(res, res + PAST_KEY_VAL_SIZE); - - // auto past_valueStr = getPastValueStr(i); - // model_inputs.past_value_buffers[i]->update(std::move(present_value)); - // prog_args.add(past_valueStr.c_str(), migraphx::argument(past_shape, present_value_buffers[i]->data())); - // } - if (new_token == EOS) { break; @@ -599,7 +604,7 @@ int main() { if (firstIter) { prog = &progSimpleInput; - output_size = VOCAB_SIZE; + progArgMax = &progArgMaxSimpleInput; migraphx::shape out_shape{migraphx_shape_half_type, {1, 1, VOCAB_SIZE}}; prog_args.add(output_name, migraphx::argument(out_shape, output_buffer_oneDim.data())); @@ -628,12 +633,10 @@ int main() { std::cout << std::endl; #endif prog = &progMultipleInputDim; - output_size = SEQ_SIZE * VOCAB_SIZE; + progArgMax = &progArgMaxMultipleInputDim; migraphx::shape out_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); - //model_inputs.resetPastKeyValueBuffers(stream); - auto updated = model_inputs.updateData(*prog, prog_args); if (updated && not offload_copy) From 1929b44bb365217cc2da4ee92d1724de1ebcd617 Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Sun, 17 Nov 2024 05:24:42 -0600 Subject: [PATCH 35/55] Update LLama2 example docker file ENVs --- examples/transformers/mgx_llama2/Dockerfile | 8 ++++---- examples/transformers/mgx_llama2/README.md | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/transformers/mgx_llama2/Dockerfile b/examples/transformers/mgx_llama2/Dockerfile index f5feb7f7569..6525fb81619 100644 --- a/examples/transformers/mgx_llama2/Dockerfile +++ b/examples/transformers/mgx_llama2/Dockerfile @@ -19,13 +19,13 @@ RUN mkdir /app && cd /app && git clone https://github.com/ROCm/rocmProfileData WORKDIR /app/rocmProfileData RUN make; make install -RUN mkdir /migraphx && cd /migraphx && git clone --branch htec/mgx-llama2-7b-example https://github.com/ROCm/AMDMIGraphX.git +RUN mkdir /migraphx && cd /migraphx && git clone --branch develop https://github.com/ROCm/AMDMIGraphX.git WORKDIR /migraphx/AMDMIGraphX RUN ./tools/install_prereqs.sh -RUN export MIGRAPHX_ENABLE_HIPBLASLT_GEMM=1 -RUN export MIGRAPHX_USE_HIPBLASLT=1 -RUN export MIGRAPHX_USE_MIOPEN=1 +ENV MIGRAPHX_ENABLE_HIPBLASLT_GEMM=1 +ENV MIGRAPHX_USE_HIPBLASLT=1 +ENV MIGRAPHX_USE_MIOPEN=1 #TODO: use $(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') for GPU_TARGETS RUN mkdir build && cd build && \ diff --git a/examples/transformers/mgx_llama2/README.md b/examples/transformers/mgx_llama2/README.md index a832a236a03..a470919365a 100644 --- a/examples/transformers/mgx_llama2/README.md +++ b/examples/transformers/mgx_llama2/README.md @@ -85,6 +85,7 @@ CXX=/opt/rocm/llvm/bin/clang++ cmake .. make -j # Running the example +export MIOPEN_FIND_ENFORCE=3 ./mgxllama2 # Test the accuracy of the output From 19f36b4a1d801fa63cec6293b2af5ed5e9887ead Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Sun, 17 Nov 2024 05:25:15 -0600 Subject: [PATCH 36/55] Disable fast_math --- examples/transformers/mgx_llama2/mgxllama2.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index ee97f1c72be..d777f727ff2 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -512,7 +512,7 @@ int main() { bool offload_copy = false; check_hip_status(hipSetDevice(DEVICE_ID)); std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; - ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, offload_copy /*offload_copy*/, true /*fast_math*/, false /*input_one_dim*/}; + ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, offload_copy /*offload_copy*/, false /*fast_math*/, false /*input_one_dim*/}; migraphx::program progMultipleInputDim = loadProgram(settings); std::cout << "Model loaded" << std::endl; migraphx::program progArgMaxMultipleInputDim = create_argmax_program(settings); From 74c9f4df8c82013eb29a6bdb565b99c68199419f Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Tue, 19 Nov 2024 04:31:52 -0600 Subject: [PATCH 37/55] Improve input_ids buffer handling, use different program arguments for two models --- examples/transformers/mgx_llama2/mgxllama2.cc | 98 +++++++++---------- 1 file changed, 44 insertions(+), 54 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index d777f727ff2..5768b4c7ca3 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -116,7 +116,6 @@ static migraphx::program loadOnnx(ModelLoadSettings& settings) } onnx_opts.set_input_parameter_shape("input_ids", inputDim); onnx_opts.set_input_parameter_shape("attention_mask", dims); - onnx_opts.set_input_parameter_shape("position_ids", dims); for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) { onnx_opts.set_input_parameter_shape(getPastKeyString(i), dimsPastKey); @@ -279,7 +278,6 @@ struct Dataset std::vector getInputIds() { return input_ids[_current_idx]; } std::vector getAttentionMask() { return attention_mask[_current_idx]; } - std::vector getPositionIds() { return position_ids[_current_idx]; } size_t size() const { return _size; } size_t currentIdx() const { return _current_idx; } @@ -311,24 +309,20 @@ struct Dataset { std::string input_file_path = getDatasetPath("input_ids"); std::string attention_mask_file_path = getDatasetPath("attention_mask"); - std::string position_ids_file_path = getDatasetPath("position_ids"); std::cout << "Input ids file: " << input_file_path << std::endl; std::ifstream input_file(input_file_path.c_str()); std::ifstream attention_mask_file(attention_mask_file_path.c_str()); - std::ifstream position_ids_file(position_ids_file_path.c_str()); - if (input_file.good() && attention_mask_file.good() && position_ids_file.good()) + if (input_file.good() && attention_mask_file.good()) { npy::NpyFile input_ids_npy{input_file_path}; npy::NpyFile attention_mask_npy{attention_mask_file_path}; - npy::NpyFile position_ids_npy{position_ids_file_path}; input_ids = loadNumpy(input_ids_npy); attention_mask = loadNumpy(attention_mask_npy); - position_ids = loadNumpy(position_ids_npy); _size = input_ids.size(); - if ((input_ids.size() == attention_mask.size()) && (attention_mask.size() == position_ids.size())) + if (input_ids.size() == attention_mask.size()) { std::cout << "Loaded numpy files\n"; _npy_files_loaded = true; @@ -338,7 +332,6 @@ struct Dataset std::cout << "Numpy files do not have the same size\n"; input_ids.clear(); attention_mask.clear(); - position_ids.clear(); } } else @@ -359,19 +352,11 @@ struct Dataset }); attention_mask.emplace_back(std::move(attention_mask_sample)); - std::vector position_ids_sample; - for (int64_t i=0; i < SEQ_SIZE; ++i) - { - position_ids_sample.emplace_back(i); - } - position_ids.emplace_back(std::move(position_ids_sample)); - _size = 1; } NumpyVector input_ids; NumpyVector attention_mask; - NumpyVector position_ids; size_t _size = 0; size_t _current_idx = 0; @@ -387,18 +372,27 @@ struct LLama2Inputs : offload_copy(offload_copy) { data.initialize(); + prepareProgArgs(prog, prog_args); + } + void prepareProgArgs(migraphx::program& prog, migraphx::program_parameters& prog_args, bool simple = false) + { auto param_shapes = prog.get_parameter_shapes(); - - auto inputShape = param_shapes[INPUTS_ID_STR]; - auto input_ids = data.getInputIds(); - input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); - prog_args.add(INPUTS_ID_STR, migraphx::argument(inputShape, input_ids_buffer->data())); + if (!simple) + { + auto inputShape = param_shapes[INPUTS_ID_STR]; + auto input_ids = data.getInputIds(); + input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); + prog_args.add(INPUTS_ID_STR, migraphx::argument(inputShape, input_ids_buffer->data())); + } auto attShape = param_shapes[ATTENTION_MASK_STR]; auto attention_mask = data.getAttentionMask(); - attention_mask_buffer = std::make_unique(std::move(attention_mask), offload_copy); + if (!simple) + { + attention_mask_buffer = std::make_unique(std::move(attention_mask), offload_copy); + } prog_args.add(ATTENTION_MASK_STR, migraphx::argument(attShape, attention_mask_buffer->data())); // past_key_values.0.key = @param:past_key_values.0.key -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} @@ -407,24 +401,29 @@ struct LLama2Inputs { auto past_keyStr = getPastKeyString(i); auto past_keyString = past_keyStr.c_str(); - past_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + if (!simple) + { + past_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + } auto pastKeyShape = param_shapes[past_keyString]; prog_args.add(past_keyString, migraphx::argument(pastKeyShape, past_key_buffers[i]->data())); auto past_valueStr = getPastValueStr(i); auto past_valueString = past_valueStr.c_str(); - past_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + if (!simple) + { + past_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + } auto pastValueShape = param_shapes[past_valueString]; prog_args.add(past_valueString, migraphx::argument(pastValueShape, past_value_buffers[i]->data())); } - }; + } void upload_to_device(hipStream_t stream) { assert(not offload_copy); input_ids_buffer->upload_to_device(stream); attention_mask_buffer->upload_to_device(stream); - //position_ids_buffer->upload_to_device(stream); } bool updateData(migraphx::program& prog, migraphx::program_parameters& prog_args) @@ -445,12 +444,6 @@ struct LLama2Inputs prog_args.add(ATTENTION_MASK_STR, migraphx::argument(param_shapes[ATTENTION_MASK_STR], attention_mask_buffer->data())); } - // auto position_ids = data.getPositionIds(); - // position_ids_buffer->update(std::move(position_ids)); - // if (offload_copy) - // { - // prog_args.add(POSITION_IDS_STR, migraphx::argument(param_shapes[POSITION_IDS_STR], position_ids_buffer->data())); - // } return true; } return false; @@ -479,7 +472,6 @@ struct LLama2Inputs std::unique_ptr input_ids_buffer; std::unique_ptr attention_mask_buffer; - //std::unique_ptr position_ids_buffer; std::vector> past_key_buffers; std::vector> past_value_buffers; Dataset data; @@ -487,7 +479,6 @@ struct LLama2Inputs const char* INPUTS_ID_STR = "input_ids"; const char* ATTENTION_MASK_STR = "attention_mask"; - const char* POSITION_IDS_STR = "position_ids"; }; void writeResults(const std::vector>& results) @@ -563,6 +554,21 @@ int main() { migraphx::shape argm_out_shape_one_dim{migraphx_shape_int64_type, {1, 1, 1}}; prog_args_argmax_one_dim.add(output_name, migraphx::argument(argm_out_shape_one_dim, argm_output_buffer_one_dim.data())); + + migraphx::program_parameters prog_args_one_dim; + model_inputs.prepareProgArgs(progSimpleInput, prog_args_one_dim, true); + auto param_shapes = progSimpleInput.get_parameter_shapes(); + auto inputShape = param_shapes[model_inputs.INPUTS_ID_STR]; + std::vector oneDimInput = {0}; + std::unique_ptr one_dim_input_buffer = std::make_unique(std::move(oneDimInput), offload_copy); + prog_args_one_dim.add(model_inputs.INPUTS_ID_STR, migraphx::argument(inputShape, one_dim_input_buffer->data())); + prog_args_one_dim.add(output_name, migraphx::argument(x_shape_one_dim, output_buffer_oneDim.data())); + + if (not offload_copy) + { + one_dim_input_buffer->upload_to_device(stream); + } + std::cout << "Dataset size: " << model_inputs.dataSize() << std::endl; std::cout << "Starting evaluation" << std::endl; size_t token_count = 0; @@ -576,7 +582,7 @@ int main() { for (size_t i = lastInputIdx; i < SEQ_SIZE - 1; ++i) { bool firstIter = (i == lastInputIdx); - prog->run_async(prog_args, stream); + prog->run_async(firstIter ? prog_args : prog_args_one_dim, stream); auto outputs = progArgMax->run_async(firstIter ? prog_args_argmax : prog_args_argmax_one_dim, stream); if (not offload_copy) { @@ -605,23 +611,9 @@ int main() { { prog = &progSimpleInput; progArgMax = &progArgMaxSimpleInput; - migraphx::shape out_shape{migraphx_shape_half_type, {1, 1, VOCAB_SIZE}}; - prog_args.add(output_name, migraphx::argument(out_shape, output_buffer_oneDim.data())); - - auto param_shapes = prog->get_parameter_shapes(); - auto inputShape = param_shapes[model_inputs.INPUTS_ID_STR]; - std::vector input_ids = {new_token}; - model_inputs.input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); - prog_args.add(model_inputs.INPUTS_ID_STR, migraphx::argument(inputShape, model_inputs.input_ids_buffer->data())); - if (not offload_copy) - { - model_inputs.input_ids_buffer->upload_to_device(stream); - } - } - else - { - model_inputs.input_ids_buffer->update_data(new_token, 0, stream); } + + one_dim_input_buffer->update_data(new_token, 0, stream); } #ifdef TRACE @@ -634,8 +626,6 @@ int main() { #endif prog = &progMultipleInputDim; progArgMax = &progArgMaxMultipleInputDim; - migraphx::shape out_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; - prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); auto updated = model_inputs.updateData(*prog, prog_args); From 9f16a146c3fad9f6636cc2460caf1da75fb033d8 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Tue, 19 Nov 2024 10:46:01 -0600 Subject: [PATCH 38/55] Move mgxllama2 components to multiple files --- .../mgx_llama2/harness/config.hpp | 21 + .../mgx_llama2/harness/dataset.hpp | 149 ++++++ .../mgx_llama2/harness/llama2inputs.hpp | 122 +++++ .../transformers/mgx_llama2/harness/utils.hpp | 220 ++++++++ examples/transformers/mgx_llama2/mgxllama2.cc | 499 +----------------- 5 files changed, 515 insertions(+), 496 deletions(-) create mode 100644 examples/transformers/mgx_llama2/harness/config.hpp create mode 100644 examples/transformers/mgx_llama2/harness/dataset.hpp create mode 100644 examples/transformers/mgx_llama2/harness/llama2inputs.hpp create mode 100644 examples/transformers/mgx_llama2/harness/utils.hpp diff --git a/examples/transformers/mgx_llama2/harness/config.hpp b/examples/transformers/mgx_llama2/harness/config.hpp new file mode 100644 index 00000000000..1adf21e3385 --- /dev/null +++ b/examples/transformers/mgx_llama2/harness/config.hpp @@ -0,0 +1,21 @@ +#pragma once + +// TODO: fix paths +const std::string MODEL_FOLDER = "/model/"; +const std::string ONNX_FILE = "model.onnx"; +const std::string DATASET_FOLDER = "/dataset/"; +const size_t DATASET_SIZE = 10; +// sequence length from model config +const size_t SEQ_SIZE = 1024; +// vocab size from model config +const size_t VOCAB_SIZE = 32000; +// EOS token from model config +const size_t EOS = 2; +// Write output tokens to file +const bool WRITE_RESULT_FILE = false; + +const int DEVICE_ID = 4; + +const size_t HIDDEN_LAYERS_NUM = 32; +const size_t HEAD_SIZE = 128; +const size_t PAST_KEY_VAL_SIZE = HIDDEN_LAYERS_NUM*HEAD_SIZE*SEQ_SIZE; \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/harness/dataset.hpp b/examples/transformers/mgx_llama2/harness/dataset.hpp new file mode 100644 index 00000000000..5d4722f8197 --- /dev/null +++ b/examples/transformers/mgx_llama2/harness/dataset.hpp @@ -0,0 +1,149 @@ +#pragma once + +#include "config.hpp" +#include "numpy.hpp" + +#include +#include + +using namespace mlinfer; + +using NumpyVector = std::vector>; + +struct Dataset +{ + Dataset() = default; + + void initialize() + { + loadDataset(); + if (!_npy_files_loaded) + { + prepareSampleDataset(); + } + } + + NumpyVector loadNumpy(npy::NpyFile& file) + { + NumpyVector numpyDataAll; + auto load_size = file.GetTensorSize()/sizeof(int64_t); + numpyDataAll.push_back(std::vector(load_size)); + file.LoadAll(numpyDataAll.back().data()); + + NumpyVector numpyData; + for(size_t i = 0; i < numpyDataAll.back().size(); i += SEQ_SIZE) + { + auto last = std::min(numpyDataAll.back().size(), i + SEQ_SIZE); + numpyData.emplace_back(numpyDataAll.back().begin() + i, numpyDataAll.back().begin() + last); + } + +#ifdef TRACE + for (auto& vec: numpyData) + { + std::cout << "Vector size: " << vec.size() << std::endl; + for (auto val: vec) + { + std::cout << val << " "; + } + std::cout << "\n"; + } +#endif + return numpyData; + } + + size_t getLastIdx() const + { + auto res = std::find_if(std::rbegin(attention_mask[_current_idx]), std::rend(attention_mask[_current_idx]), [](uint64_t val) { return 1 == val;}); + size_t last_idx = std::distance(res, std::rend(attention_mask[_current_idx])); + #ifdef TRACE + std::cout << "Last input idx: " << last_idx << std::endl; + #endif + return last_idx; + } + + std::vector getInputIds() { return input_ids[_current_idx]; } + std::vector getAttentionMask() { return attention_mask[_current_idx]; } + + size_t size() const { return _size; } + size_t currentIdx() const { return _current_idx; } + size_t getNext() + { + if (_current_idx < size() - 1) + { + ++_current_idx; + } + #ifdef TRACE + std::cout << "Current idx: " << _current_idx << std::endl; + #endif + return _current_idx; + } + + Dataset(const Dataset &buf) = delete; + Dataset &operator=(const Dataset &buf) = delete; +private: + + // e.g.: /dataset/input_ids_size_3_seq_256.npy + std::string getDatasetPath(const std::string& datasetName) + { + std::stringstream path; + path << DATASET_FOLDER << datasetName << "_size_" << std::to_string(DATASET_SIZE) << "_seq_" << std::to_string(SEQ_SIZE) << ".npy"; + return path.str(); + } + + void loadDataset() + { + std::string input_file_path = getDatasetPath("input_ids"); + std::string attention_mask_file_path = getDatasetPath("attention_mask"); + + std::cout << "Input ids file: " << input_file_path << std::endl; + std::ifstream input_file(input_file_path.c_str()); + std::ifstream attention_mask_file(attention_mask_file_path.c_str()); + if (input_file.good() && attention_mask_file.good()) + { + npy::NpyFile input_ids_npy{input_file_path}; + npy::NpyFile attention_mask_npy{attention_mask_file_path}; + input_ids = loadNumpy(input_ids_npy); + attention_mask = loadNumpy(attention_mask_npy); + + _size = input_ids.size(); + + if (input_ids.size() == attention_mask.size()) + { + std::cout << "Loaded numpy files\n"; + _npy_files_loaded = true; + } + else + { + std::cout << "Numpy files do not have the same size\n"; + input_ids.clear(); + attention_mask.clear(); + } + } + else + { + std::cout << "Unable to open numpy files\n"; + } + } + + void prepareSampleDataset() + { + std::cout << "Numpy files are not loaded, using dummy data\n"; + std::vector input_ids_sample = {1,6804,338,5207,387,287,29973}; + input_ids_sample.resize(SEQ_SIZE, 0); + std::vector attention_mask_sample = input_ids_sample; + input_ids.emplace_back(std::move(input_ids_sample)); + std::transform(std::begin(attention_mask_sample), std::end(attention_mask_sample), std::begin(attention_mask_sample), [](auto i){ + return (i != 0) ? 1 : 0; + }); + attention_mask.emplace_back(std::move(attention_mask_sample)); + + _size = 1; + } + + NumpyVector input_ids; + NumpyVector attention_mask; + + size_t _size = 0; + size_t _current_idx = 0; + bool _npy_files_loaded = false; +}; \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp new file mode 100644 index 00000000000..092dd0b95b0 --- /dev/null +++ b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp @@ -0,0 +1,122 @@ +#pragma once + +#include "config.hpp" +#include "utils.hpp" + +struct LLama2Inputs +{ + LLama2Inputs( + migraphx::program& prog, + migraphx::program_parameters& prog_args, + bool offload_copy) + : offload_copy(offload_copy) + { + data.initialize(); + prepareProgArgs(prog, prog_args); + } + + void prepareProgArgs(migraphx::program& prog, migraphx::program_parameters& prog_args, bool simple = false) + { + auto param_shapes = prog.get_parameter_shapes(); + if (!simple) + { + auto inputShape = param_shapes[INPUTS_ID_STR]; + auto input_ids = data.getInputIds(); + input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); + prog_args.add(INPUTS_ID_STR, migraphx::argument(inputShape, input_ids_buffer->data())); + } + + + auto attShape = param_shapes[ATTENTION_MASK_STR]; + auto attention_mask = data.getAttentionMask(); + if (!simple) + { + attention_mask_buffer = std::make_unique(std::move(attention_mask), offload_copy); + } + prog_args.add(ATTENTION_MASK_STR, migraphx::argument(attShape, attention_mask_buffer->data())); + + // past_key_values.0.key = @param:past_key_values.0.key -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} + // past_key_values.0.value = @param:past_key_values.0.value -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} + for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + { + auto past_keyStr = getPastKeyString(i); + auto past_keyString = past_keyStr.c_str(); + if (!simple) + { + past_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + } + auto pastKeyShape = param_shapes[past_keyString]; + prog_args.add(past_keyString, migraphx::argument(pastKeyShape, past_key_buffers[i]->data())); + + auto past_valueStr = getPastValueStr(i); + auto past_valueString = past_valueStr.c_str(); + if (!simple) + { + past_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + } + auto pastValueShape = param_shapes[past_valueString]; + prog_args.add(past_valueString, migraphx::argument(pastValueShape, past_value_buffers[i]->data())); + } + } + + void upload_to_device(hipStream_t stream) + { + assert(not offload_copy); + input_ids_buffer->upload_to_device(stream); + attention_mask_buffer->upload_to_device(stream); + } + + bool updateData(migraphx::program& prog, migraphx::program_parameters& prog_args) + { + auto currentIdx = data.currentIdx(); + if (currentIdx != data.getNext()) + { + auto param_shapes = prog.get_parameter_shapes(); + + std::vector input_ids = data.getInputIds(); + input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); + prog_args.add(INPUTS_ID_STR, migraphx::argument(param_shapes[INPUTS_ID_STR], input_ids_buffer->data())); + + auto attention_mask = data.getAttentionMask(); + attention_mask_buffer->update(std::move(attention_mask)); + if (offload_copy) + { + prog_args.add(ATTENTION_MASK_STR, migraphx::argument(param_shapes[ATTENTION_MASK_STR], attention_mask_buffer->data())); + } + + return true; + } + return false; + } + + void resetPastKeyValueBuffers(hipStream_t stream) + { + for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + { + past_key_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); + past_value_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); + if (not offload_copy) + { + past_key_buffers[i]->upload_to_device(stream); + past_value_buffers[i]->upload_to_device(stream); + } + } + } + + size_t getLastInputIndex() const { return data.getLastIdx(); } + size_t dataSize() const { return data.size(); } + + LLama2Inputs() = delete; + LLama2Inputs(const LLama2Inputs &buf) = delete; + LLama2Inputs &operator=(const LLama2Inputs &buf) = delete; + + std::unique_ptr input_ids_buffer; + std::unique_ptr attention_mask_buffer; + std::vector> past_key_buffers; + std::vector> past_value_buffers; + Dataset data; + bool offload_copy; + + const char* INPUTS_ID_STR = "input_ids"; + const char* ATTENTION_MASK_STR = "attention_mask"; +}; diff --git a/examples/transformers/mgx_llama2/harness/utils.hpp b/examples/transformers/mgx_llama2/harness/utils.hpp new file mode 100644 index 00000000000..f34b14095cf --- /dev/null +++ b/examples/transformers/mgx_llama2/harness/utils.hpp @@ -0,0 +1,220 @@ +#pragma once + +#include "config.hpp" + +#include +#include +#include + +#include + + +struct ModelLoadSettings +{ + size_t sequnce_length; + bool quantize_fp16; + bool offload_copy; + bool fast_math; + bool input_one_dim; +}; + +static std::string getModelPath(ModelLoadSettings& s) +{ + std::stringstream path; + path << MODEL_FOLDER << "model-" << std::to_string(s.sequnce_length) << "_fp" << (s.quantize_fp16 ? "16" : "32") << "_"; + if (!s.offload_copy) + { + path << "no"; + } + path << "offload_"; + if (!s.fast_math) + { + path << "no"; + } + path << "fastmath"; + if (s.input_one_dim) + { + path << "_inputonedim"; + } + path << ".mxr"; + return path.str(); +} + +[[maybe_unused]] static std::string getPastKeyString(size_t i) +{ + std::stringstream past_key; + past_key << "past_key_values." << std::to_string(i) << ".key"; + return past_key.str(); +} + +[[maybe_unused]] static std::string getPastValueStr(size_t i) +{ + std::stringstream past_val; + past_val << "past_key_values." << std::to_string(i) << ".value"; + return past_val.str(); +} + +[[maybe_unused]] static std::string getPresentKeyString(size_t i) +{ + std::stringstream past_key; + past_key << "present." << std::to_string(i) << ".key"; + return past_key.str(); +} + +[[maybe_unused]] static std::string getPresentValueStr(size_t i) +{ + std::stringstream past_val; + past_val << "present." << std::to_string(i) << ".value"; + return past_val.str(); +} + +static migraphx::program loadOnnx(ModelLoadSettings& settings) +{ + std::filesystem::path onnx_path(MODEL_FOLDER + ONNX_FILE); + + #ifdef TRACE + std::cout << "Using model: " << MODEL_FOLDER + ONNX_FILE << std::endl; + #endif + + migraphx::program prog; + std::ifstream f(onnx_path.c_str()); + if (f.good()) + { + migraphx::onnx_options onnx_opts; + std::vector dims = {1, SEQ_SIZE}; + std::vector dimsPastKey = {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}; + std::vector inputDim; + if (settings.input_one_dim) + { + inputDim = {1,1}; + } + else + { + inputDim = dims; + } + onnx_opts.set_input_parameter_shape("input_ids", inputDim); + onnx_opts.set_input_parameter_shape("attention_mask", dims); + for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + { + onnx_opts.set_input_parameter_shape(getPastKeyString(i), dimsPastKey); + onnx_opts.set_input_parameter_shape(getPastValueStr(i), dimsPastKey); + } + std::cout << "Parsing onnx file ..." << std::endl; + prog = parse_onnx(onnx_path.c_str(), onnx_opts); + + std::string target_str = "gpu"; + migraphx::target targ = migraphx::target(target_str.c_str()); + + if (settings.quantize_fp16) + { + std::cout << "Quantize FP16 ..." << std::endl; + migraphx::quantize_fp16(prog); + } + + migraphx::compile_options comp_opts; + + if (settings.offload_copy) + comp_opts.set_offload_copy(); + + if (settings.fast_math) + comp_opts.set_fast_math(); + + comp_opts.set_exhaustive_tune_flag(); + + std::cout << "Compile to target ..." << std::endl; + prog.compile(targ, comp_opts); + + std::string modelPath = getModelPath(settings); + migraphx::file_options file_options; + file_options.set_file_format("msgpack"); + std::cout << "Saving mxr file to: " << modelPath << "\n"; + migraphx::save(prog, modelPath.c_str(), file_options); + } + else + { + std::cerr << "Onnx file is not available on path: " << onnx_path << std::endl; + exit(1); + } + return prog; +}; + +static migraphx::program loadProgram(ModelLoadSettings& settings) +{ + std::filesystem::path compiled_path(getModelPath(settings)); + + migraphx::file_options file_options; + file_options.set_file_format("msgpack"); + + migraphx::program prog; + std::ifstream f(compiled_path.c_str()); + if (f.good()) + { + std::cout << "Loading model from " << compiled_path << " ...\n"; + prog = migraphx::load(compiled_path.c_str(), file_options); + } + else + { + std::cout << "MXR file can't be loaded try to load ONNX\n"; + prog = loadOnnx(settings); + } + return prog; +}; + +static migraphx::program create_argmax_program(ModelLoadSettings& settings) +{ + migraphx::program prog; + std::vector dims {1, SEQ_SIZE, VOCAB_SIZE}; + if (settings.input_one_dim) + { + dims[1] = 1; + } + migraphx::shape s{migraphx_shape_half_type, dims}; + migraphx::module m = prog.get_main_module(); + auto x = m.add_parameter("x", s); + auto argmax_ins = m.add_instruction(migraphx::operation("argmax", "{axis: 2}"), {x}); + m.add_return({argmax_ins}); + + std::cout << "Creating ArgMax program ..." << std::endl; + + std::string target_str = "gpu"; + migraphx::target targ = migraphx::target(target_str.c_str()); + + if (settings.quantize_fp16) + { + std::cout << "Quantize FP16 ..." << std::endl; + migraphx::quantize_fp16(prog); + } + + migraphx::compile_options comp_opts; + + if (settings.offload_copy) + comp_opts.set_offload_copy(); + + if (settings.fast_math) + comp_opts.set_fast_math(); + + comp_opts.set_exhaustive_tune_flag(); + + std::cout << "Compile to target ..." << std::endl; + prog.compile(targ, comp_opts); + + return prog; +} + +static void writeResults(const std::vector>& results) +{ + std::string RESULT_FILE = "result.txt"; + std::ofstream outFile(RESULT_FILE); + for (auto& resVec : results) + { + for (auto& res : resVec) + { + outFile << res; + if (&res != &resVec.back()) + { + outFile << ", "; + } + } + outFile << "\n"; + } +} diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 5768b4c7ca3..f62c6274ebf 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -1,503 +1,10 @@ #include "buffer.hpp" #include "common.hpp" -#include "numpy.hpp" +#include "dataset.hpp" +#include "llama2inputs.hpp" +#include "utils.hpp" #include -#include -#include -#include -#include -#include - -using namespace mlinfer; - -// TODO: fix paths -const std::string MODEL_FOLDER = "/model/"; -const std::string ONNX_FILE = "model.onnx"; -const std::string DATASET_FOLDER = "/dataset/"; -const size_t DATASET_SIZE = 10; -// sequence length from model config -const size_t SEQ_SIZE = 1024; -// vocab size from model config -const size_t VOCAB_SIZE = 32000; -// EOS token from model config -const size_t EOS = 2; -// Write output tokens to file -const bool WRITE_RESULT_FILE = false; - -const size_t HIDDEN_LAYERS_NUM = 32; -const size_t HEAD_SIZE = 128; -const size_t PAST_KEY_VAL_SIZE = HIDDEN_LAYERS_NUM*HEAD_SIZE*SEQ_SIZE; - -const int DEVICE_ID = 4; - -struct ModelLoadSettings -{ - size_t sequnce_length; - bool quantize_fp16; - bool offload_copy; - bool fast_math; - bool input_one_dim; -}; - -static std::string getModelPath(ModelLoadSettings& s) -{ - std::stringstream path; - path << MODEL_FOLDER << "model-" << std::to_string(s.sequnce_length) << "_fp" << (s.quantize_fp16 ? "16" : "32") << "_"; - if (!s.offload_copy) - { - path << "no"; - } - path << "offload_"; - if (!s.fast_math) - { - path << "no"; - } - path << "fastmath"; - if (s.input_one_dim) - { - path << "_inputonedim"; - } - path << ".mxr"; - return path.str(); -} - -std::string getPastKeyString(size_t i) -{ - std::stringstream past_key; - past_key << "past_key_values." << std::to_string(i) << ".key"; - return past_key.str(); -} - -std::string getPastValueStr(size_t i) -{ - std::stringstream past_val; - past_val << "past_key_values." << std::to_string(i) << ".value"; - return past_val.str(); -} - -std::string getPresentKeyString(size_t i) -{ - std::stringstream past_key; - past_key << "present." << std::to_string(i) << ".key"; - return past_key.str(); -} - -std::string getPresentValueStr(size_t i) -{ - std::stringstream past_val; - past_val << "present." << std::to_string(i) << ".value"; - return past_val.str(); -} - -static migraphx::program loadOnnx(ModelLoadSettings& settings) -{ - std::filesystem::path onnx_path(MODEL_FOLDER + ONNX_FILE); - - #ifdef TRACE - std::cout << "Using model: " << MODEL_FOLDER + ONNX_FILE << std::endl; - #endif - - migraphx::program prog; - std::ifstream f(onnx_path.c_str()); - if (f.good()) - { - migraphx::onnx_options onnx_opts; - std::vector dims = {1, SEQ_SIZE}; - std::vector dimsPastKey = {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}; - std::vector inputDim; - if (settings.input_one_dim) - { - inputDim = {1,1}; - } - else - { - inputDim = dims; - } - onnx_opts.set_input_parameter_shape("input_ids", inputDim); - onnx_opts.set_input_parameter_shape("attention_mask", dims); - for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) - { - onnx_opts.set_input_parameter_shape(getPastKeyString(i), dimsPastKey); - onnx_opts.set_input_parameter_shape(getPastValueStr(i), dimsPastKey); - } - std::cout << "Parsing onnx file ..." << std::endl; - prog = parse_onnx(onnx_path.c_str(), onnx_opts); - - std::string target_str = "gpu"; - migraphx::target targ = migraphx::target(target_str.c_str()); - - if (settings.quantize_fp16) - { - std::cout << "Quantize FP16 ..." << std::endl; - migraphx::quantize_fp16(prog); - } - - migraphx::compile_options comp_opts; - - if (settings.offload_copy) - comp_opts.set_offload_copy(); - - if (settings.fast_math) - comp_opts.set_fast_math(); - - comp_opts.set_exhaustive_tune_flag(); - - std::cout << "Compile to target ..." << std::endl; - prog.compile(targ, comp_opts); - - std::string modelPath = getModelPath(settings); - migraphx::file_options file_options; - file_options.set_file_format("msgpack"); - std::cout << "Saving mxr file to: " << modelPath << "\n"; - migraphx::save(prog, modelPath.c_str(), file_options); - } - else - { - std::cerr << "Onnx file is not available on path: " << onnx_path << std::endl; - exit(1); - } - return prog; -}; - -static migraphx::program loadProgram(ModelLoadSettings& settings) -{ - std::filesystem::path compiled_path(getModelPath(settings)); - - migraphx::file_options file_options; - file_options.set_file_format("msgpack"); - - migraphx::program prog; - std::ifstream f(compiled_path.c_str()); - if (f.good()) - { - std::cout << "Loading model from " << compiled_path << " ...\n"; - prog = migraphx::load(compiled_path.c_str(), file_options); - } - else - { - std::cout << "MXR file can't be loaded try to load ONNX\n"; - prog = loadOnnx(settings); - } - return prog; -}; - -static migraphx::program create_argmax_program(ModelLoadSettings& settings) -{ - migraphx::program prog; - std::vector dims {1, SEQ_SIZE, VOCAB_SIZE}; - if (settings.input_one_dim) - { - dims[1] = 1; - } - migraphx::shape s{migraphx_shape_half_type, dims}; - migraphx::module m = prog.get_main_module(); - auto x = m.add_parameter("x", s); - auto argmax_ins = m.add_instruction(migraphx::operation("argmax", "{axis: 2}"), {x}); - m.add_return({argmax_ins}); - - std::cout << "Creating ArgMax program ..." << std::endl; - - std::string target_str = "gpu"; - migraphx::target targ = migraphx::target(target_str.c_str()); - - if (settings.quantize_fp16) - { - std::cout << "Quantize FP16 ..." << std::endl; - migraphx::quantize_fp16(prog); - } - - migraphx::compile_options comp_opts; - - if (settings.offload_copy) - comp_opts.set_offload_copy(); - - if (settings.fast_math) - comp_opts.set_fast_math(); - - comp_opts.set_exhaustive_tune_flag(); - - std::cout << "Compile to target ..." << std::endl; - prog.compile(targ, comp_opts); - - return prog; -} - -using NumpyVector = std::vector>; - -struct Dataset -{ - Dataset() = default; - - void initialize() - { - loadDataset(); - if (!_npy_files_loaded) - { - prepareSampleDataset(); - } - } - - NumpyVector loadNumpy(npy::NpyFile& file) - { - NumpyVector numpyDataAll; - auto load_size = file.GetTensorSize()/sizeof(int64_t); - numpyDataAll.push_back(std::vector(load_size)); - file.LoadAll(numpyDataAll.back().data()); - - NumpyVector numpyData; - for(size_t i = 0; i < numpyDataAll.back().size(); i += SEQ_SIZE) - { - auto last = std::min(numpyDataAll.back().size(), i + SEQ_SIZE); - numpyData.emplace_back(numpyDataAll.back().begin() + i, numpyDataAll.back().begin() + last); - } - -#ifdef TRACE - for (auto& vec: numpyData) - { - std::cout << "Vector size: " << vec.size() << std::endl; - for (auto val: vec) - { - std::cout << val << " "; - } - std::cout << "\n"; - } -#endif - return numpyData; - } - - size_t getLastIdx() const - { - auto res = std::find_if(std::rbegin(attention_mask[_current_idx]), std::rend(attention_mask[_current_idx]), [](uint64_t val) { return 1 == val;}); - size_t last_idx = std::distance(res, std::rend(attention_mask[_current_idx])); - #ifdef TRACE - std::cout << "Last input idx: " << last_idx << std::endl; - #endif - return last_idx; - } - - std::vector getInputIds() { return input_ids[_current_idx]; } - std::vector getAttentionMask() { return attention_mask[_current_idx]; } - - size_t size() const { return _size; } - size_t currentIdx() const { return _current_idx; } - size_t getNext() - { - if (_current_idx < size() - 1) - { - ++_current_idx; - } - #ifdef TRACE - std::cout << "Current idx: " << _current_idx << std::endl; - #endif - return _current_idx; - } - - Dataset(const Dataset &buf) = delete; - Dataset &operator=(const Dataset &buf) = delete; -private: - - // e.g.: /dataset/input_ids_size_3_seq_256.npy - std::string getDatasetPath(const std::string& datasetName) - { - std::stringstream path; - path << DATASET_FOLDER << datasetName << "_size_" << std::to_string(DATASET_SIZE) << "_seq_" << std::to_string(SEQ_SIZE) << ".npy"; - return path.str(); - } - - void loadDataset() - { - std::string input_file_path = getDatasetPath("input_ids"); - std::string attention_mask_file_path = getDatasetPath("attention_mask"); - - std::cout << "Input ids file: " << input_file_path << std::endl; - std::ifstream input_file(input_file_path.c_str()); - std::ifstream attention_mask_file(attention_mask_file_path.c_str()); - if (input_file.good() && attention_mask_file.good()) - { - npy::NpyFile input_ids_npy{input_file_path}; - npy::NpyFile attention_mask_npy{attention_mask_file_path}; - input_ids = loadNumpy(input_ids_npy); - attention_mask = loadNumpy(attention_mask_npy); - - _size = input_ids.size(); - - if (input_ids.size() == attention_mask.size()) - { - std::cout << "Loaded numpy files\n"; - _npy_files_loaded = true; - } - else - { - std::cout << "Numpy files do not have the same size\n"; - input_ids.clear(); - attention_mask.clear(); - } - } - else - { - std::cout << "Unable to open numpy files\n"; - } - } - - void prepareSampleDataset() - { - std::cout << "Numpy files are not loaded, using dummy data\n"; - std::vector input_ids_sample = {1,6804,338,5207,387,287,29973}; - input_ids_sample.resize(SEQ_SIZE, 0); - std::vector attention_mask_sample = input_ids_sample; - input_ids.emplace_back(std::move(input_ids_sample)); - std::transform(std::begin(attention_mask_sample), std::end(attention_mask_sample), std::begin(attention_mask_sample), [](auto i){ - return (i != 0) ? 1 : 0; - }); - attention_mask.emplace_back(std::move(attention_mask_sample)); - - _size = 1; - } - - NumpyVector input_ids; - NumpyVector attention_mask; - - size_t _size = 0; - size_t _current_idx = 0; - bool _npy_files_loaded = false; -}; - -struct LLama2Inputs -{ - LLama2Inputs( - migraphx::program& prog, - migraphx::program_parameters& prog_args, - bool offload_copy) - : offload_copy(offload_copy) - { - data.initialize(); - prepareProgArgs(prog, prog_args); - } - - void prepareProgArgs(migraphx::program& prog, migraphx::program_parameters& prog_args, bool simple = false) - { - auto param_shapes = prog.get_parameter_shapes(); - if (!simple) - { - auto inputShape = param_shapes[INPUTS_ID_STR]; - auto input_ids = data.getInputIds(); - input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); - prog_args.add(INPUTS_ID_STR, migraphx::argument(inputShape, input_ids_buffer->data())); - } - - - auto attShape = param_shapes[ATTENTION_MASK_STR]; - auto attention_mask = data.getAttentionMask(); - if (!simple) - { - attention_mask_buffer = std::make_unique(std::move(attention_mask), offload_copy); - } - prog_args.add(ATTENTION_MASK_STR, migraphx::argument(attShape, attention_mask_buffer->data())); - - // past_key_values.0.key = @param:past_key_values.0.key -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} - // past_key_values.0.value = @param:past_key_values.0.value -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} - for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) - { - auto past_keyStr = getPastKeyString(i); - auto past_keyString = past_keyStr.c_str(); - if (!simple) - { - past_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); - } - auto pastKeyShape = param_shapes[past_keyString]; - prog_args.add(past_keyString, migraphx::argument(pastKeyShape, past_key_buffers[i]->data())); - - auto past_valueStr = getPastValueStr(i); - auto past_valueString = past_valueStr.c_str(); - if (!simple) - { - past_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); - } - auto pastValueShape = param_shapes[past_valueString]; - prog_args.add(past_valueString, migraphx::argument(pastValueShape, past_value_buffers[i]->data())); - } - } - - void upload_to_device(hipStream_t stream) - { - assert(not offload_copy); - input_ids_buffer->upload_to_device(stream); - attention_mask_buffer->upload_to_device(stream); - } - - bool updateData(migraphx::program& prog, migraphx::program_parameters& prog_args) - { - auto currentIdx = data.currentIdx(); - if (currentIdx != data.getNext()) - { - auto param_shapes = prog.get_parameter_shapes(); - - std::vector input_ids = data.getInputIds(); - input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); - prog_args.add(INPUTS_ID_STR, migraphx::argument(param_shapes[INPUTS_ID_STR], input_ids_buffer->data())); - - auto attention_mask = data.getAttentionMask(); - attention_mask_buffer->update(std::move(attention_mask)); - if (offload_copy) - { - prog_args.add(ATTENTION_MASK_STR, migraphx::argument(param_shapes[ATTENTION_MASK_STR], attention_mask_buffer->data())); - } - - return true; - } - return false; - } - - void resetPastKeyValueBuffers(hipStream_t stream) - { - for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) - { - past_key_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); - past_value_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); - if (not offload_copy) - { - past_key_buffers[i]->upload_to_device(stream); - past_value_buffers[i]->upload_to_device(stream); - } - } - } - - size_t getLastInputIndex() const { return data.getLastIdx(); } - size_t dataSize() const { return data.size(); } - - LLama2Inputs() = delete; - LLama2Inputs(const LLama2Inputs &buf) = delete; - LLama2Inputs &operator=(const LLama2Inputs &buf) = delete; - - std::unique_ptr input_ids_buffer; - std::unique_ptr attention_mask_buffer; - std::vector> past_key_buffers; - std::vector> past_value_buffers; - Dataset data; - bool offload_copy; - - const char* INPUTS_ID_STR = "input_ids"; - const char* ATTENTION_MASK_STR = "attention_mask"; -}; - -void writeResults(const std::vector>& results) -{ - std::string RESULT_FILE = "result.txt"; - std::ofstream outFile(RESULT_FILE); - for (auto& resVec : results) - { - for (auto& res : resVec) - { - outFile << res; - if (&res != &resVec.back()) - { - outFile << ", "; - } - } - outFile << "\n"; - } -} int main() { bool offload_copy = false; From e9d0a81d13f774280eea9a1c5bbedc82f3eddb3f Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Wed, 20 Nov 2024 06:18:36 -0600 Subject: [PATCH 39/55] Move one dim input ids to input class --- .../transformers/mgx_llama2/harness/llama2inputs.hpp | 11 +++++++++++ examples/transformers/mgx_llama2/mgxllama2.cc | 11 +++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp index 092dd0b95b0..09f8a1f69d5 100644 --- a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp @@ -59,6 +59,16 @@ struct LLama2Inputs } } + void prepareOneDimProgArgs(migraphx::program& prog, migraphx::program_parameters& prog_args) + { + prepareProgArgs(prog, prog_args, true); + auto param_shapes = prog.get_parameter_shapes(); + auto inputShape = param_shapes[INPUTS_ID_STR]; + std::vector oneDimInput = {0}; + one_dim_input_buffer = std::make_unique(std::move(oneDimInput), offload_copy); + prog_args.add(INPUTS_ID_STR, migraphx::argument(inputShape, one_dim_input_buffer->data())); + } + void upload_to_device(hipStream_t stream) { assert(not offload_copy); @@ -111,6 +121,7 @@ struct LLama2Inputs LLama2Inputs &operator=(const LLama2Inputs &buf) = delete; std::unique_ptr input_ids_buffer; + std::unique_ptr one_dim_input_buffer; std::unique_ptr attention_mask_buffer; std::vector> past_key_buffers; std::vector> past_value_buffers; diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index f62c6274ebf..1195e5965fc 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -63,17 +63,12 @@ int main() { migraphx::program_parameters prog_args_one_dim; - model_inputs.prepareProgArgs(progSimpleInput, prog_args_one_dim, true); - auto param_shapes = progSimpleInput.get_parameter_shapes(); - auto inputShape = param_shapes[model_inputs.INPUTS_ID_STR]; - std::vector oneDimInput = {0}; - std::unique_ptr one_dim_input_buffer = std::make_unique(std::move(oneDimInput), offload_copy); - prog_args_one_dim.add(model_inputs.INPUTS_ID_STR, migraphx::argument(inputShape, one_dim_input_buffer->data())); + model_inputs.prepareOneDimProgArgs(progSimpleInput, prog_args_one_dim); prog_args_one_dim.add(output_name, migraphx::argument(x_shape_one_dim, output_buffer_oneDim.data())); if (not offload_copy) { - one_dim_input_buffer->upload_to_device(stream); + model_inputs.one_dim_input_buffer->upload_to_device(stream); } std::cout << "Dataset size: " << model_inputs.dataSize() << std::endl; @@ -120,7 +115,7 @@ int main() { progArgMax = &progArgMaxSimpleInput; } - one_dim_input_buffer->update_data(new_token, 0, stream); + model_inputs.one_dim_input_buffer->update_data(new_token, 0, stream); } #ifdef TRACE From 10596c1cd64b3720189170d63d1a516962e0fffe Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Thu, 21 Nov 2024 08:18:27 -0600 Subject: [PATCH 40/55] Refactor outputs to llama2outputs file and move main to MGXLlama2 struct --- .../mgx_llama2/harness/llama2outputs.hpp | 55 ++++ examples/transformers/mgx_llama2/mgxllama2.cc | 243 +++++++++--------- 2 files changed, 182 insertions(+), 116 deletions(-) create mode 100644 examples/transformers/mgx_llama2/harness/llama2outputs.hpp diff --git a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp new file mode 100644 index 00000000000..3ab0e24b0ec --- /dev/null +++ b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include "buffer.hpp" +#include "config.hpp" + +#include + +struct LLama2Outputs +{ + LLama2Outputs(bool offload_copy) + : offload_copy(offload_copy) + { + } + + void prepareProgArgs(migraphx::program_parameters& prog_args, migraphx::program_parameters& prog_args_one_dim) + { + output_buffer = std::make_unique(std::vector(OUTPUT_SIZE), offload_copy); + one_dim_output_buffer = std::make_unique(std::vector(VOCAB_SIZE), offload_copy); + migraphx::shape out_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; + prog_args.add(OUTPUT_NAME, migraphx::argument(out_shape, output_buffer->data())); + + migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {1, 1, VOCAB_SIZE}}; + prog_args_one_dim.add(OUTPUT_NAME, migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); + } + + void prepareProgArgsArgMax(migraphx::program_parameters& prog_args_argmax, migraphx::program_parameters& prog_args_argmax_one_dim) + { + // setting up argmax arguments + migraphx::shape x_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; + prog_args_argmax.add("x", migraphx::argument(x_shape, output_buffer->data())); + argm_output_buffer = std::make_unique(std::vector(VOCAB_SIZE), offload_copy); + migraphx::shape argm_out_shape{migraphx_shape_int64_type, {1, SEQ_SIZE, 1}}; + prog_args_argmax.add(OUTPUT_NAME, migraphx::argument(argm_out_shape, argm_output_buffer->data())); + + migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {1, 1, VOCAB_SIZE}}; + prog_args_argmax_one_dim.add("x", migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); + argm_output_buffer_one_dim = std::make_unique(std::vector(1), offload_copy); + migraphx::shape argm_out_shape_one_dim{migraphx_shape_int64_type, {1, 1, 1}}; + prog_args_argmax_one_dim.add(OUTPUT_NAME, migraphx::argument(argm_out_shape_one_dim, argm_output_buffer_one_dim->data())); + } + + LLama2Outputs() = delete; + LLama2Outputs(const LLama2Outputs &buf) = delete; + LLama2Outputs &operator=(const LLama2Outputs &buf) = delete; + + std::unique_ptr output_buffer; + std::unique_ptr one_dim_output_buffer; + std::unique_ptr argm_output_buffer; + std::unique_ptr argm_output_buffer_one_dim; + + bool offload_copy = false; + + const char* OUTPUT_NAME = "main:#output_0"; + const size_t OUTPUT_SIZE = SEQ_SIZE * VOCAB_SIZE; +}; \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 1195e5965fc..9c4a39f16ae 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -2,151 +2,162 @@ #include "common.hpp" #include "dataset.hpp" #include "llama2inputs.hpp" +#include "llama2outputs.hpp" #include "utils.hpp" #include -int main() { - bool offload_copy = false; - check_hip_status(hipSetDevice(DEVICE_ID)); - std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; - ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, offload_copy /*offload_copy*/, false /*fast_math*/, false /*input_one_dim*/}; - migraphx::program progMultipleInputDim = loadProgram(settings); - std::cout << "Model loaded" << std::endl; - migraphx::program progArgMaxMultipleInputDim = create_argmax_program(settings); - std::cout << "ArgMax model created" << std::endl; - - // Load {1,1} input_ids model - settings.input_one_dim = true; - migraphx::program progSimpleInput = loadProgram(settings); - std::cout << "Model 1 dim input loaded" << std::endl; - migraphx::program progArgMaxSimpleInput = create_argmax_program(settings); - std::cout << "ArgMax model for 1 dim model created" << std::endl; - - migraphx::program *prog = &progMultipleInputDim; - migraphx::program *progArgMax = &progArgMaxMultipleInputDim; - - // Setup model inputs - std::vector> results; - std::vector output_tokens; - migraphx::program_parameters prog_args; - hipStream_t stream; - check_hip_status(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); - auto model_inputs = LLama2Inputs(*prog, prog_args, offload_copy); - if (not offload_copy) +struct MGXLlama2 +{ + MGXLlama2() { - model_inputs.upload_to_device(stream); - } + check_hip_status(hipSetDevice(DEVICE_ID)); + std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; + ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, offload_copy /*offload_copy*/, false /*fast_math*/, false /*input_one_dim*/}; + progMultipleInputDim = loadProgram(settings); + std::cout << "Model loaded" << std::endl; + progArgMaxMultipleInputDim = create_argmax_program(settings); + std::cout << "ArgMax model created" << std::endl; + + // Load {1,1} input_ids model + settings.input_one_dim = true; + progSimpleInput = loadProgram(settings); + std::cout << "Model 1 dim input loaded" << std::endl; + progArgMaxSimpleInput = create_argmax_program(settings); + std::cout << "ArgMax model for 1 dim model created" << std::endl; - auto output_name = "main:#output_0"; - - size_t output_size = SEQ_SIZE * VOCAB_SIZE; - auto output_buffer = LLama2PastKeyValueBuffer(std::vector(output_size), offload_copy); - auto output_buffer_oneDim = LLama2PastKeyValueBuffer(std::vector(VOCAB_SIZE), offload_copy); - migraphx::shape out_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; - prog_args.add(output_name, migraphx::argument(out_shape, output_buffer.data())); - - // setting up argmax arguments - migraphx::program_parameters prog_args_argmax; - migraphx::shape x_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; - prog_args_argmax.add("x", migraphx::argument(x_shape, output_buffer.data())); - auto argm_output_buffer = ArgMaxOutputBuffer(std::vector(VOCAB_SIZE), offload_copy); - migraphx::shape argm_out_shape{migraphx_shape_int64_type, {1, SEQ_SIZE, 1}}; - prog_args_argmax.add(output_name, migraphx::argument(argm_out_shape, argm_output_buffer.data())); + prog = &progMultipleInputDim; + progArgMax = &progArgMaxMultipleInputDim; - migraphx::program_parameters prog_args_argmax_one_dim; - migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {1, 1, VOCAB_SIZE}}; - prog_args_argmax_one_dim.add("x", migraphx::argument(x_shape_one_dim, output_buffer_oneDim.data())); - auto argm_output_buffer_one_dim = ArgMaxOutputBuffer(std::vector(1), offload_copy); - migraphx::shape argm_out_shape_one_dim{migraphx_shape_int64_type, {1, 1, 1}}; - prog_args_argmax_one_dim.add(output_name, migraphx::argument(argm_out_shape_one_dim, argm_output_buffer_one_dim.data())); + check_hip_status(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + model_inputs = std::make_unique(*prog, prog_args, offload_copy); + if (not offload_copy) + { + model_inputs->upload_to_device(stream); + } + model_inputs->prepareOneDimProgArgs(progSimpleInput, prog_args_one_dim); - migraphx::program_parameters prog_args_one_dim; - model_inputs.prepareOneDimProgArgs(progSimpleInput, prog_args_one_dim); - prog_args_one_dim.add(output_name, migraphx::argument(x_shape_one_dim, output_buffer_oneDim.data())); + model_outputs = std::make_unique(offload_copy); + model_outputs->prepareProgArgs(prog_args,prog_args_one_dim); + // setting up argmax arguments + model_outputs->prepareProgArgsArgMax(prog_args_argmax, prog_args_argmax_one_dim); - if (not offload_copy) - { - model_inputs.one_dim_input_buffer->upload_to_device(stream); + if (not offload_copy) + { + model_inputs->one_dim_input_buffer->upload_to_device(stream); + } } - std::cout << "Dataset size: " << model_inputs.dataSize() << std::endl; - std::cout << "Starting evaluation" << std::endl; - size_t token_count = 0; - auto start = std::chrono::steady_clock::now(); - for (size_t i = 0; i < model_inputs.dataSize(); ++i) + void run() { - #ifdef TRACE - std::cout << "Iter #" << i << std::endl; - #endif - auto lastInputIdx = model_inputs.getLastInputIndex(); - for (size_t i = lastInputIdx; i < SEQ_SIZE - 1; ++i) + std::cout << "Dataset size: " << model_inputs->dataSize() << std::endl; + std::cout << "Starting evaluation" << std::endl; + size_t token_count = 0; + auto start = std::chrono::steady_clock::now(); + for (size_t i = 0; i < model_inputs->dataSize(); ++i) { - bool firstIter = (i == lastInputIdx); - prog->run_async(firstIter ? prog_args : prog_args_one_dim, stream); - auto outputs = progArgMax->run_async(firstIter ? prog_args_argmax : prog_args_argmax_one_dim, stream); - if (not offload_copy) - { - firstIter ? argm_output_buffer.download_from_device(stream, i, i + 1) : argm_output_buffer_one_dim.download_from_device(stream); - } - - check_hip_status(hipStreamSynchronize(stream)); - int64_t* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast( firstIter ? argm_output_buffer.hbuff.data() : argm_output_buffer_one_dim.hbuff.data()); - auto new_token_idx = firstIter ? i : 0; - int64_t new_token = results[new_token_idx]; - - token_count++; #ifdef TRACE - std::cout << "New token: " << new_token << std::endl; + std::cout << "Iter #" << i << std::endl; #endif - output_tokens.push_back(new_token); - - if (new_token == EOS) + auto lastInputIdx = model_inputs->getLastInputIndex(); + for (size_t i = lastInputIdx; i < SEQ_SIZE - 1; ++i) { - break; + bool firstIter = (i == lastInputIdx); + prog->run_async(firstIter ? prog_args : prog_args_one_dim, stream); + auto outputs = progArgMax->run_async(firstIter ? prog_args_argmax : prog_args_argmax_one_dim, stream); + if (not offload_copy) + { + firstIter ? model_outputs->argm_output_buffer->download_from_device(stream, i, i + 1) : model_outputs->argm_output_buffer_one_dim->download_from_device(stream); + } + + check_hip_status(hipStreamSynchronize(stream)); + int64_t* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast( firstIter ? model_outputs->argm_output_buffer->hbuff.data() : model_outputs->argm_output_buffer_one_dim->hbuff.data()); + auto new_token_idx = firstIter ? i : 0; + int64_t new_token = results[new_token_idx]; + + token_count++; + #ifdef TRACE + std::cout << "New token: " << new_token << std::endl; + #endif + output_tokens.push_back(new_token); + + if (new_token == EOS) + { + break; + } + + model_inputs->attention_mask_buffer->update_data(1, i + 1, stream); + + if (firstIter) + { + prog = &progSimpleInput; + progArgMax = &progArgMaxSimpleInput; + } + + model_inputs->one_dim_input_buffer->update_data(new_token, 0, stream); } - model_inputs.attention_mask_buffer->update_data(1, i + 1, stream); + #ifdef TRACE + std::cout << "######### Output token ids for #" << i << " #########" << std::endl; + // print output tokens + for (auto tok: output_tokens){ + std::cout << tok << ", "; + } + std::cout << std::endl; + #endif + prog = &progMultipleInputDim; + progArgMax = &progArgMaxMultipleInputDim; + + auto updated = model_inputs->updateData(*prog, prog_args); - if (firstIter) + if (updated && not offload_copy) { - prog = &progSimpleInput; - progArgMax = &progArgMaxSimpleInput; + model_inputs->upload_to_device(stream); } - - model_inputs.one_dim_input_buffer->update_data(new_token, 0, stream); + results.emplace_back(output_tokens); + output_tokens.clear(); } -#ifdef TRACE - std::cout << "######### Output token ids for #" << i << " #########" << std::endl; - // print output tokens - for (auto tok: output_tokens){ - std::cout << tok << ", "; - } - std::cout << std::endl; -#endif - prog = &progMultipleInputDim; - progArgMax = &progArgMaxMultipleInputDim; + float dur = std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count() / 1000.f; + std::cout << "Duration: " << dur << " seconds." << std::endl; + std::cout << "Completed " << token_count << " tokens." << std::endl; + std::cout << "Tokens/sec: " << token_count / dur << std::endl; - auto updated = model_inputs.updateData(*prog, prog_args); - - if (updated && not offload_copy) + if (WRITE_RESULT_FILE) { - model_inputs.upload_to_device(stream); + writeResults(results); } - results.emplace_back(output_tokens); - output_tokens.clear(); } - float dur = std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count() / 1000.f; - std::cout << "Duration: " << dur << " seconds." << std::endl; - std::cout << "Completed " << token_count << " tokens." << std::endl; - std::cout << "Tokens/sec: " << token_count / dur << std::endl; + MGXLlama2(const MGXLlama2 &buf) = delete; + MGXLlama2 &operator=(const MGXLlama2 &buf) = delete; - if (WRITE_RESULT_FILE) - { - writeResults(results); - } + migraphx::program progMultipleInputDim; + migraphx::program progArgMaxMultipleInputDim; + migraphx::program progSimpleInput; + migraphx::program progArgMaxSimpleInput; + migraphx::program *prog = nullptr; + migraphx::program *progArgMax = nullptr; + + migraphx::program_parameters prog_args; + migraphx::program_parameters prog_args_one_dim; + migraphx::program_parameters prog_args_argmax; + migraphx::program_parameters prog_args_argmax_one_dim; + + std::vector> results; + std::vector output_tokens; + hipStream_t stream; + bool offload_copy = false; + + std::unique_ptr model_inputs; + std::unique_ptr model_outputs; +}; + + +int main() +{ + MGXLlama2 mgxllama2; + mgxllama2.run(); return 0; } From d9c2746a36a11bd8bf765471e11601bd9a77557a Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Mon, 25 Nov 2024 09:39:56 -0600 Subject: [PATCH 41/55] Refactor mgxllama2 --- examples/transformers/mgx_llama2/mgxllama2.cc | 158 ++++++++++-------- 1 file changed, 84 insertions(+), 74 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 9c4a39f16ae..ff1e686db17 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -12,41 +12,22 @@ struct MGXLlama2 MGXLlama2() { check_hip_status(hipSetDevice(DEVICE_ID)); - std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; - ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, offload_copy /*offload_copy*/, false /*fast_math*/, false /*input_one_dim*/}; - progMultipleInputDim = loadProgram(settings); - std::cout << "Model loaded" << std::endl; - progArgMaxMultipleInputDim = create_argmax_program(settings); - std::cout << "ArgMax model created" << std::endl; - - // Load {1,1} input_ids model - settings.input_one_dim = true; - progSimpleInput = loadProgram(settings); - std::cout << "Model 1 dim input loaded" << std::endl; - progArgMaxSimpleInput = create_argmax_program(settings); - std::cout << "ArgMax model for 1 dim model created" << std::endl; + check_hip_status(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); - prog = &progMultipleInputDim; - progArgMax = &progArgMaxMultipleInputDim; + loadPrograms(); - check_hip_status(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); model_inputs = std::make_unique(*prog, prog_args, offload_copy); + model_inputs->prepareOneDimProgArgs(progSimpleInput, prog_args_one_dim); if (not offload_copy) { model_inputs->upload_to_device(stream); + model_inputs->one_dim_input_buffer->upload_to_device(stream); } - model_inputs->prepareOneDimProgArgs(progSimpleInput, prog_args_one_dim); - model_outputs = std::make_unique(offload_copy); model_outputs->prepareProgArgs(prog_args,prog_args_one_dim); // setting up argmax arguments model_outputs->prepareProgArgsArgMax(prog_args_argmax, prog_args_argmax_one_dim); - - if (not offload_copy) - { - model_inputs->one_dim_input_buffer->upload_to_device(stream); - } } void run() @@ -57,46 +38,7 @@ struct MGXLlama2 auto start = std::chrono::steady_clock::now(); for (size_t i = 0; i < model_inputs->dataSize(); ++i) { - #ifdef TRACE - std::cout << "Iter #" << i << std::endl; - #endif - auto lastInputIdx = model_inputs->getLastInputIndex(); - for (size_t i = lastInputIdx; i < SEQ_SIZE - 1; ++i) - { - bool firstIter = (i == lastInputIdx); - prog->run_async(firstIter ? prog_args : prog_args_one_dim, stream); - auto outputs = progArgMax->run_async(firstIter ? prog_args_argmax : prog_args_argmax_one_dim, stream); - if (not offload_copy) - { - firstIter ? model_outputs->argm_output_buffer->download_from_device(stream, i, i + 1) : model_outputs->argm_output_buffer_one_dim->download_from_device(stream); - } - - check_hip_status(hipStreamSynchronize(stream)); - int64_t* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast( firstIter ? model_outputs->argm_output_buffer->hbuff.data() : model_outputs->argm_output_buffer_one_dim->hbuff.data()); - auto new_token_idx = firstIter ? i : 0; - int64_t new_token = results[new_token_idx]; - - token_count++; - #ifdef TRACE - std::cout << "New token: " << new_token << std::endl; - #endif - output_tokens.push_back(new_token); - - if (new_token == EOS) - { - break; - } - - model_inputs->attention_mask_buffer->update_data(1, i + 1, stream); - - if (firstIter) - { - prog = &progSimpleInput; - progArgMax = &progArgMaxSimpleInput; - } - - model_inputs->one_dim_input_buffer->update_data(new_token, 0, stream); - } + evaluateSample(i, token_count); #ifdef TRACE std::cout << "######### Output token ids for #" << i << " #########" << std::endl; @@ -106,17 +48,7 @@ struct MGXLlama2 } std::cout << std::endl; #endif - prog = &progMultipleInputDim; - progArgMax = &progArgMaxMultipleInputDim; - - auto updated = model_inputs->updateData(*prog, prog_args); - - if (updated && not offload_copy) - { - model_inputs->upload_to_device(stream); - } - results.emplace_back(output_tokens); - output_tokens.clear(); + prepareNextSample(); } float dur = std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count() / 1000.f; @@ -130,6 +62,84 @@ struct MGXLlama2 } } + void loadPrograms() + { + std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; + ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, offload_copy /*offload_copy*/, false /*fast_math*/, false /*input_one_dim*/}; + progMultipleInputDim = loadProgram(settings); + std::cout << "Model loaded" << std::endl; + progArgMaxMultipleInputDim = create_argmax_program(settings); + std::cout << "ArgMax model created" << std::endl; + + // Load {1,1} input_ids model + settings.input_one_dim = true; + progSimpleInput = loadProgram(settings); + std::cout << "Model 1 dim input loaded" << std::endl; + progArgMaxSimpleInput = create_argmax_program(settings); + std::cout << "ArgMax model for 1 dim model created" << std::endl; + + prog = &progMultipleInputDim; + progArgMax = &progArgMaxMultipleInputDim; + } + + void evaluateSample(size_t sample_id, size_t& token_count) + { + #ifdef TRACE + std::cout << "Iter #" << sample_id << std::endl; + #endif + auto lastInputIdx = model_inputs->getLastInputIndex(); + for (size_t i = lastInputIdx; i < SEQ_SIZE - 1; ++i) + { + bool firstIter = (i == lastInputIdx); + prog->run_async(firstIter ? prog_args : prog_args_one_dim, stream); + auto outputs = progArgMax->run_async(firstIter ? prog_args_argmax : prog_args_argmax_one_dim, stream); + if (not offload_copy) + { + firstIter ? model_outputs->argm_output_buffer->download_from_device(stream, i, i + 1) : model_outputs->argm_output_buffer_one_dim->download_from_device(stream); + } + + check_hip_status(hipStreamSynchronize(stream)); + int64_t* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast( firstIter ? model_outputs->argm_output_buffer->hbuff.data() : model_outputs->argm_output_buffer_one_dim->hbuff.data()); + auto new_token_idx = firstIter ? i : 0; + int64_t new_token = results[new_token_idx]; + + token_count++; + #ifdef TRACE + std::cout << "New token: " << new_token << std::endl; + #endif + output_tokens.push_back(new_token); + + if (new_token == EOS) + { + break; + } + + model_inputs->attention_mask_buffer->update_data(1, i + 1, stream); + model_inputs->one_dim_input_buffer->update_data(new_token, 0, stream); + + if (firstIter) + { + prog = &progSimpleInput; + progArgMax = &progArgMaxSimpleInput; + } + } + } + + void prepareNextSample() + { + prog = &progMultipleInputDim; + progArgMax = &progArgMaxMultipleInputDim; + + auto updated = model_inputs->updateData(*prog, prog_args); + + if (updated && not offload_copy) + { + model_inputs->upload_to_device(stream); + } + results.emplace_back(output_tokens); + output_tokens.clear(); + } + MGXLlama2(const MGXLlama2 &buf) = delete; MGXLlama2 &operator=(const MGXLlama2 &buf) = delete; From fbd46b8462ff300f3aca67a906c7a958c687a5c5 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Wed, 27 Nov 2024 03:18:56 -0600 Subject: [PATCH 42/55] Use batch size from config --- .../transformers/mgx_llama2/harness/config.hpp | 3 ++- .../mgx_llama2/harness/llama2outputs.hpp | 14 +++++++------- examples/transformers/mgx_llama2/harness/utils.hpp | 8 ++++---- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/transformers/mgx_llama2/harness/config.hpp b/examples/transformers/mgx_llama2/harness/config.hpp index 1adf21e3385..693f3c90f18 100644 --- a/examples/transformers/mgx_llama2/harness/config.hpp +++ b/examples/transformers/mgx_llama2/harness/config.hpp @@ -5,6 +5,7 @@ const std::string MODEL_FOLDER = "/model/"; const std::string ONNX_FILE = "model.onnx"; const std::string DATASET_FOLDER = "/dataset/"; const size_t DATASET_SIZE = 10; +const size_t BATCH_SIZE = 1; // sequence length from model config const size_t SEQ_SIZE = 1024; // vocab size from model config @@ -18,4 +19,4 @@ const int DEVICE_ID = 4; const size_t HIDDEN_LAYERS_NUM = 32; const size_t HEAD_SIZE = 128; -const size_t PAST_KEY_VAL_SIZE = HIDDEN_LAYERS_NUM*HEAD_SIZE*SEQ_SIZE; \ No newline at end of file +const size_t PAST_KEY_VAL_SIZE = BATCH_SIZE*HIDDEN_LAYERS_NUM*HEAD_SIZE*SEQ_SIZE; \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp index 3ab0e24b0ec..0a0ca3b0a1b 100644 --- a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp @@ -16,26 +16,26 @@ struct LLama2Outputs { output_buffer = std::make_unique(std::vector(OUTPUT_SIZE), offload_copy); one_dim_output_buffer = std::make_unique(std::vector(VOCAB_SIZE), offload_copy); - migraphx::shape out_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; + migraphx::shape out_shape{migraphx_shape_half_type, {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; prog_args.add(OUTPUT_NAME, migraphx::argument(out_shape, output_buffer->data())); - migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {1, 1, VOCAB_SIZE}}; + migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {BATCH_SIZE, 1, VOCAB_SIZE}}; prog_args_one_dim.add(OUTPUT_NAME, migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); } void prepareProgArgsArgMax(migraphx::program_parameters& prog_args_argmax, migraphx::program_parameters& prog_args_argmax_one_dim) { // setting up argmax arguments - migraphx::shape x_shape{migraphx_shape_half_type, {1, SEQ_SIZE, VOCAB_SIZE}}; + migraphx::shape x_shape{migraphx_shape_half_type, {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; prog_args_argmax.add("x", migraphx::argument(x_shape, output_buffer->data())); argm_output_buffer = std::make_unique(std::vector(VOCAB_SIZE), offload_copy); - migraphx::shape argm_out_shape{migraphx_shape_int64_type, {1, SEQ_SIZE, 1}}; + migraphx::shape argm_out_shape{migraphx_shape_int64_type, {BATCH_SIZE, SEQ_SIZE, 1}}; prog_args_argmax.add(OUTPUT_NAME, migraphx::argument(argm_out_shape, argm_output_buffer->data())); - migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {1, 1, VOCAB_SIZE}}; + migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {BATCH_SIZE, 1, VOCAB_SIZE}}; prog_args_argmax_one_dim.add("x", migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); argm_output_buffer_one_dim = std::make_unique(std::vector(1), offload_copy); - migraphx::shape argm_out_shape_one_dim{migraphx_shape_int64_type, {1, 1, 1}}; + migraphx::shape argm_out_shape_one_dim{migraphx_shape_int64_type, {BATCH_SIZE, 1, 1}}; prog_args_argmax_one_dim.add(OUTPUT_NAME, migraphx::argument(argm_out_shape_one_dim, argm_output_buffer_one_dim->data())); } @@ -51,5 +51,5 @@ struct LLama2Outputs bool offload_copy = false; const char* OUTPUT_NAME = "main:#output_0"; - const size_t OUTPUT_SIZE = SEQ_SIZE * VOCAB_SIZE; + const size_t OUTPUT_SIZE = BATCH_SIZE * SEQ_SIZE * VOCAB_SIZE; }; \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/harness/utils.hpp b/examples/transformers/mgx_llama2/harness/utils.hpp index f34b14095cf..1575e104540 100644 --- a/examples/transformers/mgx_llama2/harness/utils.hpp +++ b/examples/transformers/mgx_llama2/harness/utils.hpp @@ -81,12 +81,12 @@ static migraphx::program loadOnnx(ModelLoadSettings& settings) if (f.good()) { migraphx::onnx_options onnx_opts; - std::vector dims = {1, SEQ_SIZE}; - std::vector dimsPastKey = {1, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}; + std::vector dims = {BATCH_SIZE, SEQ_SIZE}; + std::vector dimsPastKey = {BATCH_SIZE, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}; std::vector inputDim; if (settings.input_one_dim) { - inputDim = {1,1}; + inputDim = {BATCH_SIZE,1}; } else { @@ -163,7 +163,7 @@ static migraphx::program loadProgram(ModelLoadSettings& settings) static migraphx::program create_argmax_program(ModelLoadSettings& settings) { migraphx::program prog; - std::vector dims {1, SEQ_SIZE, VOCAB_SIZE}; + std::vector dims {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}; if (settings.input_one_dim) { dims[1] = 1; From 583a43942b6bda005922c80c0f94156aec180dee Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Wed, 27 Nov 2024 05:01:31 -0600 Subject: [PATCH 43/55] Remove offload copy option --- .../mgx_llama2/harness/buffer.hpp | 21 ++++---------- .../mgx_llama2/harness/llama2inputs.hpp | 29 ++++++------------- .../mgx_llama2/harness/llama2outputs.hpp | 14 ++++----- .../transformers/mgx_llama2/harness/utils.hpp | 13 +-------- examples/transformers/mgx_llama2/mgxllama2.cc | 26 ++++++----------- 5 files changed, 30 insertions(+), 73 deletions(-) diff --git a/examples/transformers/mgx_llama2/harness/buffer.hpp b/examples/transformers/mgx_llama2/harness/buffer.hpp index fc13d73aa04..5216c7023f9 100644 --- a/examples/transformers/mgx_llama2/harness/buffer.hpp +++ b/examples/transformers/mgx_llama2/harness/buffer.hpp @@ -110,19 +110,16 @@ namespace mlinfer struct ManagedBuffer_v2 { - explicit ManagedBuffer_v2(std::vector&& host_data, bool with_offload_copy=true): with_offload_copy(with_offload_copy) + explicit ManagedBuffer_v2(std::vector&& host_data) { size_in_bytes = host_data.size() * sizeof(T); hbuff = std::move(host_data); - if (not with_offload_copy) - { - dbuff = DeviceBuffer(size_in_bytes, 0); - } + dbuff = DeviceBuffer(size_in_bytes, 0); } void* data() { - return with_offload_copy ? static_cast(hbuff.data()) : dbuff.tensor_ptr; + return dbuff.tensor_ptr; } void update(std::vector&& host_data) @@ -132,7 +129,6 @@ namespace mlinfer void upload_to_device(hipStream_t stream, size_t start_idx=0, size_t end_idx=0) { - assert(not with_offload_copy); char* src_addr = reinterpret_cast(hbuff.data()); char* dst_addr = static_cast(dbuff.tensor_ptr); size_t copy_size_in_bytes = size_in_bytes; @@ -150,7 +146,6 @@ namespace mlinfer void download_from_device(hipStream_t stream, size_t start_idx=0, size_t end_idx=0) { - assert(not with_offload_copy); char* src_addr = static_cast(dbuff.tensor_ptr); char* dst_addr = reinterpret_cast(hbuff.data()); size_t copy_size_in_bytes = size_in_bytes; @@ -169,12 +164,9 @@ namespace mlinfer void update_data(T data, size_t position, hipStream_t stream) { hbuff.at(position) = data; - if (not with_offload_copy) - { - // TODO: don't copy over the entire buffer just the changed range - // check_hip_status(hipMemcpy(get_device_ptr(), get_host_ptr(), dbuff.size_in_bytes, hipMemcpyKind::hipMemcpyHostToDevice)); - upload_to_device(stream, position, position + 1); - } + // TODO: don't copy over the entire buffer just the changed range + // check_hip_status(hipMemcpy(get_device_ptr(), get_host_ptr(), dbuff.size_in_bytes, hipMemcpyKind::hipMemcpyHostToDevice)); + upload_to_device(stream, position, position + 1); } ManagedBuffer_v2() = delete; @@ -184,7 +176,6 @@ namespace mlinfer DeviceBuffer dbuff; std::vector hbuff; size_t size_in_bytes; - bool with_offload_copy; }; using LLama2InputBuffer = ManagedBuffer_v2; diff --git a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp index 09f8a1f69d5..8f236f3373b 100644 --- a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp @@ -7,9 +7,7 @@ struct LLama2Inputs { LLama2Inputs( migraphx::program& prog, - migraphx::program_parameters& prog_args, - bool offload_copy) - : offload_copy(offload_copy) + migraphx::program_parameters& prog_args) { data.initialize(); prepareProgArgs(prog, prog_args); @@ -22,7 +20,7 @@ struct LLama2Inputs { auto inputShape = param_shapes[INPUTS_ID_STR]; auto input_ids = data.getInputIds(); - input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); + input_ids_buffer = std::make_unique(std::move(input_ids)); prog_args.add(INPUTS_ID_STR, migraphx::argument(inputShape, input_ids_buffer->data())); } @@ -31,7 +29,7 @@ struct LLama2Inputs auto attention_mask = data.getAttentionMask(); if (!simple) { - attention_mask_buffer = std::make_unique(std::move(attention_mask), offload_copy); + attention_mask_buffer = std::make_unique(std::move(attention_mask)); } prog_args.add(ATTENTION_MASK_STR, migraphx::argument(attShape, attention_mask_buffer->data())); @@ -43,7 +41,7 @@ struct LLama2Inputs auto past_keyString = past_keyStr.c_str(); if (!simple) { - past_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + past_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h))); } auto pastKeyShape = param_shapes[past_keyString]; prog_args.add(past_keyString, migraphx::argument(pastKeyShape, past_key_buffers[i]->data())); @@ -52,7 +50,7 @@ struct LLama2Inputs auto past_valueString = past_valueStr.c_str(); if (!simple) { - past_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h), offload_copy)); + past_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h))); } auto pastValueShape = param_shapes[past_valueString]; prog_args.add(past_valueString, migraphx::argument(pastValueShape, past_value_buffers[i]->data())); @@ -65,13 +63,12 @@ struct LLama2Inputs auto param_shapes = prog.get_parameter_shapes(); auto inputShape = param_shapes[INPUTS_ID_STR]; std::vector oneDimInput = {0}; - one_dim_input_buffer = std::make_unique(std::move(oneDimInput), offload_copy); + one_dim_input_buffer = std::make_unique(std::move(oneDimInput)); prog_args.add(INPUTS_ID_STR, migraphx::argument(inputShape, one_dim_input_buffer->data())); } void upload_to_device(hipStream_t stream) { - assert(not offload_copy); input_ids_buffer->upload_to_device(stream); attention_mask_buffer->upload_to_device(stream); } @@ -84,15 +81,11 @@ struct LLama2Inputs auto param_shapes = prog.get_parameter_shapes(); std::vector input_ids = data.getInputIds(); - input_ids_buffer = std::make_unique(std::move(input_ids), offload_copy); + input_ids_buffer = std::make_unique(std::move(input_ids)); prog_args.add(INPUTS_ID_STR, migraphx::argument(param_shapes[INPUTS_ID_STR], input_ids_buffer->data())); auto attention_mask = data.getAttentionMask(); attention_mask_buffer->update(std::move(attention_mask)); - if (offload_copy) - { - prog_args.add(ATTENTION_MASK_STR, migraphx::argument(param_shapes[ATTENTION_MASK_STR], attention_mask_buffer->data())); - } return true; } @@ -105,11 +98,8 @@ struct LLama2Inputs { past_key_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); past_value_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); - if (not offload_copy) - { - past_key_buffers[i]->upload_to_device(stream); - past_value_buffers[i]->upload_to_device(stream); - } + past_key_buffers[i]->upload_to_device(stream); + past_value_buffers[i]->upload_to_device(stream); } } @@ -126,7 +116,6 @@ struct LLama2Inputs std::vector> past_key_buffers; std::vector> past_value_buffers; Dataset data; - bool offload_copy; const char* INPUTS_ID_STR = "input_ids"; const char* ATTENTION_MASK_STR = "attention_mask"; diff --git a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp index 0a0ca3b0a1b..696483d9d68 100644 --- a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp @@ -7,15 +7,14 @@ struct LLama2Outputs { - LLama2Outputs(bool offload_copy) - : offload_copy(offload_copy) + LLama2Outputs() { } void prepareProgArgs(migraphx::program_parameters& prog_args, migraphx::program_parameters& prog_args_one_dim) { - output_buffer = std::make_unique(std::vector(OUTPUT_SIZE), offload_copy); - one_dim_output_buffer = std::make_unique(std::vector(VOCAB_SIZE), offload_copy); + output_buffer = std::make_unique(std::vector(OUTPUT_SIZE)); + one_dim_output_buffer = std::make_unique(std::vector(VOCAB_SIZE)); migraphx::shape out_shape{migraphx_shape_half_type, {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; prog_args.add(OUTPUT_NAME, migraphx::argument(out_shape, output_buffer->data())); @@ -28,18 +27,17 @@ struct LLama2Outputs // setting up argmax arguments migraphx::shape x_shape{migraphx_shape_half_type, {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; prog_args_argmax.add("x", migraphx::argument(x_shape, output_buffer->data())); - argm_output_buffer = std::make_unique(std::vector(VOCAB_SIZE), offload_copy); + argm_output_buffer = std::make_unique(std::vector(VOCAB_SIZE)); migraphx::shape argm_out_shape{migraphx_shape_int64_type, {BATCH_SIZE, SEQ_SIZE, 1}}; prog_args_argmax.add(OUTPUT_NAME, migraphx::argument(argm_out_shape, argm_output_buffer->data())); migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {BATCH_SIZE, 1, VOCAB_SIZE}}; prog_args_argmax_one_dim.add("x", migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); - argm_output_buffer_one_dim = std::make_unique(std::vector(1), offload_copy); + argm_output_buffer_one_dim = std::make_unique(std::vector(1)); migraphx::shape argm_out_shape_one_dim{migraphx_shape_int64_type, {BATCH_SIZE, 1, 1}}; prog_args_argmax_one_dim.add(OUTPUT_NAME, migraphx::argument(argm_out_shape_one_dim, argm_output_buffer_one_dim->data())); } - LLama2Outputs() = delete; LLama2Outputs(const LLama2Outputs &buf) = delete; LLama2Outputs &operator=(const LLama2Outputs &buf) = delete; @@ -48,8 +46,6 @@ struct LLama2Outputs std::unique_ptr argm_output_buffer; std::unique_ptr argm_output_buffer_one_dim; - bool offload_copy = false; - const char* OUTPUT_NAME = "main:#output_0"; const size_t OUTPUT_SIZE = BATCH_SIZE * SEQ_SIZE * VOCAB_SIZE; }; \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/harness/utils.hpp b/examples/transformers/mgx_llama2/harness/utils.hpp index 1575e104540..b51b8a52d13 100644 --- a/examples/transformers/mgx_llama2/harness/utils.hpp +++ b/examples/transformers/mgx_llama2/harness/utils.hpp @@ -13,7 +13,6 @@ struct ModelLoadSettings { size_t sequnce_length; bool quantize_fp16; - bool offload_copy; bool fast_math; bool input_one_dim; }; @@ -22,11 +21,7 @@ static std::string getModelPath(ModelLoadSettings& s) { std::stringstream path; path << MODEL_FOLDER << "model-" << std::to_string(s.sequnce_length) << "_fp" << (s.quantize_fp16 ? "16" : "32") << "_"; - if (!s.offload_copy) - { - path << "no"; - } - path << "offload_"; + path << "nooffload_"; if (!s.fast_math) { path << "no"; @@ -113,9 +108,6 @@ static migraphx::program loadOnnx(ModelLoadSettings& settings) migraphx::compile_options comp_opts; - if (settings.offload_copy) - comp_opts.set_offload_copy(); - if (settings.fast_math) comp_opts.set_fast_math(); @@ -187,9 +179,6 @@ static migraphx::program create_argmax_program(ModelLoadSettings& settings) migraphx::compile_options comp_opts; - if (settings.offload_copy) - comp_opts.set_offload_copy(); - if (settings.fast_math) comp_opts.set_fast_math(); diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index ff1e686db17..7bd1fd5c12c 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -16,15 +16,12 @@ struct MGXLlama2 loadPrograms(); - model_inputs = std::make_unique(*prog, prog_args, offload_copy); + model_inputs = std::make_unique(*prog, prog_args); model_inputs->prepareOneDimProgArgs(progSimpleInput, prog_args_one_dim); - if (not offload_copy) - { - model_inputs->upload_to_device(stream); - model_inputs->one_dim_input_buffer->upload_to_device(stream); - } + model_inputs->upload_to_device(stream); + model_inputs->one_dim_input_buffer->upload_to_device(stream); - model_outputs = std::make_unique(offload_copy); + model_outputs = std::make_unique(); model_outputs->prepareProgArgs(prog_args,prog_args_one_dim); // setting up argmax arguments model_outputs->prepareProgArgsArgMax(prog_args_argmax, prog_args_argmax_one_dim); @@ -64,8 +61,7 @@ struct MGXLlama2 void loadPrograms() { - std::cout << "Offload copy: " << std::boolalpha << offload_copy << std::endl; - ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, offload_copy /*offload_copy*/, false /*fast_math*/, false /*input_one_dim*/}; + ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, false /*fast_math*/, false /*input_one_dim*/}; progMultipleInputDim = loadProgram(settings); std::cout << "Model loaded" << std::endl; progArgMaxMultipleInputDim = create_argmax_program(settings); @@ -92,14 +88,11 @@ struct MGXLlama2 { bool firstIter = (i == lastInputIdx); prog->run_async(firstIter ? prog_args : prog_args_one_dim, stream); - auto outputs = progArgMax->run_async(firstIter ? prog_args_argmax : prog_args_argmax_one_dim, stream); - if (not offload_copy) - { - firstIter ? model_outputs->argm_output_buffer->download_from_device(stream, i, i + 1) : model_outputs->argm_output_buffer_one_dim->download_from_device(stream); - } + progArgMax->run_async(firstIter ? prog_args_argmax : prog_args_argmax_one_dim, stream); + firstIter ? model_outputs->argm_output_buffer->download_from_device(stream, i, i + 1) : model_outputs->argm_output_buffer_one_dim->download_from_device(stream); check_hip_status(hipStreamSynchronize(stream)); - int64_t* results = offload_copy ? reinterpret_cast(outputs[0].data()) : reinterpret_cast( firstIter ? model_outputs->argm_output_buffer->hbuff.data() : model_outputs->argm_output_buffer_one_dim->hbuff.data()); + int64_t* results = reinterpret_cast( firstIter ? model_outputs->argm_output_buffer->hbuff.data() : model_outputs->argm_output_buffer_one_dim->hbuff.data()); auto new_token_idx = firstIter ? i : 0; int64_t new_token = results[new_token_idx]; @@ -132,7 +125,7 @@ struct MGXLlama2 auto updated = model_inputs->updateData(*prog, prog_args); - if (updated && not offload_copy) + if (updated) { model_inputs->upload_to_device(stream); } @@ -158,7 +151,6 @@ struct MGXLlama2 std::vector> results; std::vector output_tokens; hipStream_t stream; - bool offload_copy = false; std::unique_ptr model_inputs; std::unique_ptr model_outputs; From a95ee80b1f85591988a8405e5c4565b38054e789 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Thu, 28 Nov 2024 07:44:04 -0600 Subject: [PATCH 44/55] Implement batching for mgxllama2 example --- .../mgx_llama2/harness/dataset.hpp | 47 ++++++-- .../mgx_llama2/harness/llama2inputs.hpp | 9 +- .../mgx_llama2/harness/llama2outputs.hpp | 6 +- .../transformers/mgx_llama2/harness/utils.hpp | 3 +- examples/transformers/mgx_llama2/mgxllama2.cc | 110 ++++++++++++------ 5 files changed, 122 insertions(+), 53 deletions(-) diff --git a/examples/transformers/mgx_llama2/harness/dataset.hpp b/examples/transformers/mgx_llama2/harness/dataset.hpp index 5d4722f8197..24a9c92a49e 100644 --- a/examples/transformers/mgx_llama2/harness/dataset.hpp +++ b/examples/transformers/mgx_llama2/harness/dataset.hpp @@ -51,31 +51,56 @@ struct Dataset return numpyData; } - size_t getLastIdx() const + size_t getLastIdx(int current_batch_idx) const { - auto res = std::find_if(std::rbegin(attention_mask[_current_idx]), std::rend(attention_mask[_current_idx]), [](uint64_t val) { return 1 == val;}); - size_t last_idx = std::distance(res, std::rend(attention_mask[_current_idx])); + auto idx = _current_batch * BATCH_SIZE + current_batch_idx; + auto res = std::find_if(std::rbegin(attention_mask[idx]), std::rend(attention_mask[idx]), [](uint64_t val) { return 1 == val;}); + size_t last_idx = std::distance(res, std::rend(attention_mask[idx])); #ifdef TRACE std::cout << "Last input idx: " << last_idx << std::endl; #endif return last_idx; } - std::vector getInputIds() { return input_ids[_current_idx]; } - std::vector getAttentionMask() { return attention_mask[_current_idx]; } + std::vector getInputIds() + { + std::vector inputIdsBatch; + inputIdsBatch.reserve(SEQ_SIZE*BATCH_SIZE); + for (size_t i = 0; i < BATCH_SIZE; ++i) + { + auto inputVec = input_ids[BATCH_SIZE*_current_batch + i]; + std::copy(inputVec.begin(), inputVec.end(), std::back_inserter(inputIdsBatch)); + } + return inputIdsBatch; + } + + std::vector getAttentionMask() + { + std::vector attentionMaskBatch; + attentionMaskBatch.reserve(SEQ_SIZE*BATCH_SIZE); + for (size_t i = 0; i < BATCH_SIZE; ++i) + { + auto attVec = attention_mask[BATCH_SIZE*_current_batch + i]; + std::copy(attVec.begin(), attVec.end(), std::back_inserter(attentionMaskBatch)); + } + return attentionMaskBatch; + } size_t size() const { return _size; } - size_t currentIdx() const { return _current_idx; } + size_t currentBatchIdx() const { return _current_batch; } + size_t batchNum() const { + return _size / BATCH_SIZE + (_size % BATCH_SIZE != 0); + } size_t getNext() { - if (_current_idx < size() - 1) + if (_current_batch < batchNum() - 1) { - ++_current_idx; + ++_current_batch; } #ifdef TRACE - std::cout << "Current idx: " << _current_idx << std::endl; + std::cout << "Current batch: " << _current_batch << std::endl; #endif - return _current_idx; + return _current_batch; } Dataset(const Dataset &buf) = delete; @@ -144,6 +169,6 @@ struct Dataset NumpyVector attention_mask; size_t _size = 0; - size_t _current_idx = 0; + size_t _current_batch = 0; bool _npy_files_loaded = false; }; \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp index 8f236f3373b..8c8fc1eda75 100644 --- a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp @@ -62,7 +62,7 @@ struct LLama2Inputs prepareProgArgs(prog, prog_args, true); auto param_shapes = prog.get_parameter_shapes(); auto inputShape = param_shapes[INPUTS_ID_STR]; - std::vector oneDimInput = {0}; + std::vector oneDimInput(BATCH_SIZE, 0); one_dim_input_buffer = std::make_unique(std::move(oneDimInput)); prog_args.add(INPUTS_ID_STR, migraphx::argument(inputShape, one_dim_input_buffer->data())); } @@ -75,8 +75,8 @@ struct LLama2Inputs bool updateData(migraphx::program& prog, migraphx::program_parameters& prog_args) { - auto currentIdx = data.currentIdx(); - if (currentIdx != data.getNext()) + auto batchIdx = data.currentBatchIdx(); + if (batchIdx != data.getNext()) { auto param_shapes = prog.get_parameter_shapes(); @@ -103,8 +103,9 @@ struct LLama2Inputs } } - size_t getLastInputIndex() const { return data.getLastIdx(); } + size_t getLastInputIndex(int current_batch_idx) const { return data.getLastIdx(current_batch_idx); } size_t dataSize() const { return data.size(); } + size_t batchNum() const { return data.batchNum(); } LLama2Inputs() = delete; LLama2Inputs(const LLama2Inputs &buf) = delete; diff --git a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp index 696483d9d68..87bc0da80a3 100644 --- a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp @@ -14,7 +14,7 @@ struct LLama2Outputs void prepareProgArgs(migraphx::program_parameters& prog_args, migraphx::program_parameters& prog_args_one_dim) { output_buffer = std::make_unique(std::vector(OUTPUT_SIZE)); - one_dim_output_buffer = std::make_unique(std::vector(VOCAB_SIZE)); + one_dim_output_buffer = std::make_unique(std::vector(BATCH_SIZE*VOCAB_SIZE)); migraphx::shape out_shape{migraphx_shape_half_type, {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; prog_args.add(OUTPUT_NAME, migraphx::argument(out_shape, output_buffer->data())); @@ -27,13 +27,13 @@ struct LLama2Outputs // setting up argmax arguments migraphx::shape x_shape{migraphx_shape_half_type, {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; prog_args_argmax.add("x", migraphx::argument(x_shape, output_buffer->data())); - argm_output_buffer = std::make_unique(std::vector(VOCAB_SIZE)); + argm_output_buffer = std::make_unique(std::vector(BATCH_SIZE*SEQ_SIZE)); migraphx::shape argm_out_shape{migraphx_shape_int64_type, {BATCH_SIZE, SEQ_SIZE, 1}}; prog_args_argmax.add(OUTPUT_NAME, migraphx::argument(argm_out_shape, argm_output_buffer->data())); migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {BATCH_SIZE, 1, VOCAB_SIZE}}; prog_args_argmax_one_dim.add("x", migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); - argm_output_buffer_one_dim = std::make_unique(std::vector(1)); + argm_output_buffer_one_dim = std::make_unique(std::vector(BATCH_SIZE)); migraphx::shape argm_out_shape_one_dim{migraphx_shape_int64_type, {BATCH_SIZE, 1, 1}}; prog_args_argmax_one_dim.add(OUTPUT_NAME, migraphx::argument(argm_out_shape_one_dim, argm_output_buffer_one_dim->data())); } diff --git a/examples/transformers/mgx_llama2/harness/utils.hpp b/examples/transformers/mgx_llama2/harness/utils.hpp index b51b8a52d13..7d1a9287921 100644 --- a/examples/transformers/mgx_llama2/harness/utils.hpp +++ b/examples/transformers/mgx_llama2/harness/utils.hpp @@ -15,13 +15,14 @@ struct ModelLoadSettings bool quantize_fp16; bool fast_math; bool input_one_dim; + size_t batch_size; }; static std::string getModelPath(ModelLoadSettings& s) { std::stringstream path; path << MODEL_FOLDER << "model-" << std::to_string(s.sequnce_length) << "_fp" << (s.quantize_fp16 ? "16" : "32") << "_"; - path << "nooffload_"; + path << "batch_" << std::to_string(s.batch_size) << "_"; if (!s.fast_math) { path << "no"; diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 7bd1fd5c12c..8c1666a76f8 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -1,3 +1,4 @@ +#include "config.hpp" #include "buffer.hpp" #include "common.hpp" #include "dataset.hpp" @@ -25,27 +26,25 @@ struct MGXLlama2 model_outputs->prepareProgArgs(prog_args,prog_args_one_dim); // setting up argmax arguments model_outputs->prepareProgArgsArgMax(prog_args_argmax, prog_args_argmax_one_dim); + output_tokens.resize(BATCH_SIZE, std::vector()); } void run() { std::cout << "Dataset size: " << model_inputs->dataSize() << std::endl; + std::cout << "Number of batches: " << model_inputs->batchNum() << std::endl; std::cout << "Starting evaluation" << std::endl; size_t token_count = 0; auto start = std::chrono::steady_clock::now(); - for (size_t i = 0; i < model_inputs->dataSize(); ++i) + for (size_t i = 0; i < model_inputs->batchNum(); ++i) { - evaluateSample(i, token_count); + evaluateBatch(i, token_count); #ifdef TRACE std::cout << "######### Output token ids for #" << i << " #########" << std::endl; - // print output tokens - for (auto tok: output_tokens){ - std::cout << tok << ", "; - } - std::cout << std::endl; + printOutputTokens(); #endif - prepareNextSample(); + prepareNextBatch(); } float dur = std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count() / 1000.f; @@ -61,7 +60,7 @@ struct MGXLlama2 void loadPrograms() { - ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, false /*fast_math*/, false /*input_one_dim*/}; + ModelLoadSettings settings = {SEQ_SIZE, false /*quantize_fp16*/, false /*fast_math*/, false /*input_one_dim*/, BATCH_SIZE}; progMultipleInputDim = loadProgram(settings); std::cout << "Model loaded" << std::endl; progArgMaxMultipleInputDim = create_argmax_program(settings); @@ -78,37 +77,62 @@ struct MGXLlama2 progArgMax = &progArgMaxMultipleInputDim; } - void evaluateSample(size_t sample_id, size_t& token_count) + void evaluateBatch(size_t batch_idx, size_t& token_count) { #ifdef TRACE - std::cout << "Iter #" << sample_id << std::endl; + std::cout << "Iter #" << batch_idx << std::endl; #endif - auto lastInputIdx = model_inputs->getLastInputIndex(); - for (size_t i = lastInputIdx; i < SEQ_SIZE - 1; ++i) + + //auto lastInputIdx = model_inputs->getLastInputIndex(); + std::vector sampleLastInputIdx; + std::vector finishedBatches; + for (size_t i = 0; i < SEQ_SIZE - 1; ++i) { - bool firstIter = (i == lastInputIdx); + bool firstIter = (i == 0); prog->run_async(firstIter ? prog_args : prog_args_one_dim, stream); progArgMax->run_async(firstIter ? prog_args_argmax : prog_args_argmax_one_dim, stream); - firstIter ? model_outputs->argm_output_buffer->download_from_device(stream, i, i + 1) : model_outputs->argm_output_buffer_one_dim->download_from_device(stream); - - check_hip_status(hipStreamSynchronize(stream)); - int64_t* results = reinterpret_cast( firstIter ? model_outputs->argm_output_buffer->hbuff.data() : model_outputs->argm_output_buffer_one_dim->hbuff.data()); - auto new_token_idx = firstIter ? i : 0; - int64_t new_token = results[new_token_idx]; - token_count++; - #ifdef TRACE - std::cout << "New token: " << new_token << std::endl; - #endif - output_tokens.push_back(new_token); - - if (new_token == EOS) + for (size_t b = 0; b < BATCH_SIZE; ++b) { - break; + if (firstIter) + { + sampleLastInputIdx.emplace_back(model_inputs->getLastInputIndex(b) + (b * SEQ_SIZE)); + } + else + { + if (std::find(std::begin(finishedBatches), std::end(finishedBatches), b) != std::end(finishedBatches)) + { + continue; + } + } + + firstIter ? model_outputs->argm_output_buffer->download_from_device(stream, sampleLastInputIdx[b], sampleLastInputIdx[b] + 1) : model_outputs->argm_output_buffer_one_dim->download_from_device(stream); + + check_hip_status(hipStreamSynchronize(stream)); + int64_t* results = reinterpret_cast( firstIter ? model_outputs->argm_output_buffer->hbuff.data() : model_outputs->argm_output_buffer_one_dim->hbuff.data()); + auto new_token_idx = firstIter ? sampleLastInputIdx[b] : b; + int64_t new_token = results[new_token_idx]; + + token_count++; + #ifdef TRACE + std::cout << "New token for batch (" << b << "): " << new_token << std::endl; + #endif + output_tokens[b].push_back(new_token); + + if (new_token == EOS) + { + #ifdef TRACE + std::cout << b << " batch is added to finished" << std::endl; + #endif + finishedBatches.emplace_back(b); + } + + model_inputs->attention_mask_buffer->update_data(1, sampleLastInputIdx[b] + i + 1, stream); + model_inputs->one_dim_input_buffer->update_data(new_token, b, stream); } - model_inputs->attention_mask_buffer->update_data(1, i + 1, stream); - model_inputs->one_dim_input_buffer->update_data(new_token, 0, stream); + if (BATCH_SIZE == finishedBatches.size()) + break; if (firstIter) { @@ -118,7 +142,7 @@ struct MGXLlama2 } } - void prepareNextSample() + void prepareNextBatch() { prog = &progMultipleInputDim; progArgMax = &progArgMaxMultipleInputDim; @@ -129,8 +153,26 @@ struct MGXLlama2 { model_inputs->upload_to_device(stream); } - results.emplace_back(output_tokens); - output_tokens.clear(); + for (auto& tokens : output_tokens) + { + results.emplace_back(tokens); + tokens.clear(); + } + } + + void printOutputTokens() const + { + // print output tokens + for (size_t b = 0; b < output_tokens.size(); ++b) + { + std::cout << "######### Batch #" << b << " #########" << std::endl; + for (auto tok: output_tokens[b]) + { + std::cout << tok << ", "; + } + std::cout << std::endl; + } + std::cout << std::endl; } MGXLlama2(const MGXLlama2 &buf) = delete; @@ -149,7 +191,7 @@ struct MGXLlama2 migraphx::program_parameters prog_args_argmax_one_dim; std::vector> results; - std::vector output_tokens; + std::vector> output_tokens; hipStream_t stream; std::unique_ptr model_inputs; From 2142a40c31a17da62042311e41370abd17f39e91 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Fri, 29 Nov 2024 07:10:22 -0600 Subject: [PATCH 45/55] Fix last batch when there is not enough sample --- .../mgx_llama2/harness/dataset.hpp | 39 ++++++++++++------- examples/transformers/mgx_llama2/mgxllama2.cc | 39 +++++++++++++------ 2 files changed, 52 insertions(+), 26 deletions(-) diff --git a/examples/transformers/mgx_llama2/harness/dataset.hpp b/examples/transformers/mgx_llama2/harness/dataset.hpp index 24a9c92a49e..43e39350e56 100644 --- a/examples/transformers/mgx_llama2/harness/dataset.hpp +++ b/examples/transformers/mgx_llama2/harness/dataset.hpp @@ -64,26 +64,37 @@ struct Dataset std::vector getInputIds() { - std::vector inputIdsBatch; - inputIdsBatch.reserve(SEQ_SIZE*BATCH_SIZE); - for (size_t i = 0; i < BATCH_SIZE; ++i) - { - auto inputVec = input_ids[BATCH_SIZE*_current_batch + i]; - std::copy(inputVec.begin(), inputVec.end(), std::back_inserter(inputIdsBatch)); - } - return inputIdsBatch; + return getBatchedBuffer(input_ids); } std::vector getAttentionMask() { - std::vector attentionMaskBatch; - attentionMaskBatch.reserve(SEQ_SIZE*BATCH_SIZE); - for (size_t i = 0; i < BATCH_SIZE; ++i) + return getBatchedBuffer(attention_mask, 1); + } + + std::vector getBatchedBuffer(NumpyVector& buffer, size_t value = 0) + { + std::vector batchedBuffer; + const size_t buffer_size = SEQ_SIZE*BATCH_SIZE; + batchedBuffer.reserve(buffer_size); + auto batchSize = BATCH_SIZE; + if (_current_batch == batchNum() -1) + { + batchSize = _size % BATCH_SIZE; + } + + for (size_t i = 0; i < batchSize; ++i) + { + auto buffVec = buffer[BATCH_SIZE*_current_batch + i]; + std::copy(buffVec.begin(), buffVec.end(), std::back_inserter(batchedBuffer)); + } + + if (batchSize != BATCH_SIZE) { - auto attVec = attention_mask[BATCH_SIZE*_current_batch + i]; - std::copy(attVec.begin(), attVec.end(), std::back_inserter(attentionMaskBatch)); + // For last batch, setting buffer values if no sample is available + batchedBuffer.resize(buffer_size, value); } - return attentionMaskBatch; + return batchedBuffer; } size_t size() const { return _size; } diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 8c1666a76f8..8512dcf6bce 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -83,28 +83,23 @@ struct MGXLlama2 std::cout << "Iter #" << batch_idx << std::endl; #endif - //auto lastInputIdx = model_inputs->getLastInputIndex(); std::vector sampleLastInputIdx; - std::vector finishedBatches; + std::vector batches = getBatchesToProcess(batch_idx); + for (size_t i = 0; i < SEQ_SIZE - 1; ++i) { bool firstIter = (i == 0); prog->run_async(firstIter ? prog_args : prog_args_one_dim, stream); progArgMax->run_async(firstIter ? prog_args_argmax : prog_args_argmax_one_dim, stream); - for (size_t b = 0; b < BATCH_SIZE; ++b) + auto batchIt = std::begin(batches); + while (batchIt != std::end(batches)) { + auto b = *batchIt; if (firstIter) { sampleLastInputIdx.emplace_back(model_inputs->getLastInputIndex(b) + (b * SEQ_SIZE)); } - else - { - if (std::find(std::begin(finishedBatches), std::end(finishedBatches), b) != std::end(finishedBatches)) - { - continue; - } - } firstIter ? model_outputs->argm_output_buffer->download_from_device(stream, sampleLastInputIdx[b], sampleLastInputIdx[b] + 1) : model_outputs->argm_output_buffer_one_dim->download_from_device(stream); @@ -124,14 +119,18 @@ struct MGXLlama2 #ifdef TRACE std::cout << b << " batch is added to finished" << std::endl; #endif - finishedBatches.emplace_back(b); + batchIt = batches.erase(batchIt); + } + else + { + ++batchIt; } model_inputs->attention_mask_buffer->update_data(1, sampleLastInputIdx[b] + i + 1, stream); model_inputs->one_dim_input_buffer->update_data(new_token, b, stream); } - if (BATCH_SIZE == finishedBatches.size()) + if (batches.empty()) break; if (firstIter) @@ -160,11 +159,27 @@ struct MGXLlama2 } } + std::vector getBatchesToProcess(size_t batch_idx) + { + std::vector batches; + size_t batchSizeRem = BATCH_SIZE; + if (batch_idx == model_inputs->batchNum() - 1) + { + batchSizeRem = model_inputs->dataSize() % BATCH_SIZE; + } + batches.resize(batchSizeRem); + std::iota(std::begin(batches), std::end(batches), 0); + return batches; + } + void printOutputTokens() const { // print output tokens for (size_t b = 0; b < output_tokens.size(); ++b) { + if (output_tokens[b].empty()) + continue; + std::cout << "######### Batch #" << b << " #########" << std::endl; for (auto tok: output_tokens[b]) { From b919b0f62a706a4c7f8c3607a918471db1e989b0 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Wed, 4 Dec 2024 04:11:26 -0600 Subject: [PATCH 46/55] Revert "Add new migraphx public API functions: replace_return and get_last_instruction" This reverts commit d970b304d7642758ceb10d4c2bef729a1c6e717e. --- src/api/api.cpp | 29 --------------------------- src/api/include/migraphx/migraphx.h | 7 ------- src/api/include/migraphx/migraphx.hpp | 17 ---------------- src/include/migraphx/module.hpp | 2 -- src/module.cpp | 10 --------- 5 files changed, 65 deletions(-) diff --git a/src/api/api.cpp b/src/api/api.cpp index 811fafd3f5a..4ecd0763225 100644 --- a/src/api/api.cpp +++ b/src/api/api.cpp @@ -1522,21 +1522,6 @@ extern "C" migraphx_status migraphx_module_add_instruction(migraphx_instruction_ return api_error_result; } -extern "C" migraphx_status migraphx_module_get_last_instruction(migraphx_instruction_t* out, - migraphx_module_t module) -{ - auto api_error_result = migraphx::try_([&] { - if(module == nullptr) - { - std::cout << "# migraphx_module_get_last_instruction nullptr" << std::endl; - MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer"); - } - *out = allocate( - (module->object).get_last_instruction()); - }); - return api_error_result; -} - extern "C" migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out, migraphx_module_t module, @@ -1605,20 +1590,6 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou return api_error_result; } -extern "C" migraphx_status migraphx_module_replace_return(migraphx_instruction_t* out, - migraphx_module_t module, - migraphx_instructions_t args) -{ - auto api_error_result = migraphx::try_([&] { - if(module == nullptr) - MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer"); - if(args == nullptr) - MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer"); - *out = allocate((module->object).replace_return((args->object))); - }); - return api_error_result; -} - extern "C" migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out, migraphx_module_t module, const_migraphx_shape_t s) diff --git a/src/api/include/migraphx/migraphx.h b/src/api/include/migraphx/migraphx.h index ea829424aff..90ba7c3e017 100644 --- a/src/api/include/migraphx/migraphx.h +++ b/src/api/include/migraphx/migraphx.h @@ -401,9 +401,6 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_instruction(migraphx_instr migraphx_operation_t op, migraphx_instructions_t args); -MIGRAPHX_C_EXPORT migraphx_status migraphx_module_get_last_instruction(migraphx_instruction_t* out, - migraphx_module_t module); - MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out, migraphx_module_t module, @@ -425,10 +422,6 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_return(migraphx_instructio migraphx_module_t module, migraphx_instructions_t args); -MIGRAPHX_C_EXPORT migraphx_status migraphx_module_replace_return(migraphx_instruction_t* out, - migraphx_module_t module, - migraphx_instructions_t args); - MIGRAPHX_C_EXPORT migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out, migraphx_module_t module, const_migraphx_shape_t s); diff --git a/src/api/include/migraphx/migraphx.hpp b/src/api/include/migraphx/migraphx.hpp index 1870992e999..fa6339b4389 100644 --- a/src/api/include/migraphx/migraphx.hpp +++ b/src/api/include/migraphx/migraphx.hpp @@ -1049,16 +1049,6 @@ struct module return instruction(op_ins, own{}); } - instruction get_last_instruction() - { - std::cout << "# get_last_instruction called" << std::endl; - migraphx_instruction_t op_ins; - call(&migraphx_module_get_last_instruction, - &op_ins, - mm.get()); - return instruction(op_ins, own{}); - } - instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args, const migraphx::modules& module_args) @@ -1097,13 +1087,6 @@ struct module return instruction(ret_ins, own{}); } - instruction replace_return(const migraphx::instructions& args) - { - migraphx_instruction_t ret_ins; - call(&migraphx_module_replace_return, &ret_ins, mm.get(), args.get_handle_ptr()); - return instruction(ret_ins, own{}); - } - instruction add_allocation(const migraphx::shape& s) { migraphx_instruction_t ret_ins; diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index e4a0529fcd8..68734c82baa 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -86,8 +86,6 @@ struct MIGRAPHX_EXPORT module return add_instruction(op, {args...}); } - instruction_ref get_last_instruction(); - instruction_ref add_instruction(const operation& op, std::vector args); instruction_ref add_instruction(const operation& op, diff --git a/src/module.cpp b/src/module.cpp index 2f1739a8578..7e02478b385 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -288,16 +288,6 @@ instruction_ref module::add_instruction(const operation& op, std::vectorinstructions.end(), op, std::move(args)); } - -instruction_ref module::get_last_instruction() -{ - auto last_instr = std::prev(this->end()); - if (last_instr->name() == "@return") - last_instr = std::prev(last_instr); - return last_instr; -} - - instruction_ref module::insert_instruction(instruction_ref ins, const operation& op, std::vector args) From 41447b08ad85562fc59d111a6488d17d7daa5c58 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Wed, 4 Dec 2024 04:57:24 -0600 Subject: [PATCH 47/55] Remove unused python imports --- examples/transformers/mgx_llama2/eval_accuracy.py | 3 +-- examples/transformers/mgx_llama2/preprocess_dataset.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/transformers/mgx_llama2/eval_accuracy.py b/examples/transformers/mgx_llama2/eval_accuracy.py index 803c55dd2f0..ae175df991a 100644 --- a/examples/transformers/mgx_llama2/eval_accuracy.py +++ b/examples/transformers/mgx_llama2/eval_accuracy.py @@ -2,10 +2,9 @@ import numpy as np import pickle from pathlib import Path -import os import evaluate import nltk -from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM +from transformers import AutoTokenizer MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" diff --git a/examples/transformers/mgx_llama2/preprocess_dataset.py b/examples/transformers/mgx_llama2/preprocess_dataset.py index f3d096f0592..8632355372f 100644 --- a/examples/transformers/mgx_llama2/preprocess_dataset.py +++ b/examples/transformers/mgx_llama2/preprocess_dataset.py @@ -1,7 +1,6 @@ import numpy as np import pickle from pathlib import Path -import os G_MAX_TOK_LEN = 1024 G_LLAMA2_EOS = 2 From 91f368b71ff9a8787569acd683778c8c4a516bd7 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Wed, 4 Dec 2024 05:38:32 -0600 Subject: [PATCH 48/55] Format files with clang-format --- .../mgx_llama2/harness/buffer.hpp | 333 ++++++++---------- .../mgx_llama2/harness/common.hpp | 138 ++++---- .../mgx_llama2/harness/config.hpp | 3 +- .../mgx_llama2/harness/dataset.hpp | 299 ++++++++-------- .../mgx_llama2/harness/llama2inputs.hpp | 232 ++++++------ .../mgx_llama2/harness/llama2outputs.hpp | 104 +++--- .../mgx_llama2/harness/logging.hpp | 37 +- .../transformers/mgx_llama2/harness/numa.hpp | 260 +++++++------- .../transformers/mgx_llama2/harness/numpy.hpp | 171 +++++---- .../transformers/mgx_llama2/harness/timer.hpp | 99 +++--- .../transformers/mgx_llama2/harness/utils.hpp | 311 ++++++++-------- 11 files changed, 947 insertions(+), 1040 deletions(-) diff --git a/examples/transformers/mgx_llama2/harness/buffer.hpp b/examples/transformers/mgx_llama2/harness/buffer.hpp index 5216c7023f9..0d04c1d4d00 100644 --- a/examples/transformers/mgx_llama2/harness/buffer.hpp +++ b/examples/transformers/mgx_llama2/harness/buffer.hpp @@ -2,184 +2,155 @@ #include "common.hpp" -namespace mlinfer -{ - template - struct IBuffer : public INoCopy - { - AllocFunc alloc_fn; - FreeFunc free_fn; - }; - - template - struct GenericBuffer : public IBuffer - { - GenericBuffer() - : size_in_bytes{0}, stride_in_bytes{0}, tensor_ptr{nullptr} - { - } - - explicit GenericBuffer(size_t size_in_bytes_, size_t stride_in_bytes_ = 0) - : size_in_bytes{size_in_bytes_}, stride_in_bytes{stride_in_bytes_} - { - if (stride_in_bytes == 0) - { - stride_in_bytes = size_in_bytes; - } - this->alloc_fn(&tensor_ptr, size_in_bytes); - } - - GenericBuffer(GenericBuffer &&buf) - : size_in_bytes{buf.size_in_bytes}, stride_in_bytes{buf.stride_in_bytes}, tensor_ptr{buf.tensor_ptr} - { - buf.size_in_bytes = 0; - buf.stride_in_bytes = 0; - buf.tensor_ptr = nullptr; - } - - GenericBuffer &operator=(GenericBuffer &&buf) - { - if (this != &buf) - { - this->free_fn(tensor_ptr); - size_in_bytes = buf.size_in_bytes; - stride_in_bytes = buf.stride_in_bytes; - tensor_ptr = buf.tensor_ptr; - buf.size_in_bytes = 0; - buf.stride_in_bytes = 0; - buf.tensor_ptr = nullptr; - } - return *this; - } - - GenericBuffer(const GenericBuffer &buf) = delete; - GenericBuffer &operator=(const GenericBuffer &buf) = delete; - - ~GenericBuffer() - { - this->free_fn(tensor_ptr); - } - - size_t size_in_bytes; - size_t stride_in_bytes; - void *tensor_ptr; - }; - - struct DeviceAllocator - { - void operator()(void **ptr, size_t size) const - { - LOG_INFO("Malloc " << size << " bytes on device"); - TIMED(hipMalloc, check_hip_status(hipMalloc(ptr, size))); - TIMED(hipMemset, check_hip_status(hipMemset(*ptr, 0, size))); - } - }; - - struct DeviceFree - { - void operator()(void *ptr) const - { - TIMED(hipFree, check_hip_status_non_throwing(hipFree(ptr))); - ptr = nullptr; - } - }; - - struct HostAllocator - { - void operator()(void **ptr, size_t size) const - { - LOG_INFO("Malloc " << size << " bytes on host"); - TIMED(hipHostMalloc, check_hip_status(hipHostMalloc(ptr, size))); - TIMED(hipMemset, check_hip_status(hipMemset(*ptr, 0, size))); - } - }; - - struct HostFree - { - void operator()(void *ptr) const - { - TIMED(hipHostFree, check_hip_status_non_throwing(hipHostFree(ptr))); - ptr = nullptr; - } - }; - - using DeviceBuffer = GenericBuffer; - using HostBuffer = GenericBuffer; - - template - struct ManagedBuffer_v2 - { - - explicit ManagedBuffer_v2(std::vector&& host_data) - { - size_in_bytes = host_data.size() * sizeof(T); - hbuff = std::move(host_data); - dbuff = DeviceBuffer(size_in_bytes, 0); - } - - void* data() - { - return dbuff.tensor_ptr; - } - - void update(std::vector&& host_data) - { - hbuff = std::move(host_data); - } - - void upload_to_device(hipStream_t stream, size_t start_idx=0, size_t end_idx=0) - { - char* src_addr = reinterpret_cast(hbuff.data()); - char* dst_addr = static_cast(dbuff.tensor_ptr); - size_t copy_size_in_bytes = size_in_bytes; - - size_t range_size_in_bytes = (end_idx - start_idx) * sizeof(T); - if (range_size_in_bytes > 0) - { - size_t offset = start_idx * sizeof(T); - src_addr += offset; - dst_addr += offset; - copy_size_in_bytes = range_size_in_bytes; - } - check_hip_status(hipMemcpyHtoDAsync(dst_addr, src_addr, copy_size_in_bytes, stream)); - } - - void download_from_device(hipStream_t stream, size_t start_idx=0, size_t end_idx=0) - { - char* src_addr = static_cast(dbuff.tensor_ptr); - char* dst_addr = reinterpret_cast(hbuff.data()); - size_t copy_size_in_bytes = size_in_bytes; - - size_t range_size_in_bytes = (end_idx - start_idx) * sizeof(T); - if (range_size_in_bytes > 0) - { - size_t offset = start_idx * sizeof(T); - src_addr += offset; - dst_addr += offset; - copy_size_in_bytes = range_size_in_bytes; - } - check_hip_status(hipMemcpyDtoHAsync(dst_addr, src_addr, copy_size_in_bytes, stream)); - } - - void update_data(T data, size_t position, hipStream_t stream) - { - hbuff.at(position) = data; - // TODO: don't copy over the entire buffer just the changed range - // check_hip_status(hipMemcpy(get_device_ptr(), get_host_ptr(), dbuff.size_in_bytes, hipMemcpyKind::hipMemcpyHostToDevice)); - upload_to_device(stream, position, position + 1); - } - - ManagedBuffer_v2() = delete; - ManagedBuffer_v2(const ManagedBuffer_v2 &buf) = delete; - ManagedBuffer_v2 &operator=(const ManagedBuffer_v2 &buf) = delete; - - DeviceBuffer dbuff; - std::vector hbuff; - size_t size_in_bytes; - }; - - using LLama2InputBuffer = ManagedBuffer_v2; - using LLama2OutputBuffer = ManagedBuffer_v2; - using LLama2PastKeyValueBuffer = ManagedBuffer_v2; - using ArgMaxOutputBuffer = ManagedBuffer_v2; -} +namespace mlinfer { +template +struct IBuffer : public INoCopy { + AllocFunc alloc_fn; + FreeFunc free_fn; +}; + +template +struct GenericBuffer : public IBuffer { + GenericBuffer() : size_in_bytes{0}, stride_in_bytes{0}, tensor_ptr{nullptr} {} + + explicit GenericBuffer(size_t size_in_bytes_, size_t stride_in_bytes_ = 0) + : size_in_bytes{size_in_bytes_}, stride_in_bytes{stride_in_bytes_} { + if (stride_in_bytes == 0) { + stride_in_bytes = size_in_bytes; + } + this->alloc_fn(&tensor_ptr, size_in_bytes); + } + + GenericBuffer(GenericBuffer &&buf) + : size_in_bytes{buf.size_in_bytes}, stride_in_bytes{buf.stride_in_bytes}, + tensor_ptr{buf.tensor_ptr} { + buf.size_in_bytes = 0; + buf.stride_in_bytes = 0; + buf.tensor_ptr = nullptr; + } + + GenericBuffer &operator=(GenericBuffer &&buf) { + if (this != &buf) { + this->free_fn(tensor_ptr); + size_in_bytes = buf.size_in_bytes; + stride_in_bytes = buf.stride_in_bytes; + tensor_ptr = buf.tensor_ptr; + buf.size_in_bytes = 0; + buf.stride_in_bytes = 0; + buf.tensor_ptr = nullptr; + } + return *this; + } + + GenericBuffer(const GenericBuffer &buf) = delete; + GenericBuffer &operator=(const GenericBuffer &buf) = delete; + + ~GenericBuffer() { this->free_fn(tensor_ptr); } + + size_t size_in_bytes; + size_t stride_in_bytes; + void *tensor_ptr; +}; + +struct DeviceAllocator { + void operator()(void **ptr, size_t size) const { + LOG_INFO("Malloc " << size << " bytes on device"); + TIMED(hipMalloc, check_hip_status(hipMalloc(ptr, size))); + TIMED(hipMemset, check_hip_status(hipMemset(*ptr, 0, size))); + } +}; + +struct DeviceFree { + void operator()(void *ptr) const { + TIMED(hipFree, check_hip_status_non_throwing(hipFree(ptr))); + ptr = nullptr; + } +}; + +struct HostAllocator { + void operator()(void **ptr, size_t size) const { + LOG_INFO("Malloc " << size << " bytes on host"); + TIMED(hipHostMalloc, check_hip_status(hipHostMalloc(ptr, size))); + TIMED(hipMemset, check_hip_status(hipMemset(*ptr, 0, size))); + } +}; + +struct HostFree { + void operator()(void *ptr) const { + TIMED(hipHostFree, check_hip_status_non_throwing(hipHostFree(ptr))); + ptr = nullptr; + } +}; + +using DeviceBuffer = GenericBuffer; +using HostBuffer = GenericBuffer; + +template struct ManagedBuffer_v2 { + + explicit ManagedBuffer_v2(std::vector &&host_data) { + size_in_bytes = host_data.size() * sizeof(T); + hbuff = std::move(host_data); + dbuff = DeviceBuffer(size_in_bytes, 0); + } + + void *data() { return dbuff.tensor_ptr; } + + void update(std::vector &&host_data) { hbuff = std::move(host_data); } + + void upload_to_device(hipStream_t stream, size_t start_idx = 0, + size_t end_idx = 0) { + char *src_addr = reinterpret_cast(hbuff.data()); + char *dst_addr = static_cast(dbuff.tensor_ptr); + size_t copy_size_in_bytes = size_in_bytes; + + size_t range_size_in_bytes = (end_idx - start_idx) * sizeof(T); + if (range_size_in_bytes > 0) { + size_t offset = start_idx * sizeof(T); + src_addr += offset; + dst_addr += offset; + copy_size_in_bytes = range_size_in_bytes; + } + check_hip_status( + hipMemcpyHtoDAsync(dst_addr, src_addr, copy_size_in_bytes, stream)); + } + + void download_from_device(hipStream_t stream, size_t start_idx = 0, + size_t end_idx = 0) { + char *src_addr = static_cast(dbuff.tensor_ptr); + char *dst_addr = reinterpret_cast(hbuff.data()); + size_t copy_size_in_bytes = size_in_bytes; + + size_t range_size_in_bytes = (end_idx - start_idx) * sizeof(T); + if (range_size_in_bytes > 0) { + size_t offset = start_idx * sizeof(T); + src_addr += offset; + dst_addr += offset; + copy_size_in_bytes = range_size_in_bytes; + } + check_hip_status( + hipMemcpyDtoHAsync(dst_addr, src_addr, copy_size_in_bytes, stream)); + } + + void update_data(T data, size_t position, hipStream_t stream) { + hbuff.at(position) = data; + // TODO: don't copy over the entire buffer just the changed range + // check_hip_status(hipMemcpy(get_device_ptr(), + // get_host_ptr(), dbuff.size_in_bytes, + // hipMemcpyKind::hipMemcpyHostToDevice)); + upload_to_device(stream, position, position + 1); + } + + ManagedBuffer_v2() = delete; + ManagedBuffer_v2(const ManagedBuffer_v2 &buf) = delete; + ManagedBuffer_v2 &operator=(const ManagedBuffer_v2 &buf) = delete; + + DeviceBuffer dbuff; + std::vector hbuff; + size_t size_in_bytes; +}; + +using LLama2InputBuffer = ManagedBuffer_v2; +using LLama2OutputBuffer = ManagedBuffer_v2; +using LLama2PastKeyValueBuffer = ManagedBuffer_v2; +using ArgMaxOutputBuffer = ManagedBuffer_v2; +} // namespace mlinfer diff --git a/examples/transformers/mgx_llama2/harness/common.hpp b/examples/transformers/mgx_llama2/harness/common.hpp index dbf65af459c..1df0159f0b0 100644 --- a/examples/transformers/mgx_llama2/harness/common.hpp +++ b/examples/transformers/mgx_llama2/harness/common.hpp @@ -6,10 +6,10 @@ #include "timer.hpp" #include -#include +#include #include #include -#include +#include #define TIMER_ON 0 #define TRACE_ON 0 @@ -19,85 +19,75 @@ using half = half_float::half; using namespace half_float::literal; -namespace mlinfer -{ - struct INoCopy - { - INoCopy() = default; - virtual ~INoCopy() = default; - INoCopy(const INoCopy &) = delete; - INoCopy &operator=(const INoCopy &) = delete; - }; - - /* Helper function to split a string based on a delimiting character */ - inline std::vector - splitString(const std::string &input, const std::string &delimiter) - { - std::vector result; - size_t start = 0; - size_t next = 0; - while (next != std::string::npos) - { - next = input.find(delimiter, start); - result.emplace_back(input, start, next - start); - start = next + 1; - } - return result; - } - -#define check_hip_status(hip_call) \ - do \ - { \ - int status = (hip_call); \ - if (status != hipSuccess) \ - { \ - throw std::runtime_error("hip error (" + std::to_string(status) + "): " + std::string(hipGetErrorString(static_cast(status)))); \ - } \ - } while (0); - -#define check_hip_status_non_throwing(hip_call) \ - do \ - { \ - int status = (hip_call); \ - if (status != hipSuccess) \ - { \ - LOG_INFO("hip error (" + std::to_string(status) + "): " + std::string(hipGetErrorString(static_cast(status)))); \ - } \ - } while (0); - - -#define CHECK(condition, error) \ - do \ - { \ - if (!(condition)) \ - { \ - std::cerr << error << std::endl; \ - } \ - } while (0); +namespace mlinfer { +struct INoCopy { + INoCopy() = default; + virtual ~INoCopy() = default; + INoCopy(const INoCopy &) = delete; + INoCopy &operator=(const INoCopy &) = delete; +}; + +/* Helper function to split a string based on a delimiting character */ +inline std::vector splitString(const std::string &input, + const std::string &delimiter) { + std::vector result; + size_t start = 0; + size_t next = 0; + while (next != std::string::npos) { + next = input.find(delimiter, start); + result.emplace_back(input, start, next - start); + start = next + 1; + } + return result; +} + +#define check_hip_status(hip_call) \ + do { \ + int status = (hip_call); \ + if (status != hipSuccess) { \ + throw std::runtime_error( \ + "hip error (" + std::to_string(status) + "): " + \ + std::string(hipGetErrorString(static_cast(status)))); \ + } \ + } while (0); + +#define check_hip_status_non_throwing(hip_call) \ + do { \ + int status = (hip_call); \ + if (status != hipSuccess) { \ + LOG_INFO( \ + "hip error (" + std::to_string(status) + "): " + \ + std::string(hipGetErrorString(static_cast(status)))); \ + } \ + } while (0); + +#define CHECK(condition, error) \ + do { \ + if (!(condition)) { \ + std::cerr << error << std::endl; \ + } \ + } while (0); #if TIMER_ON -#define TIMER_STARTV(s) \ - static Timer timer##s(#s, true); \ - auto start##s = std::chrono::high_resolution_clock::now(); -#define TIMER_START(s) \ - static Timer timer##s(#s); \ - auto start##s = std::chrono::high_resolution_clock::now(); -#define TIMER_END(s) timer##s.add(std::chrono::high_resolution_clock::now() - start##s); +#define TIMER_STARTV(s) \ + static Timer timer##s(#s, true); \ + auto start##s = std::chrono::high_resolution_clock::now(); +#define TIMER_START(s) \ + static Timer timer##s(#s); \ + auto start##s = std::chrono::high_resolution_clock::now(); +#define TIMER_END(s) \ + timer##s.add(std::chrono::high_resolution_clock::now() - start##s); #else #define TIMER_START(s) #define TIMER_STARTV(s) #define TIMER_END(s) #endif -#define TIMED(s, call) \ - do \ - { \ - TIMER_START(s); \ - { \ - call; \ - } \ - TIMER_END(s); \ - } while (0); +#define TIMED(s, call) \ + do { \ + TIMER_START(s); \ + { call; } \ + TIMER_END(s); \ + } while (0); } // namespace mlinfer - diff --git a/examples/transformers/mgx_llama2/harness/config.hpp b/examples/transformers/mgx_llama2/harness/config.hpp index 693f3c90f18..883ac47df7c 100644 --- a/examples/transformers/mgx_llama2/harness/config.hpp +++ b/examples/transformers/mgx_llama2/harness/config.hpp @@ -19,4 +19,5 @@ const int DEVICE_ID = 4; const size_t HIDDEN_LAYERS_NUM = 32; const size_t HEAD_SIZE = 128; -const size_t PAST_KEY_VAL_SIZE = BATCH_SIZE*HIDDEN_LAYERS_NUM*HEAD_SIZE*SEQ_SIZE; \ No newline at end of file +const size_t PAST_KEY_VAL_SIZE = + BATCH_SIZE * HIDDEN_LAYERS_NUM * HEAD_SIZE * SEQ_SIZE; \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/harness/dataset.hpp b/examples/transformers/mgx_llama2/harness/dataset.hpp index 43e39350e56..5a6a739f953 100644 --- a/examples/transformers/mgx_llama2/harness/dataset.hpp +++ b/examples/transformers/mgx_llama2/harness/dataset.hpp @@ -3,183 +3,164 @@ #include "config.hpp" #include "numpy.hpp" -#include #include +#include using namespace mlinfer; using NumpyVector = std::vector>; -struct Dataset -{ - Dataset() = default; - - void initialize() - { - loadDataset(); - if (!_npy_files_loaded) - { - prepareSampleDataset(); - } - } - - NumpyVector loadNumpy(npy::NpyFile& file) - { - NumpyVector numpyDataAll; - auto load_size = file.GetTensorSize()/sizeof(int64_t); - numpyDataAll.push_back(std::vector(load_size)); - file.LoadAll(numpyDataAll.back().data()); - - NumpyVector numpyData; - for(size_t i = 0; i < numpyDataAll.back().size(); i += SEQ_SIZE) - { - auto last = std::min(numpyDataAll.back().size(), i + SEQ_SIZE); - numpyData.emplace_back(numpyDataAll.back().begin() + i, numpyDataAll.back().begin() + last); - } +struct Dataset { + Dataset() = default; -#ifdef TRACE - for (auto& vec: numpyData) - { - std::cout << "Vector size: " << vec.size() << std::endl; - for (auto val: vec) - { - std::cout << val << " "; - } - std::cout << "\n"; - } -#endif - return numpyData; + void initialize() { + loadDataset(); + if (!_npy_files_loaded) { + prepareSampleDataset(); } - - size_t getLastIdx(int current_batch_idx) const - { - auto idx = _current_batch * BATCH_SIZE + current_batch_idx; - auto res = std::find_if(std::rbegin(attention_mask[idx]), std::rend(attention_mask[idx]), [](uint64_t val) { return 1 == val;}); - size_t last_idx = std::distance(res, std::rend(attention_mask[idx])); - #ifdef TRACE - std::cout << "Last input idx: " << last_idx << std::endl; - #endif - return last_idx; + } + + NumpyVector loadNumpy(npy::NpyFile &file) { + NumpyVector numpyDataAll; + auto load_size = file.GetTensorSize() / sizeof(int64_t); + numpyDataAll.push_back(std::vector(load_size)); + file.LoadAll(numpyDataAll.back().data()); + + NumpyVector numpyData; + for (size_t i = 0; i < numpyDataAll.back().size(); i += SEQ_SIZE) { + auto last = std::min(numpyDataAll.back().size(), i + SEQ_SIZE); + numpyData.emplace_back(numpyDataAll.back().begin() + i, + numpyDataAll.back().begin() + last); } - std::vector getInputIds() - { - return getBatchedBuffer(input_ids); +#ifdef TRACE + for (auto &vec : numpyData) { + std::cout << "Vector size: " << vec.size() << std::endl; + for (auto val : vec) { + std::cout << val << " "; + } + std::cout << "\n"; } - - std::vector getAttentionMask() - { - return getBatchedBuffer(attention_mask, 1); +#endif + return numpyData; + } + + size_t getLastIdx(int current_batch_idx) const { + auto idx = _current_batch * BATCH_SIZE + current_batch_idx; + auto res = std::find_if(std::rbegin(attention_mask[idx]), + std::rend(attention_mask[idx]), + [](uint64_t val) { return 1 == val; }); + size_t last_idx = std::distance(res, std::rend(attention_mask[idx])); +#ifdef TRACE + std::cout << "Last input idx: " << last_idx << std::endl; +#endif + return last_idx; + } + + std::vector getInputIds() { return getBatchedBuffer(input_ids); } + + std::vector getAttentionMask() { + return getBatchedBuffer(attention_mask, 1); + } + + std::vector getBatchedBuffer(NumpyVector &buffer, size_t value = 0) { + std::vector batchedBuffer; + const size_t buffer_size = SEQ_SIZE * BATCH_SIZE; + batchedBuffer.reserve(buffer_size); + auto batchSize = BATCH_SIZE; + if (_current_batch == batchNum() - 1) { + batchSize = _size % BATCH_SIZE; } - std::vector getBatchedBuffer(NumpyVector& buffer, size_t value = 0) - { - std::vector batchedBuffer; - const size_t buffer_size = SEQ_SIZE*BATCH_SIZE; - batchedBuffer.reserve(buffer_size); - auto batchSize = BATCH_SIZE; - if (_current_batch == batchNum() -1) - { - batchSize = _size % BATCH_SIZE; - } - - for (size_t i = 0; i < batchSize; ++i) - { - auto buffVec = buffer[BATCH_SIZE*_current_batch + i]; - std::copy(buffVec.begin(), buffVec.end(), std::back_inserter(batchedBuffer)); - } - - if (batchSize != BATCH_SIZE) - { - // For last batch, setting buffer values if no sample is available - batchedBuffer.resize(buffer_size, value); - } - return batchedBuffer; + for (size_t i = 0; i < batchSize; ++i) { + auto buffVec = buffer[BATCH_SIZE * _current_batch + i]; + std::copy(buffVec.begin(), buffVec.end(), + std::back_inserter(batchedBuffer)); } - size_t size() const { return _size; } - size_t currentBatchIdx() const { return _current_batch; } - size_t batchNum() const { - return _size / BATCH_SIZE + (_size % BATCH_SIZE != 0); - } - size_t getNext() - { - if (_current_batch < batchNum() - 1) - { - ++_current_batch; - } - #ifdef TRACE - std::cout << "Current batch: " << _current_batch << std::endl; - #endif - return _current_batch; + if (batchSize != BATCH_SIZE) { + // For last batch, setting buffer values if no sample is available + batchedBuffer.resize(buffer_size, value); } - - Dataset(const Dataset &buf) = delete; - Dataset &operator=(const Dataset &buf) = delete; -private: - - // e.g.: /dataset/input_ids_size_3_seq_256.npy - std::string getDatasetPath(const std::string& datasetName) - { - std::stringstream path; - path << DATASET_FOLDER << datasetName << "_size_" << std::to_string(DATASET_SIZE) << "_seq_" << std::to_string(SEQ_SIZE) << ".npy"; - return path.str(); + return batchedBuffer; + } + + size_t size() const { return _size; } + size_t currentBatchIdx() const { return _current_batch; } + size_t batchNum() const { + return _size / BATCH_SIZE + (_size % BATCH_SIZE != 0); + } + size_t getNext() { + if (_current_batch < batchNum() - 1) { + ++_current_batch; } +#ifdef TRACE + std::cout << "Current batch: " << _current_batch << std::endl; +#endif + return _current_batch; + } - void loadDataset() - { - std::string input_file_path = getDatasetPath("input_ids"); - std::string attention_mask_file_path = getDatasetPath("attention_mask"); - - std::cout << "Input ids file: " << input_file_path << std::endl; - std::ifstream input_file(input_file_path.c_str()); - std::ifstream attention_mask_file(attention_mask_file_path.c_str()); - if (input_file.good() && attention_mask_file.good()) - { - npy::NpyFile input_ids_npy{input_file_path}; - npy::NpyFile attention_mask_npy{attention_mask_file_path}; - input_ids = loadNumpy(input_ids_npy); - attention_mask = loadNumpy(attention_mask_npy); - - _size = input_ids.size(); - - if (input_ids.size() == attention_mask.size()) - { - std::cout << "Loaded numpy files\n"; - _npy_files_loaded = true; - } - else - { - std::cout << "Numpy files do not have the same size\n"; - input_ids.clear(); - attention_mask.clear(); - } - } - else - { - std::cout << "Unable to open numpy files\n"; - } - } + Dataset(const Dataset &buf) = delete; + Dataset &operator=(const Dataset &buf) = delete; - void prepareSampleDataset() - { - std::cout << "Numpy files are not loaded, using dummy data\n"; - std::vector input_ids_sample = {1,6804,338,5207,387,287,29973}; - input_ids_sample.resize(SEQ_SIZE, 0); - std::vector attention_mask_sample = input_ids_sample; - input_ids.emplace_back(std::move(input_ids_sample)); - std::transform(std::begin(attention_mask_sample), std::end(attention_mask_sample), std::begin(attention_mask_sample), [](auto i){ - return (i != 0) ? 1 : 0; - }); - attention_mask.emplace_back(std::move(attention_mask_sample)); - - _size = 1; +private: + // e.g.: /dataset/input_ids_size_3_seq_256.npy + std::string getDatasetPath(const std::string &datasetName) { + std::stringstream path; + path << DATASET_FOLDER << datasetName << "_size_" + << std::to_string(DATASET_SIZE) << "_seq_" << std::to_string(SEQ_SIZE) + << ".npy"; + return path.str(); + } + + void loadDataset() { + std::string input_file_path = getDatasetPath("input_ids"); + std::string attention_mask_file_path = getDatasetPath("attention_mask"); + + std::cout << "Input ids file: " << input_file_path << std::endl; + std::ifstream input_file(input_file_path.c_str()); + std::ifstream attention_mask_file(attention_mask_file_path.c_str()); + if (input_file.good() && attention_mask_file.good()) { + npy::NpyFile input_ids_npy{input_file_path}; + npy::NpyFile attention_mask_npy{attention_mask_file_path}; + input_ids = loadNumpy(input_ids_npy); + attention_mask = loadNumpy(attention_mask_npy); + + _size = input_ids.size(); + + if (input_ids.size() == attention_mask.size()) { + std::cout << "Loaded numpy files\n"; + _npy_files_loaded = true; + } else { + std::cout << "Numpy files do not have the same size\n"; + input_ids.clear(); + attention_mask.clear(); + } + } else { + std::cout << "Unable to open numpy files\n"; } - - NumpyVector input_ids; - NumpyVector attention_mask; - - size_t _size = 0; - size_t _current_batch = 0; - bool _npy_files_loaded = false; + } + + void prepareSampleDataset() { + std::cout << "Numpy files are not loaded, using dummy data\n"; + std::vector input_ids_sample = {1, 6804, 338, 5207, + 387, 287, 29973}; + input_ids_sample.resize(SEQ_SIZE, 0); + std::vector attention_mask_sample = input_ids_sample; + input_ids.emplace_back(std::move(input_ids_sample)); + std::transform(std::begin(attention_mask_sample), + std::end(attention_mask_sample), + std::begin(attention_mask_sample), + [](auto i) { return (i != 0) ? 1 : 0; }); + attention_mask.emplace_back(std::move(attention_mask_sample)); + + _size = 1; + } + + NumpyVector input_ids; + NumpyVector attention_mask; + + size_t _size = 0; + size_t _current_batch = 0; + bool _npy_files_loaded = false; }; \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp index 8c8fc1eda75..29605d78ee2 100644 --- a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp @@ -3,121 +3,131 @@ #include "config.hpp" #include "utils.hpp" -struct LLama2Inputs -{ - LLama2Inputs( - migraphx::program& prog, - migraphx::program_parameters& prog_args) - { - data.initialize(); - prepareProgArgs(prog, prog_args); +struct LLama2Inputs { + LLama2Inputs(migraphx::program &prog, + migraphx::program_parameters &prog_args) { + data.initialize(); + prepareProgArgs(prog, prog_args); + } + + void prepareProgArgs(migraphx::program &prog, + migraphx::program_parameters &prog_args, + bool simple = false) { + auto param_shapes = prog.get_parameter_shapes(); + if (!simple) { + auto inputShape = param_shapes[INPUTS_ID_STR]; + auto input_ids = data.getInputIds(); + input_ids_buffer = + std::make_unique(std::move(input_ids)); + prog_args.add(INPUTS_ID_STR, + migraphx::argument(inputShape, input_ids_buffer->data())); } - void prepareProgArgs(migraphx::program& prog, migraphx::program_parameters& prog_args, bool simple = false) - { - auto param_shapes = prog.get_parameter_shapes(); - if (!simple) - { - auto inputShape = param_shapes[INPUTS_ID_STR]; - auto input_ids = data.getInputIds(); - input_ids_buffer = std::make_unique(std::move(input_ids)); - prog_args.add(INPUTS_ID_STR, migraphx::argument(inputShape, input_ids_buffer->data())); - } - - - auto attShape = param_shapes[ATTENTION_MASK_STR]; - auto attention_mask = data.getAttentionMask(); - if (!simple) - { - attention_mask_buffer = std::make_unique(std::move(attention_mask)); - } - prog_args.add(ATTENTION_MASK_STR, migraphx::argument(attShape, attention_mask_buffer->data())); - - // past_key_values.0.key = @param:past_key_values.0.key -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} - // past_key_values.0.value = @param:past_key_values.0.value -> half_type, {1, 32, 1, 128}, {4096, 128, 128, 1} - for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) - { - auto past_keyStr = getPastKeyString(i); - auto past_keyString = past_keyStr.c_str(); - if (!simple) - { - past_key_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h))); - } - auto pastKeyShape = param_shapes[past_keyString]; - prog_args.add(past_keyString, migraphx::argument(pastKeyShape, past_key_buffers[i]->data())); - - auto past_valueStr = getPastValueStr(i); - auto past_valueString = past_valueStr.c_str(); - if (!simple) - { - past_value_buffers.emplace_back(std::make_unique(std::vector(PAST_KEY_VAL_SIZE, 0.0_h))); - } - auto pastValueShape = param_shapes[past_valueString]; - prog_args.add(past_valueString, migraphx::argument(pastValueShape, past_value_buffers[i]->data())); - } + auto attShape = param_shapes[ATTENTION_MASK_STR]; + auto attention_mask = data.getAttentionMask(); + if (!simple) { + attention_mask_buffer = + std::make_unique(std::move(attention_mask)); } - - void prepareOneDimProgArgs(migraphx::program& prog, migraphx::program_parameters& prog_args) - { - prepareProgArgs(prog, prog_args, true); - auto param_shapes = prog.get_parameter_shapes(); - auto inputShape = param_shapes[INPUTS_ID_STR]; - std::vector oneDimInput(BATCH_SIZE, 0); - one_dim_input_buffer = std::make_unique(std::move(oneDimInput)); - prog_args.add(INPUTS_ID_STR, migraphx::argument(inputShape, one_dim_input_buffer->data())); - } - - void upload_to_device(hipStream_t stream) - { - input_ids_buffer->upload_to_device(stream); - attention_mask_buffer->upload_to_device(stream); + prog_args.add(ATTENTION_MASK_STR, + migraphx::argument(attShape, attention_mask_buffer->data())); + + // past_key_values.0.key = @param:past_key_values.0.key -> half_type, {1, + // 32, 1, 128}, {4096, 128, 128, 1} past_key_values.0.value = + // @param:past_key_values.0.value -> half_type, {1, 32, 1, 128}, {4096, 128, + // 128, 1} + for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) { + auto past_keyStr = getPastKeyString(i); + auto past_keyString = past_keyStr.c_str(); + if (!simple) { + past_key_buffers.emplace_back( + std::make_unique( + std::vector(PAST_KEY_VAL_SIZE, 0.0_h))); + } + auto pastKeyShape = param_shapes[past_keyString]; + prog_args.add( + past_keyString, + migraphx::argument(pastKeyShape, past_key_buffers[i]->data())); + + auto past_valueStr = getPastValueStr(i); + auto past_valueString = past_valueStr.c_str(); + if (!simple) { + past_value_buffers.emplace_back( + std::make_unique( + std::vector(PAST_KEY_VAL_SIZE, 0.0_h))); + } + auto pastValueShape = param_shapes[past_valueString]; + prog_args.add( + past_valueString, + migraphx::argument(pastValueShape, past_value_buffers[i]->data())); } - - bool updateData(migraphx::program& prog, migraphx::program_parameters& prog_args) - { - auto batchIdx = data.currentBatchIdx(); - if (batchIdx != data.getNext()) - { - auto param_shapes = prog.get_parameter_shapes(); - - std::vector input_ids = data.getInputIds(); - input_ids_buffer = std::make_unique(std::move(input_ids)); - prog_args.add(INPUTS_ID_STR, migraphx::argument(param_shapes[INPUTS_ID_STR], input_ids_buffer->data())); - - auto attention_mask = data.getAttentionMask(); - attention_mask_buffer->update(std::move(attention_mask)); - - return true; - } - return false; + } + + void prepareOneDimProgArgs(migraphx::program &prog, + migraphx::program_parameters &prog_args) { + prepareProgArgs(prog, prog_args, true); + auto param_shapes = prog.get_parameter_shapes(); + auto inputShape = param_shapes[INPUTS_ID_STR]; + std::vector oneDimInput(BATCH_SIZE, 0); + one_dim_input_buffer = + std::make_unique(std::move(oneDimInput)); + prog_args.add(INPUTS_ID_STR, + migraphx::argument(inputShape, one_dim_input_buffer->data())); + } + + void upload_to_device(hipStream_t stream) { + input_ids_buffer->upload_to_device(stream); + attention_mask_buffer->upload_to_device(stream); + } + + bool updateData(migraphx::program &prog, + migraphx::program_parameters &prog_args) { + auto batchIdx = data.currentBatchIdx(); + if (batchIdx != data.getNext()) { + auto param_shapes = prog.get_parameter_shapes(); + + std::vector input_ids = data.getInputIds(); + input_ids_buffer = + std::make_unique(std::move(input_ids)); + prog_args.add(INPUTS_ID_STR, + migraphx::argument(param_shapes[INPUTS_ID_STR], + input_ids_buffer->data())); + + auto attention_mask = data.getAttentionMask(); + attention_mask_buffer->update(std::move(attention_mask)); + + return true; } - - void resetPastKeyValueBuffers(hipStream_t stream) - { - for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) - { - past_key_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); - past_value_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); - past_key_buffers[i]->upload_to_device(stream); - past_value_buffers[i]->upload_to_device(stream); - } + return false; + } + + void resetPastKeyValueBuffers(hipStream_t stream) { + for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) { + past_key_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); + past_value_buffers[i]->update( + std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); + past_key_buffers[i]->upload_to_device(stream); + past_value_buffers[i]->upload_to_device(stream); } - - size_t getLastInputIndex(int current_batch_idx) const { return data.getLastIdx(current_batch_idx); } - size_t dataSize() const { return data.size(); } - size_t batchNum() const { return data.batchNum(); } - - LLama2Inputs() = delete; - LLama2Inputs(const LLama2Inputs &buf) = delete; - LLama2Inputs &operator=(const LLama2Inputs &buf) = delete; - - std::unique_ptr input_ids_buffer; - std::unique_ptr one_dim_input_buffer; - std::unique_ptr attention_mask_buffer; - std::vector> past_key_buffers; - std::vector> past_value_buffers; - Dataset data; - - const char* INPUTS_ID_STR = "input_ids"; - const char* ATTENTION_MASK_STR = "attention_mask"; + } + + size_t getLastInputIndex(int current_batch_idx) const { + return data.getLastIdx(current_batch_idx); + } + size_t dataSize() const { return data.size(); } + size_t batchNum() const { return data.batchNum(); } + + LLama2Inputs() = delete; + LLama2Inputs(const LLama2Inputs &buf) = delete; + LLama2Inputs &operator=(const LLama2Inputs &buf) = delete; + + std::unique_ptr input_ids_buffer; + std::unique_ptr one_dim_input_buffer; + std::unique_ptr attention_mask_buffer; + std::vector> past_key_buffers; + std::vector> past_value_buffers; + Dataset data; + + const char *INPUTS_ID_STR = "input_ids"; + const char *ATTENTION_MASK_STR = "attention_mask"; }; diff --git a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp index 87bc0da80a3..ed2cab2566f 100644 --- a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp @@ -5,47 +5,65 @@ #include -struct LLama2Outputs -{ - LLama2Outputs() - { - } - - void prepareProgArgs(migraphx::program_parameters& prog_args, migraphx::program_parameters& prog_args_one_dim) - { - output_buffer = std::make_unique(std::vector(OUTPUT_SIZE)); - one_dim_output_buffer = std::make_unique(std::vector(BATCH_SIZE*VOCAB_SIZE)); - migraphx::shape out_shape{migraphx_shape_half_type, {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; - prog_args.add(OUTPUT_NAME, migraphx::argument(out_shape, output_buffer->data())); - - migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {BATCH_SIZE, 1, VOCAB_SIZE}}; - prog_args_one_dim.add(OUTPUT_NAME, migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); - } - - void prepareProgArgsArgMax(migraphx::program_parameters& prog_args_argmax, migraphx::program_parameters& prog_args_argmax_one_dim) - { - // setting up argmax arguments - migraphx::shape x_shape{migraphx_shape_half_type, {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; - prog_args_argmax.add("x", migraphx::argument(x_shape, output_buffer->data())); - argm_output_buffer = std::make_unique(std::vector(BATCH_SIZE*SEQ_SIZE)); - migraphx::shape argm_out_shape{migraphx_shape_int64_type, {BATCH_SIZE, SEQ_SIZE, 1}}; - prog_args_argmax.add(OUTPUT_NAME, migraphx::argument(argm_out_shape, argm_output_buffer->data())); - - migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {BATCH_SIZE, 1, VOCAB_SIZE}}; - prog_args_argmax_one_dim.add("x", migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); - argm_output_buffer_one_dim = std::make_unique(std::vector(BATCH_SIZE)); - migraphx::shape argm_out_shape_one_dim{migraphx_shape_int64_type, {BATCH_SIZE, 1, 1}}; - prog_args_argmax_one_dim.add(OUTPUT_NAME, migraphx::argument(argm_out_shape_one_dim, argm_output_buffer_one_dim->data())); - } - - LLama2Outputs(const LLama2Outputs &buf) = delete; - LLama2Outputs &operator=(const LLama2Outputs &buf) = delete; - - std::unique_ptr output_buffer; - std::unique_ptr one_dim_output_buffer; - std::unique_ptr argm_output_buffer; - std::unique_ptr argm_output_buffer_one_dim; - - const char* OUTPUT_NAME = "main:#output_0"; - const size_t OUTPUT_SIZE = BATCH_SIZE * SEQ_SIZE * VOCAB_SIZE; +struct LLama2Outputs { + LLama2Outputs() {} + + void prepareProgArgs(migraphx::program_parameters &prog_args, + migraphx::program_parameters &prog_args_one_dim) { + output_buffer = std::make_unique( + std::vector(OUTPUT_SIZE)); + one_dim_output_buffer = std::make_unique( + std::vector(BATCH_SIZE * VOCAB_SIZE)); + migraphx::shape out_shape{migraphx_shape_half_type, + {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; + prog_args.add(OUTPUT_NAME, + migraphx::argument(out_shape, output_buffer->data())); + + migraphx::shape x_shape_one_dim{migraphx_shape_half_type, + {BATCH_SIZE, 1, VOCAB_SIZE}}; + prog_args_one_dim.add( + OUTPUT_NAME, + migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); + } + + void prepareProgArgsArgMax( + migraphx::program_parameters &prog_args_argmax, + migraphx::program_parameters &prog_args_argmax_one_dim) { + // setting up argmax arguments + migraphx::shape x_shape{migraphx_shape_half_type, + {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; + prog_args_argmax.add("x", + migraphx::argument(x_shape, output_buffer->data())); + argm_output_buffer = std::make_unique( + std::vector(BATCH_SIZE * SEQ_SIZE)); + migraphx::shape argm_out_shape{migraphx_shape_int64_type, + {BATCH_SIZE, SEQ_SIZE, 1}}; + prog_args_argmax.add( + OUTPUT_NAME, + migraphx::argument(argm_out_shape, argm_output_buffer->data())); + + migraphx::shape x_shape_one_dim{migraphx_shape_half_type, + {BATCH_SIZE, 1, VOCAB_SIZE}}; + prog_args_argmax_one_dim.add( + "x", + migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); + argm_output_buffer_one_dim = + std::make_unique(std::vector(BATCH_SIZE)); + migraphx::shape argm_out_shape_one_dim{migraphx_shape_int64_type, + {BATCH_SIZE, 1, 1}}; + prog_args_argmax_one_dim.add( + OUTPUT_NAME, migraphx::argument(argm_out_shape_one_dim, + argm_output_buffer_one_dim->data())); + } + + LLama2Outputs(const LLama2Outputs &buf) = delete; + LLama2Outputs &operator=(const LLama2Outputs &buf) = delete; + + std::unique_ptr output_buffer; + std::unique_ptr one_dim_output_buffer; + std::unique_ptr argm_output_buffer; + std::unique_ptr argm_output_buffer_one_dim; + + const char *OUTPUT_NAME = "main:#output_0"; + const size_t OUTPUT_SIZE = BATCH_SIZE * SEQ_SIZE * VOCAB_SIZE; }; \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/harness/logging.hpp b/examples/transformers/mgx_llama2/harness/logging.hpp index a25e73498a2..2dee1ec530a 100644 --- a/examples/transformers/mgx_llama2/harness/logging.hpp +++ b/examples/transformers/mgx_llama2/harness/logging.hpp @@ -2,31 +2,29 @@ #include -namespace mlinfer -{ +namespace mlinfer { #define LOGGING_OFF 0 #define ENABLE_TIMED_LOGGING 0 #define ENABLE_DEBUG_LOGGING 0 #if (!LOGGING_OFF) -#define LOG_INFO(...) \ - do \ - { \ - std::cout << __VA_ARGS__ << std::endl; \ - } while (0) -#define LOG_ERROR(...) \ - do \ - { \ - std::cerr << __VA_ARGS__ << std::endl; \ - } while (0) -#define LOG_STATE(...) \ - do \ - { \ - std::cout << "================================================" << std::endl; \ - std::cout << __VA_ARGS__ << std::endl; \ - std::cout << "================================================" << std::endl; \ - } while (0) +#define LOG_INFO(...) \ + do { \ + std::cout << __VA_ARGS__ << std::endl; \ + } while (0) +#define LOG_ERROR(...) \ + do { \ + std::cerr << __VA_ARGS__ << std::endl; \ + } while (0) +#define LOG_STATE(...) \ + do { \ + std::cout << "================================================" \ + << std::endl; \ + std::cout << __VA_ARGS__ << std::endl; \ + std::cout << "================================================" \ + << std::endl; \ + } while (0) #else #define LOG_INFO(...) (void)0 #define LOG_ERROR(...) (void)0 @@ -46,4 +44,3 @@ namespace mlinfer #endif } // namespace mlinfer - diff --git a/examples/transformers/mgx_llama2/harness/numa.hpp b/examples/transformers/mgx_llama2/harness/numa.hpp index 5c0ee561efe..13b46e84d67 100644 --- a/examples/transformers/mgx_llama2/harness/numa.hpp +++ b/examples/transformers/mgx_llama2/harness/numa.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -8,151 +9,128 @@ #include #include #include -#include #include "common.hpp" -namespace mlinfer -{ - // NUMA config. Each NUMA node contains a pair of GPU indices and CPU indices. - using NumaConfig = std::vector, std::vector>>; - - // The NUMA node idx for each GPU. - using GpuToNumaMap = std::vector; - - struct NumaSettings - { - NumaConfig numa_config; - GpuToNumaMap gpu_to_numa_map; - }; - - struct Numa final - { - NumaSettings numa_settings; - - explicit Numa(const NumaSettings &numa_settings) : numa_settings{numa_settings} {} - - inline bool UseNuma() const - { - return not numa_settings.numa_config.empty(); - } - - inline size_t GetNumaCount() const - { - return numa_settings.numa_config.size(); - }; - - inline int GetNumaIdx(const int deviceId) const - { - return UseNuma() ? numa_settings.gpu_to_numa_map.at(deviceId) : 0; - } - - inline std::vector GetClosestCpus(const int deviceId) const - { - assertm(UseNuma(), "GetClosestCpus only available for NUMA"); - return numa_settings.numa_config.at(GetNumaIdx(deviceId)).second; - } - }; - - // Restrict mem allocation to specific NUMA node. - inline void - bindNumaMemPolicy(const int32_t numaIdx, const int32_t nbNumas) - { - unsigned long nodeMask = 1UL << numaIdx; - long ret = set_mempolicy(MPOL_BIND, &nodeMask, nbNumas + 1); - CHECK(ret >= 0, std::strerror(errno)); - } - - // Reset mem allocation setting. - inline void resetNumaMemPolicy() - { - long ret = set_mempolicy(MPOL_DEFAULT, nullptr, 0); - CHECK(ret >= 0, std::strerror(errno)); +namespace mlinfer { +// NUMA config. Each NUMA node contains a pair of GPU indices and CPU indices. +using NumaConfig = + std::vector, std::vector>>; + +// The NUMA node idx for each GPU. +using GpuToNumaMap = std::vector; + +struct NumaSettings { + NumaConfig numa_config; + GpuToNumaMap gpu_to_numa_map; +}; + +struct Numa final { + NumaSettings numa_settings; + + explicit Numa(const NumaSettings &numa_settings) + : numa_settings{numa_settings} {} + + inline bool UseNuma() const { return not numa_settings.numa_config.empty(); } + + inline size_t GetNumaCount() const { + return numa_settings.numa_config.size(); + }; + + inline int GetNumaIdx(const int deviceId) const { + return UseNuma() ? numa_settings.gpu_to_numa_map.at(deviceId) : 0; + } + + inline std::vector GetClosestCpus(const int deviceId) const { + assertm(UseNuma(), "GetClosestCpus only available for NUMA"); + return numa_settings.numa_config.at(GetNumaIdx(deviceId)).second; + } +}; + +// Restrict mem allocation to specific NUMA node. +inline void bindNumaMemPolicy(const int32_t numaIdx, const int32_t nbNumas) { + unsigned long nodeMask = 1UL << numaIdx; + long ret = set_mempolicy(MPOL_BIND, &nodeMask, nbNumas + 1); + CHECK(ret >= 0, std::strerror(errno)); +} + +// Reset mem allocation setting. +inline void resetNumaMemPolicy() { + long ret = set_mempolicy(MPOL_DEFAULT, nullptr, 0); + CHECK(ret >= 0, std::strerror(errno)); +} + +// Limit a thread to be on specific cpus. +inline void bindThreadToCpus(std::thread &th, const std::vector &cpus, + const bool ignore_esrch = false) { + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + for (int cpu : cpus) { + CPU_SET(cpu, &cpuset); + } + int ret = + pthread_setaffinity_np(th.native_handle(), sizeof(cpu_set_t), &cpuset); + bool noerr = ignore_esrch ? ret == 0 || ret == ESRCH : ret == 0; + CHECK(noerr, std::strerror(ret)); +} + +// Helper to converts the range string (like "0,2-5,13-17") to a vector of ints. +inline std::vector parseRange(const std::string &s) { + std::vector results; + auto ranges = splitString(s, ","); + for (const auto &range : ranges) { + auto startEnd = splitString(range, "-"); + CHECK((startEnd.size() <= 2), + "Invalid numa_config setting. Expects zero or one '-'."); + if (startEnd.size() == 1) { + results.push_back(std::stoi(startEnd[0])); + } else { + size_t start = std::stoi(startEnd[0]); + size_t last = std::stoi(startEnd[1]); + for (size_t i = start; i <= last; ++i) { + results.push_back(i); + } } - - // Limit a thread to be on specific cpus. - inline void bindThreadToCpus(std::thread &th, const std::vector &cpus, const bool ignore_esrch = false) - { - cpu_set_t cpuset; - CPU_ZERO(&cpuset); - for (int cpu : cpus) - { - CPU_SET(cpu, &cpuset); - } - int ret = pthread_setaffinity_np(th.native_handle(), sizeof(cpu_set_t), &cpuset); - bool noerr = ignore_esrch ? ret == 0 || ret == ESRCH : ret == 0; - CHECK(noerr, std::strerror(ret)); + } + return results; +} + +// Example of the format: "0,2:0-63&1,3:64-127" for 4 GPUs, 128 CPU, 2 NUMA node +// system. +inline NumaConfig parseNumaConfig(const std::string &numa_file) { + std::string numa_str; + std::ifstream file(numa_file.c_str()); + if (file.is_open()) { + getline(file, numa_str); + file.close(); + } + + NumaConfig config; + if (!numa_str.empty()) { + auto nodes = splitString(numa_str, "&"); + for (const auto &node : nodes) { + auto pair = splitString(node, ":"); + CHECK((pair.size() == 2), + "Invalid numa_config setting. Expects one ':'."); + auto gpus = parseRange(pair[0]); + auto cpus = parseRange(pair[1]); + config.emplace_back(std::make_pair(gpus, cpus)); } - - // Helper to converts the range string (like "0,2-5,13-17") to a vector of ints. - inline std::vector parseRange(const std::string &s) - { - std::vector results; - auto ranges = splitString(s, ","); - for (const auto &range : ranges) - { - auto startEnd = splitString(range, "-"); - CHECK((startEnd.size() <= 2), "Invalid numa_config setting. Expects zero or one '-'."); - if (startEnd.size() == 1) - { - results.push_back(std::stoi(startEnd[0])); - } - else - { - size_t start = std::stoi(startEnd[0]); - size_t last = std::stoi(startEnd[1]); - for (size_t i = start; i <= last; ++i) - { - results.push_back(i); - } - } - } - return results; - } - - // Example of the format: "0,2:0-63&1,3:64-127" for 4 GPUs, 128 CPU, 2 NUMA node system. - inline NumaConfig parseNumaConfig(const std::string &numa_file) - { - std::string numa_str; - std::ifstream file(numa_file.c_str()); - if (file.is_open()) - { - getline(file, numa_str); - file.close(); - } - - NumaConfig config; - if (!numa_str.empty()) - { - auto nodes = splitString(numa_str, "&"); - for (const auto &node : nodes) - { - auto pair = splitString(node, ":"); - CHECK((pair.size() == 2), "Invalid numa_config setting. Expects one ':'."); - auto gpus = parseRange(pair[0]); - auto cpus = parseRange(pair[1]); - config.emplace_back(std::make_pair(gpus, cpus)); - } - } - return config; - } - - // Convert NumaConfig to GpuToNumaMap for easier look-up. - inline GpuToNumaMap getGpuToNumaMap(const NumaConfig &config) - { - std::vector map; - for (size_t numaIdx = 0; numaIdx < config.size(); numaIdx++) - { - for (const auto gpuIdx : config[numaIdx].first) - { - if (gpuIdx >= map.size()) - { - map.resize(gpuIdx + 1); - } - map[gpuIdx] = numaIdx; - } - } - return map; + } + return config; +} + +// Convert NumaConfig to GpuToNumaMap for easier look-up. +inline GpuToNumaMap getGpuToNumaMap(const NumaConfig &config) { + std::vector map; + for (size_t numaIdx = 0; numaIdx < config.size(); numaIdx++) { + for (const auto gpuIdx : config[numaIdx].first) { + if (gpuIdx >= map.size()) { + map.resize(gpuIdx + 1); + } + map[gpuIdx] = numaIdx; } + } + return map; +} } // namespace mlinfer - diff --git a/examples/transformers/mgx_llama2/harness/numpy.hpp b/examples/transformers/mgx_llama2/harness/numpy.hpp index 8cc5a15db67..3aeca0e44f5 100644 --- a/examples/transformers/mgx_llama2/harness/numpy.hpp +++ b/examples/transformers/mgx_llama2/harness/numpy.hpp @@ -11,101 +11,92 @@ #include "common.hpp" #include "logging.hpp" -namespace mlinfer -{ - namespace npy - { - class NpyFile - { - private: - std::string m_Path; - std::ifstream m_FStream; - size_t m_HeaderSize; - std::string m_Header; - size_t m_TensorSize; - size_t m_ElementSize; - std::vector m_TensorDims; +namespace mlinfer { +namespace npy { +class NpyFile { +private: + std::string m_Path; + std::ifstream m_FStream; + size_t m_HeaderSize; + std::string m_Header; + size_t m_TensorSize; + size_t m_ElementSize; + std::vector m_TensorDims; - public: - explicit NpyFile(const std::string &path) - : m_Path(path), m_FStream(m_Path) - { - LOG_INFO("Npy file from " << path); - // magic and fixed header - char b[256]; - m_FStream.read(b, 10); - CHECK(m_FStream, "Unable to parse: " << m_Path); +public: + explicit NpyFile(const std::string &path) : m_Path(path), m_FStream(m_Path) { + LOG_INFO("Npy file from " << path); + // magic and fixed header + char b[256]; + m_FStream.read(b, 10); + CHECK(m_FStream, "Unable to parse: " << m_Path); - // check magic - CHECK(static_cast(b[0]) == 0x93 && b[1] == 'N' && b[2] == 'U' && b[3] == 'M' && b[4] == 'P' && b[5] == 'Y', "Bad magic: " << m_Path); + // check magic + CHECK(static_cast(b[0]) == 0x93 && b[1] == 'N' && + b[2] == 'U' && b[3] == 'M' && b[4] == 'P' && b[5] == 'Y', + "Bad magic: " << m_Path); - // get header - auto major = static_cast(b[6]); - // auto minor = static_cast(b[7]); - CHECK(major == 1, "Only npy version 1 is supported: " << m_Path); - m_HeaderSize = static_cast(b[8]); - m_Header.resize(m_HeaderSize); - m_FStream.read(static_cast(m_Header.data()), m_HeaderSize); + // get header + auto major = static_cast(b[6]); + // auto minor = static_cast(b[7]); + CHECK(major == 1, "Only npy version 1 is supported: " << m_Path); + m_HeaderSize = static_cast(b[8]); + m_Header.resize(m_HeaderSize); + m_FStream.read(static_cast(m_Header.data()), m_HeaderSize); - // get file size - auto cur = m_FStream.tellg(); - m_FStream.seekg(0, std::ios::end); - auto size = m_FStream.tellg(); - m_TensorSize = size - cur; + // get file size + auto cur = m_FStream.tellg(); + m_FStream.seekg(0, std::ios::end); + auto size = m_FStream.tellg(); + m_TensorSize = size - cur; - // parse header - std::regex re(R"re(\{'descr': '[<|][fi]([\d])', 'fortran_order': False, 'shape': \(([\d, ]*)\), \} +\n)re"); - std::smatch matches; - CHECK(std::regex_match(m_Header, matches, re), "Cannot parse numpy header: " << m_Path); - CHECK(matches.size() == 3, "Cannot parse numpy header: " << m_Path); - m_ElementSize = std::stoi(matches[1]); - std::vector dims = splitString(matches[2], ", "); - m_TensorDims.resize(dims.size()); - std::transform( - dims.begin(), dims.end(), m_TensorDims.begin(), [](const std::string &s) - { return std::stoi(s); }); + // parse header + std::regex re( + R"re(\{'descr': '[<|][fi]([\d])', 'fortran_order': False, 'shape': \(([\d, ]*)\), \} +\n)re"); + std::smatch matches; + CHECK(std::regex_match(m_Header, matches, re), + "Cannot parse numpy header: " << m_Path); + CHECK(matches.size() == 3, "Cannot parse numpy header: " << m_Path); + m_ElementSize = std::stoi(matches[1]); + std::vector dims = splitString(matches[2], ", "); + m_TensorDims.resize(dims.size()); + std::transform(dims.begin(), dims.end(), m_TensorDims.begin(), + [](const std::string &s) { return std::stoi(s); }); - // check header sanity - size_t tensorSize = std::accumulate(m_TensorDims.begin(), m_TensorDims.end(), m_ElementSize, std::multiplies()); - CHECK(tensorSize == m_TensorSize, "Header description does not match file size: " << m_Path); - LOG_DEBUG(" Input num=" << m_TensorDims[0] << " | Sample size=" << (tensorSize / m_TensorDims[0]) << " | Full size=" << m_TensorSize); - } - ~NpyFile() - { - m_FStream.close(); - }; - std::string GetPath() const - { - return m_Path; - } - std::vector GetDims() const - { - return m_TensorDims; - } - size_t GetTensorSize() const - { - return m_TensorSize; - } - // load the entire tensor - void LoadAll(void *dst) - { - m_FStream.seekg(10 + m_HeaderSize, std::ios::beg); - m_FStream.read(static_cast(dst), m_TensorSize); - CHECK(m_FStream, "Unable to parse: " << m_Path); - CHECK(m_FStream.peek() == EOF, "Did not consume full file: " << m_Path); - } + // check header sanity + size_t tensorSize = + std::accumulate(m_TensorDims.begin(), m_TensorDims.end(), m_ElementSize, + std::multiplies()); + CHECK(tensorSize == m_TensorSize, + "Header description does not match file size: " << m_Path); + LOG_DEBUG(" Input num=" << m_TensorDims[0] << " | Sample size=" + << (tensorSize / m_TensorDims[0]) + << " | Full size=" << m_TensorSize); + } + ~NpyFile() { m_FStream.close(); }; + std::string GetPath() const { return m_Path; } + std::vector GetDims() const { return m_TensorDims; } + size_t GetTensorSize() const { return m_TensorSize; } + // load the entire tensor + void LoadAll(void *dst) { + m_FStream.seekg(10 + m_HeaderSize, std::ios::beg); + m_FStream.read(static_cast(dst), m_TensorSize); + CHECK(m_FStream, "Unable to parse: " << m_Path); + CHECK(m_FStream.peek() == EOF, "Did not consume full file: " << m_Path); + } - // load only selected indices from the Tensor, assuming that the first dim is batch dim. - void LoadSamples(void *dst, const std::vector &indices) - { - size_t sampleSize = std::accumulate(m_TensorDims.begin() + 1, m_TensorDims.end(), m_ElementSize, std::multiplies()); - for (size_t i = 0; i < indices.size(); i++) - { - m_FStream.seekg(10 + m_HeaderSize + indices[i] * sampleSize, std::ios::beg); - m_FStream.read(static_cast(dst) + i * sampleSize, sampleSize); - } - } - }; - } // namespace npy + // load only selected indices from the Tensor, assuming that the first dim is + // batch dim. + void LoadSamples(void *dst, const std::vector &indices) { + size_t sampleSize = + std::accumulate(m_TensorDims.begin() + 1, m_TensorDims.end(), + m_ElementSize, std::multiplies()); + for (size_t i = 0; i < indices.size(); i++) { + m_FStream.seekg(10 + m_HeaderSize + indices[i] * sampleSize, + std::ios::beg); + m_FStream.read(static_cast(dst) + i * sampleSize, sampleSize); + } + } +}; +} // namespace npy } // namespace mlinfer - diff --git a/examples/transformers/mgx_llama2/harness/timer.hpp b/examples/transformers/mgx_llama2/harness/timer.hpp index 517cffb148e..cfcc20a6ad1 100644 --- a/examples/transformers/mgx_llama2/harness/timer.hpp +++ b/examples/transformers/mgx_llama2/harness/timer.hpp @@ -2,68 +2,63 @@ #include #include +#include #include #include -#include #include -#include +#include // For debugging the timing of each part -class Timer -{ +class Timer { public: - explicit Timer(const std::string &tag_, bool verbose_ = false) - : tag(tag_), verbose(verbose_) - { - std::cout << "Timer " << tag << " created." << std::endl; - } - void add(const std::chrono::duration &in) - { - std::thread::id id = std::this_thread::get_id(); - count[id] += 1; - total[id] += in; - if (verbose) - measurements[id].emplace_back(in); - } - ~Timer() - { - auto total_accum = std::accumulate( - std::begin(total), - std::end(total), - 0, - [](int64_t value, std::pair> p) - { return value + p.second.count(); }); + explicit Timer(const std::string &tag_, bool verbose_ = false) + : tag(tag_), verbose(verbose_) { + std::cout << "Timer " << tag << " created." << std::endl; + } + void add(const std::chrono::duration &in) { + std::thread::id id = std::this_thread::get_id(); + count[id] += 1; + total[id] += in; + if (verbose) + measurements[id].emplace_back(in); + } + ~Timer() { + auto total_accum = std::accumulate( + std::begin(total), std::end(total), 0, + [](int64_t value, + std::pair> + p) { return value + p.second.count(); }); - auto count_accum = std::accumulate( - std::begin(count), - std::end(count), - 0, - [](size_t value, std::pair p) - { return value + p.second; }); + auto count_accum = + std::accumulate(std::begin(count), std::end(count), 0, + [](size_t value, std::pair p) { + return value + p.second; + }); - std::cout << "Timer " << tag << " reports " << (double)total_accum / count_accum << " ms per call for " << count_accum - << " times." << std::endl; - if (verbose) - { - std::cout << " Measurements=["; - for (const auto &m : measurements) - { - std::cout << " Thread " << m.first << ": {"; - for (const auto &d : m.second) - { - std::cout << d.count() << ","; - } - - std::cout << "},"; - } - std::cout << "]" << std::endl; + std::cout << "Timer " << tag << " reports " + << (double)total_accum / count_accum << " ms per call for " + << count_accum << " times." << std::endl; + if (verbose) { + std::cout << " Measurements=["; + for (const auto &m : measurements) { + std::cout << " Thread " << m.first << ": {"; + for (const auto &d : m.second) { + std::cout << d.count() << ","; } + + std::cout << "},"; + } + std::cout << "]" << std::endl; } + } private: - std::string tag; - bool verbose; - std::unordered_map> total; - std::unordered_map>> measurements; - std::unordered_map count; + std::string tag; + bool verbose; + std::unordered_map> + total; + std::unordered_map>> + measurements; + std::unordered_map count; }; diff --git a/examples/transformers/mgx_llama2/harness/utils.hpp b/examples/transformers/mgx_llama2/harness/utils.hpp index 7d1a9287921..794f0bf8ae3 100644 --- a/examples/transformers/mgx_llama2/harness/utils.hpp +++ b/examples/transformers/mgx_llama2/harness/utils.hpp @@ -2,209 +2,184 @@ #include "config.hpp" +#include #include #include -#include #include - -struct ModelLoadSettings -{ - size_t sequnce_length; - bool quantize_fp16; - bool fast_math; - bool input_one_dim; - size_t batch_size; +struct ModelLoadSettings { + size_t sequnce_length; + bool quantize_fp16; + bool fast_math; + bool input_one_dim; + size_t batch_size; }; -static std::string getModelPath(ModelLoadSettings& s) -{ - std::stringstream path; - path << MODEL_FOLDER << "model-" << std::to_string(s.sequnce_length) << "_fp" << (s.quantize_fp16 ? "16" : "32") << "_"; - path << "batch_" << std::to_string(s.batch_size) << "_"; - if (!s.fast_math) - { - path << "no"; - } - path << "fastmath"; - if (s.input_one_dim) - { - path << "_inputonedim"; - } - path << ".mxr"; - return path.str(); +static std::string getModelPath(ModelLoadSettings &s) { + std::stringstream path; + path << MODEL_FOLDER << "model-" << std::to_string(s.sequnce_length) << "_fp" + << (s.quantize_fp16 ? "16" : "32") << "_"; + path << "batch_" << std::to_string(s.batch_size) << "_"; + if (!s.fast_math) { + path << "no"; + } + path << "fastmath"; + if (s.input_one_dim) { + path << "_inputonedim"; + } + path << ".mxr"; + return path.str(); } -[[maybe_unused]] static std::string getPastKeyString(size_t i) -{ - std::stringstream past_key; - past_key << "past_key_values." << std::to_string(i) << ".key"; - return past_key.str(); +[[maybe_unused]] static std::string getPastKeyString(size_t i) { + std::stringstream past_key; + past_key << "past_key_values." << std::to_string(i) << ".key"; + return past_key.str(); } -[[maybe_unused]] static std::string getPastValueStr(size_t i) -{ - std::stringstream past_val; - past_val << "past_key_values." << std::to_string(i) << ".value"; - return past_val.str(); +[[maybe_unused]] static std::string getPastValueStr(size_t i) { + std::stringstream past_val; + past_val << "past_key_values." << std::to_string(i) << ".value"; + return past_val.str(); } -[[maybe_unused]] static std::string getPresentKeyString(size_t i) -{ - std::stringstream past_key; - past_key << "present." << std::to_string(i) << ".key"; - return past_key.str(); +[[maybe_unused]] static std::string getPresentKeyString(size_t i) { + std::stringstream past_key; + past_key << "present." << std::to_string(i) << ".key"; + return past_key.str(); } -[[maybe_unused]] static std::string getPresentValueStr(size_t i) -{ - std::stringstream past_val; - past_val << "present." << std::to_string(i) << ".value"; - return past_val.str(); +[[maybe_unused]] static std::string getPresentValueStr(size_t i) { + std::stringstream past_val; + past_val << "present." << std::to_string(i) << ".value"; + return past_val.str(); } -static migraphx::program loadOnnx(ModelLoadSettings& settings) -{ - std::filesystem::path onnx_path(MODEL_FOLDER + ONNX_FILE); - - #ifdef TRACE - std::cout << "Using model: " << MODEL_FOLDER + ONNX_FILE << std::endl; - #endif - - migraphx::program prog; - std::ifstream f(onnx_path.c_str()); - if (f.good()) - { - migraphx::onnx_options onnx_opts; - std::vector dims = {BATCH_SIZE, SEQ_SIZE}; - std::vector dimsPastKey = {BATCH_SIZE, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}; - std::vector inputDim; - if (settings.input_one_dim) - { - inputDim = {BATCH_SIZE,1}; - } - else - { - inputDim = dims; - } - onnx_opts.set_input_parameter_shape("input_ids", inputDim); - onnx_opts.set_input_parameter_shape("attention_mask", dims); - for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) - { - onnx_opts.set_input_parameter_shape(getPastKeyString(i), dimsPastKey); - onnx_opts.set_input_parameter_shape(getPastValueStr(i), dimsPastKey); - } - std::cout << "Parsing onnx file ..." << std::endl; - prog = parse_onnx(onnx_path.c_str(), onnx_opts); - - std::string target_str = "gpu"; - migraphx::target targ = migraphx::target(target_str.c_str()); - - if (settings.quantize_fp16) - { - std::cout << "Quantize FP16 ..." << std::endl; - migraphx::quantize_fp16(prog); - } - - migraphx::compile_options comp_opts; - - if (settings.fast_math) - comp_opts.set_fast_math(); - - comp_opts.set_exhaustive_tune_flag(); - - std::cout << "Compile to target ..." << std::endl; - prog.compile(targ, comp_opts); - - std::string modelPath = getModelPath(settings); - migraphx::file_options file_options; - file_options.set_file_format("msgpack"); - std::cout << "Saving mxr file to: " << modelPath << "\n"; - migraphx::save(prog, modelPath.c_str(), file_options); +static migraphx::program loadOnnx(ModelLoadSettings &settings) { + std::filesystem::path onnx_path(MODEL_FOLDER + ONNX_FILE); + +#ifdef TRACE + std::cout << "Using model: " << MODEL_FOLDER + ONNX_FILE << std::endl; +#endif + + migraphx::program prog; + std::ifstream f(onnx_path.c_str()); + if (f.good()) { + migraphx::onnx_options onnx_opts; + std::vector dims = {BATCH_SIZE, SEQ_SIZE}; + std::vector dimsPastKey = {BATCH_SIZE, HIDDEN_LAYERS_NUM, + SEQ_SIZE, HEAD_SIZE}; + std::vector inputDim; + if (settings.input_one_dim) { + inputDim = {BATCH_SIZE, 1}; + } else { + inputDim = dims; } - else - { - std::cerr << "Onnx file is not available on path: " << onnx_path << std::endl; - exit(1); + onnx_opts.set_input_parameter_shape("input_ids", inputDim); + onnx_opts.set_input_parameter_shape("attention_mask", dims); + for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) { + onnx_opts.set_input_parameter_shape(getPastKeyString(i), dimsPastKey); + onnx_opts.set_input_parameter_shape(getPastValueStr(i), dimsPastKey); } - return prog; -}; + std::cout << "Parsing onnx file ..." << std::endl; + prog = parse_onnx(onnx_path.c_str(), onnx_opts); -static migraphx::program loadProgram(ModelLoadSettings& settings) -{ - std::filesystem::path compiled_path(getModelPath(settings)); - - migraphx::file_options file_options; - file_options.set_file_format("msgpack"); - - migraphx::program prog; - std::ifstream f(compiled_path.c_str()); - if (f.good()) - { - std::cout << "Loading model from " << compiled_path << " ...\n"; - prog = migraphx::load(compiled_path.c_str(), file_options); - } - else - { - std::cout << "MXR file can't be loaded try to load ONNX\n"; - prog = loadOnnx(settings); - } - return prog; -}; - -static migraphx::program create_argmax_program(ModelLoadSettings& settings) -{ - migraphx::program prog; - std::vector dims {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}; - if (settings.input_one_dim) - { - dims[1] = 1; - } - migraphx::shape s{migraphx_shape_half_type, dims}; - migraphx::module m = prog.get_main_module(); - auto x = m.add_parameter("x", s); - auto argmax_ins = m.add_instruction(migraphx::operation("argmax", "{axis: 2}"), {x}); - m.add_return({argmax_ins}); - - std::cout << "Creating ArgMax program ..." << std::endl; - std::string target_str = "gpu"; migraphx::target targ = migraphx::target(target_str.c_str()); - if (settings.quantize_fp16) - { - std::cout << "Quantize FP16 ..." << std::endl; - migraphx::quantize_fp16(prog); + if (settings.quantize_fp16) { + std::cout << "Quantize FP16 ..." << std::endl; + migraphx::quantize_fp16(prog); } migraphx::compile_options comp_opts; if (settings.fast_math) - comp_opts.set_fast_math(); + comp_opts.set_fast_math(); comp_opts.set_exhaustive_tune_flag(); std::cout << "Compile to target ..." << std::endl; prog.compile(targ, comp_opts); - return prog; + std::string modelPath = getModelPath(settings); + migraphx::file_options file_options; + file_options.set_file_format("msgpack"); + std::cout << "Saving mxr file to: " << modelPath << "\n"; + migraphx::save(prog, modelPath.c_str(), file_options); + } else { + std::cerr << "Onnx file is not available on path: " << onnx_path + << std::endl; + exit(1); + } + return prog; +}; + +static migraphx::program loadProgram(ModelLoadSettings &settings) { + std::filesystem::path compiled_path(getModelPath(settings)); + + migraphx::file_options file_options; + file_options.set_file_format("msgpack"); + + migraphx::program prog; + std::ifstream f(compiled_path.c_str()); + if (f.good()) { + std::cout << "Loading model from " << compiled_path << " ...\n"; + prog = migraphx::load(compiled_path.c_str(), file_options); + } else { + std::cout << "MXR file can't be loaded try to load ONNX\n"; + prog = loadOnnx(settings); + } + return prog; +}; + +static migraphx::program create_argmax_program(ModelLoadSettings &settings) { + migraphx::program prog; + std::vector dims{BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}; + if (settings.input_one_dim) { + dims[1] = 1; + } + migraphx::shape s{migraphx_shape_half_type, dims}; + migraphx::module m = prog.get_main_module(); + auto x = m.add_parameter("x", s); + auto argmax_ins = + m.add_instruction(migraphx::operation("argmax", "{axis: 2}"), {x}); + m.add_return({argmax_ins}); + + std::cout << "Creating ArgMax program ..." << std::endl; + + std::string target_str = "gpu"; + migraphx::target targ = migraphx::target(target_str.c_str()); + + if (settings.quantize_fp16) { + std::cout << "Quantize FP16 ..." << std::endl; + migraphx::quantize_fp16(prog); + } + + migraphx::compile_options comp_opts; + + if (settings.fast_math) + comp_opts.set_fast_math(); + + comp_opts.set_exhaustive_tune_flag(); + + std::cout << "Compile to target ..." << std::endl; + prog.compile(targ, comp_opts); + + return prog; } -static void writeResults(const std::vector>& results) -{ - std::string RESULT_FILE = "result.txt"; - std::ofstream outFile(RESULT_FILE); - for (auto& resVec : results) - { - for (auto& res : resVec) - { - outFile << res; - if (&res != &resVec.back()) - { - outFile << ", "; - } - } - outFile << "\n"; +static void writeResults(const std::vector> &results) { + std::string RESULT_FILE = "result.txt"; + std::ofstream outFile(RESULT_FILE); + for (auto &resVec : results) { + for (auto &res : resVec) { + outFile << res; + if (&res != &resVec.back()) { + outFile << ", "; + } } + outFile << "\n"; + } } From e5c0d6278dbba1b2a4e12135fc2fb24586513d4e Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Wed, 4 Dec 2024 06:00:09 -0600 Subject: [PATCH 49/55] Format files with clang-format --- .../mgx_llama2/harness/buffer.hpp | 260 ++++++++------- .../mgx_llama2/harness/common.hpp | 117 +++---- .../mgx_llama2/harness/config.hpp | 13 +- .../mgx_llama2/harness/dataset.hpp | 296 +++++++++-------- .../mgx_llama2/harness/llama2inputs.hpp | 243 +++++++------- .../mgx_llama2/harness/llama2outputs.hpp | 112 +++---- .../mgx_llama2/harness/logging.hpp | 37 +-- .../transformers/mgx_llama2/harness/numa.hpp | 185 ++++++----- .../transformers/mgx_llama2/harness/numpy.hpp | 154 ++++----- .../transformers/mgx_llama2/harness/timer.hpp | 100 +++--- .../transformers/mgx_llama2/harness/utils.hpp | 301 ++++++++++-------- 11 files changed, 951 insertions(+), 867 deletions(-) diff --git a/examples/transformers/mgx_llama2/harness/buffer.hpp b/examples/transformers/mgx_llama2/harness/buffer.hpp index 0d04c1d4d00..645c1dd748b 100644 --- a/examples/transformers/mgx_llama2/harness/buffer.hpp +++ b/examples/transformers/mgx_llama2/harness/buffer.hpp @@ -4,153 +4,173 @@ namespace mlinfer { template -struct IBuffer : public INoCopy { - AllocFunc alloc_fn; - FreeFunc free_fn; +struct IBuffer : public INoCopy +{ + AllocFunc alloc_fn; + FreeFunc free_fn; }; template -struct GenericBuffer : public IBuffer { - GenericBuffer() : size_in_bytes{0}, stride_in_bytes{0}, tensor_ptr{nullptr} {} +struct GenericBuffer : public IBuffer +{ + GenericBuffer() : size_in_bytes{0}, stride_in_bytes{0}, tensor_ptr{nullptr} {} + + explicit GenericBuffer(size_t size_in_bytes_, size_t stride_in_bytes_ = 0) + : size_in_bytes{size_in_bytes_}, stride_in_bytes{stride_in_bytes_} + { + if(stride_in_bytes == 0) + { + stride_in_bytes = size_in_bytes; + } + this->alloc_fn(&tensor_ptr, size_in_bytes); + } - explicit GenericBuffer(size_t size_in_bytes_, size_t stride_in_bytes_ = 0) - : size_in_bytes{size_in_bytes_}, stride_in_bytes{stride_in_bytes_} { - if (stride_in_bytes == 0) { - stride_in_bytes = size_in_bytes; + GenericBuffer(GenericBuffer&& buf) + : size_in_bytes{buf.size_in_bytes}, + stride_in_bytes{buf.stride_in_bytes}, + tensor_ptr{buf.tensor_ptr} + { + buf.size_in_bytes = 0; + buf.stride_in_bytes = 0; + buf.tensor_ptr = nullptr; } - this->alloc_fn(&tensor_ptr, size_in_bytes); - } - - GenericBuffer(GenericBuffer &&buf) - : size_in_bytes{buf.size_in_bytes}, stride_in_bytes{buf.stride_in_bytes}, - tensor_ptr{buf.tensor_ptr} { - buf.size_in_bytes = 0; - buf.stride_in_bytes = 0; - buf.tensor_ptr = nullptr; - } - - GenericBuffer &operator=(GenericBuffer &&buf) { - if (this != &buf) { - this->free_fn(tensor_ptr); - size_in_bytes = buf.size_in_bytes; - stride_in_bytes = buf.stride_in_bytes; - tensor_ptr = buf.tensor_ptr; - buf.size_in_bytes = 0; - buf.stride_in_bytes = 0; - buf.tensor_ptr = nullptr; + + GenericBuffer& operator=(GenericBuffer&& buf) + { + if(this != &buf) + { + this->free_fn(tensor_ptr); + size_in_bytes = buf.size_in_bytes; + stride_in_bytes = buf.stride_in_bytes; + tensor_ptr = buf.tensor_ptr; + buf.size_in_bytes = 0; + buf.stride_in_bytes = 0; + buf.tensor_ptr = nullptr; + } + return *this; } - return *this; - } - GenericBuffer(const GenericBuffer &buf) = delete; - GenericBuffer &operator=(const GenericBuffer &buf) = delete; + GenericBuffer(const GenericBuffer& buf) = delete; + GenericBuffer& operator=(const GenericBuffer& buf) = delete; - ~GenericBuffer() { this->free_fn(tensor_ptr); } + ~GenericBuffer() { this->free_fn(tensor_ptr); } - size_t size_in_bytes; - size_t stride_in_bytes; - void *tensor_ptr; + size_t size_in_bytes; + size_t stride_in_bytes; + void* tensor_ptr; }; -struct DeviceAllocator { - void operator()(void **ptr, size_t size) const { - LOG_INFO("Malloc " << size << " bytes on device"); - TIMED(hipMalloc, check_hip_status(hipMalloc(ptr, size))); - TIMED(hipMemset, check_hip_status(hipMemset(*ptr, 0, size))); - } +struct DeviceAllocator +{ + void operator()(void** ptr, size_t size) const + { + LOG_INFO("Malloc " << size << " bytes on device"); + TIMED(hipMalloc, check_hip_status(hipMalloc(ptr, size))); + TIMED(hipMemset, check_hip_status(hipMemset(*ptr, 0, size))); + } }; -struct DeviceFree { - void operator()(void *ptr) const { - TIMED(hipFree, check_hip_status_non_throwing(hipFree(ptr))); - ptr = nullptr; - } +struct DeviceFree +{ + void operator()(void* ptr) const + { + TIMED(hipFree, check_hip_status_non_throwing(hipFree(ptr))); + ptr = nullptr; + } }; -struct HostAllocator { - void operator()(void **ptr, size_t size) const { - LOG_INFO("Malloc " << size << " bytes on host"); - TIMED(hipHostMalloc, check_hip_status(hipHostMalloc(ptr, size))); - TIMED(hipMemset, check_hip_status(hipMemset(*ptr, 0, size))); - } +struct HostAllocator +{ + void operator()(void** ptr, size_t size) const + { + LOG_INFO("Malloc " << size << " bytes on host"); + TIMED(hipHostMalloc, check_hip_status(hipHostMalloc(ptr, size))); + TIMED(hipMemset, check_hip_status(hipMemset(*ptr, 0, size))); + } }; -struct HostFree { - void operator()(void *ptr) const { - TIMED(hipHostFree, check_hip_status_non_throwing(hipHostFree(ptr))); - ptr = nullptr; - } +struct HostFree +{ + void operator()(void* ptr) const + { + TIMED(hipHostFree, check_hip_status_non_throwing(hipHostFree(ptr))); + ptr = nullptr; + } }; using DeviceBuffer = GenericBuffer; -using HostBuffer = GenericBuffer; - -template struct ManagedBuffer_v2 { +using HostBuffer = GenericBuffer; - explicit ManagedBuffer_v2(std::vector &&host_data) { - size_in_bytes = host_data.size() * sizeof(T); - hbuff = std::move(host_data); - dbuff = DeviceBuffer(size_in_bytes, 0); - } +template +struct ManagedBuffer_v2 +{ - void *data() { return dbuff.tensor_ptr; } - - void update(std::vector &&host_data) { hbuff = std::move(host_data); } + explicit ManagedBuffer_v2(std::vector&& host_data) + { + size_in_bytes = host_data.size() * sizeof(T); + hbuff = std::move(host_data); + dbuff = DeviceBuffer(size_in_bytes, 0); + } - void upload_to_device(hipStream_t stream, size_t start_idx = 0, - size_t end_idx = 0) { - char *src_addr = reinterpret_cast(hbuff.data()); - char *dst_addr = static_cast(dbuff.tensor_ptr); - size_t copy_size_in_bytes = size_in_bytes; + void* data() { return dbuff.tensor_ptr; } + + void update(std::vector&& host_data) { hbuff = std::move(host_data); } + + void upload_to_device(hipStream_t stream, size_t start_idx = 0, size_t end_idx = 0) + { + char* src_addr = reinterpret_cast(hbuff.data()); + char* dst_addr = static_cast(dbuff.tensor_ptr); + size_t copy_size_in_bytes = size_in_bytes; + + size_t range_size_in_bytes = (end_idx - start_idx) * sizeof(T); + if(range_size_in_bytes > 0) + { + size_t offset = start_idx * sizeof(T); + src_addr += offset; + dst_addr += offset; + copy_size_in_bytes = range_size_in_bytes; + } + check_hip_status(hipMemcpyHtoDAsync(dst_addr, src_addr, copy_size_in_bytes, stream)); + } - size_t range_size_in_bytes = (end_idx - start_idx) * sizeof(T); - if (range_size_in_bytes > 0) { - size_t offset = start_idx * sizeof(T); - src_addr += offset; - dst_addr += offset; - copy_size_in_bytes = range_size_in_bytes; + void download_from_device(hipStream_t stream, size_t start_idx = 0, size_t end_idx = 0) + { + char* src_addr = static_cast(dbuff.tensor_ptr); + char* dst_addr = reinterpret_cast(hbuff.data()); + size_t copy_size_in_bytes = size_in_bytes; + + size_t range_size_in_bytes = (end_idx - start_idx) * sizeof(T); + if(range_size_in_bytes > 0) + { + size_t offset = start_idx * sizeof(T); + src_addr += offset; + dst_addr += offset; + copy_size_in_bytes = range_size_in_bytes; + } + check_hip_status(hipMemcpyDtoHAsync(dst_addr, src_addr, copy_size_in_bytes, stream)); } - check_hip_status( - hipMemcpyHtoDAsync(dst_addr, src_addr, copy_size_in_bytes, stream)); - } - - void download_from_device(hipStream_t stream, size_t start_idx = 0, - size_t end_idx = 0) { - char *src_addr = static_cast(dbuff.tensor_ptr); - char *dst_addr = reinterpret_cast(hbuff.data()); - size_t copy_size_in_bytes = size_in_bytes; - - size_t range_size_in_bytes = (end_idx - start_idx) * sizeof(T); - if (range_size_in_bytes > 0) { - size_t offset = start_idx * sizeof(T); - src_addr += offset; - dst_addr += offset; - copy_size_in_bytes = range_size_in_bytes; + + void update_data(T data, size_t position, hipStream_t stream) + { + hbuff.at(position) = data; + // TODO: don't copy over the entire buffer just the changed range + // check_hip_status(hipMemcpy(get_device_ptr(), + // get_host_ptr(), dbuff.size_in_bytes, + // hipMemcpyKind::hipMemcpyHostToDevice)); + upload_to_device(stream, position, position + 1); } - check_hip_status( - hipMemcpyDtoHAsync(dst_addr, src_addr, copy_size_in_bytes, stream)); - } - - void update_data(T data, size_t position, hipStream_t stream) { - hbuff.at(position) = data; - // TODO: don't copy over the entire buffer just the changed range - // check_hip_status(hipMemcpy(get_device_ptr(), - // get_host_ptr(), dbuff.size_in_bytes, - // hipMemcpyKind::hipMemcpyHostToDevice)); - upload_to_device(stream, position, position + 1); - } - - ManagedBuffer_v2() = delete; - ManagedBuffer_v2(const ManagedBuffer_v2 &buf) = delete; - ManagedBuffer_v2 &operator=(const ManagedBuffer_v2 &buf) = delete; - - DeviceBuffer dbuff; - std::vector hbuff; - size_t size_in_bytes; + + ManagedBuffer_v2() = delete; + ManagedBuffer_v2(const ManagedBuffer_v2& buf) = delete; + ManagedBuffer_v2& operator=(const ManagedBuffer_v2& buf) = delete; + + DeviceBuffer dbuff; + std::vector hbuff; + size_t size_in_bytes; }; -using LLama2InputBuffer = ManagedBuffer_v2; -using LLama2OutputBuffer = ManagedBuffer_v2; +using LLama2InputBuffer = ManagedBuffer_v2; +using LLama2OutputBuffer = ManagedBuffer_v2; using LLama2PastKeyValueBuffer = ManagedBuffer_v2; -using ArgMaxOutputBuffer = ManagedBuffer_v2; +using ArgMaxOutputBuffer = ManagedBuffer_v2; } // namespace mlinfer diff --git a/examples/transformers/mgx_llama2/harness/common.hpp b/examples/transformers/mgx_llama2/harness/common.hpp index 1df0159f0b0..0f7ba9f92f8 100644 --- a/examples/transformers/mgx_llama2/harness/common.hpp +++ b/examples/transformers/mgx_llama2/harness/common.hpp @@ -20,74 +20,83 @@ using half = half_float::half; using namespace half_float::literal; namespace mlinfer { -struct INoCopy { - INoCopy() = default; - virtual ~INoCopy() = default; - INoCopy(const INoCopy &) = delete; - INoCopy &operator=(const INoCopy &) = delete; +struct INoCopy +{ + INoCopy() = default; + virtual ~INoCopy() = default; + INoCopy(const INoCopy&) = delete; + INoCopy& operator=(const INoCopy&) = delete; }; /* Helper function to split a string based on a delimiting character */ -inline std::vector splitString(const std::string &input, - const std::string &delimiter) { - std::vector result; - size_t start = 0; - size_t next = 0; - while (next != std::string::npos) { - next = input.find(delimiter, start); - result.emplace_back(input, start, next - start); - start = next + 1; - } - return result; +inline std::vector splitString(const std::string& input, const std::string& delimiter) +{ + std::vector result; + size_t start = 0; + size_t next = 0; + while(next != std::string::npos) + { + next = input.find(delimiter, start); + result.emplace_back(input, start, next - start); + start = next + 1; + } + return result; } -#define check_hip_status(hip_call) \ - do { \ - int status = (hip_call); \ - if (status != hipSuccess) { \ - throw std::runtime_error( \ - "hip error (" + std::to_string(status) + "): " + \ - std::string(hipGetErrorString(static_cast(status)))); \ - } \ - } while (0); +#define check_hip_status(hip_call) \ + do \ + { \ + int status = (hip_call); \ + if(status != hipSuccess) \ + { \ + throw std::runtime_error( \ + "hip error (" + std::to_string(status) + \ + "): " + std::string(hipGetErrorString(static_cast(status)))); \ + } \ + } while(0); -#define check_hip_status_non_throwing(hip_call) \ - do { \ - int status = (hip_call); \ - if (status != hipSuccess) { \ - LOG_INFO( \ - "hip error (" + std::to_string(status) + "): " + \ - std::string(hipGetErrorString(static_cast(status)))); \ - } \ - } while (0); +#define check_hip_status_non_throwing(hip_call) \ + do \ + { \ + int status = (hip_call); \ + if(status != hipSuccess) \ + { \ + LOG_INFO("hip error (" + std::to_string(status) + \ + "): " + std::string(hipGetErrorString(static_cast(status)))); \ + } \ + } while(0); -#define CHECK(condition, error) \ - do { \ - if (!(condition)) { \ - std::cerr << error << std::endl; \ - } \ - } while (0); +#define CHECK(condition, error) \ + do \ + { \ + if(!(condition)) \ + { \ + std::cerr << error << std::endl; \ + } \ + } while(0); #if TIMER_ON -#define TIMER_STARTV(s) \ - static Timer timer##s(#s, true); \ - auto start##s = std::chrono::high_resolution_clock::now(); -#define TIMER_START(s) \ - static Timer timer##s(#s); \ - auto start##s = std::chrono::high_resolution_clock::now(); -#define TIMER_END(s) \ - timer##s.add(std::chrono::high_resolution_clock::now() - start##s); +#define TIMER_STARTV(s) \ + static Timer timer##s(#s, true); \ + auto start##s = std::chrono::high_resolution_clock::now(); +#define TIMER_START(s) \ + static Timer timer##s(#s); \ + auto start##s = std::chrono::high_resolution_clock::now(); +#define TIMER_END(s) timer##s.add(std::chrono::high_resolution_clock::now() - start##s); #else #define TIMER_START(s) #define TIMER_STARTV(s) #define TIMER_END(s) #endif -#define TIMED(s, call) \ - do { \ - TIMER_START(s); \ - { call; } \ - TIMER_END(s); \ - } while (0); +#define TIMED(s, call) \ + do \ + { \ + TIMER_START(s); \ + { \ + call; \ + } \ + TIMER_END(s); \ + } while(0); } // namespace mlinfer diff --git a/examples/transformers/mgx_llama2/harness/config.hpp b/examples/transformers/mgx_llama2/harness/config.hpp index 883ac47df7c..69867402296 100644 --- a/examples/transformers/mgx_llama2/harness/config.hpp +++ b/examples/transformers/mgx_llama2/harness/config.hpp @@ -1,11 +1,11 @@ #pragma once // TODO: fix paths -const std::string MODEL_FOLDER = "/model/"; -const std::string ONNX_FILE = "model.onnx"; +const std::string MODEL_FOLDER = "/model/"; +const std::string ONNX_FILE = "model.onnx"; const std::string DATASET_FOLDER = "/dataset/"; -const size_t DATASET_SIZE = 10; -const size_t BATCH_SIZE = 1; +const size_t DATASET_SIZE = 10; +const size_t BATCH_SIZE = 1; // sequence length from model config const size_t SEQ_SIZE = 1024; // vocab size from model config @@ -18,6 +18,5 @@ const bool WRITE_RESULT_FILE = false; const int DEVICE_ID = 4; const size_t HIDDEN_LAYERS_NUM = 32; -const size_t HEAD_SIZE = 128; -const size_t PAST_KEY_VAL_SIZE = - BATCH_SIZE * HIDDEN_LAYERS_NUM * HEAD_SIZE * SEQ_SIZE; \ No newline at end of file +const size_t HEAD_SIZE = 128; +const size_t PAST_KEY_VAL_SIZE = BATCH_SIZE * HIDDEN_LAYERS_NUM * HEAD_SIZE * SEQ_SIZE; \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/harness/dataset.hpp b/examples/transformers/mgx_llama2/harness/dataset.hpp index 5a6a739f953..850862ff86f 100644 --- a/examples/transformers/mgx_llama2/harness/dataset.hpp +++ b/examples/transformers/mgx_llama2/harness/dataset.hpp @@ -10,157 +10,173 @@ using namespace mlinfer; using NumpyVector = std::vector>; -struct Dataset { - Dataset() = default; - - void initialize() { - loadDataset(); - if (!_npy_files_loaded) { - prepareSampleDataset(); - } - } - - NumpyVector loadNumpy(npy::NpyFile &file) { - NumpyVector numpyDataAll; - auto load_size = file.GetTensorSize() / sizeof(int64_t); - numpyDataAll.push_back(std::vector(load_size)); - file.LoadAll(numpyDataAll.back().data()); - - NumpyVector numpyData; - for (size_t i = 0; i < numpyDataAll.back().size(); i += SEQ_SIZE) { - auto last = std::min(numpyDataAll.back().size(), i + SEQ_SIZE); - numpyData.emplace_back(numpyDataAll.back().begin() + i, - numpyDataAll.back().begin() + last); +struct Dataset +{ + Dataset() = default; + + void initialize() + { + loadDataset(); + if(!_npy_files_loaded) + { + prepareSampleDataset(); + } } + NumpyVector loadNumpy(npy::NpyFile& file) + { + NumpyVector numpyDataAll; + auto load_size = file.GetTensorSize() / sizeof(int64_t); + numpyDataAll.push_back(std::vector(load_size)); + file.LoadAll(numpyDataAll.back().data()); + + NumpyVector numpyData; + for(size_t i = 0; i < numpyDataAll.back().size(); i += SEQ_SIZE) + { + auto last = std::min(numpyDataAll.back().size(), i + SEQ_SIZE); + numpyData.emplace_back(numpyDataAll.back().begin() + i, + numpyDataAll.back().begin() + last); + } + #ifdef TRACE - for (auto &vec : numpyData) { - std::cout << "Vector size: " << vec.size() << std::endl; - for (auto val : vec) { - std::cout << val << " "; - } - std::cout << "\n"; - } + for(auto& vec : numpyData) + { + std::cout << "Vector size: " << vec.size() << std::endl; + for(auto val : vec) + { + std::cout << val << " "; + } + std::cout << "\n"; + } #endif - return numpyData; - } - - size_t getLastIdx(int current_batch_idx) const { - auto idx = _current_batch * BATCH_SIZE + current_batch_idx; - auto res = std::find_if(std::rbegin(attention_mask[idx]), - std::rend(attention_mask[idx]), - [](uint64_t val) { return 1 == val; }); - size_t last_idx = std::distance(res, std::rend(attention_mask[idx])); + return numpyData; + } + + size_t getLastIdx(int current_batch_idx) const + { + auto idx = _current_batch * BATCH_SIZE + current_batch_idx; + auto res = std::find_if(std::rbegin(attention_mask[idx]), + std::rend(attention_mask[idx]), + [](uint64_t val) { return 1 == val; }); + size_t last_idx = std::distance(res, std::rend(attention_mask[idx])); #ifdef TRACE - std::cout << "Last input idx: " << last_idx << std::endl; + std::cout << "Last input idx: " << last_idx << std::endl; #endif - return last_idx; - } - - std::vector getInputIds() { return getBatchedBuffer(input_ids); } - - std::vector getAttentionMask() { - return getBatchedBuffer(attention_mask, 1); - } - - std::vector getBatchedBuffer(NumpyVector &buffer, size_t value = 0) { - std::vector batchedBuffer; - const size_t buffer_size = SEQ_SIZE * BATCH_SIZE; - batchedBuffer.reserve(buffer_size); - auto batchSize = BATCH_SIZE; - if (_current_batch == batchNum() - 1) { - batchSize = _size % BATCH_SIZE; + return last_idx; } - for (size_t i = 0; i < batchSize; ++i) { - auto buffVec = buffer[BATCH_SIZE * _current_batch + i]; - std::copy(buffVec.begin(), buffVec.end(), - std::back_inserter(batchedBuffer)); + std::vector getInputIds() { return getBatchedBuffer(input_ids); } + + std::vector getAttentionMask() { return getBatchedBuffer(attention_mask, 1); } + + std::vector getBatchedBuffer(NumpyVector& buffer, size_t value = 0) + { + std::vector batchedBuffer; + const size_t buffer_size = SEQ_SIZE * BATCH_SIZE; + batchedBuffer.reserve(buffer_size); + auto batchSize = BATCH_SIZE; + if(_current_batch == batchNum() - 1) + { + batchSize = _size % BATCH_SIZE; + } + + for(size_t i = 0; i < batchSize; ++i) + { + auto buffVec = buffer[BATCH_SIZE * _current_batch + i]; + std::copy(buffVec.begin(), buffVec.end(), std::back_inserter(batchedBuffer)); + } + + if(batchSize != BATCH_SIZE) + { + // For last batch, setting buffer values if no sample is available + batchedBuffer.resize(buffer_size, value); + } + return batchedBuffer; } - if (batchSize != BATCH_SIZE) { - // For last batch, setting buffer values if no sample is available - batchedBuffer.resize(buffer_size, value); - } - return batchedBuffer; - } - - size_t size() const { return _size; } - size_t currentBatchIdx() const { return _current_batch; } - size_t batchNum() const { - return _size / BATCH_SIZE + (_size % BATCH_SIZE != 0); - } - size_t getNext() { - if (_current_batch < batchNum() - 1) { - ++_current_batch; - } + size_t size() const { return _size; } + size_t currentBatchIdx() const { return _current_batch; } + size_t batchNum() const { return _size / BATCH_SIZE + (_size % BATCH_SIZE != 0); } + size_t getNext() + { + if(_current_batch < batchNum() - 1) + { + ++_current_batch; + } #ifdef TRACE - std::cout << "Current batch: " << _current_batch << std::endl; + std::cout << "Current batch: " << _current_batch << std::endl; #endif - return _current_batch; - } - - Dataset(const Dataset &buf) = delete; - Dataset &operator=(const Dataset &buf) = delete; - -private: - // e.g.: /dataset/input_ids_size_3_seq_256.npy - std::string getDatasetPath(const std::string &datasetName) { - std::stringstream path; - path << DATASET_FOLDER << datasetName << "_size_" - << std::to_string(DATASET_SIZE) << "_seq_" << std::to_string(SEQ_SIZE) - << ".npy"; - return path.str(); - } - - void loadDataset() { - std::string input_file_path = getDatasetPath("input_ids"); - std::string attention_mask_file_path = getDatasetPath("attention_mask"); - - std::cout << "Input ids file: " << input_file_path << std::endl; - std::ifstream input_file(input_file_path.c_str()); - std::ifstream attention_mask_file(attention_mask_file_path.c_str()); - if (input_file.good() && attention_mask_file.good()) { - npy::NpyFile input_ids_npy{input_file_path}; - npy::NpyFile attention_mask_npy{attention_mask_file_path}; - input_ids = loadNumpy(input_ids_npy); - attention_mask = loadNumpy(attention_mask_npy); - - _size = input_ids.size(); - - if (input_ids.size() == attention_mask.size()) { - std::cout << "Loaded numpy files\n"; - _npy_files_loaded = true; - } else { - std::cout << "Numpy files do not have the same size\n"; - input_ids.clear(); - attention_mask.clear(); - } - } else { - std::cout << "Unable to open numpy files\n"; + return _current_batch; } - } - - void prepareSampleDataset() { - std::cout << "Numpy files are not loaded, using dummy data\n"; - std::vector input_ids_sample = {1, 6804, 338, 5207, - 387, 287, 29973}; - input_ids_sample.resize(SEQ_SIZE, 0); - std::vector attention_mask_sample = input_ids_sample; - input_ids.emplace_back(std::move(input_ids_sample)); - std::transform(std::begin(attention_mask_sample), - std::end(attention_mask_sample), - std::begin(attention_mask_sample), - [](auto i) { return (i != 0) ? 1 : 0; }); - attention_mask.emplace_back(std::move(attention_mask_sample)); - - _size = 1; - } - - NumpyVector input_ids; - NumpyVector attention_mask; - - size_t _size = 0; - size_t _current_batch = 0; - bool _npy_files_loaded = false; + + Dataset(const Dataset& buf) = delete; + Dataset& operator=(const Dataset& buf) = delete; + + private: + // e.g.: /dataset/input_ids_size_3_seq_256.npy + std::string getDatasetPath(const std::string& datasetName) + { + std::stringstream path; + path << DATASET_FOLDER << datasetName << "_size_" << std::to_string(DATASET_SIZE) << "_seq_" + << std::to_string(SEQ_SIZE) << ".npy"; + return path.str(); + } + + void loadDataset() + { + std::string input_file_path = getDatasetPath("input_ids"); + std::string attention_mask_file_path = getDatasetPath("attention_mask"); + + std::cout << "Input ids file: " << input_file_path << std::endl; + std::ifstream input_file(input_file_path.c_str()); + std::ifstream attention_mask_file(attention_mask_file_path.c_str()); + if(input_file.good() && attention_mask_file.good()) + { + npy::NpyFile input_ids_npy{input_file_path}; + npy::NpyFile attention_mask_npy{attention_mask_file_path}; + input_ids = loadNumpy(input_ids_npy); + attention_mask = loadNumpy(attention_mask_npy); + + _size = input_ids.size(); + + if(input_ids.size() == attention_mask.size()) + { + std::cout << "Loaded numpy files\n"; + _npy_files_loaded = true; + } + else + { + std::cout << "Numpy files do not have the same size\n"; + input_ids.clear(); + attention_mask.clear(); + } + } + else + { + std::cout << "Unable to open numpy files\n"; + } + } + + void prepareSampleDataset() + { + std::cout << "Numpy files are not loaded, using dummy data\n"; + std::vector input_ids_sample = {1, 6804, 338, 5207, 387, 287, 29973}; + input_ids_sample.resize(SEQ_SIZE, 0); + std::vector attention_mask_sample = input_ids_sample; + input_ids.emplace_back(std::move(input_ids_sample)); + std::transform(std::begin(attention_mask_sample), + std::end(attention_mask_sample), + std::begin(attention_mask_sample), + [](auto i) { return (i != 0) ? 1 : 0; }); + attention_mask.emplace_back(std::move(attention_mask_sample)); + + _size = 1; + } + + NumpyVector input_ids; + NumpyVector attention_mask; + + size_t _size = 0; + size_t _current_batch = 0; + bool _npy_files_loaded = false; }; \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp index 29605d78ee2..d33934f342e 100644 --- a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp @@ -3,131 +3,132 @@ #include "config.hpp" #include "utils.hpp" -struct LLama2Inputs { - LLama2Inputs(migraphx::program &prog, - migraphx::program_parameters &prog_args) { - data.initialize(); - prepareProgArgs(prog, prog_args); - } - - void prepareProgArgs(migraphx::program &prog, - migraphx::program_parameters &prog_args, - bool simple = false) { - auto param_shapes = prog.get_parameter_shapes(); - if (!simple) { - auto inputShape = param_shapes[INPUTS_ID_STR]; - auto input_ids = data.getInputIds(); - input_ids_buffer = - std::make_unique(std::move(input_ids)); - prog_args.add(INPUTS_ID_STR, - migraphx::argument(inputShape, input_ids_buffer->data())); +struct LLama2Inputs +{ + LLama2Inputs(migraphx::program& prog, migraphx::program_parameters& prog_args) + { + data.initialize(); + prepareProgArgs(prog, prog_args); } - auto attShape = param_shapes[ATTENTION_MASK_STR]; - auto attention_mask = data.getAttentionMask(); - if (!simple) { - attention_mask_buffer = - std::make_unique(std::move(attention_mask)); + void prepareProgArgs(migraphx::program& prog, + migraphx::program_parameters& prog_args, + bool simple = false) + { + auto param_shapes = prog.get_parameter_shapes(); + if(!simple) + { + auto inputShape = param_shapes[INPUTS_ID_STR]; + auto input_ids = data.getInputIds(); + input_ids_buffer = std::make_unique(std::move(input_ids)); + prog_args.add(INPUTS_ID_STR, migraphx::argument(inputShape, input_ids_buffer->data())); + } + + auto attShape = param_shapes[ATTENTION_MASK_STR]; + auto attention_mask = data.getAttentionMask(); + if(!simple) + { + attention_mask_buffer = std::make_unique(std::move(attention_mask)); + } + prog_args.add(ATTENTION_MASK_STR, + migraphx::argument(attShape, attention_mask_buffer->data())); + + // past_key_values.0.key = @param:past_key_values.0.key -> half_type, {1, + // 32, 1, 128}, {4096, 128, 128, 1} past_key_values.0.value = + // @param:past_key_values.0.value -> half_type, {1, 32, 1, 128}, {4096, 128, + // 128, 1} + for(size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + { + auto past_keyStr = getPastKeyString(i); + auto past_keyString = past_keyStr.c_str(); + if(!simple) + { + past_key_buffers.emplace_back(std::make_unique( + std::vector(PAST_KEY_VAL_SIZE, 0.0_h))); + } + auto pastKeyShape = param_shapes[past_keyString]; + prog_args.add(past_keyString, + migraphx::argument(pastKeyShape, past_key_buffers[i]->data())); + + auto past_valueStr = getPastValueStr(i); + auto past_valueString = past_valueStr.c_str(); + if(!simple) + { + past_value_buffers.emplace_back(std::make_unique( + std::vector(PAST_KEY_VAL_SIZE, 0.0_h))); + } + auto pastValueShape = param_shapes[past_valueString]; + prog_args.add(past_valueString, + migraphx::argument(pastValueShape, past_value_buffers[i]->data())); + } } - prog_args.add(ATTENTION_MASK_STR, - migraphx::argument(attShape, attention_mask_buffer->data())); - - // past_key_values.0.key = @param:past_key_values.0.key -> half_type, {1, - // 32, 1, 128}, {4096, 128, 128, 1} past_key_values.0.value = - // @param:past_key_values.0.value -> half_type, {1, 32, 1, 128}, {4096, 128, - // 128, 1} - for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) { - auto past_keyStr = getPastKeyString(i); - auto past_keyString = past_keyStr.c_str(); - if (!simple) { - past_key_buffers.emplace_back( - std::make_unique( - std::vector(PAST_KEY_VAL_SIZE, 0.0_h))); - } - auto pastKeyShape = param_shapes[past_keyString]; - prog_args.add( - past_keyString, - migraphx::argument(pastKeyShape, past_key_buffers[i]->data())); - - auto past_valueStr = getPastValueStr(i); - auto past_valueString = past_valueStr.c_str(); - if (!simple) { - past_value_buffers.emplace_back( - std::make_unique( - std::vector(PAST_KEY_VAL_SIZE, 0.0_h))); - } - auto pastValueShape = param_shapes[past_valueString]; - prog_args.add( - past_valueString, - migraphx::argument(pastValueShape, past_value_buffers[i]->data())); + + void prepareOneDimProgArgs(migraphx::program& prog, migraphx::program_parameters& prog_args) + { + prepareProgArgs(prog, prog_args, true); + auto param_shapes = prog.get_parameter_shapes(); + auto inputShape = param_shapes[INPUTS_ID_STR]; + std::vector oneDimInput(BATCH_SIZE, 0); + one_dim_input_buffer = std::make_unique(std::move(oneDimInput)); + prog_args.add(INPUTS_ID_STR, migraphx::argument(inputShape, one_dim_input_buffer->data())); + } + + void upload_to_device(hipStream_t stream) + { + input_ids_buffer->upload_to_device(stream); + attention_mask_buffer->upload_to_device(stream); } - } - - void prepareOneDimProgArgs(migraphx::program &prog, - migraphx::program_parameters &prog_args) { - prepareProgArgs(prog, prog_args, true); - auto param_shapes = prog.get_parameter_shapes(); - auto inputShape = param_shapes[INPUTS_ID_STR]; - std::vector oneDimInput(BATCH_SIZE, 0); - one_dim_input_buffer = - std::make_unique(std::move(oneDimInput)); - prog_args.add(INPUTS_ID_STR, - migraphx::argument(inputShape, one_dim_input_buffer->data())); - } - - void upload_to_device(hipStream_t stream) { - input_ids_buffer->upload_to_device(stream); - attention_mask_buffer->upload_to_device(stream); - } - - bool updateData(migraphx::program &prog, - migraphx::program_parameters &prog_args) { - auto batchIdx = data.currentBatchIdx(); - if (batchIdx != data.getNext()) { - auto param_shapes = prog.get_parameter_shapes(); - - std::vector input_ids = data.getInputIds(); - input_ids_buffer = - std::make_unique(std::move(input_ids)); - prog_args.add(INPUTS_ID_STR, - migraphx::argument(param_shapes[INPUTS_ID_STR], - input_ids_buffer->data())); - - auto attention_mask = data.getAttentionMask(); - attention_mask_buffer->update(std::move(attention_mask)); - - return true; + + bool updateData(migraphx::program& prog, migraphx::program_parameters& prog_args) + { + auto batchIdx = data.currentBatchIdx(); + if(batchIdx != data.getNext()) + { + auto param_shapes = prog.get_parameter_shapes(); + + std::vector input_ids = data.getInputIds(); + input_ids_buffer = std::make_unique(std::move(input_ids)); + prog_args.add( + INPUTS_ID_STR, + migraphx::argument(param_shapes[INPUTS_ID_STR], input_ids_buffer->data())); + + auto attention_mask = data.getAttentionMask(); + attention_mask_buffer->update(std::move(attention_mask)); + + return true; + } + return false; } - return false; - } - - void resetPastKeyValueBuffers(hipStream_t stream) { - for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) { - past_key_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); - past_value_buffers[i]->update( - std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); - past_key_buffers[i]->upload_to_device(stream); - past_value_buffers[i]->upload_to_device(stream); + + void resetPastKeyValueBuffers(hipStream_t stream) + { + for(size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + { + past_key_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); + past_value_buffers[i]->update(std::vector(PAST_KEY_VAL_SIZE, 0.0_h)); + past_key_buffers[i]->upload_to_device(stream); + past_value_buffers[i]->upload_to_device(stream); + } + } + + size_t getLastInputIndex(int current_batch_idx) const + { + return data.getLastIdx(current_batch_idx); } - } - - size_t getLastInputIndex(int current_batch_idx) const { - return data.getLastIdx(current_batch_idx); - } - size_t dataSize() const { return data.size(); } - size_t batchNum() const { return data.batchNum(); } - - LLama2Inputs() = delete; - LLama2Inputs(const LLama2Inputs &buf) = delete; - LLama2Inputs &operator=(const LLama2Inputs &buf) = delete; - - std::unique_ptr input_ids_buffer; - std::unique_ptr one_dim_input_buffer; - std::unique_ptr attention_mask_buffer; - std::vector> past_key_buffers; - std::vector> past_value_buffers; - Dataset data; - - const char *INPUTS_ID_STR = "input_ids"; - const char *ATTENTION_MASK_STR = "attention_mask"; + size_t dataSize() const { return data.size(); } + size_t batchNum() const { return data.batchNum(); } + + LLama2Inputs() = delete; + LLama2Inputs(const LLama2Inputs& buf) = delete; + LLama2Inputs& operator=(const LLama2Inputs& buf) = delete; + + std::unique_ptr input_ids_buffer; + std::unique_ptr one_dim_input_buffer; + std::unique_ptr attention_mask_buffer; + std::vector> past_key_buffers; + std::vector> past_value_buffers; + Dataset data; + + const char* INPUTS_ID_STR = "input_ids"; + const char* ATTENTION_MASK_STR = "attention_mask"; }; diff --git a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp index ed2cab2566f..7fb5ebe0d6a 100644 --- a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp @@ -5,65 +5,55 @@ #include -struct LLama2Outputs { - LLama2Outputs() {} - - void prepareProgArgs(migraphx::program_parameters &prog_args, - migraphx::program_parameters &prog_args_one_dim) { - output_buffer = std::make_unique( - std::vector(OUTPUT_SIZE)); - one_dim_output_buffer = std::make_unique( - std::vector(BATCH_SIZE * VOCAB_SIZE)); - migraphx::shape out_shape{migraphx_shape_half_type, - {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; - prog_args.add(OUTPUT_NAME, - migraphx::argument(out_shape, output_buffer->data())); - - migraphx::shape x_shape_one_dim{migraphx_shape_half_type, - {BATCH_SIZE, 1, VOCAB_SIZE}}; - prog_args_one_dim.add( - OUTPUT_NAME, - migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); - } - - void prepareProgArgsArgMax( - migraphx::program_parameters &prog_args_argmax, - migraphx::program_parameters &prog_args_argmax_one_dim) { - // setting up argmax arguments - migraphx::shape x_shape{migraphx_shape_half_type, - {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; - prog_args_argmax.add("x", - migraphx::argument(x_shape, output_buffer->data())); - argm_output_buffer = std::make_unique( - std::vector(BATCH_SIZE * SEQ_SIZE)); - migraphx::shape argm_out_shape{migraphx_shape_int64_type, - {BATCH_SIZE, SEQ_SIZE, 1}}; - prog_args_argmax.add( - OUTPUT_NAME, - migraphx::argument(argm_out_shape, argm_output_buffer->data())); - - migraphx::shape x_shape_one_dim{migraphx_shape_half_type, - {BATCH_SIZE, 1, VOCAB_SIZE}}; - prog_args_argmax_one_dim.add( - "x", - migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); - argm_output_buffer_one_dim = - std::make_unique(std::vector(BATCH_SIZE)); - migraphx::shape argm_out_shape_one_dim{migraphx_shape_int64_type, - {BATCH_SIZE, 1, 1}}; - prog_args_argmax_one_dim.add( - OUTPUT_NAME, migraphx::argument(argm_out_shape_one_dim, - argm_output_buffer_one_dim->data())); - } - - LLama2Outputs(const LLama2Outputs &buf) = delete; - LLama2Outputs &operator=(const LLama2Outputs &buf) = delete; - - std::unique_ptr output_buffer; - std::unique_ptr one_dim_output_buffer; - std::unique_ptr argm_output_buffer; - std::unique_ptr argm_output_buffer_one_dim; - - const char *OUTPUT_NAME = "main:#output_0"; - const size_t OUTPUT_SIZE = BATCH_SIZE * SEQ_SIZE * VOCAB_SIZE; +struct LLama2Outputs +{ + LLama2Outputs() {} + + void prepareProgArgs(migraphx::program_parameters& prog_args, + migraphx::program_parameters& prog_args_one_dim) + { + output_buffer = std::make_unique(std::vector(OUTPUT_SIZE)); + one_dim_output_buffer = + std::make_unique(std::vector(BATCH_SIZE * VOCAB_SIZE)); + migraphx::shape out_shape{migraphx_shape_half_type, {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; + prog_args.add(OUTPUT_NAME, migraphx::argument(out_shape, output_buffer->data())); + + migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {BATCH_SIZE, 1, VOCAB_SIZE}}; + prog_args_one_dim.add(OUTPUT_NAME, + migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); + } + + void prepareProgArgsArgMax(migraphx::program_parameters& prog_args_argmax, + migraphx::program_parameters& prog_args_argmax_one_dim) + { + // setting up argmax arguments + migraphx::shape x_shape{migraphx_shape_half_type, {BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}}; + prog_args_argmax.add("x", migraphx::argument(x_shape, output_buffer->data())); + argm_output_buffer = + std::make_unique(std::vector(BATCH_SIZE * SEQ_SIZE)); + migraphx::shape argm_out_shape{migraphx_shape_int64_type, {BATCH_SIZE, SEQ_SIZE, 1}}; + prog_args_argmax.add(OUTPUT_NAME, + migraphx::argument(argm_out_shape, argm_output_buffer->data())); + + migraphx::shape x_shape_one_dim{migraphx_shape_half_type, {BATCH_SIZE, 1, VOCAB_SIZE}}; + prog_args_argmax_one_dim.add( + "x", migraphx::argument(x_shape_one_dim, one_dim_output_buffer->data())); + argm_output_buffer_one_dim = + std::make_unique(std::vector(BATCH_SIZE)); + migraphx::shape argm_out_shape_one_dim{migraphx_shape_int64_type, {BATCH_SIZE, 1, 1}}; + prog_args_argmax_one_dim.add( + OUTPUT_NAME, + migraphx::argument(argm_out_shape_one_dim, argm_output_buffer_one_dim->data())); + } + + LLama2Outputs(const LLama2Outputs& buf) = delete; + LLama2Outputs& operator=(const LLama2Outputs& buf) = delete; + + std::unique_ptr output_buffer; + std::unique_ptr one_dim_output_buffer; + std::unique_ptr argm_output_buffer; + std::unique_ptr argm_output_buffer_one_dim; + + const char* OUTPUT_NAME = "main:#output_0"; + const size_t OUTPUT_SIZE = BATCH_SIZE * SEQ_SIZE * VOCAB_SIZE; }; \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/harness/logging.hpp b/examples/transformers/mgx_llama2/harness/logging.hpp index 2dee1ec530a..9770090e079 100644 --- a/examples/transformers/mgx_llama2/harness/logging.hpp +++ b/examples/transformers/mgx_llama2/harness/logging.hpp @@ -8,30 +8,31 @@ namespace mlinfer { #define ENABLE_TIMED_LOGGING 0 #define ENABLE_DEBUG_LOGGING 0 -#if (!LOGGING_OFF) -#define LOG_INFO(...) \ - do { \ - std::cout << __VA_ARGS__ << std::endl; \ - } while (0) -#define LOG_ERROR(...) \ - do { \ - std::cerr << __VA_ARGS__ << std::endl; \ - } while (0) -#define LOG_STATE(...) \ - do { \ - std::cout << "================================================" \ - << std::endl; \ - std::cout << __VA_ARGS__ << std::endl; \ - std::cout << "================================================" \ - << std::endl; \ - } while (0) +#if(!LOGGING_OFF) +#define LOG_INFO(...) \ + do \ + { \ + std::cout << __VA_ARGS__ << std::endl; \ + } while(0) +#define LOG_ERROR(...) \ + do \ + { \ + std::cerr << __VA_ARGS__ << std::endl; \ + } while(0) +#define LOG_STATE(...) \ + do \ + { \ + std::cout << "================================================" << std::endl; \ + std::cout << __VA_ARGS__ << std::endl; \ + std::cout << "================================================" << std::endl; \ + } while(0) #else #define LOG_INFO(...) (void)0 #define LOG_ERROR(...) (void)0 #define LOG_STATE(...) (void)0 #endif -#if (ENABLE_TIMED_LOGGING || ENABLE_DEBUG_LOGGING) +#if(ENABLE_TIMED_LOGGING || ENABLE_DEBUG_LOGGING) #define LOG_TIMED(...) LOG_INFO(__VA_ARGS__) #else #define LOG_TIMED(...) (void)0 diff --git a/examples/transformers/mgx_llama2/harness/numa.hpp b/examples/transformers/mgx_llama2/harness/numa.hpp index 13b46e84d67..4aa60ddeccb 100644 --- a/examples/transformers/mgx_llama2/harness/numa.hpp +++ b/examples/transformers/mgx_llama2/harness/numa.hpp @@ -14,123 +14,138 @@ namespace mlinfer { // NUMA config. Each NUMA node contains a pair of GPU indices and CPU indices. -using NumaConfig = - std::vector, std::vector>>; +using NumaConfig = std::vector, std::vector>>; // The NUMA node idx for each GPU. using GpuToNumaMap = std::vector; -struct NumaSettings { - NumaConfig numa_config; - GpuToNumaMap gpu_to_numa_map; +struct NumaSettings +{ + NumaConfig numa_config; + GpuToNumaMap gpu_to_numa_map; }; -struct Numa final { - NumaSettings numa_settings; +struct Numa final +{ + NumaSettings numa_settings; - explicit Numa(const NumaSettings &numa_settings) - : numa_settings{numa_settings} {} + explicit Numa(const NumaSettings& numa_settings) : numa_settings{numa_settings} {} - inline bool UseNuma() const { return not numa_settings.numa_config.empty(); } + inline bool UseNuma() const { return not numa_settings.numa_config.empty(); } - inline size_t GetNumaCount() const { - return numa_settings.numa_config.size(); - }; + inline size_t GetNumaCount() const { return numa_settings.numa_config.size(); }; - inline int GetNumaIdx(const int deviceId) const { - return UseNuma() ? numa_settings.gpu_to_numa_map.at(deviceId) : 0; - } + inline int GetNumaIdx(const int deviceId) const + { + return UseNuma() ? numa_settings.gpu_to_numa_map.at(deviceId) : 0; + } - inline std::vector GetClosestCpus(const int deviceId) const { - assertm(UseNuma(), "GetClosestCpus only available for NUMA"); - return numa_settings.numa_config.at(GetNumaIdx(deviceId)).second; - } + inline std::vector GetClosestCpus(const int deviceId) const + { + assertm(UseNuma(), "GetClosestCpus only available for NUMA"); + return numa_settings.numa_config.at(GetNumaIdx(deviceId)).second; + } }; // Restrict mem allocation to specific NUMA node. -inline void bindNumaMemPolicy(const int32_t numaIdx, const int32_t nbNumas) { - unsigned long nodeMask = 1UL << numaIdx; - long ret = set_mempolicy(MPOL_BIND, &nodeMask, nbNumas + 1); - CHECK(ret >= 0, std::strerror(errno)); +inline void bindNumaMemPolicy(const int32_t numaIdx, const int32_t nbNumas) +{ + unsigned long nodeMask = 1UL << numaIdx; + long ret = set_mempolicy(MPOL_BIND, &nodeMask, nbNumas + 1); + CHECK(ret >= 0, std::strerror(errno)); } // Reset mem allocation setting. -inline void resetNumaMemPolicy() { - long ret = set_mempolicy(MPOL_DEFAULT, nullptr, 0); - CHECK(ret >= 0, std::strerror(errno)); +inline void resetNumaMemPolicy() +{ + long ret = set_mempolicy(MPOL_DEFAULT, nullptr, 0); + CHECK(ret >= 0, std::strerror(errno)); } // Limit a thread to be on specific cpus. -inline void bindThreadToCpus(std::thread &th, const std::vector &cpus, - const bool ignore_esrch = false) { - cpu_set_t cpuset; - CPU_ZERO(&cpuset); - for (int cpu : cpus) { - CPU_SET(cpu, &cpuset); - } - int ret = - pthread_setaffinity_np(th.native_handle(), sizeof(cpu_set_t), &cpuset); - bool noerr = ignore_esrch ? ret == 0 || ret == ESRCH : ret == 0; - CHECK(noerr, std::strerror(ret)); +inline void +bindThreadToCpus(std::thread& th, const std::vector& cpus, const bool ignore_esrch = false) +{ + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + for(int cpu : cpus) + { + CPU_SET(cpu, &cpuset); + } + int ret = pthread_setaffinity_np(th.native_handle(), sizeof(cpu_set_t), &cpuset); + bool noerr = ignore_esrch ? ret == 0 || ret == ESRCH : ret == 0; + CHECK(noerr, std::strerror(ret)); } // Helper to converts the range string (like "0,2-5,13-17") to a vector of ints. -inline std::vector parseRange(const std::string &s) { - std::vector results; - auto ranges = splitString(s, ","); - for (const auto &range : ranges) { - auto startEnd = splitString(range, "-"); - CHECK((startEnd.size() <= 2), - "Invalid numa_config setting. Expects zero or one '-'."); - if (startEnd.size() == 1) { - results.push_back(std::stoi(startEnd[0])); - } else { - size_t start = std::stoi(startEnd[0]); - size_t last = std::stoi(startEnd[1]); - for (size_t i = start; i <= last; ++i) { - results.push_back(i); - } +inline std::vector parseRange(const std::string& s) +{ + std::vector results; + auto ranges = splitString(s, ","); + for(const auto& range : ranges) + { + auto startEnd = splitString(range, "-"); + CHECK((startEnd.size() <= 2), "Invalid numa_config setting. Expects zero or one '-'."); + if(startEnd.size() == 1) + { + results.push_back(std::stoi(startEnd[0])); + } + else + { + size_t start = std::stoi(startEnd[0]); + size_t last = std::stoi(startEnd[1]); + for(size_t i = start; i <= last; ++i) + { + results.push_back(i); + } + } } - } - return results; + return results; } // Example of the format: "0,2:0-63&1,3:64-127" for 4 GPUs, 128 CPU, 2 NUMA node // system. -inline NumaConfig parseNumaConfig(const std::string &numa_file) { - std::string numa_str; - std::ifstream file(numa_file.c_str()); - if (file.is_open()) { - getline(file, numa_str); - file.close(); - } - - NumaConfig config; - if (!numa_str.empty()) { - auto nodes = splitString(numa_str, "&"); - for (const auto &node : nodes) { - auto pair = splitString(node, ":"); - CHECK((pair.size() == 2), - "Invalid numa_config setting. Expects one ':'."); - auto gpus = parseRange(pair[0]); - auto cpus = parseRange(pair[1]); - config.emplace_back(std::make_pair(gpus, cpus)); +inline NumaConfig parseNumaConfig(const std::string& numa_file) +{ + std::string numa_str; + std::ifstream file(numa_file.c_str()); + if(file.is_open()) + { + getline(file, numa_str); + file.close(); + } + + NumaConfig config; + if(!numa_str.empty()) + { + auto nodes = splitString(numa_str, "&"); + for(const auto& node : nodes) + { + auto pair = splitString(node, ":"); + CHECK((pair.size() == 2), "Invalid numa_config setting. Expects one ':'."); + auto gpus = parseRange(pair[0]); + auto cpus = parseRange(pair[1]); + config.emplace_back(std::make_pair(gpus, cpus)); + } } - } - return config; + return config; } // Convert NumaConfig to GpuToNumaMap for easier look-up. -inline GpuToNumaMap getGpuToNumaMap(const NumaConfig &config) { - std::vector map; - for (size_t numaIdx = 0; numaIdx < config.size(); numaIdx++) { - for (const auto gpuIdx : config[numaIdx].first) { - if (gpuIdx >= map.size()) { - map.resize(gpuIdx + 1); - } - map[gpuIdx] = numaIdx; +inline GpuToNumaMap getGpuToNumaMap(const NumaConfig& config) +{ + std::vector map; + for(size_t numaIdx = 0; numaIdx < config.size(); numaIdx++) + { + for(const auto gpuIdx : config[numaIdx].first) + { + if(gpuIdx >= map.size()) + { + map.resize(gpuIdx + 1); + } + map[gpuIdx] = numaIdx; + } } - } - return map; + return map; } } // namespace mlinfer diff --git a/examples/transformers/mgx_llama2/harness/numpy.hpp b/examples/transformers/mgx_llama2/harness/numpy.hpp index 3aeca0e44f5..5864a48b179 100644 --- a/examples/transformers/mgx_llama2/harness/numpy.hpp +++ b/examples/transformers/mgx_llama2/harness/numpy.hpp @@ -13,90 +13,92 @@ namespace mlinfer { namespace npy { -class NpyFile { -private: - std::string m_Path; - std::ifstream m_FStream; - size_t m_HeaderSize; - std::string m_Header; - size_t m_TensorSize; - size_t m_ElementSize; - std::vector m_TensorDims; +class NpyFile +{ + private: + std::string m_Path; + std::ifstream m_FStream; + size_t m_HeaderSize; + std::string m_Header; + size_t m_TensorSize; + size_t m_ElementSize; + std::vector m_TensorDims; -public: - explicit NpyFile(const std::string &path) : m_Path(path), m_FStream(m_Path) { - LOG_INFO("Npy file from " << path); - // magic and fixed header - char b[256]; - m_FStream.read(b, 10); - CHECK(m_FStream, "Unable to parse: " << m_Path); + public: + explicit NpyFile(const std::string& path) : m_Path(path), m_FStream(m_Path) + { + LOG_INFO("Npy file from " << path); + // magic and fixed header + char b[256]; + m_FStream.read(b, 10); + CHECK(m_FStream, "Unable to parse: " << m_Path); - // check magic - CHECK(static_cast(b[0]) == 0x93 && b[1] == 'N' && - b[2] == 'U' && b[3] == 'M' && b[4] == 'P' && b[5] == 'Y', - "Bad magic: " << m_Path); + // check magic + CHECK(static_cast(b[0]) == 0x93 && b[1] == 'N' && b[2] == 'U' && + b[3] == 'M' && b[4] == 'P' && b[5] == 'Y', + "Bad magic: " << m_Path); - // get header - auto major = static_cast(b[6]); - // auto minor = static_cast(b[7]); - CHECK(major == 1, "Only npy version 1 is supported: " << m_Path); - m_HeaderSize = static_cast(b[8]); - m_Header.resize(m_HeaderSize); - m_FStream.read(static_cast(m_Header.data()), m_HeaderSize); + // get header + auto major = static_cast(b[6]); + // auto minor = static_cast(b[7]); + CHECK(major == 1, "Only npy version 1 is supported: " << m_Path); + m_HeaderSize = static_cast(b[8]); + m_Header.resize(m_HeaderSize); + m_FStream.read(static_cast(m_Header.data()), m_HeaderSize); - // get file size - auto cur = m_FStream.tellg(); - m_FStream.seekg(0, std::ios::end); - auto size = m_FStream.tellg(); - m_TensorSize = size - cur; + // get file size + auto cur = m_FStream.tellg(); + m_FStream.seekg(0, std::ios::end); + auto size = m_FStream.tellg(); + m_TensorSize = size - cur; - // parse header - std::regex re( - R"re(\{'descr': '[<|][fi]([\d])', 'fortran_order': False, 'shape': \(([\d, ]*)\), \} +\n)re"); - std::smatch matches; - CHECK(std::regex_match(m_Header, matches, re), - "Cannot parse numpy header: " << m_Path); - CHECK(matches.size() == 3, "Cannot parse numpy header: " << m_Path); - m_ElementSize = std::stoi(matches[1]); - std::vector dims = splitString(matches[2], ", "); - m_TensorDims.resize(dims.size()); - std::transform(dims.begin(), dims.end(), m_TensorDims.begin(), - [](const std::string &s) { return std::stoi(s); }); + // parse header + std::regex re( + R"re(\{'descr': '[<|][fi]([\d])', 'fortran_order': False, 'shape': \(([\d, ]*)\), \} +\n)re"); + std::smatch matches; + CHECK(std::regex_match(m_Header, matches, re), "Cannot parse numpy header: " << m_Path); + CHECK(matches.size() == 3, "Cannot parse numpy header: " << m_Path); + m_ElementSize = std::stoi(matches[1]); + std::vector dims = splitString(matches[2], ", "); + m_TensorDims.resize(dims.size()); + std::transform(dims.begin(), dims.end(), m_TensorDims.begin(), [](const std::string& s) { + return std::stoi(s); + }); - // check header sanity - size_t tensorSize = - std::accumulate(m_TensorDims.begin(), m_TensorDims.end(), m_ElementSize, - std::multiplies()); - CHECK(tensorSize == m_TensorSize, - "Header description does not match file size: " << m_Path); - LOG_DEBUG(" Input num=" << m_TensorDims[0] << " | Sample size=" - << (tensorSize / m_TensorDims[0]) - << " | Full size=" << m_TensorSize); - } - ~NpyFile() { m_FStream.close(); }; - std::string GetPath() const { return m_Path; } - std::vector GetDims() const { return m_TensorDims; } - size_t GetTensorSize() const { return m_TensorSize; } - // load the entire tensor - void LoadAll(void *dst) { - m_FStream.seekg(10 + m_HeaderSize, std::ios::beg); - m_FStream.read(static_cast(dst), m_TensorSize); - CHECK(m_FStream, "Unable to parse: " << m_Path); - CHECK(m_FStream.peek() == EOF, "Did not consume full file: " << m_Path); - } + // check header sanity + size_t tensorSize = std::accumulate( + m_TensorDims.begin(), m_TensorDims.end(), m_ElementSize, std::multiplies()); + CHECK(tensorSize == m_TensorSize, + "Header description does not match file size: " << m_Path); + LOG_DEBUG(" Input num=" << m_TensorDims[0] + << " | Sample size=" << (tensorSize / m_TensorDims[0]) + << " | Full size=" << m_TensorSize); + } + ~NpyFile() { m_FStream.close(); }; + std::string GetPath() const { return m_Path; } + std::vector GetDims() const { return m_TensorDims; } + size_t GetTensorSize() const { return m_TensorSize; } + // load the entire tensor + void LoadAll(void* dst) + { + m_FStream.seekg(10 + m_HeaderSize, std::ios::beg); + m_FStream.read(static_cast(dst), m_TensorSize); + CHECK(m_FStream, "Unable to parse: " << m_Path); + CHECK(m_FStream.peek() == EOF, "Did not consume full file: " << m_Path); + } - // load only selected indices from the Tensor, assuming that the first dim is - // batch dim. - void LoadSamples(void *dst, const std::vector &indices) { - size_t sampleSize = - std::accumulate(m_TensorDims.begin() + 1, m_TensorDims.end(), - m_ElementSize, std::multiplies()); - for (size_t i = 0; i < indices.size(); i++) { - m_FStream.seekg(10 + m_HeaderSize + indices[i] * sampleSize, - std::ios::beg); - m_FStream.read(static_cast(dst) + i * sampleSize, sampleSize); + // load only selected indices from the Tensor, assuming that the first dim is + // batch dim. + void LoadSamples(void* dst, const std::vector& indices) + { + size_t sampleSize = std::accumulate( + m_TensorDims.begin() + 1, m_TensorDims.end(), m_ElementSize, std::multiplies()); + for(size_t i = 0; i < indices.size(); i++) + { + m_FStream.seekg(10 + m_HeaderSize + indices[i] * sampleSize, std::ios::beg); + m_FStream.read(static_cast(dst) + i * sampleSize, sampleSize); + } } - } }; } // namespace npy } // namespace mlinfer diff --git a/examples/transformers/mgx_llama2/harness/timer.hpp b/examples/transformers/mgx_llama2/harness/timer.hpp index cfcc20a6ad1..91a68373333 100644 --- a/examples/transformers/mgx_llama2/harness/timer.hpp +++ b/examples/transformers/mgx_llama2/harness/timer.hpp @@ -9,56 +9,62 @@ #include // For debugging the timing of each part -class Timer { -public: - explicit Timer(const std::string &tag_, bool verbose_ = false) - : tag(tag_), verbose(verbose_) { - std::cout << "Timer " << tag << " created." << std::endl; - } - void add(const std::chrono::duration &in) { - std::thread::id id = std::this_thread::get_id(); - count[id] += 1; - total[id] += in; - if (verbose) - measurements[id].emplace_back(in); - } - ~Timer() { - auto total_accum = std::accumulate( - std::begin(total), std::end(total), 0, - [](int64_t value, - std::pair> - p) { return value + p.second.count(); }); +class Timer +{ + public: + explicit Timer(const std::string& tag_, bool verbose_ = false) : tag(tag_), verbose(verbose_) + { + std::cout << "Timer " << tag << " created." << std::endl; + } + void add(const std::chrono::duration& in) + { + std::thread::id id = std::this_thread::get_id(); + count[id] += 1; + total[id] += in; + if(verbose) + measurements[id].emplace_back(in); + } + ~Timer() + { + auto total_accum = std::accumulate( + std::begin(total), + std::end(total), + 0, + [](int64_t value, + std::pair> p) { + return value + p.second.count(); + }); - auto count_accum = - std::accumulate(std::begin(count), std::end(count), 0, - [](size_t value, std::pair p) { - return value + p.second; - }); + auto count_accum = std::accumulate( + std::begin(count), + std::end(count), + 0, + [](size_t value, std::pair p) { return value + p.second; }); - std::cout << "Timer " << tag << " reports " - << (double)total_accum / count_accum << " ms per call for " - << count_accum << " times." << std::endl; - if (verbose) { - std::cout << " Measurements=["; - for (const auto &m : measurements) { - std::cout << " Thread " << m.first << ": {"; - for (const auto &d : m.second) { - std::cout << d.count() << ","; - } + std::cout << "Timer " << tag << " reports " << (double)total_accum / count_accum + << " ms per call for " << count_accum << " times." << std::endl; + if(verbose) + { + std::cout << " Measurements=["; + for(const auto& m : measurements) + { + std::cout << " Thread " << m.first << ": {"; + for(const auto& d : m.second) + { + std::cout << d.count() << ","; + } - std::cout << "},"; - } - std::cout << "]" << std::endl; + std::cout << "},"; + } + std::cout << "]" << std::endl; + } } - } -private: - std::string tag; - bool verbose; - std::unordered_map> - total; - std::unordered_map>> - measurements; - std::unordered_map count; + private: + std::string tag; + bool verbose; + std::unordered_map> total; + std::unordered_map>> + measurements; + std::unordered_map count; }; diff --git a/examples/transformers/mgx_llama2/harness/utils.hpp b/examples/transformers/mgx_llama2/harness/utils.hpp index 794f0bf8ae3..151f75acd12 100644 --- a/examples/transformers/mgx_llama2/harness/utils.hpp +++ b/examples/transformers/mgx_llama2/harness/utils.hpp @@ -8,178 +8,203 @@ #include -struct ModelLoadSettings { - size_t sequnce_length; - bool quantize_fp16; - bool fast_math; - bool input_one_dim; - size_t batch_size; +struct ModelLoadSettings +{ + size_t sequnce_length; + bool quantize_fp16; + bool fast_math; + bool input_one_dim; + size_t batch_size; }; -static std::string getModelPath(ModelLoadSettings &s) { - std::stringstream path; - path << MODEL_FOLDER << "model-" << std::to_string(s.sequnce_length) << "_fp" - << (s.quantize_fp16 ? "16" : "32") << "_"; - path << "batch_" << std::to_string(s.batch_size) << "_"; - if (!s.fast_math) { - path << "no"; - } - path << "fastmath"; - if (s.input_one_dim) { - path << "_inputonedim"; - } - path << ".mxr"; - return path.str(); +static std::string getModelPath(ModelLoadSettings& s) +{ + std::stringstream path; + path << MODEL_FOLDER << "model-" << std::to_string(s.sequnce_length) << "_fp" + << (s.quantize_fp16 ? "16" : "32") << "_"; + path << "batch_" << std::to_string(s.batch_size) << "_"; + if(!s.fast_math) + { + path << "no"; + } + path << "fastmath"; + if(s.input_one_dim) + { + path << "_inputonedim"; + } + path << ".mxr"; + return path.str(); } -[[maybe_unused]] static std::string getPastKeyString(size_t i) { - std::stringstream past_key; - past_key << "past_key_values." << std::to_string(i) << ".key"; - return past_key.str(); +[[maybe_unused]] static std::string getPastKeyString(size_t i) +{ + std::stringstream past_key; + past_key << "past_key_values." << std::to_string(i) << ".key"; + return past_key.str(); } -[[maybe_unused]] static std::string getPastValueStr(size_t i) { - std::stringstream past_val; - past_val << "past_key_values." << std::to_string(i) << ".value"; - return past_val.str(); +[[maybe_unused]] static std::string getPastValueStr(size_t i) +{ + std::stringstream past_val; + past_val << "past_key_values." << std::to_string(i) << ".value"; + return past_val.str(); } -[[maybe_unused]] static std::string getPresentKeyString(size_t i) { - std::stringstream past_key; - past_key << "present." << std::to_string(i) << ".key"; - return past_key.str(); +[[maybe_unused]] static std::string getPresentKeyString(size_t i) +{ + std::stringstream past_key; + past_key << "present." << std::to_string(i) << ".key"; + return past_key.str(); } -[[maybe_unused]] static std::string getPresentValueStr(size_t i) { - std::stringstream past_val; - past_val << "present." << std::to_string(i) << ".value"; - return past_val.str(); +[[maybe_unused]] static std::string getPresentValueStr(size_t i) +{ + std::stringstream past_val; + past_val << "present." << std::to_string(i) << ".value"; + return past_val.str(); } -static migraphx::program loadOnnx(ModelLoadSettings &settings) { - std::filesystem::path onnx_path(MODEL_FOLDER + ONNX_FILE); +static migraphx::program loadOnnx(ModelLoadSettings& settings) +{ + std::filesystem::path onnx_path(MODEL_FOLDER + ONNX_FILE); #ifdef TRACE - std::cout << "Using model: " << MODEL_FOLDER + ONNX_FILE << std::endl; + std::cout << "Using model: " << MODEL_FOLDER + ONNX_FILE << std::endl; #endif - migraphx::program prog; - std::ifstream f(onnx_path.c_str()); - if (f.good()) { - migraphx::onnx_options onnx_opts; - std::vector dims = {BATCH_SIZE, SEQ_SIZE}; - std::vector dimsPastKey = {BATCH_SIZE, HIDDEN_LAYERS_NUM, - SEQ_SIZE, HEAD_SIZE}; - std::vector inputDim; - if (settings.input_one_dim) { - inputDim = {BATCH_SIZE, 1}; - } else { - inputDim = dims; - } - onnx_opts.set_input_parameter_shape("input_ids", inputDim); - onnx_opts.set_input_parameter_shape("attention_mask", dims); - for (size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) { - onnx_opts.set_input_parameter_shape(getPastKeyString(i), dimsPastKey); - onnx_opts.set_input_parameter_shape(getPastValueStr(i), dimsPastKey); + migraphx::program prog; + std::ifstream f(onnx_path.c_str()); + if(f.good()) + { + migraphx::onnx_options onnx_opts; + std::vector dims = {BATCH_SIZE, SEQ_SIZE}; + std::vector dimsPastKey = {BATCH_SIZE, HIDDEN_LAYERS_NUM, SEQ_SIZE, HEAD_SIZE}; + std::vector inputDim; + if(settings.input_one_dim) + { + inputDim = {BATCH_SIZE, 1}; + } + else + { + inputDim = dims; + } + onnx_opts.set_input_parameter_shape("input_ids", inputDim); + onnx_opts.set_input_parameter_shape("attention_mask", dims); + for(size_t i = 0; i < HIDDEN_LAYERS_NUM; ++i) + { + onnx_opts.set_input_parameter_shape(getPastKeyString(i), dimsPastKey); + onnx_opts.set_input_parameter_shape(getPastValueStr(i), dimsPastKey); + } + std::cout << "Parsing onnx file ..." << std::endl; + prog = parse_onnx(onnx_path.c_str(), onnx_opts); + + std::string target_str = "gpu"; + migraphx::target targ = migraphx::target(target_str.c_str()); + + if(settings.quantize_fp16) + { + std::cout << "Quantize FP16 ..." << std::endl; + migraphx::quantize_fp16(prog); + } + + migraphx::compile_options comp_opts; + + if(settings.fast_math) + comp_opts.set_fast_math(); + + comp_opts.set_exhaustive_tune_flag(); + + std::cout << "Compile to target ..." << std::endl; + prog.compile(targ, comp_opts); + + std::string modelPath = getModelPath(settings); + migraphx::file_options file_options; + file_options.set_file_format("msgpack"); + std::cout << "Saving mxr file to: " << modelPath << "\n"; + migraphx::save(prog, modelPath.c_str(), file_options); } - std::cout << "Parsing onnx file ..." << std::endl; - prog = parse_onnx(onnx_path.c_str(), onnx_opts); - - std::string target_str = "gpu"; - migraphx::target targ = migraphx::target(target_str.c_str()); - - if (settings.quantize_fp16) { - std::cout << "Quantize FP16 ..." << std::endl; - migraphx::quantize_fp16(prog); + else + { + std::cerr << "Onnx file is not available on path: " << onnx_path << std::endl; + exit(1); } + return prog; +}; - migraphx::compile_options comp_opts; - - if (settings.fast_math) - comp_opts.set_fast_math(); - - comp_opts.set_exhaustive_tune_flag(); - - std::cout << "Compile to target ..." << std::endl; - prog.compile(targ, comp_opts); +static migraphx::program loadProgram(ModelLoadSettings& settings) +{ + std::filesystem::path compiled_path(getModelPath(settings)); - std::string modelPath = getModelPath(settings); migraphx::file_options file_options; file_options.set_file_format("msgpack"); - std::cout << "Saving mxr file to: " << modelPath << "\n"; - migraphx::save(prog, modelPath.c_str(), file_options); - } else { - std::cerr << "Onnx file is not available on path: " << onnx_path - << std::endl; - exit(1); - } - return prog; -}; -static migraphx::program loadProgram(ModelLoadSettings &settings) { - std::filesystem::path compiled_path(getModelPath(settings)); - - migraphx::file_options file_options; - file_options.set_file_format("msgpack"); - - migraphx::program prog; - std::ifstream f(compiled_path.c_str()); - if (f.good()) { - std::cout << "Loading model from " << compiled_path << " ...\n"; - prog = migraphx::load(compiled_path.c_str(), file_options); - } else { - std::cout << "MXR file can't be loaded try to load ONNX\n"; - prog = loadOnnx(settings); - } - return prog; + migraphx::program prog; + std::ifstream f(compiled_path.c_str()); + if(f.good()) + { + std::cout << "Loading model from " << compiled_path << " ...\n"; + prog = migraphx::load(compiled_path.c_str(), file_options); + } + else + { + std::cout << "MXR file can't be loaded try to load ONNX\n"; + prog = loadOnnx(settings); + } + return prog; }; -static migraphx::program create_argmax_program(ModelLoadSettings &settings) { - migraphx::program prog; - std::vector dims{BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}; - if (settings.input_one_dim) { - dims[1] = 1; - } - migraphx::shape s{migraphx_shape_half_type, dims}; - migraphx::module m = prog.get_main_module(); - auto x = m.add_parameter("x", s); - auto argmax_ins = - m.add_instruction(migraphx::operation("argmax", "{axis: 2}"), {x}); - m.add_return({argmax_ins}); +static migraphx::program create_argmax_program(ModelLoadSettings& settings) +{ + migraphx::program prog; + std::vector dims{BATCH_SIZE, SEQ_SIZE, VOCAB_SIZE}; + if(settings.input_one_dim) + { + dims[1] = 1; + } + migraphx::shape s{migraphx_shape_half_type, dims}; + migraphx::module m = prog.get_main_module(); + auto x = m.add_parameter("x", s); + auto argmax_ins = m.add_instruction(migraphx::operation("argmax", "{axis: 2}"), {x}); + m.add_return({argmax_ins}); - std::cout << "Creating ArgMax program ..." << std::endl; + std::cout << "Creating ArgMax program ..." << std::endl; - std::string target_str = "gpu"; - migraphx::target targ = migraphx::target(target_str.c_str()); + std::string target_str = "gpu"; + migraphx::target targ = migraphx::target(target_str.c_str()); - if (settings.quantize_fp16) { - std::cout << "Quantize FP16 ..." << std::endl; - migraphx::quantize_fp16(prog); - } + if(settings.quantize_fp16) + { + std::cout << "Quantize FP16 ..." << std::endl; + migraphx::quantize_fp16(prog); + } - migraphx::compile_options comp_opts; + migraphx::compile_options comp_opts; - if (settings.fast_math) - comp_opts.set_fast_math(); + if(settings.fast_math) + comp_opts.set_fast_math(); - comp_opts.set_exhaustive_tune_flag(); + comp_opts.set_exhaustive_tune_flag(); - std::cout << "Compile to target ..." << std::endl; - prog.compile(targ, comp_opts); + std::cout << "Compile to target ..." << std::endl; + prog.compile(targ, comp_opts); - return prog; + return prog; } -static void writeResults(const std::vector> &results) { - std::string RESULT_FILE = "result.txt"; - std::ofstream outFile(RESULT_FILE); - for (auto &resVec : results) { - for (auto &res : resVec) { - outFile << res; - if (&res != &resVec.back()) { - outFile << ", "; - } +static void writeResults(const std::vector>& results) +{ + std::string RESULT_FILE = "result.txt"; + std::ofstream outFile(RESULT_FILE); + for(auto& resVec : results) + { + for(auto& res : resVec) + { + outFile << res; + if(&res != &resVec.back()) + { + outFile << ", "; + } + } + outFile << "\n"; } - outFile << "\n"; - } } From 7d0bb2f60719a3c6f76e36d3b5c25b921d62edd9 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Wed, 4 Dec 2024 07:42:26 -0600 Subject: [PATCH 50/55] Fix delete modifier ident --- examples/transformers/mgx_llama2/harness/buffer.hpp | 6 +++--- examples/transformers/mgx_llama2/harness/common.hpp | 6 +++--- examples/transformers/mgx_llama2/harness/dataset.hpp | 2 +- examples/transformers/mgx_llama2/harness/llama2inputs.hpp | 4 ++-- examples/transformers/mgx_llama2/harness/llama2outputs.hpp | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/transformers/mgx_llama2/harness/buffer.hpp b/examples/transformers/mgx_llama2/harness/buffer.hpp index 645c1dd748b..83d0c7a75b1 100644 --- a/examples/transformers/mgx_llama2/harness/buffer.hpp +++ b/examples/transformers/mgx_llama2/harness/buffer.hpp @@ -50,7 +50,7 @@ struct GenericBuffer : public IBuffer return *this; } - GenericBuffer(const GenericBuffer& buf) = delete; + GenericBuffer(const GenericBuffer& buf) = delete; GenericBuffer& operator=(const GenericBuffer& buf) = delete; ~GenericBuffer() { this->free_fn(tensor_ptr); } @@ -160,8 +160,8 @@ struct ManagedBuffer_v2 upload_to_device(stream, position, position + 1); } - ManagedBuffer_v2() = delete; - ManagedBuffer_v2(const ManagedBuffer_v2& buf) = delete; + ManagedBuffer_v2() = delete; + ManagedBuffer_v2(const ManagedBuffer_v2& buf) = delete; ManagedBuffer_v2& operator=(const ManagedBuffer_v2& buf) = delete; DeviceBuffer dbuff; diff --git a/examples/transformers/mgx_llama2/harness/common.hpp b/examples/transformers/mgx_llama2/harness/common.hpp index 0f7ba9f92f8..29454f7704f 100644 --- a/examples/transformers/mgx_llama2/harness/common.hpp +++ b/examples/transformers/mgx_llama2/harness/common.hpp @@ -22,9 +22,9 @@ using namespace half_float::literal; namespace mlinfer { struct INoCopy { - INoCopy() = default; - virtual ~INoCopy() = default; - INoCopy(const INoCopy&) = delete; + INoCopy() = default; + virtual ~INoCopy() = default; + INoCopy(const INoCopy&) = delete; INoCopy& operator=(const INoCopy&) = delete; }; diff --git a/examples/transformers/mgx_llama2/harness/dataset.hpp b/examples/transformers/mgx_llama2/harness/dataset.hpp index 850862ff86f..3aeba8ef252 100644 --- a/examples/transformers/mgx_llama2/harness/dataset.hpp +++ b/examples/transformers/mgx_llama2/harness/dataset.hpp @@ -109,7 +109,7 @@ struct Dataset return _current_batch; } - Dataset(const Dataset& buf) = delete; + Dataset(const Dataset& buf) = delete; Dataset& operator=(const Dataset& buf) = delete; private: diff --git a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp index d33934f342e..3ddfff6b294 100644 --- a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp @@ -118,8 +118,8 @@ struct LLama2Inputs size_t dataSize() const { return data.size(); } size_t batchNum() const { return data.batchNum(); } - LLama2Inputs() = delete; - LLama2Inputs(const LLama2Inputs& buf) = delete; + LLama2Inputs() = delete; + LLama2Inputs(const LLama2Inputs& buf) = delete; LLama2Inputs& operator=(const LLama2Inputs& buf) = delete; std::unique_ptr input_ids_buffer; diff --git a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp index 7fb5ebe0d6a..a6729dfc3b1 100644 --- a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp @@ -46,7 +46,7 @@ struct LLama2Outputs migraphx::argument(argm_out_shape_one_dim, argm_output_buffer_one_dim->data())); } - LLama2Outputs(const LLama2Outputs& buf) = delete; + LLama2Outputs(const LLama2Outputs& buf) = delete; LLama2Outputs& operator=(const LLama2Outputs& buf) = delete; std::unique_ptr output_buffer; From 43364a103b0a3b7dd0cb39b394ffa026eaca593e Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Wed, 4 Dec 2024 07:59:12 -0600 Subject: [PATCH 51/55] Fix python files format issues --- .../transformers/mgx_llama2/eval_accuracy.py | 87 +++++++++---------- .../mgx_llama2/preprocess_dataset.py | 15 ++-- 2 files changed, 49 insertions(+), 53 deletions(-) diff --git a/examples/transformers/mgx_llama2/eval_accuracy.py b/examples/transformers/mgx_llama2/eval_accuracy.py index ae175df991a..a4cd85776ff 100644 --- a/examples/transformers/mgx_llama2/eval_accuracy.py +++ b/examples/transformers/mgx_llama2/eval_accuracy.py @@ -6,7 +6,6 @@ import nltk from transformers import AutoTokenizer - MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" G_MAX_TOK_LEN = 1024 @@ -16,12 +15,14 @@ DATASET_PATH = "/dataset/open_orca_gpt4_tokenized_llama.sampled_24576.pkl" RESULT_PATH = "build/result.txt" + def main(dataset_path, result_path, sample_size, sequence_size): tokenizer = AutoTokenizer.from_pretrained( - MODEL_NAME, - model_max_length=sequence_size, - padding_side="left", - use_fast=False,) + MODEL_NAME, + model_max_length=sequence_size, + padding_side="left", + use_fast=False, + ) metric = evaluate.load("rouge") nltk.download("punkt_tab") @@ -31,36 +32,35 @@ def main(dataset_path, result_path, sample_size, sequence_size): with _p.open(mode="rb") as f: d = pickle.load(f) - target = d['output'].to_list() targets = target[0:sample_size] results, gen_tok_len = readResult(result_path) - preds = tokenizer.batch_decode( - results, skip_special_tokens=True - ) + preds = tokenizer.batch_decode(results, skip_special_tokens=True) postprocess_text(preds, target) - result = metric.compute( - predictions=preds, references=targets, use_stemmer=True, use_aggregator=False - ) + result = metric.compute(predictions=preds, + references=targets, + use_stemmer=True, + use_aggregator=False) result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} prediction_lens = [len(pred) for pred in preds] gen_num = len(preds) result = { - **result, - "gen_len": np.sum(prediction_lens), - "gen_num": gen_num, - "gen_tok_len": gen_tok_len, - "tokens_per_sample": round(gen_tok_len / gen_num, 1), - } + **result, + "gen_len": np.sum(prediction_lens), + "gen_num": gen_num, + "gen_tok_len": gen_tok_len, + "tokens_per_sample": round(gen_tok_len / gen_num, 1), + } print("\nResults\n") print(result) + def readResult(path): results = [] tok_len = 0 @@ -72,6 +72,7 @@ def readResult(path): tok_len += len(result) return results, tok_len + def postprocess_text(preds, targets): preds = [pred.strip() for pred in preds] targets = [target.strip() for target in targets] @@ -84,34 +85,26 @@ def postprocess_text(preds, targets): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument( - "-d", - "--dataset-path", - help="Path to the dataset pickle file", - default=DATASET_PATH - ) - - parser.add_argument( - "-r", - "--result_path", - help="Path to output tokens result file", - default=RESULT_PATH - ) - - parser.add_argument( - "-size", - "--sample-size", - help="Sample size of dataset", - type=int, - default=SAMPLE_SIZE - ) - - parser.add_argument( - "-seq_size", - "--sequence_size", - help="Size of sequence", - type=int, - default=G_MAX_TOK_LEN - ) + parser.add_argument("-d", + "--dataset-path", + help="Path to the dataset pickle file", + default=DATASET_PATH) + + parser.add_argument("-r", + "--result_path", + help="Path to output tokens result file", + default=RESULT_PATH) + + parser.add_argument("-size", + "--sample-size", + help="Sample size of dataset", + type=int, + default=SAMPLE_SIZE) + + parser.add_argument("-seq_size", + "--sequence_size", + help="Size of sequence", + type=int, + default=G_MAX_TOK_LEN) main(**vars(parser.parse_args())) \ No newline at end of file diff --git a/examples/transformers/mgx_llama2/preprocess_dataset.py b/examples/transformers/mgx_llama2/preprocess_dataset.py index 8632355372f..578d2ba8bd5 100644 --- a/examples/transformers/mgx_llama2/preprocess_dataset.py +++ b/examples/transformers/mgx_llama2/preprocess_dataset.py @@ -16,22 +16,25 @@ raise RuntimeError(f"Missing dataset from {DATASET_PATH}") toks = d['tok_input'].to_list() -#toks = [toks[0]] toks_np = np.ones((len(toks), G_MAX_TOK_LEN), dtype=np.int64) * G_LLAMA2_EOS mask_np = np.zeros((len(toks), G_MAX_TOK_LEN), dtype=np.int64) -position_nps = [np.arange(0, G_MAX_TOK_LEN, dtype=np.int64) for _ in range(len(toks))] +position_nps = [ + np.arange(0, G_MAX_TOK_LEN, dtype=np.int64) for _ in range(len(toks)) +] for i, q in enumerate(toks): toks_np[i, :len(q)] = q mask_np[i, :len(q)] = np.ones_like(q) - token_size = len(toks) -np.save(f"{OUTPUT_PATH}input_ids_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", toks_np) -np.save(f"{OUTPUT_PATH}attention_mask_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", mask_np) -np.save(f"{OUTPUT_PATH}position_ids_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", position_nps) +np.save(f"{OUTPUT_PATH}input_ids_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", + toks_np) +np.save(f"{OUTPUT_PATH}attention_mask_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", + mask_np) +np.save(f"{OUTPUT_PATH}position_ids_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", + position_nps) print("Npy files are created") From 5f2190fdbf68127c52eac02ad289e0721fddff2f Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Wed, 4 Dec 2024 08:13:13 -0600 Subject: [PATCH 52/55] Fix python files format issues 2 --- examples/transformers/mgx_llama2/eval_accuracy.py | 1 + examples/transformers/mgx_llama2/preprocess_dataset.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/transformers/mgx_llama2/eval_accuracy.py b/examples/transformers/mgx_llama2/eval_accuracy.py index a4cd85776ff..2bdf080e0b8 100644 --- a/examples/transformers/mgx_llama2/eval_accuracy.py +++ b/examples/transformers/mgx_llama2/eval_accuracy.py @@ -83,6 +83,7 @@ def postprocess_text(preds, targets): return preds, targets + if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-d", diff --git a/examples/transformers/mgx_llama2/preprocess_dataset.py b/examples/transformers/mgx_llama2/preprocess_dataset.py index 578d2ba8bd5..ed2692299af 100644 --- a/examples/transformers/mgx_llama2/preprocess_dataset.py +++ b/examples/transformers/mgx_llama2/preprocess_dataset.py @@ -23,7 +23,6 @@ np.arange(0, G_MAX_TOK_LEN, dtype=np.int64) for _ in range(len(toks)) ] - for i, q in enumerate(toks): toks_np[i, :len(q)] = q mask_np[i, :len(q)] = np.ones_like(q) @@ -32,8 +31,9 @@ np.save(f"{OUTPUT_PATH}input_ids_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", toks_np) -np.save(f"{OUTPUT_PATH}attention_mask_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", - mask_np) +np.save( + f"{OUTPUT_PATH}attention_mask_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", + mask_np) np.save(f"{OUTPUT_PATH}position_ids_size_{token_size}_seq_{G_MAX_TOK_LEN}.npy", position_nps) From 0640809a9e68127f7619895722f0eff501213ea7 Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Thu, 5 Dec 2024 06:45:22 -0600 Subject: [PATCH 53/55] Make output results const for cppcheck --- examples/transformers/mgx_llama2/mgxllama2.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index 8512dcf6bce..a1785101cdd 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -104,9 +104,9 @@ struct MGXLlama2 firstIter ? model_outputs->argm_output_buffer->download_from_device(stream, sampleLastInputIdx[b], sampleLastInputIdx[b] + 1) : model_outputs->argm_output_buffer_one_dim->download_from_device(stream); check_hip_status(hipStreamSynchronize(stream)); - int64_t* results = reinterpret_cast( firstIter ? model_outputs->argm_output_buffer->hbuff.data() : model_outputs->argm_output_buffer_one_dim->hbuff.data()); + const int64_t* output_results = reinterpret_cast( firstIter ? model_outputs->argm_output_buffer->hbuff.data() : model_outputs->argm_output_buffer_one_dim->hbuff.data()); auto new_token_idx = firstIter ? sampleLastInputIdx[b] : b; - int64_t new_token = results[new_token_idx]; + int64_t new_token = output_results[new_token_idx]; token_count++; #ifdef TRACE From 20fb8bc3db7cbeab642fa94b1fd341f8b949128b Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Thu, 5 Dec 2024 10:25:44 -0600 Subject: [PATCH 54/55] Pass GPU_TARGET from build_docker script to Dockerfile --- examples/transformers/mgx_llama2/Dockerfile | 5 +++-- examples/transformers/mgx_llama2/build_docker.sh | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/transformers/mgx_llama2/Dockerfile b/examples/transformers/mgx_llama2/Dockerfile index 6525fb81619..193c3ed400e 100644 --- a/examples/transformers/mgx_llama2/Dockerfile +++ b/examples/transformers/mgx_llama2/Dockerfile @@ -2,6 +2,8 @@ FROM rocm/dev-ubuntu-22.04:6.2 ENV DEBIAN_FRONTEND=noninteractive +ARG GPU_TARGET + SHELL ["/bin/bash", "-c"] RUN apt-get update && apt-get install -y --allow-unauthenticated \ @@ -27,9 +29,8 @@ ENV MIGRAPHX_ENABLE_HIPBLASLT_GEMM=1 ENV MIGRAPHX_USE_HIPBLASLT=1 ENV MIGRAPHX_USE_MIOPEN=1 -#TODO: use $(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') for GPU_TARGETS RUN mkdir build && cd build && \ - CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DGPU_TARGETS='gfx942' && \ + CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DGPU_TARGETS=${GPU_TARGET} && \ make -j$(nproc) && \ make install diff --git a/examples/transformers/mgx_llama2/build_docker.sh b/examples/transformers/mgx_llama2/build_docker.sh index 4aafe5bea88..051d8007ebe 100755 --- a/examples/transformers/mgx_llama2/build_docker.sh +++ b/examples/transformers/mgx_llama2/build_docker.sh @@ -1,3 +1,5 @@ #!/bin/bash -docker build --platform linux/amd64 --tag mgx_llama2:v0.2 --file Dockerfile . +GPU_TARGET=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') + +docker build --platform linux/amd64 --tag mgx_llama2:v0.2 --build-arg GPU_TARGET=$GPU_TARGET --file Dockerfile . From b499638d01392182bcbf5155c051fdef576ef24f Mon Sep 17 00:00:00 2001 From: Oliver Toth Date: Fri, 6 Dec 2024 03:21:37 -0600 Subject: [PATCH 55/55] Add license --- .../transformers/mgx_llama2/CMakeLists.txt | 23 +++++++++++++++++ .../transformers/mgx_llama2/build_docker.sh | 25 +++++++++++++++++++ .../transformers/mgx_llama2/eval_accuracy.py | 23 +++++++++++++++++ .../mgx_llama2/harness/buffer.hpp | 23 +++++++++++++++++ .../mgx_llama2/harness/common.hpp | 23 +++++++++++++++++ .../mgx_llama2/harness/config.hpp | 24 +++++++++++++++++- .../mgx_llama2/harness/dataset.hpp | 23 +++++++++++++++++ .../mgx_llama2/harness/llama2inputs.hpp | 23 +++++++++++++++++ .../mgx_llama2/harness/llama2outputs.hpp | 23 +++++++++++++++++ .../mgx_llama2/harness/logging.hpp | 23 +++++++++++++++++ .../transformers/mgx_llama2/harness/numa.hpp | 23 +++++++++++++++++ .../transformers/mgx_llama2/harness/numpy.hpp | 23 +++++++++++++++++ .../transformers/mgx_llama2/harness/timer.hpp | 23 +++++++++++++++++ .../transformers/mgx_llama2/harness/utils.hpp | 23 +++++++++++++++++ examples/transformers/mgx_llama2/mgxllama2.cc | 23 +++++++++++++++++ .../mgx_llama2/preprocess_dataset.py | 23 +++++++++++++++++ .../transformers/mgx_llama2/run_docker.sh | 25 +++++++++++++++++++ 17 files changed, 395 insertions(+), 1 deletion(-) diff --git a/examples/transformers/mgx_llama2/CMakeLists.txt b/examples/transformers/mgx_llama2/CMakeLists.txt index ec27b8ff203..233a58f2d90 100644 --- a/examples/transformers/mgx_llama2/CMakeLists.txt +++ b/examples/transformers/mgx_llama2/CMakeLists.txt @@ -1,3 +1,26 @@ +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +##################################################################################### project(MGXLlama2) cmake_minimum_required(VERSION 3.22) diff --git a/examples/transformers/mgx_llama2/build_docker.sh b/examples/transformers/mgx_llama2/build_docker.sh index 051d8007ebe..82eed2f85f1 100755 --- a/examples/transformers/mgx_llama2/build_docker.sh +++ b/examples/transformers/mgx_llama2/build_docker.sh @@ -1,5 +1,30 @@ #!/bin/bash +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +##################################################################################### + GPU_TARGET=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') docker build --platform linux/amd64 --tag mgx_llama2:v0.2 --build-arg GPU_TARGET=$GPU_TARGET --file Dockerfile . diff --git a/examples/transformers/mgx_llama2/eval_accuracy.py b/examples/transformers/mgx_llama2/eval_accuracy.py index 2bdf080e0b8..8244fee3e41 100644 --- a/examples/transformers/mgx_llama2/eval_accuracy.py +++ b/examples/transformers/mgx_llama2/eval_accuracy.py @@ -1,3 +1,26 @@ +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +##################################################################################### from argparse import ArgumentParser import numpy as np import pickle diff --git a/examples/transformers/mgx_llama2/harness/buffer.hpp b/examples/transformers/mgx_llama2/harness/buffer.hpp index 83d0c7a75b1..dd1acff970a 100644 --- a/examples/transformers/mgx_llama2/harness/buffer.hpp +++ b/examples/transformers/mgx_llama2/harness/buffer.hpp @@ -1,3 +1,26 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ #pragma once #include "common.hpp" diff --git a/examples/transformers/mgx_llama2/harness/common.hpp b/examples/transformers/mgx_llama2/harness/common.hpp index 29454f7704f..bdb497f539d 100644 --- a/examples/transformers/mgx_llama2/harness/common.hpp +++ b/examples/transformers/mgx_llama2/harness/common.hpp @@ -1,3 +1,26 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ #pragma once #include diff --git a/examples/transformers/mgx_llama2/harness/config.hpp b/examples/transformers/mgx_llama2/harness/config.hpp index 69867402296..e70ad79f334 100644 --- a/examples/transformers/mgx_llama2/harness/config.hpp +++ b/examples/transformers/mgx_llama2/harness/config.hpp @@ -1,6 +1,28 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ #pragma once -// TODO: fix paths const std::string MODEL_FOLDER = "/model/"; const std::string ONNX_FILE = "model.onnx"; const std::string DATASET_FOLDER = "/dataset/"; diff --git a/examples/transformers/mgx_llama2/harness/dataset.hpp b/examples/transformers/mgx_llama2/harness/dataset.hpp index 3aeba8ef252..5b7f2a6a851 100644 --- a/examples/transformers/mgx_llama2/harness/dataset.hpp +++ b/examples/transformers/mgx_llama2/harness/dataset.hpp @@ -1,3 +1,26 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ #pragma once #include "config.hpp" diff --git a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp index 3ddfff6b294..a49344af3c7 100644 --- a/examples/transformers/mgx_llama2/harness/llama2inputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2inputs.hpp @@ -1,3 +1,26 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ #pragma once #include "config.hpp" diff --git a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp index a6729dfc3b1..0116825672e 100644 --- a/examples/transformers/mgx_llama2/harness/llama2outputs.hpp +++ b/examples/transformers/mgx_llama2/harness/llama2outputs.hpp @@ -1,3 +1,26 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ #pragma once #include "buffer.hpp" diff --git a/examples/transformers/mgx_llama2/harness/logging.hpp b/examples/transformers/mgx_llama2/harness/logging.hpp index 9770090e079..4e926f4cc6d 100644 --- a/examples/transformers/mgx_llama2/harness/logging.hpp +++ b/examples/transformers/mgx_llama2/harness/logging.hpp @@ -1,3 +1,26 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ #pragma once #include diff --git a/examples/transformers/mgx_llama2/harness/numa.hpp b/examples/transformers/mgx_llama2/harness/numa.hpp index 4aa60ddeccb..83eeaeeed08 100644 --- a/examples/transformers/mgx_llama2/harness/numa.hpp +++ b/examples/transformers/mgx_llama2/harness/numa.hpp @@ -1,3 +1,26 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ #pragma once #include diff --git a/examples/transformers/mgx_llama2/harness/numpy.hpp b/examples/transformers/mgx_llama2/harness/numpy.hpp index 5864a48b179..00db711ed87 100644 --- a/examples/transformers/mgx_llama2/harness/numpy.hpp +++ b/examples/transformers/mgx_llama2/harness/numpy.hpp @@ -1,3 +1,26 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ #pragma once #include diff --git a/examples/transformers/mgx_llama2/harness/timer.hpp b/examples/transformers/mgx_llama2/harness/timer.hpp index 91a68373333..20b2ca42f69 100644 --- a/examples/transformers/mgx_llama2/harness/timer.hpp +++ b/examples/transformers/mgx_llama2/harness/timer.hpp @@ -1,3 +1,26 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ #pragma once #include diff --git a/examples/transformers/mgx_llama2/harness/utils.hpp b/examples/transformers/mgx_llama2/harness/utils.hpp index 151f75acd12..965a46bc751 100644 --- a/examples/transformers/mgx_llama2/harness/utils.hpp +++ b/examples/transformers/mgx_llama2/harness/utils.hpp @@ -1,3 +1,26 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ #pragma once #include "config.hpp" diff --git a/examples/transformers/mgx_llama2/mgxllama2.cc b/examples/transformers/mgx_llama2/mgxllama2.cc index a1785101cdd..adb1ac0ac26 100644 --- a/examples/transformers/mgx_llama2/mgxllama2.cc +++ b/examples/transformers/mgx_llama2/mgxllama2.cc @@ -1,3 +1,26 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ #include "config.hpp" #include "buffer.hpp" #include "common.hpp" diff --git a/examples/transformers/mgx_llama2/preprocess_dataset.py b/examples/transformers/mgx_llama2/preprocess_dataset.py index ed2692299af..eb1a7c8d5b9 100644 --- a/examples/transformers/mgx_llama2/preprocess_dataset.py +++ b/examples/transformers/mgx_llama2/preprocess_dataset.py @@ -1,3 +1,26 @@ +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +##################################################################################### import numpy as np import pickle from pathlib import Path diff --git a/examples/transformers/mgx_llama2/run_docker.sh b/examples/transformers/mgx_llama2/run_docker.sh index 329ab350873..c8d02b7363b 100755 --- a/examples/transformers/mgx_llama2/run_docker.sh +++ b/examples/transformers/mgx_llama2/run_docker.sh @@ -1,5 +1,30 @@ #!/bin/bash +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +##################################################################################### + if [[ -z "${MODEL_DIR_PATH}" ]]; then echo "MODEL_DIR_PATH is not set, please provide the path to model before running docker." exit 1