Skip to content

Commit 6987d91

Browse files
fix
Signed-off-by: Arjun Jagdish Ram <arjun.ram@tier4.jp>
1 parent 314d1b7 commit 6987d91

18 files changed

Lines changed: 1688 additions & 215 deletions

planning/autoware_trajectory_optimizer/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ project(autoware_trajectory_optimizer)
33

44
find_package(autoware_cmake REQUIRED)
55
find_package(acados REQUIRED)
6+
find_package(tf2_geometry_msgs REQUIRED)
67

78
autoware_package()
89

@@ -39,7 +40,11 @@ ament_auto_add_library(autoware_trajectory_optimizer_plugins SHARED
3940
src/trajectory_optimizer_plugins/plugin_utils/trajectory_spline_smoother_utils.cpp
4041
src/trajectory_optimizer_plugins/plugin_utils/trajectory_velocity_optimizer_utils.cpp
4142
)
42-
target_link_libraries(autoware_trajectory_optimizer_plugins acados_interface_temporal)
43+
target_link_libraries(
44+
autoware_trajectory_optimizer_plugins
45+
acados_interface_temporal
46+
tf2_geometry_msgs::tf2_geometry_msgs
47+
)
4348

4449
if(BUILD_TESTING)
4550
find_package(ament_lint_auto REQUIRED)
@@ -70,6 +75,7 @@ install(FILES
7075

7176
install(PROGRAMS
7277
scripts/temporal_mpt_debug_visualizer.py
78+
scripts/temporal_mpt_python_reference.py
7379
DESTINATION lib/${PROJECT_NAME}
7480
)
7581

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
/**:
22
ros__parameters:
33
trajectory_temporal_mpt_optimizer:
4-
# Acados temporal solver (SQP). If res_stat/res_eq stay ~O(1), check reference vs odometry
5-
# alignment and steering reference (see TrajectoryTemporalMPTOptimizer).
4+
# SQP. Horizon yref uses closest path index to ego + stage k (see TrajectoryTemporalMPTOptimizer).
65
max_iter: 100
76
tol: 0.001
8-
# Bicycle geometry: should match vehicle_info (affects dynamics and inferred path curvature → δ_ref).
7+
# Bicycle geometry: should match vehicle_info (dynamics only).
98
lf: 1.0
109
lr: 1.0
11-
tau: 0.1
1210
min_points_for_optimization: 10
1311
output_points: 30
14-
enable_debug_info: false
12+
enable_debug_info: true
1513
publish_debug_topics: true

planning/autoware_trajectory_optimizer/config/trajectory_optimizer.param.yaml

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
/**:
22
ros__parameters:
3-
# Plugin execution order - plugins will run in this order
3+
# Plugin execution order — this list is the ONLY control for which plugins load and run.
4+
# To disable temporal MPC entirely, omit TrajectoryTemporalMPTOptimizer from this list.
5+
#
6+
# use_temporal_mpt_optimizer only skips optimize_trajectory() when that plugin IS loaded; it
7+
# does not unload the plugin or stop acados init. For a full off, remove from plugin_names.
48
plugin_names:
5-
# - "autoware::trajectory_optimizer::plugin::TrajectoryPointFixer"
6-
# - "autoware::trajectory_optimizer::plugin::TrajectoryKinematicFeasibilityEnforcer"
7-
# - "autoware::trajectory_optimizer::plugin::TrajectoryQPSmoother"
8-
# - "autoware::trajectory_optimizer::plugin::TrajectoryKinematicFeasibilityEnforcer"
9-
# - "autoware::trajectory_optimizer::plugin::TrajectoryVelocityOptimizer"
10-
# - "autoware::trajectory_optimizer::plugin::TrajectorySplineSmoother"
11-
- "autoware::trajectory_optimizer::plugin::TrajectoryTemporalMPTOptimizer"
9+
- "autoware::trajectory_optimizer::plugin::TrajectoryPointFixer"
10+
- "autoware::trajectory_optimizer::plugin::TrajectoryKinematicFeasibilityEnforcer"
11+
- "autoware::trajectory_optimizer::plugin::TrajectoryQPSmoother"
12+
- "autoware::trajectory_optimizer::plugin::TrajectoryKinematicFeasibilityEnforcer"
13+
- "autoware::trajectory_optimizer::plugin::TrajectoryVelocityOptimizer"
14+
- "autoware::trajectory_optimizer::plugin::TrajectorySplineSmoother"
15+
# - "autoware::trajectory_optimizer::plugin::TrajectoryTemporalMPTOptimizer"
1216

1317
# Plugin activation flags - control runtime enable/disable
1418
use_akima_spline_interpolation: true

planning/autoware_trajectory_optimizer/include/autoware/trajectory_optimizer/trajectory_optimizer_plugins/trajectory_temporal_mpt_optimizer.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ struct TemporalMPTParams
3939
double tol{1.0e-3};
4040
double lf{1.0};
4141
double lr{1.0};
42-
double tau{0.1};
4342
size_t min_points_for_optimization{10};
4443
size_t output_points{30};
4544
bool enable_debug_info{false};

planning/autoware_trajectory_optimizer/scripts/temporal_mpt_debug_visualizer.py

Lines changed: 146 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
Debug topics (under ``--topic-prefix``):
1212
1313
input/reference_trajectory, input/initial_state, output/trajectory, output/solve_status
14-
output/control_acceleration_mps2, output/control_delta_cmd_rad (Float64MultiArray, u = [a, delta_cmd];
14+
output/control_acceleration_mps2, output/control_delta_cmd_rad (Float64MultiArray, u = [a, delta];
1515
also published when solve fails, from the solver's last iterate)
16+
17+
With ``--compare-python``, also subscribes to the Python reference node under
18+
``<prefix>/python/output/...`` (same message types) for side-by-side comparison with C++.
1619
"""
1720

1821
# Copyright 2026 TIER IV, Inc.
@@ -97,10 +100,17 @@ class DebugFrame:
97100
control_delta_cmd: List[float] = field(default_factory=list)
98101
initial_xy_yaw: Optional[Tuple[float, float, float]] = None
99102
solve_status: Optional[int] = None
103+
# Optional: temporal_mpt_python_reference.py replay (same inputs as C++ debug I/O)
104+
python_output_xy: Optional[Tuple[List[float], List[float]]] = None
105+
python_output_vel: List[float] = field(default_factory=list)
106+
python_output_delta: List[float] = field(default_factory=list)
107+
python_control_accel: List[float] = field(default_factory=list)
108+
python_control_delta_cmd: List[float] = field(default_factory=list)
109+
python_solve_status: Optional[int] = None
100110

101111

102112
class TemporalMptDebugVisualizer(Node):
103-
def __init__(self, *, topic_prefix: str, update_hz: float) -> None:
113+
def __init__(self, *, topic_prefix: str, update_hz: float, compare_python: bool) -> None:
104114
super().__init__("temporal_mpt_debug_visualizer")
105115

106116
update_hz = max(update_hz, 1.0)
@@ -146,6 +156,34 @@ def __init__(self, *, topic_prefix: str, update_hz: float) -> None:
146156
TEMPORAL_MPT_DEBUG_QOS,
147157
)
148158

159+
self._compare_python = compare_python
160+
if compare_python:
161+
py = f"{topic_prefix}/python"
162+
self.create_subscription(
163+
Trajectory,
164+
f"{py}/output/trajectory",
165+
self.on_python_output_trajectory,
166+
TEMPORAL_MPT_DEBUG_QOS,
167+
)
168+
self.create_subscription(
169+
Int32,
170+
f"{py}/output/solve_status",
171+
self.on_python_solve_status,
172+
TEMPORAL_MPT_DEBUG_QOS,
173+
)
174+
self.create_subscription(
175+
Float64MultiArray,
176+
f"{py}/output/control_acceleration_mps2",
177+
self.on_python_control_accel,
178+
TEMPORAL_MPT_DEBUG_QOS,
179+
)
180+
self.create_subscription(
181+
Float64MultiArray,
182+
f"{py}/output/control_delta_cmd_rad",
183+
self.on_python_control_delta_cmd,
184+
TEMPORAL_MPT_DEBUG_QOS,
185+
)
186+
149187
self._fig = plt.figure(figsize=(14, 8))
150188
gs = gridspec.GridSpec(
151189
3, 2, figure=self._fig, width_ratios=[1.15, 1.0], wspace=0.28, hspace=0.35
@@ -162,6 +200,8 @@ def __init__(self, *, topic_prefix: str, update_hz: float) -> None:
162200

163201
self.get_logger().info("Temporal MPT debug visualizer started.")
164202
self.get_logger().info(f"Listening under prefix: {topic_prefix}")
203+
if compare_python:
204+
self.get_logger().info(f"Python comparison enabled: {topic_prefix}/python/output/...")
165205
self.get_logger().info(
166206
"Subscriptions use BEST_EFFORT QoS (must match temporal MPT debug publishers)."
167207
)
@@ -228,6 +268,25 @@ def on_control_delta_cmd(self, msg: Float64MultiArray) -> None:
228268
with self._lock:
229269
self._frame.control_delta_cmd = list(msg.data)
230270

271+
def on_python_output_trajectory(self, msg: Trajectory) -> None:
272+
with self._lock:
273+
self._frame.python_output_xy = trajectory_to_xy(msg.points)
274+
self._frame.python_output_vel, self._frame.python_output_delta = (
275+
trajectory_velocity_steering(msg.points)
276+
)
277+
278+
def on_python_solve_status(self, msg: Int32) -> None:
279+
with self._lock:
280+
self._frame.python_solve_status = msg.data
281+
282+
def on_python_control_accel(self, msg: Float64MultiArray) -> None:
283+
with self._lock:
284+
self._frame.python_control_accel = list(msg.data)
285+
286+
def on_python_control_delta_cmd(self, msg: Float64MultiArray) -> None:
287+
with self._lock:
288+
self._frame.python_control_delta_cmd = list(msg.data)
289+
231290
def on_timer(self) -> None:
232291
with self._lock:
233292
frame = DebugFrame(
@@ -241,6 +300,12 @@ def on_timer(self) -> None:
241300
control_delta_cmd=list(self._frame.control_delta_cmd),
242301
initial_xy_yaw=self._frame.initial_xy_yaw,
243302
solve_status=self._frame.solve_status,
303+
python_output_xy=self._frame.python_output_xy,
304+
python_output_vel=list(self._frame.python_output_vel),
305+
python_output_delta=list(self._frame.python_output_delta),
306+
python_control_accel=list(self._frame.python_control_accel),
307+
python_control_delta_cmd=list(self._frame.python_control_delta_cmd),
308+
python_solve_status=self._frame.python_solve_status,
244309
)
245310

246311
# --- Left: XY path ---
@@ -261,7 +326,20 @@ def on_timer(self) -> None:
261326
)
262327
if frame.output_xy is not None and len(frame.output_xy[0]) > 0:
263328
self._ax_xy.plot(
264-
frame.output_xy[0], frame.output_xy[1], "r-", linewidth=2, label="output optimized"
329+
frame.output_xy[0],
330+
frame.output_xy[1],
331+
"r-",
332+
linewidth=2,
333+
label="output (C++)",
334+
)
335+
if frame.python_output_xy is not None and len(frame.python_output_xy[0]) > 0:
336+
self._ax_xy.plot(
337+
frame.python_output_xy[0],
338+
frame.python_output_xy[1],
339+
color="darkorange",
340+
linestyle="--",
341+
linewidth=2,
342+
label="output (Python)",
265343
)
266344

267345
if frame.initial_xy_yaw is not None:
@@ -278,23 +356,31 @@ def on_timer(self) -> None:
278356
ec="k",
279357
)
280358

281-
if frame.solve_status is not None:
282-
status_text = f"solve_status: {frame.solve_status}"
283-
status_color = "green" if frame.solve_status == 0 else "red"
359+
if frame.solve_status is not None or frame.python_solve_status is not None:
360+
parts = []
361+
if frame.solve_status is not None:
362+
parts.append(f"C++ status: {frame.solve_status}")
363+
if frame.python_solve_status is not None:
364+
parts.append(f"Py status: {frame.python_solve_status}")
365+
status_text = " | ".join(parts)
366+
ok_cpp = frame.solve_status is None or frame.solve_status == 0
367+
ok_py = frame.python_solve_status is None or frame.python_solve_status == 0
368+
status_color = "green" if (ok_cpp and ok_py) else "red"
284369
self._ax_xy.text(
285370
0.02,
286371
0.98,
287372
status_text,
288373
transform=self._ax_xy.transAxes,
289374
verticalalignment="top",
290375
color=status_color,
291-
fontsize=11,
376+
fontsize=10,
292377
bbox={"facecolor": "white", "alpha": 0.8, "edgecolor": status_color},
293378
)
294379

295380
if (
296381
(frame.input_xy and len(frame.input_xy[0]) > 0)
297382
or (frame.output_xy and len(frame.output_xy[0]) > 0)
383+
or (frame.python_output_xy and len(frame.python_output_xy[0]) > 0)
298384
or frame.initial_xy_yaw is not None
299385
):
300386
self._ax_xy.relim()
@@ -317,8 +403,18 @@ def on_timer(self) -> None:
317403
idx, frame.input_vel[:n_compare], "b.--", markersize=5, label="input ref v"
318404
)
319405
self._ax_v.plot(
320-
idx, frame.output_vel[:n_compare], "r-", linewidth=2, label="output state v"
406+
idx, frame.output_vel[:n_compare], "r-", linewidth=2, label="output v (C++)"
321407
)
408+
if frame.python_output_vel:
409+
n_py = min(n_compare, len(frame.python_output_vel))
410+
self._ax_v.plot(
411+
list(range(n_py)),
412+
frame.python_output_vel[:n_py],
413+
color="darkorange",
414+
linestyle="--",
415+
linewidth=2,
416+
label="output v (Python)",
417+
)
322418
self._ax_v.legend(loc="best")
323419
elif frame.input_vel:
324420
self._ax_v.plot(
@@ -351,8 +447,18 @@ def on_timer(self) -> None:
351447
idx, frame.input_delta[:n_compare], "b.--", markersize=5, label="input ref δ"
352448
)
353449
self._ax_delta.plot(
354-
idx, frame.output_delta[:n_compare], "r-", linewidth=2, label="output state δ"
450+
idx, frame.output_delta[:n_compare], "r-", linewidth=2, label="output δ (C++)"
355451
)
452+
if frame.python_output_delta:
453+
n_py = min(n_compare, len(frame.python_output_delta))
454+
self._ax_delta.plot(
455+
list(range(n_py)),
456+
frame.python_output_delta[:n_py],
457+
color="darkorange",
458+
linestyle="--",
459+
linewidth=2,
460+
label="output δ (Python)",
461+
)
356462
self._ax_delta.legend(loc="best")
357463
elif frame.input_delta:
358464
self._ax_delta.plot(
@@ -389,9 +495,20 @@ def on_timer(self) -> None:
389495
self._ax_u.set_ylabel("a [m/s²]", color="tab:green")
390496
if na >= n_stages:
391497
self._ax_u.plot(
392-
stages, frame.control_accel[:n_stages], "g-", linewidth=1.5, label="a (u₀)"
498+
stages, frame.control_accel[:n_stages], "g-", linewidth=1.5, label="a C++"
393499
)
394500
self._ax_u.tick_params(axis="y", labelcolor="tab:green")
501+
npa = len(frame.python_control_accel)
502+
if npa > 0:
503+
n_ap = min(n_stages, npa)
504+
self._ax_u.plot(
505+
stages[:n_ap],
506+
frame.python_control_accel[:n_ap],
507+
color="darkgoldenrod",
508+
linestyle="--",
509+
linewidth=1.5,
510+
label="a Python",
511+
)
395512

396513
self._ax_u_twin.set_ylabel("δ_cmd [rad]", color="tab:purple")
397514
if nd >= n_stages:
@@ -401,9 +518,20 @@ def on_timer(self) -> None:
401518
color="tab:purple",
402519
linestyle="-",
403520
linewidth=1.5,
404-
label="δ_cmd (u₁)",
521+
label="δ_cmd C++",
405522
)
406523
self._ax_u_twin.tick_params(axis="y", labelcolor="tab:purple")
524+
npd = len(frame.python_control_delta_cmd)
525+
if npd > 0:
526+
n_dp = min(n_stages, npd)
527+
self._ax_u_twin.plot(
528+
stages[:n_dp],
529+
frame.python_control_delta_cmd[:n_dp],
530+
color="darkorange",
531+
linestyle="--",
532+
linewidth=1.5,
533+
label="δ_cmd Python",
534+
)
407535

408536
lines = self._ax_u.get_lines() + self._ax_u_twin.get_lines()
409537
labels = [ln.get_label() for ln in lines]
@@ -449,6 +577,12 @@ def parse_args(argv: list[str]) -> argparse.Namespace:
449577
default=10.0,
450578
help="Matplotlib refresh rate",
451579
)
580+
parser.add_argument(
581+
"--compare-python",
582+
action=argparse.BooleanOptionalAction,
583+
default=True,
584+
help="Also subscribe to <prefix>/python/output/... from temporal_mpt_python_reference.py",
585+
)
452586
return parser.parse_args(argv)
453587

454588

@@ -460,6 +594,7 @@ def main() -> None:
460594
node = TemporalMptDebugVisualizer(
461595
topic_prefix=cli.topic_prefix.rstrip("/"),
462596
update_hz=cli.update_hz,
597+
compare_python=cli.compare_python,
463598
)
464599
try:
465600
while rclpy.ok():

0 commit comments

Comments
 (0)