FXC-4927 enable source differentiation#3197
Conversation
tests/test_components/autograd/numerical/test_autograd_source_numerical.py
Outdated
Show resolved
Hide resolved
9542484 to
5dec731
Compare
5dec731 to
6aa0c0b
Compare
6aa0c0b to
7d69d2a
Compare
7d69d2a to
e93bb93
Compare
e93bb93 to
360ac7a
Compare
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
| scale = scale * step | ||
| if dim not in dims_to_integrate and field_data.sizes.get(dim, 0) > 1: | ||
| scale = scale / field_data.sizes[dim] | ||
| return weights * scale |
There was a problem hiding this comment.
Unused exported function never used in production
Low Severity
The function compute_source_weights is defined, added to __all__, and has a unit test, but it is never actually imported or used anywhere in the production code. CustomCurrentSource._compute_derivatives only imports transpose_interp_field_to_dataset, and CustomFieldSource._compute_derivatives imports compute_spatial_weights (not compute_source_weights), get_frequency_omega, and transpose_interp_field_to_dataset. This is dead code that adds maintenance burden without being utilized.
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/components/autograd/derivative_utils.pyLines 1063-1071 1063 """
1064
1065 def _cell_size_weights(coord: np.ndarray) -> np.ndarray:
1066 if coord.size <= 1:
! 1067 return np.array([1.0], dtype=float)
1068 deltas = np.diff(coord)
1069 diff_left = np.pad(deltas, (1, 0), mode="edge")
1070 diff_right = np.pad(deltas, (0, 1), mode="edge")
1071 return 0.5 * (diff_left + diff_right)Lines 1073-1081 1073 weight_dims = []
1074 weight_arrays = []
1075 for dim in dims:
1076 if dim not in arr.coords:
! 1077 continue
1078 coord = np.asarray(arr.coords[dim].data)
1079 if coord.size <= 1:
1080 continue
1081 weight_dims.append(dim)Lines 1081-1089 1081 weight_dims.append(dim)
1082 weight_arrays.append(_cell_size_weights(coord))
1083
1084 if not weight_dims:
! 1085 return SpatialDataArray(1.0)
1086
1087 weights = np.ix_(*weight_arrays)
1088 weights_data = weights[0]
1089 for weight_array in weights[1:]:Lines 1105-1117 1105 weights = compute_spatial_weights(field_data, dims=dims_to_integrate)
1106 scale = 1.0
1107 for axis, dim in enumerate("xyz"):
1108 if dim not in field_data.coords:
! 1109 continue
1110 if dim in dims_to_integrate and field_data.sizes.get(dim, 0) == 1:
1111 axis_size = float(source_size[axis])
1112 if axis_size > 0.0:
! 1113 scale = scale * axis_size
1114 elif axis_size == 0.0 and dim in adjoint_field.coords:
1115 coord_vals = np.asarray(adjoint_field.coords[dim].data)
1116 if coord_vals.size > 1:
1117 step = np.min(np.abs(np.diff(coord_vals)))Lines 1117-1125 1117 step = np.min(np.abs(np.diff(coord_vals)))
1118 if np.isfinite(step) and step > 0.0:
1119 scale = scale * step
1120 if dim not in dims_to_integrate and field_data.sizes.get(dim, 0) > 1:
! 1121 scale = scale / field_data.sizes[dim]
1122 return weights * scale
1123
1124
1125 def transpose_interp_field_to_dataset(Lines 1136-1145 1136 if target_freqs.size == source_freqs.size and np.allclose(
1137 target_freqs, source_freqs, rtol=1e-12, atol=0.0
1138 ):
1139 return field
! 1140 method = "nearest" if target_freqs.size <= 1 or source_freqs.size <= 1 else "linear"
! 1141 return field.interp(
1142 {"f": target_freqs},
1143 method=method,
1144 kwargs={"bounds_error": False, "fill_value": 0.0},
1145 ).fillna(0.0)Lines 1149-1157 1149 ) -> np.ndarray:
1150 if param_coords_1d.size == 1:
1151 return field_values.sum(axis=0, keepdims=True)
1152 if np.any(param_coords_1d[1:] < param_coords_1d[:-1]):
! 1153 raise ValueError("Spatial coordinates must be sorted before computing derivatives.")
1154
1155 n_param = param_coords_1d.size
1156 n_field = field_values.shape[0]
1157 field_values_2d = field_values.reshape(n_field, -1)Lines 1199-1207 1199 values = np.asarray(weighted.data)
1200 dims = list(weighted.dims)
1201 for dim in "xyz":
1202 if dim not in field_coords or dim not in param_coords:
! 1203 continue
1204 axis_index = dims.index(dim)
1205 values = _interp_axis(values, axis_index, field_coords[dim], param_coords[dim])
1206
1207 out_coords = {dim: np.asarray(dataset_field.coords[dim].data) for dim in dataset_field.dims}Lines 1206-1214 1206
1207 out_coords = {dim: np.asarray(dataset_field.coords[dim].data) for dim in dataset_field.dims}
1208 result = SpatialDataArray(values, coords=out_coords, dims=tuple(dims))
1209 if tuple(dims) != tuple(dataset_field.dims):
! 1210 result = result.transpose(*dataset_field.dims)
1211 return result
1212
1213
1214 def get_frequency_omega(Lines 1217-1225 1217 """Return angular frequency aligned with field_data frequencies."""
1218 if "f" in field_data.dims:
1219 omega = 2 * np.pi * np.asarray(field_data.coords["f"].data)
1220 return FreqDataArray(omega, coords={"f": np.asarray(field_data.coords["f"].data)})
! 1221 return 2 * np.pi * float(np.asarray(frequencies).squeeze())
1222
1223
1224 __all__ = [
1225 "DerivativeInfo",tidy3d/components/base.pyLines 1222-1230 1222 # Handle multiple starting paths
1223 if paths:
1224 # If paths is a single tuple, convert to tuple of tuples
1225 if isinstance(paths[0], str):
! 1226 paths = (paths,)
1227
1228 # Process each starting path
1229 for starting_path in paths:
1230 # Navigate to the starting path in the dictionarytidy3d/components/simulation.pyLines 4894-4902 4894 structure_index_to_keys[index].append(fields)
4895 elif component_type == "sources":
4896 source_index_to_keys[index].append(fields)
4897 else:
! 4898 raise ValueError(
4899 f"Unknown component type '{component_type}' encountered while "
4900 "constructing adjoint monitors. "
4901 "Expected one of: 'structures', 'sources'."
4902 )tidy3d/components/source/base.pyLines 67-75 67 _warn_traced_size = _warn_unsupported_traced_argument("size")
68
69 def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:
70 """Compute adjoint derivatives for source parameters."""
! 71 raise NotImplementedError(f"Can't compute derivative for 'Source': '{type(self)}'.")
72
73 @pydantic.validator("source_time", always=True)
74 def _freqs_lower_bound(cls, val):
75 """Raise validation error if central frequency is too low."""tidy3d/components/source/current.pyLines 230-238 230 transpose_interp_field_to_dataset,
231 )
232
233 if self.current_dataset is None:
! 234 return {tuple(path): 0.0 for path in derivative_info.paths}
235
236 derivative_map = {}
237 center = tuple(self.center)
238 h_adj = derivative_info.H_adj or {}Lines 240-252 240
241 for field_path in derivative_info.paths:
242 field_path = tuple(field_path)
243 if len(field_path) < 2 or field_path[0] != "current_dataset":
! 244 log.warning(
245 f"Unsupported traced source path '{field_path}' for CustomCurrentSource."
246 )
! 247 derivative_map[field_path] = 0.0
! 248 continue
249
250 field_name = field_path[1]
251 if (
252 len(field_name) != 2Lines 252-270 252 len(field_name) != 2
253 or field_name[0] not in ("E", "H")
254 or field_name[1] not in ("x", "y", "z")
255 ):
! 256 log.warning(f"Unsupported field component '{field_name}' in CustomCurrentSource.")
! 257 derivative_map[field_path] = 0.0
! 258 continue
259
260 field_data = getattr(self.current_dataset, field_name, None)
261 if field_data is None:
! 262 raise ValueError(f"Cannot find field '{field_name}' in current dataset.")
263
264 if field_name.startswith("H"):
! 265 adjoint_field = h_adj.get(field_name)
! 266 component_sign = -1.0
267 else: # "E" case
268 adjoint_field = e_adj.get(field_name)
269 component_sign = 1.0tidy3d/components/source/field.pyLines 254-262 254 transpose_interp_field_to_dataset,
255 )
256
257 if self.field_dataset is None:
! 258 return {tuple(path): 0.0 for path in derivative_info.paths}
259
260 derivative_map = {}
261 center = tuple(self.center)
262 e_adj = derivative_info.E_adj or {}Lines 261-282 261 center = tuple(self.center)
262 e_adj = derivative_info.E_adj or {}
263 h_adj = derivative_info.H_adj or {}
264 if self.injection_axis is None:
! 265 return {tuple(path): 0.0 for path in derivative_info.paths}
266
267 for field_path in derivative_info.paths:
268 field_path = tuple(field_path)
269 if len(field_path) < 2 or field_path[0] != "field_dataset":
! 270 log.warning(f"Unsupported traced source path '{field_path}' for CustomFieldSource.")
! 271 derivative_map[field_path] = 0.0
! 272 continue
273
274 field_name = field_path[1]
275 field_data = getattr(self.field_dataset, field_name, None)
276 if field_data is None:
! 277 derivative_map[field_path] = 0.0
! 278 continue
279
280 if (
281 len(field_name) != 2
282 or field_name[0] not in ("E", "H")Lines 281-296 281 len(field_name) != 2
282 or field_name[0] not in ("E", "H")
283 or field_name[1] not in ("x", "y", "z")
284 ):
! 285 log.warning(f"Unsupported field component '{field_name}' in CustomFieldSource.")
! 286 derivative_map[field_path] = 0.0
! 287 continue
288
289 component_axis = "xyz".index(field_name[1])
290 if component_axis == self.injection_axis:
! 291 derivative_map[field_path] = np.zeros_like(field_data.data)
! 292 continue
293
294 def _get_adjoint_and_sign(
295 *,
296 field_name: str,Lines 304-312 304 e_vec = np.eye(3)[component_axis]
305 cross = np.cross(n_vec, e_vec)
306
307 if not np.any(cross):
! 308 return None, 0.0 # indicates "no gradient"
309
310 target_axis = int(np.flatnonzero(cross)[0])
311 component_sign = float(cross[target_axis])Lines 313-322 313 if field_name.startswith("E"):
314 target_component = f"H{'xyz'[target_axis]}"
315 adjoint_field = h_adj.get(target_component)
316 else:
! 317 target_component = f"E{'xyz'[target_axis]}"
! 318 adjoint_field = e_adj.get(target_component)
319
320 return adjoint_field, component_sign
321
322 adjoint_field, component_sign = _get_adjoint_and_sign(Lines 328-337 328 )
329
330 if component_sign == 0.0:
331 # no gradient for injection_axis == component_axis
! 332 derivative_map[field_path] = np.zeros_like(field_data.data)
! 333 continue
334
335 adjoint_on_dataset = transpose_interp_field_to_dataset(
336 adjoint_field, field_data, center=center
337 )tidy3d/web/api/autograd/backward.pyLines 132-140 132 sim_data_adj, sim_data_orig, sim_data_fwd, component_index, component_paths
133 )
134 )
135 else:
! 136 raise ValueError(
137 f"Unexpected component_type='{component_type}' for component_index={component_index}. "
138 "Expected 'structures' or 'sources'."
139 )Lines 170-178 170 monitor_freqs = np.array(fld_adj.monitor.freqs)
171 if len(adjoint_frequencies) != len(monitor_freqs) or not np.allclose(
172 np.sort(adjoint_frequencies), np.sort(monitor_freqs), rtol=1e-10, atol=0
173 ):
! 174 raise ValueError(
175 f"Frequency mismatch in adjoint postprocessing for source {source_index}. "
176 f"Expected frequencies from monitor: {monitor_freqs}, "
177 f"but derivative map has: {adjoint_frequencies}. "
178 )Lines 266-274 266 monitor_freqs = np.array(fld_adj.monitor.freqs)
267 if len(adjoint_frequencies) != len(monitor_freqs) or not np.allclose(
268 np.sort(adjoint_frequencies), np.sort(monitor_freqs), rtol=1e-10, atol=0
269 ):
! 270 raise ValueError(
271 f"Frequency mismatch in adjoint postprocessing for structure {structure_index}. "
272 f"Expected frequencies from monitor: {monitor_freqs}, "
273 f"but derivative map has: {adjoint_frequencies}. "
274 )Lines 314-326 314 geometry_box = structure.geometry.bounding_box
315 background_structures_2d = []
316 sim_inf_background_medium = sim_orig.medium
317 if np.any(np.array(geometry_box.size) == 0.0):
! 318 zero_coordinate = tuple(geometry_box.size).index(0.0)
! 319 new_size = [td.inf, td.inf, td.inf]
! 320 new_size[zero_coordinate] = 0.0
321
! 322 background_structures_2d = [
323 structure.updated_copy(geometry=geometry_box.updated_copy(size=new_size))
324 ]
325 else:
326 sim_inf_background_medium = structure.mediumLines 356-364 356 n_freqs = len(adjoint_frequencies)
357 if not freq_chunk_size or freq_chunk_size <= 0:
358 freq_chunk_size = n_freqs
359 else:
! 360 freq_chunk_size = min(freq_chunk_size, n_freqs)
361
362 # process in chunks
363 vjp_value_map = {}Lines 431-443 431
432 # accumulate results
433 for path, value in vjp_chunk.items():
434 if path in vjp_value_map:
! 435 val = vjp_value_map[path]
! 436 if isinstance(val, (list, tuple)) and isinstance(value, (list, tuple)):
! 437 vjp_value_map[path] = type(val)(x + y for x, y in zip(val, value))
438 else:
! 439 vjp_value_map[path] += value
440 else:
441 vjp_value_map[path] = value
442 sim_fields_vjp = {}
443 # store vjps in output map |
Implemented adjoint gradients for
CustomCurrentSource.current_datasetandCustomFieldSource.field_dataset.here some raw results from the numerical tests
Note
Enables differentiation w.r.t. source field data and wires sources into the autograd forward/adjoint flow.
_compute_derivativesforCustomCurrentSource.current_datasetandCustomFieldSource.field_datasettranspose_interp_field_to_dataset,compute_spatial_weights,compute_source_weights,get_frequency_omegaSimulation._make_adjoint_monitorsto create field monitors for sources; no eps monitors for sources_strip_traced_fieldsto accept multiplestarting_pathsand update autograd API (setup_run, forward/backward paths) to includesourcespostprocess_adjto process structures and sources separately; add source-time scaling for VJPSIM_FIELDS_KEYSWritten by Cursor Bugbot for commit 360ac7a. This will update automatically on new commits. Configure here.
Greptile Overview
Greptile Summary
This PR implements adjoint gradient computation for
CustomCurrentSource.current_datasetandCustomFieldSource.field_dataset, enabling automatic differentiation with respect to source field data. The implementation extends the existing autograd infrastructure to support sources in addition to structures.Key Changes:
_compute_derivatives()methods toCustomCurrentSourceandCustomFieldSourcethat compute vector-Jacobian products (VJPs) by interpolating adjoint fields onto source datasets_make_adjoint_monitors()inSimulationto create field monitors for sources alongside existing structure monitorspostprocess_adj()in backward.py to handle both structures and sources through separate processing functionstranspose_interp_field_to_dataset(),compute_source_weights(), andget_frequency_omega()for source gradient computations_strip_traced_fields()in base.py to support multiple starting paths instead of a single pathImplementation Details:
For
CustomCurrentSource, the gradient is computed as0.5 * Re(source_time_scaling * adjoint_field * sign)where the sign depends on whether the component is E (+1) or H (-1).For
CustomFieldSource, the implementation uses the equivalence principle with cross products to determine the relationship between field components and injected currents, scaled byomega * epsilon_0 / cell_size.The numerical test results in the PR description show angle differences between adjoint and finite-difference gradients ranging from 0.02° to 3.0°, indicating good agreement.
Confidence Score: 4/5
tidy3d/components/source/current.pyandtidy3d/web/api/autograd/backward.pyto align with coding standardsImportant Files Changed
_compute_derivativesmethod toCustomCurrentSourcefor adjoint gradient computation with proper field interpolation and scaling_compute_derivativesmethod toCustomFieldSourcefor adjoint gradient computation with cross-product based current scaling_process_source_gradientsfunction with source time scaling_make_adjoint_monitorsto create field monitors for sources in addition to structures_strip_traced_fieldsto support multiple starting paths instead of single pathcompute_source_weights,transpose_interp_field_to_dataset, andget_frequency_omegafor source gradient computationSequence Diagram
sequenceDiagram participant User participant AutogradAPI as Autograd API participant Simulation participant Source as CustomSource participant BackwardPass as Backward Pass participant DerivativeInfo User->>AutogradAPI: run with traced source parameters AutogradAPI->>Simulation: execute forward simulation Simulation->>Simulation: _make_adjoint_monitors() Simulation->>Simulation: create source field monitors Note over Simulation: Forward simulation runs User->>AutogradAPI: compute gradients (backward pass) AutogradAPI->>BackwardPass: setup_adj(data_fields_vjp) BackwardPass->>BackwardPass: filter traced fields BackwardPass->>Simulation: _make_adjoint_sims() Note over Simulation: Adjoint simulation runs BackwardPass->>BackwardPass: postprocess_adj() BackwardPass->>BackwardPass: _process_source_gradients() BackwardPass->>DerivativeInfo: create DerivativeInfo with E_adj, H_adj BackwardPass->>Source: _compute_derivatives(derivative_info) alt CustomCurrentSource Source->>Source: transpose_interp_field_to_dataset() Source->>Source: compute VJP with source_time_scaling Source-->>BackwardPass: derivative_map else CustomFieldSource Source->>Source: compute cross products (n x E, n x H) Source->>Source: transpose_interp_field_to_dataset() Source->>Source: apply current_scale (omega * epsilon_0) Source-->>BackwardPass: derivative_map end BackwardPass-->>AutogradAPI: sim_fields_vjp AutogradAPI-->>User: gradients w.r.t. source parameters