diff --git a/testdata/dnn/onnx/data/input_gemm_no_transB.npy b/testdata/dnn/onnx/data/input_gemm_no_transB.npy new file mode 100644 index 000000000..b56cdfaa5 Binary files /dev/null and b/testdata/dnn/onnx/data/input_gemm_no_transB.npy differ diff --git a/testdata/dnn/onnx/data/input_gemm_transB_0.npy b/testdata/dnn/onnx/data/input_gemm_transB_0.npy new file mode 100644 index 000000000..b56cdfaa5 Binary files /dev/null and b/testdata/dnn/onnx/data/input_gemm_transB_0.npy differ diff --git a/testdata/dnn/onnx/data/output_gemm_no_transB.npy b/testdata/dnn/onnx/data/output_gemm_no_transB.npy new file mode 100644 index 000000000..f9ea2ed37 Binary files /dev/null and b/testdata/dnn/onnx/data/output_gemm_no_transB.npy differ diff --git a/testdata/dnn/onnx/data/output_gemm_transB_0.npy b/testdata/dnn/onnx/data/output_gemm_transB_0.npy new file mode 100644 index 000000000..f9ea2ed37 Binary files /dev/null and b/testdata/dnn/onnx/data/output_gemm_transB_0.npy differ diff --git a/testdata/dnn/onnx/generate_onnx_models.py b/testdata/dnn/onnx/generate_onnx_models.py index eed78b9dd..b5ab95229 100644 --- a/testdata/dnn/onnx/generate_onnx_models.py +++ b/testdata/dnn/onnx/generate_onnx_models.py @@ -11,7 +11,7 @@ import onnxsim import google.protobuf.text_format import io - +from typing import Optional def assertExpected(s): if not (isinstance(s, str) or (sys.version_info[0] == 2 and isinstance(s, unicode))): @@ -73,6 +73,24 @@ def save_onnx_data_and_model(input, output, name, operation, *args, **kwargs): model = onnx.helper.make_model(graph, producer_name=name) onnx.save(model, models_files) +def save_data_and_onnx_model(name, input_np, output_np, onnx_model): + print(name + " input has sizes", input_np.shape) + input_files = os.path.join("data", "input_" + name) + np.save(input_files, input_np.data) + + print(name + " output has sizes", output_np.shape) + print() + output_files = os.path.join("data", "output_" + name) + np.save(output_files, np.ascontiguousarray(output_np.data)) + + models_files = os.path.join("models", name + ".onnx") + + onnx_model_pb = onnx._serialize(onnx_model) + model_def = assertONNXExpected(onnx_model_pb) + with open(models_files, 'wb') as file: + file.write(model_def.SerializeToString()) + + def simplify(name, rename=False, **kwargs): model, check = onnxsim.simplify(name, **kwargs) assert check, "couldn't valide" @@ -1665,4 +1683,46 @@ def forward(self, a, b): save_data_and_model_multy_inputs('output_registration', model, a, b) model = onnx.load('models/output_registration.onnx') model.graph.node[0].name = model.graph.output[0].name -onnx.save(model, 'models/output_registration.onnx') \ No newline at end of file +onnx.save(model, 'models/output_registration.onnx') + +# ########################## GEMM ########################## +# The original code is : https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gemm.py +def gemm_reference_implementation(A: np.ndarray, B: np.ndarray, C: Optional[np.ndarray] = None, alpha: float = 1., beta: float = 1., transA: int = 0, + transB: int = 0) -> np.ndarray: + A = A if transA == 0 else A.T + B = B if transB == 0 else B.T + C = C if C is not None else np.array(0) + + Y = alpha * np.dot(A, B) + beta * C + + return Y + +## gemm without transB +input_np = np.random.rand(2, 10).astype("float32") +inputs = [onnx.helper.make_tensor_value_info("input1", onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[input_np.dtype], shape=input_np.shape)] + +weight_np = np.random.rand(10, 3).astype("float32") +weight_tensor = onnx.helper.make_tensor('weight_tensor', data_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[weight_np.dtype], dims=weight_np.shape, vals=weight_np) + +outputs = [onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape=(2, 3))] + +nodes = [onnx.helper.make_node("Gemm", ["input1", "weight_tensor"], ["output"])] + +graph = onnx.helper.make_graph(nodes, + "gemm_test", + inputs, + outputs, initializer=[weight_tensor]) +gemm_model = onnx.helper.make_model(graph) +output_np = gemm_reference_implementation(input_np, weight_np) +save_data_and_onnx_model("gemm_no_transB", input_np, output_np, gemm_model) + +## gemm with transB = 0 + +nodes2 = [onnx.helper.make_node("Gemm", ["input1", "weight_tensor"], ["output"], transB=0)] +graph2 = onnx.helper.make_graph(nodes2, + "gemm_test", + inputs, + outputs, initializer=[weight_tensor]) +gemm_model2 = onnx.helper.make_model(graph2) +output_np = gemm_reference_implementation(input_np, weight_np) +save_data_and_onnx_model("gemm_transB_0", input_np, output_np, gemm_model2) \ No newline at end of file diff --git a/testdata/dnn/onnx/models/gemm_no_transB.onnx b/testdata/dnn/onnx/models/gemm_no_transB.onnx new file mode 100644 index 000000000..07e47ff22 --- /dev/null +++ b/testdata/dnn/onnx/models/gemm_no_transB.onnx @@ -0,0 +1,16 @@ +:� +�weight_node_outinput22"Constant*� +value*� +"x��z?��L?G�>��G?�9�=��#?4�>��q?ڗ?�N�>�s�>.4F?���>�?�<\?N�?c�?y�q?Ƌ.?k�>���>��2?��v=9�*?�+?�nW>A>��>L8�>B weight_tensor� +' +input1 +weight_node_outoutput"Gemm gemm_testZ +input1 +  + + +b +output +  + +B \ No newline at end of file diff --git a/testdata/dnn/onnx/models/gemm_transB_0.onnx b/testdata/dnn/onnx/models/gemm_transB_0.onnx new file mode 100644 index 000000000..46bf7fe4a Binary files /dev/null and b/testdata/dnn/onnx/models/gemm_transB_0.onnx differ