Skip to content

Commit

Permalink
NdArray ifft, rfft, irfft, conj implementation for PyTorch engine.
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavel Semenov committed Jul 13, 2024
1 parent bcf8fb3 commit 90305cb
Show file tree
Hide file tree
Showing 10 changed files with 325 additions and 0 deletions.
71 changes: 71 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -3375,6 +3375,69 @@ default NDArray fft(long length) {
*/
NDArray fft(long length, long axis);

/**
* Computes the one dimensional inverse discrete Fourier transform.
*
* @param length Length of the transformed axis of the output.
* @param axis Axis over which to compute the IFFT.
* @return The truncated or zero-padded input, transformed along the axis indicated by axis, or
* the last one if axis is not specified.
*/
NDArray ifft(long length, long axis);

/**
* Computes the one dimensional inverse discrete Fourier transform.
*
* @param length Length of the transformed axis of the output.
* @return The truncated or zero-padded input, transformed along the axis indicated by axis, or
* the last one if axis is not specified.
*/
default NDArray ifft(long length) {
return ifft(length, -1);
}

/**
* Computes the one dimensional Fourier transform of real-valued input.
*
* @param length Length of the transformed axis of the output.
* @return The truncated or zero-padded input, transformed along the axis indicated by axis, or
* the last one if axis is not specified.
*/
default NDArray rfft(long length) {
return rfft(length, -1);
}

/**
* Computes the one dimensional Fourier transform of real-valued input.
*
* @param length Length of the transformed axis of the output.
* @param axis Axis over which to compute the FFT.
* @return The truncated or transformed along the axis indicated by axis, or the last one if
* axis is not specified.
*/
NDArray rfft(long length, long axis);

/**
* Computes the one dimensional inverse Fourier transform of real-valued input.
*
* @param length Length of the transformed axis of the output.
* @param axis Axis over which to compute the IRFFT.
* @return The truncated or transformed along the axis indicated by axis, or the last one if
* axis is not specified.
*/
NDArray irfft(long length, long axis);

/**
* Computes the one dimensional inverse Fourier transform of real-valued input.
*
* @param length Length of the transformed axis of the output.
* @return The truncated or transformed along the axis indicated by axis, or the last one if
* axis is not specified.
*/
default NDArray irfft(long length) {
return irfft(length, -1);
}

/**
* Computes the Short Time Fourier Transform (STFT).
*
Expand Down Expand Up @@ -5404,4 +5467,12 @@ default NDArray oneHot(int depth, DataType dataType) {
* @return tje real NDArray
*/
NDArray real();

/**
* Conjugate complex array.
*
* @return Returns a view of input with a flipped conjugate bit. If input has a non-complex
* type, this function just returns input.
*/
NDArray conj();
}
24 changes: 24 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,24 @@ public NDArray fft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray ifft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray rfft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray irfft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray stft(
Expand Down Expand Up @@ -1254,6 +1272,12 @@ public NDArray real() {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray conj() {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArrayEx getNDArrayInternal() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,24 @@ public NDArray fft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray ifft(long length, long axis) {
return null;
}

/** {@inheritDoc} */
@Override
public NDArray irfft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented");
}

/** {@inheritDoc} */
@Override
public NDArray rfft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray stft(
Expand Down Expand Up @@ -1679,6 +1697,12 @@ public NDArray real() {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray conj() {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArrayEx getNDArrayInternal() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,24 @@ public NDArray fft(long length, long axis) {
return JniUtils.fft(this, length, axis);
}

/** {@inheritDoc} */
@Override
public NDArray rfft(long length, long axis) {
return JniUtils.rfft(this, length, axis);
}

/** {@inheritDoc} */
@Override
public NDArray ifft(long length, long axis) {
return JniUtils.ifft(this, length, axis);
}

/** {@inheritDoc} */
@Override
public NDArray irfft(long length, long axis) {
return JniUtils.irfft(this, length, axis);
}

/** {@inheritDoc} */
@Override
public NDArray stft(
Expand Down Expand Up @@ -1628,6 +1646,12 @@ public NDArray real() {
return JniUtils.real(this);
}

/** {@inheritDoc} */
@Override
public NDArray conj() {
return JniUtils.conj(this);
}

/** {@inheritDoc} */
@Override
public PtNDArrayEx getNDArrayInternal() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,24 @@ public static PtNDArray fft(PtNDArray ndArray, long length, long axis) {
PyTorchLibrary.LIB.torchFft(ndArray.getHandle(), length, axis));
}

public static PtNDArray ifft(PtNDArray ndArray, long length, long axis) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchIfft(ndArray.getHandle(), length, axis));
}

public static PtNDArray rfft(PtNDArray ndArray, long length, long axis) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchRfft(ndArray.getHandle(), length, axis));
}

public static PtNDArray irfft(PtNDArray ndArray, long length, long axis) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchIrfft(ndArray.getHandle(), length, axis));
}

public static PtNDArray stft(
PtNDArray ndArray,
long nFft,
Expand Down Expand Up @@ -1068,6 +1086,10 @@ public static PtNDArray complex(PtNDArray ndArray) {
return new PtNDArray(ndArray.getManager(), handle);
}

public static PtNDArray conj(PtNDArray ndArray) {
return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.conj(ndArray.getHandle()));
}

public static PtNDArray abs(PtNDArray ndArray) {
return new PtNDArray(
ndArray.getManager(), PyTorchLibrary.LIB.torchAbs(ndArray.getHandle()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ native void torchIndexPut(

native long torchFft(long handle, long length, long axis);

native long torchIfft(long handle, long length, long axis);

native long torchRfft(long handle, long length, long axis);

native long torchIrfft(long handle, long length, long axis);

native long torchStft(
long handle,
long nFft,
Expand All @@ -287,6 +293,8 @@ native long torchStft(

native long torchViewAsComplex(long handle);

native long conj(long handle);

native long[] torchSplit(long handle, long size, long dim);

native long[] torchSplit(long handle, long[] indices, long dim);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,33 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft(
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIfft(
JNIEnv* env, jobject jthis, jlong jhandle, jlong jn, jlong jaxis) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* result_ptr = new torch::Tensor(torch::fft_ifft(*tensor_ptr, jn, jaxis));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchRfft(
JNIEnv* env, jobject jthis, jlong jhandle, jlong jn, jlong jaxis) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* result_ptr = new torch::Tensor(torch::fft_rfft(*tensor_ptr, jn, jaxis));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIrfft(
JNIEnv* env, jobject jthis, jlong jhandle, jlong jn, jlong jaxis) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* result_ptr = new torch::Tensor(torch::fft_irfft(*tensor_ptr, jn, jaxis));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft2(
JNIEnv* env, jobject jthis, jlong jhandle, jlongArray js, jlongArray jaxes) {
API_BEGIN()
Expand Down Expand Up @@ -90,6 +117,15 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchViewAsReal(
#endif
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_conj(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* result_ptr = new torch::Tensor(torch::conj(*tensor_ptr));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchViewAsComplex(
JNIEnv* env, jobject jthis, jlong jhandle) {
#ifdef V1_11_X
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,12 @@ public NDArray real() {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray conj() {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray add(Number n) {
Expand Down Expand Up @@ -1173,6 +1179,24 @@ public NDArray fft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray rfft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray ifft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented");
}

/** {@inheritDoc} */
@Override
public NDArray irfft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented");
}

/** {@inheritDoc} */
@Override
public NDArray stft(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,24 @@ public NDArray fft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented");
}

/** {@inheritDoc} */
@Override
public NDArray rfft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented");
}

/** {@inheritDoc} */
@Override
public NDArray ifft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented");
}

/** {@inheritDoc} */
@Override
public NDArray irfft(long length, long axis) {
throw new UnsupportedOperationException("Not implemented");
}

/** {@inheritDoc} */
@Override
public NDArray stft(
Expand Down Expand Up @@ -1589,6 +1607,11 @@ public NDArray real() {
return toArray(RustLibrary.real(getHandle()));
}

@Override
public NDArray conj() {
throw new UnsupportedOperationException("Not implemented");
}

/** {@inheritDoc} */
@Override
public RsNDArrayEx getNDArrayInternal() {
Expand Down
Loading

0 comments on commit 90305cb

Please sign in to comment.