Skip to content

Commit 43b2b2e

Browse files
authored
Revert "Fix complex log1p accuracy with large abs values." (#88290)
Reverts #88260 The test fails on the GCC7 buildbot.
1 parent e72c949 commit 43b2b2e

File tree

2 files changed

+41
-57
lines changed

2 files changed

+41
-57
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

+24-26
Original file line numberDiff line numberDiff line change
@@ -570,39 +570,37 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
570570
ConversionPatternRewriter &rewriter) const override {
571571
auto type = cast<ComplexType>(adaptor.getComplex().getType());
572572
auto elementType = cast<FloatType>(type.getElementType());
573-
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
573+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
574574
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
575575

576-
Value real = b.create<complex::ReOp>(adaptor.getComplex());
577-
Value imag = b.create<complex::ImOp>(adaptor.getComplex());
576+
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
577+
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
578578

579579
Value half = b.create<arith::ConstantOp>(elementType,
580580
b.getFloatAttr(elementType, 0.5));
581581
Value one = b.create<arith::ConstantOp>(elementType,
582582
b.getFloatAttr(elementType, 1));
583-
Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
584-
Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
585-
Value absImag = b.create<math::AbsFOp>(imag, fmf);
586-
587-
Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
588-
Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
589-
590-
Value maxAbsOfRealPlusOneAndImagMinusOne = b.create<arith::SelectOp>(
591-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, realPlusOne, absImag,
592-
fmf),
593-
real, b.create<arith::SubFOp>(maxAbs, one, fmf));
594-
Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmf);
595-
Value logOfMaxAbsOfRealPlusOneAndImag =
596-
b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
597-
Value logOfSqrtPart = b.create<math::Log1pOp>(
598-
b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmf), fmf);
599-
Value r = b.create<arith::AddFOp>(
600-
b.create<arith::MulFOp>(half, logOfSqrtPart, fmf),
601-
logOfMaxAbsOfRealPlusOneAndImag, fmf);
602-
Value resultReal = b.create<arith::SelectOp>(
603-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmf), minAbs,
604-
r);
605-
Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
583+
Value two = b.create<arith::ConstantOp>(elementType,
584+
b.getFloatAttr(elementType, 2));
585+
586+
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
587+
// log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
588+
// log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
589+
Value sumSq = b.create<arith::MulFOp>(real, real, fmf.getValue());
590+
sumSq = b.create<arith::AddFOp>(
591+
sumSq, b.create<arith::MulFOp>(real, two, fmf.getValue()),
592+
fmf.getValue());
593+
sumSq = b.create<arith::AddFOp>(
594+
sumSq, b.create<arith::MulFOp>(imag, imag, fmf.getValue()),
595+
fmf.getValue());
596+
Value logSumSq =
597+
b.create<math::Log1pOp>(elementType, sumSq, fmf.getValue());
598+
Value resultReal = b.create<arith::MulFOp>(logSumSq, half, fmf.getValue());
599+
600+
Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf.getValue());
601+
602+
Value resultImag =
603+
b.create<math::Atan2Op>(elementType, imag, realPlusOne, fmf.getValue());
606604
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
607605
resultImag);
608606
return success();

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

+17-31
Original file line numberDiff line numberDiff line change
@@ -300,22 +300,15 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
300300
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
301301
// CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
302302
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
303+
// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
304+
// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
305+
// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] : f32
306+
// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] : f32
307+
// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
308+
// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] : f32
309+
// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] : f32
310+
// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] : f32
303311
// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] : f32
304-
// CHECK: %[[ABS_REAL_PLUS_ONE:.*]] = math.absf %[[REAL_PLUS_ONE]] : f32
305-
// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
306-
// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32
307-
// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32
308-
// CHECK: %[[CMPF:.*]] = arith.cmpf ogt, %[[REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32
309-
// CHECK: %[[MAX_MINUS_ONE:.*]] = arith.subf %[[MAX]], %cst_0 : f32
310-
// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %0, %[[MAX_MINUS_ONE]] : f32
311-
// CHECK: %[[MIN_MAX_RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
312-
// CHECK: %[[LOG_1:.*]] = math.log1p %[[SELECT]] : f32
313-
// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[MIN_MAX_RATIO]], %[[MIN_MAX_RATIO]] : f32
314-
// CHECK: %[[LOG_SQ:.*]] = math.log1p %[[RATIO_SQ]] : f32
315-
// CHECK: %[[HALF_LOG_SQ:.*]] = arith.mulf %cst, %[[LOG_SQ]] : f32
316-
// CHECK: %[[R:.*]] = arith.addf %[[HALF_LOG_SQ]], %[[LOG_1]] : f32
317-
// CHECK: %[[ISNAN:.*]] = arith.cmpf uno, %[[R]], %[[R]] : f32
318-
// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[ISNAN]], %[[MIN]], %[[R]] : f32
319312
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] : f32
320313
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
321314
// CHECK: return %[[RESULT]] : complex<f32>
@@ -970,22 +963,15 @@ func.func @complex_log1p_with_fmf(%arg: complex<f32>) -> complex<f32> {
970963
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
971964
// CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
972965
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
973-
// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
974-
// CHECK: %[[ABS_REAL_PLUS_ONE:.*]] = math.absf %[[REAL_PLUS_ONE]] fastmath<nnan,contract> : f32
975-
// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
976-
// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
977-
// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
978-
// CHECK: %[[CMPF:.*]] = arith.cmpf ogt, %[[REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
979-
// CHECK: %[[MAX_MINUS_ONE:.*]] = arith.subf %[[MAX]], %cst_0 fastmath<nnan,contract> : f32
980-
// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %0, %[[MAX_MINUS_ONE]] : f32
981-
// CHECK: %[[MIN_MAX_RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,contract> : f32
982-
// CHECK: %[[LOG_1:.*]] = math.log1p %[[SELECT]] fastmath<nnan,contract> : f32
983-
// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[MIN_MAX_RATIO]], %[[MIN_MAX_RATIO]] fastmath<nnan,contract> : f32
984-
// CHECK: %[[LOG_SQ:.*]] = math.log1p %[[RATIO_SQ]] fastmath<nnan,contract> : f32
985-
// CHECK: %[[HALF_LOG_SQ:.*]] = arith.mulf %cst, %[[LOG_SQ]] fastmath<nnan,contract> : f32
986-
// CHECK: %[[R:.*]] = arith.addf %[[HALF_LOG_SQ]], %[[LOG_1]] fastmath<nnan,contract> : f32
987-
// CHECK: %[[ISNAN:.*]] = arith.cmpf uno, %[[R]], %[[R]] fastmath<nnan,contract> : f32
988-
// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[ISNAN]], %[[MIN]], %[[R]] : f32
966+
// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
967+
// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
968+
// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] fastmath<nnan,contract> : f32
969+
// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] fastmath<nnan,contract> : f32
970+
// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
971+
// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] fastmath<nnan,contract> : f32
972+
// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] fastmath<nnan,contract> : f32
973+
// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] fastmath<nnan,contract> : f32
974+
// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
989975
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] fastmath<nnan,contract> : f32
990976
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
991977
// CHECK: return %[[RESULT]] : complex<f32>

0 commit comments

Comments
 (0)