Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,15 @@ declare_mlir_dialect_extension_python_bindings(
"../../include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
)

declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/SparseTensorTransformOps.td
SOURCES
dialects/transform/sparse_tensor.py
DIALECT_NAME transform
EXTENSION_NAME sparse_tensor_transform)

declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down
14 changes: 14 additions & 0 deletions mlir/python/mlir/dialects/SparseTensorTransformOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//===-- SparseTensorTransfromOps.td ------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_SPARSE_TENSOR_TRANSFORM_OPS
#define PYTHON_BINDINGS_SPARSE_TENSOR_TRANSFORM_OPS

include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td"

#endif
5 changes: 5 additions & 0 deletions mlir/python/mlir/dialects/transform/sparse_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .._sparse_tensor_transform_ops_gen import *
31 changes: 31 additions & 0 deletions mlir/test/python/dialects/transform_sparse_tensor_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import sparse_tensor


def run(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
f(sequence.bodyTarget)
transform.YieldOp()
print("\nTEST:", f.__name__)
print(module)
return f


@run
def testMatchSparseInOut(target):
sparse_tensor.MatchSparseInOut(transform.AnyOpType.get(), target)
# CHECK-LABEL: TEST: testMatchSparseInOut
# CHECK: transform.sequence
# CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op):
# CHECK-NEXT: transform.sparse_tensor.match.sparse_inout %[[ARG0]]
20 changes: 20 additions & 0 deletions utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,25 @@ gentbl_filegroup(
],
)

gentbl_filegroup(
name = "SparseTensorTransformOpsPyGen",
tbl_outs = [
(
[
"-gen-python-op-bindings",
"-bind-dialect=transform",
"-dialect-extension=sparse_tensor_transform",
],
"mlir/dialects/_sparse_tensor_transform_ops_gen.py",
),
],
tblgen = "//mlir:mlir-tblgen",
td_file = "mlir/dialects/SparseTensorTransformOps.td",
deps = [
"//mlir:SparseTensorTransformOpsTdFiles",
],
)

gentbl_filegroup(
name = "TensorTransformOpsPyGen",
tbl_outs = [
Expand Down Expand Up @@ -1309,6 +1328,7 @@ filegroup(
":LoopTransformOpsPyGen",
":MemRefTransformOpsPyGen",
":PDLTransformOpsPyGen",
":SparseTensorTransformOpsPyGen",
":StructureTransformEnumPyGen",
":StructuredTransformOpsPyGen",
":TensorTransformOpsPyGen",
Expand Down