|
| 1 | +// Copyright 2004-present Facebook. All Rights Reserved. |
| 2 | + |
| 3 | +#pragma once |
| 4 | + |
| 5 | +#include <c10/util/TypeSafeSignMath.h> |
| 6 | + |
| 7 | +#include <algorithm> |
| 8 | +#include <cstddef> |
| 9 | +#include <iterator> |
| 10 | +#include <type_traits> |
| 11 | + |
| 12 | +namespace c10 { |
| 13 | + |
| 14 | +namespace detail { |
| 15 | + |
| 16 | +template < |
| 17 | + typename I, |
| 18 | + bool one_sided = false, |
| 19 | + std::enable_if_t<std::is_integral_v<I>, int> = 0> |
| 20 | +struct integer_iterator { |
| 21 | + using iterator_category = std::input_iterator_tag; |
| 22 | + using value_type = I; |
| 23 | + using difference_type = std::ptrdiff_t; |
| 24 | + using pointer = I*; |
| 25 | + using reference = I&; |
| 26 | + |
| 27 | + explicit integer_iterator(I value) : value(value) {} |
| 28 | + |
| 29 | + I operator*() const { |
| 30 | + return value; |
| 31 | + } |
| 32 | + |
| 33 | + I const* operator->() const { |
| 34 | + return &value; |
| 35 | + } |
| 36 | + |
| 37 | + integer_iterator& operator++() { |
| 38 | + ++value; |
| 39 | + return *this; |
| 40 | + } |
| 41 | + |
| 42 | + integer_iterator operator++(int) { |
| 43 | + const auto copy = *this; |
| 44 | + ++*this; |
| 45 | + return copy; |
| 46 | + } |
| 47 | + |
| 48 | + bool operator==(const integer_iterator& other) const { |
| 49 | + if constexpr (one_sided) { |
| 50 | + // Range-for loops' end test is `begin != end`, not `begin < |
| 51 | + // end`. To handle `c10::irange(n)` where n < 0 (which should be |
| 52 | + // empty), we just make `begin != end` fail whenever `end` is |
| 53 | + // negative. |
| 54 | + return is_negative(other.value) || value == other.value; |
| 55 | + } else { |
| 56 | + return value == other.value; |
| 57 | + } |
| 58 | + // Suppress "warning: missing return statement at end of non-void function" |
| 59 | + // which Nvidia's Robert Crovella confirms is an NVCC compiler error |
| 60 | + // here https://stackoverflow.com/a/64561686/752843 on 2020-10-27 |
| 61 | + // `__builtin_unreachable();` would be best here, but it's not |
| 62 | + // available with all compilers. So we instead return an arbitrary |
| 63 | + // value trusting that this line will, in fact, never be reached. |
| 64 | + return false; // Horrible hack |
| 65 | + } |
| 66 | + |
| 67 | + bool operator!=(const integer_iterator& other) const { |
| 68 | + return !(*this == other); |
| 69 | + } |
| 70 | + |
| 71 | + protected: |
| 72 | + I value; |
| 73 | +}; |
| 74 | + |
| 75 | +} // namespace detail |
| 76 | + |
| 77 | +template < |
| 78 | + typename I, |
| 79 | + bool one_sided = false, |
| 80 | + std::enable_if_t<std::is_integral_v<I>, bool> = true> |
| 81 | +struct integer_range { |
| 82 | + public: |
| 83 | + integer_range(I begin, I end) : begin_(begin), end_(end) {} |
| 84 | + using iterator = detail::integer_iterator<I, one_sided>; |
| 85 | + iterator begin() const { |
| 86 | + return begin_; |
| 87 | + } |
| 88 | + iterator end() const { |
| 89 | + return end_; |
| 90 | + } |
| 91 | + |
| 92 | + private: |
| 93 | + iterator begin_; |
| 94 | + iterator end_; |
| 95 | +}; |
| 96 | + |
| 97 | +/// Creates an integer range for the half-open interval [begin, end) |
| 98 | +/// If end<=begin, then the range is empty. |
| 99 | +/// The range has the type of the `end` integer; `begin` integer is |
| 100 | +/// cast to this type. |
| 101 | +template < |
| 102 | + typename Integer1, |
| 103 | + typename Integer2, |
| 104 | + std::enable_if_t<std::is_integral_v<Integer1>, bool> = true, |
| 105 | + std::enable_if_t<std::is_integral_v<Integer2>, bool> = true> |
| 106 | +integer_range<Integer2> irange(Integer1 begin, Integer2 end) { |
| 107 | + // If end<=begin then the range is empty; we can achieve this effect by |
| 108 | + // choosing the larger of {begin, end} as the loop terminator |
| 109 | + return { |
| 110 | + static_cast<Integer2>(begin), |
| 111 | + std::max(static_cast<Integer2>(begin), end)}; |
| 112 | +} |
| 113 | + |
| 114 | +/// Creates an integer range for the half-open interval [0, end) |
| 115 | +/// If end<=begin, then the range is empty |
| 116 | +template < |
| 117 | + typename Integer, |
| 118 | + std::enable_if_t<std::is_integral_v<Integer>, bool> = true> |
| 119 | +integer_range<Integer, true> irange(Integer end) { |
| 120 | + return {Integer(), end}; |
| 121 | +} |
| 122 | + |
| 123 | +} // namespace c10 |
0 commit comments