88from __future__ import annotations
99
1010import logging
11- import math
1211
1312import torch
14- from e3nn import o3
1513
16- YTOL = 0.999999
1714
18-
19- def init_edge_rot_mat (edge_distance_vec , rot_clip = False ):
15+ def init_edge_rot_euler_angles (edge_distance_vec ):
2016 edge_vec_0 = edge_distance_vec
2117 edge_vec_0_distance = torch .sqrt (torch .sum (edge_vec_0 ** 2 , dim = 1 ))
2218
@@ -25,54 +21,30 @@ def init_edge_rot_mat(edge_distance_vec, rot_clip=False):
2521 if len (edge_vec_0_distance ) > 0 and torch .min (edge_vec_0_distance ) < 0.0001 :
2622 logging .error (f"Error edge_vec_0_distance: { torch .min (edge_vec_0_distance )} " )
2723
28- norm_x = edge_vec_0 / (edge_vec_0_distance .view (- 1 , 1 ))
29-
30- if rot_clip :
31- yprod = norm_x @ norm_x .new_tensor ([0.0 , 1.0 , 0.0 ])
32- norm_x [yprod > YTOL ] = norm_x .new_tensor ([0.0 , 1.0 , 0.0 ])
33- norm_x [yprod < - YTOL ] = norm_x .new_tensor ([0.0 , - 1.0 , 0.0 ])
34-
35- edge_vec_2 = torch .rand_like (edge_vec_0 ) - 0.5
36- edge_vec_2 = edge_vec_2 / (torch .sqrt (torch .sum (edge_vec_2 ** 2 , dim = 1 )).view (- 1 , 1 ))
37- # Create two rotated copys of the random vectors in case the random vector is aligned with norm_x
38- # With two 90 degree rotated vectors, at least one should not be aligned with norm_x
39- edge_vec_2b = edge_vec_2 .clone ()
40- edge_vec_2b [:, 0 ] = - edge_vec_2 [:, 1 ]
41- edge_vec_2b [:, 1 ] = edge_vec_2 [:, 0 ]
42- edge_vec_2c = edge_vec_2 .clone ()
43- edge_vec_2c [:, 1 ] = - edge_vec_2 [:, 2 ]
44- edge_vec_2c [:, 2 ] = edge_vec_2 [:, 1 ]
45- vec_dot_b = torch .abs (torch .sum (edge_vec_2b * norm_x , dim = 1 )).view (- 1 , 1 )
46- vec_dot_c = torch .abs (torch .sum (edge_vec_2c * norm_x , dim = 1 )).view (- 1 , 1 )
47-
48- vec_dot = torch .abs (torch .sum (edge_vec_2 * norm_x , dim = 1 )).view (- 1 , 1 )
49- edge_vec_2 = torch .where (torch .gt (vec_dot , vec_dot_b ), edge_vec_2b , edge_vec_2 )
50- vec_dot = torch .abs (torch .sum (edge_vec_2 * norm_x , dim = 1 )).view (- 1 , 1 )
51- edge_vec_2 = torch .where (torch .gt (vec_dot , vec_dot_c ), edge_vec_2c , edge_vec_2 )
52-
53- vec_dot = torch .abs (torch .sum (edge_vec_2 * norm_x , dim = 1 ))
54- # Check the vectors aren't aligned
55- if len (vec_dot ) > 0 :
56- assert torch .max (vec_dot ) < 0.99
57-
58- norm_z = torch .cross (norm_x , edge_vec_2 , dim = 1 )
59- norm_z = norm_z / (torch .sqrt (torch .sum (norm_z ** 2 , dim = 1 , keepdim = True )))
60- norm_z = norm_z / (torch .sqrt (torch .sum (norm_z ** 2 , dim = 1 )).view (- 1 , 1 ))
61- norm_y = torch .cross (norm_x , norm_z , dim = 1 )
62- norm_y = norm_y / (torch .sqrt (torch .sum (norm_y ** 2 , dim = 1 , keepdim = True )))
63-
64- # Construct the 3D rotation matrix
65- norm_x = norm_x .view (- 1 , 3 , 1 )
66- norm_y = - norm_y .view (- 1 , 3 , 1 )
67- norm_z = norm_z .view (- 1 , 3 , 1 )
68-
69- edge_rot_mat_inv = torch .cat ([norm_z , norm_x , norm_y ], dim = 2 )
70- edge_rot_mat = torch .transpose (edge_rot_mat_inv , 1 , 2 )
71-
72- if rot_clip :
73- return edge_rot_mat
74- else :
75- return edge_rot_mat .detach ()
24+ # make unit vectors
25+ xyz = edge_vec_0 / (edge_vec_0_distance .view (- 1 , 1 ))
26+
27+ # are we standing at the north pole
28+ mask = xyz [:, 1 ].abs ().isclose (xyz .new_ones (1 ))
29+
30+ # compute alpha and beta
31+
32+ # latitude (beta)
33+ beta = xyz .new_zeros (xyz .shape [0 ])
34+ beta [~ mask ] = torch .acos (xyz [~ mask , 1 ])
35+ beta [mask ] = torch .acos (xyz [mask , 1 ]).detach ()
36+
37+ # longitude (alpha)
38+ alpha = torch .zeros_like (beta )
39+ alpha [~ mask ] = torch .atan2 (xyz [~ mask , 0 ], xyz [~ mask , 2 ])
40+ alpha [mask ] = torch .atan2 (xyz [mask , 0 ], xyz [mask , 2 ]).detach ()
41+
42+ # random gamma (roll)
43+ gamma = torch .rand_like (alpha ) * 2 * torch .pi
44+ # gamma = torch.zeros_like(alpha)
45+
46+ # intrinsic to extrinsic swap
47+ return - gamma , - beta , - alpha
7648
7749
7850# Borrowed from e3nn @ 0.4.0:
@@ -118,58 +90,24 @@ def _z_rot_mat(angle: torch.Tensor, lv: int) -> torch.Tensor:
11890 return M
11991
12092
121- def rotation_to_wigner (
122- edge_rot_mat : torch .Tensor ,
93+ def eulers_to_wigner (
94+ eulers : torch .Tensor ,
12395 start_lmax : int ,
12496 end_lmax : int ,
12597 Jd : list [torch .Tensor ],
126- rot_clip : bool = False ,
12798) -> torch .Tensor :
12899 """
129100 set <rot_clip=True> to handle gradient instability when using gradient-based force/stress prediction.
130101 """
131- x = edge_rot_mat @ edge_rot_mat .new_tensor ([0.0 , 1.0 , 0.0 ])
132- alpha , beta = o3 .xyz_to_angles (x )
133- R = (
134- o3 .angles_to_matrix (alpha , beta , torch .zeros_like (alpha )).transpose (- 1 , - 2 )
135- @ edge_rot_mat
136- )
137- gamma = torch .atan2 (R [..., 0 , 2 ], R [..., 0 , 0 ])
138-
139- if rot_clip :
140- yprod = (x @ x .new_tensor ([0 , 1 , 0 ])).detach ()
141- mask = (yprod > - YTOL ) & (yprod < YTOL )
142- alpha_detach = alpha [~ mask ].clone ().detach ()
143- gamma_detach = gamma [~ mask ].clone ().detach ()
144- beta_detach = beta .clone ().detach ()
145- beta_detach [yprod > YTOL ] = 0.0
146- beta_detach [yprod < - YTOL ] = math .pi
147- beta_detach = beta_detach [~ mask ]
102+ alpha , beta , gamma = eulers
148103
149104 size = int ((end_lmax + 1 ) ** 2 ) - int ((start_lmax ) ** 2 )
150- wigner = torch .zeros (
151- len (alpha ), size , size , device = edge_rot_mat .device , dtype = edge_rot_mat .dtype
152- )
105+ wigner = torch .zeros (len (alpha ), size , size , device = alpha .device , dtype = alpha .dtype )
153106 start = 0
154107 for lmax in range (start_lmax , end_lmax + 1 ):
155- if rot_clip :
156- block = wigner_D (lmax , alpha [mask ], beta [mask ], gamma [mask ], Jd ).to (
157- wigner .dtype
158- )
159- block_detach = wigner_D (
160- lmax , alpha_detach , beta_detach , gamma_detach , Jd
161- ).to (wigner .dtype )
162- end = start + block .size ()[1 ]
163- wigner [mask , start :end , start :end ] = block
164- wigner [~ mask , start :end , start :end ] = block_detach
165- start = end
166- else :
167- block = wigner_D (lmax , alpha , beta , gamma , Jd )
168- end = start + block .size ()[1 ]
169- wigner [:, start :end , start :end ] = block
170- start = end
171-
172- if rot_clip :
173- return wigner
174- else :
175- return wigner .detach ()
108+ block = wigner_D (lmax , alpha , beta , gamma , Jd )
109+ end = start + block .size ()[1 ]
110+ wigner [:, start :end , start :end ] = block
111+ start = end
112+
113+ return wigner
0 commit comments