Skip to content

Commit a8ab914

Browse files
committed
[Models] Stopped zeroing out node/vertex loads if tmax>1
1 parent d907cab commit a8ab914

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12-
- Added `jax.lax.custom_linear_solver()` to backward rule of implicit `custom_vjp` of fixed-point solver. This new function acts as a thin wrapper around the *sparse* linear solver of `EquilibriumModel.linearsolve_fn()` that defines a custom transpose rule to be compatible with `jax.linear_transpose()`. The transpose is needed by `lineax`, inside `FunctionLinearOperator`. Without the wrapper and the transpose, we cannot use implicit differentiation with a sparse linear solver and a fixed-point solver. Now we can.
12+
- Added automatic support for dense and sparse stiffness matrices in `custom_vjp` of `solver_fixed_point_implicit()`. For sparse matrices, we apply `jax.lax.custom_linear_solver()` as a thin wrapper around the sparse linear solve defined in `EquilibriumModel.linearsolve_fn()` to generate a transpose rule for it. The transpose rule is required by `lineax`, inside `FunctionLinearOperator`. Without the wrapper and the transpose, we cannot use implicit differentiation with a sparse linear solver and a fixed-point solver. Now we can.
1313
- Implemented `geometry.length_vector_sqrd()`.
1414
- Print out statistics with `ndigits` of precision in `FDDatastructure.print_stats()`.
1515
- Listed `lineax` and `optimistix` as dependencies.

src/jax_fdm/equilibrium/models.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,6 @@ def __call__(self, params, structure):
172172
xyz_free = self.equilibrium(q, xyz_fixed, load_nodes, structure)
173173

174174
if tmax > 1:
175-
load_nodes = jnp.zeros_like(load_nodes)
176-
load_state = LoadState(load_nodes, load_state.edges, load_state.faces)
177175

178176
xyz_free = self.eq_iterative_fn(
179177
q,

0 commit comments

Comments
 (0)