Parameter Identification with JAX
Rumoca compiles a Modelica model into a JAX-compatible Python module. Because the resulting rollout is fully differentiable, you can fit unknown physical parameters — mass, inertia, friction — directly against measured data using ordinary gradient descent. No finite differences, no surrogate.
The notebook walks through the workflow end-to-end: emit the JAX module from a Modelica source, run a differentiable ODE rollout with diffrax, and minimize a trajectory-matching loss with Adam.
JAX gradient-based parameter identification
Identify the physical parameters (m, J) of the PVTOL from a
reference trajectory by gradient descent on (log m̂, log Ĵ),
differentiating through the entire diffrax integration with
jax.grad.
Headline: starting 50%/40% off, we recover both parameters to <0.5% in 60 Adam steps.
Export step — Rumoca compiles the Modelica source into the JAX module this notebook imports (run once, already done for you):
rumoca compile model/pvtol.mo -m PVTOL --target jax -o _generated/jax
That writes _generated/jax/PVTOL_jax.py — machine-generated, never hand-edited.
%matplotlib inline
import os, sys, time
from pathlib import Path
os.environ.setdefault("JAX_PLATFORMS", "cpu")
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from diffrax import diffeqsolve, ODETerm, Tsit5, SaveAt
sys.path.insert(0, str(Path.cwd() / "_generated" / "jax"))
import PVTOL_jax as plant
STATE_NAMES = plant.get_state_names()
INPUT_NAMES = plant.get_input_names()
PARAM_NAMES = plant.get_param_names()
print(f"states = {STATE_NAMES}")
print(f"inputs = {INPUT_NAMES}")
print(f"params = {PARAM_NAMES}") states = ['x', 'z', 'theta', 'x_dot', 'z_dot', 'theta_dot'] inputs = ['T', 'M'] params = ['m', 'J', 'g']
Build a rollout that takes params as a JAX array
We can't use plant.simulate() directly because it bakes in the
default p0. We need the parameters to be a JAX array the
optimizer differentiates through.
def rollout(params, x0, u_seq, t_grid):
"""Piecewise-constant input over each [t_k, t_{k+1}] interval."""
cd = {}
def step(carry, k):
x, _ = carry
u_k = u_seq[k]
sol = diffeqsolve(
ODETerm(plant.ode_fn),
Tsit5(),
t0=t_grid[k], t1=t_grid[k + 1], dt0=1e-3,
y0=x, args=(params, u_k, cd),
saveat=SaveAt(ts=jnp.array([t_grid[k + 1]])),
)
return (sol.ys[-1], None), sol.ys[-1]
init = (x0, None)
(_, _), xs = jax.lax.scan(init=init, xs=jnp.arange(len(t_grid) - 1), f=step)
return jnp.concatenate([x0[None, :], xs], axis=0) Reference trajectory under ground truth
p0 = np.asarray(plant.get_default_values()["p0"]).flatten()
m_true, J_true, g = float(p0[0]), float(p0[1]), float(p0[2])
T_hover = m_true * g
print(f"ground truth: m = {m_true}, J = {J_true}, g = {g}, T_hover = {T_hover:.4f} N")
t_grid = jnp.linspace(0.0, 3.0, 61)
t_centers = 0.5 * (t_grid[:-1] + t_grid[1:])
T = jnp.where(t_centers < 1.0, T_hover,
jnp.where(t_centers < 2.0, 1.10 * T_hover, T_hover))
M = 0.05 * jnp.sin(2.0 * jnp.pi * t_centers)
u_seq = jnp.stack([T, M], axis=1)
x0 = jnp.zeros(len(STATE_NAMES))
params_true = jnp.array([m_true, J_true, g])
y_ref = rollout(params_true, x0, u_seq, t_grid)
print(f"reference final state: x={float(y_ref[-1,0]):+.3f} z={float(y_ref[-1,1]):+.3f} "
f"theta={float(y_ref[-1,2]):+.3f}") ground truth: m = 2.496, J = 0.0344, g = 9.80665, T_hover = 24.4774 N reference final state: x=-10.306 z=-0.386 theta=+0.697
Loss function
MSE in log-parameter space. Log keeps m̂, Ĵ > 0 automatically.
def loss_fn(theta_log, x0, u_seq, t_grid, y_ref, g_const):
m = jnp.exp(theta_log[0])
J = jnp.exp(theta_log[1])
params = jnp.array([m, J, g_const])
y_pred = rollout(params, x0, u_seq, t_grid)
return jnp.mean((y_pred - y_ref) ** 2)
@jax.jit
def loss_and_grad(theta_log):
return jax.value_and_grad(loss_fn)(theta_log, x0, u_seq, t_grid, y_ref, g) Adam loop, 60 steps, track best
def adam_update(theta, grad, state, lr=0.05, beta1=0.9, beta2=0.999, eps=1e-8):
m, v, t = state
t = t + 1
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * (grad ** 2)
m_hat = m / (1 - beta1 ** t); v_hat = v / (1 - beta2 ** t)
theta = theta - lr * m_hat / (jnp.sqrt(v_hat) + eps)
return theta, (m, v, t)
m_init = 0.5 * m_true
J_init = 0.6 * J_true
theta = jnp.array([jnp.log(m_init), jnp.log(J_init)])
print(f"start: m = {m_init:.4f} ({100*(m_init-m_true)/m_true:+.1f}%), "
f"J = {J_init:.6f} ({100*(J_init-J_true)/J_true:+.1f}%)")
state = (jnp.zeros_like(theta), jnp.zeros_like(theta), 0)
loss_history = []
theta_history = [np.asarray(theta)]
best_loss = float("inf"); best_theta = theta
t0 = time.perf_counter()
for k in range(60):
loss, grad = loss_and_grad(theta)
loss_history.append(float(loss))
if float(loss) < best_loss:
best_loss = float(loss); best_theta = theta
theta, state = adam_update(theta, grad, state)
theta_history.append(np.asarray(theta))
if k < 4 or (k + 1) % 10 == 0:
print(f" iter {k+1:3d} loss = {float(loss):.3e} "
f"m = {float(jnp.exp(theta[0])):.4f} J = {float(jnp.exp(theta[1])):.6f}")
print(f"\nwall time = {time.perf_counter()-t0:.2f} s")
m_est = float(jnp.exp(best_theta[0])); J_est = float(jnp.exp(best_theta[1]))
print(f"\nbest: m = {m_est:.4f} ({100*(m_est-m_true)/m_true:+.3f}%)")
print(f" J = {J_est:.6f} ({100*(J_est-J_true)/J_true:+.3f}%)")
print(f"loss {loss_history[0]:.3e} -> {best_loss:.3e} "
f"({loss_history[0]/best_loss:.1e}x reduction)") start: m = 1.2480 (-50.0%), J = 0.020640 (-40.0%)
iter 1 loss = 1.156e+02 m = 1.3120 J = 0.019633
iter 2 loss = 9.454e+01 m = 1.3787 J = 0.019716
iter 3 loss = 7.672e+01 m = 1.4480 J = 0.020274
iter 4 loss = 6.156e+01 m = 1.5196 J = 0.021029
iter 10 loss = 1.195e+01 m = 1.9725 J = 0.027394
iter 20 loss = 8.015e-02 m = 2.6007 J = 0.037452
iter 30 loss = 1.768e+00 m = 2.8467 J = 0.039180
iter 40 loss = 1.412e+00 m = 2.7872 J = 0.036517
iter 50 loss = 3.384e-01 m = 2.6199 J = 0.034785
iter 60 loss = 2.113e-04 m = 2.4893 J = 0.034528
wall time = 4.43 s
best: m = 2.4985 (+0.101%)
J = 0.034544 (+0.419%)
loss 1.156e+02 -> 2.113e-04 (5.5e+05x reduction)
Plot loss curve, parameter trajectories, and before/after rollouts
theta_hist = np.array(theta_history)
m_hist = np.exp(theta_hist[:, 0]); J_hist = np.exp(theta_hist[:, 1])
y_init = rollout(jnp.array([m_init, J_init, g]), x0, u_seq, t_grid)
y_fit = rollout(jnp.array([m_est, J_est, g]), x0, u_seq, t_grid)
t = np.asarray(t_grid)
y_ref_np = np.asarray(y_ref); y_init_np = np.asarray(y_init); y_fit_np = np.asarray(y_fit)
fig = plt.figure(figsize=(11, 7))
gs = fig.add_gridspec(2, 3)
ax = fig.add_subplot(gs[0, 0]); ax.semilogy(loss_history); ax.set_xlabel("iter"); ax.set_ylabel("MSE"); ax.set_title("Loss"); ax.grid(alpha=0.3, which="both")
ax = fig.add_subplot(gs[0, 1]); ax.plot(m_hist); ax.axhline(m_true, ls="--", alpha=0.5); ax.set_xlabel("iter"); ax.set_ylabel("m"); ax.set_title("Mass estimate"); ax.grid(alpha=0.3)
ax = fig.add_subplot(gs[0, 2]); ax.plot(J_hist, color="C1"); ax.axhline(J_true, ls="--", color="C1", alpha=0.5); ax.set_xlabel("iter"); ax.set_ylabel("J"); ax.set_title("Inertia estimate"); ax.grid(alpha=0.3)
for i, n in enumerate(("x", "z", "theta")):
ax = fig.add_subplot(gs[1, i])
idx = STATE_NAMES.index(n)
ax.plot(t, y_ref_np[:, idx], "k-", lw=2, alpha=0.6, label="ref")
ax.plot(t, y_init_np[:, idx], "C3--", lw=1, label="wrong init")
ax.plot(t, y_fit_np[:, idx], "C2-", lw=1, label="fitted")
ax.set_xlabel("Time [s]"); ax.set_ylabel(n); ax.legend(fontsize=8); ax.grid(alpha=0.3)
fig.suptitle("JAX parameter ID via differentiable rollout of Rumoca-emitted PVTOL")
fig.tight_layout(); plt.show()