Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

AdOptimize #98

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ jobs:
contents: read
packages: write
steps:
- name: set up jdk 11
- name: set up jdk 8
uses: actions/setup-java@v2
with:
java-version: '11'
java-version: '8'
distribution: 'adopt'
- name: install diffkt dependencies
run: |
Expand Down
20 changes: 19 additions & 1 deletion kotlin/api/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.jetbrains.dokka.gradle.DokkaTask

plugins {
`maven-publish`
id("meta-diffkt-differentiable-api-preprocessor") version "0.0.1.3"
id("shapeKt") version "1.0"
id("org.jetbrains.dokka") version "1.6.0"
}
Expand All @@ -28,6 +29,23 @@ repositories {
mavenCentral()
}

differentiableApiPreprocessor {
this.stackImplAnnotation("org.diffkt.adOptimize.StackImpl")
this.boxedPrimitive("org.diffkt.adOptimize.BoxedPrimitive")
this.scalarRoot("org.diffkt.adOptimize.ScalarRoot")
this.primalAndPullbackAnnotation("org.diffkt.adOptimize.PrimalAndPullback")
this.reverseAnnotation("org.diffkt.adOptimize.ReverseDifferentiable")
this.unboxedFunction("org.diffkt.adOptimize.ToUnboxedFunction")
val userDir = System.getProperty("user.dir")
val pathToResources = "$userDir/api/src/main/resources"
this.resourcesPath(pathToResources)
this.toReverseAnnotation("org.diffkt.adOptimize.ToReverse")
this.dTensorAnnotation("org.diffkt.adOptimize.DTensorRoot")
this.reverseScalarOperationsAnnotation("org.diffkt.adOptimize.ReverseScalarOperations")
this.scalarNoop("org.diffkt.adOptimize.ScalarNoop")
this.forwardDifferentiable("org.diffkt.adOptimize.ForwardDifferentiable")
}

dependencies {
implementation(group = "net.bytebuddy", name = "byte-buddy", version="1.12.7")
compileOnly("shapeKt:annotations:1.0")
Expand Down Expand Up @@ -63,7 +81,7 @@ tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile>() {
publishing {
publications {
create<MavenPublication>("maven") {
groupId = "org.diffkt"
groupId = "org.diffkt.adopt"
artifactId = "api"
version = project.version.toString()
from(components["java"])
Expand Down
51 changes: 51 additions & 0 deletions kotlin/api/src/main/kotlin/org/diffkt/ADConfig.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.diffkt.adOptimize

annotation class StackImpl
annotation class BoxedPrimitive(val value:String)
annotation class ScalarRoot
annotation class ToUnboxedFunction(val functionName:String)
annotation class ToReverse(val fqClass:String)
annotation class DTensorRoot
annotation class ReverseScalarOperations
annotation class ForwardDifferentiable(val tangentProperty: String)

/**
* a function is marked as a scalar noop if it accepts exactly one potentially
* active operand and returns the same type. The return value MUST be the implicit receiver
*/
annotation class ScalarNoop

/**
* The reverse differentiable scalar type should be annotated with this annotation
* so the compiler plugin can build implementations of reverse nodes
*/
annotation class ReverseDifferentiable(val primalField:String, val upstreamField:String, val backpropogateMethod:String, val pushbackMethod:String, val derivativeID:String)

/**
* This is used by the compiler when it cannot inline a function call.
* The compiler expects the signature of the function annotated with this annotation
* to be (DTensor, (DTensor) -> DTensor) -> Pair<DTensor, (DTensor)->DTensor>
*/
annotation class PrimalAndPullback

@StackImpl
class CodeGenStack<T> {
val data = arrayListOf<T>()
fun push(d:T){
data.add(d)
}
fun pop():T {
val x = data.last()
data.removeLast()
return x
}
fun top():T = data.last()
fun notEmpty() = data.isNotEmpty()
}
2 changes: 2 additions & 0 deletions kotlin/api/src/main/kotlin/org/diffkt/DScalar.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.diffkt

import shapeTyping.annotations.SType
import org.diffkt.adOptimize.ScalarRoot

/**
* A differentiable scalar (float).
Expand All @@ -16,6 +17,7 @@ import shapeTyping.annotations.SType
* - a [ForwardScalar] for forward differentiation with a [DScalar] primal value, and a [DTensor] tangent, or
* - a [ReverseScalar] for reverse mode differentiation.
*/
@ScalarRoot
interface DScalar : @SType("[]") DTensor {

/**
Expand Down
2 changes: 2 additions & 0 deletions kotlin/api/src/main/kotlin/org/diffkt/DTensor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.diffkt

import shapeTyping.annotations.SType
import org.diffkt.adOptimize.DTensorRoot

/**
* Interface for a differentiable tensor.
Expand All @@ -17,6 +18,7 @@ import shapeTyping.annotations.SType
* - a [ForwardTensor] for forward differentiation, or
* - a [ReverseTensor] for reverse mode differentiation.
*/
@DTensorRoot
@SType("S: Shape")
interface DTensor: Differentiable<DTensor> {

Expand Down
3 changes: 3 additions & 0 deletions kotlin/api/src/main/kotlin/org/diffkt/FloatScalar.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@

package org.diffkt

import org.diffkt.adOptimize.BoxedPrimitive

/**
* A differentiable tensor of rank 0 containing a single float (a FloatTensor wrapper around a float).
*
* @property value The floating point number to wrap with a FloatTensor.
* @constructor Creates a FloatScalar initialized to value.
*/
@BoxedPrimitive("value")
class FloatScalar(val value: Float) : FloatTensor(), DScalar {
override val derivativeID: DerivativeID get() = NoDerivativeID
override val operations: Operations
Expand Down
3 changes: 3 additions & 0 deletions kotlin/api/src/main/kotlin/org/diffkt/Power.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.diffkt

import org.diffkt.adOptimize.ToUnboxedFunction
import shapeTyping.annotations.SType

// Tensor powers
Expand All @@ -17,10 +18,12 @@ fun DTensor.pow(x: DScalar): DTensor = exp(x * ln(this))
// a^b == e^(b ln a)
fun DScalar.pow(x: DScalar): DScalar = exp(x * ln(this))

@ToUnboxedFunction("kotlin.math.pow")
fun DScalar.pow(x: Float): DScalar = (this as DTensor).pow(x) as DScalar

fun DScalar.pow(x: Int): DScalar = (this as DTensor).pow(x.toFloat()) as DScalar

@ToUnboxedFunction("kotlin.math.pow")
fun DTensor.pow(x: Int): DTensor = this.pow(x.toFloat())

@SType("S: Shape")
Expand Down
2 changes: 2 additions & 0 deletions kotlin/api/src/main/kotlin/org/diffkt/ReverseDerivative.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package org.diffkt
import org.diffkt.reverse.ReverseDerivativeID
import org.diffkt.reverse.ReverseScalar
import org.diffkt.reverse.ReverseTensor
import org.diffkt.adOptimize.PrimalAndPullback

// ********** General Reverse Derivative of Univariate Functions **********

Expand Down Expand Up @@ -54,6 +55,7 @@ fun vjp(
) = primalAndVjp(x, vf, f).second

// Also known as VJP (vector-Jacobian product)
@PrimalAndPullback
internal fun primalAndPullback(
x: DTensor,
f: (x: DTensor) -> DTensor
Expand Down
2 changes: 2 additions & 0 deletions kotlin/api/src/main/kotlin/org/diffkt/Sigmoid.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package org.diffkt

import kotlin.math.exp
import shapeTyping.annotations.SType
import org.diffkt.adOptimize.ToUnboxedFunction

/**
* Compute the sigmoid for a single floating-point value.
Expand All @@ -23,6 +24,7 @@ internal fun sigmoidElem(x: Float): Float {
}
}

@ToUnboxedFunction("org.diffkt.sigmoidElem")
fun sigmoid(x: DScalar): DScalar {
return x.operations.sigmoid(x) as DScalar
}
Expand Down
9 changes: 9 additions & 0 deletions kotlin/api/src/main/kotlin/org/diffkt/Transcendental.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,29 @@

package org.diffkt

import org.diffkt.adOptimize.ToUnboxedFunction

// ==========
// Transcendentals
// ==========

// Trig
@ToUnboxedFunction("kotlin.math.sin")
fun sin(x: DScalar): DScalar {
return x.operations.sin(x) as DScalar
}

@ToUnboxedFunction("kotlin.math.sin")
fun sin(x: DTensor): DTensor {
return x.operations.sin(x)
}

@ToUnboxedFunction("kotlin.math.cos")
fun cos(x: DScalar): DScalar {
return x.operations.cos(x) as DScalar
}

@ToUnboxedFunction("kotlin.math.cos")
fun cos(x: DTensor): DTensor {
return x.operations.cos(x)
}
Expand All @@ -45,6 +51,7 @@ fun atan(x: DTensor): DTensor {
return x.operations.atan(x)
}

@ToUnboxedFunction("kotlin.math.exp")
fun exp(x: DScalar): DScalar {
return x.operations.exp(x) as DScalar
}
Expand All @@ -53,10 +60,12 @@ fun exp(x: DTensor): DTensor {
return x.operations.exp(x)
}

@ToUnboxedFunction("kotlin.math.ln")
fun ln(x: DScalar): DScalar {
return x.operations.ln(x) as DScalar
}

@ToUnboxedFunction("kotlin.math.ln")
fun ln(x: DTensor): DTensor {
return x.operations.ln(x)
}
Expand Down
2 changes: 2 additions & 0 deletions kotlin/api/src/main/kotlin/org/diffkt/Utils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package org.diffkt

import org.diffkt.external.ExternalLib
import org.diffkt.reverse.ReverseTensor
import org.diffkt.adOptimize.ScalarNoop

internal fun identityGradientofSameKind(x: DTensor, halfShape: Shape = x.shape): DTensor {
return x.operations.identityGradientOfSameKind(x, halfShape)
Expand Down Expand Up @@ -82,6 +83,7 @@ val IntArray.product get() = run {
product
}

@ScalarNoop
fun DTensor.expandToTangent(tangent: DTensor): DTensor {
if (this.shape == tangent.shape) return this
val ones = Shape(IntArray(tangent.shape.rank - this.shape.rank) { 1 })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ package org.diffkt.forward
import org.diffkt.DScalar
import org.diffkt.DTensor
import org.diffkt.Operations
import org.diffkt.adOptimize.ForwardDifferentiable

/**
* A differentiable dual scalar (for forward derivatives)
*/
@ForwardDifferentiable("tangent")
open class ForwardScalar protected constructor(
primal: DScalar,
derivativeID: ForwardDerivativeID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ package org.diffkt.reverse
import org.diffkt.DScalar
import org.diffkt.gpu.GpuFloatScalar
import org.diffkt.Operations
import org.diffkt.adOptimize.ReverseDifferentiable

/**
* A scalar for reverse mode differentiation.
*/
@ReverseDifferentiable("primal", "upstream", "backpropagate", "pushback", "derivativeID")
abstract class ReverseScalar(override val primal: DScalar, derivativeID: ReverseDerivativeID) : ReverseTensor(primal, derivativeID),
DScalar {
override val operations: Operations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import org.diffkt.*
import org.diffkt.model.BatchNormResult
import org.diffkt.random.RandomKey
import shapeTyping.annotations.AllowUnreduced
import org.diffkt.adOptimize.ReverseScalarOperations

@ReverseScalarOperations
@AllowUnreduced
internal open class ReverseScalarOperationsImpl: Operations {
override val name get() = "ReverseScalar"
Expand Down
6 changes: 3 additions & 3 deletions kotlin/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ allprojects {
apply(plugin = "java")

java {
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
}

tasks.withType<KotlinCompile>() {
kotlinOptions.jvmTarget = "11"
kotlinOptions.jvmTarget = "1.8"
kotlinOptions.freeCompilerArgs += "-XXLanguage:+ProperCheckAnnotationsTargetInTypeUsePositions"
}

Expand Down