From: Simon Pilgrim Date: Wed, 24 Apr 2019 16:53:17 +0000 (+0000) Subject: [InstCombine][X86] Use generic expansion of PACKSS/PACKUS for constant folding. NFCI. X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=7847a7c9b09d942aa613e2e4bed94a4a12e7edde;p=llvm [InstCombine][X86] Use generic expansion of PACKSS/PACKUS for constant folding. NFCI. This patch rewrites the existing PACKSS/PACKUS constant folding code to expand as a generic expansion. This is a first NFCI step toward expanding PACKSS/PACKUS intrinsics which are acting as non-saturating truncations (although technically the expansion could be used in all cases - but we'll probably want to be conservative). git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@359111 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 51c72eb1837..b1bb9281ea2 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -541,7 +541,8 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, return Builder.CreateAShr(Vec, ShiftVec); } -static Value *simplifyX86pack(IntrinsicInst &II, bool IsSigned) { +static Value *simplifyX86pack(IntrinsicInst &II, + InstCombiner::BuilderTy &Builder, bool IsSigned) { Value *Arg0 = II.getArgOperand(0); Value *Arg1 = II.getArgOperand(1); Type *ResTy = II.getType(); @@ -552,68 +553,61 @@ static Value *simplifyX86pack(IntrinsicInst &II, bool IsSigned) { Type *ArgTy = Arg0->getType(); unsigned NumLanes = ResTy->getPrimitiveSizeInBits() / 128; - unsigned NumDstElts = ResTy->getVectorNumElements(); unsigned NumSrcElts = ArgTy->getVectorNumElements(); - assert(NumDstElts == (2 * NumSrcElts) && "Unexpected packing types"); + assert(ResTy->getVectorNumElements() == (2 * NumSrcElts) && + "Unexpected packing types"); - unsigned NumDstEltsPerLane = NumDstElts / NumLanes; unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes; unsigned DstScalarSizeInBits = ResTy->getScalarSizeInBits(); - assert(ArgTy->getScalarSizeInBits() == (2 * DstScalarSizeInBits) && + unsigned SrcScalarSizeInBits = ArgTy->getScalarSizeInBits(); + assert(SrcScalarSizeInBits == (2 * DstScalarSizeInBits) && "Unexpected packing types"); // Constant folding. - auto *Cst0 = dyn_cast(Arg0); - auto *Cst1 = dyn_cast(Arg1); - if (!Cst0 || !Cst1) + if (!isa(Arg0) || !isa(Arg1)) return nullptr; - SmallVector Vals; - for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { - for (unsigned Elt = 0; Elt != NumDstEltsPerLane; ++Elt) { - unsigned SrcIdx = Lane * NumSrcEltsPerLane + Elt % NumSrcEltsPerLane; - auto *Cst = (Elt >= NumSrcEltsPerLane) ? Cst1 : Cst0; - auto *COp = Cst->getAggregateElement(SrcIdx); - if (COp && isa(COp)) { - Vals.push_back(UndefValue::get(ResTy->getScalarType())); - continue; - } + // Clamp Values - signed/unsigned both use signed clamp values, but they + // differ on the min/max values. + APInt MinValue, MaxValue; + if (IsSigned) { + // PACKSS: Truncate signed value with signed saturation. + // Source values less than dst minint are saturated to minint. + // Source values greater than dst maxint are saturated to maxint. + MinValue = + APInt::getSignedMinValue(DstScalarSizeInBits).sext(SrcScalarSizeInBits); + MaxValue = + APInt::getSignedMaxValue(DstScalarSizeInBits).sext(SrcScalarSizeInBits); + } else { + // PACKUS: Truncate signed value with unsigned saturation. + // Source values less than zero are saturated to zero. + // Source values greater than dst maxuint are saturated to maxuint. + MinValue = APInt::getNullValue(SrcScalarSizeInBits); + MaxValue = APInt::getLowBitsSet(SrcScalarSizeInBits, DstScalarSizeInBits); + } - auto *CInt = dyn_cast_or_null(COp); - if (!CInt) - return nullptr; + auto *MinC = Constant::getIntegerValue(ArgTy, MinValue); + auto *MaxC = Constant::getIntegerValue(ArgTy, MaxValue); + Arg0 = Builder.CreateSelect(Builder.CreateICmpSLT(Arg0, MinC), MinC, Arg0); + Arg1 = Builder.CreateSelect(Builder.CreateICmpSLT(Arg1, MinC), MinC, Arg1); + Arg0 = Builder.CreateSelect(Builder.CreateICmpSGT(Arg0, MaxC), MaxC, Arg0); + Arg1 = Builder.CreateSelect(Builder.CreateICmpSGT(Arg1, MaxC), MaxC, Arg1); - APInt Val = CInt->getValue(); - assert(Val.getBitWidth() == ArgTy->getScalarSizeInBits() && - "Unexpected constant bitwidth"); - - if (IsSigned) { - // PACKSS: Truncate signed value with signed saturation. - // Source values less than dst minint are saturated to minint. - // Source values greater than dst maxint are saturated to maxint. - if (Val.isSignedIntN(DstScalarSizeInBits)) - Val = Val.trunc(DstScalarSizeInBits); - else if (Val.isNegative()) - Val = APInt::getSignedMinValue(DstScalarSizeInBits); - else - Val = APInt::getSignedMaxValue(DstScalarSizeInBits); - } else { - // PACKUS: Truncate signed value with unsigned saturation. - // Source values less than zero are saturated to zero. - // Source values greater than dst maxuint are saturated to maxuint. - if (Val.isIntN(DstScalarSizeInBits)) - Val = Val.trunc(DstScalarSizeInBits); - else if (Val.isNegative()) - Val = APInt::getNullValue(DstScalarSizeInBits); - else - Val = APInt::getAllOnesValue(DstScalarSizeInBits); - } + // Truncate clamped args to dst size. + auto *TruncTy = VectorType::get(ResTy->getScalarType(), NumSrcElts); + Arg0 = Builder.CreateTrunc(Arg0, TruncTy); + Arg1 = Builder.CreateTrunc(Arg1, TruncTy); - Vals.push_back(ConstantInt::get(ResTy->getScalarType(), Val)); - } + // Shuffle args together at the lane level. + SmallVector PackMask; + for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { + for (unsigned Elt = 0; Elt != NumSrcEltsPerLane; ++Elt) + PackMask.push_back(Elt + (Lane * NumSrcEltsPerLane)); + for (unsigned Elt = 0; Elt != NumSrcEltsPerLane; ++Elt) + PackMask.push_back(Elt + (Lane * NumSrcEltsPerLane) + NumSrcElts); } - return ConstantVector::get(Vals); + return Builder.CreateShuffleVector(Arg0, Arg1, PackMask); } // Replace X86-specific intrinsics with generic floor-ceil where applicable. @@ -2977,7 +2971,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_packsswb: case Intrinsic::x86_avx512_packssdw_512: case Intrinsic::x86_avx512_packsswb_512: - if (Value *V = simplifyX86pack(*II, true)) + if (Value *V = simplifyX86pack(*II, Builder, true)) return replaceInstUsesWith(*II, V); break; @@ -2987,7 +2981,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_packuswb: case Intrinsic::x86_avx512_packusdw_512: case Intrinsic::x86_avx512_packuswb_512: - if (Value *V = simplifyX86pack(*II, false)) + if (Value *V = simplifyX86pack(*II, Builder, false)) return replaceInstUsesWith(*II, V); break;