Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tackle Typing and Linting Errors #379

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open

Tackle Typing and Linting Errors #379

wants to merge 27 commits into from

Conversation

gileshd
Copy link
Collaborator

@gileshd gileshd commented Sep 12, 2024

Summary

This PR is an initial step in an overhaul of the type annotations in the code base.

Details

The aim of the PR is to:

  • respond to some review comments
  • fix some central core typing problems
  • lay the groundwork for further typing improvements
  • update typing, particularly jaxtyping, in hmm code

Further updates to other parts of the codebase will come in subsequent PRs.

Discussion

The state of PRNGKey typing in this PR is a bit inconsistent

  • Sometimes keys are typed as dynamax.types.PRNGKeyT (which is an alias for jaxtyping.Array) and sometimes directly as jaxtyping.Array.
  • Perhaps it would be better pick one and stick with it.
  • Relevant discussion in jax docs here.

There are a lot of commmits here! After review we can perhaps squash many of the HMM typing commits into appropriate groups.

Commit Overview

This is a brief categorisation of the commits in this PR:

  • Fix linting errors:
  • General jaxtyping fix:
  • PRNGKey typing:
    • ca9a75c Fix jr.PRNGKey type hints
    • 75205f8 Rename and change PRNGKey Type
  • Dynamax typing:
  • Misc typing:
    • 7da57ef Minor arg and type changes in utils/utils.py
  • HMM Typing:
    • HMM Core:
      • 62df3ba Update HMM[Parameter|Property]Set protocols
      • 45134e8 Update type annotations in hmm base classes
      • 0d4ca74 Update type annotations in hmm inference code.
      • cfd3bfc Fix type annotations in hmm parallel inference
      • d0a0e0f Add further type annotations to hmm transitions class
      • 4b2c40b Add further type annotations to hmm initial base class
    • HMM Models:
      • 4c84ddc Add further type annotations to categorical hmm
      • 7e04c9e Add further type annotations to arhmm
      • b2fffa4 Add further type annotations to linreghmm
      • a45e61b Add further type annotations to Bernoulli HMM
      • 32db62a Add further type annotations to Gamma HMM
      • 81138f8 Add further type annotations to Gaussian HMMs
      • 4e36520 Add further type annotations to gmhmms
      • 4b7ee53 Add further type annotations to logreg hmm
      • e246606 Add further type annotations to multinomialhmm
      • e19748d Add further type annotations to poisson hmm
      • c2098cf Add further type annotations to categorical glm hmm
  • Other models:
    • 547c610 Fix LinearGaussianSSM.sample type hint
    • 58c9127 Change type hints to jaxtyping in slds code

@gileshd gileshd changed the title Tackle Typing Tackle Typing and Linting Errors Sep 12, 2024
@gileshd gileshd force-pushed the ghd/typing branch 8 times, most recently from 27008b7 to 47fe33f Compare September 24, 2024 15:20
- Rename to PRNGKeyT to differentiate from `jax.random` function
- Change type to Array
  - see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html#type-annotations-for-prng-keys
- Change default arguments to floats in `monotonically_increasing` function
- Help type-checkers parse `compute_state_overlap` function
- Update `keepdims` default value in `pytree_sum` to match `jnp.sum`
- Change attributes of these protocols to read-only.
  - This change means that instances of `NamedTuple` are compatible with
    these protocols
- Add runtime_checkable decorator to Protocols
  - This allows the Protocolcs to be checked by runtime type checkers
    such as beartype which can be used to verify jaxtyping annotations
Major changes:
- Fix an error in `_compute_all_transition_probs` which caused the
  return array to be too short.
  - The `filtered_probs` were being truncated twice instead of once.
- Replace `jaxtyping.Int` with `dynamax.typing.IntScalar` or `int`
  - this reflects when integer scalar arrays are accepted
  - `jaxtyping.[Dtype]` cannot be used directly for type checking
    instead they must be used as part of an array.
- Fix the shape of `transition_matrix`:
  - if transition_matrix has a leading timestep axis it should be of
    length T-1 not of length T.
- Add annotation indicating that `transition_matrix` is an optional argument
- Raise ValueError when neither `transition_matrix` or `transition_fn`
  provided.
Further, enforce that either key or initial_probs specified to
initialize method.
Further, enforce that either key or initial_probs specified in
initialize method.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant