Skip to content
80 changes: 80 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class VectorCombine {
bool foldShuffleOfSelects(Instruction &I);
bool foldShuffleOfCastops(Instruction &I);
bool foldShuffleOfShuffles(Instruction &I);
bool foldPermuteOfIntrinsic(Instruction &I);
bool foldShuffleOfIntrinsics(Instruction &I);
bool foldShuffleToIdentity(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
Expand Down Expand Up @@ -2961,6 +2962,83 @@ bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
return true;
}

/// Try to convert
/// "shuffle (intrinsic), (poison/undef)" into "intrinsic (shuffle)".
bool VectorCombine::foldPermuteOfIntrinsic(Instruction &I) {
Value *V0;
ArrayRef<int> Mask;
if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_Undef(), m_Mask(Mask))))
return false;

auto *II0 = dyn_cast<IntrinsicInst>(V0);
if (!II0)
return false;

auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
auto *IntrinsicSrcTy = dyn_cast<FixedVectorType>(II0->getType());
if (!ShuffleDstTy || !IntrinsicSrcTy)
return false;

// Validate it's a pure permute, mask should only reference the first vector
unsigned NumSrcElts = IntrinsicSrcTy->getNumElements();
if (any_of(Mask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
return false;

Intrinsic::ID IID = II0->getIntrinsicID();
if (!isTriviallyVectorizable(IID))
return false;

// Cost analysis
InstructionCost OldCost =
TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind) +
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleDstTy,
IntrinsicSrcTy, Mask, CostKind, 0, nullptr, {V0}, &I);

SmallVector<Type *> NewArgsTy;
InstructionCost NewCost = 0;
for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
NewArgsTy.push_back(II0->getArgOperand(I)->getType());
} else {
auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
auto *ArgTy = FixedVectorType::get(VecTy->getElementType(),
ShuffleDstTy->getNumElements());
NewArgsTy.push_back(ArgTy);
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
ArgTy, VecTy, Mask, CostKind, 0, nullptr,
{II0->getArgOperand(I)});
}
}
IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind);

LLVM_DEBUG(dbgs() << "Found a permute of intrinsic: " << I << "\n OldCost: "
<< OldCost << " vs NewCost: " << NewCost << "\n");

if (NewCost > OldCost)
return false;

// Transform
SmallVector<Value *> NewArgs;
for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
NewArgs.push_back(II0->getArgOperand(I));
} else {
Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I), Mask);
NewArgs.push_back(Shuf);
Worklist.pushValue(Shuf);
}
}

Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs);

if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic))
NewInst->copyIRFlags(II0);

replaceValue(I, *NewIntrinsic);
return true;
}

using InstLane = std::pair<Use *, int>;

static InstLane lookThroughShuffles(Use *U, int Lane) {
Expand Down Expand Up @@ -4719,6 +4797,8 @@ bool VectorCombine::run() {
return true;
if (foldShuffleOfShuffles(I))
return true;
if (foldPermuteOfIntrinsic(I))
return true;
if (foldShuffleOfIntrinsics(I))
return true;
if (foldSelectShuffle(I))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -passes=vector-combine -S -mtriple=aarch64 %s | FileCheck %s

; This file tests the foldPermuteOfIntrinsic optimization which transforms:
; shuffle(intrinsic(args), poison) -> intrinsic(shuffle(args))
; when the shuffle is a permute (operates on single vector) and cost model
; determines the transformation is beneficial.

;; ============================================================================
;; Positive Tests - Should Optimize
;; ============================================================================

define <4 x i32> @extract_lower_sadd_sat(<8 x i32> %v1, <8 x i32> %v2) {
; CHECK-LABEL: @extract_lower_sadd_sat(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[V1:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[V2:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> [[TMP1]], <4 x i32> [[TMP2]])
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
;
%sat = call <8 x i32> @llvm.sadd.sat.v8i32(<8 x i32> %v1, <8 x i32> %v2)
%result = shufflevector <8 x i32> %sat, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x i32> %result
}

define <4 x i32> @extract_lower_uadd_sat(<8 x i32> %v1, <8 x i32> %v2) {
; CHECK-LABEL: @extract_lower_uadd_sat(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[V1:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[V2:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x i32> @llvm.uadd.sat.v4i32(<4 x i32> [[TMP1]], <4 x i32> [[TMP2]])
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
;
%sat = call <8 x i32> @llvm.uadd.sat.v8i32(<8 x i32> %v1, <8 x i32> %v2)
%result = shufflevector <8 x i32> %sat, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x i32> %result
}

define <4 x float> @extract_lower_fma(<8 x float> %a, <8 x float> %b, <8 x float> %c) {
; CHECK-LABEL: @extract_lower_fma(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[B:%.*]], <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <8 x float> [[C:%.*]], <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[TMP1]], <4 x float> [[TMP2]], <4 x float> [[TMP3]])
; CHECK-NEXT: ret <4 x float> [[RESULT]]
;
%fma = call <8 x float> @llvm.fma.v8f32(<8 x float> %a, <8 x float> %b, <8 x float> %c)
%result = shufflevector <8 x float> %fma, <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x float> %result
}

define <4 x i32> @extract_lower_abs_should_not_shuffle_scalar(<8 x i32> %v) {
; CHECK-LABEL: @extract_lower_abs_should_not_shuffle_scalar(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[V:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[TMP1]], i1 false)
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
;
%abs = call <8 x i32> @llvm.abs.v8i32(<8 x i32> %v, i1 false)
%result = shufflevector <8 x i32> %abs, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x i32> %result
}

define <2 x i64> @extract_lower_i64(<4 x i64> %v1, <4 x i64> %v2) {
; CHECK-LABEL: @extract_lower_i64(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i64> [[V1:%.*]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i64> [[V2:%.*]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
; CHECK-NEXT: [[RESULT:%.*]] = call <2 x i64> @llvm.sadd.sat.v2i64(<2 x i64> [[TMP1]], <2 x i64> [[TMP2]])
; CHECK-NEXT: ret <2 x i64> [[RESULT]]
;
%sat = call <4 x i64> @llvm.sadd.sat.v4i64(<4 x i64> %v1, <4 x i64> %v2)
%result = shufflevector <4 x i64> %sat, <4 x i64> poison, <2 x i32> <i32 0, i32 1>
ret <2 x i64> %result
}

define <8 x i16> @extract_lower_i16(<16 x i16> %v1, <16 x i16> %v2) {
; CHECK-LABEL: @extract_lower_i16(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <16 x i16> [[V1:%.*]], <16 x i16> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <16 x i16> [[V2:%.*]], <16 x i16> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[RESULT:%.*]] = call <8 x i16> @llvm.sadd.sat.v8i16(<8 x i16> [[TMP1]], <8 x i16> [[TMP2]])
; CHECK-NEXT: ret <8 x i16> [[RESULT]]
;
%sat = call <16 x i16> @llvm.sadd.sat.v16i16(<16 x i16> %v1, <16 x i16> %v2)
%result = shufflevector <16 x i16> %sat, <16 x i16> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
ret <8 x i16> %result
}

;; ============================================================================
;; Negative Tests - Should NOT Optimize
;; ============================================================================

define <4 x i32> @same_size_permute(<4 x i32> %v1, <4 x i32> %v2) {
; CHECK-LABEL: @same_size_permute(
; CHECK-NEXT: [[SAT:%.*]] = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> [[V1:%.*]], <4 x i32> [[V2:%.*]])
; CHECK-NEXT: [[RESULT:%.*]] = shufflevector <4 x i32> [[SAT]], <4 x i32> poison, <4 x i32> <i32 2, i32 0, i32 3, i32 1>
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
;
%sat = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> %v1, <4 x i32> %v2)
%result = shufflevector <4 x i32> %sat, <4 x i32> poison, <4 x i32> <i32 2, i32 0, i32 3, i32 1>
ret <4 x i32> %result
}

define <4 x i32> @not_a_permute_uses_second_operand(<4 x i32> %v1, <4 x i32> %v2, <4 x i32> %other) {
; CHECK-LABEL: @not_a_permute_uses_second_operand(
; CHECK-NEXT: [[SAT:%.*]] = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> [[V1:%.*]], <4 x i32> [[V2:%.*]])
; CHECK-NEXT: [[RESULT:%.*]] = shufflevector <4 x i32> [[SAT]], <4 x i32> [[OTHER:%.*]], <4 x i32> <i32 0, i32 4, i32 1, i32 5>
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
;
%sat = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> %v1, <4 x i32> %v2)
%result = shufflevector <4 x i32> %sat, <4 x i32> %other, <4 x i32> <i32 0, i32 4, i32 1, i32 5>
ret <4 x i32> %result
}

define <4 x i32> @not_an_intrinsic(<8 x i32> %v1, <8 x i32> %v2) {
; CHECK-LABEL: @not_an_intrinsic(
; CHECK-NEXT: [[ADD:%.*]] = add <8 x i32> [[V1:%.*]], [[V2:%.*]]
; CHECK-NEXT: [[RESULT:%.*]] = shufflevector <8 x i32> [[ADD]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
;
%add = add <8 x i32> %v1, %v2
%result = shufflevector <8 x i32> %add, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x i32> %result
}

declare <8 x i32> @llvm.sadd.sat.v8i32(<8 x i32>, <8 x i32>)
declare <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32>, <4 x i32>)
declare <4 x i64> @llvm.sadd.sat.v4i64(<4 x i64>, <4 x i64>)
declare <2 x i64> @llvm.sadd.sat.v2i64(<2 x i64>, <2 x i64>)
declare <16 x i16> @llvm.sadd.sat.v16i16(<16 x i16>, <16 x i16>)
declare <8 x i16> @llvm.sadd.sat.v8i16(<8 x i16>, <8 x i16>)

declare <8 x i32> @llvm.uadd.sat.v8i32(<8 x i32>, <8 x i32>)
declare <4 x i32> @llvm.uadd.sat.v4i32(<4 x i32>, <4 x i32>)

declare <8 x i32> @llvm.abs.v8i32(<8 x i32>, i1 immarg)
declare <4 x i32> @llvm.abs.v4i32(<4 x i32>, i1 immarg)

declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>)
declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>)
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ define <8 x i8> @abs_different(<8 x i8> %a) {

define <4 x i32> @poison_intrinsic(<2 x i16> %l256) {
; CHECK-LABEL: @poison_intrinsic(
; CHECK-NEXT: [[L266:%.*]] = call <2 x i16> @llvm.abs.v2i16(<2 x i16> [[L256:%.*]], i1 false)
; CHECK-NEXT: [[L267:%.*]] = shufflevector <2 x i16> [[L266]], <2 x i16> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[L271:%.*]] = zext <4 x i16> [[L267]] to <4 x i32>
; CHECK-NEXT: [[L267:%.*]] = shufflevector <2 x i16> [[L266:%.*]], <2 x i16> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[L267]], i1 false)
; CHECK-NEXT: [[L271:%.*]] = zext <4 x i16> [[TMP2]] to <4 x i32>
; CHECK-NEXT: ret <4 x i32> [[L271]]
;
%l266 = call <2 x i16> @llvm.abs.v2i16(<2 x i16> %l256, i1 false)
Expand Down
18 changes: 12 additions & 6 deletions llvm/test/Transforms/VectorCombine/X86/shuffle-of-fma-const.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
; RUN: opt < %s -passes=vector-combine -S -mtriple=x86_64-- -mcpu=x86-64-v3 | FileCheck %s --check-prefixes=CHECK,AVX

define <4 x float> @shuffle_fma_const_chain(<4 x float> %a0) {
; CHECK-LABEL: define <4 x float> @shuffle_fma_const_chain(
; CHECK-SAME: <4 x float> [[A0:%.*]]) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[F:%.*]] = tail call noundef <4 x float> @llvm.fma.v4f32(<4 x float> [[A0]], <4 x float> splat (float 0x3F8DE8D040000000), <4 x float> splat (float 0xBFB3715EE0000000))
; CHECK-NEXT: [[RES:%.*]] = shufflevector <4 x float> [[F]], <4 x float> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
; CHECK-NEXT: ret <4 x float> [[RES]]
; SSE-LABEL: define <4 x float> @shuffle_fma_const_chain(
; SSE-SAME: <4 x float> [[A0:%.*]]) #[[ATTR0:[0-9]+]] {
; SSE-NEXT: [[F:%.*]] = tail call noundef <4 x float> @llvm.fma.v4f32(<4 x float> [[A0]], <4 x float> splat (float 0x3F8DE8D040000000), <4 x float> splat (float 0xBFB3715EE0000000))
; SSE-NEXT: [[RES:%.*]] = shufflevector <4 x float> [[F]], <4 x float> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
; SSE-NEXT: ret <4 x float> [[RES]]
;
; AVX-LABEL: define <4 x float> @shuffle_fma_const_chain(
; AVX-SAME: <4 x float> [[A0:%.*]]) #[[ATTR0:[0-9]+]] {
; AVX-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[A0]], <4 x float> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
; AVX-NEXT: [[RES:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[TMP1]], <4 x float> splat (float 0x3F8DE8D040000000), <4 x float> splat (float 0xBFB3715EE0000000))
; AVX-NEXT: ret <4 x float> [[RES]]
;
%f = tail call noundef <4 x float> @llvm.fma.v4f32(<4 x float> %a0, <4 x float> splat (float 0x3F8DE8D040000000), <4 x float> splat (float 0xBFB3715EE0000000))
%res = shufflevector <4 x float> %f, <4 x float> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
Expand All @@ -16,7 +22,7 @@ define <4 x float> @shuffle_fma_const_chain(<4 x float> %a0) {

define <8 x float> @concat_fma_const_chain(<4 x float> %a0, <4 x float> %a1) {
; CHECK-LABEL: define <8 x float> @concat_fma_const_chain(
; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]]) #[[ATTR0]] {
; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]]) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[A0]], <4 x float> [[A1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[RES:%.*]] = call <8 x float> @llvm.fma.v8f32(<8 x float> [[TMP1]], <8 x float> splat (float 0x3F8DE8D040000000), <8 x float> splat (float 0xBFB3715EE0000000))
; CHECK-NEXT: ret <8 x float> [[RES]]
Expand Down
Loading