Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added JumpStepWrapper #484

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

andyElking
Copy link
Contributor

Hi Patrick,

I factored the jump_ts and step_ts out of the PIDController into JumpStepWrapper (I'm not very set on this name, lmk if you have ideas). I also made it behave as we discussed in #483. In particular, the following three rules are maintained:

  1. We always have t1-t0 <= prev_dt (this is checked via eqx.error_if), with inequality only if the step was clipped or if we hit the end of the integration interval (we do not explicitly check for that).
  2. If the step was accepted, then next_dt must be >=prev_dt.
  3. If the step was rejected, then next_dt must be < t1-t0.

We achieve this in a very simple way here:

dt_proposal = next_t1 - next_t0
dt_proposal = jnp.where(
keep_step, jnp.maximum(dt_proposal, prev_dt), dt_proposal
)
new_prev_dt = dt_proposal

The next step is to add a parameter JumpStepWrapper.revisit_rejected_steps which does what you expect. That will appear in a future commit in this same PR.

@andyElking
Copy link
Contributor Author

I now also added the functionality to revisit rejected steps. In addition, I also imporved the runtime of step_ts and jump_ts, because the controller no longer searches the whole array each time, but keeps an index of where in the array it was previously.

Also I think there was a bug in the PID controller, where it would sometimes reject a step, but have factor>1. To remedy this I modified the following:

factormax = jnp.where(keep_step, self.factormax, self.safety)
factor = jnp.clip(
self.safety * factor1 * factor2 * factor3,
min=factormin,
max=factormax,
)

I think possibly something smaller than just self.safety would make even more sense, I feel like if a step is rejected the next step should be at least 0.5x smaller. But I'm not an expert.

I added a test for revisiting steps and it all seems to work. I also sprinkled in a bunch of eqx.error_if statements to make sure the necessary invariants are always maintained. But this is a bit experimental, so maybe there are some bugs I didn't test for.

I think I commented the code quite well, so hopefully you can easily notice if I made a mistake somewhere.

P.S.: Sorry for bombarding you with PRs. As far as I'm concerned this one is very low priority, I can use the code even if it isn't merged into diffrax proper.

@andyElking andyElking force-pushed the jump_step_pr branch 2 times, most recently from d022ac1 to 4702380 Compare August 14, 2024 14:53
@andyElking
Copy link
Contributor Author

Hi @patrick-kidger,
I got rid of some eqx.error_ifs that I added to my JumpStepWrapper and redid the timing benchmarks. My new implementation was already faster than the old PIDController before, but now this is way more significant, especially when step_ts is long (think >100). Surprisingly, it is faster even when it has to revisit rejected steps. See

# ======= RESULTS =======
# New controller: 0.22829 s, Old controller: 0.31039 s
# Revisiting controller: 0.23212 s

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, quick first pass at a review!

diffrax/_step_size_controller/adaptive.py Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
"Maximum number of rejected steps reached. "
"Consider increasing JumpStepWrapper.rejected_step_buffer_len.",
)
rjct_buff = jnp.where(keep_step, rjct_buff, rjct_buff.at[i_rjct].set(t1))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is a very expensive way to describe this operation! You're copying the whole buffer. XLA will sometimes optimize this out -- because I added that optimization to it! -- but not always.

Better is to do rjct_buff.at[i_rjct].set(jnp.where(keep_step, rjct_buff[i_rjct], t1))

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than that, I think we may need to extend the API here slightly -- we should be able to mark state like this as being a buffer for the purposes of:

final_state = outer_while_loop(

which is needed to avoid spurious copies during backpropagation.

(You can see that both of these comments are basically us having to work around limitations of the XLA compiler.)

Copy link
Contributor Author

@andyElking andyElking Aug 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first one makes sense, I should have seen this.

I don't really know what you want me to do in your second comment. And frankly diffeqsolve is something I haven't even started digesting yet. Are you telling me rejected_buffer should be one of the outer_buffers, meaning that I should make it an instance of SaveState or sth like that? I would apprecaite a bit more guidance.

Also damn how you managed to write all this code is beyond me. Even trying to begin understanding it seems a lot! Very impressive!

Copy link
Owner

@patrick-kidger patrick-kidger Nov 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, you're too kind!

As for my second comment -- I think I've realised that I was wrong. Let me explain. Our backpropagation involves saving copies of our state in checkpoints. Let's suppose we set RecursiveCheckpointAdjoint(checkpoints=max_steps), so that's O(max_steps) memory right? Well, not quite: our updating buffer here is potentially of length max_steps (as per the debate above), and we're saving a copy of it in every checkpont, so we'd actually be using O(max_steps^2) memory! That's not acceptable.

The simple solution to this will just be to set the size of this buffer to e.g. 100 by default, and just allow those copies to be made. And given the behaviour you have here -- in which you potentially overwrite values -- then that is actually what's necessary as well.

As for the complicated solution that I was wrong about: let's consider the case of SaveAt(steps=True). This also involves a buffer of length max_steps, that we save into as we go along. Fortunately, this one has a useful extra property, which is that we never overwrite a value. That means we don't actually need to copy our buffer for every checkpoint! We can use a single buffer that is shared across all checkpoints, getting gradually filled in. To support this case then we actually have a special argument eqxi.while_loop(..., buffers=...), to declare which elements of our loop state have this behaviour. Unfortunately that's not the case here because we do overwrite the values. (And side-note the presence of this buffers parameter is the reason I've not made this public API in Equinox, because the buffer-ness is completely unchecked and it's very easy to shoot yourself in the foot.)

Copy link
Contributor Author

@andyElking andyElking Nov 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I see. Thanks for the in-depth explanation! So let's see if I understand this correctly. If this was not getting rewritten, then I should make it register as a buffer in the outer while loop of diffeqsolve. But, because it does get rewritten, I should not do that(??). Still, I am curious, if I did want to register it as a buffer, how would I accomplish that? Is it indeed by making it an instance of SaveState, or is it something else entirely?

Other than that, should I keep it an Optional[Int] and just add something like this to the docstring:

For most SDEs, setting this to `100` should be sufficient, but if more consecutive steps are rejected, then an error will be raised.

?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup that's correct. If you have an array that is getting filled-in then registering it as a buffer will mean that a single copy is used across all checkpoints. (The checkpoints in eqxi.while_loop, for later backprop.)

But as we're overwriting values here then we actually must keep separate copies of it in each checkpoint, just like every other value that changes from step-to-step of the while loop.

To register something as a buffer it must be specified in eqxi.while_loop(..., buffers=....), see here:

cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers

As for the message you've suggested -- this looks good to me.

diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
@andyElking
Copy link
Contributor Author

Thanks for the review! I made all the edits I could and I left some comments where I need guidance (no hurry though, this is not high priority for me). Also, should I get rid of prev_dt entirely, as you suggested in #483?

@patrick-kidger
Copy link
Owner

Also, should I get rid of prev_dt entirely, as you suggested in #483?

If it's easy to do that in a separate commit afterwards then I would say yes. A separate commit just so it's easy to revert if it turns out we were wrong about something here :D

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'm really sorry for taking so long to get around to this one! Some other work projects got in the way for a bit. (But on the plus side I have a few more open source projects in the pipe, keep an eye out for those ;) ) This is a really useful PR that I very much want to see in.

I've just done another revivew, LMK what you think!

diffrax/_step_size_controller/adaptive_base.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved

def _get_t(i: IntScalarLike, ts: Array) -> RealScalarLike:
i_min_len = jnp.minimum(i, len(ts) - 1)
return jnp.where(i == len(ts), jnp.inf, ts[i_min_len])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the inf here -- can you add a test for using this with a backward solve with t0 > t1? Just to make sure that we're correctly handling that case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For each of the tests (except backprop) in test_adaptive_stepsize_controller.py I added @pytest.mark.parametrize("backwards", [False, True]) which just swaps the order of t0 and t1. This also revealed two other problems:

  1. step_ts and jump_ts need to be sorted and in particular need to be re-sorted after multiplying them with direction in wrap. I added sorting both in JSW.__init__ and JSW.wrap.
  2. VBT complains if t0 >= t1. I don't think this is necessarily a problem but it could be confusing to some users, so let me know if you want to revisit that design decision and I can try to do something about it.

But other than that it all seems to work perfectly when solving backwards. To be fair the way you dealt with the whole backwards solve business is very clean and makes this work perfectly without any alterations.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, great! Then I'm very glad to have these tests :D

For VBT I think I'm happy with either approach -- autoswitching t0, t1 or just raising an error. No strong feelings on what is the better UX.

test/test_progress_meter.py Show resolved Hide resolved
diffrax/_step_size_controller/pid.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
keep_step,
next_t0,
next_t1,
_,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think discarding here is correct. We should do the right thing even if we have a doubly-nested JumpStepWrapper(JumpStepWraper(PIDController(...), ...), ...).

Copy link
Contributor Author

@andyElking andyElking Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've given this a hard think and I think it is nearly impossible to do this perfectly. We can maintain correctness of the DE solution by allowing made_jump to sometimes be a false positive (e.g. the inner JSW recorded jump_next_step=True in the previous step, but then the outer JSW further clipped the proposal, so the jump didn't actually happen). That just makes us do one unnecessary VF evaluation at the fanthom jump point, but if I understand correctly the final solution should still be correct. WDYT?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following on from my main comment below -- perhaps if the this jump-next-step business is handled in integrate.py, then we can make each stepsize controller just report whether or not they clipped the step, with these wrappers just |ing things together on whether something was clipped.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I handled this all inside the JSW in the new commit and I wrote a long comment about why I think the implementation is correct. But please do double check it.

if step_ts is not None:
# If we stepped to `t1 == step_ts[i_step]` and kept the step, then we
# increment i_step and move on to the next t in step_ts.
step_inc_cond = keep_step & (t1 == _get_t(i_step, step_ts))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I'm comfortable with this == between floating point numbers.

More generally speaking I think there could be cases in which the times being passed here do not perfectly align with the times that the adaptive step size controller suggested on the previous step (e.g. because of further wrapping of the step size controller), so I think this kind of logic is wrong anyway. I think you need something more like the jump_ts branch below, where you just want to snap i_step to the correct value. (Nothing that the correct value should be determinable statically, we only have state here for efficiency purposes.)

Copy link
Contributor Author

@andyElking andyElking Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented a linear search as you suggest and now use it to determine i_step and i_jump.

However, I think that for the rejected step business using linear search is incorrect, because t1 should never be greater than rejected_t = _get_t(i_reject, rejected_buffer). So depending on what you think is more appropriate we can either use == or jnp.isclose(t1, rejected_t, atol=1e-12). I also added an eqxi.error_if (see below), but I can remove it if you don't think it is necessary. I could also add a parameter that only activates this callback when we are doing tests. Let me know.

i_reject = eqx.error_if(
    i_reject,
    t1 > rejected_t,
    "Jumped over a rejected time. Please report this as a bug.",
)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW if you want error_if here then I'd make it a testing-only branch. During regular user runtime I try to avoid using it for these kinds of asserts as it's actually quite slow.

Comment on lines 348 to 350
next_jump_t = _get_t(i_jump, jump_ts)
jump_inc_cond = keep_step & (t1 >= eqxi.prevbefore(next_jump_t))
i_jump = jnp.where(jump_inc_cond, i_jump + 1, i_jump)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, if someone else is fiddling with the times then this seems to me like it might be fragile.

Recalling the previous implementation with its inefficient use of searchsorted. I think the robust approach here might be to write something with the same API as that, but whose implementation is just a simple linear search forwards or backwards from the current position (which is a 'hint' about where to start searching).

Most of the time that will just iterate once and be done, as here. But in the edge cases it should now do the right thing.

@andyElking
Copy link
Contributor Author

Thanks for the review, Patrick! I'll probably make the fixes sometime in the coming week. I am also making progress on the ML examples for the Single-seed paper, but it is slower now, due to my internship.

@andyElking
Copy link
Contributor Author

I am very confused about what the correct value of made_jump should be when the step was rejected. By my understanding the original PID controller also got this wrong. However, I'm not sure whether getting this wrong is actually a significant issue, so I want your thoughts on it.

Suppose there is a jump at t=2. I will present 2 possible scenarios, in both of which I think something goes wrong (although maybe diffeqsolve might correct for the issue in scenario B). I wrote them as if JSW and the controller are separate, but the same holds for just the old PID controller.

====== scenario A =======

  1. We start with a step from t0 = 0 to t1 = 1, the controller decides the step will be kept and makes a proposal next_t0 = 1, next_t1=2.5 which gets clipped to next_t1 = 2 and so the JSW writes jump_next_step=True in its state.
  2. We take the next step from t0 = 1 to t1 = 2. Suppose the step is rejected, so next_t0 = 1. But the way the code works now, it will still set made_jump = previous_state.jump_next_step, which has been set to True. But that is wrong because made_jump should reflect whether there is a jump at next_t0 so that the solver knows whether to reevaluate the VF at the start of the next step. Still this is not horrible, since it just causes one extra evaluation of the VF, but the full solution should still be correct. Right?

====== scenario B =======

  1. We start with a step from t0 = 0 to t1 = 1, the controller decides the step will be kept and makes a proposal next_t0 = 1, next_t1=2.5 which gets clipped to next_t1 = 2 and so the JSW writes jump_next_step=True in its state.
  2. We take the next step from t0 = 1 to t1 = 2. Suppose the step is accepted and the controller proposes next_t0 = 2 and next_t1 = 4 the controller returns (correctly) made_jump = previous_state.jump_next_step (=True) and ControllerState(..., jump_next_step=False).
  3. We take the next step from t0 = 2 to t1 = 4. Note that because made_jump was set to True last iteration, the FSAL VF was reevaluated at t = 2. Suppose the step is rejected and the new proposal is next_t0 = 2 and next_t1 = 3. The controller returns made_jump = previous_state.jump_next_step (=False). This is a problem because there is a jump at next_t0=2, so made_jump should have been True. However that might be fine if the value of the VF which was recomputed at the start of this step is kept for the next step (but I remember thinking that solver state is discarded when the step is rejected, so the recomputed VF also gets discarded).

Another way of seeing this all is through this:

  • jump_next_step = (there is a jump at next_t1)
  • previous_state.jump_next_step = (there is a jump at t1)
  • made_jump should be True iff there is a jump at next_t0

Hence setting made_jump = previous_state.jump_next_step only works when the current step is being accepted, so next_t0 = t1. When rejecting the step, the new made_jump should be equal to the previous step's made_jump which we no longer have access to. The solution would be to add yet another thing to the state, but first I wanted to confirm with you to see if I'm getting this right. What are your thoughts?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Nov 29, 2024

So I think the made_jump-on-rejected-step is handled through this line, outside the stepsize controller:

made_jump = static_select(keep_step, made_jump, state.made_jump)

I made the decision to handle some of the step-rejection logic in the main integrate.py loop, on the basis that for those pieces it should be the same for all stepsize controllers.

So I think this fine? Do double-check my logic though! :p

Other than that, one thing I am noticing is that this next_made_jump business is pretty annoying to deal with in the stepsize controller, and I think could probably also be factored out to happen in the main integrate.py loop. (Not sure if that's doable in a backward-compatible way though.)

@andyElking
Copy link
Contributor Author

andyElking commented Dec 3, 2024

Great, that's exactly the line I was looking for (I must admit I looked in _integrate.py only briefly). But yes this is precisely the piece of the puzzle that my logic was missing, so I think we both independently arrived at the same conclusion that this is probably correct.

Thinking about it now, the next_made_jump situation might be way easier than I thought it was. I'll try to write the code and then I'll write a proof that it does the right thing for any number of stacked JSWs. Hopefully I can be done with that tomorrow.

Edit: I already implemented what I mentioned above and wrote the proof in a comment. If you're curious and have extra time (yes I know that's a very far tail event :)) you can find it on my pr_correction branch here. I haven't fixed all the other things yet, so I'll push everything together once it's all done.

@andyElking
Copy link
Contributor Author

andyElking commented Dec 5, 2024

Hi Patrick!

I just pushed a new version of this PR, rebased on top of the most current main. I think I addressed everything you asked me to fix.

As it stands this contains 3 commits, contatining:

  1. Everything except 2. and 3.
  2. The changes to at_dtmin and factormax in pid.py
  3. Removing prev_dt from JSW.

I left some conversations unresloved. I did try to fix the things mentioned in those, but I am not sure whether what I did was the best way to tackle that so I wanted to hear your opinion.

Also the test are failing because pyright doesn't know how to import typeguard, which has nothing to do with my changes.

PS: The linear search I added slows it down compared to the way I wrote it before, but it is still faster than the old implementation with binary search. In particular the times (as obtained by benchmarks/jump_step_timing.py) are as follows:

  • my previous implementation: 0.263 (without revisit rejected), 0.285 (with revisit rejected)
  • new implementation (linear search): 0.295 (without revisit rejected), 0.315 (with revisit rejected)
  • old pid with binary search: 0.332

Additionally, changing the length of rejected_buffer from 10 to 4096 amounts to a neglibigle slowdown. However, this might change depending on the problem setup.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! I think I really like this.

First of all, I think I'm basically happy with pretty much everything outside of jump_step_wrapper.py. The changes here are pleasingly simple ^^

For jump_step_wrapper.py, I think my main question is around whether the rejected-step-buffer should actually be part of this wrapper at all -- since that handles SDEs with any kind of step rejection, which I think is completely orthogonal to clipping steps? (Not sure how I didn't notice this before!) I've also commented on a few other more minor points.

By the way, what did you think of the idea of moving next_made_jump into _integrate.py? It doesn't have to be now -- happy for that to be a separate PR -- just checking your thoughts on whether it is a generalisable thing.

Finally: merry Christmas, and a happy new year! :D

Comment on lines +26 to +36
```bibtex
@misc{foster2024convergenceadaptiveapproximationsstochastic,
title={On the convergence of adaptive approximations for stochastic differential equations},
author={James Foster and Andraž Jelinčič},
year={2024},
eprint={2311.14201},
archivePrefix={arXiv},
primaryClass={math.NA},
url={https://arxiv.org/abs/2311.14201},
}
```
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I'm really glad to be advertising this paper, it needs to be better-known.

(Nit: whilst the ```bibtex is fine, the contents are currently all indented one step too far.)

module_meta = type(eqx.Module)


class PIDMeta(module_meta):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this should be private _PIDMeta. (Also similar for module_meta, although you could also just inline class _PIDMeta(type(eqx.Module)))



_ControllerState = TypeVar("_ControllerState")
_Dt0 = TypeVar("_Dt0", None, RealScalarLike, Optional[RealScalarLike])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You want TypeVar("_Dt0", bound=Optional[RealScalarLike]) here.

  1. When you have TypeVar("T", Foo, Bar), then it indicates that T must be filled by precisely Foo or Bar, and in particular not a subclass.

  2. Meanwhile TypeVar("T", bound=Union[Foo, Bar]) indicates that any subclass of Union[Foo, Bar] is acceptable -- in particular this includes both Foo (which is a subclass of the union type) and Bar (which is also a subclass of the union type), but also includes any subclass of Foo and Bar themselves.

The reason this matters is that RealScalarLike is itself a union (which if you haven't thought about it before is essentially an anonymous ABC), so by definition no instance can ever have that type! It means that version (1) above can essentially never be useful.

(Okay, I'm lying a little bit: ever-so-technically the static type checker could take a value with a known concrete type T and pretend that it has type Union[T, S], and it actually will do so sometimes... but this is pretty fragile so I never rely on this working in practice.)

Comment on lines +515 to 516
at_dtmin = at_dtmin | (prev_dt <= self.dtmin)
keep_step = keep_step | at_dtmin
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, does at_dtmin need to be state? (I'm not sure it ever did.) I think we might just be able to have keep_step = keep_step | (prev_dt <= self.dtmin)?

Comment on lines +133 to +140
The `step_ts` and `jump_ts` are used to force the solver to step to certain times.
They mostly act in the same way, except that when we hit an element of `jump_ts`,
the controller must return `made_jump = True`, so that the diffeqsolve function
knows that the vector field has a discontinuity at that point, in which case it
re-evaluates it right after the jump point. In addition, the
exact time of the jump will be skipped using eqxi.prevbefore and eqxi.nextafter.
So now to the explanation of the two (we will use `step_ts` as an example, but the
same applies to `jump_ts`):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I often rewrite parts of docs after merging anyway, so feel free to ignore this for now -- but just a heads-up that this part is discussing a lot of implementation details: made_jump = True and eqxi.{prevbefore,nextafter} are not details familiar to most users.

Comment on lines +93 to +94
i = jax.lax.while_loop(cond_up, lambda _i: _i + 1, i)
i = jax.lax.while_loop(cond_down, lambda _i: _i - 1, i)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have both of these loops? I think we only need a linear search in one direction: to find the next element of ts to clip to?

(And if we do need a bidirectional search, then given a hint n it's probably more efficient to search e.g. n / n+1 / n-1 / n+2 / n-2 / ... etc back and forth?)


# This is just a logging utility for testing purposes
if self.callback_on_reject is not None:
jax.debug.callback(self.callback_on_reject, keep_step, t1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might suggest making this a pure_callback or io_callback, so that it will definitely be called in the right order across steps. JAX doesn't actually offer guarantees about the order in which multiple debug callbacks are called.

See for example how eqx.error_if works, which does the same thing by requiring a token.

(There is actually jax.debug.callback(..., ordered=True), but this works by having JAX sneakily rewriting the jaxpr to thread a dummy argument through as a token so as to order things... and I think that edge cases, so I try to avoid it.)

next_t0, RealScalarLike
), f"type(next_t0) = {type(next_t0)}"
else:
isinstance(next_t0, get_args(RealScalarLike))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line doesn't actually do anything? It just creates a False or True and then does nothing. Are you missing an assert?

if TYPE_CHECKING: # if i don't seperate this out pyright complains
assert isinstance(
next_t0, RealScalarLike
), f"type(next_t0) = {type(next_t0)}"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the string here will never appear as it's inside a static-type-checking-only block.

Comment on lines +430 to +444
# Let's prove that the line below is correct. Say the inner controller is
# itself a JumpStepWrapper (JSW) with some inner_jump_ts. Then, given that
# it propsed (next_t0, original_next_t1), there cannot be any jumps in
# inner_jump_ts between next_t0 and original_next_t1. So if the next_t1
# proposed by the outer JSW is different from the original_next_t1 then
# next_t1 \in (next_t0, original_next_t1) and hence there cannot be a jump
# in inner_jump_ts at next_t1. So the jump_at_next_t1 only depends on
# jump_at_next_t1.
# On the other hand if original_next_t1 == next_t1, then we just take an
# OR of the two.
jump_at_next_t1 = jnp.where(
next_t1 == original_next_t1,
jump_at_original_next_t1,
jump_at_next_t1 | jump_at_original_next_t1,
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't think I completely believe this. Can we have the following:

  • the PID controller proposes t1.
  • the inner JSW wants to clip to a jump b < t1.
  • the outer JSW wants to clip to a step (not a jump!) a < b
    ?

In this case then we will have next_t1 != original_next_t1, an jump_at_original_next_t1 == True... but overall we want made_jump == False?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+can we have a test for two tested JSW, including the above scenario? It doesn't need to be a full diffeqsolve, just directly calling adapt_step_size and checking that we get the right output.

This was referenced Jan 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants