forked from cms-patatrack/pixeltrack-standalone
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ScopedContext.cc
116 lines (98 loc) · 4.36 KB
/
ScopedContext.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include "CUDACore/ScopedContext.h"
#include "CUDACore/StreamCache.h"
#include "CUDACore/cudaCheck.h"
#include "chooseDevice.h"
namespace {
struct CallbackData {
edm::WaitingTaskWithArenaHolder holder;
int device;
};
void CUDART_CB cudaScopedContextCallback(cudaStream_t streamId, cudaError_t status, void* data) {
std::unique_ptr<CallbackData> guard{reinterpret_cast<CallbackData*>(data)};
edm::WaitingTaskWithArenaHolder& waitingTaskHolder = guard->holder;
int device = guard->device;
if (status == cudaSuccess) {
//std::cout << " GPU kernel finished (in callback) device " << device << " CUDA stream "
// << streamId << std::endl;
waitingTaskHolder.doneWaiting(nullptr);
} else {
// wrap the exception in a try-catch block to let GDB "catch throw" break on it
try {
auto error = cudaGetErrorName(status);
auto message = cudaGetErrorString(status);
throw std::runtime_error("Callback of CUDA stream " +
std::to_string(reinterpret_cast<unsigned long>(streamId)) + " in device " +
std::to_string(device) + " error " + std::string(error) + ": " + std::string(message));
} catch (std::exception&) {
waitingTaskHolder.doneWaiting(std::current_exception());
}
}
}
} // namespace
namespace cms::cuda {
namespace impl {
ScopedContextBase::ScopedContextBase(edm::StreamID streamID) : currentDevice_(chooseDevice(streamID)) {
cudaCheck(cudaSetDevice(currentDevice_));
stream_ = getStreamCache().get();
}
ScopedContextBase::ScopedContextBase(const ProductBase& data) : currentDevice_(data.device()) {
cudaCheck(cudaSetDevice(currentDevice_));
if (data.mayReuseStream()) {
stream_ = data.streamPtr();
} else {
stream_ = getStreamCache().get();
}
}
ScopedContextBase::ScopedContextBase(int device, SharedStreamPtr stream)
: currentDevice_(device), stream_(std::move(stream)) {
cudaCheck(cudaSetDevice(currentDevice_));
}
////////////////////
void ScopedContextGetterBase::synchronizeStreams(int dataDevice,
cudaStream_t dataStream,
bool available,
cudaEvent_t dataEvent) {
if (dataDevice != device()) {
// Eventually replace with prefetch to current device (assuming unified memory works)
// If we won't go to unified memory, need to figure out something else...
throw std::runtime_error("Handling data from multiple devices is not yet supported");
}
if (dataStream != stream()) {
// Different streams, need to synchronize
if (not available) {
// Event not yet occurred, so need to add synchronization
// here. Sychronization is done by making the CUDA stream to
// wait for an event, so all subsequent work in the stream
// will run only after the event has "occurred" (i.e. data
// product became available).
cudaCheck(cudaStreamWaitEvent(stream(), dataEvent, 0), "Failed to make a stream to wait for an event");
}
}
}
void ScopedContextHolderHelper::enqueueCallback(int device, cudaStream_t stream) {
cudaCheck(
cudaStreamAddCallback(stream, cudaScopedContextCallback, new CallbackData{waitingTaskHolder_, device}, 0));
}
} // namespace impl
////////////////////
ScopedContextAcquire::~ScopedContextAcquire() {
holderHelper_.enqueueCallback(device(), stream());
if (contextState_) {
contextState_->set(device(), std::move(streamPtr()));
}
}
void ScopedContextAcquire::throwNoState() {
throw std::runtime_error(
"Calling ScopedContextAcquire::insertNextTask() requires ScopedContextAcquire to be constructed with "
"ContextState, but that was not the case");
}
////////////////////
ScopedContextProduce::~ScopedContextProduce() {
// Intentionally not checking the return value to avoid throwing
// exceptions. If this call would fail, we should get failures
// elsewhere as well.
cudaEventRecord(event_.get(), stream());
}
////////////////////
ScopedContextTask::~ScopedContextTask() { holderHelper_.enqueueCallback(device(), stream()); }
} // namespace cms::cuda