Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CWT operator #4860

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
937b963
add MotherWavelet helper and WaveletGpu kernel
May 18, 2023
cf7b6a6
Cwt WIP
mwdowski May 18, 2023
68bb330
Merge branch 'NVIDIA:main' into wavelet-computing
kubo11 May 18, 2023
9d6e0b0
Merge pull request #2 from mwdowski/wavelet-computing
mwdowski May 18, 2023
359d79c
Merge pull request #1 from mwdowski/mwdowski
mwdowski May 18, 2023
b034619
Rename namespace
mwdowski May 18, 2023
6bb49f5
Merge branch 'main' into mwdowski
mwdowski May 18, 2023
5eed0c5
add WaveletArgs class
May 22, 2023
09196c6
Merge pull request #3 from mwdowski/wavelet-computing
kubo11 May 29, 2023
279e61b
Improve wavelet computing kernel
Jun 5, 2023
c4814f9
Optimize and remove discrete wavelets
Jun 7, 2023
11df6aa
Merge pull request #4 from mwdowski/wavelet-computing-improvements
kubo11 Jun 7, 2023
d3a8d6a
add DALIWaveletName enum
Jun 11, 2023
27cedd3
fix linting errors
Jun 11, 2023
2875c95
replace MeyerWavelet with GaussianWavelet
Jun 13, 2023
20d5d7e
Merge pull request #5 from mwdowski/wavelet-computing-improvements
kubo11 Jun 13, 2023
0efec3d
Fix wavelet exceptions
Jul 3, 2023
1ed22bc
Add CWT operator docstr
Jul 4, 2023
3c36192
Merge pull request #6 from mwdowski/wavelet-fixes
kubo11 Jul 6, 2023
1cdc5e7
WIP
mwdowski Sep 8, 2023
e99099e
Merge branch 'NVIDIA:main' into main
mwdowski Sep 8, 2023
15ce332
Merge branch 'main' into mwdowski2
mwdowski Sep 8, 2023
101efc4
Good size but full of zeros
mwdowski Sep 12, 2023
276f87e
WIP
mwdowski Sep 12, 2023
1849a30
Merge pull request #7 from mwdowski/mwdowski2
mwdowski Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix linting errors
JakubO committed Jun 11, 2023
commit 27cedd3642fb28abfa8f4d8b58d394590a368829
18 changes: 12 additions & 6 deletions dali/kernels/signal/wavelet/mother_wavelet.cu
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
// limitations under the License.

#include <cmath>
#include <vector>
#include "dali/kernels/signal/wavelet/mother_wavelet.cuh"
#include "dali/core/math_util.h"

@@ -51,8 +52,10 @@ MeyerWavelet<T>::MeyerWavelet(const std::vector<T> &args) {
template <typename T>
__device__ T MeyerWavelet<T>::operator()(const T &t) const {
T tt = t - 0.5;
T psi1 = (4.0/(3.0*M_PI)*tt*std::cos((2.0*M_PI)/3.0*tt)-1.0/M_PI*std::sin((4.0*M_PI)/3.0*tt))/(tt-16.0/9.0*std::pow(tt, 3.0));
T psi2 = (8.0/(3.0*M_PI)*tt*std::cos(8.0*M_PI/3.0*tt)+1.0/M_PI*std::sin((4.0*M_PI)/3.0)*tt)/(tt-64.0/9.0*std::pow(tt, 3.0));
T psi1 = (4.0/(3.0*M_PI)*tt*std::cos((2.0*M_PI)/3.0*tt)-1.0/M_PI*std::sin((4.0*M_PI)/3.0*tt))/
(tt-16.0/9.0*std::pow(tt, 3.0));
T psi2 = (8.0/(3.0*M_PI)*tt*std::cos(8.0*M_PI/3.0*tt)+1.0/M_PI*std::sin((4.0*M_PI)/3.0)*tt)/
(tt-64.0/9.0*std::pow(tt, 3.0));
return psi1 + psi2;
}

@@ -69,7 +72,8 @@ MexicanHatWavelet<T>::MexicanHatWavelet(const std::vector<T> &args) {

template <typename T>
__device__ T MexicanHatWavelet<T>::operator()(const T &t) const {
return 2.0/(std::sqrt(3.0*sigma)*std::pow(M_PI, 0.25))*(1.0-std::pow(t/sigma, 2.0))*std::exp(-std::pow(t, 2.0)/(2.0*std::pow(sigma, 2.0)));
return 2.0/(std::sqrt(3.0*sigma)*std::pow(M_PI, 0.25))*(1.0-std::pow(t/sigma, 2.0))*
std::exp(-std::pow(t, 2.0)/(2.0*std::pow(sigma, 2.0)));
}

template class MexicanHatWavelet<float>;
@@ -94,7 +98,8 @@ template class MorletWavelet<double>;
template <typename T>
ShannonWavelet<T>::ShannonWavelet(const std::vector<T> &args) {
if (args.size() != 2) {
throw new std::invalid_argument("ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order.");
throw new std::invalid_argument(
"ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order.");
}
this->fb = args[0];
this->fc = args[1];
@@ -112,7 +117,8 @@ template class ShannonWavelet<double>;
template <typename T>
FbspWavelet<T>::FbspWavelet(const std::vector<T> &args) {
if (args.size() != 3) {
throw new std::invalid_argument("FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order.");
throw new std::invalid_argument(
"FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order.");
}
this->m = args[0];
this->fb = args[1];
@@ -129,5 +135,5 @@ template class FbspWavelet<float>;
template class FbspWavelet<double>;

} // namespace signal
} // namespace kernel
} // namespace kernels
} // namespace dali
18 changes: 9 additions & 9 deletions dali/kernels/signal/wavelet/mother_wavelet.cuh
Original file line number Diff line number Diff line change
@@ -15,14 +15,14 @@
#ifndef DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_
#define DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_

#include <vector>

#include "dali/core/common.h"
#include "dali/core/error_handling.h"
#include "dali/core/format.h"
#include "dali/core/util.h"
#include "dali/kernels/kernel.h"

#include <vector>

namespace dali {
namespace kernels {
namespace signal {
@@ -37,7 +37,7 @@ class HaarWavelet {
"Data type should be floating point");
public:
HaarWavelet() = default;
HaarWavelet(const std::vector<T> &args);
explicit HaarWavelet(const std::vector<T> &args);
~HaarWavelet() = default;

__device__ T operator()(const T &t) const;
@@ -49,7 +49,7 @@ class MeyerWavelet {
"Data type should be floating point");
public:
MeyerWavelet() = default;
MeyerWavelet(const std::vector<T> &args);
explicit MeyerWavelet(const std::vector<T> &args);
~MeyerWavelet() = default;

__device__ T operator()(const T &t) const;
@@ -61,7 +61,7 @@ class MexicanHatWavelet {
"Data type should be floating point");
public:
MexicanHatWavelet() = default;
MexicanHatWavelet(const std::vector<T> &args);
explicit MexicanHatWavelet(const std::vector<T> &args);
~MexicanHatWavelet() = default;

__device__ T operator()(const T &t) const;
@@ -76,7 +76,7 @@ class MorletWavelet {
"Data type should be floating point");
public:
MorletWavelet() = default;
MorletWavelet(const std::vector<T> &args);
explicit MorletWavelet(const std::vector<T> &args);
~MorletWavelet() = default;

__device__ T operator()(const T &t) const;
@@ -91,7 +91,7 @@ class ShannonWavelet {
"Data type should be floating point");
public:
ShannonWavelet() = default;
ShannonWavelet(const std::vector<T> &args);
explicit ShannonWavelet(const std::vector<T> &args);
~ShannonWavelet() = default;

__device__ T operator()(const T &t) const;
@@ -107,7 +107,7 @@ class FbspWavelet {
"Data type should be floating point");
public:
FbspWavelet() = default;
FbspWavelet(const std::vector<T> &args);
explicit FbspWavelet(const std::vector<T> &args);
~FbspWavelet() = default;

__device__ T operator()(const T &t) const;
@@ -119,7 +119,7 @@ class FbspWavelet {
};

} // namespace signal
} // namespace kernel
} // namespace kernels
} // namespace dali

#endif // DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_
18 changes: 10 additions & 8 deletions dali/kernels/signal/wavelet/wavelet_gpu.cu
Original file line number Diff line number Diff line change
@@ -42,8 +42,7 @@ __global__ void ComputeWavelet(const SampleDesc<T>* sample_data, W<T> wavelet) {
auto x = std::pow(2.0, a);
if (a == 0.0) {
shm[b_id] = sample.in[t_id];
}
else {
} else {
shm[b_id] = x * sample.in[t_id];
shm[1024] = std::pow(2.0, a / 2.0);
}
@@ -53,8 +52,7 @@ __global__ void ComputeWavelet(const SampleDesc<T>* sample_data, W<T> wavelet) {
auto b = sample.b[i];
if (b == 0.0) {
sample.out[out_id] = wavelet(shm[b_id]);
}
else {
} else {
sample.out[out_id] = wavelet(shm[b_id] - b);
}
if (a != 0.0) {
@@ -66,7 +64,8 @@ __global__ void ComputeWavelet(const SampleDesc<T>* sample_data, W<T> wavelet) {
// translate input range information to input samples
template <typename T>
__global__ void ComputeInputSamples(const SampleDesc<T>* sample_data) {
const int64_t t_id = blockDim.x * blockDim.y * blockIdx.x + threadIdx.y * blockDim.x + threadIdx.x;
const int64_t block_size = blockDim.x * blockDim.y;
const int64_t t_id = block_size * blockIdx.x + threadIdx.y * blockDim.x + threadIdx.x;
auto& sample = sample_data[blockIdx.y];
if (t_id >= sample.size_in) return;
sample.in[t_id] = sample.span.begin + (T)t_id / sample.span.sampling_rate;
@@ -107,7 +106,8 @@ DLL_PUBLIC void WaveletGpu<T, W>::Run(KernelContext &ctx,
sample.b = b.tensor_data(i);
sample.size_b = b.shape.tensor_size(i);
sample.span = span;
sample.size_in = std::ceil((sample.span.end - sample.span.begin) * sample.span.sampling_rate) + 1;
sample.size_in =
std::ceil((sample.span.end - sample.span.begin) * sample.span.sampling_rate) + 1;
sample.in = ctx.scratchpad->AllocateGPU<T>(sample.size_in);
max_size_in = std::max(max_size_in, sample.size_in);
}
@@ -133,9 +133,11 @@ TensorListShape<> WaveletGpu<T, W>::GetOutputShape(const TensorListShape<> &a_sh
TensorListShape<> out_shape(N, 3);
TensorShape<> tshape;
for (int i = 0; i < N; i++) {
// output tensor will be 3-dimensional of shape:
// output tensor will be 3-dimensional of shape:
// a coeffs x b coeffs x signal samples
tshape = TensorShape<>({a_shape.tensor_shape(i).num_elements(), b_shape.tensor_shape(i).num_elements(), in_size});
tshape = TensorShape<>({a_shape.tensor_shape(i).num_elements(),
b_shape.tensor_shape(i).num_elements(),
in_size});
out_shape.set_tensor_shape(i, tshape);
}
return out_shape;
25 changes: 15 additions & 10 deletions dali/kernels/signal/wavelet/wavelet_gpu.cuh
Original file line number Diff line number Diff line change
@@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_
#define DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_
#ifndef DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_
#define DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_

#include <memory>
#include <string>
#include <vector>
#include "dali/core/common.h"
#include "dali/core/error_handling.h"
#include "dali/core/format.h"
@@ -26,13 +27,16 @@

// makes sure both tensors have the same number of samples and
// that they're one-dimensional
#define ENFORCE_SHAPES(a_shape, b_shape) do { \
DALI_ENFORCE(a_shape.num_samples() == b_shape.num_samples(),"a and b tensors must have the same amount of samples."); \
for (int i = 0; i < a_shape.num_samples(); ++i) { \
DALI_ENFORCE(a_shape.tensor_shape(i).size() == 1, "Tensor of a coeffs should be 1-dimensional."); \
DALI_ENFORCE(b_shape.tensor_shape(i).size() == 1, "Tensor of b coeffs should be 1-dimensional."); \
} \
} while(0);
#define ENFORCE_SHAPES(a_shape, b_shape) do { \
DALI_ENFORCE(a_shape.num_samples() == b_shape.num_samples(), \
"a and b tensors must have the same amount of samples."); \
for (int i = 0; i < a_shape.num_samples(); ++i) { \
DALI_ENFORCE(a_shape.tensor_shape(i).size() == 1, \
"Tensor of a coeffs should be 1-dimensional."); \
DALI_ENFORCE(b_shape.tensor_shape(i).size() == 1, \
"Tensor of b coeffs should be 1-dimensional."); \
} \
} while (0);

namespace dali {
namespace kernels {
@@ -90,6 +94,7 @@ class DLL_PUBLIC WaveletGpu {
static TensorListShape<> GetOutputShape(const TensorListShape<> &a_shape,
const TensorListShape<> &b_shape,
const WaveletSpan<T> &span);

private:
W<T> wavelet_;
};
@@ -98,4 +103,4 @@ class DLL_PUBLIC WaveletGpu {
} // namespace kernels
} // namespace dali

#endif // DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_
#endif // DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_
4 changes: 2 additions & 2 deletions dali/operators/signal/wavelet/wavelet_name.h
Original file line number Diff line number Diff line change
@@ -29,6 +29,6 @@ enum DALIWaveletName {
DALI_FBSP = 5
};

} // namespace dali
} // namespace dali

#endif // DALI_OPERATORS_SIGNAL_WAVELET_NAME_H_
#endif // DALI_OPERATORS_SIGNAL_WAVELET_WAVELET_NAME_H_
22 changes: 14 additions & 8 deletions dali/operators/signal/wavelet/wavelet_run.h
Original file line number Diff line number Diff line change
@@ -57,28 +57,34 @@ void RunForName(const DALIWaveletName &name,
const std::vector<T> &args) {
switch (name) {
case DALIWaveletName::DALI_HAAR:
RunWaveletKernel<T, kernels::signal::HaarWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
using kernels::signal::HaarWavelet;
RunWaveletKernel<T, HaarWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
break;
case DALIWaveletName::DALI_MEY:
RunWaveletKernel<T, kernels::signal::MeyerWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
using kernels::signal::MeyerWavelet;
RunWaveletKernel<T, MeyerWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
break;
case DALIWaveletName::DALI_MEXH:
RunWaveletKernel<T, kernels::signal::MexicanHatWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
using kernels::signal::MexicanHatWavelet;
RunWaveletKernel<T, MexicanHatWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
break;
case DALIWaveletName::DALI_MORL:
RunWaveletKernel<T, kernels::signal::MorletWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
using kernels::signal::MorletWavelet;
RunWaveletKernel<T, MorletWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
break;
case DALIWaveletName::DALI_SHAN:
RunWaveletKernel<T, kernels::signal::ShannonWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
using kernels::signal::ShannonWavelet;
RunWaveletKernel<T, ShannonWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
break;
case DALIWaveletName::DALI_FBSP:
RunWaveletKernel<T, kernels::signal::FbspWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
using kernels::signal::FbspWavelet;
RunWaveletKernel<T, FbspWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
break;
default:
throw new std::invalid_argument("Unknown wavelet name.");
}
}

} // namespace dali
} // namespace dali

#endif // DALI_OPERATORS_SIGNAL_WAVELET_RUN_H_
#endif // DALI_OPERATORS_SIGNAL_WAVELET_WAVELET_RUN_H_
6 changes: 4 additions & 2 deletions dali/python/nvidia/dali/types.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,8 @@
from enum import Enum, unique
import re

from nvidia.dali.backend_impl.types import DALIDataType, DALIImageType, DALIInterpType, DALIWaveletName
from nvidia.dali.backend_impl.types import DALIDataType, DALIImageType, \
DALIInterpType, DALIWaveletName

# TODO: Handle forwarding imports from backend_impl
from nvidia.dali.backend_impl.types import * # noqa: F401, F403
@@ -63,7 +64,8 @@ def _not_implemented(val):
DALIDataType.DATA_TYPE: ("nvidia.dali.types.DALIDataType", lambda x: DALIDataType(int(x))),
DALIDataType.INTERP_TYPE:
("nvidia.dali.types.DALIInterpType", lambda x: DALIInterpType(int(x))),
DALIDataType.WAVELET_NAME: ("nvidia.dali.types.DALIWaveletName", lambda x: DALIWaveletName(int(x))),
DALIDataType.WAVELET_NAME:
("nvidia.dali.types.DALIWaveletName", lambda x: DALIWaveletName(int(x))),
DALIDataType.TENSOR_LAYOUT: (":ref:`layout str<layout_str_doc>`", lambda x: str(x)),
DALIDataType.PYTHON_OBJECT: ("object", lambda x: x),
DALIDataType._TENSOR_LAYOUT_VEC: