Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
]
COMPILE_OPTIONS = {
"msvc": ["/Ox", "/EHsc"],
"other": ["-O3", "-Wno-strict-prototypes", "-Wno-unused-function"],
"other": ["-O3", "-Wno-strict-prototypes", "-Wno-unused-function", "-std=c++11"],
}
COMPILER_DIRECTIVES = {
"language_level": -3,
Expand Down
74 changes: 64 additions & 10 deletions thinc/backends/cpu_kernels.hh
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,58 @@

// All elementwise functions, such as most activations, work in-place.

template <typename A, typename L>
L argmax(A* arr, L len)

template <typename T, typename L>
struct argmax_result {
T max;
L max_idx;
};

template <typename T, typename L>
argmax_result<T, L> argmax(T const *arr, L len)
{
static_assert(std::is_floating_point<A>::value,
static_assert(std::is_floating_point<T>::value,
"Array should be floating point");
static_assert(std::is_integral<L>::value, "Array length should be integral");

L max = 0;
argmax_result<T, L> r { arr[0], 0 };

for (L i = 1; i < len; ++i) {
if (arr[i] > arr[max]) {
max = i;
if (arr[i] > r.max) {
r.max = arr[i];
r.max_idx = i;
}
}

return max;
return r;
}

// The next two templates define argmax for a fixed number of elements.

template <typename T, typename L>
argmax_result<T, L> argmax(T a) {
static_assert(std::is_floating_point<T>::value, "Argument should be floating point");
argmax_result<T, L> acc { a, 0 };
return acc;
}

template<typename T, typename L, typename... Args>
argmax_result<T, L> argmax(T a, Args... args) {
static_assert(std::is_floating_point<T>::value, "Arguments should be floating point");

auto acc = argmax<T, L>(args...);

if (acc.max > a) {
acc.max_idx += 1;
} else {
acc.max_idx = 0;
acc.max = a;
}

return acc;
}


template <typename A, typename L>
void vec_add(A* X, const A* Y, A scale, L N)
{
Expand All @@ -46,12 +81,31 @@ void cpu_maxout(A* best__bo, L* which__bo, const A* cands__bop, L B, L O, L P)
"Array should be floating point");
static_assert(std::is_integral<L>::value, "Array length should be integral");

for (int i = 0; i < B * O; ++i) {
which__bo[i] = argmax(cands__bop + i * P, P);
best__bo[i] = cands__bop[i * P + which__bo[i]];
// For small inputs, we use an unrolled argmax.
if (P == 2) {
for (int i = 0; i < B * O; ++i) {
A const *input = cands__bop + i * P;
auto r = argmax<A, L>(input[0], input[1]);
which__bo[i] = r.max_idx;
best__bo[i] = r.max;
}
} else if (P == 3) {
for (int i = 0; i < B * O; ++i) {
A const *input = cands__bop + i * P;
auto r = argmax<A, L>(input[0], input[1], input[2]);
which__bo[i] = r.max_idx;
best__bo[i] = r.max;
}
} else {
for (int i = 0; i < B * O; ++i) {
auto r = argmax<A, L>(cands__bop + i * P, P);
which__bo[i] = r.max_idx;
best__bo[i] = r.max;
}
}
}


template <typename A, typename L>
void cpu_backprop_maxout(A* dX__bop, const A* dX__bo, const L* which__bo,
L B, L O, L P)
Expand Down