From 44ffd2ca9e8180ff9cf229c216989b5714cb82b9 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Mon, 8 Apr 2019 21:23:50 +0000 Subject: [PATCH] [InstCombine] peek through fdiv to find a squared sqrt A more general canonicalization between fdiv and fmul would not handle this case because that would have to be limited by uses to prevent 2 values from becoming 3 values: (x/y) * (x/y) --> (x*x) / (y*y) (But we probably should still have that limited -- but more general -- canonicalization independently of this change.) git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@357943 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../InstCombine/InstCombineMulDivRem.cpp | 19 +++++++ test/Transforms/InstCombine/fmul-sqrt.ll | 56 ++++++++----------- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 03855f6bac4..6be2efdcc53 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -441,6 +441,25 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { return replaceInstUsesWith(I, Sqrt); } + // Like the similar transform in instsimplify, this requires 'nsz' because + // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0. + if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && + Op0->hasNUses(2)) { + // Peek through fdiv to find squaring of square root: + // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y + if (match(Op0, m_FDiv(m_Value(X), + m_Intrinsic(m_Value(Y))))) { + Value *XX = Builder.CreateFMulFMF(X, X, &I); + return BinaryOperator::CreateFDivFMF(XX, Y, &I); + } + // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X) + if (match(Op0, m_FDiv(m_Intrinsic(m_Value(Y)), + m_Value(X)))) { + Value *XX = Builder.CreateFMulFMF(X, X, &I); + return BinaryOperator::CreateFDivFMF(Y, XX, &I); + } + } + // exp(X) * exp(Y) -> exp(X + Y) // Match as long as at least one of exp has only one use. if (match(Op0, m_Intrinsic(m_Value(X))) && diff --git a/test/Transforms/InstCombine/fmul-sqrt.ll b/test/Transforms/InstCombine/fmul-sqrt.ll index 1a95c72f31d..6ab70e4d3cd 100644 --- a/test/Transforms/InstCombine/fmul-sqrt.ll +++ b/test/Transforms/InstCombine/fmul-sqrt.ll @@ -90,9 +90,7 @@ define double @sqrt_a_sqrt_b_sqrt_c_sqrt_d_reassoc(double %a, double %b, double define double @rsqrt_squared(double %x) { ; CHECK-LABEL: @rsqrt_squared( -; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X:%.*]]) -; CHECK-NEXT: [[RSQRT:%.*]] = fdiv fast double 1.000000e+00, [[SQRT]] -; CHECK-NEXT: [[SQUARED:%.*]] = fmul fast double [[RSQRT]], [[RSQRT]] +; CHECK-NEXT: [[SQUARED:%.*]] = fdiv fast double 1.000000e+00, [[X:%.*]] ; CHECK-NEXT: ret double [[SQUARED]] ; %sqrt = call fast double @llvm.sqrt.f64(double %x) @@ -103,9 +101,8 @@ define double @rsqrt_squared(double %x) { define double @sqrt_divisor_squared(double %x, double %y) { ; CHECK-LABEL: @sqrt_divisor_squared( -; CHECK-NEXT: [[SQRT:%.*]] = call double @llvm.sqrt.f64(double [[X:%.*]]) -; CHECK-NEXT: [[DIV:%.*]] = fdiv double [[Y:%.*]], [[SQRT]] -; CHECK-NEXT: [[SQUARED:%.*]] = fmul reassoc nnan nsz double [[DIV]], [[DIV]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul reassoc nnan nsz double [[Y:%.*]], [[Y]] +; CHECK-NEXT: [[SQUARED:%.*]] = fdiv reassoc nnan nsz double [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret double [[SQUARED]] ; %sqrt = call double @llvm.sqrt.f64(double %x) @@ -114,19 +111,21 @@ define double @sqrt_divisor_squared(double %x, double %y) { ret double %squared } -define double @sqrt_dividend_squared(double %x, double %y) { +define <2 x float> @sqrt_dividend_squared(<2 x float> %x, <2 x float> %y) { ; CHECK-LABEL: @sqrt_dividend_squared( -; CHECK-NEXT: [[SQRT:%.*]] = call double @llvm.sqrt.f64(double [[X:%.*]]) -; CHECK-NEXT: [[DIV:%.*]] = fdiv fast double [[SQRT]], [[Y:%.*]] -; CHECK-NEXT: [[SQUARED:%.*]] = fmul fast double [[DIV]], [[DIV]] -; CHECK-NEXT: ret double [[SQUARED]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul fast <2 x float> [[Y:%.*]], [[Y]] +; CHECK-NEXT: [[SQUARED:%.*]] = fdiv fast <2 x float> [[X:%.*]], [[TMP1]] +; CHECK-NEXT: ret <2 x float> [[SQUARED]] ; - %sqrt = call double @llvm.sqrt.f64(double %x) - %div = fdiv fast double %sqrt, %y - %squared = fmul fast double %div, %div - ret double %squared + %sqrt = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %x) + %div = fdiv fast <2 x float> %sqrt, %y + %squared = fmul fast <2 x float> %div, %div + ret <2 x float> %squared } +; We do not transform this because it would result in an extra instruction. +; This might still be a good optimization for the backend. + define double @sqrt_divisor_squared_extra_use(double %x, double %y) { ; CHECK-LABEL: @sqrt_divisor_squared_extra_use( ; CHECK-NEXT: [[SQRT:%.*]] = call double @llvm.sqrt.f64(double [[X:%.*]]) @@ -146,8 +145,8 @@ define double @sqrt_dividend_squared_extra_use(double %x, double %y) { ; CHECK-LABEL: @sqrt_dividend_squared_extra_use( ; CHECK-NEXT: [[SQRT:%.*]] = call double @llvm.sqrt.f64(double [[X:%.*]]) ; CHECK-NEXT: call void @use(double [[SQRT]]) -; CHECK-NEXT: [[DIV:%.*]] = fdiv fast double [[SQRT]], [[Y:%.*]] -; CHECK-NEXT: [[SQUARED:%.*]] = fmul fast double [[DIV]], [[DIV]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[Y:%.*]], [[Y]] +; CHECK-NEXT: [[SQUARED:%.*]] = fdiv fast double [[X]], [[TMP1]] ; CHECK-NEXT: ret double [[SQUARED]] ; %sqrt = call double @llvm.sqrt.f64(double %x) @@ -172,8 +171,12 @@ define double @sqrt_divisor_not_enough_FMF(double %x, double %y) { ret double %squared } -define double @sqrt_squared_extra_use(double %x) { -; CHECK-LABEL: @sqrt_squared_extra_use( +; TODO: This is a special-case of the general pattern. If we have a constant +; operand, the extra use limitation could be eased because this does not +; result in an extra instruction (1.0 * 1.0 is constant folded). + +define double @rsqrt_squared_extra_use(double %x) { +; CHECK-LABEL: @rsqrt_squared_extra_use( ; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[X:%.*]]) ; CHECK-NEXT: [[RSQRT:%.*]] = fdiv fast double 1.000000e+00, [[SQRT]] ; CHECK-NEXT: call void @use(double [[RSQRT]]) @@ -186,18 +189,3 @@ define double @sqrt_squared_extra_use(double %x) { %squared = fmul fast double %rsqrt, %rsqrt ret double %squared } - -; Minimal FMF to reassociate fmul+fdiv. - -define <2 x float> @sqrt_squared_vec(<2 x float> %x) { -; CHECK-LABEL: @sqrt_squared_vec( -; CHECK-NEXT: [[SQRT:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X:%.*]]) -; CHECK-NEXT: [[RSQRT:%.*]] = fdiv <2 x float> , [[SQRT]] -; CHECK-NEXT: [[SQUARED:%.*]] = fmul reassoc <2 x float> [[RSQRT]], [[RSQRT]] -; CHECK-NEXT: ret <2 x float> [[SQUARED]] -; - %sqrt = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %x) - %rsqrt = fdiv <2 x float> , %sqrt - %squared = fmul reassoc <2 x float> %rsqrt, %rsqrt - ret <2 x float> %squared -} -- 2.50.1