Skip to content

Commit eb47008

Browse files
authored
[js/webgpu] FP16 Cast, Resize (#18035)
### Description <!-- Describe your changes. --> Cast/Resize with f16 are missing in vae-decoder-f16. With this change, vae-decoder-f16 becomes 315 ms from over than 1 seconds.
1 parent 688524a commit eb47008

File tree

2 files changed

+24
-25
lines changed

2 files changed

+24
-25
lines changed

onnxruntime/core/providers/js/operators/cast.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ const std::vector<MLDataType>& CastOpTypeConstraints() {
1414
// https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
1515
//
1616
static std::vector<MLDataType> types{
17-
// TODO(fs-eire): support f16 when it's ready
18-
// DataTypeImpl::GetTensorType<MLFloat16>(),
17+
DataTypeImpl::GetTensorType<MLFloat16>(),
1918
DataTypeImpl::GetTensorType<float>(),
2019
DataTypeImpl::GetTensorType<int32_t>(),
2120
DataTypeImpl::GetTensorType<uint32_t>(),

onnxruntime/core/providers/js/operators/resize.cc

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55

66
namespace onnxruntime {
77
namespace js {
8-
#define REGISTER_RESIZE_VERSIONED_10_10_KERNEL(domain) \
9-
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
10-
Resize, \
11-
domain, \
12-
10, 10, \
13-
kJsExecutionProvider, \
14-
(*KernelDefBuilder::Create()) \
15-
.InputMemoryType(OrtMemTypeCPUInput, 1) \
16-
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), \
8+
#define REGISTER_RESIZE_VERSIONED_10_10_KERNEL(domain) \
9+
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
10+
Resize, \
11+
domain, \
12+
10, 10, \
13+
kJsExecutionProvider, \
14+
(*KernelDefBuilder::Create()) \
15+
.InputMemoryType(OrtMemTypeCPUInput, 1) \
16+
.TypeConstraint("T", JsepSupportedFloatTypes()), \
1717
Resize);
1818

1919
#define REGISTER_RESIZE_VERSIONED_KERNEL(domain, sinceVersion, endVerion) \
@@ -26,22 +26,22 @@ namespace js {
2626
.InputMemoryType(OrtMemTypeCPUInput, 1) \
2727
.InputMemoryType(OrtMemTypeCPUInput, 2) \
2828
.InputMemoryType(OrtMemTypeCPUInput, 3) \
29-
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>()) \
30-
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()), \
29+
.TypeConstraint("T1", JsepSupportedFloatTypes()) \
30+
.TypeConstraint("T2", JsepSupportedFloatTypes()), \
3131
Resize);
3232

33-
#define REGISTER_RESIZE_KERNEL(domain, sinceVersion) \
34-
ONNX_OPERATOR_KERNEL_EX( \
35-
Resize, \
36-
domain, \
37-
sinceVersion, \
38-
kJsExecutionProvider, \
39-
(*KernelDefBuilder::Create()) \
40-
.InputMemoryType(OrtMemTypeCPUInput, 1) \
41-
.InputMemoryType(OrtMemTypeCPUInput, 2) \
42-
.InputMemoryType(OrtMemTypeCPUInput, 3) \
43-
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>()) \
44-
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()), \
33+
#define REGISTER_RESIZE_KERNEL(domain, sinceVersion) \
34+
ONNX_OPERATOR_KERNEL_EX( \
35+
Resize, \
36+
domain, \
37+
sinceVersion, \
38+
kJsExecutionProvider, \
39+
(*KernelDefBuilder::Create()) \
40+
.InputMemoryType(OrtMemTypeCPUInput, 1) \
41+
.InputMemoryType(OrtMemTypeCPUInput, 2) \
42+
.InputMemoryType(OrtMemTypeCPUInput, 3) \
43+
.TypeConstraint("T1", JsepSupportedFloatTypes()) \
44+
.TypeConstraint("T2", JsepSupportedFloatTypes()), \
4545
Resize);
4646

4747
#define REGISTER_RESIZE_KERNEL_DOMAIN(domain) \

0 commit comments

Comments
 (0)