Skip to content

Commit f613682

Browse files
bottlerfacebook-github-bot
authored andcommitted
marching_cubes type fix
Summary: fixes #1679 Reviewed By: MichaelRamamonjisoa Differential Revision: D50949933 fbshipit-source-id: 5c467de8bf84dd2a3d61748b3846678582d24ea3
1 parent 2f11ddc commit f613682

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

pytorch3d/csrc/marching_cubes/marching_cubes.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ __global__ void CompactVoxelsKernel(
223223
compactedVoxelArray,
224224
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
225225
voxelOccupied,
226-
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
226+
const at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits>
227227
voxelOccupiedScan,
228228
uint numVoxels) {
229229
uint id = blockIdx.x * blockDim.x + threadIdx.x;
@@ -255,7 +255,7 @@ __global__ void GenerateFacesKernel(
255255
at::PackedTensorAccessor<int64_t, 1, at::RestrictPtrTraits> ids,
256256
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
257257
compactedVoxelArray,
258-
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> numVertsScanned,
258+
at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits> numVertsScanned,
259259
const uint activeVoxels,
260260
const at::PackedTensorAccessor32<float, 3, at::RestrictPtrTraits> vol,
261261
const at::PackedTensorAccessor32<int, 2, at::RestrictPtrTraits> faceTable,
@@ -471,7 +471,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
471471
auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});
472472

473473
// number of active voxels
474-
int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int>();
474+
int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<long>();
475475

476476
const int device_id = vol.device().index();
477477
auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
@@ -492,7 +492,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
492492
CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
493493
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
494494
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
495-
d_voxelOccupiedScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
495+
d_voxelOccupiedScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
496496
numVoxels);
497497
AT_CUDA_CHECK(cudaGetLastError());
498498
cudaDeviceSynchronize();
@@ -502,7 +502,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
502502
auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});
503503

504504
// total number of vertices
505-
int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int>();
505+
int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<long>();
506506

507507
// Execute "GenerateFacesKernel" kernel
508508
// This runs only on the occupied voxels.
@@ -522,7 +522,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
522522
faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
523523
ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
524524
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
525-
d_voxelVertsScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
525+
d_voxelVertsScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
526526
activeVoxels,
527527
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
528528
faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),

tests/test_marching_cubes.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -939,8 +939,11 @@ def test_ball_example(self):
939939
u = u[None].float()
940940
verts, faces = marching_cubes_naive(u, 0, return_local_coords=False)
941941
verts2, faces2 = marching_cubes(u, 0, return_local_coords=False)
942-
self.assertClose(verts[0], verts2[0])
943-
self.assertClose(faces[0], faces2[0])
942+
self.assertClose(verts2[0], verts[0])
943+
self.assertClose(faces2[0], faces[0])
944+
verts3, faces3 = marching_cubes(u.cuda(), 0, return_local_coords=False)
945+
self.assertEqual(len(verts3), len(verts))
946+
self.assertEqual(len(faces3), len(faces))
944947

945948
@staticmethod
946949
def marching_cubes_with_init(algo_type: str, batch_size: int, V: int, device: str):

0 commit comments

Comments
 (0)