Skip to content

Commit 7606854

Browse files
eclipse0922facebook-github-bot
authored andcommitted
Fix windows build (#1689)
Summary: Change the data type usage in the code to ensure cross-platform compatibility long -> int64_t <img width="628" alt="image" src="https://github.com/facebookresearch/pytorch3d/assets/6214316/40041f7f-3c09-4571-b9ff-676c625802e9"> Tested under Win 11 and Ubuntu 22.04 with CUDA 12.1.1 torch 2.1.1 Related issues & PR #9 #1679 Pull Request resolved: #1689 Reviewed By: MichaelRamamonjisoa Differential Revision: D51521562 Pulled By: bottler fbshipit-source-id: d8ea81e223c642e0e9fb283f5d7efc9d6ac00d93
1 parent 83bacda commit 7606854

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

pytorch3d/csrc/marching_cubes/marching_cubes.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ __global__ void CompactVoxelsKernel(
223223
compactedVoxelArray,
224224
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
225225
voxelOccupied,
226-
const at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits>
226+
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
227227
voxelOccupiedScan,
228228
uint numVoxels) {
229229
uint id = blockIdx.x * blockDim.x + threadIdx.x;
@@ -255,7 +255,8 @@ __global__ void GenerateFacesKernel(
255255
at::PackedTensorAccessor<int64_t, 1, at::RestrictPtrTraits> ids,
256256
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
257257
compactedVoxelArray,
258-
at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits> numVertsScanned,
258+
at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
259+
numVertsScanned,
259260
const uint activeVoxels,
260261
const at::PackedTensorAccessor32<float, 3, at::RestrictPtrTraits> vol,
261262
const at::PackedTensorAccessor32<int, 2, at::RestrictPtrTraits> faceTable,
@@ -471,7 +472,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
471472
auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});
472473

473474
// number of active voxels
474-
int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<long>();
475+
int64_t activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int64_t>();
475476

476477
const int device_id = vol.device().index();
477478
auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
@@ -492,7 +493,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
492493
CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
493494
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
494495
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
495-
d_voxelOccupiedScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
496+
d_voxelOccupiedScan_
497+
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
496498
numVoxels);
497499
AT_CUDA_CHECK(cudaGetLastError());
498500
cudaDeviceSynchronize();
@@ -502,7 +504,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
502504
auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});
503505

504506
// total number of vertices
505-
int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<long>();
507+
int64_t totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int64_t>();
506508

507509
// Execute "GenerateFacesKernel" kernel
508510
// This runs only on the occupied voxels.
@@ -522,7 +524,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
522524
faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
523525
ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
524526
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
525-
d_voxelVertsScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
527+
d_voxelVertsScan_.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
526528
activeVoxels,
527529
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
528530
faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),

0 commit comments

Comments
 (0)