Skip to content

Commit

Permalink
TiarkRompf#1 Allow more heterogenous types in numeric operations
Browse files Browse the repository at this point in the history
  • Loading branch information
julienrf committed Oct 4, 2012
1 parent 0c2d15e commit a806f1d
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 17 deletions.
51 changes: 35 additions & 16 deletions src/common/NumericOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,38 @@ import java.io.PrintWriter
import scala.reflect.SourceContext

trait LiftNumeric {
this: Base =>
this: Base with NumericOps =>

// HACK The Numeric context bound is not *required* but it is useful to reduce the applicability of this implicit conversion
implicit def numericToNumericRep[T:Numeric:Manifest](x: T) = unit(x)
// Explicit `1 + unit(1)` support because it needs two implicit conversions (FIXME Doesn’t work)
implicit def anyToNumericOps[A](a: A)(implicit lift: A => Rep[A]) = new NumericOpsCls(lift(a))
// implicit def numericToNumericOps[A : Numeric : Manifest](a: A) = new NumericOpsCls(unit(a))
}

trait NumericOps extends Variables {

// workaround for infix not working with manifests
implicit def numericToNumericOps[T:Numeric:Manifest](n: T) = new NumericOpsCls(unit(n))
implicit def repNumericToNumericOps[T:Numeric:Manifest](n: Rep[T]) = new NumericOpsCls(n)
implicit def varNumericToNumericOps[T:Numeric:Manifest](n: Var[T]) = new NumericOpsCls(readVar(n))

class NumericOpsCls[T:Numeric:Manifest](lhs: Rep[T]){
def +[A](rhs: A)(implicit c: A => T, pos: SourceContext) = numeric_plus(lhs,unit(c(rhs)))
def +(rhs: Rep[T])(implicit pos: SourceContext) = numeric_plus(lhs,rhs)
def -(rhs: Rep[T])(implicit pos: SourceContext) = numeric_minus(lhs,rhs)
def *(rhs: Rep[T])(implicit pos: SourceContext) = numeric_times(lhs,rhs)
def /(rhs: Rep[T])(implicit pos: SourceContext) = numeric_divide(lhs,rhs)
// Type constraints allowing an eventual type promotion (e.g. Int to Float) before performing the numeric operation
object NumericOpsTypes {
trait Args { type Lhs; type Rhs }
trait ~[A, B] extends Args { type Lhs = A; type Rhs = B }
class :=[A <: Args, B](val lhs: Rep[A#Lhs] => Rep[B], val rhs: Rep[A#Rhs] => Rep[B])(implicit val Numeric: Numeric[B])
}

//def infix_+[T:Numeric:Manifest](lhs: Rep[T], rhs: Rep[T]) = numeric_plus(lhs,rhs)
//def infix_-[T:Numeric:Manifest](lhs: Rep[T], rhs: Rep[T]) = numeric_minus(lhs,rhs)
//def infix_*[T:Numeric:Manifest](lhs: Rep[T], rhs: Rep[T]) = numeric_times(lhs,rhs)
import NumericOpsTypes._
implicit def numericSameArgs[A : Numeric] = new (A ~ A := A) (identity, identity)

/* FIXME
* I’d like to define numeric operators as follows:
* def infix_+[A, B](lhs: A, rhs: B)(implicit someAdditionnalConstraints...)
* But this signature leads to an ambiguous reference to overloaded definition with an infix_+(s: String, a: Any) method defined in EmbeddedControls (?)
*/
implicit class NumericOpsCls[A](lhs: Rep[A]) {
def + [B, C](rhs: Rep[B])(implicit op: (A ~ B := C), mC: Manifest[C], sc: SourceContext) = numeric_plus(op.lhs(lhs), op.rhs(rhs))(op.Numeric, mC, sc)
def - [B, C](rhs: Rep[B])(implicit op: (A ~ B := C), mC: Manifest[C], sc: SourceContext) = numeric_minus(op.lhs(lhs), op.rhs(rhs))(op.Numeric, mC, sc)
def * [B, C](rhs: Rep[B])(implicit op: (A ~ B := C), mC: Manifest[C], sc: SourceContext) = numeric_times(op.lhs(lhs), op.rhs(rhs))(op.Numeric, mC, sc)
def / [B, C](rhs: Rep[B])(implicit op: (A ~ B := C), mC: Manifest[C], sc: SourceContext) = numeric_divide(op.lhs(lhs), op.rhs(rhs))(op.Numeric, mC, sc)
}
implicit def varNumericToNumericOps[T : Numeric : Manifest](n: Var[T]) = new NumericOpsCls(readVar(n))

def numeric_plus[T:Numeric:Manifest](lhs: Rep[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T]
def numeric_minus[T:Numeric:Manifest](lhs: Rep[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T]
Expand All @@ -38,6 +47,16 @@ trait NumericOps extends Variables {
//def numeric_signum[T:Numeric](x: T): Rep[Int]
}

/*
* Enable promotion of arguments involved in a numeric operation provided there exists an implicit conversion to perform the promotion.
* For instance, it allows to mix Int values and Double values in a numeric operation.
*/
trait NumericPromotions { this: ImplicitOps with NumericOps =>
import NumericOpsTypes._
implicit def numericPromoteLhs[A : Manifest, B : Numeric : Manifest](implicit aToB: A => B) = new (A ~ B := B) (lhs = implicit_convert[A, B](_), rhs = identity)
implicit def numericPromoteRhs[A : Manifest, B : Numeric : Manifest](implicit aToB: A => B) = new (B ~ A := B) (lhs = identity, rhs = implicit_convert[A, B](_))
}

trait NumericOpsExp extends NumericOps with VariablesExp with BaseFatExp {
abstract class DefMN[A:Manifest:Numeric] extends Def[A] {
def mev = manifest[A]
Expand Down
2 changes: 1 addition & 1 deletion test-src/epfl/test11-shonan/TestHMM.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class TestHMM extends FileDiffSuite {
if ((a(i) filter (_ != 0)).length < 3) {
for (j <- 0 until n: Range) {
if (a(i)(j) != 0)
v1(i) = v1(i) + a(i)(j) * v(j)
v1(i) = v1(i) + unit(a(i)(j)) * v(j)
}
} else {
for (j <- 0 until n: Rep[Range]) {
Expand Down
64 changes: 64 additions & 0 deletions test-src/epfl/test13-numeric/TestNumeric.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package scala.virtualization.lms
package epfl
package test13

import common._

// Does nothing but checks the code compiles
trait TestNumeric {

def typed[A](a: => A) {}

trait Usage { this: Base with NumericOps =>

val a = unit(1) + unit(1)
typed[Rep[Int]](a)

val b = unit(1.0) + unit(1.0)
typed[Rep[Double]](b)
}

trait UsageWithLift { this: Base with NumericOps with LiftNumeric =>

// val a = 1 + unit(1)
val a = anyToNumericOps(1) + unit(1)
typed[Rep[Int]](a)

// val b = 1.0 + unit(1.0)
val b = anyToNumericOps(1.0) + unit(1.0)
typed[Rep[Double]](b)

val c = unit(1) + 1
typed[Rep[Int]](c)

val d = unit(1.0) + 1.0
typed[Rep[Double]](d)
}

trait UsageWithPromotions { this: Base with NumericOps with NumericPromotions =>

val a = unit(1) + unit(1.0)
typed[Rep[Double]](a)

val b = unit(1.0) + unit(1)
typed[Rep[Double]](b)
}

trait UsageWithPromotionsAndLift { this: Base with NumericOps with NumericPromotions with LiftNumeric =>

val a = unit(1) + 1.0
typed[Rep[Double]](a)

// val b = 1 + unit(1.0)
val b = anyToNumericOps(1) + unit(1.0)
typed[Rep[Double]](b)

val c = unit(1.0) + 1
typed[Rep[Double]](c)
// val d = 1.0 + unit(1)

val d = anyToNumericOps(1.0) + unit(1)
typed[Rep[Double]](d)
}

}

0 comments on commit a806f1d

Please sign in to comment.