Skip to content

Commit

Permalink
Merge pull request #46 from sbrunk/new-pytorch-presets
Browse files Browse the repository at this point in the history
New pytorch presets
  • Loading branch information
sbrunk authored Jul 27, 2023
2 parents e8d6b3f + e2c2825 commit 795485b
Show file tree
Hide file tree
Showing 19 changed files with 38 additions and 96 deletions.
9 changes: 4 additions & 5 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ val cudaVersion = "12.1-8.9"
val openblasVersion = "0.3.23"
val mklVersion = "2023.1"
ThisBuild / scalaVersion := "3.3.0"
ThisBuild / javaCppVersion := "1.5.9"
ThisBuild / javaCppVersion := "1.5.10-SNAPSHOT"
ThisBuild / resolvers ++= Resolver.sonatypeOssRepos("snapshots")

ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("11"))

Expand All @@ -40,8 +41,7 @@ ThisBuild / enableGPU := false
lazy val commonSettings = Seq(
Compile / doc / scalacOptions ++= Seq("-groups", "-snippet-compiler:compile"),
javaCppVersion := (ThisBuild / javaCppVersion).value,
javaCppPlatform := Seq(),
resolvers ++= Resolver.sonatypeOssRepos("snapshots")
javaCppPlatform := Seq()
// This is a hack to avoid depending on the native libs when publishing
// but conveniently have them on the classpath during development.
// There's probably a cleaner way to do this.
Expand Down Expand Up @@ -75,8 +75,7 @@ lazy val core = project
(if (enableGPU.value) "pytorch-gpu" else "pytorch") -> pytorchVersion,
"mkl" -> mklVersion,
"openblas" -> openblasVersion
// TODO remove cuda (not cuda-redist) once https://github.com/bytedeco/javacpp-presets/issues/1376 is fixed
) ++ (if (enableGPU.value) Seq("cuda-redist" -> cudaVersion, "cuda" -> cudaVersion) else Seq()),
) ++ (if (enableGPU.value) Seq("cuda-redist" -> cudaVersion) else Seq()),
javaCppPlatform := org.bytedeco.sbt.javacpp.Platform.current,
fork := true,
Test / fork := true,
Expand Down
6 changes: 2 additions & 4 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ import spire.math.{Complex, UByte}
import scala.reflect.Typeable
import internal.NativeConverters
import internal.NativeConverters.toArray
import internal.LoadCusolver
import Device.CPU
import Layout.Strided
import org.bytedeco.pytorch.ByteArrayRef
Expand Down Expand Up @@ -342,7 +341,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
import ScalarType.*
val out = native.dtype().toScalarType().intern() match
case Byte => UByte(native.item_int())
case Char => native.item_byte()
case Char => native.item_char()
case Short => native.item_short()
case Int => native.item_int()
case Long => native.item_long()
Expand All @@ -357,7 +356,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
val b = native.contiguous.createBuffer[DoubleBuffer]
Complex(b.get(), b.get())
case Bool => native.item().toBool
case QInt8 => native.item_byte()
case QInt8 => native.item_char()
case QUInt8 => native.item_short()
case QInt32 => native.item_int()
case BFloat16 => native.item().toBFloat16.asFloat()
Expand Down Expand Up @@ -802,7 +801,6 @@ type IntTensor = UInt8Tensor | Int8Tensor | Int16Tensor | Int32Tensor | Int64Ten
type ComplexTensor = Complex32Tensor | Complex64Tensor | Complex128Tensor

object Tensor:
LoadCusolver // TODO workaround for https://github.com/bytedeco/javacpp-presets/issues/1376

def apply[D <: DType](native: pytorch.Tensor): Tensor[D] = (native.scalar_type().intern() match
case ScalarType.Byte => new UInt8Tensor(native)
Expand Down
2 changes: 0 additions & 2 deletions core/src/main/scala/torch/cuda/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
package torch

import org.bytedeco.pytorch.global.torch as torchNative
import torch.internal.LoadCusolver

/** This package adds support for CUDA tensor types, that implement the same function as CPU
* tensors, but they utilize GPUs for computation.
*/
package object cuda {
LoadCusolver

/** Returns a Boolean indicating if CUDA is currently available. */
def isAvailable: Boolean = torchNative.cuda_is_available()
Expand Down
31 changes: 0 additions & 31 deletions core/src/main/scala/torch/internal/LoadCusolver.scala

This file was deleted.

3 changes: 0 additions & 3 deletions core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,9 @@ import org.bytedeco.pytorch.GenericDictIterator
import spire.math.Complex
import spire.math.UByte
import scala.annotation.targetName
import internal.LoadCusolver

private[torch] object NativeConverters:

LoadCusolver // TODO workaround for https://github.com/bytedeco/javacpp-presets/issues/1376

inline def convertToOptional[T, U <: T | Option[T], V >: Null](i: U, f: T => V): V = i match
case i: Option[T] => i.map(f(_)).orNull
case i: T => f(i)
Expand Down
5 changes: 1 addition & 4 deletions core/src/main/scala/torch/nn/functional/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package torch
package nn

import functional.*
import torch.internal.LoadCusolver

/** @groupname nn_conv Convolution functions
* @groupname nn_pooling Pooling functions
Expand All @@ -38,6 +37,4 @@ package object functional
with Linear
with Loss
with Pooling
with Sparse {
LoadCusolver
}
with Sparse
6 changes: 2 additions & 4 deletions core/src/main/scala/torch/nn/init.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.bytedeco.pytorch.global.torch as torchNative
import org.bytedeco.pytorch.FanModeType
import org.bytedeco.pytorch.kFanIn
import org.bytedeco.pytorch.kFanOut
import org.bytedeco.pytorch.NonlinearityType
import org.bytedeco.pytorch.Nonlinearity as NonlinearityNative
import org.bytedeco.pytorch.kLinear
import org.bytedeco.pytorch.kConv1D
import org.bytedeco.pytorch.kConv2D
Expand All @@ -33,11 +33,9 @@ import org.bytedeco.pytorch.kSigmoid
import org.bytedeco.pytorch.kReLU
import org.bytedeco.pytorch.kLeakyReLU
import org.bytedeco.pytorch.Scalar
import torch.internal.LoadCusolver

// TODO implement remaining init functions
object init:
LoadCusolver
def kaimingNormal_(
t: Tensor[?],
a: Double = 0,
Expand All @@ -56,7 +54,7 @@ object init:
enum NonLinearity:
case Linear, Conv1D, Conv2D, Conv3D, ConvTranspose1D, ConvTranspose2D, ConvTranspose3D, Sigmoid,
ReLU, LeakyReLU
private[torch] def toNative: NonlinearityType = NonlinearityType(this match
private[torch] def toNative: NonlinearityNative = NonlinearityNative(this match
case NonLinearity.Linear => kLinear()
case NonLinearity.Conv1D => kConv1D()
case NonLinearity.Conv2D => kConv2D()
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/torch/nn/modules/Module.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ abstract class Module {

def copy(): this.type =
val clone = super.clone().asInstanceOf[Module]
clone._nativeModule = _nativeModule.clone()
clone._nativeModule = _nativeModule.clone(null)
clone.asInstanceOf[this.type]

protected[torch] def registerWithParent[T <: pytorch.Module](parent: T)(using
Expand Down Expand Up @@ -100,15 +100,15 @@ abstract class Module {

def to(device: Device): this.type =
// val nativeCopy = nativeModule.clone()
nativeModule.asModule.to(device.toNative)
nativeModule.asModule.to(device.toNative, false)
// copy
// val clone: this.type = copy()
// clone.nativeModule = nativeCopy
this

def to(dtype: DType, nonBlocking: Boolean = false): this.type =
val nativeCopy = nativeModule.clone()
nativeCopy.asModule.to(dtype.toScalarType)
val nativeCopy = nativeModule.clone(null)
nativeCopy.asModule.to(dtype.toScalarType, false)
this

def save(outputArchive: OutputArchive) = nativeModule.save(outputArchive)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ final class BatchNorm2d[ParamType <: FloatNN | ComplexNN: Default](
options.track_running_stats().put(trackRunningStats)

override private[torch] val nativeModule: BatchNorm2dImpl = BatchNorm2dImpl(options)
nativeModule.asModule.to(paramType.toScalarType)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/torch/nn/modules/conv/Conv2d.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ final class Conv2d[ParamType <: FloatNN | ComplexNN: Default](
options.padding_mode().put(paddingModeNative)

override private[torch] val nativeModule: Conv2dImpl = Conv2dImpl(options)
nativeModule.asModule.to(paramType.toScalarType)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/torch/nn/modules/linear/Linear.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import torch.nn.modules.{HasParams, TensorModule}
*
* This module supports `TensorFloat32<tf32_on_ampere>`.
*
* * Example:
* Example:
*
* ```scala sc:nocompile
* import torch.*
Expand Down Expand Up @@ -57,7 +57,7 @@ final class Linear[ParamType <: FloatNN: Default](
private val options = new LinearOptions(inFeatures, outFeatures)
options.bias().put(bias)
override private[torch] val nativeModule: LinearImpl = new LinearImpl(options)
nativeModule.asModule.to(paramType.toScalarType)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[T <: pytorch.Module](parent: T)(using
name: sourcecode.Name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default](
options.ceil_mode().put(ceilMode)

override private[torch] val nativeModule: MaxPool2dImpl = MaxPool2dImpl(options)
nativeModule.asModule.to(paramType.toScalarType)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
Expand Down
3 changes: 0 additions & 3 deletions core/src/main/scala/torch/nn/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@

package torch

import torch.internal.LoadCusolver

/** These are the basic building blocks for graphs.
*
* @groupname nn_conv Convolution Layers
* @groupname nn_linear Linear Layers
* @groupname nn_utilities Utilities
*/
package object nn {
LoadCusolver

export modules.Module
export modules.Default
Expand Down
2 changes: 0 additions & 2 deletions core/src/main/scala/torch/nn/utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ package nn

import org.bytedeco.pytorch.global.torch as torchNative
import org.bytedeco.pytorch.TensorVector
import torch.internal.LoadCusolver

object utils:
LoadCusolver
def clipGradNorm_(
parameters: Seq[Tensor[?]],
max_norm: Double,
Expand Down
24 changes: 13 additions & 11 deletions core/src/main/scala/torch/ops/IndexingSlicingJoiningOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@ package ops
import internal.NativeConverters.*

import org.bytedeco.pytorch.global.torch as torchNative
import org.bytedeco.pytorch.{TensorArrayRef, TensorVector}
import org.bytedeco.pytorch.TensorArrayRef
import org.bytedeco.pytorch.TensorVector

/** Indexing, Slicing, Joining, Mutating Ops
*
* https://pytorch.org/docs/stable/torch.html#indexing-slicing-joining-mutating-ops
*/
private[torch] trait IndexingSlicingJoiningOps {

private def toArrayRef(tensors: Seq[Tensor[?]]): TensorArrayRef =
new TensorArrayRef(new TensorVector(tensors.map(_.native)*))

/** Returns a view of the tensor conjugated and with the last two dimensions transposed.
*
* `x.adjoint()` is equivalent to `x.transpose(-2, -1).conj()` for complex tensors and to
Expand Down Expand Up @@ -120,7 +124,7 @@ private[torch] trait IndexingSlicingJoiningOps {
* @group indexing_slicing_joining_mutating_ops
*/
def cat[D <: DType](tensors: Seq[Tensor[D]], dim: Int = 0): Tensor[D] = Tensor(
torchNative.cat(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)), dim.toLong)
torchNative.cat(toArrayRef(tensors), dim.toLong)
)

/** Returns a view of `input` with a flipped conjugate bit. If `input` has a non-complex dtype,
Expand Down Expand Up @@ -292,9 +296,8 @@ private[torch] trait IndexingSlicingJoiningOps {
*
* @group indexing_slicing_joining_mutating_ops
*/
def columnStack[D <: DType](tensors: Seq[Tensor[D]]): Tensor[D] = Tensor(
torchNative.column_stack(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)))
)
def columnStack[D <: DType](tensors: Seq[Tensor[D]]): Tensor[D] =
Tensor(torchNative.column_stack(toArrayRef(tensors)))

/** Stack tensors in sequence depthwise (along third axis).
*
Expand Down Expand Up @@ -323,7 +326,7 @@ private[torch] trait IndexingSlicingJoiningOps {
* @group indexing_slicing_joining_mutating_ops
*/
def dstack[D <: DType](tensors: Seq[Tensor[D]]): Tensor[D] = Tensor(
torchNative.dstack(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)))
torchNative.dstack(toArrayRef(tensors))
)

/** Gathers values along an axis specified by `dim`.
Expand Down Expand Up @@ -427,9 +430,8 @@ private[torch] trait IndexingSlicingJoiningOps {
*
* @group indexing_slicing_joining_mutating_ops
*/
def hstack[D <: DType](tensors: Seq[Tensor[D]]): Tensor[D] = Tensor(
torchNative.hstack(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)))
)
def hstack[D <: DType](tensors: Seq[Tensor[D]]): Tensor[D] =
Tensor(torchNative.hstack(toArrayRef(tensors)))

/** Accumulate the elements of `source` into the `input` tensor by adding to the indices in the
* order given in `index`.
Expand Down Expand Up @@ -1080,7 +1082,7 @@ private[torch] trait IndexingSlicingJoiningOps {
* tensors (inclusive)
*/
def stack[D <: DType](tensors: Seq[Tensor[D]], dim: Int = 0): Tensor[D] = Tensor(
torchNative.stack(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)), dim)
torchNative.stack(toArrayRef(tensors), dim)
)

/** Alias for `torch.transpose`.
Expand Down Expand Up @@ -1491,7 +1493,7 @@ private[torch] trait IndexingSlicingJoiningOps {
* @group indexing_slicing_joining_mutating_ops
*/
def vstack[D <: DType](tensors: Seq[Tensor[D]]): Tensor[D] = Tensor(
torchNative.vstack(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)))
torchNative.vstack(toArrayRef(tensors))
)

/** Return a tensor of elements selected from either `input` or `other`, depending on `condition`.
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/torch/ops/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package torch
import internal.NativeConverters
import org.bytedeco.pytorch.global.torch as torchNative
import org.bytedeco.pytorch
import org.bytedeco.pytorch.MemoryFormatOptional
import org.bytedeco.pytorch.{MemoryFormatOptional, TensorArrayRef, TensorVector}

package object ops {

Expand Down
3 changes: 0 additions & 3 deletions core/src/main/scala/torch/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
* limitations under the License.
*/

import torch.internal.LoadCusolver

import scala.util.Using

/** The torch package contains data structures for multi-dimensional tensors and defines
Expand All @@ -37,7 +35,6 @@ package object torch
with ops.PointwiseOps
with ops.RandomSamplingOps
with ops.ReductionOps {
LoadCusolver // TODO workaround for https://github.com/bytedeco/javacpp-presets/issues/1376

/** Disable gradient calculation for [[op]].
*
Expand Down
Loading

0 comments on commit 795485b

Please sign in to comment.