Skip to content

Commit

Permalink
Replace use of cuda::device::current::scoped_override_t with plain CU…
Browse files Browse the repository at this point in the history
…DA (#407)

Remove cuda::device::current::scoped_override_t from CUDAScopedContext .
Replace cuda::device::current::scoped_override_t wth cudautils::ScopedSetDevice .
  • Loading branch information
makortel authored and fwyzard committed Nov 8, 2019
1 parent 20aff22 commit 614ee0b
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 7 deletions.
7 changes: 6 additions & 1 deletion HeterogeneousCore/CUDACore/interface/CUDAScopedContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ namespace impl {
const cudautils::SharedStreamPtr& streamPtr() const { return stream_; }

protected:
// The constructors set the current device device, but the device
// is not set back to the previous value at the destructor. This
// should be sufficient (and tiny bit faster) as all CUDA API
// functions relying on the current device should be called from
// the scope where this context is. The current device doesn't
// really matter between modules (or across TBB tasks).
explicit CUDAScopedContextBase(edm::StreamID streamID);

explicit CUDAScopedContextBase(const CUDAProductBase& data);
Expand All @@ -41,7 +47,6 @@ namespace impl {

private:
int currentDevice_;
cuda::device::current::scoped_override_t<> setDeviceForThisScope_;
cudautils::SharedStreamPtr stream_;
};

Expand Down
10 changes: 7 additions & 3 deletions HeterogeneousCore/CUDACore/src/CUDAScopedContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ namespace {

namespace impl {
CUDAScopedContextBase::CUDAScopedContextBase(edm::StreamID streamID)
: currentDevice_(cudacore::chooseCUDADevice(streamID)), setDeviceForThisScope_(currentDevice_) {
: currentDevice_(cudacore::chooseCUDADevice(streamID)) {
cudaCheck(cudaSetDevice(currentDevice_));
stream_ = cudautils::getCUDAStreamCache().getCUDAStream();
}

CUDAScopedContextBase::CUDAScopedContextBase(const CUDAProductBase& data)
: currentDevice_(data.device()), setDeviceForThisScope_(currentDevice_) {
: currentDevice_(data.device()) {
cudaCheck(cudaSetDevice(currentDevice_));
if (data.mayReuseStream()) {
stream_ = data.streamPtr();
} else {
Expand All @@ -53,7 +55,9 @@ namespace impl {
}

CUDAScopedContextBase::CUDAScopedContextBase(int device, cudautils::SharedStreamPtr stream)
: currentDevice_(device), setDeviceForThisScope_(device), stream_(std::move(stream)) {}
: currentDevice_(device), stream_(std::move(stream)) {
cudaCheck(cudaSetDevice(currentDevice_));
}

////////////////////

Expand Down
3 changes: 2 additions & 1 deletion HeterogeneousCore/CUDACore/test/test_CUDAScopedContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "HeterogeneousCore/CUDAUtilities/interface/exitSansCUDADevices.h"
#include "HeterogeneousCore/CUDAUtilities/interface/CUDAStreamCache.h"
#include "HeterogeneousCore/CUDAUtilities/interface/CUDAEventCache.h"
#include "HeterogeneousCore/CUDAUtilities/interface/ScopedSetDevice.h"

#include "test_CUDAScopedContextKernels.h"

Expand Down Expand Up @@ -84,7 +85,7 @@ TEST_CASE("Use of CUDAScopedContext", "[CUDACore]") {
}

SECTION("Joining multiple CUDA streams") {
cuda::device::current::scoped_override_t<> setDeviceForThisScope(defaultDevice);
cudautils::ScopedSetDevice setDeviceForThisScope(defaultDevice);
auto current_device = cuda::device::current::get();

// Mimick a producer on the first CUDA stream
Expand Down
5 changes: 3 additions & 2 deletions HeterogeneousCore/CUDAUtilities/src/allocate_device.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "HeterogeneousCore/CUDAUtilities/interface/allocate_device.h"
#include "HeterogeneousCore/CUDAUtilities/interface/cudaCheck.h"
#include "HeterogeneousCore/CUDAUtilities/interface/ScopedSetDevice.h"
#include "FWCore/Utilities/interface/Likely.h"

#include "getCachingDeviceAllocator.h"
Expand All @@ -23,7 +24,7 @@ namespace cudautils {
}
cuda::throw_if_error(cudautils::allocator::getCachingDeviceAllocator().DeviceAllocate(dev, &ptr, nbytes, stream));
} else {
cuda::device::current::scoped_override_t<> setDeviceForThisScope(dev);
ScopedSetDevice setDeviceForThisScope(dev);
cuda::throw_if_error(cudaMalloc(&ptr, nbytes));
}
return ptr;
Expand All @@ -33,7 +34,7 @@ namespace cudautils {
if constexpr (cudautils::allocator::useCaching) {
cuda::throw_if_error(cudautils::allocator::getCachingDeviceAllocator().DeviceFree(device, ptr));
} else {
cuda::device::current::scoped_override_t<> setDeviceForThisScope(device);
ScopedSetDevice setDeviceForThisScope(device);
cuda::throw_if_error(cudaFree(ptr));
}
}
Expand Down

0 comments on commit 614ee0b

Please sign in to comment.