diff --git a/sycl/doc/extensions/SYCL_ONEAPI_dot_accumulate.asciidoc b/sycl/doc/extensions/SYCL_ONEAPI_dot_accumulate.asciidoc new file mode 100755 index 0000000000000..c34c4581d14d3 --- /dev/null +++ b/sycl/doc/extensions/SYCL_ONEAPI_dot_accumulate.asciidoc @@ -0,0 +1,148 @@ += SYCL_INTEL_dot_accumulate +:source-highlighter: coderay +:coderay-linenums-mode: table +:doctype: book +:encoding: utf-8 +:lang: en + +:blank: pass:[ +] + +// Set the default source code type in this document to C, +// for syntax highlighting purposes. +:language: c + +// This is what is needed for C++, since docbook uses c++ +// and everything else uses cpp. This doesn't work when +// source blocks are in table cells, though, so don't use +// C++ unless it is required. +//:language: {basebackend@docbook:c++:cpp} + +== Introduction + +IMPORTANT: This specification is a draft. + +NOTE: Khronos(R) is a registered trademark and SYCL(TM) and SPIR(TM) are trademarks of The Khronos Group Inc. OpenCL(TM) is a trademark of Apple Inc. used by permission by Khronos. + +NOTE: This document is better viewed when rendered as html with asciidoctor. GitHub does not render image icons. + +== Name Strings + +`SYCL_ONEAPI_dot_accumulate` + +This is a placeholder name. + +== Notice + +Copyright (c) 2020 Intel Corporation. All rights reserved. + +== Status + +Working Draft + +This is a preview extension specification, intended to provide early access to a feature for review and community feedback. When the feature matures, this specification may be released as a formal extension. + +Because the interfaces defined by this specification are not final and are subject to change they are not intended to be used by shipping software products. + +== Version + +Built On: {docdate} + +Revision: B + +== Contact + +Ben Ashbaugh, Intel (ben 'dot' ashbaugh 'at' intel 'dot' com) + +== Dependencies + +This extension is written against the SYCL 1.2.1 specification, Revision v1.2.1-6. + +== Overview + +This extension adds new SYCL built-in functions that may simplify development and provide access specialized hardware instructions when a SYCL kernel needs to perform a dot product of two vectors followed by a scalar accumulation. + +== Enabling the extension + +The extension is always enabled. The dot product functionality may be emulated in software or executed using hardware when suitable instructions are available. + +== Modifications of SYCL 1.2.1 specification + +=== Add to Section 4.13.6 - Geometric Functions + +Additionally, the following additional functions are available in the namespace `cl::sycl::intel` on the host and device. + +[cols="4a,4",options="header"] +|==== +| *Function* +| *Description* + +|[source,c] +---- +int32_t dot_acc(vec a, + vec b, + int32_t c) +int32_t dot_acc(vec a, + vec b, + int32_t c) +int32_t dot_acc(vec a, + vec b, + int32_t c) +int32_t dot_acc(vec a, + vec b, + int32_t c) +---- + +|Performs a four-component integer dot product accumulate operation. + +{blank} +The value that is returned is equivalent to + +{blank} +*dot*(_a_, _b_) + _c_ + +|==== + +== Sample Header + +[source,c++] +---- +namespace cl { +namespace sycl { +namespace ONEAPI { + +int32_t dot_acc(vec a, vec b, int32_t c); +int32_t dot_acc(vec a, vec b, int32_t c); +int32_t dot_acc(vec a, vec b, int32_t c); +int32_t dot_acc(vec a, vec b, int32_t c); + +int32_t dot_acc(int32_t a, int32_t b, int32_t c); +int32_t dot_acc(int32_t a, uint32_t b, int32_t c); +int32_t dot_acc(uint32_t a, int32_t b, int32_t c); +int32_t dot_acc(uint32_t a, uint32_t b, int32_t c); + +} // ONEAPI +} // sycl +} // cl +---- + +== Issues + +None. + +== Revision History + +[cols="5,15,15,70"] +[grid="rows"] +[options="header"] +|======================================== +|Rev|Date|Author|Changes +|A|2019-12-13|Ben Ashbaugh|*Initial draft* +|B|2019-12-18|Ben Ashbaugh|Switched to standard C++ fixed width types. +|C|2020-10-26|Rajiv Deodhar|Added int32 types. +|======================================== + +//************************************************************************ +//Other formatting suggestions: +// +//* Use *bold* text for host APIs, or [source] syntax highlighting. +//* Use `mono` text for device APIs, or [source] syntax highlighting. +//* Use `mono` text for extension names, types, or enum values. +//* Use _italics_ for parameters. +//************************************************************************ diff --git a/sycl/include/CL/sycl/ONEAPI/dot_product.hpp b/sycl/include/CL/sycl/ONEAPI/dot_product.hpp new file mode 100755 index 0000000000000..865a15de5acb3 --- /dev/null +++ b/sycl/include/CL/sycl/ONEAPI/dot_product.hpp @@ -0,0 +1,76 @@ +//==----------- dot_product.hpp ------- SYCL dot-product -------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// DP4A extension + +#pragma once + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace ONEAPI { + +union Us { + char s[4]; + int32_t i; +}; +union Uu { + unsigned char s[4]; + uint32_t i; +}; + +int32_t dot_acc(int32_t pa, int32_t pb, int32_t c) { + Us a = *(reinterpret_cast(&pa)); + Us b = *(reinterpret_cast(&pb)); + return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] + + c; +} + +int32_t dot_acc(uint32_t pa, uint32_t pb, int32_t c) { + Uu a = *(reinterpret_cast(&pa)); + Uu b = *(reinterpret_cast(&pb)); + return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] + + c; +} + +int32_t dot_acc(int32_t pa, uint32_t pb, int32_t c) { + Us a = *(reinterpret_cast(&pa)); + Uu b = *(reinterpret_cast(&pb)); + return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] + + c; +} + +int32_t dot_acc(uint32_t pa, int32_t pb, int32_t c) { + Uu a = *(reinterpret_cast(&pa)); + Us b = *(reinterpret_cast(&pb)); + return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] + + c; +} + +int32_t dot_acc(vec a, vec b, int32_t c) { + return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() + + c; +} + +int32_t dot_acc(vec a, vec b, int32_t c) { + return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() + + c; +} + +int32_t dot_acc(vec a, vec b, int32_t c) { + return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() + + c; +} + +int32_t dot_acc(vec a, vec b, int32_t c) { + return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() + + c; +} + +} // namespace ONEAPI +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/test/dot_product/dot_product_int_test.cpp b/sycl/test/dot_product/dot_product_int_test.cpp new file mode 100755 index 0000000000000..e68d321446426 --- /dev/null +++ b/sycl/test/dot_product/dot_product_int_test.cpp @@ -0,0 +1,248 @@ +// This test checks dp4a support +// For now we only check fallback support because DG1 hardware is not widespread + +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out +// RUN: env SYCL_DEVICE_TYPE=HOST %t.out +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out +// RUN: %ACC_RUN_PLACEHOLDER %t.out + +#include +#include +#include +#include +#include + +// Change if tests are added/removed +static int testCount = 4; +static int passCount; + +using namespace cl::sycl; +using namespace cl::sycl::detail::gtl; +using namespace cl::sycl::ONEAPI; + +constexpr int RangeLength = 100; + +// Verify 1D array +template +static bool verify_1D(const char *name, int X, T *A, T *A_ref) { + int ErrCnt = 0; + + for (int i = 0; i < X; i++) { + if (A_ref[i] != A[i]) { + if (++ErrCnt < 10) { + std::cout << name << " mismatch at " << i << ". Expected " << A_ref[i] + << " result is " << A[i] << "\n"; + } + } + } + + if (ErrCnt == 0) { + return true; + } + std::cout << " Failed. Failure rate: " << ErrCnt << "/" << X << "(" + << ErrCnt / (float)X * 100.f << "%)\n"; + return false; +} + +static bool testss(queue &Q) { + int A[RangeLength]; + int B[RangeLength]; + int C[RangeLength]; + int D[RangeLength]; + int D_ref[RangeLength]; + + std::memset(D, 0, RangeLength * sizeof(int)); + std::memset(D_ref, 0, RangeLength * sizeof(int)); + + for (int i = 0; i < RangeLength; i++) { + A[i] = i | (i << 8) | (i << 16) | (i << 24); + B[i] = 0xFFFFFFFF; + C[i] = i; + } + for (int i = 0; i < RangeLength; i++) { + D_ref[i] = (i * -1) + (i * -1) + (i * -1) + (i * -1) + C[i]; + } + + buffer Abuf(A, range<1>(RangeLength)); + buffer Bbuf(B, range<1>(RangeLength)); + buffer Cbuf(C, range<1>(RangeLength)); + buffer Dbuf(D, range<1>(RangeLength)); + + Q.submit([&](handler &cgh) { + auto Ap = Abuf.get_access(cgh); + auto Bp = Bbuf.get_access(cgh); + auto Cp = Cbuf.get_access(cgh); + auto Dp = Dbuf.get_access(cgh); + + cgh.parallel_for(range<1>(RangeLength), [=](id<1> I) { + Dp[I] = dot_acc(Ap[I], Bp[I], Cp[I]); + }); + }); + const auto HAcc = Dbuf.get_access(); + + return verify_1D("testss D", RangeLength, D, D_ref); +} + +static bool testuu(queue &Q) { + unsigned int A[RangeLength]; + unsigned int B[RangeLength]; + int C[RangeLength]; + int D[RangeLength]; + int D_ref[RangeLength]; + + std::memset(D, 0, RangeLength * sizeof(int)); + std::memset(D_ref, 0, RangeLength * sizeof(int)); + + for (int i = 0; i < RangeLength; i++) { + A[i] = i | (i << 8) | (i << 16) | (i << 24); + B[i] = 0xFFFFFFFF; + C[i] = i; + } + for (int i = 0; i < RangeLength; i++) { + D_ref[i] = (i * 255) + (i * 255) + (i * 255) + (i * 255) + C[i]; + } + + buffer Abuf(A, range<1>(RangeLength)); + buffer Bbuf(B, range<1>(RangeLength)); + buffer Cbuf(C, range<1>(RangeLength)); + buffer Dbuf(D, range<1>(RangeLength)); + + Q.submit([&](handler &cgh) { + auto Ap = Abuf.get_access(cgh); + auto Bp = Bbuf.get_access(cgh); + auto Cp = Cbuf.get_access(cgh); + auto Dp = Dbuf.get_access(cgh); + + cgh.parallel_for(range<1>(RangeLength), [=](id<1> I) { + Dp[I] = dot_acc(Ap[I], Bp[I], Cp[I]); + }); + }); + const auto HAcc = Dbuf.get_access(); + + return verify_1D("testuu D", RangeLength, D, D_ref); +} + +static bool testsu(queue &Q) { + int A[RangeLength]; + unsigned int B[RangeLength]; + int C[RangeLength]; + int D[RangeLength]; + int D_ref[RangeLength]; + + std::memset(D, 0, RangeLength * sizeof(int)); + std::memset(D_ref, 0, RangeLength * sizeof(int)); + + for (int i = 0; i < RangeLength; i++) { + A[i] = 0xFFFFFFFF; + B[i] = i | (i << 8) | (i << 16) | (i << 24); + C[i] = i; + } + for (int i = 0; i < RangeLength; i++) { + D_ref[i] = (i * -1) + (i * -1) + (i * -1) + (i * -1) + C[i]; + } + + buffer Abuf(A, range<1>(RangeLength)); + buffer Bbuf(B, range<1>(RangeLength)); + buffer Cbuf(C, range<1>(RangeLength)); + buffer Dbuf(D, range<1>(RangeLength)); + + Q.submit([&](handler &cgh) { + auto Ap = Abuf.get_access(cgh); + auto Bp = Bbuf.get_access(cgh); + auto Cp = Cbuf.get_access(cgh); + auto Dp = Dbuf.get_access(cgh); + + cgh.parallel_for(range<1>(RangeLength), [=](id<1> I) { + Dp[I] = dot_acc(Ap[I], Bp[I], Cp[I]); + }); + }); + const auto HAcc = Dbuf.get_access(); + + return verify_1D("testsu D", RangeLength, D, D_ref); +} + +static bool testus(queue &Q) { + unsigned int A[RangeLength]; + int B[RangeLength]; + int C[RangeLength]; + int D[RangeLength]; + int D_ref[RangeLength]; + + std::memset(D, 0, RangeLength * sizeof(int)); + std::memset(D_ref, 0, RangeLength * sizeof(int)); + + for (int i = 0; i < RangeLength; i++) { + A[i] = i | (i << 8) | (i << 16) | (i << 24); + B[i] = 0xFFFFFFFF; + C[i] = i; + } + for (int i = 0; i < RangeLength; i++) { + D_ref[i] = (i * -1) + (i * -1) + (i * -1) + (i * -1) + C[i]; + } + + buffer Abuf(A, range<1>(RangeLength)); + buffer Bbuf(B, range<1>(RangeLength)); + buffer Cbuf(C, range<1>(RangeLength)); + buffer Dbuf(D, range<1>(RangeLength)); + + Q.submit([&](handler &cgh) { + auto Ap = Abuf.get_access(cgh); + auto Bp = Bbuf.get_access(cgh); + auto Cp = Cbuf.get_access(cgh); + auto Dp = Dbuf.get_access(cgh); + + cgh.parallel_for(range<1>(RangeLength), [=](id<1> I) { + Dp[I] = dot_acc(Ap[I], Bp[I], Cp[I]); + }); + }); + const auto HAcc = Dbuf.get_access(); + + return verify_1D("testus D", RangeLength, D, D_ref); +} + +bool run_tests() { + queue Q([](exception_list L) { + for (auto ep : L) { + std::rethrow_exception(ep); + } + }); + + passCount = 0; + if (testss(Q)) { + ++passCount; + } + if (testuu(Q)) { + ++passCount; + } + if (testsu(Q)) { + ++passCount; + } + if (testus(Q)) { + ++passCount; + } + + auto D = Q.get_device(); + const char *devType = D.is_host() ? "Host" : D.is_cpu() ? "CPU" : "GPU"; + std::cout << passCount << " of " << testCount << " tests passed on " + << devType << "\n"; + + return (testCount == passCount); +} + +int main(int argc, char *argv[]) { + bool passed = true; + default_selector selector{}; + auto D = selector.select_device(); + const char *devType = D.is_host() ? "Host" : D.is_cpu() ? "CPU" : "GPU"; + std::cout << "Running on device " << devType << " (" + << D.get_info() << ")\n"; + passed &= run_tests(); + + if (!passed) { + std::cout << "FAILED\n"; + return 1; + } + std::cout << "PASSED\n"; + return 0; +} diff --git a/sycl/test/dot_product/dot_product_vec_test.cpp b/sycl/test/dot_product/dot_product_vec_test.cpp new file mode 100644 index 0000000000000..e6277b17f2d62 --- /dev/null +++ b/sycl/test/dot_product/dot_product_vec_test.cpp @@ -0,0 +1,260 @@ +// This test checks dp4a support with vec<> arguments +// For now we only check fallback support because DG1 hardware is not widespread + +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out +// RUN: env SYCL_DEVICE_TYPE=HOST %t.out +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out +// RUN: %ACC_RUN_PLACEHOLDER %t.out + +#include +#include +#include +#include +#include + +// Change if tests are added/removed +static int testCount = 4; +static int passCount; + +using namespace cl::sycl; +using namespace cl::sycl::detail::gtl; +using namespace cl::sycl::ONEAPI; + +constexpr int RangeLength = 100; + +// Verify 1D array +template +static bool verify_1D(const char *name, int X, T *A, T *A_ref) { + int ErrCnt = 0; + + for (int i = 0; i < X; i++) { + if (A_ref[i] != A[i]) { + if (++ErrCnt < 10) { + std::cout << name << " mismatch at " << i << ". Expected " << A_ref[i] + << " result is " << A[i] << "\n"; + } + } + } + + if (ErrCnt == 0) { + return true; + } + std::cout << " Failed. Failure rate: " << ErrCnt << "/" << X << "(" + << ErrCnt / (float)X * 100.f << "%)\n"; + return false; +} + +static bool testss(queue &Q) { + vec A[RangeLength]; + vec B[RangeLength]; + int32_t C[RangeLength]; + int32_t D[RangeLength]; + int32_t D_ref[RangeLength]; + + std::memset(D, 0, RangeLength * sizeof(int)); + std::memset(D_ref, 0, RangeLength * sizeof(int)); + + for (int i = 0; i < RangeLength; i++) { + A[i].s0() = A[i].s1() = A[i].s2() = A[i].s3() = i; + B[i].s0() = B[i].s1() = B[i].s2() = B[i].s3() = 0xFF; + C[i] = i; + } + for (int i = 0; i < RangeLength; i++) { + D_ref[i] = 4 * (i * -1) + C[i]; + } + + buffer, 1> Abuf(A, range<1>(RangeLength)); + buffer, 1> Bbuf(B, range<1>(RangeLength)); + buffer Cbuf(C, range<1>(RangeLength)); + buffer Dbuf(D, range<1>(RangeLength)); + + Q.submit([&](handler &cgh) { + auto Ap = Abuf.get_access(cgh); + auto Bp = Bbuf.get_access(cgh); + auto Cp = Cbuf.get_access(cgh); + auto Dp = Dbuf.get_access(cgh); + + cgh.parallel_for(range<1>(RangeLength), [=](id<1> I) { + Dp[I] = dot_acc(Ap[I], Bp[I], Cp[I]); + }); + }); + const auto HAcc = Dbuf.get_access(); + + return verify_1D("testss D", RangeLength, D, D_ref); +} + +static bool testuu(queue &Q) { + vec A[RangeLength]; + vec B[RangeLength]; + int32_t C[RangeLength]; + int32_t D[RangeLength]; + int32_t D_ref[RangeLength]; + + std::memset(D, 0, RangeLength * sizeof(int)); + std::memset(D_ref, 0, RangeLength * sizeof(int)); + + for (int i = 0; i < RangeLength; i++) { + A[i].s0() = A[i].s1() = A[i].s2() = A[i].s3() = i; + B[i].s0() = B[i].s1() = B[i].s2() = B[i].s3() = 0xFF; + C[i] = i; + } + for (int i = 0; i < RangeLength; i++) { + D_ref[i] = 4 * (i * 255) + C[i]; + } + + buffer, 1> Abuf(A, range<1>(RangeLength)); + buffer, 1> Bbuf(B, range<1>(RangeLength)); + buffer Cbuf(C, range<1>(RangeLength)); + buffer Dbuf(D, range<1>(RangeLength)); + + Q.submit([&](handler &cgh) { + auto Ap = Abuf.get_access(cgh); + auto Bp = Bbuf.get_access(cgh); + auto Cp = Cbuf.get_access(cgh); + auto Dp = Dbuf.get_access(cgh); + + cgh.parallel_for(range<1>(RangeLength), [=](id<1> I) { + Dp[I] = dot_acc(Ap[I], Bp[I], Cp[I]); + }); + }); + const auto HAcc = Dbuf.get_access(); + + return verify_1D("testuu D", RangeLength, D, D_ref); +} + +static bool testsu(queue &Q) { + vec A[RangeLength]; + vec B[RangeLength]; + int32_t C[RangeLength]; + int32_t D[RangeLength]; + int32_t D_ref[RangeLength]; + + std::memset(D, 0, RangeLength * sizeof(int)); + std::memset(D_ref, 0, RangeLength * sizeof(int)); + + for (int i = 0; i < RangeLength; i++) { + A[i].s0() = A[i].s1() = A[i].s2() = A[i].s3() = 0xFF; + B[i].s0() = B[i].s1() = B[i].s2() = B[i].s3() = i; + C[i] = i; + } + for (int i = 0; i < RangeLength; i++) { + D_ref[i] = 4 * (i * -1) + C[i]; + } + + buffer, 1> Abuf(A, range<1>(RangeLength)); + buffer, 1> Bbuf(B, range<1>(RangeLength)); + buffer Cbuf(C, range<1>(RangeLength)); + buffer Dbuf(D, range<1>(RangeLength)); + + Q.submit([&](handler &cgh) { + auto Ap = Abuf.get_access(cgh); + auto Bp = Bbuf.get_access(cgh); + auto Cp = Cbuf.get_access(cgh); + auto Dp = Dbuf.get_access(cgh); + + cgh.parallel_for(range<1>(RangeLength), [=](id<1> I) { + Dp[I] = dot_acc(Ap[I], Bp[I], Cp[I]); + }); + }); + const auto HAcc = Dbuf.get_access(); + + return verify_1D("testsu D", RangeLength, D, D_ref); +} + +static bool testus(queue &Q) { + vec A[RangeLength]; + vec B[RangeLength]; + int32_t C[RangeLength]; + int32_t D[RangeLength]; + int32_t D_ref[RangeLength]; + + std::memset(D, 0, RangeLength * sizeof(int)); + std::memset(D_ref, 0, RangeLength * sizeof(int)); + + for (int i = 0; i < RangeLength; i++) { + A[i].s0() = A[i].s1() = A[i].s2() = A[i].s3() = i; + B[i].s0() = B[i].s1() = B[i].s2() = B[i].s3() = 0xFF; + C[i] = i; + } + for (int i = 0; i < RangeLength; i++) { + D_ref[i] = 4 * (i * -1) + C[i]; + } + + buffer, 1> Abuf(A, range<1>(RangeLength)); + buffer, 1> Bbuf(B, range<1>(RangeLength)); + buffer Cbuf(C, range<1>(RangeLength)); + buffer Dbuf(D, range<1>(RangeLength)); + + Q.submit([&](handler &cgh) { + auto Ap = Abuf.get_access(cgh); + auto Bp = Bbuf.get_access(cgh); + auto Cp = Cbuf.get_access(cgh); + auto Dp = Dbuf.get_access(cgh); + + cgh.parallel_for(range<1>(RangeLength), [=](id<1> I) { + Dp[I] = dot_acc(Ap[I], Bp[I], Cp[I]); + }); + }); + const auto HAcc = Dbuf.get_access(); + + return verify_1D("testus D", RangeLength, D, D_ref); +} + +bool run_tests() { + queue Q([](exception_list L) { + for (auto ep : L) { + try { + std::rethrow_exception(ep); + } catch (std::exception &E) { + std::cout << "*** std exception caught:\n"; + std::cout << E.what(); + } catch (cl::sycl::exception const &E1) { + std::cout << "*** SYCL exception caught:\n"; + std::cout << E1.what(); + } + } + }); + + passCount = 0; + if (testss(Q)) { + ++passCount; + } + if (testuu(Q)) { + ++passCount; + } + if (testsu(Q)) { + ++passCount; + } + if (testus(Q)) { + ++passCount; + } + + auto D = Q.get_device(); + const char *devType = D.is_host() ? "Host" : D.is_cpu() ? "CPU" : "GPU"; + std::cout << passCount << " of " << testCount << " tests passed on " + << devType << "\n"; + + return (testCount == passCount); +} + +int main(int argc, char *argv[]) { + bool passed = true; + default_selector selector{}; + auto D = selector.select_device(); + const char *devType = D.is_host() ? "Host" : D.is_cpu() ? "CPU" : "GPU"; + std::cout << "Running on device " << devType << " (" + << D.get_info() << ")\n"; + try { + passed &= run_tests(); + } catch (exception e) { + std::cout << e.what(); + } + + if (!passed) { + std::cout << "FAILED\n"; + return 1; + } + std::cout << "PASSED\n"; + return 0; +}