Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warp #58

Merged
merged 2 commits into from
Aug 28, 2024
Merged

Warp #58

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions libNeonPy/src/Neon/py/CudaDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,23 @@ auto CudaDriver::run_kernel(
Neon::StreamIdx streamIdx) -> void
{
[[maybe_unused]] auto& streamSet = backend.streamSet(streamIdx);

int const ndevs = backend.getDeviceCount();
// #pragma omp parallel for num_threads(ndevs)
int const ndevs = backend.getDeviceCount();
#pragma omp parallel for num_threads(ndevs)
for (int setIdx = 0; setIdx < ndevs; setIdx++) {
backend.devSet().setActiveDevContext(setIdx);

cudaStream_t const& cuda_stream = streamSet.cudaStream(setIdx);
CUstream driverStream = (CUstream)cuda_stream;
CUfunction function = static_cast<CUfunction>(kernelSet[setIdx]);
auto& launch_info = launch_params[setIdx];

auto& launch_info = launch_params[setIdx];
//std::cout << "setIdx " << setIdx << " function " << function << std::endl;
// auto const cudaGrid = launch_info.cudaGrid();
// auto const cudaBlock = launch_info.cudaBlock();

// Set the created context as the current context
CUresult res = cuCtxSetCurrent(cu_contexts[setIdx]);
check_cuda_res(res, "cuCtxSetCurrent");
// CUresult res = cuCtxSetCurrent(cu_contexts[setIdx]);
// check_cuda_res(res, "cuCtxSetCurrent");
// std::cout << "Current CUDA context ID (handle): " << (cu_contexts[setIdx]) << std::endl;
// int64_t pywarp_size = 1;
// std::cout << "pywarp_size" << pywarp_size << std::endl;
const int LAUNCH_MAX_DIMS = 4; // should match types.py
Expand All @@ -99,23 +100,28 @@ auto CudaDriver::run_kernel(
std::vector<void*> args;
args.push_back(&bounds);

[[maybe_unused]] auto devset = backend.devSet();
devset.setActiveDevContext(setIdx);
[[maybe_unused]] auto const& gpuDev = devset.gpuDev(setIdx);
[[maybe_unused]] auto kinfo = launch_params.operator[](setIdx);
// [[maybe_unused]] auto devset = backend.devSet();
// [[maybe_unused]] auto const& gpuDev = devset.gpuDev(setIdx);
// [[maybe_unused]] auto kinfo = launch_params.operator[](setIdx);
// try {
// gpuDev.kernel.cudaLaunchKernel<Neon::run_et::sync>(streamSet[setIdx], kinfo, function, args.data());
// } catch (...) {
//
// }
// int block_dim = 256;
// int grid_dim = (n + block_dim - 1) / block_dim;
// std::cout << "block_dim " << launch_info.toString()<< std::endl;
// std::cout << "grid_dim " << launch_info << std::endl;
// std::cout << "n " << n << std::endl;
// std::cout << "cuLaunchKernel" << std::endl;
// std::cout << "block_dim " << launch_info.domainGrid() << std::endl;
// // std::cout << "grid_dim " << launch_info << std::endl;
// std::cout << "n " << n << std::endl;
// std::cout << "cuLaunchKernel" << std::endl;
// int deviceId;
// cudaError_t status = cudaGetDevice(&deviceId);
// if (status != cudaSuccess) {
// std::cerr << "Failed to get current device ID: " << cudaGetErrorString(status) << std::endl;
// }

res = cuLaunchKernel(
//std::cout << "Current CUDA device ID: " << deviceId << std::endl;
auto res = cuLaunchKernel(
function,
launch_info.cudaGrid().x,
launch_info.cudaGrid().y,
Expand All @@ -129,7 +135,7 @@ auto CudaDriver::run_kernel(
0);

check_cuda_res(res, "cuLaunchKernel");
//cuCtxSynchronize();
// cuCtxSynchronize();
}
}

Expand Down
26 changes: 17 additions & 9 deletions py_neon/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ def __del__(self):
def help_load_api(self):
# ------------------------------------------------------------------
# backend_new
self.py_neon.lib.dBackend_new.argtypes = [ctypes.POINTER(self.py_neon.handle_type),
ctypes.c_int,
ctypes.c_int,
ctypes.POINTER(ctypes.c_int)]
self.py_neon.lib.dBackend_new.restype = ctypes.c_int
lib_obj = self.py_neon.lib
self.api_new = lib_obj.dBackend_new
self.api_new.argtypes = [ctypes.POINTER(self.py_neon.handle_type),
ctypes.c_int,
ctypes.c_int,
ctypes.POINTER(ctypes.c_int)]
self.api_new.restype = ctypes.c_int
# ------------------------------------------------------------------
# backend_delete
self.py_neon.lib.dBackend_delete.argtypes = [ctypes.POINTER(self.py_neon.handle_type)]
Expand Down Expand Up @@ -82,17 +84,17 @@ def help_backend_new(self):
raise Exception(f'DBackend: Invalid handle {self.backend_handle}')

if self.n_dev > len(self.dev_idx_list):
dev_idx_list = list(range(self.n_dev))
self.dev_idx_list = list(range(self.n_dev))
else:
self.n_dev = len(self.dev_idx_list)

dev_idx_np = np.array(self.dev_idx_list, dtype=int)
dev_idx_ptr = dev_idx_np.ctypes.data_as(ctypes.POINTER(ctypes.c_int))
# Loading the device list into a contiguous array
dev_array = (ctypes.c_int * self.n_dev)(*self.dev_idx_list)

res = self.py_neon.lib.dBackend_new(ctypes.pointer(self.backend_handle),
self.runtime.value,
self.n_dev,
dev_idx_ptr)
dev_array)

print(f"NEON PYTHON self.backend_handle: {hex(self.backend_handle.value)}")
if res != 0:
Expand All @@ -102,6 +104,7 @@ def help_backend_new(self):
self.backend_handle)
pass


def help_backend_delete(self):
if self.backend_handle == 0:
return
Expand All @@ -112,21 +115,26 @@ def help_backend_delete(self):
if res != 0:
raise Exception('Failed to delete backend')


def get_num_devices(self):
return self.n_dev


def get_warp_device_name(self):
if self.runtime == Backend.Runtime.stream:
return 'cuda'
else:
return 'cpu'


def __str__(self):
return ctypes.cast(self.py_neon.lib.get_string(self.backend_handle), ctypes.c_char_p).value.decode('utf-8')


def sync(self):
return self.py_neon.lib.dBackend_sync(self.backend_handle)


def get_device_name(self, dev_idx: int):
if self.runtime == Backend.Runtime.stream:
dev_id = self.dev_idx_list[dev_idx]
Expand Down
Loading