-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
When running notebooks/how_to/regression_across_subjects.ipynb on the server (frank), I get these errors:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[15], [line 23](vscode-notebook-cell:?execution_count=15&line=23)
16 continue
18 meshes = {
19 "global": global_template,
20 subject_id: subject_template,
21 session_id: mesh,
22 }
---> [23](vscode-notebook-cell:?execution_count=15&line=23) plddmm.geometry.parallel_transport_ABC(
24 meshes,
25 output_dir=transport_dir,
26 registration_dir=registrations_dir,
27 compute_shoot=True,
28 shoot_dir=parallel_shoot_dir,
29 use_pole_ladder=True,
30 **registration_kwargs,
31 )
File ~/anaconda3/envs/deformetrica/lib/python3.12/site-packages/polpo/lddmm/geometry.py:325, in parallel_transport_ABC(meshes, output_dir, kernel_width, kernel_type, kernel_device, use_pole_ladder, compute_shoot, registration_dir, shoot_dir, **registration_kwargs)
320 transp_target = c_name
321 transport_output_dir = output_dir / io.build_parallel_transport_name(
322 source, geod_target, transp_target
323 )
--> [325](https://vscode-remote+ssh-002dremote-002bfrank-002eece-002eucsb-002eedu.vscode-resource.vscode-cdn.net/home/sak/polpo/notebooks/how_to/~/anaconda3/envs/deformetrica/lib/python3.12/site-packages/polpo/lddmm/geometry.py:325) out = parallel_transport(
...
100 if as_numpy:
--> [101](https://vscode-remote+ssh-002dremote-002bfrank-002eece-002eucsb-002eedu.vscode-resource.vscode-cdn.net/home/sak/polpo/notebooks/how_to/~/anaconda3/envs/deformetrica/lib/python3.12/site-packages/polpo/lddmm/geometry.py:101) return transported_cp.numpy(), transported_mom.numpy()
103 return transported_cp, transported_mom
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
the errors are fixed when I change the following in polpo/lddmm/geometry.py:
line 101 return transported_cp.numpy(), transported_mom.numpy() -> return transported_cp.cpu().numpy(), transported_mom.cpu().numpy()
line 396 return vel.numpy() -> return vel.cpu().numpy())
Metadata
Metadata
Assignees
Labels
No labels