Skip to content

Commit

Permalink
WIP: Add test for COOMatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Dec 6, 2023
1 parent d5d4f7c commit e898b7d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/pmhn/_trees/_backend_jax/_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,23 @@

class Values(NamedTuple):
start: Int[Array, " K"]
end: Int[Array, " K"] | None
end: Int[Array, " K"]
value: Float[Array, " K"]

@property
def size(self) -> int:
return self.start.shape[0]


class COOMatrix(NamedTuple):
diagonal: Float[Array, " n_subtrees"]
diagonal: Float[Array, " size"]
offdiagonal: Values
size: int
fill_value: float | Float

@property
def size(self) -> int:
return self.diagonal.shape[0]

def to_dense(self) -> Float[Array, "size size"]:
"""Converts a COO matrix to a dense matrix.
Expand All @@ -28,8 +32,6 @@ def to_dense(self) -> Float[Array, "size size"]:
Depending on the convention used, you may prefer
to transpose it.
"""
# TODO(Pawel): UNTESTED

# Fill the matrix with the fill value
a = jnp.full((self.size, self.size), fill_value=self.fill_value)

Expand All @@ -39,7 +41,7 @@ def _diag_loop_body(
) -> Float[Array, "size size"]:
return a.at[i, i].set(self.diagonal[i])

a = jax.lax.fori_loop(0, self.diagonal.shape[0], _diag_loop_body, a)
a = jax.lax.fori_loop(0, self.size, _diag_loop_body, a)

# Iterate over the off-diagonal terms
def _offdiag_loop_body(
Expand All @@ -49,5 +51,5 @@ def _offdiag_loop_body(
self.offdiagonal.value[i]
)

a = jax.lax.fori_loop(0, self.offdiagonal.size(), _offdiag_loop_body, a)
a = jax.lax.fori_loop(0, self.offdiagonal.size, _offdiag_loop_body, a)
return a
20 changes: 20 additions & 0 deletions tests/trees/backend_jax/test_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import jax.numpy as jnp
import numpy.testing as npt
import pmhn._trees._backend_jax._private_api as api


def test_to_dense() -> None:
mat = api.COOMatrix(
diagonal=1.0 * jnp.arange(1, 4),
offdiagonal=api.Values(
start=jnp.asarray([0, 1]),
end=jnp.asarray([1, 2]),
value=jnp.asarray([8.0, 11.0]),
),
fill_value=0.0,
)

npt.assert_allclose(
mat.to_dense(),
jnp.asarray([[1.0, 8.0, 0.0], [0.0, 2.0, 11.0], [0.0, 0.0, 3.0]]),
)

0 comments on commit e898b7d

Please sign in to comment.