Skip to content

Commit e17b90e

Browse files
committed
merge main
Signed-off-by: Liqun Fu <liqfu@microsoft.com>
2 parents 9eaaf77 + efa0cc2 commit e17b90e

37 files changed

Lines changed: 1167 additions & 365 deletions

docs/OperatorKernels.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,10 @@ Do not modify directly.*
157157
|||[1, 10]|**B** = tensor(bool)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
158158
|ImageScaler|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
159159
|InstanceNormalization|*in* input:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *out* output:**T**|6+|**T** = tensor(float)|
160-
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|10+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
161-
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|13+|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
160+
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
161+
|||[10, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
162+
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
163+
|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
162164
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
163165
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float)|
164166
|||[1, 12]|**T** = tensor(float)|

include/onnxruntime/core/framework/float8.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,10 @@ struct Float8E4M3FNUZ {
208208
val = static_cast<uint8_t>((b & 0x80000000) >> 24); // sign
209209
if ((b & 0x7fffffff) == 0x7f800000) { // infinity
210210
if (saturate) {
211+
// the highest available value
211212
val |= 0x7F;
212213
} else {
213-
// infinity
214+
// NaN
214215
val = 0x80;
215216
}
216217
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
@@ -362,8 +363,10 @@ struct Float8E5M2 {
362363
val = (b & 0x80000000) >> 24; // sign
363364
if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
364365
if (saturate) {
366+
// the highest available value
365367
val |= 0x7B;
366368
} else {
369+
// the infinity
367370
val |= 0x7C;
368371
}
369372
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN

include/onnxruntime/core/graph/graph.h

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include <functional>
67
#include <limits>
78
#include <memory>
89
#include <string>
@@ -83,10 +84,10 @@ class Node {
8384
gsl::span<NodeArg* const> output_args,
8485
const NodeAttributes* attributes,
8586
std::string_view domain) {
86-
Init(std::string{name}, std::string{op_type}, std::string{description},
87-
std::vector<NodeArg*>{input_args.begin(), input_args.end()},
88-
std::vector<NodeArg*>{output_args.begin(), output_args.end()},
89-
attributes, std::string{domain});
87+
Init(name, op_type, description,
88+
input_args,
89+
output_args,
90+
attributes, domain);
9091
}
9192
#endif
9293

@@ -563,13 +564,13 @@ class Node {
563564
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node);
564565

565566
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
566-
void Init(const std::string& name,
567-
const std::string& op_type,
568-
const std::string& description,
569-
const std::vector<NodeArg*>& input_args,
570-
const std::vector<NodeArg*>& output_args,
567+
void Init(std::string_view name,
568+
std::string_view op_type,
569+
std::string_view description,
570+
gsl::span<NodeArg* const> input_args,
571+
gsl::span<NodeArg* const> output_args,
571572
const NodeAttributes* attributes,
572-
const std::string& domain);
573+
std::string_view domain);
573574
#endif
574575

575576
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
@@ -1141,8 +1142,22 @@ class Graph {
11411142
*/
11421143
Status InlineFunction(Node& node);
11431144

1145+
/**
1146+
Directly insert the nodes in the function proto provided into the graph.
1147+
The function converts Constant nodes into the initializers in the graph.
1148+
It then creates a node in the graph for each of the function nodes.
1149+
All of the names are expected to be specialized, and, therefore unique.
1150+
See function_utils::Specialize().
1151+
1152+
The Graph needs to be Resolve()d after this call.
1153+
@param func_to_inline
1154+
@returns Status indicating success or providing an error message.
1155+
*/
1156+
1157+
Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline);
1158+
11441159
/** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will
1145-
be used as a GraphProto attribute in another Node..
1160+
be used as a GraphProto attribute in another Node.
11461161
e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to
11471162
define placeholder NodeArgs for outer scope values. This prevents these values from becoming explicit graph inputs
11481163
when the Graph is resolved.
@@ -1391,6 +1406,13 @@ class Graph {
13911406
Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto,
13921407
const ArgNameToTypeMap& name_to_type);
13931408

1409+
/** Helper that converts and adds constant node proto to an initializer in the graph.
1410+
@param constant_node_proto Constant node to convert
1411+
@param new_name use the new name for the initializer.
1412+
*/
1413+
Status AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& constant_node_proto,
1414+
std::optional<std::string_view> new_name);
1415+
13941416
#endif
13951417

13961418
Version IrVersion() const noexcept {

include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,16 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab
6767
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
6868
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
6969

70+
// This setting controls whether to enable AheadOfTime function inlining.
71+
// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model
72+
// as possible with the help of enabled execution providers.
73+
// This can reduce the number of function calls and improve performance because it is done before
74+
// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available,
75+
// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time.
76+
// "0": enable; "1": disable.
77+
// Its default value is "0".
78+
static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining";
79+
7080
#ifdef ENABLE_TRAINING
7181
// Specifies a list of op types for memory footprint reduction.
7282
// The value should be a ","-delimited list of pair of

js/web/lib/wasm/jsep/webgpu/ops/pool.ts

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
1818
if (!inputs || inputs.length !== 1) {
1919
throw new Error('Pool ops requires 1 input.');
2020
}
21-
if (inputs[0].dims.length !== 4) {
22-
throw new Error('Pool ops supports 2-D inputs only for now.');
21+
if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) {
22+
throw new Error('Pool ops supports 1-D or 2-D inputs only for now.');
2323
}
2424
};
2525

2626
const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
2727
input: TensorView, attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => {
2828
const isChannelsLast = attributes.format === 'NHWC';
29-
const inputShapeAsChannelFirst =
30-
isChannelsLast ? [input.dims[0], input.dims[3], input.dims[1], input.dims[2]] : input.dims.slice();
29+
const inputShapeAsChannelFirst = input.dims.slice();
30+
if (isChannelsLast) {
31+
inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position.
32+
}
3133
const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations');
3234
const kernelShape = attributes.kernelShape.slice();
3335
const strides = attributes.strides.slice();
@@ -44,15 +46,9 @@ const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePo
4446
} else {
4547
Object.assign(newAttributes, {kernelShape, strides, pads, cacheKey: attributes.cacheKey});
4648
}
47-
return [
48-
newAttributes,
49-
isChannelsLast ?
50-
[
51-
outputShapeAsChannelFirst[0], outputShapeAsChannelFirst[2], outputShapeAsChannelFirst[3],
52-
outputShapeAsChannelFirst[1]
53-
] :
54-
outputShapeAsChannelFirst
55-
];
49+
const outputShapeAsChannelLast = outputShapeAsChannelFirst.slice();
50+
outputShapeAsChannelLast.push(outputShapeAsChannelLast.splice(1, 1)[0]);
51+
return [newAttributes, isChannelsLast ? outputShapeAsChannelLast : outputShapeAsChannelFirst];
5652
};
5753

5854
const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
@@ -76,22 +72,22 @@ const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPool
7672
let codeHEnd = '';
7773
if (pwStart + pwEnd !== 0) {
7874
codeW = `
79-
for (var i: u32 = 0u; i < ${kw}u; i++) {
80-
xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
81-
if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) {
82-
pad++;
83-
continue;
84-
}
85-
let x_val = x[${x.indicesToOffset('xIndices')}];
86-
${op1}
87-
}`;
75+
for (var i: u32 = 0u; i < ${kw}u; i++) {
76+
xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
77+
if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) {
78+
pad++;
79+
continue;
80+
}
81+
let x_val = x[${x.indicesToOffset('xIndices')}];
82+
${op1}
83+
}`;
8884
} else {
8985
codeW = `
90-
for (var i: u32 = 0u; i < ${kw}u; i++) {
91-
xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
92-
let x_val = x[${x.indicesToOffset('xIndices')}];
93-
${op1}
94-
}`;
86+
for (var i: u32 = 0u; i < ${kw}u; i++) {
87+
xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
88+
let x_val = x[${x.indicesToOffset('xIndices')}];
89+
${op1}
90+
}`;
9591
}
9692

9793
if (attributes.kernelShape.length === 2) {

0 commit comments

Comments
 (0)