Skip to content

Commit

Permalink
WIP: add Vector inner and outer products.
Browse files Browse the repository at this point in the history
I think these methods are worth adding to the main API.
The way in which we do this is SuiteSparse-specific, and we can add recipes
for other implementations if/when that time comes.

Note that `semiring(w @ v)` is sugar for `w.inner(v)`.

TODO:
- [ ] document
- [ ] test
- [ ] allow `v.outer` to accept a `BinaryOp`, `Monoid`, or `Semiring`

xref: DrTimothyAldenDavis/GraphBLAS#57
  • Loading branch information
eriknw committed Jul 13, 2021
1 parent 7095fd1 commit 1f0fecf
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 3 deletions.
14 changes: 12 additions & 2 deletions grblas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,16 @@ def _update(self, delayed, mask=None, accum=None, replace=False, input_mask=None
output_replace=replace,
)
if self._is_scalar:
args = [_Pointer(self), accum]
cfunc_name = delayed.cfunc_name.format(output_dtype=self.dtype)
is_fake_scalar = delayed.method_name == "inner"
if is_fake_scalar:
from .vector import Vector

fake_self = Vector.new(self.dtype, size=1)
args = [fake_self, mask, accum]
cfunc_name = delayed.cfunc_name
else:
args = [_Pointer(self), accum]
cfunc_name = delayed.cfunc_name.format(output_dtype=self.dtype)
else:
args = [self, mask, accum]
cfunc_name = delayed.cfunc_name
Expand All @@ -400,6 +408,8 @@ def _update(self, delayed, mask=None, accum=None, replace=False, input_mask=None
# Make the GraphBLAS call
call(cfunc_name, args)
if self._is_scalar:
if is_fake_scalar:
self.value = fake_self[0].value
self._is_empty = False

@property
Expand Down
14 changes: 13 additions & 1 deletion grblas/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,14 @@ def ncols(self):
return self._ncols


class ScalarMatMulExpr(InfixExprBase):
__slots__ = ()
method_name = "inner"
output_type = None # ScalarExpression
_infix = "@"
_example_op = "plus_times"


def _ewise_infix_expr(left, right, *, method, within):
from .vector import Vector
from .matrix import Matrix, TransposedMatrix
Expand Down Expand Up @@ -518,6 +526,8 @@ def _matmul_infix_expr(left, right, *, within):
if left_type is Vector:
if right_type is Matrix or right_type is TransposedMatrix:
method = "vxm"
elif right_type is Vector:
method = "inner"
else:
left._expect_type(
right,
Expand Down Expand Up @@ -562,4 +572,6 @@ def _matmul_infix_expr(left, right, *, within):
expr = getattr(left, method)(right, any_pair[bool])
if expr.output_type is Vector:
return VectorMatMulExpr(left, right, method_name=method, size=expr._size)
return MatrixMatMulExpr(left, right, nrows=expr.nrows, ncols=expr.ncols)
elif expr.output_type is Matrix:
return MatrixMatMulExpr(left, right, nrows=expr.nrows, ncols=expr.ncols)
return ScalarMatMulExpr(left, right)
51 changes: 51 additions & 0 deletions grblas/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,41 @@ def reduce(self, op=monoid.plus):
op=op, # to be determined later
)

# Unofficial methods
def inner(self, other, op=semiring.plus_times):
method_name = "inner"
self._expect_type(other, Vector, within=method_name, argname="other")
op = get_typed_op(op, self.dtype, other.dtype)
self._expect_op(op, "Semiring", within=method_name, argname="op")
expr = ScalarExpression(
method_name,
"GrB_vxm",
[self, _VectorAsMatrix(other)],
op=op,
# bt=other._is_transposed,
)
if self._size != other._size:
expr.new(name="") # incompatible shape; raise now
return expr

def outer(self, other, op=semiring.plus_times):
from .matrix import MatrixExpression

method_name = "outer"
self._expect_type(other, Vector, within=method_name, argname="other")
op = get_typed_op(op, self.dtype, other.dtype)
self._expect_op(op, "Semiring", within=method_name, argname="op")
expr = MatrixExpression(
method_name,
"GrB_mxm",
[_VectorAsMatrix(self), _VectorAsMatrix(other)],
op=op,
nrows=self._size,
ncols=other._size,
bt=True,
)
return expr

##################################
# Extract and Assign index methods
##################################
Expand Down Expand Up @@ -702,6 +737,22 @@ def size(self):
return self._size


class _VectorAsMatrix:
__slots__ = "vector"

def __init__(self, vector):
self.vector = vector

@property
def _carg(self):
# SS, SuiteSparse-specific: casting Vector to Matrix
return ffi.cast("GrB_Matrix*", self.vector.gb_obj)[0]

@property
def name(self):
return f"(GrB_Matrix){self.vector.name}"


expr.VectorEwiseAddExpr.output_type = VectorExpression
expr.VectorEwiseMultExpr.output_type = VectorExpression
expr.VectorMatMulExpr.output_type = VectorExpression

0 comments on commit 1f0fecf

Please sign in to comment.