From 6aff16193f26e183efe3bce605164424fd4ff7db Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sat, 19 Jul 2025 04:25:39 -0500 Subject: [PATCH 1/2] fix --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index bb0760b39..7f6f8092e 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -12723,6 +12723,7 @@ bool isLegalConcatToOneDimDUS(stablehlo::ConcatenateOp outer, return false; } } + rhs = rhsSlice; } if (!lhs && !rhs) { From 1394db07c8c993271928b9dba156a767fae600c5 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sat, 19 Jul 2025 07:55:49 -0500 Subject: [PATCH 2/2] add test --- test/lit_tests/concat_to_dus3.mlir | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 test/lit_tests/concat_to_dus3.mlir diff --git a/test/lit_tests/concat_to_dus3.mlir b/test/lit_tests/concat_to_dus3.mlir new file mode 100644 index 000000000..b47091641 --- /dev/null +++ b/test/lit_tests/concat_to_dus3.mlir @@ -0,0 +1,15 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=concat_to_onedim_dus" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s + +func.func @main(%arg0: tensor<100000x768xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<768xf32>) -> tensor<100000x768xf32> { + %0 = stablehlo.reshape %arg1 : (tensor<768xf32>) -> tensor<1x768xf32> + %1 = stablehlo.slice %arg0 [1:100000, 0:768] : (tensor<100000x768xf32>) -> tensor<99999x768xf32> + %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<1x768xf32>, tensor<99999x768xf32>) -> tensor<100000x768xf32> + return %2 : tensor<100000x768xf32> +} + +// CHECK: func.func @main(%arg0: tensor<100000x768xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<768xf32>) -> tensor<100000x768xf32> { +// CHECK-NEXT: %c = stablehlo.constant dense<0> : tensor +// CHECK-NEXT: %0 = stablehlo.reshape %arg1 : (tensor<768xf32>) -> tensor<1x768xf32> +// CHECK-NEXT: %1 = stablehlo.dynamic_update_slice %arg0, %0, %c, %c : (tensor<100000x768xf32>, tensor<1x768xf32>, tensor, tensor) -> tensor<100000x768xf32> +// CHECK-NEXT: return %1 : tensor<100000x768xf32> +// CHECK-NEXT: }