diff --git a/runtime/core/portable_type/c10/README.md b/runtime/core/portable_type/c10/README.md index df14d22a4cf..104a6717ba7 100644 --- a/runtime/core/portable_type/c10/README.md +++ b/runtime/core/portable_type/c10/README.md @@ -1,7 +1,13 @@ -We added an extra c10 directory so that runtime/core/portable_type/c10 +This directory contains header files from `c10` in PyTorch core that +need to be used in ExecuTorch core. They are copied here rather than +being found through the torch pip package to keep the core build +hermetic for embedded use cases. The headers should be exact copies +from PyTorch core; if they are out of sync, please send a PR! + +We added an extra c10 directory so that `runtime/core/portable_type/c10` can be the directory to put on your include path, rather than -runtime/core/portable_type, because using runtime/core/portable_type +`runtime/core/portable_type`, because using `runtime/core/portable_type` would cause all headers in that directory to be includeable with `#include `. In particular, that includes -runtime/core/portable_type/complex.h, which would shadow the C99 -complex.h standard header. +`runtime/core/portable_type/complex.h`, which would shadow the C99 +`complex.h` standard header. diff --git a/runtime/core/portable_type/c10/c10/targets.bzl b/runtime/core/portable_type/c10/c10/targets.bzl index 1e60b70a4b8..64436278e79 100644 --- a/runtime/core/portable_type/c10/c10/targets.bzl +++ b/runtime/core/portable_type/c10/c10/targets.bzl @@ -26,6 +26,7 @@ def define_common_targets(): "util/TypeSafeSignMath.h", "util/bit_cast.h", "util/floating_point_utils.h", + "util/irange.h", ], exported_preprocessor_flags = [ # NOTE: If we define C10_EMBEDDED to prevent Half and diff --git a/runtime/core/portable_type/c10/c10/util/irange.h b/runtime/core/portable_type/c10/c10/util/irange.h new file mode 100644 index 00000000000..2719a82075c --- /dev/null +++ b/runtime/core/portable_type/c10/c10/util/irange.h @@ -0,0 +1,123 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once + +#include + +#include +#include +#include +#include + +namespace c10 { + +namespace detail { + +template < + typename I, + bool one_sided = false, + std::enable_if_t, int> = 0> +struct integer_iterator { + using iterator_category = std::input_iterator_tag; + using value_type = I; + using difference_type = std::ptrdiff_t; + using pointer = I*; + using reference = I&; + + explicit integer_iterator(I value) : value(value) {} + + I operator*() const { + return value; + } + + I const* operator->() const { + return &value; + } + + integer_iterator& operator++() { + ++value; + return *this; + } + + integer_iterator operator++(int) { + const auto copy = *this; + ++*this; + return copy; + } + + bool operator==(const integer_iterator& other) const { + if constexpr (one_sided) { + // Range-for loops' end test is `begin != end`, not `begin < + // end`. To handle `c10::irange(n)` where n < 0 (which should be + // empty), we just make `begin != end` fail whenever `end` is + // negative. + return is_negative(other.value) || value == other.value; + } else { + return value == other.value; + } + // Suppress "warning: missing return statement at end of non-void function" + // which Nvidia's Robert Crovella confirms is an NVCC compiler error + // here https://stackoverflow.com/a/64561686/752843 on 2020-10-27 + // `__builtin_unreachable();` would be best here, but it's not + // available with all compilers. So we instead return an arbitrary + // value trusting that this line will, in fact, never be reached. + return false; // Horrible hack + } + + bool operator!=(const integer_iterator& other) const { + return !(*this == other); + } + + protected: + I value; +}; + +} // namespace detail + +template < + typename I, + bool one_sided = false, + std::enable_if_t, bool> = true> +struct integer_range { + public: + integer_range(I begin, I end) : begin_(begin), end_(end) {} + using iterator = detail::integer_iterator; + iterator begin() const { + return begin_; + } + iterator end() const { + return end_; + } + + private: + iterator begin_; + iterator end_; +}; + +/// Creates an integer range for the half-open interval [begin, end) +/// If end<=begin, then the range is empty. +/// The range has the type of the `end` integer; `begin` integer is +/// cast to this type. +template < + typename Integer1, + typename Integer2, + std::enable_if_t, bool> = true, + std::enable_if_t, bool> = true> +integer_range irange(Integer1 begin, Integer2 end) { + // If end<=begin then the range is empty; we can achieve this effect by + // choosing the larger of {begin, end} as the loop terminator + return { + static_cast(begin), + std::max(static_cast(begin), end)}; +} + +/// Creates an integer range for the half-open interval [0, end) +/// If end<=begin, then the range is empty +template < + typename Integer, + std::enable_if_t, bool> = true> +integer_range irange(Integer end) { + return {Integer(), end}; +} + +} // namespace c10 diff --git a/runtime/core/portable_type/targets.bzl b/runtime/core/portable_type/targets.bzl index 43efeca208c..6178f2c0f9a 100644 --- a/runtime/core/portable_type/targets.bzl +++ b/runtime/core/portable_type/targets.bzl @@ -28,6 +28,9 @@ def define_common_targets(): "//executorch/runtime/core/exec_aten/...", "//executorch/runtime/core/portable_type/test/...", ], + deps = [ + "//executorch/runtime/core/portable_type/c10/c10:c10", + ], exported_deps = [ ":scalar_type", "//executorch/runtime/core:core", diff --git a/runtime/core/portable_type/tensor_impl.cpp b/runtime/core/portable_type/tensor_impl.cpp index b978e23cbd6..6366a8eac28 100644 --- a/runtime/core/portable_type/tensor_impl.cpp +++ b/runtime/core/portable_type/tensor_impl.cpp @@ -11,6 +11,8 @@ #include #include +#include + #include #include #include @@ -30,7 +32,7 @@ ssize_t compute_numel(const TensorImpl::SizesType* sizes, ssize_t dim) { dim == 0 || sizes != nullptr, "Sizes must be provided for non-scalar tensors"); ssize_t numel = 1; // Zero-dimensional tensors (scalars) have numel == 1. - for (ssize_t i = 0; i < dim; ++i) { + for (const auto i : c10::irange(dim)) { ET_CHECK_MSG( sizes[i] >= 0, "Size must be non-negative, got %d at dimension %zd",