Skip to content

Commit

Permalink
Vectorize differenced process (#446)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris authored Sep 18, 2024
1 parent afbb340 commit 6e9fd3f
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 210 deletions.
117 changes: 115 additions & 2 deletions pyrenew/math.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""
Helper functions for doing analytical
and/or numerical calculations about
a given renewal process.
and/or numerical calculations.
"""

from __future__ import annotations

import jax.numpy as jnp
from jax.lax import broadcast_shapes, scan
from jax.typing import ArrayLike

from pyrenew.distutil import validate_discrete_dist_vector
Expand Down Expand Up @@ -172,3 +172,116 @@ def get_asymptotic_growth_rate(
return get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[
0
]


def integrate_discrete(
init_diff_vals: ArrayLike, highest_order_diff_vals: ArrayLike
) -> ArrayLike:
"""
Integrate (de-difference) the differenced process,
obtaining the process values :math:`X(t=0), X(t=1), ... X(t)`
from the :math:`n^{th}` differences and a set of
initial process / difference values
:math:`X(t=0), X^1(t=1), X^2(t=2), ... X^{(n-1)}(t=n-1)`,
where :math:`X^k(t)` is the value of the :math:`n^{th}`
difference at index :math:`t` of the process,
obtaining a sequence of length equal to the length of
the provided `highest_order_diff_vals` vector plus
the order of the process.
Parameters
----------
init_diff_vals : ArrayLike
Values of
:math:`X(t=0), X^1(t=1), X^2(t=2) ... X^{(n-1)}(t=n-1)`.
highest_order_diff_vals : ArrayLike
Array of differences at the highest order of
differencing, i.e. the order of the overall process,
starting with :math:`X^{n}(t=n)`
Returns
-------
ArrayLike
The integrated (de-differenced) sequence of values,
of length n_diffs + order, where n_diffs is the
number of highest_order_diff_vals and order is the
order of the process.
"""
inits_by_order = jnp.atleast_1d(init_diff_vals)
highest_diffs = jnp.atleast_1d(highest_order_diff_vals)
order = inits_by_order.shape[0]
n_diffs = highest_diffs.shape[0]

try:
batch_shape = broadcast_shapes(
highest_diffs.shape[1:], inits_by_order.shape[1:]
)
except Exception as e:
raise ValueError(
"Non-time dimensions "
"(i.e. dimensions after the first) "
"for highest_order_diff_vals and init_diff_vals "
"must be broadcastable together. "
"Got highest_order_diff_vals of shape "
f"{highest_diffs.shape} and "
"init_diff_vals of shape "
f"{inits_by_order.shape}"
) from e

highest_diffs = jnp.broadcast_to(highest_diffs, (n_diffs,) + batch_shape)
inits_by_order = jnp.broadcast_to(inits_by_order, (order,) + batch_shape)

highest_diffs = jnp.concatenate(
[jnp.zeros_like(inits_by_order), highest_diffs],
axis=0,
)

scan_arrays = (
jnp.arange(start=order - 1, stop=-1, step=-1),
jnp.flip(inits_by_order, axis=0),
)

integrated, _ = scan(
f=_integrate_one_step, init=highest_diffs, xs=scan_arrays
)

return integrated


def _integrate_one_step(
current_diffs: ArrayLike,
next_order_and_init: tuple[int, ArrayLike],
) -> tuple[ArrayLike, None]:
"""
Perform one step of integration
(de-differencing) for integrate_discrete().
Helper function passed to :func:`jax.lax.scan()`.
Parameters
----------
current_diffs: ArrayLike
Array of differences at the current
de-differencing order
next_order_and_init: tuple
Tuple containing with two entries.
First entry: the next order of de-differencing
(the current order - 1) as an integer.
Second entry: the initial value at
that the next order of de-differencing
as an ArrayLike of appropriate shape.
Returns
-------
tuple[ArrayLike, None]
A tuple whose first entry contains the
values at the next order of (de)-differencing
and whose second entry is None.
"""
next_order, next_init = next_order_and_init
next_diffs = jnp.cumsum(
current_diffs.at[next_order, ...].set(next_init), axis=0
)
return next_diffs, None
180 changes: 56 additions & 124 deletions pyrenew/process/differencedprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import jax.numpy as jnp
from jax.typing import ArrayLike
from numpyro.contrib.control_flow import scan

from pyrenew.math import integrate_discrete
from pyrenew.metaclass import RandomVariable


Expand All @@ -18,6 +18,15 @@ class DifferencedProcess(RandomVariable):
https://otexts.com/fpp3/stationarity.html
for a discussion of differencing in the
context of discrete timeseries data.
Notes
-----
The order of differencing is the discrete
analogue of the order of a derivative in single
variable calculus. A first difference (derivative)
represents a rate of change. A second difference
(derivative) represents the rate of change of that
rate of change, et cetera.
"""

def __init__(
Expand All @@ -33,7 +42,7 @@ def __init__(
----------
fundamental_process : RandomVariable
Stochastic process for the
differences. Should accept an
differences. Must accept an
`n` argument specifying the number
of samples to draw.
differencing_order : int
Expand All @@ -45,104 +54,49 @@ def __init__(
of change), 2 a process on the
2nd differences (rate of change of
the rate of change), et cetera.
**kwargs :
Additional keyword arguments passed to
the parent class constructor.
Returns
-------
None
Notes
-----
The order of differencing is the discrete
analogue of the order of a derivative in single
variable calculus. A first difference (derivative)
represents a rate of change. A second difference
(derivative) represents the rate of change of that
rate of change, et cetera.
"""
self.assert_valid_differencing_order(differencing_order)
self.differencing_order = differencing_order
self.fundamental_process = fundamental_process
self.differencing_order = differencing_order
super().__init__(**kwargs)

def integrate(
self, init_diff_vals: ArrayLike, highest_order_diff_vals: ArrayLike
):
@staticmethod
def assert_valid_differencing_order(differencing_order: any):
"""
Integrate (de-difference) the differenced process,
obtaining the process values :math:`X(t=0), X(t=1), ... X(t)`
from the :math:`n^{th}` differences and a set of
initial process / difference values
:math:`X(t=0), X^1(t=1), X^2(t=2), ... X^{(n-1)}(t=n-1)`,
where :math:`X^k(t)` is the value of the :math:`n^{th}`
difference at index :math:`t` of the process,
obtaining a sequence of length equal to the length of
the provided `highest_order_diff_vals` vector plus
the order of the process.
Parameters
----------
init_diff_vals : ArrayLike
Values of
:math:`X(t=0), X^1(t=1), X^2(t=2) ... X^{(n-1)}(t=n-1)`.
highest_order_diff_vals : ArrayLike
Array of differences at the highest order of
differencing, i.e. the order of the overall process,
starting with :math:`X^{n}(t=n)`
To be valid, a differencing order must
be an integer and must be strictly positive.
This function raises a value error if its
argument is not a valid differencing order.
Parameter
---------
differcing_order : any
Potential differencing order to validate.
Returns
-------
The integrated (de-differenced) sequence of values,
of length n_diffs + order, where n_diffs is the
number of highest_order_diff_vals and order is the
order of the process.
None or raises a ValueError
"""
init_arr = jnp.atleast_1d(init_diff_vals)
diff_arr = jnp.atleast_1d(highest_order_diff_vals)
if not init_arr.ndim == 1:
raise ValueError(
"init_diff_vals must be 1-dimensional "
"array or a scalar. "
f"Got {init_diff_vals}"
)
if not diff_arr.ndim == 1:
if not isinstance(differencing_order, int):
raise ValueError(
"highest_order_diff_vals must be a "
"1-dimensional array or a scalar "
f"Got {highest_order_diff_vals}"
"differencing_order must be an integer. "
f"got type {type(differencing_order)} "
f"and value {differencing_order}"
)
n_inits = init_arr.size
if not n_inits == self.differencing_order:
if not differencing_order >= 1:
raise ValueError(
"Must have exactly as many "
"initial difference values as "
"the differencing order, given "
"in the sequence :math:`X(t=0), X^1(t=1),` "
"et cetera. "
f"Got {n_inits} values "
"for a process of order "
f"{self.differencing_order}"
"differencing_order must be an integer "
"greater than or equal to 1. Got "
f"{differencing_order}"
)

def _integrate_one_step(diffs, scanned):
# numpydoc ignore=GL08
order, init = scanned
new_diffs = jnp.cumsum(diffs.at[order].set(init))
return (new_diffs, None)

integrated, _ = scan(
_integrate_one_step,
init=jnp.pad(diff_arr, (self.differencing_order, 0)),
xs=(
jnp.flip(jnp.arange(self.differencing_order)),
jnp.flip(init_arr),
),
)

return integrated
def validate(self):
"""
Empty validation method.
"""
pass

def sample(
self,
Expand All @@ -161,22 +115,22 @@ def sample(
initial values for the :math:`0^{th}` through
:math:`(n-1)^{st}` differences, passed as the
``init_diff_vals`` argument to
:meth:`DifferencedProcess.integrate()`
:func:`integrate_discrete()`
n : int
Number of values to sample. Will sample
``n - self.differencing_order`` values from
:code:`n - differencing_order` values from
:meth:`self.fundamental_process` to ensure
that the de-differenced output is of length
``n``.
:code`n`.
*args :
Additional positional arguments passed to
:meth:`self.fundamental_process.sample()`
fundamental_process_init_vals : ArrayLike
Initial values for the fundamental process.
Passed as the ``init_vals`` keyword argument
Passed as the :arg:`init_vals` keyword argument
to :meth:`self.fundamental_process.sample()`.
**kwargs : dict, optional
Expand All @@ -193,7 +147,22 @@ def sample(
if n < 1:
raise ValueError("n must be positive. " f"Got {n}")

init_vals = jnp.atleast_1d(init_vals)
n_inits = init_vals.shape[0]

if not n_inits == self.differencing_order:
raise ValueError(
"Must have exactly as many "
"initial difference values as "
"the differencing order, given "
"in the sequence :math:`X(t=0), X^1(t=1),` "
"et cetera. "
f"Got {n_inits} values "
"for a process of order "
f"{self.differencing_order}."
)
n_diffs = n - self.differencing_order

if n_diffs > 0:
diff_samp = self.fundamental_process.sample(
*args,
Expand All @@ -204,42 +173,5 @@ def sample(
diffs = diff_samp
else:
diffs = jnp.array([])
integrated_ts = self.integrate(init_vals, diffs)[:n]
integrated_ts = integrate_discrete(init_vals, diffs)[:n]
return integrated_ts

@staticmethod
def validate():
"""
Validates input parameters, implementation pending.
"""
return None

@staticmethod
def assert_valid_differencing_order(differencing_order: any):
"""
To be valid, a differencing order must
be an integer and must be strictly positive.
This function raises a value error if its
argument is not a valid differencing order.
Parameter
---------
differcing_order : any
Potential differencing order to validate.
Returns
-------
None or raises a ValueError
"""
if not isinstance(differencing_order, int):
raise ValueError(
"differencing_order must be an integer. "
f"got type {type(differencing_order)} "
f"and value {differencing_order}"
)
if not differencing_order >= 1:
raise ValueError(
"differencing_order must be an integer "
"greater than or equal to 1. Got "
f"{differencing_order}"
)
Loading

0 comments on commit 6e9fd3f

Please sign in to comment.