diff --git a/onnxruntime/core/providers/js/operators/cast.cc b/onnxruntime/core/providers/js/operators/cast.cc index f05e1eac4329c..9b6ac6d7e253b 100644 --- a/onnxruntime/core/providers/js/operators/cast.cc +++ b/onnxruntime/core/providers/js/operators/cast.cc @@ -14,8 +14,7 @@ const std::vector& CastOpTypeConstraints() { // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section // static std::vector types{ - // TODO(fs-eire): support f16 when it's ready - // DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/core/providers/js/operators/resize.cc b/onnxruntime/core/providers/js/operators/resize.cc index 7619c33a477aa..5b2e385777a37 100644 --- a/onnxruntime/core/providers/js/operators/resize.cc +++ b/onnxruntime/core/providers/js/operators/resize.cc @@ -5,15 +5,15 @@ namespace onnxruntime { namespace js { -#define REGISTER_RESIZE_VERSIONED_10_10_KERNEL(domain) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - Resize, \ - domain, \ - 10, 10, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .InputMemoryType(OrtMemTypeCPUInput, 1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REGISTER_RESIZE_VERSIONED_10_10_KERNEL(domain) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + Resize, \ + domain, \ + 10, 10, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .TypeConstraint("T", JsepSupportedFloatTypes()), \ Resize); #define REGISTER_RESIZE_VERSIONED_KERNEL(domain, sinceVersion, endVerion) \ @@ -26,22 +26,22 @@ namespace js { .InputMemoryType(OrtMemTypeCPUInput, 1) \ .InputMemoryType(OrtMemTypeCPUInput, 2) \ .InputMemoryType(OrtMemTypeCPUInput, 3) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + .TypeConstraint("T1", JsepSupportedFloatTypes()) \ + .TypeConstraint("T2", JsepSupportedFloatTypes()), \ Resize); -#define REGISTER_RESIZE_KERNEL(domain, sinceVersion) \ - ONNX_OPERATOR_KERNEL_EX( \ - Resize, \ - domain, \ - sinceVersion, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .InputMemoryType(OrtMemTypeCPUInput, 1) \ - .InputMemoryType(OrtMemTypeCPUInput, 2) \ - .InputMemoryType(OrtMemTypeCPUInput, 3) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ +#define REGISTER_RESIZE_KERNEL(domain, sinceVersion) \ + ONNX_OPERATOR_KERNEL_EX( \ + Resize, \ + domain, \ + sinceVersion, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T1", JsepSupportedFloatTypes()) \ + .TypeConstraint("T2", JsepSupportedFloatTypes()), \ Resize); #define REGISTER_RESIZE_KERNEL_DOMAIN(domain) \