Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New pytorch presets #46

Merged
merged 2 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ val cudaVersion = "12.1-8.9"
val openblasVersion = "0.3.23"
val mklVersion = "2023.1"
ThisBuild / scalaVersion := "3.3.0"
ThisBuild / javaCppVersion := "1.5.9"
ThisBuild / javaCppVersion := "1.5.10-SNAPSHOT"
ThisBuild / resolvers ++= Resolver.sonatypeOssRepos("snapshots")

ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("11"))

Expand All @@ -40,8 +41,7 @@ ThisBuild / enableGPU := false
lazy val commonSettings = Seq(
Compile / doc / scalacOptions ++= Seq("-groups", "-snippet-compiler:compile"),
javaCppVersion := (ThisBuild / javaCppVersion).value,
javaCppPlatform := Seq(),
resolvers ++= Resolver.sonatypeOssRepos("snapshots")
javaCppPlatform := Seq()
// This is a hack to avoid depending on the native libs when publishing
// but conveniently have them on the classpath during development.
// There's probably a cleaner way to do this.
Expand Down Expand Up @@ -75,8 +75,7 @@ lazy val core = project
(if (enableGPU.value) "pytorch-gpu" else "pytorch") -> pytorchVersion,
"mkl" -> mklVersion,
"openblas" -> openblasVersion
// TODO remove cuda (not cuda-redist) once https://github.com/bytedeco/javacpp-presets/issues/1376 is fixed
) ++ (if (enableGPU.value) Seq("cuda-redist" -> cudaVersion, "cuda" -> cudaVersion) else Seq()),
) ++ (if (enableGPU.value) Seq("cuda-redist" -> cudaVersion) else Seq()),
javaCppPlatform := org.bytedeco.sbt.javacpp.Platform.current,
fork := true,
Test / fork := true,
Expand Down
6 changes: 2 additions & 4 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ import spire.math.{Complex, UByte}
import scala.reflect.Typeable
import internal.NativeConverters
import internal.NativeConverters.toArray
import internal.LoadCusolver
import Device.CPU
import Layout.Strided
import org.bytedeco.pytorch.ByteArrayRef
Expand Down Expand Up @@ -342,7 +341,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
import ScalarType.*
val out = native.dtype().toScalarType().intern() match
case Byte => UByte(native.item_int())
case Char => native.item_byte()
case Char => native.item_char()
case Short => native.item_short()
case Int => native.item_int()
case Long => native.item_long()
Expand All @@ -357,7 +356,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
val b = native.contiguous.createBuffer[DoubleBuffer]
Complex(b.get(), b.get())
case Bool => native.item().toBool
case QInt8 => native.item_byte()
case QInt8 => native.item_char()
case QUInt8 => native.item_short()
case QInt32 => native.item_int()
case BFloat16 => native.item().toBFloat16.asFloat()
Expand Down Expand Up @@ -802,7 +801,6 @@ type IntTensor = UInt8Tensor | Int8Tensor | Int16Tensor | Int32Tensor | Int64Ten
type ComplexTensor = Complex32Tensor | Complex64Tensor | Complex128Tensor

object Tensor:
LoadCusolver // TODO workaround for https://github.com/bytedeco/javacpp-presets/issues/1376

def apply[D <: DType](native: pytorch.Tensor): Tensor[D] = (native.scalar_type().intern() match
case ScalarType.Byte => new UInt8Tensor(native)
Expand Down
2 changes: 0 additions & 2 deletions core/src/main/scala/torch/cuda/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
package torch

import org.bytedeco.pytorch.global.torch as torchNative
import torch.internal.LoadCusolver

/** This package adds support for CUDA tensor types, that implement the same function as CPU
* tensors, but they utilize GPUs for computation.
*/
package object cuda {
LoadCusolver

/** Returns a Boolean indicating if CUDA is currently available. */
def isAvailable: Boolean = torchNative.cuda_is_available()
Expand Down
31 changes: 0 additions & 31 deletions core/src/main/scala/torch/internal/LoadCusolver.scala

This file was deleted.

3 changes: 0 additions & 3 deletions core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,9 @@ import org.bytedeco.pytorch.GenericDictIterator
import spire.math.Complex
import spire.math.UByte
import scala.annotation.targetName
import internal.LoadCusolver

private[torch] object NativeConverters:

LoadCusolver // TODO workaround for https://github.com/bytedeco/javacpp-presets/issues/1376

inline def convertToOptional[T, U <: T | Option[T], V >: Null](i: U, f: T => V): V = i match
case i: Option[T] => i.map(f(_)).orNull
case i: T => f(i)
Expand Down
5 changes: 1 addition & 4 deletions core/src/main/scala/torch/nn/functional/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package torch
package nn

import functional.*
import torch.internal.LoadCusolver

/** @groupname nn_conv Convolution functions
* @groupname nn_pooling Pooling functions
Expand All @@ -38,6 +37,4 @@ package object functional
with Linear
with Loss
with Pooling
with Sparse {
LoadCusolver
}
with Sparse
6 changes: 2 additions & 4 deletions core/src/main/scala/torch/nn/init.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.bytedeco.pytorch.global.torch as torchNative
import org.bytedeco.pytorch.FanModeType
import org.bytedeco.pytorch.kFanIn
import org.bytedeco.pytorch.kFanOut
import org.bytedeco.pytorch.NonlinearityType
import org.bytedeco.pytorch.Nonlinearity as NonlinearityNative
import org.bytedeco.pytorch.kLinear
import org.bytedeco.pytorch.kConv1D
import org.bytedeco.pytorch.kConv2D
Expand All @@ -33,11 +33,9 @@ import org.bytedeco.pytorch.kSigmoid
import org.bytedeco.pytorch.kReLU
import org.bytedeco.pytorch.kLeakyReLU
import org.bytedeco.pytorch.Scalar
import torch.internal.LoadCusolver

// TODO implement remaining init functions
object init:
LoadCusolver
def kaimingNormal_(
t: Tensor[?],
a: Double = 0,
Expand All @@ -56,7 +54,7 @@ object init:
enum NonLinearity:
case Linear, Conv1D, Conv2D, Conv3D, ConvTranspose1D, ConvTranspose2D, ConvTranspose3D, Sigmoid,
ReLU, LeakyReLU
private[torch] def toNative: NonlinearityType = NonlinearityType(this match
private[torch] def toNative: NonlinearityNative = NonlinearityNative(this match
case NonLinearity.Linear => kLinear()
case NonLinearity.Conv1D => kConv1D()
case NonLinearity.Conv2D => kConv2D()
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/torch/nn/modules/Module.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ abstract class Module {

def copy(): this.type =
val clone = super.clone().asInstanceOf[Module]
clone._nativeModule = _nativeModule.clone()
clone._nativeModule = _nativeModule.clone(null)
clone.asInstanceOf[this.type]

protected[torch] def registerWithParent[T <: pytorch.Module](parent: T)(using
Expand Down Expand Up @@ -100,15 +100,15 @@ abstract class Module {

def to(device: Device): this.type =
// val nativeCopy = nativeModule.clone()
nativeModule.asModule.to(device.toNative)
nativeModule.asModule.to(device.toNative, false)
// copy
// val clone: this.type = copy()
// clone.nativeModule = nativeCopy
this

def to(dtype: DType, nonBlocking: Boolean = false): this.type =
val nativeCopy = nativeModule.clone()
nativeCopy.asModule.to(dtype.toScalarType)
val nativeCopy = nativeModule.clone(null)
nativeCopy.asModule.to(dtype.toScalarType, false)
this

def save(outputArchive: OutputArchive) = nativeModule.save(outputArchive)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ final class BatchNorm2d[ParamType <: FloatNN | ComplexNN: Default](
options.track_running_stats().put(trackRunningStats)

override private[torch] val nativeModule: BatchNorm2dImpl = BatchNorm2dImpl(options)
nativeModule.asModule.to(paramType.toScalarType)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/torch/nn/modules/conv/Conv2d.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ final class Conv2d[ParamType <: FloatNN | ComplexNN: Default](
options.padding_mode().put(paddingModeNative)

override private[torch] val nativeModule: Conv2dImpl = Conv2dImpl(options)
nativeModule.asModule.to(paramType.toScalarType)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/torch/nn/modules/linear/Linear.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import torch.nn.modules.{HasParams, TensorModule}
*
* This module supports `TensorFloat32<tf32_on_ampere>`.
*
* * Example:
* Example:
*
* ```scala sc:nocompile
* import torch.*
Expand Down Expand Up @@ -57,7 +57,7 @@ final class Linear[ParamType <: FloatNN: Default](
private val options = new LinearOptions(inFeatures, outFeatures)
options.bias().put(bias)
override private[torch] val nativeModule: LinearImpl = new LinearImpl(options)
nativeModule.asModule.to(paramType.toScalarType)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[T <: pytorch.Module](parent: T)(using
name: sourcecode.Name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default](
options.ceil_mode().put(ceilMode)

override private[torch] val nativeModule: MaxPool2dImpl = MaxPool2dImpl(options)
nativeModule.asModule.to(paramType.toScalarType)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
Expand Down
3 changes: 0 additions & 3 deletions core/src/main/scala/torch/nn/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@

package torch

import torch.internal.LoadCusolver

/** These are the basic building blocks for graphs.
*
* @groupname nn_conv Convolution Layers
* @groupname nn_linear Linear Layers
* @groupname nn_utilities Utilities
*/
package object nn {
LoadCusolver

export modules.Module
export modules.Default
Expand Down
2 changes: 0 additions & 2 deletions core/src/main/scala/torch/nn/utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ package nn

import org.bytedeco.pytorch.global.torch as torchNative
import org.bytedeco.pytorch.TensorVector
import torch.internal.LoadCusolver

object utils:
LoadCusolver
def clipGradNorm_(
parameters: Seq[Tensor[?]],
max_norm: Double,
Expand Down
24 changes: 13 additions & 11 deletions core/src/main/scala/torch/ops/IndexingSlicingJoiningOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@ package ops
import internal.NativeConverters.*

import org.bytedeco.pytorch.global.torch as torchNative
import org.bytedeco.pytorch.{TensorArrayRef, TensorVector}
import org.bytedeco.pytorch.TensorArrayRef
import org.bytedeco.pytorch.TensorVector

/** Indexing, Slicing, Joining, Mutating Ops
*
* https://pytorch.org/docs/stable/torch.html#indexing-slicing-joining-mutating-ops
*/
private[torch] trait IndexingSlicingJoiningOps {

private def toArrayRef(tensors: Seq[Tensor[?]]): TensorArrayRef =
new TensorArrayRef(new TensorVector(tensors.map(_.native)*))

/** Returns a view of the tensor conjugated and with the last two dimensions transposed.
*
* `x.adjoint()` is equivalent to `x.transpose(-2, -1).conj()` for complex tensors and to
Expand Down Expand Up @@ -120,7 +124,7 @@ private[torch] trait IndexingSlicingJoiningOps {
* @group indexing_slicing_joining_mutating_ops
*/
def cat[D <: DType](tensors: Seq[Tensor[D]], dim: Int = 0): Tensor[D] = Tensor(
torchNative.cat(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)), dim.toLong)
torchNative.cat(toArrayRef(tensors), dim.toLong)
)

/** Returns a view of `input` with a flipped conjugate bit. If `input` has a non-complex dtype,
Expand Down Expand Up @@ -292,9 +296,8 @@ private[torch] trait IndexingSlicingJoiningOps {
*
* @group indexing_slicing_joining_mutating_ops
*/
def columnStack[D <: DType](tensors: Seq[Tensor[D]]): Tensor[D] = Tensor(
torchNative.column_stack(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)))
)
def columnStack[D <: DType](tensors: Seq[Tensor[D]]): Tensor[D] =
Tensor(torchNative.column_stack(toArrayRef(tensors)))

/** Stack tensors in sequence depthwise (along third axis).
*
Expand Down Expand Up @@ -323,7 +326,7 @@ private[torch] trait IndexingSlicingJoiningOps {
* @group indexing_slicing_joining_mutating_ops
*/
def dstack[D <: DType](tensors: Seq[Tensor[D]]): Tensor[D] = Tensor(
torchNative.dstack(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)))
torchNative.dstack(toArrayRef(tensors))
)

/** Gathers values along an axis specified by `dim`.
Expand Down Expand Up @@ -427,9 +430,8 @@ private[torch] trait IndexingSlicingJoiningOps {
*
* @group indexing_slicing_joining_mutating_ops
*/
def hstack[D <: DType](tensors: Seq[Tensor[D]]): Tensor[D] = Tensor(
torchNative.hstack(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)))
)
def hstack[D <: DType](tensors: Seq[Tensor[D]]): Tensor[D] =
Tensor(torchNative.hstack(toArrayRef(tensors)))

/** Accumulate the elements of `source` into the `input` tensor by adding to the indices in the
* order given in `index`.
Expand Down Expand Up @@ -1080,7 +1082,7 @@ private[torch] trait IndexingSlicingJoiningOps {
* tensors (inclusive)
*/
def stack[D <: DType](tensors: Seq[Tensor[D]], dim: Int = 0): Tensor[D] = Tensor(
torchNative.stack(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)), dim)
torchNative.stack(toArrayRef(tensors), dim)
)

/** Alias for `torch.transpose`.
Expand Down Expand Up @@ -1491,7 +1493,7 @@ private[torch] trait IndexingSlicingJoiningOps {
* @group indexing_slicing_joining_mutating_ops
*/
def vstack[D <: DType](tensors: Seq[Tensor[D]]): Tensor[D] = Tensor(
torchNative.vstack(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)))
torchNative.vstack(toArrayRef(tensors))
)

/** Return a tensor of elements selected from either `input` or `other`, depending on `condition`.
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/torch/ops/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package torch
import internal.NativeConverters
import org.bytedeco.pytorch.global.torch as torchNative
import org.bytedeco.pytorch
import org.bytedeco.pytorch.MemoryFormatOptional
import org.bytedeco.pytorch.{MemoryFormatOptional, TensorArrayRef, TensorVector}

package object ops {

Expand Down
3 changes: 0 additions & 3 deletions core/src/main/scala/torch/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
* limitations under the License.
*/

import torch.internal.LoadCusolver

import scala.util.Using

/** The torch package contains data structures for multi-dimensional tensors and defines
Expand All @@ -37,7 +35,6 @@ package object torch
with ops.PointwiseOps
with ops.RandomSamplingOps
with ops.ReductionOps {
LoadCusolver // TODO workaround for https://github.com/bytedeco/javacpp-presets/issues/1376

/** Disable gradient calculation for [[op]].
*
Expand Down
Loading