Skip to content

Commit 9e21659

Browse files
yuanluxufacebook-github-bot
authored andcommitted
Fixed windows MSVC build compatibility (#9)
Summary: Fixed a few MSVC compiler (visual studio 2019, MSVC 19.16.27034) compatibility issues 1. Replaced long with int64_t. aten::data_ptr\<long\> is not supported in MSVC 2. pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp, inline function is not correctly recognized by MSVC. 3. pytorch3d/csrc/rasterize_meshes/geometry_utils.cuh const auto kEpsilon = 1e-30; MSVC does not compile this const into both host and device, change to a MACRO. 4. pytorch3d/csrc/rasterize_meshes/geometry_utils.cuh, const float area2 = pow(area, 2.0); 2.0 is considered as double by MSVC and raised an error 5. pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp std::tuple<torch::Tensor, torch::Tensor> RasterizePointsCoarseCpu() return type does not match the declaration in rasterize_points_cpu.h. Pull Request resolved: #9 Reviewed By: nikhilaravi Differential Revision: D19986567 Pulled By: yuanluxu fbshipit-source-id: f4d98525d088c99c513b85193db6f0fc69c7f017
1 parent a3baa36 commit 9e21659

File tree

5 files changed

+60
-18
lines changed

5 files changed

+60
-18
lines changed

INSTALL.md

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
The core library is written in PyTorch. Several components have underlying implementation in CUDA for improved performance. A subset of these components have CPU implementations in C++/Pytorch. It is advised to use PyTorch3d with GPU support in order to use all the features.
99

10-
- Linux or macOS
10+
- Linux or macOS or Windows
1111
- Python ≥ 3.6
1212
- PyTorch 1.4
1313
- torchvision that matches the PyTorch installation. You can install them together at pytorch.org to make sure of this.
@@ -72,3 +72,41 @@ To rebuild after installing from a local clone run, `rm -rf build/ **/*.so` then
7272
```
7373
MACOSX_DEPLOYMENT_TARGET=10.14 CC=clang CXX=clang++ pip install -e .
7474
```
75+
76+
**Install from local clone on Windows:**
77+
78+
If you are using pre-compiled pytorch 1.4 and torchvision 0.5, you should make the following changes to the pytorch source code to successfully compile with Visual Studio 2019 (MSVC 19.16.27034) and CUDA 10.1.
79+
80+
Change python/Lib/site-packages/torch/include/csrc/jit/script/module.h
81+
82+
L466, 476, 493, 506, 536
83+
```
84+
-static constexpr *
85+
+static const *
86+
```
87+
Change python/Lib/site-packages/torch/include/csrc/jit/argument_spec.h
88+
89+
L190
90+
```
91+
-static constexpr size_t DEPTH_LIMIT = 128;
92+
+static const size_t DEPTH_LIMIT = 128;
93+
```
94+
95+
Change python/Lib/site-packages/torch/include/pybind11/cast.h
96+
97+
L1449
98+
```
99+
-explicit operator type&() { return *(this->value); }
100+
+explicit operator type& () { return *((type*)(this->value)); }
101+
```
102+
103+
After patching, you can go to "x64 Native Tools Command Prompt for VS 2019" to compile and install
104+
```
105+
cd pytorch3d
106+
python3 setup.py install
107+
```
108+
After installing, verify whether all unit tests have passed
109+
```
110+
cd tests
111+
python3 -m unittest discover -p *.py
112+
```

pytorch3d/csrc/gather_scatter/gather_scatter.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// TODO(T47953967) to make this cuda kernel support all datatypes.
66
__global__ void gather_scatter_kernel(
77
const float* __restrict__ input,
8-
const long* __restrict__ edges,
8+
const int64_t* __restrict__ edges,
99
float* __restrict__ output,
1010
bool directed,
1111
bool backward,
@@ -21,8 +21,8 @@ __global__ void gather_scatter_kernel(
2121
// Edges are split evenly across the blocks.
2222
for (int e = blockIdx.x; e < E; e += gridDim.x) {
2323
// Get indices of vertices which form the edge.
24-
const long v0 = edges[2 * e + v0_idx];
25-
const long v1 = edges[2 * e + v1_idx];
24+
const int64_t v0 = edges[2 * e + v0_idx];
25+
const int64_t v1 = edges[2 * e + v1_idx];
2626

2727
// Split vertex features evenly across threads.
2828
// This implementation will be quite wasteful when D<128 since there will be
@@ -57,7 +57,7 @@ at::Tensor gather_scatter_cuda(
5757

5858
gather_scatter_kernel<<<blocks, threads>>>(
5959
input.data_ptr<float>(),
60-
edges.data_ptr<long>(),
60+
edges.data_ptr<int64_t>(),
6161
output.data_ptr<float>(),
6262
directed,
6363
backward,

pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.cu

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
template <typename scalar_t>
77
__device__ void WarpReduce(
88
volatile scalar_t* min_dists,
9-
volatile long* min_idxs,
9+
volatile int64_t* min_idxs,
1010
const size_t tid) {
1111
// s = 32
1212
if (min_dists[tid] > min_dists[tid + 32]) {
@@ -57,7 +57,7 @@ template <typename scalar_t>
5757
__global__ void NearestNeighborKernel(
5858
const scalar_t* __restrict__ points1,
5959
const scalar_t* __restrict__ points2,
60-
long* __restrict__ idx,
60+
int64_t* __restrict__ idx,
6161
const size_t N,
6262
const size_t P1,
6363
const size_t P2,
@@ -74,7 +74,7 @@ __global__ void NearestNeighborKernel(
7474
extern __shared__ char shared_buf[];
7575
scalar_t* x = (scalar_t*)shared_buf; // scalar_t[DD]
7676
scalar_t* min_dists = &x[D_2]; // scalar_t[NUM_THREADS]
77-
long* min_idxs = (long*)&min_dists[blockDim.x]; // long[NUM_THREADS]
77+
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
7878

7979
const size_t n = blockIdx.y; // index of batch element.
8080
const size_t i = blockIdx.x; // index of point within batch element.
@@ -147,14 +147,14 @@ template <typename scalar_t>
147147
__global__ void NearestNeighborKernelD3(
148148
const scalar_t* __restrict__ points1,
149149
const scalar_t* __restrict__ points2,
150-
long* __restrict__ idx,
150+
int64_t* __restrict__ idx,
151151
const size_t N,
152152
const size_t P1,
153153
const size_t P2) {
154154
// Single shared memory buffer which is split and cast to different types.
155155
extern __shared__ char shared_buf[];
156156
scalar_t* min_dists = (scalar_t*)shared_buf; // scalar_t[NUM_THREADS]
157-
long* min_idxs = (long*)&min_dists[blockDim.x]; // long[NUM_THREADS]
157+
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
158158

159159
const size_t D = 3;
160160
const size_t n = blockIdx.y; // index of batch element.
@@ -230,12 +230,12 @@ at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2) {
230230
// Use the specialized kernel for D=3.
231231
AT_DISPATCH_FLOATING_TYPES(p1.type(), "nearest_neighbor_v3_cuda", ([&] {
232232
size_t shared_size = threads * sizeof(size_t) +
233-
threads * sizeof(long);
233+
threads * sizeof(int64_t);
234234
NearestNeighborKernelD3<scalar_t>
235235
<<<blocks, threads, shared_size>>>(
236236
p1.data_ptr<scalar_t>(),
237237
p2.data_ptr<scalar_t>(),
238-
idx.data_ptr<long>(),
238+
idx.data_ptr<int64_t>(),
239239
N,
240240
P1,
241241
P2);
@@ -248,11 +248,11 @@ at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2) {
248248
// need to be rounded to the next even size.
249249
size_t D_2 = D + (D % 2);
250250
size_t shared_size = (D_2 + threads) * sizeof(size_t);
251-
shared_size += threads * sizeof(long);
251+
shared_size += threads * sizeof(int64_t);
252252
NearestNeighborKernel<scalar_t><<<blocks, threads, shared_size>>>(
253253
p1.data_ptr<scalar_t>(),
254254
p2.data_ptr<scalar_t>(),
255-
idx.data_ptr<long>(),
255+
idx.data_ptr<int64_t>(),
256256
N,
257257
P1,
258258
P2,

pytorch3d/csrc/rasterize_meshes/geometry_utils.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
#include "float_math.cuh"
88

99
// Set epsilon for preventing floating point errors and division by 0.
10+
#ifdef _MSC_VER
11+
#define kEpsilon 1e-30f
12+
#else
1013
const auto kEpsilon = 1e-30;
14+
#endif
1115

1216
// Determines whether a point p is on the right side of a 2D line segment
1317
// given by the end points v0, v1.
@@ -93,7 +97,7 @@ BarycentricCoordsBackward(
9397
const float2& v2,
9498
const float3& grad_bary_upstream) {
9599
const float area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
96-
const float area2 = pow(area, 2.0);
100+
const float area2 = pow(area, 2.0f);
97101
const float e0 = EdgeFunctionForward(p, v1, v2);
98102
const float e1 = EdgeFunctionForward(p, v2, v0);
99103
const float e2 = EdgeFunctionForward(p, v0, v1);

pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
88
// coordinate in the range [-1, 1]. The NDC range is divided into S evenly-sized
99
// pixels, and assume that each pixel falls in the *center* of its range.
10-
inline float PixToNdc(const int i, const int S) {
10+
static float PixToNdc(const int i, const int S) {
1111
// NDC x-offset + (i * pixel_width + half_pixel_width)
1212
return -1 + (2 * i + 1.0f) / S;
1313
}
@@ -74,7 +74,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
7474
return std::make_tuple(point_idxs, zbuf, pix_dists);
7575
}
7676

77-
std::tuple<torch::Tensor, torch::Tensor> RasterizePointsCoarseCpu(
77+
torch::Tensor RasterizePointsCoarseCpu(
7878
const torch::Tensor& points,
7979
const int image_size,
8080
const float radius,
@@ -140,7 +140,7 @@ std::tuple<torch::Tensor, torch::Tensor> RasterizePointsCoarseCpu(
140140
bin_y_max = bin_y_min + bin_width;
141141
}
142142
}
143-
return std::make_tuple(points_per_bin, bin_points);
143+
return bin_points;
144144
}
145145

146146
torch::Tensor RasterizePointsBackwardCpu(

0 commit comments

Comments
 (0)