feat(tidy3d): FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint#3208
Conversation
|
@greptile |
8386a06 to
0c9fe1b
Compare
|
technical still semi-drafty, marked as ready for cursor bugbot. Will re-request review when really ready. |
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/components/autograd/parallel_adjoint_bases.pyLines 18-26 18
19 def _coord_index(coord_values: np.ndarray, target: object) -> int:
20 values = np.asarray(coord_values)
21 if values.size == 0:
! 22 raise ValueError("No coordinate values available to index.")
23 if values.dtype.kind in ("f", "c"):
24 matches = np.where(np.isclose(values, float(target), rtol=1e-10, atol=0.0))[0]
25 else:
26 matches = np.where(values == target)[0]Lines 57-65 57 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
58 ) -> complex:
59 vjp = data_fields_vjp.get(self.data_path)
60 if vjp is None:
! 61 return 0.0 + 0.0j
62 data_index = self._data_index_from_sim_data(sim_data_orig)
63 vjp_array = np.asarray(vjp)
64 value = complex(vjp_array[data_index])
65 return valueLines 68-80 68 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
69 ) -> None:
70 vjp = data_fields_vjp.get(self.data_path)
71 if vjp is None:
! 72 return
73 vjp_array = np.asarray(vjp)
74 vjp_array[self._data_index_from_sim_data(sim_data_orig)] = 0.0
75 if vjp_array is not vjp:
! 76 data_fields_vjp[self.data_path] = vjp_array
77
78
79 @dataclass(frozen=True)
80 class DiffractionAdjointBasis:Lines 101-109 101 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData, norm: np.ndarray
102 ) -> complex:
103 vjp = data_fields_vjp.get(self.data_path)
104 if vjp is None:
! 105 return 0.0 + 0.0j
106 try:
107 data_index = self._data_index_from_sim_data(sim_data_orig)
108 except ValueError:
109 return 0.0 + 0.0jLines 113-129 113 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
114 ) -> None:
115 vjp = data_fields_vjp.get(self.data_path)
116 if vjp is None:
! 117 return
118 try:
119 data_index = self._data_index_from_sim_data(sim_data_orig)
! 120 except ValueError:
! 121 return
122 vjp_array = np.asarray(vjp)
123 vjp_array[data_index] = 0.0
124 if vjp_array is not vjp:
! 125 data_fields_vjp[self.data_path] = vjp_array
126
127
128 @dataclass(frozen=True)
129 class PointFieldAdjointBasis:Lines 156-168 156 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
157 ) -> None:
158 vjp = data_fields_vjp.get(self.data_path)
159 if vjp is None:
! 160 return
161 vjp_array = np.asarray(vjp)
162 vjp_array[self._data_index_from_sim_data(sim_data_orig)] = 0.0
163 if vjp_array is not vjp:
! 164 data_fields_vjp[self.data_path] = vjp_array
165
166
167 ParallelAdjointBasis = ModeAdjointBasis | DiffractionAdjointBasis | PointFieldAdjointBasisLines 200-208 200 ) -> list[PointFieldAdjointBasis]:
201 bases: list[PointFieldAdjointBasis] = []
202 for component, freqs in component_freqs:
203 if component not in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
! 204 continue
205 for freq in freqs:
206 bases.append(
207 PointFieldAdjointBasis(
208 monitor_index=monitor_index,Lines 229-241 229 for order_x in orders_x:
230 for order_y in orders_y:
231 angle_theta = float(theta_for(int(order_x), int(order_y)))
232 if np.isnan(angle_theta) or np.cos(angle_theta) <= COS_THETA_THRESH:
! 233 continue
234 for pol in pols:
235 pol_str = str(pol)
236 if pol_str not in ("s", "p"):
! 237 continue
238 dataset_name = "Ephi" if pol_str == "s" else "Etheta"
239 bases.append(
240 DiffractionAdjointBasis(
241 monitor_index=monitor_index,tidy3d/components/autograd/source_factory.pyLines 20-28 20 def flip_direction(direction: object) -> str:
21 if hasattr(direction, "values"):
22 direction = str(direction.values)
23 if direction not in ("+", "-"):
! 24 raise ValueError(f"Direction must be in {('+', '-')}, got '{direction}'.")
25 return "-" if direction == "+" else "+"
26
27
28 def adjoint_fwidth_from_simulation(simulation: Simulation) -> float:Lines 75-85 75 coefficient: complex,
76 fwidth: float,
77 ) -> CustomCurrentSource | None:
78 if any(simulation.symmetry):
! 79 raise ValueError("Point-field adjoint sources require symmetry to be disabled.")
80 if not monitor.colocate:
! 81 raise ValueError("Point-field adjoint sources require colocated field monitors.")
82
83 grid = simulation.discretize_monitor(monitor)
84 coords = {}
85 spatial_coords = grid.boundariesLines 87-95 87 for axis, dim in enumerate("xyz"):
88 if monitor.size[axis] == 0:
89 coords[dim] = np.array([monitor.center[axis]])
90 else:
! 91 coords[dim] = np.array(spatial_coords_dict[dim][:-1])
92 values = (
93 2
94 * -1j
95 * coefficientLines 116-124 116 values *= scaling_factor
117 values = np.nan_to_num(values, nan=0.0)
118
119 if np.all(values == 0):
! 120 return None
121
122 dataset = FieldDataset(**{component: ScalarFieldDataArray(values, coords=coords)})
123 return CustomCurrentSource(
124 center=monitor.geometry.center,Lines 132-140 132 def diffraction_monitor_medium(simulation: Simulation, monitor: DiffractionMonitor) -> object:
133 structures = [simulation.scene.background_structure, *list(simulation.structures or ())]
134 mediums = simulation.scene.intersecting_media(monitor, structures)
135 if len(mediums) != 1:
! 136 raise ValueError("Diffraction monitor plane must be homogeneous to build adjoint sources.")
137 return list(mediums)[0]
138
139
140 def bloch_vec_for_axis(simulation: Simulation, axis_name: str) -> float:Lines 141-149 141 boundary = simulation.boundary_spec[axis_name]
142 plus = boundary.plus
143 if hasattr(plus, "bloch_vec"):
144 return float(plus.bloch_vec)
! 145 return 0.0
146
147
148 def diffraction_order_range(
149 size: float, bloch_vec: float, freq: float, medium: objectLines 155-163 155 limit = abs(index) * freq * size / C_0
156 order_min = int(np.ceil(-limit - bloch_vec))
157 order_max = int(np.floor(limit - bloch_vec))
158 if order_max < order_min:
! 159 return np.array([], dtype=int)
160 return np.arange(order_min, order_max + 1, dtype=int)
161
162
163 def diffraction_source_from_simulation(Lines 198-206 198 theta_vals, phi_vals = DiffractionData.compute_angles((ux, uy))
199 angle_theta = float(theta_vals[0, 0, 0])
200 angle_phi = float(phi_vals[0, 0, 0])
201 if np.isnan(angle_theta) or np.cos(angle_theta) <= COS_THETA_THRESH:
! 202 raise ValueError("Adjoint source not available for evanescent diffraction order.")
203
204 pol_angle = 0.0 if polarization == "p" else np.pi / 2
205 bck_eps = medium.eps_model(freq)
206 return _diffraction_plane_wave(Lines 230-242 230 angle_theta = float(theta_data.sel(**angle_sel_kwargs))
231 angle_phi = float(phi_data.sel(**angle_sel_kwargs))
232
233 if np.isnan(angle_theta):
! 234 return None
235
236 pol_str = str(polarization)
237 if pol_str not in ("p", "s"):
! 238 raise ValueError(f"Something went wrong, given pol='{pol_str}' in adjoint source.")
239
240 pol_angle = 0.0 if pol_str == "p" else np.pi / 2
241 bck_eps = diff_data.medium.eps_model(freq)
242 return _diffraction_plane_wave(tidy3d/components/autograd/utils.pyLines 62-70 62 if k in target:
63 val = target[k]
64 if isinstance(val, (list, tuple)) and isinstance(v, (list, tuple)):
65 if len(val) != len(v):
! 66 raise ValueError(
67 f"Cannot accumulate field map for key '{k}': "
68 f"length mismatch ({len(val)} vs {len(v)})."
69 )
70 target[k] = type(val)(x + y for x, y in zip(val, v))tidy3d/components/data/monitor_data.pyLines 174-186 174 return []
175
176 def supports_parallel_adjoint(self) -> bool:
177 """Return ``True`` if this monitor data supports parallel adjoint sources."""
! 178 return False
179
180 def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
181 """Return parallel adjoint bases for this monitor data."""
! 182 return []
183
184 @staticmethod
185 def get_amplitude(x) -> complex:
186 """Get the complex amplitude out of some data."""Lines 1463-1471 1463
1464 def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
1465 """Return parallel adjoint bases for single-point field monitors."""
1466 if not self.supports_parallel_adjoint():
! 1467 return []
1468 component_freqs = [
1469 (str(component), data_array.coords["f"].values)
1470 for component, data_array in self.field_components.items()
1471 ]Lines 1875-1883 1875 return val
1876
1877 def supports_parallel_adjoint(self) -> bool:
1878 """Return ``True`` for mode monitor amplitude adjoints."""
! 1879 return True
1880
1881 def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
1882 """Return parallel adjoint bases for mode monitor amplitudes."""
1883 amps = self.ampsLines 4106-4114 4106 return DataArray(np.stack([amp_phi, amp_theta], axis=3), coords=coords)
4107
4108 def supports_parallel_adjoint(self) -> bool:
4109 """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 4110 return True
4111
4112 def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
4113 """Return parallel adjoint bases for diffraction monitor amplitudes."""
4114 amps = self.ampstidy3d/components/monitor.pyLines 125-133 125 return self.storage_size(num_cells=num_cells, tmesh=tmesh)
126
127 def supports_parallel_adjoint(self) -> bool:
128 """Return ``True`` if this monitor can provide parallel adjoint bases."""
! 129 return False
130
131 def parallel_adjoint_bases(
132 self, simulation: Simulation, monitor_index: int
133 ) -> list[ParallelAdjointBasis]:Lines 131-139 131 def parallel_adjoint_bases(
132 self, simulation: Simulation, monitor_index: int
133 ) -> list[ParallelAdjointBasis]:
134 """Return parallel adjoint bases for this monitor."""
! 135 return []
136
137
138 class FreqMonitor(Monitor, ABC):
139 """:class:`Monitor` that records data in the frequency-domain."""Lines 1112-1120 1112 return amps_size + fields_size
1113
1114 def supports_parallel_adjoint(self) -> bool:
1115 """Return ``True`` for mode monitor amplitude adjoints."""
! 1116 return True
1117
1118 def parallel_adjoint_bases(
1119 self, simulation: Simulation, monitor_index: int
1120 ) -> list[ParallelAdjointBasis]:Lines 1853-1861 1853 return BYTES_COMPLEX * len(self.ux) * len(self.uy) * len(self.freqs) * 6
1854
1855 def supports_parallel_adjoint(self) -> bool:
1856 """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 1857 return True
1858
1859 def parallel_adjoint_bases(
1860 self, simulation: Simulation, monitor_index: int
1861 ) -> list[ParallelAdjointBasis]:Lines 1859-1867 1859 def parallel_adjoint_bases(
1860 self, simulation: Simulation, monitor_index: int
1861 ) -> list[ParallelAdjointBasis]:
1862 """Return parallel adjoint bases for diffraction monitor amplitudes."""
! 1863 from tidy3d.components.autograd.source_factory import (
1864 bloch_vec_for_axis,
1865 diffraction_monitor_medium,
1866 diffraction_order_range,
1867 )Lines 1865-1902 1865 diffraction_monitor_medium,
1866 diffraction_order_range,
1867 )
1868
! 1869 medium = diffraction_monitor_medium(simulation, self)
1870
! 1871 axis_names = ("x", "y", "z")
! 1872 normal_axis = self.normal_axis
! 1873 transverse_axes = [axis_names[i] for i in range(3) if i != normal_axis]
! 1874 axis_x, axis_y = transverse_axes
1875
! 1876 size_x = simulation.size[axis_names.index(axis_x)]
! 1877 size_y = simulation.size[axis_names.index(axis_y)]
! 1878 bloch_vec_x = bloch_vec_for_axis(simulation, axis_x)
! 1879 bloch_vec_y = bloch_vec_for_axis(simulation, axis_y)
1880
! 1881 bases: list[DiffractionAdjointBasis] = []
! 1882 freqs = [float(freq) for freq in self.freqs]
! 1883 for freq in freqs:
! 1884 orders_x = diffraction_order_range(size_x, bloch_vec_x, freq, medium)
! 1885 orders_y = diffraction_order_range(size_y, bloch_vec_y, freq, medium)
! 1886 if orders_x.size == 0 or orders_y.size == 0:
! 1887 continue
1888
! 1889 ux = _reciprocal_coords(
1890 orders=orders_x, size=size_x, bloch_vec=bloch_vec_x, freq=freq, medium=medium
1891 )
! 1892 uy = _reciprocal_coords(
1893 orders=orders_y, size=size_y, bloch_vec=bloch_vec_y, freq=freq, medium=medium
1894 )
! 1895 theta_vals, _ = _compute_angles((ux, uy))
! 1896 order_x_index = {int(val): idx for idx, val in enumerate(orders_x)}
! 1897 order_y_index = {int(val): idx for idx, val in enumerate(orders_y)}
! 1898 bases.extend(
1899 _build_diffraction_bases_for_freq(
1900 monitor_name=self.name,
1901 monitor_index=monitor_index,
1902 freq=freq,Lines 1911-1919 1911 order_x_index[ox], order_y_index[oy], 0
1912 ],
1913 )
1914 )
! 1915 return bases
1916
1917
1918 class DiffractionMonitor(PlanarMonitor, FreqMonitor):
1919 """:class:`Monitor` that uses a 2D Fourier transform to compute theLines 1987-1995 1987 return BYTES_COMPLEX * len(self.freqs)
1988
1989 def supports_parallel_adjoint(self) -> bool:
1990 """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 1991 return True
1992
1993 def parallel_adjoint_bases(
1994 self, simulation: Simulation, monitor_index: int
1995 ) -> list[ParallelAdjointBasis]:Lines 2017-2025 2017 for freq in freqs:
2018 orders_x = diffraction_order_range(size_x, bloch_vec_x, freq, medium)
2019 orders_y = diffraction_order_range(size_y, bloch_vec_y, freq, medium)
2020 if orders_x.size == 0 or orders_y.size == 0:
! 2021 continue
2022
2023 ux = _reciprocal_coords(
2024 orders=orders_x, size=size_x, bloch_vec=bloch_vec_x, freq=freq, medium=medium
2025 )tidy3d/web/api/autograd/parallel_adjoint.pyLines 40-48 40 def _scale_field_map(field_map: AutogradFieldMap, scale: float) -> AutogradFieldMap:
41 scaled = {}
42 for k, v in field_map.items():
43 if isinstance(v, (list, tuple)):
! 44 scaled[k] = type(v)(scale * x for x in v)
45 else:
46 scaled[k] = scale * v
47 return scaledLines 59-69 59 unsupported: list[str] = []
60 for monitor_index, monitor in enumerate(simulation.monitors):
61 try:
62 bases_for_monitor = monitor.parallel_adjoint_bases(simulation, monitor_index)
! 63 except ValueError:
! 64 unsupported.append(monitor.name)
! 65 continue
66 if bases_for_monitor:
67 bases.extend(bases_for_monitor)
68 elif not monitor.supports_parallel_adjoint():
69 unsupported.append(monitor.name)Lines 103-115 103 basis_spec: object,
104 ) -> object:
105 post_norm = sim_data_adj.simulation.post_norm
106 if not hasattr(basis_spec, "freq"):
! 107 return post_norm
108 freqs = np.asarray(post_norm.coords["f"].values)
109 idx = int(np.argmin(np.abs(freqs - basis_spec.freq)))
110 if not np.isclose(freqs[idx], basis_spec.freq):
! 111 raise td.exceptions.AdjointError(
112 "Parallel adjoint basis frequency not found in adjoint post-normalization."
113 )
114 return post_norm.isel(f=[idx])Lines 132-150 132 for key, data_array in monitor_data.field_components.items():
133 if "f" in data_array.dims:
134 freqs = np.asarray(data_array.coords["f"].values)
135 if freqs.size == 0:
! 136 raise td.exceptions.AdjointError(
137 "Parallel adjoint expected frequency data but no frequencies were found."
138 )
139 idx = int(np.argmin(np.abs(freqs - freq)))
140 if not np.isclose(freqs[idx], freq, rtol=1e-10, atol=0.0):
! 141 raise td.exceptions.AdjointError(
142 "Parallel adjoint basis frequency not found in monitor data."
143 )
144 updates[key] = data_array.isel(f=[idx])
145 return monitor_data.updated_copy(monitor=monitor, deep=False, validate=False, **updates)
! 146 return monitor_data.updated_copy(monitor=monitor, deep=False, validate=False)
147
148
149 def _select_sim_data_freq(
150 sim_data_adj: td.SimulationData,Lines 156-164 156 for monitor in sim.monitors:
157 if hasattr(monitor, "freqs"):
158 monitor_updated = monitor.updated_copy(freqs=[freq])
159 else:
! 160 monitor_updated = monitor
161 monitors.append(monitor_updated)
162 monitor_map[monitor.name] = monitor_updated
163 sim_updated = sim.updated_copy(monitors=monitors)Lines 224-232 224 simulation: td.Simulation,
225 basis_sources: list[tuple[ParallelAdjointBasis, Any]],
226 ) -> list[tuple[list[ParallelAdjointBasis], AdjointSourceInfo]]:
227 if not basis_sources:
! 228 return []
229
230 sim_data_stub = td.SimulationData(simulation=simulation, data=())
231 sources = [source for _, source in basis_sources]
232 sources_processed = td.SimulationData._adjoint_src_width_single(sources)Lines 310-321 310 coefficient=coefficient,
311 fwidth=fwidth,
312 )
313 if source is None:
! 314 raise ValueError("Adjoint point source has zero amplitude.")
315 return adjoint_source_info_single(source)
316
! 317 raise ValueError("Unsupported parallel adjoint basis.")
318
319
320 @dataclass(frozen=True)
321 class ParallelAdjointPayload:Lines 370-392 370 simulation=simulation,
371 basis=basis,
372 coefficient=1.0 + 0.0j,
373 )
! 374 except ValueError as exc:
! 375 td.log.info(
376 f"Skipping parallel adjoint basis for monitor '{basis.monitor_name}': {exc}"
377 )
! 378 continue
379 basis_sources.append((basis, source_info.sources[0]))
380
381 if not basis_sources:
! 382 if basis_specs:
! 383 td.log.info("Parallel adjoint produced no simulations for this task.")
384 else:
! 385 td.log.warning(
386 "Parallel adjoint disabled because no eligible monitor outputs were found."
387 )
! 388 return None
389
390 grouped = _group_parallel_adjoint_bases_by_port(simulation, basis_sources)
391 if len(grouped) > max_num_adjoint_per_fwd:
392 raise AdjointError(Lines 411-425 411 task_map[adj_task_name] = bases
412 used_bases.extend(bases)
413
414 if not sims_adj_dict:
! 415 if basis_specs:
! 416 td.log.info("Parallel adjoint produced no simulations for this task.")
417 else:
! 418 td.log.warning(
419 "Parallel adjoint disabled because no eligible monitor outputs were found."
420 )
! 421 return None
422
423 td.log.info(
424 "Parallel adjoint enabled: launched "
425 f"{len(sims_adj_dict)} canonical adjoint simulations for task '{task_name}'."Lines 437-445 437 task_paths: dict[str, str],
438 base_dir: PathLike,
439 ) -> None:
440 if not task_names:
! 441 return
442 target_dir = Path(base_dir) / config.adjoint.local_adjoint_dir
443 target_dir.mkdir(parents=True, exist_ok=True)
444 for task_name in task_names:
445 src_path = task_paths.get(task_name)Lines 444-459 444 for task_name in task_names:
445 src_path = task_paths.get(task_name)
446 if not src_path:
447 continue
! 448 src = Path(src_path)
! 449 if not src.exists():
! 450 continue
! 451 dst = target_dir / src.name
! 452 if src.resolve() == dst.resolve():
! 453 continue
! 454 dst.parent.mkdir(parents=True, exist_ok=True)
! 455 src.replace(dst)
456
457
458 def apply_parallel_adjoint(
459 data_fields_vjp: AutogradFieldMap,Lines 461-469 461 sim_data_orig: td.SimulationData,
462 ) -> tuple[AutogradFieldMap, AutogradFieldMap]:
463 basis_maps = parallel_info.get("basis_maps")
464 if basis_maps is None:
! 465 return {}, data_fields_vjp
466
467 data_fields_vjp_fallback = {k: np.array(v, copy=True) for k, v in data_fields_vjp.items()}
468 vjp_parallel: AutogradFieldMap = {}
469 norm_cache: dict[int, np.ndarray] = {}Lines 476-488 476 used_bases = 0
477 for basis in basis_specs:
478 basis_map = basis_maps.get(basis)
479 if basis_map is None:
! 480 continue
481 basis_real = basis_map.get("real")
482 basis_imag = basis_map.get("imag")
483 if basis_real is None or basis_imag is None:
! 484 continue
485 tracked_bases += 1
486 if isinstance(basis, DiffractionAdjointBasis):
487 norm = norm_cache.get(basis.monitor_index)
488 if norm is None:Lines 518-531 518 f"{unused_sims} simulations were unused. Disable parallel adjoint to avoid "
519 "unused precomputations."
520 )
521 else:
! 522 td.log.warning(
523 f"Parallel adjoint used {used_bases} of {tracked_bases} bases after VJP "
524 "evaluation. Disable parallel adjoint to avoid unused precomputations."
525 )
526 else:
! 527 td.log.warning(
528 f"Parallel adjoint used {used_bases} of {tracked_bases} bases after VJP "
529 f"evaluation; {unused_bases} had zero VJP coefficients. Disable parallel adjoint "
530 "to avoid unused precomputations."
531 ) |
0c9fe1b to
4d59014
Compare
4d59014 to
7cabb37
Compare
yaugenst-flex
left a comment
There was a problem hiding this comment.
Thanks @marcorudolphflex this is pretty great, had a cursory glance at the PR to try to understand a bit what's going on and left some questions/comments, but I'll look deeper into the implementation when I find some time. I guess one thing to note is that this introduces a lot of new code, even modules. Not a problem in itself but I'd maybe have a closer look whether any of this can be simplified.
Also, could you show some plots/verification against the non-parallel adjoint?
| - Mode direction policy (for mode monitors): `config.adjoint.parallel_adjoint_mode_direction_policy` | ||
| - `"assume_outgoing"` (default): pick the mode direction based on monitor position relative to the simulation center and flip it for the adjoint. | ||
| - `"run_both_directions"`: launch parallel adjoint sources for both `+` and `-` directions. | ||
| - `"no_parallel"`: disable parallel adjoint entirely. |
There was a problem hiding this comment.
Why do mode monitors separately have a flag to turn parallel adjoint off, in addition to the global config?
There was a problem hiding this comment.
tbd if users need that in case they want to override the global toggle for this less-determined mode monitor... As we do have a config field anyways, I think it doesn't hurt. Or could that be confusing for users regarding its effect along with the global toggle?
There was a problem hiding this comment.
by less-determined, do you mean that it's harder to predict the adjoint sources to run in parallel and that's why a user would want to turn it off?
| - Only effective when: `config.adjoint.local_gradient = True` | ||
| - If `local_gradient=False`, the flag is ignored and behavior remains unchanged. |
There was a problem hiding this comment.
Why only local gradients? Couldn't this be supported in remote too? Maybe it's fine as an initial version but I don't see how this couldnt be done for remote?
There was a problem hiding this comment.
Probably yes, this was the "easy" start.
|
|
||
| #### Limits and guardrails you should expect | ||
|
|
||
| - Hard cap: the feature will not exceed `config.adjoint.max_adjoint_per_fwd`. |
There was a problem hiding this comment.
Are all parallel simulations counted as adjoint toward this cap?
tidy3d/plugins/autograd/README.md
Outdated
| #### Limits and guardrails you should expect | ||
|
|
||
| - Hard cap: the feature will not exceed `config.adjoint.max_adjoint_per_fwd`. | ||
| - If enabling parallel adjoint would exceed the cap, the run logs a warning and proceeds with the sequential path for that forward run (or a safe subset, depending on policy). |
There was a problem hiding this comment.
We might not want to proceed at all in that case, not sure. Since this a flag that we wouldn't turn on by default, it means that generally the user will have requested it, so they might want to choose to increase the cap instead of running sequentially.
There was a problem hiding this comment.
true, changed it to raising an AdjointError as we do it currently for sequential adjoint
There was a problem hiding this comment.
It would be important to explain/understand here in which scenarios how many adjoint simulations would get launched in the parallel case and what the edge cases are so there are no surprises.
There was a problem hiding this comment.
added some section in the readme
7cabb37 to
f351142
Compare
f351142 to
9abfd55
Compare
9abfd55 to
ac73e25
Compare
ac73e25 to
dac5685
Compare
dac5685 to
29306de
Compare
29306de to
e05cfd7
Compare
e05cfd7 to
e13d19f
Compare
e13d19f to
6b77f75
Compare
|
|
||
| When enabled, Tidy3D launches eligible adjoint simulations in parallel with the forward simulation by running a set of canonical "unit" adjoint solves up front. During the backward pass, it reuses those precomputed results and scales them with the actual VJP coefficients from your objective. | ||
|
|
||
| Net effect: reduced gradient wall-clock time (often close to ~2x faster in the "many deterministic adjoints" regime), at the cost of sometimes running adjoint solves that your objective ultimately does not use. |
There was a problem hiding this comment.
are there cases where the number of unused simulations can be large? I guess this is where `max_adjoint_per_fwd' protects the user from accidentally running a bunch of sims? Would we want to also issue a warning if a large number of adjoint simulations are unused (maybe this is done already)?
There was a problem hiding this comment.
Yes this can happen if the objective does not use all modes or frequencies. Added a warning which does inform about the number of unused parallel simulations.
| Parallel adjoint launches one canonical adjoint simulation per eligible “basis,” so the total | ||
| count is driven by how many distinct outputs your monitors expose: | ||
|
|
||
| - **Mode monitors**: one basis per `(freq, mode_index, direction)`. If |
There was a problem hiding this comment.
in the sequential version, we do some grouping for adjoint sources based on number of ports versus number of frequencies. Does the same happen for parallel adjoint?
I'm thinking of a multi-port optimization (like multiple mode monitors) and a single frequency optimization. In sequential when we know the vjp coming from the objective, we can launch all the adjoint sources at the same time provided we set the amplitude and phase for the single frequency. Thinking through this, we wouldn't be able to do the same for the parallel case right? Instead, if I'm understanding correctly, we would run them all as single frequency injections from each mode monitor and then combine results after computing the objective function grad.
There was a problem hiding this comment.
No I also think this is not possible as we do not know the individual VJPs from the ports before. Added a note in the readme to clarify that.
| ) -> None: | ||
| if not parallel_info or not sims_adj: | ||
| return | ||
| td.log.warning( |
There was a problem hiding this comment.
does this mean part of the process will happen with parallel and part with sequential?
There was a problem hiding this comment.
currently yes. But this does not really makes sense at that point. Changed it such that we completely fall back to sequential in this case.
| raise td.exceptions.AdjointError( | ||
| "Parallel adjoint basis frequency not found in adjoint post-normalization." | ||
| ) | ||
| return post_norm.isel(f=[idx]) |
There was a problem hiding this comment.
can this be called without the list? post_norm.isel(f=idx)
There was a problem hiding this comment.
Using a scalar index would drop the f dimension in xarray, which breaks downstream expectations for post_norm.f and frequency‑aligned broadcasting. The list keeps a length‑1 f dim and matches _select_monitor_data_freq.
| monitor = simulation.monitors[basis.monitor_index] | ||
| fwidth = adjoint_fwidth_from_simulation(simulation) | ||
|
|
||
| if isinstance(basis, DiffractionAdjointBasis): |
There was a problem hiding this comment.
is it expected that we would have cases of mismatched basis and monitor type?
There was a problem hiding this comment.
not in the current setup, but was introduced to detect regressions. But I guess this is over-defensive here, removed that.
|
Thanks @marcorudolphflex for working on this, it is a really huge effort! I've been working through the doc you included here and the code and still have a good bit to go. I am curious to understand a bit more about the cases that are mostly being accelerated with the parallel adjoint approach and which cases we end up needing to run extra simulations that we don't end up needing as a result (or which may have been able to be grouped if running sequentially). |
2ace906 to
9cdf564
Compare
Thanks for the review! I think the essential cases where we use more or unused simulations are frequency-grouped simulations and cases where monitor components have a 0-vjp as the objective does not use them. |
9cdf564 to
cacd0bf
Compare
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
| values = np.nan_to_num(values, nan=0.0) | ||
|
|
||
| if np.all(values == 0): | ||
| return None |
There was a problem hiding this comment.
Floating-point equality check for array values
Medium Severity
The check if np.all(values == 0): uses exact floating-point equality to determine if an adjoint source has zero amplitude. Due to floating-point arithmetic in the computation of values (involving scaling_factor, EPSILON_0, omega0, and size_element), the array may contain very small non-zero values even when the source is effectively zero. This could cause the function to return a non-None source when it should return None, or vice versa.
Thanks for your reply on this and the extra information. I was intending to get to thinking through more how it might be broken up today, but ended up spending the day on the CustomMedium debugging. I'll get around to this tomorrow and maybe we could chat later Thursday sometime on the different cases here! |


Note
Medium Risk
Touches core autograd forward/backward execution paths and adds new scheduling/fallback logic that changes how/when adjoint simulations are launched under
local_gradient=True, so regressions could affect gradient correctness or runtime behavior despite strong test coverage.Overview
Enables parallel adjoint scheduling for autograd local gradients: when
config.adjoint.parallel_all_port=True, the forwardrun/run_asynccan launch a batch containing the forward solve plus canonical “unit” adjoint simulations for eligible monitor outputs, then reuse/scales those precomputed results during VJP evaluation.Introduces new parallel-adjoint infrastructure: monitor/monitor-data expose
supports_parallel_adjoint()andparallel_adjoint_bases()for mode amplitudes, diffraction amplitudes, and single-point field probes; new source factories deterministically build the corresponding adjoint sources; and a newparallel_adjointweb API module prepares payloads, applies precomputed basis maps, and falls back to sequential adjoints for any remaining VJP entries (with warnings and caps viamax_adjoint_per_fwd).Refactors and hardens the adjoint pipeline by centralizing adjoint-simulation construction (
make_adjoint_simulation), adding shared VJP-map filtering/NaN checks and a reusableaccumulate_field_map, improving test emulation determinism, and adding extensive tests/docs/config reference for the new flags (includingparallel_adjoint_mode_direction_policy).Written by Cursor Bugbot for commit cacd0bf. This will update automatically on new commits. Configure here.
Greptile Overview
Greptile Summary
This PR implements parallel adjoint scheduling for autograd simulations, allowing eligible adjoint simulations to run concurrently with forward simulations when
local_gradient=True. The feature launches canonical "unit" adjoint solves up front and scales them during the backward pass, reducing gradient computation wall-clock time.Key changes:
config.adjoint.parallel_all_portconfiguration flag to enable the featureconfig.adjoint.parallel_adjoint_mode_direction_policyto control mode direction handlingParallelAdjointDescriptorclasses for mode, diffraction, and point-field monitorssupports_parallel_adjoint()andparallel_adjoint_descriptors()methodsmake_adjoint_simulation()functionThe implementation includes proper fallback mechanisms when monitors are unsupported or limits are exceeded, ensuring backward compatibility.
Confidence Score: 3/5
tidy3d/web/api/autograd/parallel_adjoint.py(lines 322, 327, 329) andtidy3d/components/autograd/source_factory.py(lines 94, 207) for floating-point comparison fixes.Important Files Changed
supports_parallel_adjoint()andparallel_adjoint_descriptors()methods; extracted mode source creation to factory.make_adjoint_simulation()function for reuse; clean refactoring with no logic changes.Sequence Diagram
sequenceDiagram participant User participant AutogradAPI as Autograd API participant ParallelAdjoint as Parallel Adjoint participant Batch as Batch Executor participant Solver as FDTD Solver User->>AutogradAPI: run(sim, local_gradient=True) AutogradAPI->>ParallelAdjoint: prepare_parallel_adjoint(sim) ParallelAdjoint->>ParallelAdjoint: collect descriptors from monitors ParallelAdjoint->>ParallelAdjoint: filter by direction policy ParallelAdjoint->>ParallelAdjoint: create canonical adjoint sims ParallelAdjoint-->>AutogradAPI: ParallelAdjointPayload alt Parallel Adjoint Enabled AutogradAPI->>Batch: run_async({fwd, adj_1, adj_2, ...}) Batch->>Solver: run forward sim Batch->>Solver: run adjoint sim 1 Batch->>Solver: run adjoint sim 2 Batch-->>AutogradAPI: BatchData AutogradAPI->>AutogradAPI: populate_parallel_adjoint_bases() AutogradAPI-->>User: SimulationData + aux_data else Parallel Adjoint Disabled AutogradAPI->>Solver: run forward sim only AutogradAPI-->>User: SimulationData end User->>AutogradAPI: backward pass (VJP) AutogradAPI->>ParallelAdjoint: apply_parallel_adjoint(vjp, bases) ParallelAdjoint->>ParallelAdjoint: compute coefficients from VJP ParallelAdjoint->>ParallelAdjoint: scale and accumulate basis maps ParallelAdjoint-->>AutogradAPI: vjp_parallel + vjp_fallback alt Has fallback VJPs AutogradAPI->>Solver: run sequential adjoint for remaining Solver-->>AutogradAPI: adjoint field data AutogradAPI->>AutogradAPI: combine vjp_parallel + sequential end AutogradAPI-->>User: gradient