-
-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why are step_ts and jump_ts treated differently here? #483
Comments
So I don't think this should ever be an infinite loop -- as the next time around then I do take your point that if the previous step was rejected, we shouldn't use the As for why we continue to use Whilst we're here I will note that there is one other difference between |
Oh right, I missed that.
Sounds good, I'll include that change in the
I understand and I agree. However:
Thanks, that is very useful to know. |
|
I would summarise the complete behaviour in 3 rules:
These can be implemented in a very simple way: dt_proposal = factor*(t1 - t0) # note that if step is rejected, then factor<1
# Here comes the clipping between dt_min and dt_max
eqx.error_if(prev_dt, prev_dt < t1-t0, "prev_dt must be >= t1-t0")
dt_proposal = jnp.where(keep_step, jnp.maximum(dt_proposal, prev_dt), dt_proposal)
new_prev_dt = dt_proposal # this goes into controller state as prev_dt
# Here comes the clipping due to step_ts and jump_ts and the whole nextafter(nextafter()) business This has the nice property that it factors well into a controller (which does the first two lines) and a WDYT? |
On I think what you've got sounds reasonable. Mulling it over, I think it should be possible to do something even simpler: change this line:
with - prev_dt = jnp.where(made_jump, prev_dt, t1 - t0)
+ prev_dt = jnp.where(made_jump & keep_step, prev_dt, t1 - t0) (and if need be I think this still factors apart, as you describe). |
Thanks, that's good to know. I will keep the number of In practice it seems that your proposal doesn't lead to desirable behaviour. I compared our two approaches on a very simple example ODE and I was surprised how precisely the experiment echoed the issue I described in my first comment:
In addition the experiment shows that my solution completely fixes this issue. You can find the experiment here. And here you can see why my proposal makes it easier to separate the |
Ah, I see what you mean. Okay, in this case I think maybe what we should do is simply remove I believe your suggestion amounts to preventing the step size form shrinking after an accepted step. For context some PID implementations exhibit this behaviour (e.g. torchdiffeq does this) but I recall deciding against this for Diffrax. It's a heuristic that I think helps some problems but hurts others. |
Fair enough, I'll get rid of it then. On the flip side, as I mentioned in #484, it seems like it was possible for the step size to increase after rejecting (I discovered this in some unrelated experiment, where it just seemed to go on forever until max_steps was reached). Was this intentional or is it good that I now capped |
That is definitely pretty weird! I'm willing to believe that it happens, though. In fact, here's something interesting I came across whilst looking at this just now:
It seems we do prevent step shrinking after an accepted step after all! 😅 In light of this, maybe we should fix the case you just mentioned by also adding (EDIT: I've now seen that you mentioned this in #484. Ignore me, you're way ahead of me!) |
Well technically this prevents it only from shrinking below Haha yes, I was just about to point you to #484 😊 |
Hi Patrick,
Am I correct in saying that the only differences between step_ts and jump_ts are the following:
jump_ts
cannot be integers, but must be floats (so that you can doprevbefore
andnextafter
)_clip_jump_ts
also returnsmade_jump
, which is used to determine whether we need to do_t1 = nextafter(nextafter(t1))
.But in addition to those discrepancies, it seems Diffrax treats them differently in one other way as well, which I am not sure I understand. Namely, the line below uses
prev_dt=prev_dt
if the step was clipped due to a jump, butprev_dt=t1-t0
if the step was clipped due tostep_ts
. I don't see why we should make a distinction between these two cases.diffrax/diffrax/_step_size_controller/adaptive.py
Line 561 in b977dce
I would go even further and say that the line should just say
prev_dt=t1-t0
in all cases. This is because the error of the current step depends ont1-t0
, rather than onprev_dt
, so I feel like keepingprev_dt
in controller state is not needed. Here is what could go wrong with the current setup:Say
prev_dt=0.1
, but due tojump_ts
it was clipped tot1-t0 = 0.01
. Also assume that the error was large and the step gets rejected and assume that the controller computesfactor=0.5
. Then the next step-size proposal will be0.05
, which is bigger than the step that was just taken, so it will again be clipped byjump_ts
to0.01
, resulting in an infinite loop. Instead the new step proposal should just be(t1-t0)*factor = 0.005
, which would presumably result in a smaller error and move forward.On the other hand if the step was clipped to a much smaller size than was intended (i.e.
t1-t0 << prev_dt
), then this will usually reflect in the error being small accordingly, resulting in a largefactor
. This means that(t1-t0)*factor
would be again a reasonably large step-size proposal, whereasprev_dt*factor
would be disproportionately massive.Let me know if I missed something.
The text was updated successfully, but these errors were encountered: