Skip to content

Commit 7b0b127

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Log dtype names on input dtype mismatch (pytorch#7537)
Summary: Update the error message when input tensor scalar type is incorrect. We've seen this get hit a few times and it should be easier to debug than it is. New Message: ``` [method.cpp:834] Input 0 has unexpected scalar type: expected Float but was Byte. ``` Old Message: ``` [method.cpp:826] The 0-th input tensor's scalartype does not meet requirement: found 0 but expected 6 ``` Test Plan: Built executorch bento kernel locally and tested with an incorrect scalar type to view the new error message. ``` [method.cpp:834] Input 0 has unexpected scalar type: expected Float but was Byte. ``` I also locally patched and built the bento kernel with ET_ENABLE_ENUM_STRINGS=0. ``` [method.cpp:834] Input 0 has unexpected scalar type: expected 6 but was 0. ``` Reviewed By: digantdesai, SS-JIA Differential Revision: D67887770 Pulled By: GregoryComer
1 parent 39e8538 commit 7b0b127

File tree

7 files changed

+78
-5
lines changed

7 files changed

+78
-5
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
2+
#include <executorch/runtime/platform/log.h>
3+
4+
namespace executorch {
5+
namespace runtime {
6+
7+
/**
8+
* Convert a scalar type value to a string representation. If ET_ENABLE_ENUM_STRINGS is
9+
* set (it is on bby default), this will return a string name (for example, "Float").
10+
* Otherwise, it will return a string representation of the index value ("6").
11+
*
12+
* If the user buffer is not large enough to hold the string representation, the string
13+
* will be truncated.
14+
*
15+
* The return value is the number of characters written, or in the case of truncation,
16+
* the number of characters that would be written if the buffer was large enough.
17+
*/
18+
size_t scalar_type_to_string(::executorch::aten::ScalarType t, char* buffer, size_t buffer_size) {
19+
#if ET_ENABLE_ENUM_STRINGS
20+
const char* name_str;
21+
#define DEFINE_CASE(unused, name) \
22+
case ::executorch::aten::ScalarType::name: \
23+
name_str = #name; \
24+
break;
25+
26+
switch (t) {
27+
ET_FORALL_SCALAR_TYPES(DEFINE_CASE)
28+
default:
29+
name_str = "Unknown";
30+
break;
31+
}
32+
33+
return snprintf(buffer, buffer_size, "%s", name_str);
34+
#undef DEFINE_CASE
35+
#else
36+
return snprintf(buffer, buffer_size, "%d", static_cast<int>(t));
37+
#endif // ET_ENABLE_ENUM_TO_STRING
38+
}
39+
40+
} // namespace runtime
41+
} // namespace executorch

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,19 @@ struct promote_types {
12941294
CTYPE_ALIAS, \
12951295
__VA_ARGS__))
12961296

1297+
/**
1298+
* Convert a scalar type value to a string representation. If ET_ENABLE_ENUM_STRINGS is
1299+
* set (it is on bby default), this will return a string name (for example, "Float").
1300+
* Otherwise, it will return a string representation of the index value ("6").
1301+
*
1302+
* If the user buffer is not large enough to hold the string representation, the string
1303+
* will be truncated.
1304+
*
1305+
* The return value is the number of characters written, or in the case of truncation,
1306+
* the number of characters that would be written if the buffer was large enough.
1307+
*/
1308+
size_t scalar_type_to_string(::executorch::aten::ScalarType t, char* buffer, size_t buffer_size);
1309+
12971310
} // namespace runtime
12981311
} // namespace executorch
12991312

runtime/core/exec_aten/util/targets.bzl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@ def define_common_targets():
1919

2020
runtime.cxx_library(
2121
name = "scalar_type_util" + aten_suffix,
22-
srcs = [],
22+
srcs = ["scalar_type_util.cpp"],
2323
exported_headers = [
2424
"scalar_type_util.h",
2525
],
2626
visibility = [
2727
"//executorch/...",
2828
"@EXECUTORCH_CLIENTS",
2929
],
30+
deps = [
31+
"//executorch/runtime/platform:platform",
32+
],
3033
exported_preprocessor_flags = exported_preprocessor_flags_,
3134
exported_deps = exported_deps_,
3235
exported_external_deps = ["libtorch"] if aten_mode else [],

runtime/core/portable_type/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def define_common_targets():
4949
"scalar_type.h",
5050
"qint_types.h",
5151
"bits_types.h",
52+
"string_view.h",
5253
],
5354
visibility = [
5455
"//executorch/extension/...",

runtime/executor/method.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -816,14 +816,22 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) {
816816
if (e.isTensor()) {
817817
const auto& t_dst = e.toTensor();
818818
const auto& t_src = input_evalue.toTensor();
819+
820+
#if ET_LOG_ENABLED
821+
char dst_type_name[16];
822+
char src_type_name[16];
823+
824+
scalar_type_to_string(t_dst.scalar_type(), dst_type_name, sizeof(dst_type_name));
825+
scalar_type_to_string(t_src.scalar_type(), src_type_name, sizeof(src_type_name));
826+
#endif
827+
819828
ET_CHECK_OR_RETURN_ERROR(
820829
t_dst.scalar_type() == t_src.scalar_type(),
821830
InvalidArgument,
822-
"The %zu-th input tensor's scalartype does not meet requirement: found %" PRId8
823-
" but expected %" PRId8,
831+
"Input %zu has unexpected scalar type: expected %s but was %s.",
824832
input_idx,
825-
static_cast<int8_t>(t_src.scalar_type()),
826-
static_cast<int8_t>(t_dst.scalar_type()));
833+
dst_type_name,
834+
src_type_name);
827835
// Reset the shape for the Method's input as the size of forwarded input
828836
// tensor for shape dynamism. Also is a safety check if need memcpy.
829837
Error err = resize_tensor(t_dst, t_src.sizes());

runtime/executor/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def define_common_targets():
8282
"//executorch/runtime/core:evalue" + aten_suffix,
8383
"//executorch/runtime/core:event_tracer" + aten_suffix,
8484
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
85+
"//executorch/runtime/core/exec_aten/util:scalar_type_util" + aten_suffix,
8586
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
8687
"//executorch/runtime/kernel:kernel_runtime_context" + aten_suffix,
8788
"//executorch/runtime/kernel:operator_registry",

runtime/platform/log.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@
3333
#define ET_LOG_ENABLED 1
3434
#endif // !defined(ET_LOG_ENABLED)
3535

36+
// Enable ET_ENABLE_ENUM_STRINGS by default. This option gates inclusion of
37+
// enum string names and can be disabled by explicitly setting it to 0.
38+
// #ifndef ET_ENABLE_ENUM_STRINGS
39+
// #define ET_ENABLE_ENUM_STRINGS 1
40+
#// endif
41+
3642
namespace executorch {
3743
namespace runtime {
3844

0 commit comments

Comments
 (0)