Skip to content

feat(tidy3d): FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint#3208

Open
marcorudolphflex wants to merge 1 commit intodevelopfrom
FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint
Open

feat(tidy3d): FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint#3208
marcorudolphflex wants to merge 1 commit intodevelopfrom
FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint

Conversation

@marcorudolphflex
Copy link
Contributor

@marcorudolphflex marcorudolphflex commented Jan 27, 2026

Note

Medium Risk
Touches core autograd forward/backward execution paths and adds new scheduling/fallback logic that changes how/when adjoint simulations are launched under local_gradient=True, so regressions could affect gradient correctness or runtime behavior despite strong test coverage.

Overview
Enables parallel adjoint scheduling for autograd local gradients: when config.adjoint.parallel_all_port=True, the forward run/run_async can launch a batch containing the forward solve plus canonical “unit” adjoint simulations for eligible monitor outputs, then reuse/scales those precomputed results during VJP evaluation.

Introduces new parallel-adjoint infrastructure: monitor/monitor-data expose supports_parallel_adjoint() and parallel_adjoint_bases() for mode amplitudes, diffraction amplitudes, and single-point field probes; new source factories deterministically build the corresponding adjoint sources; and a new parallel_adjoint web API module prepares payloads, applies precomputed basis maps, and falls back to sequential adjoints for any remaining VJP entries (with warnings and caps via max_adjoint_per_fwd).

Refactors and hardens the adjoint pipeline by centralizing adjoint-simulation construction (make_adjoint_simulation), adding shared VJP-map filtering/NaN checks and a reusable accumulate_field_map, improving test emulation determinism, and adding extensive tests/docs/config reference for the new flags (including parallel_adjoint_mode_direction_policy).

Written by Cursor Bugbot for commit cacd0bf. This will update automatically on new commits. Configure here.

Greptile Overview

Greptile Summary

This PR implements parallel adjoint scheduling for autograd simulations, allowing eligible adjoint simulations to run concurrently with forward simulations when local_gradient=True. The feature launches canonical "unit" adjoint solves up front and scales them during the backward pass, reducing gradient computation wall-clock time.

Key changes:

  • Added config.adjoint.parallel_all_port configuration flag to enable the feature
  • Added config.adjoint.parallel_adjoint_mode_direction_policy to control mode direction handling
  • Created ParallelAdjointDescriptor classes for mode, diffraction, and point-field monitors
  • Implemented source factory functions for generating adjoint sources deterministically
  • Extended monitor data classes with supports_parallel_adjoint() and parallel_adjoint_descriptors() methods
  • Refactored adjoint simulation creation into reusable make_adjoint_simulation() function
  • Added comprehensive test suite verifying parallel vs sequential gradient equivalence
  • Updated documentation with detailed feature description

The implementation includes proper fallback mechanisms when monitors are unsupported or limits are exceeded, ensuring backward compatibility.

Confidence Score: 3/5

  • This PR introduces significant new functionality with good test coverage but has floating-point comparison issues that need addressing.
  • Score reflects well-architected feature with comprehensive tests and documentation, but critical floating-point equality comparisons (5 instances) need tolerance-based checks per project standards. The refactoring is clean and maintains backward compatibility with proper fallback mechanisms.
  • Pay close attention to tidy3d/web/api/autograd/parallel_adjoint.py (lines 322, 327, 329) and tidy3d/components/autograd/source_factory.py (lines 94, 207) for floating-point comparison fixes.

Important Files Changed

Filename Overview
tidy3d/web/api/autograd/parallel_adjoint.py New file implementing parallel adjoint scheduling. Contains floating-point comparison issues (lines 322, 327, 329) that need tolerance-based checks.
tidy3d/components/autograd/parallel_adjoint_descriptors.py New file with descriptor classes for parallel adjoint; well-structured with proper error handling and type checking.
tidy3d/components/autograd/source_factory.py New source factory utilities with floating-point equality issues (lines 94, 207) that should use tolerance-based comparisons.
tidy3d/web/api/autograd/autograd.py Extended autograd pipeline with parallel adjoint integration; adds helper functions for VJP filtering, field map accumulation, and batch processing.
tidy3d/components/data/monitor_data.py Refactored to support parallel adjoint via new supports_parallel_adjoint() and parallel_adjoint_descriptors() methods; extracted mode source creation to factory.
tidy3d/components/data/sim_data.py Extracted adjoint simulation creation into standalone make_adjoint_simulation() function for reuse; clean refactoring with no logic changes.

Sequence Diagram

sequenceDiagram
    participant User
    participant AutogradAPI as Autograd API
    participant ParallelAdjoint as Parallel Adjoint
    participant Batch as Batch Executor
    participant Solver as FDTD Solver

    User->>AutogradAPI: run(sim, local_gradient=True)
    AutogradAPI->>ParallelAdjoint: prepare_parallel_adjoint(sim)
    ParallelAdjoint->>ParallelAdjoint: collect descriptors from monitors
    ParallelAdjoint->>ParallelAdjoint: filter by direction policy
    ParallelAdjoint->>ParallelAdjoint: create canonical adjoint sims
    ParallelAdjoint-->>AutogradAPI: ParallelAdjointPayload
    
    alt Parallel Adjoint Enabled
        AutogradAPI->>Batch: run_async({fwd, adj_1, adj_2, ...})
        Batch->>Solver: run forward sim
        Batch->>Solver: run adjoint sim 1
        Batch->>Solver: run adjoint sim 2
        Batch-->>AutogradAPI: BatchData
        AutogradAPI->>AutogradAPI: populate_parallel_adjoint_bases()
        AutogradAPI-->>User: SimulationData + aux_data
    else Parallel Adjoint Disabled
        AutogradAPI->>Solver: run forward sim only
        AutogradAPI-->>User: SimulationData
    end

    User->>AutogradAPI: backward pass (VJP)
    AutogradAPI->>ParallelAdjoint: apply_parallel_adjoint(vjp, bases)
    ParallelAdjoint->>ParallelAdjoint: compute coefficients from VJP
    ParallelAdjoint->>ParallelAdjoint: scale and accumulate basis maps
    ParallelAdjoint-->>AutogradAPI: vjp_parallel + vjp_fallback
    
    alt Has fallback VJPs
        AutogradAPI->>Solver: run sequential adjoint for remaining
        Solver-->>AutogradAPI: adjoint field data
        AutogradAPI->>AutogradAPI: combine vjp_parallel + sequential
    end
    
    AutogradAPI-->>User: gradient
Loading

@marcorudolphflex
Copy link
Contributor Author

@greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch 4 times, most recently from 8386a06 to 0c9fe1b Compare January 27, 2026 12:37
@marcorudolphflex marcorudolphflex marked this pull request as ready for review January 27, 2026 12:46
@marcorudolphflex
Copy link
Contributor Author

technical still semi-drafty, marked as ready for cursor bugbot. Will re-request review when really ready.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

@github-actions
Copy link
Contributor

github-actions bot commented Jan 27, 2026

Diff Coverage

Diff: origin/develop...HEAD, staged and unstaged changes

  • tidy3d/components/autograd/parallel_adjoint_bases.py (90.1%): Missing lines 22,61,72,76,105,117,120-121,125,160,164,204,233,237
  • tidy3d/components/autograd/source_factory.py (91.9%): Missing lines 24,79,81,91,120,136,145,159,202,234,238
  • tidy3d/components/autograd/utils.py (90.0%): Missing lines 66
  • tidy3d/components/data/monitor_data.py (85.3%): Missing lines 178,182,1467,1879,4110
  • tidy3d/components/data/sim_data.py (100%)
  • tidy3d/components/monitor.py (67.0%): Missing lines 129,135,1116,1857,1863,1869,1871-1874,1876-1879,1881-1887,1889,1892,1895-1898,1915,1991,2021
  • tidy3d/config/sections.py (100%)
  • tidy3d/web/api/autograd/autograd.py (100%)
  • tidy3d/web/api/autograd/backward.py (100%)
  • tidy3d/web/api/autograd/constants.py (100%)
  • tidy3d/web/api/autograd/parallel_adjoint.py (87.1%): Missing lines 44,63-65,107,111,136,141,146,160,228,314,317,374-375,378,382-383,385,388,415-416,418,421,441,448-455,465,480,484,522,527
  • tidy3d/web/api/autograd/utils.py (100%)

Summary

  • Total: 847 lines
  • Missing: 99 lines
  • Coverage: 88%

tidy3d/components/autograd/parallel_adjoint_bases.py

Lines 18-26

  18 
  19 def _coord_index(coord_values: np.ndarray, target: object) -> int:
  20     values = np.asarray(coord_values)
  21     if values.size == 0:
! 22         raise ValueError("No coordinate values available to index.")
  23     if values.dtype.kind in ("f", "c"):
  24         matches = np.where(np.isclose(values, float(target), rtol=1e-10, atol=0.0))[0]
  25     else:
  26         matches = np.where(values == target)[0]

Lines 57-65

  57         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
  58     ) -> complex:
  59         vjp = data_fields_vjp.get(self.data_path)
  60         if vjp is None:
! 61             return 0.0 + 0.0j
  62         data_index = self._data_index_from_sim_data(sim_data_orig)
  63         vjp_array = np.asarray(vjp)
  64         value = complex(vjp_array[data_index])
  65         return value

Lines 68-80

  68         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
  69     ) -> None:
  70         vjp = data_fields_vjp.get(self.data_path)
  71         if vjp is None:
! 72             return
  73         vjp_array = np.asarray(vjp)
  74         vjp_array[self._data_index_from_sim_data(sim_data_orig)] = 0.0
  75         if vjp_array is not vjp:
! 76             data_fields_vjp[self.data_path] = vjp_array
  77 
  78 
  79 @dataclass(frozen=True)
  80 class DiffractionAdjointBasis:

Lines 101-109

  101         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData, norm: np.ndarray
  102     ) -> complex:
  103         vjp = data_fields_vjp.get(self.data_path)
  104         if vjp is None:
! 105             return 0.0 + 0.0j
  106         try:
  107             data_index = self._data_index_from_sim_data(sim_data_orig)
  108         except ValueError:
  109             return 0.0 + 0.0j

Lines 113-129

  113         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
  114     ) -> None:
  115         vjp = data_fields_vjp.get(self.data_path)
  116         if vjp is None:
! 117             return
  118         try:
  119             data_index = self._data_index_from_sim_data(sim_data_orig)
! 120         except ValueError:
! 121             return
  122         vjp_array = np.asarray(vjp)
  123         vjp_array[data_index] = 0.0
  124         if vjp_array is not vjp:
! 125             data_fields_vjp[self.data_path] = vjp_array
  126 
  127 
  128 @dataclass(frozen=True)
  129 class PointFieldAdjointBasis:

Lines 156-168

  156         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
  157     ) -> None:
  158         vjp = data_fields_vjp.get(self.data_path)
  159         if vjp is None:
! 160             return
  161         vjp_array = np.asarray(vjp)
  162         vjp_array[self._data_index_from_sim_data(sim_data_orig)] = 0.0
  163         if vjp_array is not vjp:
! 164             data_fields_vjp[self.data_path] = vjp_array
  165 
  166 
  167 ParallelAdjointBasis = ModeAdjointBasis | DiffractionAdjointBasis | PointFieldAdjointBasis

Lines 200-208

  200 ) -> list[PointFieldAdjointBasis]:
  201     bases: list[PointFieldAdjointBasis] = []
  202     for component, freqs in component_freqs:
  203         if component not in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
! 204             continue
  205         for freq in freqs:
  206             bases.append(
  207                 PointFieldAdjointBasis(
  208                     monitor_index=monitor_index,

Lines 229-241

  229     for order_x in orders_x:
  230         for order_y in orders_y:
  231             angle_theta = float(theta_for(int(order_x), int(order_y)))
  232             if np.isnan(angle_theta) or np.cos(angle_theta) <= COS_THETA_THRESH:
! 233                 continue
  234             for pol in pols:
  235                 pol_str = str(pol)
  236                 if pol_str not in ("s", "p"):
! 237                     continue
  238                 dataset_name = "Ephi" if pol_str == "s" else "Etheta"
  239                 bases.append(
  240                     DiffractionAdjointBasis(
  241                         monitor_index=monitor_index,

tidy3d/components/autograd/source_factory.py

Lines 20-28

  20 def flip_direction(direction: object) -> str:
  21     if hasattr(direction, "values"):
  22         direction = str(direction.values)
  23     if direction not in ("+", "-"):
! 24         raise ValueError(f"Direction must be in {('+', '-')}, got '{direction}'.")
  25     return "-" if direction == "+" else "+"
  26 
  27 
  28 def adjoint_fwidth_from_simulation(simulation: Simulation) -> float:

Lines 75-85

  75     coefficient: complex,
  76     fwidth: float,
  77 ) -> CustomCurrentSource | None:
  78     if any(simulation.symmetry):
! 79         raise ValueError("Point-field adjoint sources require symmetry to be disabled.")
  80     if not monitor.colocate:
! 81         raise ValueError("Point-field adjoint sources require colocated field monitors.")
  82 
  83     grid = simulation.discretize_monitor(monitor)
  84     coords = {}
  85     spatial_coords = grid.boundaries

Lines 87-95

  87     for axis, dim in enumerate("xyz"):
  88         if monitor.size[axis] == 0:
  89             coords[dim] = np.array([monitor.center[axis]])
  90         else:
! 91             coords[dim] = np.array(spatial_coords_dict[dim][:-1])
  92     values = (
  93         2
  94         * -1j
  95         * coefficient

Lines 116-124

  116     values *= scaling_factor
  117     values = np.nan_to_num(values, nan=0.0)
  118 
  119     if np.all(values == 0):
! 120         return None
  121 
  122     dataset = FieldDataset(**{component: ScalarFieldDataArray(values, coords=coords)})
  123     return CustomCurrentSource(
  124         center=monitor.geometry.center,

Lines 132-140

  132 def diffraction_monitor_medium(simulation: Simulation, monitor: DiffractionMonitor) -> object:
  133     structures = [simulation.scene.background_structure, *list(simulation.structures or ())]
  134     mediums = simulation.scene.intersecting_media(monitor, structures)
  135     if len(mediums) != 1:
! 136         raise ValueError("Diffraction monitor plane must be homogeneous to build adjoint sources.")
  137     return list(mediums)[0]
  138 
  139 
  140 def bloch_vec_for_axis(simulation: Simulation, axis_name: str) -> float:

Lines 141-149

  141     boundary = simulation.boundary_spec[axis_name]
  142     plus = boundary.plus
  143     if hasattr(plus, "bloch_vec"):
  144         return float(plus.bloch_vec)
! 145     return 0.0
  146 
  147 
  148 def diffraction_order_range(
  149     size: float, bloch_vec: float, freq: float, medium: object

Lines 155-163

  155     limit = abs(index) * freq * size / C_0
  156     order_min = int(np.ceil(-limit - bloch_vec))
  157     order_max = int(np.floor(limit - bloch_vec))
  158     if order_max < order_min:
! 159         return np.array([], dtype=int)
  160     return np.arange(order_min, order_max + 1, dtype=int)
  161 
  162 
  163 def diffraction_source_from_simulation(

Lines 198-206

  198     theta_vals, phi_vals = DiffractionData.compute_angles((ux, uy))
  199     angle_theta = float(theta_vals[0, 0, 0])
  200     angle_phi = float(phi_vals[0, 0, 0])
  201     if np.isnan(angle_theta) or np.cos(angle_theta) <= COS_THETA_THRESH:
! 202         raise ValueError("Adjoint source not available for evanescent diffraction order.")
  203 
  204     pol_angle = 0.0 if polarization == "p" else np.pi / 2
  205     bck_eps = medium.eps_model(freq)
  206     return _diffraction_plane_wave(

Lines 230-242

  230     angle_theta = float(theta_data.sel(**angle_sel_kwargs))
  231     angle_phi = float(phi_data.sel(**angle_sel_kwargs))
  232 
  233     if np.isnan(angle_theta):
! 234         return None
  235 
  236     pol_str = str(polarization)
  237     if pol_str not in ("p", "s"):
! 238         raise ValueError(f"Something went wrong, given pol='{pol_str}' in adjoint source.")
  239 
  240     pol_angle = 0.0 if pol_str == "p" else np.pi / 2
  241     bck_eps = diff_data.medium.eps_model(freq)
  242     return _diffraction_plane_wave(

tidy3d/components/autograd/utils.py

Lines 62-70

  62         if k in target:
  63             val = target[k]
  64             if isinstance(val, (list, tuple)) and isinstance(v, (list, tuple)):
  65                 if len(val) != len(v):
! 66                     raise ValueError(
  67                         f"Cannot accumulate field map for key '{k}': "
  68                         f"length mismatch ({len(val)} vs {len(v)})."
  69                     )
  70                 target[k] = type(val)(x + y for x, y in zip(val, v))

tidy3d/components/data/monitor_data.py

Lines 174-186

  174         return []
  175 
  176     def supports_parallel_adjoint(self) -> bool:
  177         """Return ``True`` if this monitor data supports parallel adjoint sources."""
! 178         return False
  179 
  180     def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
  181         """Return parallel adjoint bases for this monitor data."""
! 182         return []
  183 
  184     @staticmethod
  185     def get_amplitude(x) -> complex:
  186         """Get the complex amplitude out of some data."""

Lines 1463-1471

  1463 
  1464     def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
  1465         """Return parallel adjoint bases for single-point field monitors."""
  1466         if not self.supports_parallel_adjoint():
! 1467             return []
  1468         component_freqs = [
  1469             (str(component), data_array.coords["f"].values)
  1470             for component, data_array in self.field_components.items()
  1471         ]

Lines 1875-1883

  1875         return val
  1876 
  1877     def supports_parallel_adjoint(self) -> bool:
  1878         """Return ``True`` for mode monitor amplitude adjoints."""
! 1879         return True
  1880 
  1881     def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
  1882         """Return parallel adjoint bases for mode monitor amplitudes."""
  1883         amps = self.amps

Lines 4106-4114

  4106         return DataArray(np.stack([amp_phi, amp_theta], axis=3), coords=coords)
  4107 
  4108     def supports_parallel_adjoint(self) -> bool:
  4109         """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 4110         return True
  4111 
  4112     def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
  4113         """Return parallel adjoint bases for diffraction monitor amplitudes."""
  4114         amps = self.amps

tidy3d/components/monitor.py

Lines 125-133

  125         return self.storage_size(num_cells=num_cells, tmesh=tmesh)
  126 
  127     def supports_parallel_adjoint(self) -> bool:
  128         """Return ``True`` if this monitor can provide parallel adjoint bases."""
! 129         return False
  130 
  131     def parallel_adjoint_bases(
  132         self, simulation: Simulation, monitor_index: int
  133     ) -> list[ParallelAdjointBasis]:

Lines 131-139

  131     def parallel_adjoint_bases(
  132         self, simulation: Simulation, monitor_index: int
  133     ) -> list[ParallelAdjointBasis]:
  134         """Return parallel adjoint bases for this monitor."""
! 135         return []
  136 
  137 
  138 class FreqMonitor(Monitor, ABC):
  139     """:class:`Monitor` that records data in the frequency-domain."""

Lines 1112-1120

  1112         return amps_size + fields_size
  1113 
  1114     def supports_parallel_adjoint(self) -> bool:
  1115         """Return ``True`` for mode monitor amplitude adjoints."""
! 1116         return True
  1117 
  1118     def parallel_adjoint_bases(
  1119         self, simulation: Simulation, monitor_index: int
  1120     ) -> list[ParallelAdjointBasis]:

Lines 1853-1861

  1853         return BYTES_COMPLEX * len(self.ux) * len(self.uy) * len(self.freqs) * 6
  1854 
  1855     def supports_parallel_adjoint(self) -> bool:
  1856         """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 1857         return True
  1858 
  1859     def parallel_adjoint_bases(
  1860         self, simulation: Simulation, monitor_index: int
  1861     ) -> list[ParallelAdjointBasis]:

Lines 1859-1867

  1859     def parallel_adjoint_bases(
  1860         self, simulation: Simulation, monitor_index: int
  1861     ) -> list[ParallelAdjointBasis]:
  1862         """Return parallel adjoint bases for diffraction monitor amplitudes."""
! 1863         from tidy3d.components.autograd.source_factory import (
  1864             bloch_vec_for_axis,
  1865             diffraction_monitor_medium,
  1866             diffraction_order_range,
  1867         )

Lines 1865-1902

  1865             diffraction_monitor_medium,
  1866             diffraction_order_range,
  1867         )
  1868 
! 1869         medium = diffraction_monitor_medium(simulation, self)
  1870 
! 1871         axis_names = ("x", "y", "z")
! 1872         normal_axis = self.normal_axis
! 1873         transverse_axes = [axis_names[i] for i in range(3) if i != normal_axis]
! 1874         axis_x, axis_y = transverse_axes
  1875 
! 1876         size_x = simulation.size[axis_names.index(axis_x)]
! 1877         size_y = simulation.size[axis_names.index(axis_y)]
! 1878         bloch_vec_x = bloch_vec_for_axis(simulation, axis_x)
! 1879         bloch_vec_y = bloch_vec_for_axis(simulation, axis_y)
  1880 
! 1881         bases: list[DiffractionAdjointBasis] = []
! 1882         freqs = [float(freq) for freq in self.freqs]
! 1883         for freq in freqs:
! 1884             orders_x = diffraction_order_range(size_x, bloch_vec_x, freq, medium)
! 1885             orders_y = diffraction_order_range(size_y, bloch_vec_y, freq, medium)
! 1886             if orders_x.size == 0 or orders_y.size == 0:
! 1887                 continue
  1888 
! 1889             ux = _reciprocal_coords(
  1890                 orders=orders_x, size=size_x, bloch_vec=bloch_vec_x, freq=freq, medium=medium
  1891             )
! 1892             uy = _reciprocal_coords(
  1893                 orders=orders_y, size=size_y, bloch_vec=bloch_vec_y, freq=freq, medium=medium
  1894             )
! 1895             theta_vals, _ = _compute_angles((ux, uy))
! 1896             order_x_index = {int(val): idx for idx, val in enumerate(orders_x)}
! 1897             order_y_index = {int(val): idx for idx, val in enumerate(orders_y)}
! 1898             bases.extend(
  1899                 _build_diffraction_bases_for_freq(
  1900                     monitor_name=self.name,
  1901                     monitor_index=monitor_index,
  1902                     freq=freq,

Lines 1911-1919

  1911                         order_x_index[ox], order_y_index[oy], 0
  1912                     ],
  1913                 )
  1914             )
! 1915         return bases
  1916 
  1917 
  1918 class DiffractionMonitor(PlanarMonitor, FreqMonitor):
  1919     """:class:`Monitor` that uses a 2D Fourier transform to compute the

Lines 1987-1995

  1987         return BYTES_COMPLEX * len(self.freqs)
  1988 
  1989     def supports_parallel_adjoint(self) -> bool:
  1990         """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 1991         return True
  1992 
  1993     def parallel_adjoint_bases(
  1994         self, simulation: Simulation, monitor_index: int
  1995     ) -> list[ParallelAdjointBasis]:

Lines 2017-2025

  2017         for freq in freqs:
  2018             orders_x = diffraction_order_range(size_x, bloch_vec_x, freq, medium)
  2019             orders_y = diffraction_order_range(size_y, bloch_vec_y, freq, medium)
  2020             if orders_x.size == 0 or orders_y.size == 0:
! 2021                 continue
  2022 
  2023             ux = _reciprocal_coords(
  2024                 orders=orders_x, size=size_x, bloch_vec=bloch_vec_x, freq=freq, medium=medium
  2025             )

tidy3d/web/api/autograd/parallel_adjoint.py

Lines 40-48

  40 def _scale_field_map(field_map: AutogradFieldMap, scale: float) -> AutogradFieldMap:
  41     scaled = {}
  42     for k, v in field_map.items():
  43         if isinstance(v, (list, tuple)):
! 44             scaled[k] = type(v)(scale * x for x in v)
  45         else:
  46             scaled[k] = scale * v
  47     return scaled

Lines 59-69

  59     unsupported: list[str] = []
  60     for monitor_index, monitor in enumerate(simulation.monitors):
  61         try:
  62             bases_for_monitor = monitor.parallel_adjoint_bases(simulation, monitor_index)
! 63         except ValueError:
! 64             unsupported.append(monitor.name)
! 65             continue
  66         if bases_for_monitor:
  67             bases.extend(bases_for_monitor)
  68         elif not monitor.supports_parallel_adjoint():
  69             unsupported.append(monitor.name)

Lines 103-115

  103     basis_spec: object,
  104 ) -> object:
  105     post_norm = sim_data_adj.simulation.post_norm
  106     if not hasattr(basis_spec, "freq"):
! 107         return post_norm
  108     freqs = np.asarray(post_norm.coords["f"].values)
  109     idx = int(np.argmin(np.abs(freqs - basis_spec.freq)))
  110     if not np.isclose(freqs[idx], basis_spec.freq):
! 111         raise td.exceptions.AdjointError(
  112             "Parallel adjoint basis frequency not found in adjoint post-normalization."
  113         )
  114     return post_norm.isel(f=[idx])

Lines 132-150

  132         for key, data_array in monitor_data.field_components.items():
  133             if "f" in data_array.dims:
  134                 freqs = np.asarray(data_array.coords["f"].values)
  135                 if freqs.size == 0:
! 136                     raise td.exceptions.AdjointError(
  137                         "Parallel adjoint expected frequency data but no frequencies were found."
  138                     )
  139                 idx = int(np.argmin(np.abs(freqs - freq)))
  140                 if not np.isclose(freqs[idx], freq, rtol=1e-10, atol=0.0):
! 141                     raise td.exceptions.AdjointError(
  142                         "Parallel adjoint basis frequency not found in monitor data."
  143                     )
  144                 updates[key] = data_array.isel(f=[idx])
  145         return monitor_data.updated_copy(monitor=monitor, deep=False, validate=False, **updates)
! 146     return monitor_data.updated_copy(monitor=monitor, deep=False, validate=False)
  147 
  148 
  149 def _select_sim_data_freq(
  150     sim_data_adj: td.SimulationData,

Lines 156-164

  156     for monitor in sim.monitors:
  157         if hasattr(monitor, "freqs"):
  158             monitor_updated = monitor.updated_copy(freqs=[freq])
  159         else:
! 160             monitor_updated = monitor
  161         monitors.append(monitor_updated)
  162         monitor_map[monitor.name] = monitor_updated
  163     sim_updated = sim.updated_copy(monitors=monitors)

Lines 224-232

  224     simulation: td.Simulation,
  225     basis_sources: list[tuple[ParallelAdjointBasis, Any]],
  226 ) -> list[tuple[list[ParallelAdjointBasis], AdjointSourceInfo]]:
  227     if not basis_sources:
! 228         return []
  229 
  230     sim_data_stub = td.SimulationData(simulation=simulation, data=())
  231     sources = [source for _, source in basis_sources]
  232     sources_processed = td.SimulationData._adjoint_src_width_single(sources)

Lines 310-321

  310             coefficient=coefficient,
  311             fwidth=fwidth,
  312         )
  313         if source is None:
! 314             raise ValueError("Adjoint point source has zero amplitude.")
  315         return adjoint_source_info_single(source)
  316 
! 317     raise ValueError("Unsupported parallel adjoint basis.")
  318 
  319 
  320 @dataclass(frozen=True)
  321 class ParallelAdjointPayload:

Lines 370-392

  370                 simulation=simulation,
  371                 basis=basis,
  372                 coefficient=1.0 + 0.0j,
  373             )
! 374         except ValueError as exc:
! 375             td.log.info(
  376                 f"Skipping parallel adjoint basis for monitor '{basis.monitor_name}': {exc}"
  377             )
! 378             continue
  379         basis_sources.append((basis, source_info.sources[0]))
  380 
  381     if not basis_sources:
! 382         if basis_specs:
! 383             td.log.info("Parallel adjoint produced no simulations for this task.")
  384         else:
! 385             td.log.warning(
  386                 "Parallel adjoint disabled because no eligible monitor outputs were found."
  387             )
! 388         return None
  389 
  390     grouped = _group_parallel_adjoint_bases_by_port(simulation, basis_sources)
  391     if len(grouped) > max_num_adjoint_per_fwd:
  392         raise AdjointError(

Lines 411-425

  411         task_map[adj_task_name] = bases
  412         used_bases.extend(bases)
  413 
  414     if not sims_adj_dict:
! 415         if basis_specs:
! 416             td.log.info("Parallel adjoint produced no simulations for this task.")
  417         else:
! 418             td.log.warning(
  419                 "Parallel adjoint disabled because no eligible monitor outputs were found."
  420             )
! 421         return None
  422 
  423     td.log.info(
  424         "Parallel adjoint enabled: launched "
  425         f"{len(sims_adj_dict)} canonical adjoint simulations for task '{task_name}'."

Lines 437-445

  437     task_paths: dict[str, str],
  438     base_dir: PathLike,
  439 ) -> None:
  440     if not task_names:
! 441         return
  442     target_dir = Path(base_dir) / config.adjoint.local_adjoint_dir
  443     target_dir.mkdir(parents=True, exist_ok=True)
  444     for task_name in task_names:
  445         src_path = task_paths.get(task_name)

Lines 444-459

  444     for task_name in task_names:
  445         src_path = task_paths.get(task_name)
  446         if not src_path:
  447             continue
! 448         src = Path(src_path)
! 449         if not src.exists():
! 450             continue
! 451         dst = target_dir / src.name
! 452         if src.resolve() == dst.resolve():
! 453             continue
! 454         dst.parent.mkdir(parents=True, exist_ok=True)
! 455         src.replace(dst)
  456 
  457 
  458 def apply_parallel_adjoint(
  459     data_fields_vjp: AutogradFieldMap,

Lines 461-469

  461     sim_data_orig: td.SimulationData,
  462 ) -> tuple[AutogradFieldMap, AutogradFieldMap]:
  463     basis_maps = parallel_info.get("basis_maps")
  464     if basis_maps is None:
! 465         return {}, data_fields_vjp
  466 
  467     data_fields_vjp_fallback = {k: np.array(v, copy=True) for k, v in data_fields_vjp.items()}
  468     vjp_parallel: AutogradFieldMap = {}
  469     norm_cache: dict[int, np.ndarray] = {}

Lines 476-488

  476     used_bases = 0
  477     for basis in basis_specs:
  478         basis_map = basis_maps.get(basis)
  479         if basis_map is None:
! 480             continue
  481         basis_real = basis_map.get("real")
  482         basis_imag = basis_map.get("imag")
  483         if basis_real is None or basis_imag is None:
! 484             continue
  485         tracked_bases += 1
  486         if isinstance(basis, DiffractionAdjointBasis):
  487             norm = norm_cache.get(basis.monitor_index)
  488             if norm is None:

Lines 518-531

  518                     f"{unused_sims} simulations were unused. Disable parallel adjoint to avoid "
  519                     "unused precomputations."
  520                 )
  521             else:
! 522                 td.log.warning(
  523                     f"Parallel adjoint used {used_bases} of {tracked_bases} bases after VJP "
  524                     "evaluation. Disable parallel adjoint to avoid unused precomputations."
  525                 )
  526         else:
! 527             td.log.warning(
  528                 f"Parallel adjoint used {used_bases} of {tracked_bases} bases after VJP "
  529                 f"evaluation; {unused_bases} had zero VJP coefficients. Disable parallel adjoint "
  530                 "to avoid unused precomputations."
  531             )

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 0c9fe1b to 4d59014 Compare January 27, 2026 15:56
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 4d59014 to 7cabb37 Compare January 27, 2026 16:10
Copy link
Collaborator

@yaugenst-flex yaugenst-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @marcorudolphflex this is pretty great, had a cursory glance at the PR to try to understand a bit what's going on and left some questions/comments, but I'll look deeper into the implementation when I find some time. I guess one thing to note is that this introduces a lot of new code, even modules. Not a problem in itself but I'd maybe have a closer look whether any of this can be simplified.
Also, could you show some plots/verification against the non-parallel adjoint?

- Mode direction policy (for mode monitors): `config.adjoint.parallel_adjoint_mode_direction_policy`
- `"assume_outgoing"` (default): pick the mode direction based on monitor position relative to the simulation center and flip it for the adjoint.
- `"run_both_directions"`: launch parallel adjoint sources for both `+` and `-` directions.
- `"no_parallel"`: disable parallel adjoint entirely.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do mode monitors separately have a flag to turn parallel adjoint off, in addition to the global config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbd if users need that in case they want to override the global toggle for this less-determined mode monitor... As we do have a config field anyways, I think it doesn't hurt. Or could that be confusing for users regarding its effect along with the global toggle?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by less-determined, do you mean that it's harder to predict the adjoint sources to run in parallel and that's why a user would want to turn it off?

Comment on lines +69 to +70
- Only effective when: `config.adjoint.local_gradient = True`
- If `local_gradient=False`, the flag is ignored and behavior remains unchanged.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only local gradients? Couldn't this be supported in remote too? Maybe it's fine as an initial version but I don't see how this couldnt be done for remote?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably yes, this was the "easy" start.


#### Limits and guardrails you should expect

- Hard cap: the feature will not exceed `config.adjoint.max_adjoint_per_fwd`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all parallel simulations counted as adjoint toward this cap?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

#### Limits and guardrails you should expect

- Hard cap: the feature will not exceed `config.adjoint.max_adjoint_per_fwd`.
- If enabling parallel adjoint would exceed the cap, the run logs a warning and proceeds with the sequential path for that forward run (or a safe subset, depending on policy).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might not want to proceed at all in that case, not sure. Since this a flag that we wouldn't turn on by default, it means that generally the user will have requested it, so they might want to choose to increase the cap instead of running sequentially.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, changed it to raising an AdjointError as we do it currently for sequential adjoint

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be important to explain/understand here in which scenarios how many adjoint simulations would get launched in the parallel case and what the edge cases are so there are no surprises.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some section in the readme

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 7cabb37 to f351142 Compare January 28, 2026 08:19
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from f351142 to 9abfd55 Compare January 28, 2026 09:46
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 9abfd55 to ac73e25 Compare January 28, 2026 11:34
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from ac73e25 to dac5685 Compare January 28, 2026 15:12
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from dac5685 to 29306de Compare January 28, 2026 16:16
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 29306de to e05cfd7 Compare January 29, 2026 13:54
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from e05cfd7 to e13d19f Compare January 29, 2026 15:05
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from e13d19f to 6b77f75 Compare January 29, 2026 16:37

When enabled, Tidy3D launches eligible adjoint simulations in parallel with the forward simulation by running a set of canonical "unit" adjoint solves up front. During the backward pass, it reuses those precomputed results and scales them with the actual VJP coefficients from your objective.

Net effect: reduced gradient wall-clock time (often close to ~2x faster in the "many deterministic adjoints" regime), at the cost of sometimes running adjoint solves that your objective ultimately does not use.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there cases where the number of unused simulations can be large? I guess this is where `max_adjoint_per_fwd' protects the user from accidentally running a bunch of sims? Would we want to also issue a warning if a large number of adjoint simulations are unused (maybe this is done already)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this can happen if the objective does not use all modes or frequencies. Added a warning which does inform about the number of unused parallel simulations.

Parallel adjoint launches one canonical adjoint simulation per eligible “basis,” so the total
count is driven by how many distinct outputs your monitors expose:

- **Mode monitors**: one basis per `(freq, mode_index, direction)`. If
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the sequential version, we do some grouping for adjoint sources based on number of ports versus number of frequencies. Does the same happen for parallel adjoint?

I'm thinking of a multi-port optimization (like multiple mode monitors) and a single frequency optimization. In sequential when we know the vjp coming from the objective, we can launch all the adjoint sources at the same time provided we set the amplitude and phase for the single frequency. Thinking through this, we wouldn't be able to do the same for the parallel case right? Instead, if I'm understanding correctly, we would run them all as single frequency injections from each mode monitor and then combine results after computing the objective function grad.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I also think this is not possible as we do not know the individual VJPs from the ports before. Added a note in the readme to clarify that.

) -> None:
if not parallel_info or not sims_adj:
return
td.log.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean part of the process will happen with parallel and part with sequential?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently yes. But this does not really makes sense at that point. Changed it such that we completely fall back to sequential in this case.

raise td.exceptions.AdjointError(
"Parallel adjoint basis frequency not found in adjoint post-normalization."
)
return post_norm.isel(f=[idx])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be called without the list? post_norm.isel(f=idx)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a scalar index would drop the f dimension in xarray, which breaks downstream expectations for post_norm.f and frequency‑aligned broadcasting. The list keeps a length‑1 f dim and matches _select_monitor_data_freq.

monitor = simulation.monitors[basis.monitor_index]
fwidth = adjoint_fwidth_from_simulation(simulation)

if isinstance(basis, DiffractionAdjointBasis):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it expected that we would have cases of mismatched basis and monitor type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not in the current setup, but was introduced to detect regressions. But I guess this is over-defensive here, removed that.

@groberts-flex
Copy link
Contributor

Thanks @marcorudolphflex for working on this, it is a really huge effort! I've been working through the doc you included here and the code and still have a good bit to go. I am curious to understand a bit more about the cases that are mostly being accelerated with the parallel adjoint approach and which cases we end up needing to run extra simulations that we don't end up needing as a result (or which may have been able to be grouped if running sequentially).
I think it would be worth us all talking through the approach a bit more and ways we could simplify some of the logic here. At the very least, for code review purposes, it might be easier if this can be broken up into smaller changes.
I'm excited to chat through, it's definitely a cool feature, thanks again for working on it!

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch 3 times, most recently from 2ace906 to 9cdf564 Compare February 3, 2026 11:21
@marcorudolphflex
Copy link
Contributor Author

Thanks @marcorudolphflex for working on this, it is a really huge effort! I've been working through the doc you included here and the code and still have a good bit to go. I am curious to understand a bit more about the cases that are mostly being accelerated with the parallel adjoint approach and which cases we end up needing to run extra simulations that we don't end up needing as a result (or which may have been able to be grouped if running sequentially). I think it would be worth us all talking through the approach a bit more and ways we could simplify some of the logic here. At the very least, for code review purposes, it might be easier if this can be broken up into smaller changes. I'm excited to chat through, it's definitely a cool feature, thanks again for working on it!

Thanks for the review!
I can understand that this is not easy to review. Unfortunately, I am not quite sure how I could exactly split it up as most changes are related, do you have an idea?
We should definitely discuss how individual cases are handled and how this could be further configured.

I think the essential cases where we use more or unused simulations are frequency-grouped simulations and cases where monitor components have a 0-vjp as the objective does not use them.

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 9cdf564 to cacd0bf Compare February 3, 2026 11:59
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

values = np.nan_to_num(values, nan=0.0)

if np.all(values == 0):
return None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Floating-point equality check for array values

Medium Severity

The check if np.all(values == 0): uses exact floating-point equality to determine if an adjoint source has zero amplitude. Due to floating-point arithmetic in the computation of values (involving scaling_factor, EPSILON_0, omega0, and size_element), the array may contain very small non-zero values even when the source is effectively zero. This could cause the function to return a non-None source when it should return None, or vice versa.

Fix in Cursor Fix in Web

@groberts-flex
Copy link
Contributor

Thanks @marcorudolphflex for working on this, it is a really huge effort! I've been working through the doc you included here and the code and still have a good bit to go. I am curious to understand a bit more about the cases that are mostly being accelerated with the parallel adjoint approach and which cases we end up needing to run extra simulations that we don't end up needing as a result (or which may have been able to be grouped if running sequentially). I think it would be worth us all talking through the approach a bit more and ways we could simplify some of the logic here. At the very least, for code review purposes, it might be easier if this can be broken up into smaller changes. I'm excited to chat through, it's definitely a cool feature, thanks again for working on it!

Thanks for the review! I can understand that this is not easy to review. Unfortunately, I am not quite sure how I could exactly split it up as most changes are related, do you have an idea? We should definitely discuss how individual cases are handled and how this could be further configured.

I think the essential cases where we use more or unused simulations are frequency-grouped simulations and cases where monitor components have a 0-vjp as the objective does not use them.

Thanks @marcorudolphflex for working on this, it is a really huge effort! I've been working through the doc you included here and the code and still have a good bit to go. I am curious to understand a bit more about the cases that are mostly being accelerated with the parallel adjoint approach and which cases we end up needing to run extra simulations that we don't end up needing as a result (or which may have been able to be grouped if running sequentially). I think it would be worth us all talking through the approach a bit more and ways we could simplify some of the logic here. At the very least, for code review purposes, it might be easier if this can be broken up into smaller changes. I'm excited to chat through, it's definitely a cool feature, thanks again for working on it!

Thanks for the review! I can understand that this is not easy to review. Unfortunately, I am not quite sure how I could exactly split it up as most changes are related, do you have an idea? We should definitely discuss how individual cases are handled and how this could be further configured.

I think the essential cases where we use more or unused simulations are frequency-grouped simulations and cases where monitor components have a 0-vjp as the objective does not use them.

Thanks for your reply on this and the extra information. I was intending to get to thinking through more how it might be broken up today, but ended up spending the day on the CustomMedium debugging. I'll get around to this tomorrow and maybe we could chat later Thursday sometime on the different cases here!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants