Skip to content

Commit 4c6d01e

Browse files
authored
[ BE ] refactored init edge rot mat (#1368)
* refactored init edge rot mat * cleanup * fix dtype issue * fix test constants * fix cuda graph version * clean up * test if gamma=0 fixes CI gpu test * NCCL debug * typo * add gamma back in for tests * only run the one test * fix gpu debug * wrong test * try to run full mlip_unit * test more * run larger set of tests * run more test * debug memory * warning supressed by pytest, do a print * try empty cache * revert yaml * clear cuda memory between tests
1 parent a99ed63 commit 4c6d01e

8 files changed

Lines changed: 117 additions & 294 deletions

File tree

.github/workflows/test.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,13 @@ jobs:
8484
env:
8585
HF_TOKEN: ${{ secrets.HF_TOKEN }}
8686
run: |
87-
pytest tests --durations=0 -vv --ignore=tests/demo/ocpapi/tests/integration/ --ignore=tests/applications/ --ignore=tests/perf --cov-report=xml --cov=fairchem --junitxml=junit.xml -o junit_family=legacy -c ./packages/fairchem-core/pyproject.toml
87+
NCCL_DEBUG=INFO pytest tests --durations=0 -vv --ignore=tests/demo/ocpapi/tests/integration/ --ignore=tests/applications/ --ignore=tests/perf --cov-report=xml --cov=fairchem --junitxml=junit.xml -o junit_family=legacy -c ./packages/fairchem-core/pyproject.toml
8888
8989
- name: Core GPU tests
9090
env:
9191
HF_TOKEN: ${{ secrets.HF_TOKEN }}
9292
run: |
93-
pytest tests/core --durations=0 -vv -m gpu -c ./packages/fairchem-core/pyproject.toml -s --junitxml=junit-gpu.xml -o junit_family=legacy --cov-report=xml:gpu-coverage.xml
93+
NCCL_DEBUG=INFO pytest tests/core --durations=0 -vv -m gpu -c ./packages/fairchem-core/pyproject.toml -s --junitxml=junit-gpu.xml -o junit_family=legacy --cov-report=xml:gpu-coverage.xml
9494
9595
- if: ${{ matrix.python_version == '3.12' }}
9696
name: codecov-coverage
@@ -156,8 +156,7 @@ jobs:
156156
env:
157157
HF_TOKEN: ${{ secrets.HF_TOKEN }}
158158
run: |
159-
pytest tests/core --durations=0 -vv -m gpu -c ./packages/fairchem-core/pyproject.toml
160-
159+
NCCL_DEBUG=INFO pytest tests/core --durations=0 -vv -m gpu -c ./packages/fairchem-core/pyproject.toml
161160
- name: Cleanup
162161
if: always()
163162
uses: ./.github/actions/multi-trigger-cleanup

src/fairchem/core/models/uma/common/rotation.py

Lines changed: 35 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,11 @@
88
from __future__ import annotations
99

1010
import logging
11-
import math
1211

1312
import torch
14-
from e3nn import o3
1513

16-
YTOL = 0.999999
1714

18-
19-
def init_edge_rot_mat(edge_distance_vec, rot_clip=False):
15+
def init_edge_rot_euler_angles(edge_distance_vec):
2016
edge_vec_0 = edge_distance_vec
2117
edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1))
2218

@@ -25,54 +21,30 @@ def init_edge_rot_mat(edge_distance_vec, rot_clip=False):
2521
if len(edge_vec_0_distance) > 0 and torch.min(edge_vec_0_distance) < 0.0001:
2622
logging.error(f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}")
2723

28-
norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1))
29-
30-
if rot_clip:
31-
yprod = norm_x @ norm_x.new_tensor([0.0, 1.0, 0.0])
32-
norm_x[yprod > YTOL] = norm_x.new_tensor([0.0, 1.0, 0.0])
33-
norm_x[yprod < -YTOL] = norm_x.new_tensor([0.0, -1.0, 0.0])
34-
35-
edge_vec_2 = torch.rand_like(edge_vec_0) - 0.5
36-
edge_vec_2 = edge_vec_2 / (torch.sqrt(torch.sum(edge_vec_2**2, dim=1)).view(-1, 1))
37-
# Create two rotated copys of the random vectors in case the random vector is aligned with norm_x
38-
# With two 90 degree rotated vectors, at least one should not be aligned with norm_x
39-
edge_vec_2b = edge_vec_2.clone()
40-
edge_vec_2b[:, 0] = -edge_vec_2[:, 1]
41-
edge_vec_2b[:, 1] = edge_vec_2[:, 0]
42-
edge_vec_2c = edge_vec_2.clone()
43-
edge_vec_2c[:, 1] = -edge_vec_2[:, 2]
44-
edge_vec_2c[:, 2] = edge_vec_2[:, 1]
45-
vec_dot_b = torch.abs(torch.sum(edge_vec_2b * norm_x, dim=1)).view(-1, 1)
46-
vec_dot_c = torch.abs(torch.sum(edge_vec_2c * norm_x, dim=1)).view(-1, 1)
47-
48-
vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1)
49-
edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_b), edge_vec_2b, edge_vec_2)
50-
vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1)
51-
edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_c), edge_vec_2c, edge_vec_2)
52-
53-
vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1))
54-
# Check the vectors aren't aligned
55-
if len(vec_dot) > 0:
56-
assert torch.max(vec_dot) < 0.99
57-
58-
norm_z = torch.cross(norm_x, edge_vec_2, dim=1)
59-
norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1, keepdim=True)))
60-
norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1)).view(-1, 1))
61-
norm_y = torch.cross(norm_x, norm_z, dim=1)
62-
norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True)))
63-
64-
# Construct the 3D rotation matrix
65-
norm_x = norm_x.view(-1, 3, 1)
66-
norm_y = -norm_y.view(-1, 3, 1)
67-
norm_z = norm_z.view(-1, 3, 1)
68-
69-
edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2)
70-
edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2)
71-
72-
if rot_clip:
73-
return edge_rot_mat
74-
else:
75-
return edge_rot_mat.detach()
24+
# make unit vectors
25+
xyz = edge_vec_0 / (edge_vec_0_distance.view(-1, 1))
26+
27+
# are we standing at the north pole
28+
mask = xyz[:, 1].abs().isclose(xyz.new_ones(1))
29+
30+
# compute alpha and beta
31+
32+
# latitude (beta)
33+
beta = xyz.new_zeros(xyz.shape[0])
34+
beta[~mask] = torch.acos(xyz[~mask, 1])
35+
beta[mask] = torch.acos(xyz[mask, 1]).detach()
36+
37+
# longitude (alpha)
38+
alpha = torch.zeros_like(beta)
39+
alpha[~mask] = torch.atan2(xyz[~mask, 0], xyz[~mask, 2])
40+
alpha[mask] = torch.atan2(xyz[mask, 0], xyz[mask, 2]).detach()
41+
42+
# random gamma (roll)
43+
gamma = torch.rand_like(alpha) * 2 * torch.pi
44+
# gamma = torch.zeros_like(alpha)
45+
46+
# intrinsic to extrinsic swap
47+
return -gamma, -beta, -alpha
7648

7749

7850
# Borrowed from e3nn @ 0.4.0:
@@ -118,58 +90,24 @@ def _z_rot_mat(angle: torch.Tensor, lv: int) -> torch.Tensor:
11890
return M
11991

12092

121-
def rotation_to_wigner(
122-
edge_rot_mat: torch.Tensor,
93+
def eulers_to_wigner(
94+
eulers: torch.Tensor,
12395
start_lmax: int,
12496
end_lmax: int,
12597
Jd: list[torch.Tensor],
126-
rot_clip: bool = False,
12798
) -> torch.Tensor:
12899
"""
129100
set <rot_clip=True> to handle gradient instability when using gradient-based force/stress prediction.
130101
"""
131-
x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0])
132-
alpha, beta = o3.xyz_to_angles(x)
133-
R = (
134-
o3.angles_to_matrix(alpha, beta, torch.zeros_like(alpha)).transpose(-1, -2)
135-
@ edge_rot_mat
136-
)
137-
gamma = torch.atan2(R[..., 0, 2], R[..., 0, 0])
138-
139-
if rot_clip:
140-
yprod = (x @ x.new_tensor([0, 1, 0])).detach()
141-
mask = (yprod > -YTOL) & (yprod < YTOL)
142-
alpha_detach = alpha[~mask].clone().detach()
143-
gamma_detach = gamma[~mask].clone().detach()
144-
beta_detach = beta.clone().detach()
145-
beta_detach[yprod > YTOL] = 0.0
146-
beta_detach[yprod < -YTOL] = math.pi
147-
beta_detach = beta_detach[~mask]
102+
alpha, beta, gamma = eulers
148103

149104
size = int((end_lmax + 1) ** 2) - int((start_lmax) ** 2)
150-
wigner = torch.zeros(
151-
len(alpha), size, size, device=edge_rot_mat.device, dtype=edge_rot_mat.dtype
152-
)
105+
wigner = torch.zeros(len(alpha), size, size, device=alpha.device, dtype=alpha.dtype)
153106
start = 0
154107
for lmax in range(start_lmax, end_lmax + 1):
155-
if rot_clip:
156-
block = wigner_D(lmax, alpha[mask], beta[mask], gamma[mask], Jd).to(
157-
wigner.dtype
158-
)
159-
block_detach = wigner_D(
160-
lmax, alpha_detach, beta_detach, gamma_detach, Jd
161-
).to(wigner.dtype)
162-
end = start + block.size()[1]
163-
wigner[mask, start:end, start:end] = block
164-
wigner[~mask, start:end, start:end] = block_detach
165-
start = end
166-
else:
167-
block = wigner_D(lmax, alpha, beta, gamma, Jd)
168-
end = start + block.size()[1]
169-
wigner[:, start:end, start:end] = block
170-
start = end
171-
172-
if rot_clip:
173-
return wigner
174-
else:
175-
return wigner.detach()
108+
block = wigner_D(lmax, alpha, beta, gamma, Jd)
109+
end = start + block.size()[1]
110+
wigner[:, start:end, start:end] = block
111+
start = end
112+
113+
return wigner

0 commit comments

Comments
 (0)