-
Notifications
You must be signed in to change notification settings - Fork 455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
merge 2nd order derivative of CutlassMLP #370
base: master
Are you sure you want to change the base?
Changes from 6 commits
8817d59
bb27a97
01b11e9
d21a48c
bed60fc
8c66716
c9163a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,33 @@ | |
|
||
namespace tcnn { | ||
|
||
// element-wise convert float* to T* | ||
template <typename T> | ||
__global__ void element_wise_convert(uint32_t n_elements, float* in, T* out) { | ||
uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i >= n_elements) return; | ||
|
||
out[i] = (T)in[i]; | ||
} | ||
|
||
// element-wise convert T* to float* and then add back to *out | ||
template <typename T> | ||
__global__ void element_wise_convert_float(uint32_t n_elements, T* in, float* out) { | ||
uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i >= n_elements) return; | ||
|
||
out[i] += (float)in[i]; | ||
} | ||
|
||
// element-wise add | ||
template <typename T> | ||
__global__ void element_wise_add(uint32_t n_elements, T* in, T* out) { | ||
uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i >= n_elements) return; | ||
|
||
out[i] += in[i]; | ||
} | ||
|
||
Comment on lines
+41
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
template <typename T> | ||
class NetworkWithInputEncoding : public Network<float, T> { | ||
public: | ||
|
@@ -90,8 +117,8 @@ class NetworkWithInputEncoding : public Network<float, T> { | |
bool use_inference_params = false, | ||
GradientMode param_gradients_mode = GradientMode::Overwrite | ||
) override { | ||
GPUMatrixDynamic<T> dL_dnetwork_input; | ||
if (m_encoding->n_params() > 0 || dL_dinput) { | ||
// dL_dnetwork_input becomes a member of the class instance | ||
dL_dnetwork_input = {m_encoding->padded_output_width(), input.n(), stream, m_encoding->preferred_output_layout()}; | ||
} | ||
|
||
|
@@ -112,6 +139,89 @@ class NetworkWithInputEncoding : public Network<float, T> { | |
} | ||
} | ||
|
||
void backward_backward_input_impl( | ||
cudaStream_t stream, | ||
const Context& ctx, | ||
const GPUMatrixDynamic<float>& input, | ||
const GPUMatrixDynamic<float>& dL_ddLdinput, | ||
const GPUMatrixDynamic<T>& dL_doutput, | ||
GPUMatrixDynamic<T>* dL_ddLdoutput = nullptr, | ||
GPUMatrixDynamic<float>* dL_dinput = nullptr, | ||
bool use_inference_params = false, | ||
GradientMode param_gradients_mode = GradientMode::Overwrite | ||
) override { | ||
const auto& forward = dynamic_cast<const ForwardContext&>(ctx); | ||
|
||
// dL_ddLdinput of m_network->backward_baward_input equals to dL_dLdencoding_output (different names) | ||
GPUMatrixDynamic<T> dL_dLdnetwork_input; | ||
|
||
if (m_encoding->n_params() > 0) { | ||
dL_dLdnetwork_input = {m_encoding->padded_output_width(), input.n(), stream, dL_ddLdinput.layout()}; | ||
// cudaMemsetAsync: set dL_dLdnetwork_input.data() with 0.0 to avoid NaN initialization | ||
CUDA_CHECK_THROW(cudaMemsetAsync(dL_dLdnetwork_input.data(), 0, dL_dLdnetwork_input.n() * dL_dLdnetwork_input.m() * sizeof(T), stream)); | ||
|
||
// encoding backward backward | ||
m_encoding->backward_backward_input( | ||
stream, | ||
*forward.encoding_ctx, | ||
input, | ||
dL_ddLdinput, | ||
dL_dnetwork_input, // dL1_denc_output | ||
&dL_dLdnetwork_input, // dL2_ddL1_denc_output | ||
dL_dinput, | ||
use_inference_params, | ||
param_gradients_mode | ||
); | ||
} else { // copy dL_ddLdinput (float) to dL_dLdnetwork_input (T) | ||
dL_dLdnetwork_input = {m_encoding->padded_output_width(), input.n(), stream, dL_ddLdinput.layout()}; | ||
linear_kernel(element_wise_convert<T>, 0, stream, dL_dLdnetwork_input.n() * dL_dLdnetwork_input.m(), dL_ddLdinput.data(), dL_dLdnetwork_input.data()); | ||
} | ||
|
||
// dL2_dinput of m_network->backward_backward_input | ||
GPUMatrixDynamic<T> dL2_dnetwork_input; | ||
if (m_encoding->n_params() > 0 || dL_dinput) { | ||
dL2_dnetwork_input = {m_encoding->padded_output_width(), input.n(), stream, m_encoding->preferred_output_layout()}; | ||
} | ||
|
||
// network backward backward | ||
m_network->backward_backward_input( | ||
stream, | ||
*forward.network_ctx, | ||
forward.network_input, // enc_output i.e. network_input | ||
dL_dLdnetwork_input, // dL2_dL1dnetwork_input | ||
dL_doutput, | ||
dL_ddLdoutput ? dL_ddLdoutput : nullptr, | ||
dL2_dnetwork_input.data() ? &dL2_dnetwork_input : nullptr, // dL2_dinput of network | ||
use_inference_params, | ||
param_gradients_mode | ||
); | ||
|
||
// dL2dnetwork_input backward to dL2dinput, first order backward | ||
GPUMatrixDynamic<float> dL2_dinput; | ||
if (m_encoding->n_params() > 0 || dL2_dnetwork_input.data()) { | ||
dL2_dinput = {m_encoding->input_width(), input.n(), stream, input.layout()}; | ||
} | ||
|
||
if (m_encoding->n_params() > 0) { | ||
// backward dL2dnetwork_input to dL2dinput | ||
m_encoding->backward( | ||
stream, | ||
*forward.encoding_ctx, | ||
input, | ||
forward.network_input, // enc_output | ||
dL2_dnetwork_input, // dL2_dencoding_output | ||
&dL2_dinput, | ||
use_inference_params, | ||
GradientMode::Accumulate // dL2denc_w : add up 1st order term | ||
); | ||
|
||
linear_kernel(element_wise_add<float>, 0, stream, dL_dinput->n() * dL_dinput->m(), dL2_dinput.data(), dL_dinput->data()); | ||
|
||
} else if (dL2_dnetwork_input.data()) { | ||
linear_kernel(element_wise_convert_float<T>, 0, stream, dL_dinput->n() * dL_dinput->m(), dL2_dnetwork_input.data(), dL_dinput->data()); | ||
} | ||
} | ||
|
||
void set_params_impl(T* params, T* inference_params, T* gradients) override { | ||
size_t offset = 0; | ||
m_network->set_params(params + offset, inference_params + offset, gradients + offset); | ||
|
@@ -181,6 +291,7 @@ class NetworkWithInputEncoding : public Network<float, T> { | |
private: | ||
std::shared_ptr<Encoding<T>> m_encoding; | ||
std::shared_ptr<Network<T>> m_network; | ||
GPUMatrixDynamic<T> dL_dnetwork_input; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GPU memory (or pointers to GPU memory) can not be class members. They should be part of either the |
||
|
||
struct ForwardContext : public Context { | ||
GPUMatrixDynamic<T> network_input; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be reverted to previous behavior