82#define DEBUG_TYPE "complex-deinterleaving"
84STATISTIC(NumComplexTransformations,
"Amount of complex patterns transformed");
87 "enable-complex-deinterleaving",
115 Value *Real =
nullptr;
116 Value *Imag =
nullptr;
119 return Real ==
Other.Real && Imag ==
Other.Imag;
138 static bool isEqual(
const ComplexValue &LHS,
const ComplexValue &RHS) {
139 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
144template <
typename T,
typename IterT>
145std::optional<T> findCommonBetweenCollections(IterT
A, IterT
B) {
147 if (Common !=
A.end())
148 return std::make_optional(*Common);
152class ComplexDeinterleavingLegacyPass :
public FunctionPass {
156 ComplexDeinterleavingLegacyPass(
const TargetMachine *TM =
nullptr)
157 : FunctionPass(ID), TM(TM) {}
159 StringRef getPassName()
const override {
160 return "Complex Deinterleaving Pass";
164 void getAnalysisUsage(AnalysisUsage &AU)
const override {
170 const TargetMachine *TM;
173class ComplexDeinterleavingGraph;
174struct ComplexDeinterleavingCompositeNode {
179 Vals.push_back({
R,
I});
184 : Operation(
Op), Vals(
Other) {}
187 friend class ComplexDeinterleavingGraph;
188 using CompositeNode = ComplexDeinterleavingCompositeNode;
189 bool OperandsValid =
true;
198 std::optional<FastMathFlags> Flags;
201 ComplexDeinterleavingRotation::Rotation_0;
203 Value *ReplacementNode =
nullptr;
207 OperandsValid =
false;
208 Operands.push_back(Node);
212 void dump(raw_ostream &OS) {
213 auto PrintValue = [&](
Value *
V) {
221 auto PrintNodeRef = [&](CompositeNode *Ptr) {
228 OS <<
"- CompositeNode: " <<
this <<
"\n";
229 for (
unsigned I = 0;
I < Vals.size();
I++) {
230 OS <<
" Real(" <<
I <<
") : ";
231 PrintValue(Vals[
I].Real);
232 OS <<
" Imag(" <<
I <<
") : ";
233 PrintValue(Vals[
I].Imag);
235 OS <<
" ReplacementNode: ";
236 PrintValue(ReplacementNode);
237 OS <<
" Operation: " << (int)Operation <<
"\n";
238 OS <<
" Rotation: " << ((int)Rotation * 90) <<
"\n";
239 OS <<
" Operands: \n";
240 for (
const auto &
Op : Operands) {
246 bool areOperandsValid() {
return OperandsValid; }
249class ComplexDeinterleavingGraph {
257 using Addend = std::pair<Value *, bool>;
259 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
263 struct PartialMulCandidate {
271 explicit ComplexDeinterleavingGraph(
const TargetLowering *TL,
272 const TargetLibraryInfo *TLI,
274 : TL(TL), TLI(TLI), Factor(Factor) {}
277 const TargetLowering *TL =
nullptr;
278 const TargetLibraryInfo *TLI =
nullptr;
281 DenseMap<ComplexValues, CompositeNode *> CachedResult;
282 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
284 SmallPtrSet<Instruction *, 16> FinalInstructions;
287 DenseMap<Instruction *, CompositeNode *> RootToNode;
314 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
322 PHINode *RealPHI =
nullptr;
323 PHINode *ImagPHI =
nullptr;
327 bool PHIsFound =
false;
335 DenseMap<PHINode *, PHINode *> OldToNewPHI;
340 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
342 "Reduction related nodes must have Real and Imaginary parts");
343 return new (Allocator.Allocate())
344 ComplexDeinterleavingCompositeNode(
Operation, R,
I);
350 for (
auto &V : Vals) {
352 ((
Operation != ComplexDeinterleavingOperation::ReductionPHI &&
353 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
354 (
V.Real &&
V.Imag)) &&
355 "Reduction related nodes must have Real and Imaginary parts");
358 return new (Allocator.Allocate())
359 ComplexDeinterleavingCompositeNode(
Operation, Vals);
362 CompositeNode *submitCompositeNode(CompositeNode *Node) {
363 CompositeNodes.push_back(Node);
364 if (
Node->Vals[0].Real)
380 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
386 identifyNodeWithImplicitAdd(Instruction *
I, Instruction *J,
387 std::pair<Value *, Value *> &CommonOperandI);
396 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
397 CompositeNode *identifySymmetricOperation(
ComplexValues &Vals);
398 CompositeNode *identifyPartialReduction(
Value *R,
Value *
I);
399 CompositeNode *identifyDotProduct(
Value *Inst);
406 return identifyNode(Vals);
413 CompositeNode *identifyAdditions(AddendList &RealAddends,
414 AddendList &ImagAddends,
415 std::optional<FastMathFlags> Flags,
419 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
420 AddendList &ImagAddends);
425 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
426 SmallVectorImpl<Product> &ImagMuls,
434 SmallVectorImpl<PartialMulCandidate> &Candidates);
442 CompositeNode *identifyReassocNodes(Instruction *
I, Instruction *J);
444 CompositeNode *identifyRoot(Instruction *
I);
462 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
466 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
468 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
475 void processReductionOperation(
Value *OperationReplacement,
476 CompositeNode *Node);
477 void processReductionSingle(
Value *OperationReplacement, CompositeNode *Node);
481 void dump(raw_ostream &OS) {
482 for (
const auto &Node : CompositeNodes)
488 bool identifyNodes(Instruction *RootI);
493 bool collectPotentialReductions(BasicBlock *
B);
495 void identifyReductionNodes();
505class ComplexDeinterleaving {
507 ComplexDeinterleaving(
const TargetLowering *tl,
const TargetLibraryInfo *tli)
508 : TL(tl), TLI(tli) {}
512 bool evaluateBasicBlock(BasicBlock *
B,
unsigned Factor);
514 const TargetLowering *TL =
nullptr;
515 const TargetLibraryInfo *TLI =
nullptr;
520char ComplexDeinterleavingLegacyPass::ID = 0;
523 "Complex Deinterleaving",
false,
false)
529 const TargetLowering *TL = TM->getSubtargetImpl(
F)->getTargetLowering();
540 return new ComplexDeinterleavingLegacyPass(TM);
543bool ComplexDeinterleavingLegacyPass::runOnFunction(
Function &
F) {
545 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
546 return ComplexDeinterleaving(TL, &TLI).runOnFunction(
F);
549bool ComplexDeinterleaving::runOnFunction(Function &
F) {
552 dbgs() <<
"Complex deinterleaving has been explicitly disabled.\n");
558 dbgs() <<
"Complex deinterleaving has been disabled, target does "
559 "not support lowering of complex number operations.\n");
565 Changed |= evaluateBasicBlock(&
B, 2);
570 Changed |= evaluateBasicBlock(&
B, 4);
580 if ((Mask.size() & 1))
583 int HalfNumElements = Mask.size() / 2;
584 for (
int Idx = 0; Idx < HalfNumElements; ++Idx) {
585 int MaskIdx = Idx * 2;
586 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
595 int HalfNumElements = Mask.size() / 2;
597 for (
int Idx = 1; Idx < HalfNumElements; ++Idx) {
598 if (Mask[Idx] != (Idx * 2) +
Offset)
612 if (
I->getOpcode() == Instruction::FNeg)
613 return I->getOperand(0);
615 return I->getOperand(1);
618bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *
B,
unsigned Factor) {
619 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
620 if (Graph.collectPotentialReductions(
B))
621 Graph.identifyReductionNodes();
624 Graph.identifyNodes(&
I);
626 if (Graph.checkNodes()) {
627 Graph.replaceNodes();
634ComplexDeinterleavingGraph::CompositeNode *
635ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
636 Instruction *Real, Instruction *Imag,
637 std::pair<Value *, Value *> &PartialMatch) {
638 LLVM_DEBUG(
dbgs() <<
"identifyNodeWithImplicitAdd " << *Real <<
" / " << *Imag
646 if ((Real->
getOpcode() != Instruction::FMul &&
647 Real->
getOpcode() != Instruction::Mul) ||
648 (Imag->
getOpcode() != Instruction::FMul &&
649 Imag->
getOpcode() != Instruction::Mul)) {
651 dbgs() <<
" - Real or imaginary instruction is not fmul or mul\n");
684 Value *CommonOperand;
685 Value *UncommonRealOp;
686 Value *UncommonImagOp;
688 if (R0 == I0 || R0 == I1) {
691 }
else if (R1 == I0 || R1 == I1) {
699 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
700 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
701 Rotation == ComplexDeinterleavingRotation::Rotation_270)
702 std::swap(UncommonRealOp, UncommonImagOp);
706 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
707 Rotation == ComplexDeinterleavingRotation::Rotation_180)
708 PartialMatch.first = CommonOperand;
710 PartialMatch.second = CommonOperand;
712 if (!PartialMatch.first || !PartialMatch.second) {
717 CompositeNode *CommonNode =
718 identifyNode(PartialMatch.first, PartialMatch.second);
724 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
730 CompositeNode *
Node = prepareCompositeNode(
731 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
732 Node->Rotation = Rotation;
733 Node->addOperand(CommonNode);
734 Node->addOperand(UncommonNode);
735 return submitCompositeNode(Node);
738ComplexDeinterleavingGraph::CompositeNode *
739ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
741 LLVM_DEBUG(
dbgs() <<
"identifyPartialMul " << *Real <<
" / " << *Imag
745 auto IsAdd = [](
unsigned Op) {
746 return Op == Instruction::FAdd ||
Op == Instruction::Add;
748 auto IsSub = [](
unsigned Op) {
749 return Op == Instruction::FSub ||
Op == Instruction::Sub;
753 Rotation = ComplexDeinterleavingRotation::Rotation_0;
755 Rotation = ComplexDeinterleavingRotation::Rotation_90;
757 Rotation = ComplexDeinterleavingRotation::Rotation_180;
759 Rotation = ComplexDeinterleavingRotation::Rotation_270;
768 LLVM_DEBUG(
dbgs() <<
" - Contract is missing from the FastMath flags.\n");
791 Value *CommonOperand;
792 Value *UncommonRealOp;
793 Value *UncommonImagOp;
795 if (R0 == I0 || R0 == I1) {
798 }
else if (R1 == I0 || R1 == I1) {
806 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
807 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
808 Rotation == ComplexDeinterleavingRotation::Rotation_270)
809 std::swap(UncommonRealOp, UncommonImagOp);
811 std::pair<Value *, Value *> PartialMatch(
812 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
813 Rotation == ComplexDeinterleavingRotation::Rotation_180)
816 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
817 Rotation == ComplexDeinterleavingRotation::Rotation_270)
824 if (!CRInst || !CIInst) {
825 LLVM_DEBUG(
dbgs() <<
" - Common operands are not instructions.\n");
829 CompositeNode *CNode =
830 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
836 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
842 assert(PartialMatch.first && PartialMatch.second);
843 CompositeNode *CommonRes =
844 identifyNode(PartialMatch.first, PartialMatch.second);
850 CompositeNode *
Node = prepareCompositeNode(
851 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
852 Node->Rotation = Rotation;
853 Node->addOperand(CommonRes);
854 Node->addOperand(UncommonRes);
855 Node->addOperand(CNode);
856 return submitCompositeNode(Node);
859ComplexDeinterleavingGraph::CompositeNode *
860ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
861 LLVM_DEBUG(
dbgs() <<
"identifyAdd " << *Real <<
" / " << *Imag <<
"\n");
865 if ((Real->
getOpcode() == Instruction::FSub &&
866 Imag->
getOpcode() == Instruction::FAdd) ||
867 (Real->
getOpcode() == Instruction::Sub &&
869 Rotation = ComplexDeinterleavingRotation::Rotation_90;
870 else if ((Real->
getOpcode() == Instruction::FAdd &&
871 Imag->
getOpcode() == Instruction::FSub) ||
872 (Real->
getOpcode() == Instruction::Add &&
874 Rotation = ComplexDeinterleavingRotation::Rotation_270;
876 LLVM_DEBUG(
dbgs() <<
" - Unhandled case, rotation is not assigned.\n");
885 if (!AR || !AI || !BR || !BI) {
890 CompositeNode *ResA = identifyNode(AR, AI);
892 LLVM_DEBUG(
dbgs() <<
" - AR/AI is not identified as a composite node.\n");
895 CompositeNode *ResB = identifyNode(BR, BI);
897 LLVM_DEBUG(
dbgs() <<
" - BR/BI is not identified as a composite node.\n");
901 CompositeNode *
Node =
902 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
903 Node->Rotation = Rotation;
904 Node->addOperand(ResA);
905 Node->addOperand(ResB);
906 return submitCompositeNode(Node);
910 unsigned OpcA =
A->getOpcode();
911 unsigned OpcB =
B->getOpcode();
913 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
914 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
915 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
916 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
927 switch (
I->getOpcode()) {
928 case Instruction::FAdd:
929 case Instruction::FSub:
930 case Instruction::FMul:
931 case Instruction::FNeg:
932 case Instruction::Add:
933 case Instruction::Sub:
934 case Instruction::Mul:
941ComplexDeinterleavingGraph::CompositeNode *
942ComplexDeinterleavingGraph::identifySymmetricOperation(
ComplexValues &Vals) {
944 unsigned FirstOpc = FirstReal->getOpcode();
945 for (
auto &V : Vals) {
962 for (
auto &V : Vals) {
968 CompositeNode *Op0 = identifyNode(OpVals);
969 CompositeNode *Op1 =
nullptr;
973 if (FirstReal->isBinaryOp()) {
975 for (
auto &V : Vals) {
980 Op1 = identifyNode(OpVals);
986 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
987 Node->Opcode = FirstReal->getOpcode();
989 Node->Flags = FirstReal->getFastMathFlags();
991 Node->addOperand(Op0);
992 if (FirstReal->isBinaryOp())
993 Node->addOperand(Op1);
995 return submitCompositeNode(Node);
998ComplexDeinterleavingGraph::CompositeNode *
999ComplexDeinterleavingGraph::identifyDotProduct(
Value *V) {
1001 ComplexDeinterleavingOperation::CDot,
V->getType())) {
1002 LLVM_DEBUG(
dbgs() <<
"Target doesn't support complex deinterleaving "
1003 "operation CDot with the type "
1004 << *
V->getType() <<
"\n");
1012 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst,
nullptr);
1014 CompositeNode *ANode =
nullptr;
1016 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1018 Value *AReal =
nullptr;
1019 Value *AImag =
nullptr;
1020 Value *BReal =
nullptr;
1021 Value *BImag =
nullptr;
1026 return CI->getOperand(0);
1040 if (
match(Inst, PatternRot0)) {
1041 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1042 }
else if (
match(Inst, PatternRot270)) {
1043 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1054 if (!
match(Inst, PatternRot90Rot180))
1057 A0 = UnwrapCast(A0);
1058 A1 = UnwrapCast(A1);
1061 ANode = identifyNode(A0, A1);
1064 ANode = identifyNode(A1, A0);
1068 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1074 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1078 AReal = UnwrapCast(AReal);
1079 AImag = UnwrapCast(AImag);
1080 BReal = UnwrapCast(BReal);
1081 BImag = UnwrapCast(BImag);
1084 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1085 if (AReal->
getType() != ExpectedOperandTy)
1087 if (AImag->
getType() != ExpectedOperandTy)
1089 if (BReal->
getType() != ExpectedOperandTy)
1091 if (BImag->
getType() != ExpectedOperandTy)
1094 if (
Phi->getType() != VTy && RealUser->getType() != VTy)
1097 CompositeNode *
Node = identifyNode(AReal, AImag);
1102 if (ANode && Node != ANode) {
1105 <<
"Identified node is different from previously identified node. "
1106 "Unable to confidently generate a complex operation node\n");
1110 CN->addOperand(Node);
1111 CN->addOperand(identifyNode(BReal, BImag));
1112 CN->addOperand(identifyNode(Phi, RealUser));
1114 return submitCompositeNode(CN);
1117ComplexDeinterleavingGraph::CompositeNode *
1118ComplexDeinterleavingGraph::identifyPartialReduction(
Value *R,
Value *
I) {
1123 if (!
R->hasUseList() || !
I->hasUseList())
1127 findCommonBetweenCollections<Value *>(
R->users(),
I->users());
1132 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1135 if (CompositeNode *CN = identifyDotProduct(IInst))
1141ComplexDeinterleavingGraph::CompositeNode *
1142ComplexDeinterleavingGraph::identifyNode(
ComplexValues &Vals) {
1143 auto It = CachedResult.
find(Vals);
1144 if (It != CachedResult.
end()) {
1149 if (Vals.
size() == 1) {
1150 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1153 if (CompositeNode *CN = identifyPartialReduction(R,
I))
1155 bool IsReduction = RealPHI ==
R && (!ImagPHI || ImagPHI ==
I);
1156 if (!IsReduction &&
R->getType() !=
I->getType())
1160 if (CompositeNode *CN = identifySplat(Vals))
1163 for (
auto &V : Vals) {
1170 if (CompositeNode *CN = identifyDeinterleave(Vals))
1173 if (Vals.size() == 1) {
1174 assert(Factor == 2 &&
"Can only handle interleave factors of 2");
1177 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1180 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1184 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1187 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1189 ComplexDeinterleavingOperation::CAdd, NewVTy);
1192 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1197 if (CompositeNode *CN = identifyAdd(Real, Imag))
1201 if (HasCMulSupport && HasCAddSupport) {
1202 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1208 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1212 CachedResult[Vals] =
nullptr;
1216ComplexDeinterleavingGraph::CompositeNode *
1217ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1218 Instruction *Imag) {
1219 auto IsOperationSupported = [](
unsigned Opcode) ->
bool {
1220 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1221 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1222 Opcode == Instruction::Sub;
1225 if (!IsOperationSupported(Real->
getOpcode()) ||
1226 !IsOperationSupported(Imag->
getOpcode()))
1229 std::optional<FastMathFlags>
Flags;
1232 LLVM_DEBUG(
dbgs() <<
"The flags in Real and Imaginary instructions are "
1238 if (!
Flags->allowReassoc()) {
1241 <<
"the 'Reassoc' attribute is missing in the FastMath flags\n");
1250 AddendList &Addends) ->
bool {
1252 SmallPtrSet<Value *, 8> Visited;
1253 while (!Worklist.
empty()) {
1255 if (!Visited.
insert(V).second)
1260 Addends.emplace_back(V, IsPositive);
1270 if (
I != Insn &&
I->hasNUsesOrMore(2)) {
1271 LLVM_DEBUG(
dbgs() <<
"Found potential sub-expression: " << *
I <<
"\n");
1272 Addends.emplace_back(
I, IsPositive);
1275 switch (
I->getOpcode()) {
1276 case Instruction::FAdd:
1277 case Instruction::Add:
1281 case Instruction::FSub:
1285 case Instruction::Sub:
1293 case Instruction::FMul:
1294 case Instruction::Mul: {
1296 if (
isNeg(
I->getOperand(0))) {
1298 IsPositive = !IsPositive;
1300 A =
I->getOperand(0);
1303 if (
isNeg(
I->getOperand(1))) {
1305 IsPositive = !IsPositive;
1307 B =
I->getOperand(1);
1309 Muls.push_back(Product{
A,
B, IsPositive});
1312 case Instruction::FNeg:
1316 Addends.emplace_back(
I, IsPositive);
1320 if (Flags &&
I->getFastMathFlags() != *Flags) {
1322 "inconsistent with the root instructions' flags: "
1331 AddendList RealAddends, ImagAddends;
1332 if (!Collect(Real, RealMuls, RealAddends) ||
1333 !Collect(Imag, ImagMuls, ImagAddends))
1336 if (RealAddends.size() != ImagAddends.size())
1339 CompositeNode *FinalNode =
nullptr;
1340 if (!RealMuls.
empty() || !ImagMuls.
empty()) {
1343 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1344 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1350 if (!RealAddends.empty() || !ImagAddends.empty()) {
1351 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1355 assert(FinalNode &&
"FinalNode can not be nullptr here");
1356 assert(FinalNode->Vals.size() == 1);
1358 FinalNode->Vals[0].Real = Real;
1359 FinalNode->Vals[0].Imag = Imag;
1360 submitCompositeNode(FinalNode);
1364bool ComplexDeinterleavingGraph::collectPartialMuls(
1366 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1368 auto FindCommonInstruction = [](
const Product &Real,
1369 const Product &Imag) ->
Value * {
1370 if (Real.Multiplicand == Imag.Multiplicand ||
1371 Real.Multiplicand == Imag.Multiplier)
1372 return Real.Multiplicand;
1374 if (Real.Multiplier == Imag.Multiplicand ||
1375 Real.Multiplier == Imag.Multiplier)
1376 return Real.Multiplier;
1385 for (
unsigned i = 0; i < RealMuls.
size(); ++i) {
1386 bool FoundCommon =
false;
1387 for (
unsigned j = 0;
j < ImagMuls.
size(); ++
j) {
1388 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1392 auto *
A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1393 : RealMuls[i].Multiplicand;
1394 auto *
B = ImagMuls[
j].Multiplicand == Common ? ImagMuls[
j].Multiplier
1395 : ImagMuls[
j].Multiplicand;
1397 auto Node = identifyNode(
A,
B);
1403 Node = identifyNode(
B,
A);
1415ComplexDeinterleavingGraph::CompositeNode *
1416ComplexDeinterleavingGraph::identifyMultiplications(
1417 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1419 if (RealMuls.
size() != ImagMuls.
size())
1423 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1427 DenseMap<Value *, CompositeNode *> CommonToNode;
1428 SmallVector<bool> Processed(
Info.size(),
false);
1429 for (
unsigned I = 0;
I <
Info.size(); ++
I) {
1433 PartialMulCandidate &InfoA =
Info[
I];
1434 for (
unsigned J =
I + 1; J <
Info.size(); ++J) {
1438 PartialMulCandidate &InfoB =
Info[J];
1439 auto *InfoReal = &InfoA;
1440 auto *InfoImag = &InfoB;
1442 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1443 if (!NodeFromCommon) {
1445 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1447 if (!NodeFromCommon)
1450 CommonToNode[InfoReal->Common] = NodeFromCommon;
1451 CommonToNode[InfoImag->Common] = NodeFromCommon;
1452 Processed[
I] =
true;
1453 Processed[J] =
true;
1457 SmallVector<bool> ProcessedReal(RealMuls.
size(),
false);
1458 SmallVector<bool> ProcessedImag(ImagMuls.
size(),
false);
1460 for (
auto &PMI : Info) {
1461 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1464 auto It = CommonToNode.
find(PMI.Common);
1467 if (It == CommonToNode.
end()) {
1469 dbgs() <<
"Unprocessed independent partial multiplication:\n";
1470 for (
auto *
Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1472 <<
" multiplied by " << *
Mul->Multiplicand <<
"\n";
1477 auto &RealMul = RealMuls[PMI.RealIdx];
1478 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1480 auto NodeA = It->second;
1481 auto NodeB = PMI.Node;
1482 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1497 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1498 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1503 if (IsMultiplicandReal) {
1505 if (RealMul.IsPositive && ImagMul.IsPositive)
1507 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1514 if (!RealMul.IsPositive && ImagMul.IsPositive)
1516 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1523 dbgs() <<
"Identified partial multiplication (X, Y) * (U, V):\n";
1524 dbgs().
indent(4) <<
"X: " << *NodeA->Vals[0].Real <<
"\n";
1525 dbgs().
indent(4) <<
"Y: " << *NodeA->Vals[0].Imag <<
"\n";
1526 dbgs().
indent(4) <<
"U: " << *NodeB->Vals[0].Real <<
"\n";
1527 dbgs().
indent(4) <<
"V: " << *NodeB->Vals[0].Imag <<
"\n";
1528 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1531 CompositeNode *NodeMul = prepareCompositeNode(
1532 ComplexDeinterleavingOperation::CMulPartial,
nullptr,
nullptr);
1533 NodeMul->Rotation = Rotation;
1534 NodeMul->addOperand(NodeA);
1535 NodeMul->addOperand(NodeB);
1537 NodeMul->addOperand(Result);
1538 submitCompositeNode(NodeMul);
1540 ProcessedReal[PMI.RealIdx] =
true;
1541 ProcessedImag[PMI.ImagIdx] =
true;
1545 if (!
all_of(ProcessedReal, [](
bool V) {
return V; }) ||
1546 !
all_of(ProcessedImag, [](
bool V) {
return V; })) {
1551 dbgs() <<
"Unprocessed products (Real):\n";
1552 for (
size_t i = 0; i < ProcessedReal.size(); ++i) {
1553 if (!ProcessedReal[i])
1554 dbgs().
indent(4) << (RealMuls[i].IsPositive ?
"+" :
"-")
1555 << *RealMuls[i].Multiplier <<
" multiplied by "
1556 << *RealMuls[i].Multiplicand <<
"\n";
1558 dbgs() <<
"Unprocessed products (Imag):\n";
1559 for (
size_t i = 0; i < ProcessedImag.size(); ++i) {
1560 if (!ProcessedImag[i])
1561 dbgs().
indent(4) << (ImagMuls[i].IsPositive ?
"+" :
"-")
1562 << *ImagMuls[i].Multiplier <<
" multiplied by "
1563 << *ImagMuls[i].Multiplicand <<
"\n";
1572ComplexDeinterleavingGraph::CompositeNode *
1573ComplexDeinterleavingGraph::identifyAdditions(
1574 AddendList &RealAddends, AddendList &ImagAddends,
1575 std::optional<FastMathFlags> Flags, CompositeNode *
Accumulator =
nullptr) {
1576 if (RealAddends.size() != ImagAddends.size())
1579 CompositeNode *
Result =
nullptr;
1585 Result = extractPositiveAddend(RealAddends, ImagAddends);
1590 while (!RealAddends.empty()) {
1591 auto ItR = RealAddends.begin();
1592 auto [
R, IsPositiveR] = *ItR;
1594 bool FoundImag =
false;
1595 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1596 auto [
I, IsPositiveI] = *ItI;
1598 if (IsPositiveR && IsPositiveI)
1599 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1600 else if (!IsPositiveR && IsPositiveI)
1601 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1602 else if (!IsPositiveR && !IsPositiveI)
1603 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1605 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1607 CompositeNode *AddNode =
nullptr;
1608 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1609 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1610 AddNode = identifyNode(R,
I);
1612 AddNode = identifyNode(
I, R);
1616 dbgs() <<
"Identified addition:\n";
1619 dbgs().
indent(4) <<
"Rotation - " << (int)Rotation * 90 <<
"\n";
1622 CompositeNode *TmpNode =
nullptr;
1624 TmpNode = prepareCompositeNode(
1625 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1627 TmpNode->Opcode = Instruction::FAdd;
1628 TmpNode->Flags = *
Flags;
1630 TmpNode->Opcode = Instruction::Add;
1632 }
else if (Rotation ==
1634 TmpNode = prepareCompositeNode(
1635 ComplexDeinterleavingOperation::Symmetric,
nullptr,
nullptr);
1637 TmpNode->Opcode = Instruction::FSub;
1638 TmpNode->Flags = *
Flags;
1640 TmpNode->Opcode = Instruction::Sub;
1643 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1645 TmpNode->Rotation = Rotation;
1648 TmpNode->addOperand(Result);
1649 TmpNode->addOperand(AddNode);
1650 submitCompositeNode(TmpNode);
1652 RealAddends.erase(ItR);
1653 ImagAddends.erase(ItI);
1664ComplexDeinterleavingGraph::CompositeNode *
1665ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1666 AddendList &ImagAddends) {
1667 for (
auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1668 for (
auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1669 auto [
R, IsPositiveR] = *ItR;
1670 auto [
I, IsPositiveI] = *ItI;
1671 if (IsPositiveR && IsPositiveI) {
1672 auto Result = identifyNode(R,
I);
1674 RealAddends.erase(ItR);
1675 ImagAddends.erase(ItI);
1684bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1689 auto It = RootToNode.
find(RootI);
1690 if (It != RootToNode.
end()) {
1691 auto RootNode = It->second;
1692 assert(RootNode->Operation ==
1693 ComplexDeinterleavingOperation::ReductionOperation ||
1694 RootNode->Operation ==
1695 ComplexDeinterleavingOperation::ReductionSingle);
1696 assert(RootNode->Vals.size() == 1 &&
1697 "Cannot handle reductions involving multiple complex values");
1706 ReplacementAnchor =
R->comesBefore(
I) ?
I :
R;
1708 ReplacementAnchor =
R;
1710 if (ReplacementAnchor != RootI)
1716 auto RootNode = identifyRoot(RootI);
1723 dbgs() <<
"Complex deinterleaving graph for " <<
F->getName()
1724 <<
"::" <<
B->getName() <<
".\n";
1728 RootToNode[RootI] = RootNode;
1733bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *
B) {
1734 bool FoundPotentialReduction =
false;
1743 if (Br->getSuccessor(0) !=
B && Br->getSuccessor(1) !=
B)
1746 for (
auto &
PHI :
B->phis()) {
1747 if (
PHI.getNumIncomingValues() != 2)
1750 if (!
PHI.getType()->isVectorTy())
1760 for (
auto *U : ReductionOp->users()) {
1767 if (NumUsers != 2 || !FinalReduction || FinalReduction->
getParent() ==
B ||
1771 ReductionInfo[ReductionOp] = {&
PHI, FinalReduction};
1773 auto BackEdgeIdx =
PHI.getBasicBlockIndex(
B);
1774 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1775 Incoming =
PHI.getIncomingBlock(IncomingIdx);
1776 FoundPotentialReduction =
true;
1782 FinalInstructions.
insert(InitPHI);
1784 return FoundPotentialReduction;
1787void ComplexDeinterleavingGraph::identifyReductionNodes() {
1788 assert(Factor == 2 &&
"Cannot handle multiple complex values");
1790 SmallVector<bool> Processed(ReductionInfo.
size(),
false);
1792 for (
auto &
P : ReductionInfo)
1797 for (
size_t i = 0; i < OperationInstruction.
size(); ++i) {
1800 for (
size_t j = i + 1;
j < OperationInstruction.
size(); ++
j) {
1803 auto *Real = OperationInstruction[i];
1804 auto *Imag = OperationInstruction[
j];
1805 if (Real->getType() != Imag->
getType())
1808 RealPHI = ReductionInfo[Real].first;
1809 ImagPHI = ReductionInfo[Imag].first;
1811 auto Node = identifyNode(Real, Imag);
1815 Node = identifyNode(Real, Imag);
1821 if (Node && PHIsFound) {
1822 LLVM_DEBUG(
dbgs() <<
"Identified reduction starting from instructions: "
1823 << *Real <<
" / " << *Imag <<
"\n");
1824 Processed[i] =
true;
1825 Processed[
j] =
true;
1826 auto RootNode = prepareCompositeNode(
1827 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1828 RootNode->addOperand(Node);
1829 RootToNode[Real] = RootNode;
1830 RootToNode[Imag] = RootNode;
1831 submitCompositeNode(RootNode);
1836 auto *Real = OperationInstruction[i];
1839 if (Processed[i] || Real->getNumOperands() < 2)
1843 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1846 RealPHI = ReductionInfo[Real].first;
1849 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1850 if (Node && PHIsFound) {
1852 dbgs() <<
"Identified single reduction starting from instruction: "
1853 << *Real <<
"/" << *ReductionInfo[Real].second <<
"\n");
1862 if (ReductionInfo[Real].second->getType()->isVectorTy())
1865 Processed[i] =
true;
1866 auto RootNode = prepareCompositeNode(
1867 ComplexDeinterleavingOperation::ReductionSingle, Real,
nullptr);
1868 RootNode->addOperand(Node);
1869 RootToNode[Real] = RootNode;
1870 submitCompositeNode(RootNode);
1878bool ComplexDeinterleavingGraph::checkNodes() {
1879 bool FoundDeinterleaveNode =
false;
1880 for (CompositeNode *
N : CompositeNodes) {
1881 if (!
N->areOperandsValid())
1884 if (
N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1885 FoundDeinterleaveNode =
true;
1890 if (!FoundDeinterleaveNode) {
1892 dbgs() <<
"Couldn't find a deinterleave node within the graph, cannot "
1893 "guarantee safety during graph transformation.\n");
1898 SmallPtrSet<Instruction *, 16> AllInstructions;
1899 SmallVector<Instruction *, 8> Worklist;
1900 for (
auto &Pair : RootToNode)
1905 while (!Worklist.
empty()) {
1908 if (!AllInstructions.
insert(
I).second)
1913 if (!FinalInstructions.
count(
I))
1920 for (
auto *
I : AllInstructions) {
1922 if (RootToNode.count(
I))
1925 for (User *U :
I->users()) {
1937 SmallPtrSet<Instruction *, 16> Visited;
1938 while (!Worklist.
empty()) {
1940 if (!Visited.
insert(
I).second)
1945 if (RootToNode.count(
I)) {
1947 <<
" could be deinterleaved but its chain of complex "
1948 "operations have an outside user\n");
1949 RootToNode.erase(
I);
1952 if (!AllInstructions.count(
I) || FinalInstructions.
count(
I))
1955 for (User *U :
I->users())
1963 return !RootToNode.
empty();
1966ComplexDeinterleavingGraph::CompositeNode *
1967ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1974 for (
unsigned I = 0;
I < Factor;
I += 2) {
1982 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
2010 return identifyNode(Real, Imag);
2013ComplexDeinterleavingGraph::CompositeNode *
2014ComplexDeinterleavingGraph::identifyDeinterleave(
ComplexValues &Vals) {
2018 auto CheckExtract = [&](
Value *
V,
unsigned ExpectedIdx,
2019 Instruction *ExpectedInsn) -> ExtractValueInst * {
2021 if (!EVI || EVI->getNumIndices() != 1 ||
2022 EVI->getIndices()[0] != ExpectedIdx ||
2024 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2029 for (
unsigned Idx = 0; Idx < Vals.
size(); Idx++) {
2030 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2,
II);
2031 if (RealEVI && Idx == 0)
2033 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1,
II)) {
2040 if (IntrinsicII->getIntrinsicID() !=
2045 CompositeNode *PlaceholderNode = prepareCompositeNode(
2047 PlaceholderNode->ReplacementNode =
II->getOperand(0);
2048 for (
auto &V : Vals) {
2052 return submitCompositeNode(PlaceholderNode);
2055 if (Vals.size() != 1)
2058 Value *Real = Vals[0].Real;
2059 Value *Imag = Vals[0].Imag;
2062 if (!RealShuffle || !ImagShuffle) {
2063 if (RealShuffle || ImagShuffle)
2064 LLVM_DEBUG(
dbgs() <<
" - There's a shuffle where there shouldn't be.\n");
2068 Value *RealOp1 = RealShuffle->getOperand(1);
2073 Value *ImagOp1 = ImagShuffle->getOperand(1);
2079 Value *RealOp0 = RealShuffle->getOperand(0);
2080 Value *ImagOp0 = ImagShuffle->getOperand(0);
2082 if (RealOp0 != ImagOp0) {
2087 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2088 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2094 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2095 LLVM_DEBUG(
dbgs() <<
" - Masks do not have the correct initial value.\n");
2101 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2102 Value *
Op = Shuffle->getOperand(0);
2106 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2108 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2114 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) ->
bool {
2118 ArrayRef<int>
Mask = Shuffle->getShuffleMask();
2121 Value *
Op = Shuffle->getOperand(0);
2123 int NumElements = OpTy->getNumElements();
2127 return Last < NumElements;
2130 if (RealShuffle->getType() != ImagShuffle->getType()) {
2134 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2138 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2143 CompositeNode *PlaceholderNode =
2145 RealShuffle, ImagShuffle);
2146 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2147 FinalInstructions.
insert(RealShuffle);
2148 FinalInstructions.
insert(ImagShuffle);
2149 return submitCompositeNode(PlaceholderNode);
2152ComplexDeinterleavingGraph::CompositeNode *
2153ComplexDeinterleavingGraph::identifySplat(
ComplexValues &Vals) {
2154 auto IsSplat = [](
Value *
V) ->
bool {
2167 if (
Const->getOpcode() != Instruction::ShuffleVector)
2172 VTy = Shuf->getType();
2173 Mask = Shuf->getShuffleMask();
2181 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2191 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2192 for (
auto &V : Vals) {
2193 if (!IsSplat(
V.Real) || !IsSplat(
V.Imag))
2198 if (!Real || !Imag || Real->getParent() != FirstBB ||
2199 Imag->getParent() != FirstBB)
2203 for (
auto &V : Vals) {
2210 for (
auto &V : Vals) {
2214 FinalInstructions.
insert(Real);
2215 FinalInstructions.
insert(Imag);
2218 CompositeNode *PlaceholderNode =
2219 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2220 return submitCompositeNode(PlaceholderNode);
2223ComplexDeinterleavingGraph::CompositeNode *
2224ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2225 Instruction *Imag) {
2226 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2230 CompositeNode *PlaceholderNode = prepareCompositeNode(
2231 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2232 return submitCompositeNode(PlaceholderNode);
2235ComplexDeinterleavingGraph::CompositeNode *
2236ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2237 Instruction *Imag) {
2240 if (!SelectReal || !SelectImag)
2257 auto NodeA = identifyNode(AR, AI);
2261 auto NodeB = identifyNode(
RA, BI);
2265 CompositeNode *PlaceholderNode = prepareCompositeNode(
2266 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2267 PlaceholderNode->addOperand(NodeA);
2268 PlaceholderNode->addOperand(NodeB);
2269 FinalInstructions.
insert(MaskA);
2270 FinalInstructions.
insert(MaskB);
2271 return submitCompositeNode(PlaceholderNode);
2275 std::optional<FastMathFlags> Flags,
2279 case Instruction::FNeg:
2280 I =
B.CreateFNeg(InputA);
2282 case Instruction::FAdd:
2283 I =
B.CreateFAdd(InputA, InputB);
2285 case Instruction::Add:
2286 I =
B.CreateAdd(InputA, InputB);
2288 case Instruction::FSub:
2289 I =
B.CreateFSub(InputA, InputB);
2291 case Instruction::Sub:
2292 I =
B.CreateSub(InputA, InputB);
2294 case Instruction::FMul:
2295 I =
B.CreateFMul(InputA, InputB);
2297 case Instruction::Mul:
2298 I =
B.CreateMul(InputA, InputB);
2308Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2309 CompositeNode *Node) {
2310 if (
Node->ReplacementNode)
2311 return Node->ReplacementNode;
2313 auto ReplaceOperandIfExist = [&](CompositeNode *
Node,
2314 unsigned Idx) ->
Value * {
2315 return Node->Operands.size() > Idx
2316 ? replaceNode(Builder,
Node->Operands[Idx])
2320 Value *ReplacementNode =
nullptr;
2321 switch (
Node->Operation) {
2322 case ComplexDeinterleavingOperation::CDot: {
2323 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2324 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2327 "Node inputs need to be of the same type"));
2332 case ComplexDeinterleavingOperation::CAdd:
2333 case ComplexDeinterleavingOperation::CMulPartial:
2334 case ComplexDeinterleavingOperation::Symmetric: {
2335 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2336 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2339 "Node inputs need to be of the same type"));
2342 "Accumulator and input need to be of the same type"));
2343 if (
Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2348 Builder,
Node->Operation,
Node->Rotation, Input0, Input1,
2352 case ComplexDeinterleavingOperation::Deinterleave:
2355 case ComplexDeinterleavingOperation::Splat: {
2357 for (
auto &V :
Node->Vals) {
2358 Ops.push_back(
V.Real);
2359 Ops.push_back(
V.Imag);
2366 for (
auto V :
Node->Vals) {
2374 ReplacementNode = IRB.CreateVectorInterleave(
Ops);
2380 case ComplexDeinterleavingOperation::ReductionPHI: {
2385 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2387 OldToNewPHI[OldPHI] = NewPHI;
2388 ReplacementNode = NewPHI;
2391 case ComplexDeinterleavingOperation::ReductionSingle:
2392 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2393 processReductionSingle(ReplacementNode, Node);
2395 case ComplexDeinterleavingOperation::ReductionOperation:
2396 ReplacementNode = replaceNode(Builder,
Node->Operands[0]);
2397 processReductionOperation(ReplacementNode, Node);
2399 case ComplexDeinterleavingOperation::ReductionSelect: {
2402 auto *
A = replaceNode(Builder,
Node->Operands[0]);
2403 auto *
B = replaceNode(Builder,
Node->Operands[1]);
2410 assert(ReplacementNode &&
"Target failed to create Intrinsic call.");
2411 NumComplexTransformations += 1;
2412 Node->ReplacementNode = ReplacementNode;
2413 return ReplacementNode;
2416void ComplexDeinterleavingGraph::processReductionSingle(
2417 Value *OperationReplacement, CompositeNode *Node) {
2419 auto *OldPHI = ReductionInfo[Real].first;
2420 auto *NewPHI = OldToNewPHI[OldPHI];
2422 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2424 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2428 Value *NewInit =
nullptr;
2430 if (
C->isNullValue())
2438 NewPHI->addIncoming(NewInit, Incoming);
2439 NewPHI->addIncoming(OperationReplacement, BackEdge);
2441 auto *FinalReduction = ReductionInfo[Real].second;
2448void ComplexDeinterleavingGraph::processReductionOperation(
2449 Value *OperationReplacement, CompositeNode *Node) {
2452 auto *OldPHIReal = ReductionInfo[Real].first;
2453 auto *OldPHIImag = ReductionInfo[Imag].first;
2454 auto *NewPHI = OldToNewPHI[OldPHIReal];
2457 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2458 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2463 NewPHI->addIncoming(NewInit, Incoming);
2464 NewPHI->addIncoming(OperationReplacement, BackEdge);
2468 auto *FinalReductionReal = ReductionInfo[Real].second;
2469 auto *FinalReductionImag = ReductionInfo[Imag].second;
2472 BasicBlock *ExitBB = Br->getSuccessor(Br->getSuccessor(0) == BackEdge);
2476 OperationReplacement->
getType(),
2477 OperationReplacement);
2480 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2484 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2487void ComplexDeinterleavingGraph::replaceNodes() {
2488 SmallVector<Instruction *, 16> DeadInstrRoots;
2489 for (
auto *RootInstruction : OrderedRoots) {
2492 if (!RootToNode.count(RootInstruction))
2496 auto RootNode = RootToNode[RootInstruction];
2497 Value *
R = replaceNode(Builder, RootNode);
2499 if (RootNode->Operation ==
2500 ComplexDeinterleavingOperation::ReductionOperation) {
2503 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2504 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2507 }
else if (RootNode->Operation ==
2508 ComplexDeinterleavingOperation::ReductionSingle) {
2510 auto &
Info = ReductionInfo[RootInst];
2511 Info.first->removeIncomingValue(BackEdge);
2514 assert(R &&
"Unable to find replacement for RootInstruction");
2515 DeadInstrRoots.
push_back(RootInstruction);
2516 RootInstruction->replaceAllUsesWith(R);
2520 for (
auto *
I : DeadInstrRoots)
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
static bool runOnFunction(Function &F, bool PostInlining)
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
This file implements a map that provides insertion order iteration.
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
SI optimize exec mask operations pre RA
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
This file describes how to lower LLVM code to machine code.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
Get the array size.
LLVM_ABI const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
iterator find(const_arg_type_t< KeyT > Val)
bool allowContract() const
FunctionPass class - This class is used to implement most global optimizations.
Common base class shared among various IRBuilders.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > OverloadTypes, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="", ArrayRef< OperandBundleDef > OpBundles={})
Create a call to intrinsic ID with Args, mangled using OverloadTypes.
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetLibraryInfo.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
virtual const TargetSubtargetInfo * getSubtargetImpl(const Function &) const
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
virtual const TargetLowering * getTargetLowering() const
bool isVectorTy() const
True if this is an instance of VectorType.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
bool hasOneUse() const
Return true if there is exactly one use of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
An opaque object representing a hash code.
const ParentTy * getParent() const
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
@ BR
Control flow instructions. These all have token chains.
@ BasicBlock
Various leaf nodes.
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
match_bind< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BinOp()
Match an arbitrary binary operation and ignore it.
auto m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
NodeAddr< NodeBase * > Node
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
FunctionAddr VTableAddr Value
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
hash_code hash_value(const FixedPointSemantics &Val)
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
auto dyn_cast_or_null(const Y &Val)
ComplexDeinterleavingOperation
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
ComplexDeinterleavingRotation
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
AllocatorList< T, BumpPtrAllocator > BumpPtrList
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
ComplexDeinterleavingPass(const TargetMachine &TM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static ComplexValue getEmptyKey()
static unsigned getHashValue(const ComplexValue &Val)
An information struct used to provide DenseMap with the various necessary components for a given valu...