From e898b7dd1e28ae560e8933a629c55034142bb6c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Wed, 6 Dec 2023 19:40:05 +0100 Subject: [PATCH] WIP: Add test for COOMatrix --- src/pmhn/_trees/_backend_jax/_sparse.py | 16 +++++++++------- tests/trees/backend_jax/test_sparse.py | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/src/pmhn/_trees/_backend_jax/_sparse.py b/src/pmhn/_trees/_backend_jax/_sparse.py index 3748711..3116b4a 100644 --- a/src/pmhn/_trees/_backend_jax/_sparse.py +++ b/src/pmhn/_trees/_backend_jax/_sparse.py @@ -7,19 +7,23 @@ class Values(NamedTuple): start: Int[Array, " K"] - end: Int[Array, " K"] | None + end: Int[Array, " K"] value: Float[Array, " K"] + @property def size(self) -> int: return self.start.shape[0] class COOMatrix(NamedTuple): - diagonal: Float[Array, " n_subtrees"] + diagonal: Float[Array, " size"] offdiagonal: Values - size: int fill_value: float | Float + @property + def size(self) -> int: + return self.diagonal.shape[0] + def to_dense(self) -> Float[Array, "size size"]: """Converts a COO matrix to a dense matrix. @@ -28,8 +32,6 @@ def to_dense(self) -> Float[Array, "size size"]: Depending on the convention used, you may prefer to transpose it. """ - # TODO(Pawel): UNTESTED - # Fill the matrix with the fill value a = jnp.full((self.size, self.size), fill_value=self.fill_value) @@ -39,7 +41,7 @@ def _diag_loop_body( ) -> Float[Array, "size size"]: return a.at[i, i].set(self.diagonal[i]) - a = jax.lax.fori_loop(0, self.diagonal.shape[0], _diag_loop_body, a) + a = jax.lax.fori_loop(0, self.size, _diag_loop_body, a) # Iterate over the off-diagonal terms def _offdiag_loop_body( @@ -49,5 +51,5 @@ def _offdiag_loop_body( self.offdiagonal.value[i] ) - a = jax.lax.fori_loop(0, self.offdiagonal.size(), _offdiag_loop_body, a) + a = jax.lax.fori_loop(0, self.offdiagonal.size, _offdiag_loop_body, a) return a diff --git a/tests/trees/backend_jax/test_sparse.py b/tests/trees/backend_jax/test_sparse.py index e69de29..d316572 100644 --- a/tests/trees/backend_jax/test_sparse.py +++ b/tests/trees/backend_jax/test_sparse.py @@ -0,0 +1,20 @@ +import jax.numpy as jnp +import numpy.testing as npt +import pmhn._trees._backend_jax._private_api as api + + +def test_to_dense() -> None: + mat = api.COOMatrix( + diagonal=1.0 * jnp.arange(1, 4), + offdiagonal=api.Values( + start=jnp.asarray([0, 1]), + end=jnp.asarray([1, 2]), + value=jnp.asarray([8.0, 11.0]), + ), + fill_value=0.0, + ) + + npt.assert_allclose( + mat.to_dense(), + jnp.asarray([[1.0, 8.0, 0.0], [0.0, 2.0, 11.0], [0.0, 0.0, 3.0]]), + )