Skip to content

Commit c29c0b8

Browse files
dylanbespalkofacebook-github-bot
authored andcommitted
Improved speed of frobenous norm for non-complex dtype (pytorch#30871)
Summary: In-tree changes to pytorch to support complex numbers are being submitted here. Out-of-tree support for CUDA complex numbers is here: [pytorch-cuda-strided-complex extension](https://gitlab.com/pytorch-complex/pytorch-cuda-strided-complex) Changes: [x] Fixed performance issue raise in pytorch#30704 so that non-complex numbers do not call `conj()` and `real()`. [x] Fixed tensor_to_numpy() conversion likely broken by a `checkBackend()` in pytorch#27064. [x] Fixed some ReduceOps and TensorCompare Ops that recently added a `checkBackend()`. - `checkBackend()` is replaced with a device type check and a layout check. - This ensures the ComplexCPU Type ID is supported. [x] Added AVX support for complex `exp()`, as requested in pytorch#755 Pull Request resolved: pytorch#30871 Differential Revision: D19200726 Pulled By: ezyang fbshipit-source-id: d7e1be0b0a89c5d6e5f4a68ce5fcd2adc5b88277
1 parent 87ab293 commit c29c0b8

File tree

8 files changed

+101
-52
lines changed

8 files changed

+101
-52
lines changed

aten/src/ATen/cpu/vec256/vec256_complex_double.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,14 @@ template <> class Vec256<std::complex<double>> {
202202
AT_ERROR("not supported for complex numbers");
203203
}
204204
Vec256<std::complex<double>> exp() const {
205-
return map(std::exp);
205+
//exp(a + bi)
206+
// = exp(a)*(cos(b) + sin(b)i)
207+
auto exp = Sleef_expd4_u10(values); //exp(a) exp(b)
208+
exp = _mm256_blend_pd(exp, _mm256_permute_pd(exp, 0x05), 0x0A); //exp(a) exp(a)
209+
210+
auto sin_cos = Sleef_sincosd4_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)]
211+
auto cos_sin = _mm256_blend_pd(sin_cos.y, sin_cos.x, 0x0A); //cos(b) sin(b)
212+
return _mm256_mul_pd(exp, cos_sin);
206213
}
207214
Vec256<std::complex<double>> expm1() const {
208215
AT_ERROR("not supported for complex numbers");

aten/src/ATen/cpu/vec256/vec256_complex_float.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,14 @@ template <> class Vec256<std::complex<float>> {
240240
AT_ERROR("not supported for complex numbers");
241241
}
242242
Vec256<std::complex<float>> exp() const {
243-
return map(std::exp);
243+
//exp(a + bi)
244+
// = exp(a)*(cos(b) + sin(b)i)
245+
auto exp = Sleef_expf8_u10(values); //exp(a) exp(b)
246+
exp = _mm256_blend_ps(exp, _mm256_permute_ps(exp, 0xB1), 0xAA); //exp(a) exp(a)
247+
248+
auto sin_cos = Sleef_sincosf8_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)]
249+
auto cos_sin = _mm256_blend_ps(sin_cos.y, sin_cos.x, 0xAA); //cos(b) sin(b)
250+
return _mm256_mul_ps(exp, cos_sin);
244251
}
245252
Vec256<std::complex<float>> expm1() const {
246253
AT_ERROR("not supported for complex numbers");

aten/src/ATen/native/Fill.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Tensor& fill_out(Tensor& self, Scalar value) {
2424
// Ideally this fast pass should be implemented in TensorIterator,
2525
// but we also want to skip compute_types which in not avoidable
2626
// in TensorIterator for now.
27-
if (self.device() == at::kCPU && self.numel() == 1 && !value.isComplex()) {
27+
if (self.device() == at::kCPU && self.numel() == 1 && !self.is_complex() && !value.isComplex()) {
2828
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, self.scalar_type(), "fill_out", [&]() {
2929
fill_fast<scalar_t>(self, value);});
3030
return self;

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,11 @@ Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
513513
if (dim.size() == 1) {
514514
return at::norm(self, 2, dim, keepdim, self.scalar_type());
515515
}
516-
return at::sqrt(at::sum((self.conj() * self).real(), dim, keepdim));
516+
if (self.is_complex()){
517+
return at::sqrt(at::sum((self.conj() * self).real(), dim, keepdim));
518+
} else {
519+
return at::sqrt(at::sum((self * self), dim, keepdim));
520+
}
517521
}
518522

519523
Tensor &frobenius_norm_out(
@@ -529,7 +533,11 @@ Tensor &frobenius_norm_out(
529533
if (dim.size() == 1) {
530534
return at::norm_out(result, self, 2, dim, keepdim, self.scalar_type());
531535
}
532-
return at::sqrt_out(result, at::sum((self.conj() * self).real(), dim, keepdim));
536+
if (self.is_complex()){
537+
return at::sqrt_out(result, at::sum((self.conj() * self).real(), dim, keepdim));
538+
} else {
539+
return at::sqrt_out(result, at::sum((self * self), dim, keepdim));
540+
}
533541
}
534542

535543
Tensor nuclear_norm(const Tensor& self, bool keepdim) {

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -488,9 +488,10 @@ Tensor& logsumexp_out(Tensor& result, const Tensor& self, DimnameList dims, bool
488488
static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt_p,
489489
IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
490490
auto p = opt_p.value_or(2.0);
491-
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
492-
"norm only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
493-
491+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
492+
"norm only supports CPU AND CUDA device type, got: ", self.device().type());
493+
TORCH_CHECK(self.layout() == Layout::Strided,
494+
"norm only supports strided layout, got: ", self.layout());
494495

495496
ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
496497
TORCH_CHECK(
@@ -513,8 +514,10 @@ static inline Tensor _norm(const Tensor &self, Scalar p) {
513514
if (self.is_sparse()) {
514515
return at::native_norm(self, p);
515516
} else {
516-
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
517-
"norm only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
517+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
518+
"norm only supports CPU AND CUDA device type, got: ", self.device().type());
519+
TORCH_CHECK(self.layout() == Layout::Strided,
520+
"norm only supports strided layout, got: ", self.layout());
518521
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
519522
"norm only supports floating-point dtypes");
520523

@@ -565,9 +568,10 @@ inline Tensor & _all(Tensor & result, TensorIterator & iter) {
565568
}
566569

567570
Tensor all(const Tensor& self) {
568-
TORCH_CHECK(self.options().backend() == Backend::CPU ||
569-
self.options().backend() == Backend::CUDA, "all only supports CPU AND CUDA "
570-
"backend, got: ", toString(self.options().backend()));
571+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
572+
"all only supports CPU AND CUDA device type, got: ", self.device().type());
573+
TORCH_CHECK(self.layout() == Layout::Strided,
574+
"all only supports strided layout, got: ", self.layout());
571575
TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool,
572576
"all only supports torch.uint8 and torch.bool dtypes");
573577

@@ -583,9 +587,10 @@ Tensor all(const Tensor& self, int64_t dim, bool keepdim) {
583587
}
584588

585589
Tensor &all_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
586-
TORCH_CHECK(self.options().backend() == Backend::CPU ||
587-
self.options().backend() == Backend::CUDA, "all only supports CPU AND CUDA "
588-
"backend, got: ", toString(self.options().backend()));
590+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
591+
"all only supports CPU AND CUDA device type, got: ", self.device().type());
592+
TORCH_CHECK(self.layout() == Layout::Strided,
593+
"all only supports strided layout, got: ", self.layout());
589594
TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool,
590595
"all only supports torch.uint8 and torch.bool dtypes");
591596
dim = maybe_wrap_dim(dim, self.dim());
@@ -609,11 +614,10 @@ inline Tensor & _any(Tensor & result, TensorIterator & iter) {
609614
}
610615

611616
Tensor any(const Tensor& self) {
612-
TORCH_CHECK(self.options().backend() == Backend::CPU ||
613-
self.options().backend() == Backend::CUDA ||
614-
self.options().backend() == Backend::SparseCPU ||
615-
self.options().backend() == Backend::SparseCUDA, "any only supports CPU, CUDA, "
616-
"SparseCPU and SparseCUDA backend, got: ", toString(self.options().backend()));
617+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
618+
"any only supports CPU AND CUDA device type, got: ", self.device().type());
619+
TORCH_CHECK(self.layout() == Layout::Strided || self.layout() == Layout::Sparse,
620+
"any only supports strided AND sparse layout, got: ", self.layout());
617621
TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool,
618622
"all only supports torch.uint8 and torch.bool dtypes");
619623

@@ -629,9 +633,10 @@ Tensor any(const Tensor& self, int64_t dim, bool keepdim) {
629633
}
630634

631635
Tensor &any_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
632-
TORCH_CHECK(self.options().backend() == Backend::CPU ||
633-
self.options().backend() == Backend::CUDA, "any only supports CPU AND CUDA "
634-
"backend, got: ", toString(self.options().backend()));
636+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
637+
"any only supports CPU AND CUDA device type, got: ", self.device().type());
638+
TORCH_CHECK(self.layout() == Layout::Strided,
639+
"any only supports strided layout, got: ", self.layout());
635640
TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool,
636641
"all only supports torch.uint8 and torch.bool dtypes");
637642
dim = maybe_wrap_dim(dim, self.dim());
@@ -730,8 +735,10 @@ Tensor argmin(const Tensor& self, c10::optional<int64_t> dim, bool keepdims) {
730735
}
731736

732737
static Tensor &std_var_out(Tensor &result, const Tensor &self, IntArrayRef dim, bool unbiased, bool keepdim, bool take_sqrt) {
733-
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
734-
"std and var only support CPU AND CUDA backend, got: ", toString(self.options().backend()));
738+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
739+
"std and var only supports CPU AND CUDA device type, got: ", self.device().type());
740+
TORCH_CHECK(self.layout() == Layout::Strided,
741+
"std and var only supports strided layout, got: ", self.layout());
735742
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
736743
"std and var only support floating-point dtypes");
737744

@@ -769,8 +776,12 @@ static Tensor &std_var_out(Tensor &result, const Tensor &self, IntArrayRef dim,
769776

770777
static std::tuple<Tensor&,Tensor&> std_var_mean_out(const char* fname, Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim, bool unbiased, bool keepdim, bool take_sqrt) {
771778
AT_ASSERT(result1.defined() && result2.defined());
772-
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA, fname, " only support CPU AND CUDA backend, got: ", toString(self.options().backend()));
773-
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), fname, " only support floating-point dtypes");
779+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
780+
fname, " only supports CPU AND CUDA device type, got: ", self.device().type());
781+
TORCH_CHECK(self.layout() == Layout::Strided,
782+
fname, " only supports strided layout, got: ", self.layout());
783+
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
784+
fname, " only support floating-point dtypes");
774785
TORCH_CHECK(result1.scalar_type() == result2.scalar_type(),
775786
"provided by result1 dtype must match dtype of result2. Got ",
776787
toString(result1.scalar_type()),
@@ -856,8 +867,10 @@ std::tuple<Tensor,Tensor> var_mean(const Tensor& self, bool unbiased) {
856867
}
857868

858869
Tensor var(const Tensor& self, bool unbiased) {
859-
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
860-
"var only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
870+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
871+
"var only supports CPU AND CUDA device type, got: ", self.device().type());
872+
TORCH_CHECK(self.layout() == Layout::Strided,
873+
"var only supports strided layout, got: ", self.layout());
861874
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
862875
"var only supports floating-point dtypes");
863876
auto trivial_return = _allreduce_return_trivial(self, std::numeric_limits<double>::quiet_NaN());
@@ -874,8 +887,10 @@ Tensor &var_out(Tensor &result, const Tensor &self, IntArrayRef dim, bool unbias
874887
}
875888

876889
Tensor std(const Tensor& self, bool unbiased) {
877-
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
878-
"std only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
890+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
891+
"std only supports CPU AND CUDA device type, got: ", self.device().type());
892+
TORCH_CHECK(self.layout() == Layout::Strided,
893+
"std only supports strided layout, got: ", self.layout());
879894
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
880895
"std only supports floating-point dtypes");
881896
auto trivial_return = _allreduce_return_trivial(self, std::numeric_limits<double>::quiet_NaN());

aten/src/ATen/native/TensorCompare.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,10 @@ std::tuple<Tensor, Tensor> mode(const Tensor& self, int64_t dim, bool keepdim) {
159159

160160
std::tuple<Tensor &,Tensor &> mode_out(Tensor& values, Tensor& indices,
161161
const Tensor& self, int64_t dim, bool keepdim) {
162-
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
163-
"mode only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
162+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
163+
"mode only supports CPU AND CUDA device type, got: ", self.device().type());
164+
TORCH_CHECK(self.layout() == Layout::Strided,
165+
"mode only supports strided layout, got: ", self.layout());
164166
dim = maybe_wrap_dim(dim, self.dim());
165167
if (_dimreduce_return_trivial_no_ident(values, self, dim, keepdim, "mode")) {
166168
AT_ASSERT(values.dim() == 0);
@@ -207,8 +209,10 @@ std::tuple<Tensor, Tensor> max(const Tensor& self, int64_t dim, bool keepdim) {
207209

208210
static std::tuple<Tensor &,Tensor &> max_out_impl(Tensor& max, Tensor& max_indices,
209211
const Tensor& self, int64_t dim, bool keepdim) {
210-
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
211-
"max only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
212+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
213+
"max only supports CPU AND CUDA device type, got: ", self.device().type());
214+
TORCH_CHECK(self.layout() == Layout::Strided,
215+
"max only supports strided layout, got: ", self.layout());
212216
dim = maybe_wrap_dim(dim, self.dim());
213217
if (_dimreduce_return_trivial_no_ident(max, self, dim, keepdim, "max")) {
214218
AT_ASSERT(max.dim() == 0);
@@ -263,8 +267,10 @@ std::tuple<Tensor, Tensor> min(const Tensor& self, int64_t dim, bool keepdim) {
263267

264268
static std::tuple<Tensor &,Tensor &> min_out_impl(Tensor& min, Tensor& min_indices,
265269
const Tensor& self, int64_t dim, bool keepdim) {
266-
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
267-
"min only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
270+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
271+
"min only supports CPU AND CUDA device type, got: ", self.device().type());
272+
TORCH_CHECK(self.layout() == Layout::Strided,
273+
"min only supports strided layout, got: ", self.layout());
268274
dim = maybe_wrap_dim(dim, self.dim());
269275
if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min")) {
270276
AT_ASSERT(min.dim() == 0);

aten/src/ATen/native/UnaryOps.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,10 @@ Tensor& _clamp_out_cpu(
229229
optional<Scalar> min,
230230
optional<Scalar> max) {
231231
if (min && max) {
232-
checkBackend("clamp", result, Backend::CPU);
232+
TORCH_CHECK(self.device().type() == DeviceType::CPU,
233+
"clamp only supports CPU device type, got: ", self.device().type());
234+
TORCH_CHECK(self.layout() == Layout::Strided,
235+
"clamp only supports strided layout, got: ", self.layout());
233236
auto iter = TensorIterator::unary_op(result, self,
234237
/*check_mem_overlap=*/true);
235238
clamp_stub(iter.device_type(), iter, *min, *max);
@@ -248,7 +251,10 @@ Tensor& _clamp_max__cpu(Tensor& self, Scalar max) {
248251
}
249252

250253
Tensor& _clamp_max_out_cpu(Tensor& result, const Tensor& self, Scalar max) {
251-
checkBackend("clamp_max", result, Backend::CPU);
254+
TORCH_CHECK(self.device().type() == DeviceType::CPU,
255+
"clamp_max only supports CPU device type, got: ", self.device().type());
256+
TORCH_CHECK(self.layout() == Layout::Strided,
257+
"clamp_max only supports strided layout, got: ", self.layout());
252258
auto iter = TensorIterator::unary_op(result, self,
253259
/*check_mem_overlap=*/true);
254260
clamp_max_stub(iter.device_type(), iter, max);
@@ -260,7 +266,10 @@ Tensor& _clamp_min__cpu(Tensor& self, Scalar min) {
260266
}
261267

262268
Tensor& _clamp_min_out_cpu(Tensor& result, const Tensor& self, Scalar min) {
263-
checkBackend("clamp_min", result, Backend::CPU);
269+
TORCH_CHECK(self.device().type() == DeviceType::CPU,
270+
"clamp_min only supports CPU device type, got: ", self.device().type());
271+
TORCH_CHECK(self.layout() == Layout::Strided,
272+
"clamp_min only supports strided layout, got: ", self.layout());
264273
auto iter = TensorIterator::unary_op(result, self,
265274
/*check_mem_overlap=*/true);
266275
clamp_min_stub(iter.device_type(), iter, min);

torch/csrc/utils/tensor_numpy.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,15 @@ static std::vector<int64_t> seq_to_aten_shape(PyObject *py_seq) {
7474
}
7575

7676
PyObject* tensor_to_numpy(const at::Tensor& tensor) {
77-
if (tensor.is_cuda()) {
78-
throw TypeError(
79-
"can't convert CUDA tensor to numpy. Use Tensor.cpu() to "
80-
"copy the tensor to host memory first.");
77+
if (tensor.device().type() != DeviceType::CPU) {
78+
throw TypeError(
79+
"can't convert %s device type tensor to numpy. Use Tensor.cpu() to "
80+
"copy the tensor to host memory first.", tensor.device().type());
8181
}
82-
if (tensor.is_sparse()) {
83-
throw TypeError(
84-
"can't convert sparse tensor to numpy. Use Tensor.to_dense() to "
85-
"convert to a dense tensor first.");
86-
}
87-
if (tensor.options().backend() != Backend::CPU) {
88-
throw TypeError("NumPy conversion for %s is not supported", tensor.toString().c_str());
82+
if (tensor.layout() != Layout::Strided) {
83+
throw TypeError(
84+
"can't convert %s layout tensor to numpy."
85+
"convert the tensor to a strided layout first.", tensor.layout());
8986
}
9087
if (tensor.requires_grad()) {
9188
throw std::runtime_error(

0 commit comments

Comments
 (0)