Skip to content

Commit cfe91db

Browse files
committed
Merge branch 'main' into musa_mmcv_main
2 parents 768090f + d9e10e1 commit cfe91db

File tree

3 files changed

+42
-16
lines changed

3 files changed

+42
-16
lines changed

docs/zh_cn/mmcv-logo.png

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../docs/mmcv-logo.png
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include "pytorch_npu_helper.hpp"
2+
3+
using namespace NPU_NAME_SPACE;
4+
using namespace std;
5+
6+
void stack_group_points_forward_npu(int b, int c, int n, int nsample,
7+
const Tensor features_tensor,
8+
const Tensor features_batch_cnt_tensor,
9+
const Tensor idx_tensor,
10+
const Tensor idx_batch_cnt_tensor,
11+
Tensor out_tensor) {
12+
EXEC_NPU_CMD(aclnnStackGroupPoints, features_tensor,
13+
features_batch_cnt_tensor, idx_tensor, idx_batch_cnt_tensor,
14+
out_tensor);
15+
}
16+
17+
void stack_group_points_forward_impl(int b, int c, int n, int nsample,
18+
const Tensor features_tensor,
19+
const Tensor features_batch_cnt_tensor,
20+
const Tensor idx_tensor,
21+
const Tensor idx_batch_cnt_tensor,
22+
Tensor out_tensor);
23+
24+
REGISTER_NPU_IMPL(stack_group_points_forward_impl,
25+
stack_group_points_forward_npu);

tests/test_ops/test_group_points.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,27 +89,27 @@ def test_grouping_points(dtype, device):
8989
assert torch.allclose(output, expected_output)
9090

9191

92-
@pytest.mark.skipif(
93-
not (torch.cuda.is_available() or is_musa_available()),
94-
reason='requires CUDA/MUSA support')
95-
@pytest.mark.parametrize('dtype', [
92+
93+
@pytest.mark.parametrize('device', [
9694
pytest.param(
97-
torch.half,
95+
'cuda',
9896
marks=pytest.mark.skipif(
99-
is_musa_available(),
100-
reason='TODO [email protected]: not supported yet')),
101-
torch.float,
97+
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
10298
pytest.param(
103-
torch.double,
99+
'npu',
104100
marks=pytest.mark.skipif(
105-
is_musa_available(),
106-
reason='TODO [email protected]: not supported yet'))
101+
not IS_NPU_AVAILABLE, reason='requires NPU support')),
102+
pytest.param(
103+
'musa',
104+
marks=pytest.mark.skipif(
105+
not IS_MUSA_AVAILABLE, reason='requires MUSA support'))
107106
])
108-
def test_stack_grouping_points(dtype):
109-
if torch.cuda.is_available():
110-
device = 'cuda'
111-
elif is_musa_available():
112-
device = 'musa'
107+
@pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double])
108+
def test_stack_grouping_points(dtype, device):
109+
if device == 'npu' and dtype == torch.double:
110+
return
111+
if device == 'musa' and dtype == torch.double:
112+
return
113113
idx = torch.tensor([[0, 0, 0], [3, 3, 3], [8, 8, 8], [1, 1, 1], [0, 0, 0],
114114
[2, 2, 2], [0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0],
115115
[1, 1, 1], [0, 0, 0]]).int().to(device)

0 commit comments

Comments
 (0)