Skip to content

BUG: integrator reverses direction when diffrax.ClipStepSizeController and event detection collide #713

@FFroehlich

Description

@FFroehlich

When diffrax is set up to track roots with positive directions (this is one the dev branch) and starts in a root with negative direction with diffrax.ClipStepSizeController also tracking the location of the same root, it will integrate backwards in time and report that root as root with positive direction (notice 9.7999999999999765e-01 < 9.7999999999999987e-01). Code example could probably further reduced, but it was sufficient to identify diffrax.ClipStepSizeController as the culprit. Segment 1 runs simulation to the root, Segment 2 is where diffrax starts integrating backwards:

Segment 1: t_start = 0.0000000000000000e+00, t_end = 1.0000000000000000e+00
Segment 1: Event detected at t = 9.7999999999999987e-01
Segment 2: t_start = 9.7999999999999987e-01, t_end = 2.0000000000000000e+00
Segment 2: Event detected at t = 9.7999999999999765e-01

With

controller = (
        diffrax.ClipStepSizeController(
            controller,
            jump_ts=[0.98],
        )
    )

removed, all works fine:

Segment 1: t_start = 0.0000000000000000e+00, t_end = 1.0000000000000000e+00
Segment 1: Event detected at t = 9.7999999999999998e-01
Segment 2: t_start = 9.7999999999999998e-01, t_end = 2.0000000000000000e+00
Segment 2: No event detected

code:

import jax

from typing import Callable
from optimistix import AbstractRootFinder

# Enable 64-bit precision
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import diffrax
import equinox as eqx
import optimistix

def _run_segment(
    t_start: float,
    t_end: float,
    y0,
    solver: diffrax.AbstractSolver,
    controller: diffrax.AbstractStepSizeController,
    root_finder: AbstractRootFinder,
    max_steps: jnp.int_,
    adjoint: diffrax.AbstractAdjoint,
    cond_fns: list[Callable],
    cond_dirs: list[None | bool],
    saveat: diffrax.SaveAt,
    term: diffrax.ODETerm,
) -> diffrax.Solution:
    controller = (
        diffrax.ClipStepSizeController(
            controller,
            jump_ts=[0.98],
        )
    )

    # combine all discontinuity conditions into a single diffrax.Event
    event = diffrax.Event(
        cond_fn=cond_fns,
        root_finder=root_finder,
        direction=cond_dirs,
    )

    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0=t_start,
        t1=t_end,
        dt0=None,
        y0=y0,
        stepsize_controller=controller,
        max_steps=max_steps,
        adjoint=adjoint,
        saveat=saveat,
        event=event,
        throw=False,
    )

    return sol


# ============================================================================
# Model 00979 functions (recreated from generated model code)
# ============================================================================

def _xdot(t, x, args):
    """ODE dynamics for model 00979 (all zero)."""

    damici_ydt = 0.0
    damici_xdt = 0.0
    damici_pdt = 0.0
    dqdt = 0.0

    return jnp.array([damici_ydt, damici_xdt, damici_pdt, dqdt])


def _root_cond_fn_combined(t, y, args, **_):
    """Combined root condition function that returns all roots as a vector."""

    eroot0 = t - 49/50      # t - 0.98
    eroot1 = 49/50 - t      # 0.98 - t

    return jnp.array([eroot0, eroot1])


def _x0(t, p):
    """Initial state."""
    x00 = 0.0
    x01 = 0.0
    x02 = 0.0
    x03 = 0.0
    return jnp.array([x00, x01, x02, x03])


# ============================================================================
# Main test
# ============================================================================

def main():
    # Model parameters
    x0 = _x0(0.0, None)

    # Create solver and controller
    atol = 1e-4
    rtol = 1e-4
    tol_factor = 1e2
    solver = diffrax.Kvaerno5()
    controller = diffrax.PIDController(
        rtol=rtol / tol_factor,
        atol=atol / tol_factor,
        pcoeff=0.0,
        icoeff=0.3,
        dcoeff=0.4,
    )
    root_finder = optimistix.Newton(atol=atol, rtol=rtol)

    # Create ODE term
    term = diffrax.ODETerm(_xdot)

    # Create individual root condition functions from the combined one
    def _root_cond_fn_event(
        ie: int,
        t: float,
        y,
        args: tuple,
        **_
    ):
        rval = _root_cond_fn_combined(t, y, args, **_)
        return rval.at[ie].get()

    # ========================================================================
    # Segment 1: t in [0.0, 1.0]
    # ========================================================================

    print(f"Segment 1: t_start = {0.0:.16e}, t_end = {1.0:.16e}")

    saveat_1 = diffrax.SaveAt(t1=True)
    sol1 = _run_segment(
        t_start=0.0,
        t_end=1.0,
        y0=x0,
        solver=solver,
        controller=controller,
        root_finder=root_finder,
        max_steps=10,
        adjoint=diffrax.DirectAdjoint(),
        cond_fns=eqx.Partial(_root_cond_fn_event, 0),
        cond_dirs=True,
        saveat=saveat_1,
        term=term,
    )

    if diffrax.is_event(sol1.result):
        t1_event = float(sol1.ts[-1]) if len(sol1.ts) > 0 else sol1.t1
        print(f"Segment 1: Event detected at t = {t1_event:.16e}")
    else:
        print(f"Segment 1: No event detected")

    # ========================================================================
    # Segment 2: t in [t0_next, 2.0]
    # ========================================================================

    t0_next = float(sol1.ts[-1]) if len(sol1.ts) > 0 else sol1.t1
    print(f"Segment 2: t_start = {t0_next:.16e}, t_end = {2.0:.16e}")

    saveat_2 = diffrax.SaveAt(t1=True)
    sol2 = _run_segment(
        t_start=sol1.ts[-1],
        t_end=2.0,
        y0=sol1.ys[-1],
        solver=solver,
        controller=controller,
        root_finder=root_finder,
        max_steps=10,
        adjoint=diffrax.DirectAdjoint(),
        cond_fns=eqx.Partial(_root_cond_fn_event, 1),
        cond_dirs=True,
        saveat=saveat_2,
        term=term,
    )

    if diffrax.is_event(sol2.result):
        t2_event = float(sol2.ts[-1]) if len(sol2.ts) > 0 else sol2.t1
        print(f"Segment 2: Event detected at t = {t2_event:.16e}")
    else:
        print(f"Segment 2: No event detected")

    return sol1, sol2


if __name__ == '__main__':
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions