-
-
Notifications
You must be signed in to change notification settings - Fork 166
Open
Labels
bugSomething isn't workingSomething isn't working
Description
When diffrax is set up to track roots with positive directions (this is one the dev branch) and starts in a root with negative direction with diffrax.ClipStepSizeController also tracking the location of the same root, it will integrate backwards in time and report that root as root with positive direction (notice 9.7999999999999765e-01 < 9.7999999999999987e-01). Code example could probably further reduced, but it was sufficient to identify diffrax.ClipStepSizeController as the culprit. Segment 1 runs simulation to the root, Segment 2 is where diffrax starts integrating backwards:
Segment 1: t_start = 0.0000000000000000e+00, t_end = 1.0000000000000000e+00
Segment 1: Event detected at t = 9.7999999999999987e-01
Segment 2: t_start = 9.7999999999999987e-01, t_end = 2.0000000000000000e+00
Segment 2: Event detected at t = 9.7999999999999765e-01
With
controller = (
diffrax.ClipStepSizeController(
controller,
jump_ts=[0.98],
)
)
removed, all works fine:
Segment 1: t_start = 0.0000000000000000e+00, t_end = 1.0000000000000000e+00
Segment 1: Event detected at t = 9.7999999999999998e-01
Segment 2: t_start = 9.7999999999999998e-01, t_end = 2.0000000000000000e+00
Segment 2: No event detected
code:
import jax
from typing import Callable
from optimistix import AbstractRootFinder
# Enable 64-bit precision
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import diffrax
import equinox as eqx
import optimistix
def _run_segment(
t_start: float,
t_end: float,
y0,
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
root_finder: AbstractRootFinder,
max_steps: jnp.int_,
adjoint: diffrax.AbstractAdjoint,
cond_fns: list[Callable],
cond_dirs: list[None | bool],
saveat: diffrax.SaveAt,
term: diffrax.ODETerm,
) -> diffrax.Solution:
controller = (
diffrax.ClipStepSizeController(
controller,
jump_ts=[0.98],
)
)
# combine all discontinuity conditions into a single diffrax.Event
event = diffrax.Event(
cond_fn=cond_fns,
root_finder=root_finder,
direction=cond_dirs,
)
sol = diffrax.diffeqsolve(
term,
solver,
t0=t_start,
t1=t_end,
dt0=None,
y0=y0,
stepsize_controller=controller,
max_steps=max_steps,
adjoint=adjoint,
saveat=saveat,
event=event,
throw=False,
)
return sol
# ============================================================================
# Model 00979 functions (recreated from generated model code)
# ============================================================================
def _xdot(t, x, args):
"""ODE dynamics for model 00979 (all zero)."""
damici_ydt = 0.0
damici_xdt = 0.0
damici_pdt = 0.0
dqdt = 0.0
return jnp.array([damici_ydt, damici_xdt, damici_pdt, dqdt])
def _root_cond_fn_combined(t, y, args, **_):
"""Combined root condition function that returns all roots as a vector."""
eroot0 = t - 49/50 # t - 0.98
eroot1 = 49/50 - t # 0.98 - t
return jnp.array([eroot0, eroot1])
def _x0(t, p):
"""Initial state."""
x00 = 0.0
x01 = 0.0
x02 = 0.0
x03 = 0.0
return jnp.array([x00, x01, x02, x03])
# ============================================================================
# Main test
# ============================================================================
def main():
# Model parameters
x0 = _x0(0.0, None)
# Create solver and controller
atol = 1e-4
rtol = 1e-4
tol_factor = 1e2
solver = diffrax.Kvaerno5()
controller = diffrax.PIDController(
rtol=rtol / tol_factor,
atol=atol / tol_factor,
pcoeff=0.0,
icoeff=0.3,
dcoeff=0.4,
)
root_finder = optimistix.Newton(atol=atol, rtol=rtol)
# Create ODE term
term = diffrax.ODETerm(_xdot)
# Create individual root condition functions from the combined one
def _root_cond_fn_event(
ie: int,
t: float,
y,
args: tuple,
**_
):
rval = _root_cond_fn_combined(t, y, args, **_)
return rval.at[ie].get()
# ========================================================================
# Segment 1: t in [0.0, 1.0]
# ========================================================================
print(f"Segment 1: t_start = {0.0:.16e}, t_end = {1.0:.16e}")
saveat_1 = diffrax.SaveAt(t1=True)
sol1 = _run_segment(
t_start=0.0,
t_end=1.0,
y0=x0,
solver=solver,
controller=controller,
root_finder=root_finder,
max_steps=10,
adjoint=diffrax.DirectAdjoint(),
cond_fns=eqx.Partial(_root_cond_fn_event, 0),
cond_dirs=True,
saveat=saveat_1,
term=term,
)
if diffrax.is_event(sol1.result):
t1_event = float(sol1.ts[-1]) if len(sol1.ts) > 0 else sol1.t1
print(f"Segment 1: Event detected at t = {t1_event:.16e}")
else:
print(f"Segment 1: No event detected")
# ========================================================================
# Segment 2: t in [t0_next, 2.0]
# ========================================================================
t0_next = float(sol1.ts[-1]) if len(sol1.ts) > 0 else sol1.t1
print(f"Segment 2: t_start = {t0_next:.16e}, t_end = {2.0:.16e}")
saveat_2 = diffrax.SaveAt(t1=True)
sol2 = _run_segment(
t_start=sol1.ts[-1],
t_end=2.0,
y0=sol1.ys[-1],
solver=solver,
controller=controller,
root_finder=root_finder,
max_steps=10,
adjoint=diffrax.DirectAdjoint(),
cond_fns=eqx.Partial(_root_cond_fn_event, 1),
cond_dirs=True,
saveat=saveat_2,
term=term,
)
if diffrax.is_event(sol2.result):
t2_event = float(sol2.ts[-1]) if len(sol2.ts) > 0 else sol2.t1
print(f"Segment 2: Event detected at t = {t2_event:.16e}")
else:
print(f"Segment 2: No event detected")
return sol1, sol2
if __name__ == '__main__':
main()
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working