LLVM 23.0.0git
ComplexDeinterleavingPass.cpp
Go to the documentation of this file.
1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
64#include "llvm/ADT/MapVector.h"
65#include "llvm/ADT/Statistic.h"
70#include "llvm/IR/IRBuilder.h"
71#include "llvm/IR/Intrinsics.h"
77#include <algorithm>
78
79using namespace llvm;
80using namespace PatternMatch;
81
82#define DEBUG_TYPE "complex-deinterleaving"
83
84STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
85
87 "enable-complex-deinterleaving",
88 cl::desc("Enable generation of complex instructions"), cl::init(true),
90
91/// Checks the given mask, and determines whether said mask is interleaving.
92///
93/// To be interleaving, a mask must alternate between `i` and `i + (Length /
94/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
95/// 4x vector interleaving mask would be <0, 2, 1, 3>).
96static bool isInterleavingMask(ArrayRef<int> Mask);
97
98/// Checks the given mask, and determines whether said mask is deinterleaving.
99///
100/// To be deinterleaving, a mask must increment in steps of 2, and either start
101/// with 0 or 1.
102/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
103/// <1, 3, 5, 7>).
104static bool isDeinterleavingMask(ArrayRef<int> Mask);
105
106/// Returns true if the operation is a negation of V, and it works for both
107/// integers and floats.
108static bool isNeg(Value *V);
109
110/// Returns the operand for negation operation.
111static Value *getNegOperand(Value *V);
112
113namespace {
114struct ComplexValue {
115 Value *Real = nullptr;
116 Value *Imag = nullptr;
117
118 bool operator==(const ComplexValue &Other) const {
119 return Real == Other.Real && Imag == Other.Imag;
120 }
121};
122hash_code hash_value(const ComplexValue &Arg) {
125}
126} // end namespace
128
129template <> struct llvm::DenseMapInfo<ComplexValue> {
130 static inline ComplexValue getEmptyKey() {
133 }
134 static unsigned getHashValue(const ComplexValue &Val) {
137 }
138 static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS) {
139 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
140 }
141};
142
143namespace {
144template <typename T, typename IterT>
145std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
146 auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
147 if (Common != A.end())
148 return std::make_optional(*Common);
149 return std::nullopt;
150}
151
152class ComplexDeinterleavingLegacyPass : public FunctionPass {
153public:
154 static char ID;
155
156 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
157 : FunctionPass(ID), TM(TM) {}
158
159 StringRef getPassName() const override {
160 return "Complex Deinterleaving Pass";
161 }
162
163 bool runOnFunction(Function &F) override;
164 void getAnalysisUsage(AnalysisUsage &AU) const override {
165 AU.addRequired<TargetLibraryInfoWrapperPass>();
166 AU.setPreservesCFG();
167 }
168
169private:
170 const TargetMachine *TM;
171};
172
173class ComplexDeinterleavingGraph;
174struct ComplexDeinterleavingCompositeNode {
175
176 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
177 Value *R, Value *I)
178 : Operation(Op) {
179 Vals.push_back({R, I});
180 }
181
182 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
184 : Operation(Op), Vals(Other) {}
185
186private:
187 friend class ComplexDeinterleavingGraph;
188 using CompositeNode = ComplexDeinterleavingCompositeNode;
189 bool OperandsValid = true;
190
191public:
193 ComplexValues Vals;
194
195 // This two members are required exclusively for generating
196 // ComplexDeinterleavingOperation::Symmetric operations.
197 unsigned Opcode;
198 std::optional<FastMathFlags> Flags;
199
201 ComplexDeinterleavingRotation::Rotation_0;
203 Value *ReplacementNode = nullptr;
204
205 void addOperand(CompositeNode *Node) {
206 if (!Node)
207 OperandsValid = false;
208 Operands.push_back(Node);
209 }
210
211 void dump() { dump(dbgs()); }
212 void dump(raw_ostream &OS) {
213 auto PrintValue = [&](Value *V) {
214 if (V) {
215 OS << "\"";
216 V->print(OS, true);
217 OS << "\"\n";
218 } else
219 OS << "nullptr\n";
220 };
221 auto PrintNodeRef = [&](CompositeNode *Ptr) {
222 if (Ptr)
223 OS << Ptr << "\n";
224 else
225 OS << "nullptr\n";
226 };
227
228 OS << "- CompositeNode: " << this << "\n";
229 for (unsigned I = 0; I < Vals.size(); I++) {
230 OS << " Real(" << I << ") : ";
231 PrintValue(Vals[I].Real);
232 OS << " Imag(" << I << ") : ";
233 PrintValue(Vals[I].Imag);
234 }
235 OS << " ReplacementNode: ";
236 PrintValue(ReplacementNode);
237 OS << " Operation: " << (int)Operation << "\n";
238 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
239 OS << " Operands: \n";
240 for (const auto &Op : Operands) {
241 OS << " - ";
242 PrintNodeRef(Op);
243 }
244 }
245
246 bool areOperandsValid() { return OperandsValid; }
247};
248
249class ComplexDeinterleavingGraph {
250public:
251 struct Product {
252 Value *Multiplier;
253 Value *Multiplicand;
254 bool IsPositive;
255 };
256
257 using Addend = std::pair<Value *, bool>;
258 using AddendList = BumpPtrList<Addend>;
259 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
260
261 // Helper struct for holding info about potential partial multiplication
262 // candidates
263 struct PartialMulCandidate {
264 Value *Common;
265 CompositeNode *Node;
266 unsigned RealIdx;
267 unsigned ImagIdx;
268 bool IsNodeInverted;
269 };
270
271 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
272 const TargetLibraryInfo *TLI,
273 unsigned Factor)
274 : TL(TL), TLI(TLI), Factor(Factor) {}
275
276private:
277 const TargetLowering *TL = nullptr;
278 const TargetLibraryInfo *TLI = nullptr;
279 unsigned Factor;
280 SmallVector<CompositeNode *> CompositeNodes;
281 DenseMap<ComplexValues, CompositeNode *> CachedResult;
282 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
283
284 SmallPtrSet<Instruction *, 16> FinalInstructions;
285
286 /// Root instructions are instructions from which complex computation starts
287 DenseMap<Instruction *, CompositeNode *> RootToNode;
288
289 /// Topologically sorted root instructions
291
292 /// When examining a basic block for complex deinterleaving, if it is a simple
293 /// one-block loop, then the only incoming block is 'Incoming' and the
294 /// 'BackEdge' block is the block itself."
295 BasicBlock *BackEdge = nullptr;
296 BasicBlock *Incoming = nullptr;
297
298 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
299 /// %OutsideUser as it is shown in the IR:
300 ///
301 /// vector.body:
302 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
303 /// [ %ReductionOp, %vector.body ]
304 /// ...
305 /// %ReductionOp = fadd i64 ...
306 /// ...
307 /// br i1 %condition, label %vector.body, %middle.block
308 ///
309 /// middle.block:
310 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
311 ///
312 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
313 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
314 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
315
316 /// In the process of detecting a reduction, we consider a pair of
317 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
318 /// traverse the use-tree to detect complex operations. As this is a reduction
319 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
320 /// to the %ReductionOPs that we suspect to be complex.
321 /// RealPHI and ImagPHI are used by the identifyPHINode method.
322 PHINode *RealPHI = nullptr;
323 PHINode *ImagPHI = nullptr;
324
325 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
326 /// detection.
327 bool PHIsFound = false;
328
329 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
330 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
331 /// This mapping is populated during
332 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
333 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
334 /// replacement process.
335 DenseMap<PHINode *, PHINode *> OldToNewPHI;
336
337 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
338 Value *R, Value *I) {
339 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
340 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
341 (R && I)) &&
342 "Reduction related nodes must have Real and Imaginary parts");
343 return new (Allocator.Allocate())
344 ComplexDeinterleavingCompositeNode(Operation, R, I);
345 }
346
347 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
348 ComplexValues &Vals) {
349#ifndef NDEBUG
350 for (auto &V : Vals) {
351 assert(
352 ((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
353 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
354 (V.Real && V.Imag)) &&
355 "Reduction related nodes must have Real and Imaginary parts");
356 }
357#endif
358 return new (Allocator.Allocate())
359 ComplexDeinterleavingCompositeNode(Operation, Vals);
360 }
361
362 CompositeNode *submitCompositeNode(CompositeNode *Node) {
363 CompositeNodes.push_back(Node);
364 if (Node->Vals[0].Real)
365 CachedResult[Node->Vals] = Node;
366 return Node;
367 }
368
369 /// Identifies a complex partial multiply pattern and its rotation, based on
370 /// the following patterns
371 ///
372 /// 0: r: cr + ar * br
373 /// i: ci + ar * bi
374 /// 90: r: cr - ai * bi
375 /// i: ci + ai * br
376 /// 180: r: cr - ar * br
377 /// i: ci - ar * bi
378 /// 270: r: cr + ai * bi
379 /// i: ci - ai * br
380 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
381
382 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
383 /// is partially known from identifyPartialMul, filling in the other half of
384 /// the complex pair.
385 CompositeNode *
386 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
387 std::pair<Value *, Value *> &CommonOperandI);
388
389 /// Identifies a complex add pattern and its rotation, based on the following
390 /// patterns.
391 ///
392 /// 90: r: ar - bi
393 /// i: ai + br
394 /// 270: r: ar + bi
395 /// i: ai - br
396 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
397 CompositeNode *identifySymmetricOperation(ComplexValues &Vals);
398 CompositeNode *identifyPartialReduction(Value *R, Value *I);
399 CompositeNode *identifyDotProduct(Value *Inst);
400
401 CompositeNode *identifyNode(ComplexValues &Vals);
402
403 CompositeNode *identifyNode(Value *R, Value *I) {
404 ComplexValues Vals;
405 Vals.push_back({R, I});
406 return identifyNode(Vals);
407 }
408
409 /// Determine if a sum of complex numbers can be formed from \p RealAddends
410 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
411 /// Return nullptr if it is not possible to construct a complex number.
412 /// \p Flags are needed to generate symmetric Add and Sub operations.
413 CompositeNode *identifyAdditions(AddendList &RealAddends,
414 AddendList &ImagAddends,
415 std::optional<FastMathFlags> Flags,
416 CompositeNode *Accumulator);
417
418 /// Extract one addend that have both real and imaginary parts positive.
419 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
420 AddendList &ImagAddends);
421
422 /// Determine if sum of multiplications of complex numbers can be formed from
423 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
424 /// to it. Return nullptr if it is not possible to construct a complex number.
425 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
426 SmallVectorImpl<Product> &ImagMuls,
427 CompositeNode *Accumulator);
428
429 /// Go through pairs of multiplication (one Real and one Imag) and find all
430 /// possible candidates for partial multiplication and put them into \p
431 /// Candidates. Returns true if all Product has pair with common operand
432 bool collectPartialMuls(ArrayRef<Product> RealMuls,
433 ArrayRef<Product> ImagMuls,
434 SmallVectorImpl<PartialMulCandidate> &Candidates);
435
436 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
437 /// the order of complex computation operations may be significantly altered,
438 /// and the real and imaginary parts may not be executed in parallel. This
439 /// function takes this into consideration and employs a more general approach
440 /// to identify complex computations. Initially, it gathers all the addends
441 /// and multiplicands and then constructs a complex expression from them.
442 CompositeNode *identifyReassocNodes(Instruction *I, Instruction *J);
443
444 CompositeNode *identifyRoot(Instruction *I);
445
446 /// Identifies the Deinterleave operation applied to a vector containing
447 /// complex numbers. There are two ways to represent the Deinterleave
448 /// operation:
449 /// * Using two shufflevectors with even indices for /pReal instruction and
450 /// odd indices for /pImag instructions (only for fixed-width vectors)
451 /// * Using N extractvalue instructions applied to `vector.deinterleaveN`
452 /// intrinsics (for both fixed and scalable vectors) where N is a multiple of
453 /// 2.
454 CompositeNode *identifyDeinterleave(ComplexValues &Vals);
455
456 /// identifying the operation that represents a complex number repeated in a
457 /// Splat vector. There are two possible types of splats: ConstantExpr with
458 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
459 /// initialization mask with all values set to zero.
460 CompositeNode *identifySplat(ComplexValues &Vals);
461
462 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
463
464 /// Identifies SelectInsts in a loop that has reduction with predication masks
465 /// and/or predicated tail folding
466 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
467
468 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
469
470 /// Complete IR modifications after producing new reduction operation:
471 /// * Populate the PHINode generated for
472 /// ComplexDeinterleavingOperation::ReductionPHI
473 /// * Deinterleave the final value outside of the loop and repurpose original
474 /// reduction users
475 void processReductionOperation(Value *OperationReplacement,
476 CompositeNode *Node);
477 void processReductionSingle(Value *OperationReplacement, CompositeNode *Node);
478
479public:
480 void dump() { dump(dbgs()); }
481 void dump(raw_ostream &OS) {
482 for (const auto &Node : CompositeNodes)
483 Node->dump(OS);
484 }
485
486 /// Returns false if the deinterleaving operation should be cancelled for the
487 /// current graph.
488 bool identifyNodes(Instruction *RootI);
489
490 /// In case \pB is one-block loop, this function seeks potential reductions
491 /// and populates ReductionInfo. Returns true if any reductions were
492 /// identified.
493 bool collectPotentialReductions(BasicBlock *B);
494
495 void identifyReductionNodes();
496
497 /// Check that every instruction, from the roots to the leaves, has internal
498 /// uses.
499 bool checkNodes();
500
501 /// Perform the actual replacement of the underlying instruction graph.
502 void replaceNodes();
503};
504
505class ComplexDeinterleaving {
506public:
507 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
508 : TL(tl), TLI(tli) {}
509 bool runOnFunction(Function &F);
510
511private:
512 bool evaluateBasicBlock(BasicBlock *B, unsigned Factor);
513
514 const TargetLowering *TL = nullptr;
515 const TargetLibraryInfo *TLI = nullptr;
516};
517
518} // namespace
519
520char ComplexDeinterleavingLegacyPass::ID = 0;
521
522INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
523 "Complex Deinterleaving", false, false)
524INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
525 "Complex Deinterleaving", false, false)
526
529 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
530 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
531 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
532 return PreservedAnalyses::all();
533
536 return PA;
537}
538
540 return new ComplexDeinterleavingLegacyPass(TM);
541}
542
543bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
544 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
545 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
546 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
547}
548
549bool ComplexDeinterleaving::runOnFunction(Function &F) {
552 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
553 return false;
554 }
555
558 dbgs() << "Complex deinterleaving has been disabled, target does "
559 "not support lowering of complex number operations.\n");
560 return false;
561 }
562
563 bool Changed = false;
564 for (auto &B : F)
565 Changed |= evaluateBasicBlock(&B, 2);
566
567 // TODO: Permit changes for both interleave factors in the same function.
568 if (!Changed) {
569 for (auto &B : F)
570 Changed |= evaluateBasicBlock(&B, 4);
571 }
572
573 // TODO: We can also support interleave factors of 6 and 8 if needed.
574
575 return Changed;
576}
577
579 // If the size is not even, it's not an interleaving mask
580 if ((Mask.size() & 1))
581 return false;
582
583 int HalfNumElements = Mask.size() / 2;
584 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
585 int MaskIdx = Idx * 2;
586 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
587 return false;
588 }
589
590 return true;
591}
592
594 int Offset = Mask[0];
595 int HalfNumElements = Mask.size() / 2;
596
597 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
598 if (Mask[Idx] != (Idx * 2) + Offset)
599 return false;
600 }
601
602 return true;
603}
604
605bool isNeg(Value *V) {
606 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
607}
608
610 assert(isNeg(V));
611 auto *I = cast<Instruction>(V);
612 if (I->getOpcode() == Instruction::FNeg)
613 return I->getOperand(0);
614
615 return I->getOperand(1);
616}
617
618bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
619 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
620 if (Graph.collectPotentialReductions(B))
621 Graph.identifyReductionNodes();
622
623 for (auto &I : *B)
624 Graph.identifyNodes(&I);
625
626 if (Graph.checkNodes()) {
627 Graph.replaceNodes();
628 return true;
629 }
630
631 return false;
632}
633
634ComplexDeinterleavingGraph::CompositeNode *
635ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
636 Instruction *Real, Instruction *Imag,
637 std::pair<Value *, Value *> &PartialMatch) {
638 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
639 << "\n");
640
641 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
642 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
643 return nullptr;
644 }
645
646 if ((Real->getOpcode() != Instruction::FMul &&
647 Real->getOpcode() != Instruction::Mul) ||
648 (Imag->getOpcode() != Instruction::FMul &&
649 Imag->getOpcode() != Instruction::Mul)) {
651 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
652 return nullptr;
653 }
654
655 Value *R0 = Real->getOperand(0);
656 Value *R1 = Real->getOperand(1);
657 Value *I0 = Imag->getOperand(0);
658 Value *I1 = Imag->getOperand(1);
659
660 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
661 // rotations and use the operand.
662 unsigned Negs = 0;
663 Value *Op;
664 if (match(R0, m_Neg(m_Value(Op)))) {
665 Negs |= 1;
666 R0 = Op;
667 } else if (match(R1, m_Neg(m_Value(Op)))) {
668 Negs |= 1;
669 R1 = Op;
670 }
671
672 if (isNeg(I0)) {
673 Negs |= 2;
674 Negs ^= 1;
675 I0 = Op;
676 } else if (match(I1, m_Neg(m_Value(Op)))) {
677 Negs |= 2;
678 Negs ^= 1;
679 I1 = Op;
680 }
681
683
684 Value *CommonOperand;
685 Value *UncommonRealOp;
686 Value *UncommonImagOp;
687
688 if (R0 == I0 || R0 == I1) {
689 CommonOperand = R0;
690 UncommonRealOp = R1;
691 } else if (R1 == I0 || R1 == I1) {
692 CommonOperand = R1;
693 UncommonRealOp = R0;
694 } else {
695 LLVM_DEBUG(dbgs() << " - No equal operand\n");
696 return nullptr;
697 }
698
699 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
700 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
701 Rotation == ComplexDeinterleavingRotation::Rotation_270)
702 std::swap(UncommonRealOp, UncommonImagOp);
703
704 // Between identifyPartialMul and here we need to have found a complete valid
705 // pair from the CommonOperand of each part.
706 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
707 Rotation == ComplexDeinterleavingRotation::Rotation_180)
708 PartialMatch.first = CommonOperand;
709 else
710 PartialMatch.second = CommonOperand;
711
712 if (!PartialMatch.first || !PartialMatch.second) {
713 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
714 return nullptr;
715 }
716
717 CompositeNode *CommonNode =
718 identifyNode(PartialMatch.first, PartialMatch.second);
719 if (!CommonNode) {
720 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
721 return nullptr;
722 }
723
724 CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
725 if (!UncommonNode) {
726 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
727 return nullptr;
728 }
729
730 CompositeNode *Node = prepareCompositeNode(
731 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
732 Node->Rotation = Rotation;
733 Node->addOperand(CommonNode);
734 Node->addOperand(UncommonNode);
735 return submitCompositeNode(Node);
736}
737
738ComplexDeinterleavingGraph::CompositeNode *
739ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
740 Instruction *Imag) {
741 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
742 << "\n");
743
744 // Determine rotation
745 auto IsAdd = [](unsigned Op) {
746 return Op == Instruction::FAdd || Op == Instruction::Add;
747 };
748 auto IsSub = [](unsigned Op) {
749 return Op == Instruction::FSub || Op == Instruction::Sub;
750 };
752 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
753 Rotation = ComplexDeinterleavingRotation::Rotation_0;
754 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
755 Rotation = ComplexDeinterleavingRotation::Rotation_90;
756 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
757 Rotation = ComplexDeinterleavingRotation::Rotation_180;
758 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
759 Rotation = ComplexDeinterleavingRotation::Rotation_270;
760 else {
761 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
762 return nullptr;
763 }
764
765 if (isa<FPMathOperator>(Real) &&
766 (!Real->getFastMathFlags().allowContract() ||
767 !Imag->getFastMathFlags().allowContract())) {
768 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
769 return nullptr;
770 }
771
772 Value *CR = Real->getOperand(0);
773 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
774 if (!RealMulI)
775 return nullptr;
776 Value *CI = Imag->getOperand(0);
777 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
778 if (!ImagMulI)
779 return nullptr;
780
781 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
782 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
783 return nullptr;
784 }
785
786 Value *R0 = RealMulI->getOperand(0);
787 Value *R1 = RealMulI->getOperand(1);
788 Value *I0 = ImagMulI->getOperand(0);
789 Value *I1 = ImagMulI->getOperand(1);
790
791 Value *CommonOperand;
792 Value *UncommonRealOp;
793 Value *UncommonImagOp;
794
795 if (R0 == I0 || R0 == I1) {
796 CommonOperand = R0;
797 UncommonRealOp = R1;
798 } else if (R1 == I0 || R1 == I1) {
799 CommonOperand = R1;
800 UncommonRealOp = R0;
801 } else {
802 LLVM_DEBUG(dbgs() << " - No equal operand\n");
803 return nullptr;
804 }
805
806 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
807 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
808 Rotation == ComplexDeinterleavingRotation::Rotation_270)
809 std::swap(UncommonRealOp, UncommonImagOp);
810
811 std::pair<Value *, Value *> PartialMatch(
812 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
813 Rotation == ComplexDeinterleavingRotation::Rotation_180)
814 ? CommonOperand
815 : nullptr,
816 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
817 Rotation == ComplexDeinterleavingRotation::Rotation_270)
818 ? CommonOperand
819 : nullptr);
820
821 auto *CRInst = dyn_cast<Instruction>(CR);
822 auto *CIInst = dyn_cast<Instruction>(CI);
823
824 if (!CRInst || !CIInst) {
825 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
826 return nullptr;
827 }
828
829 CompositeNode *CNode =
830 identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
831 if (!CNode) {
832 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
833 return nullptr;
834 }
835
836 CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
837 if (!UncommonRes) {
838 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
839 return nullptr;
840 }
841
842 assert(PartialMatch.first && PartialMatch.second);
843 CompositeNode *CommonRes =
844 identifyNode(PartialMatch.first, PartialMatch.second);
845 if (!CommonRes) {
846 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
847 return nullptr;
848 }
849
850 CompositeNode *Node = prepareCompositeNode(
851 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
852 Node->Rotation = Rotation;
853 Node->addOperand(CommonRes);
854 Node->addOperand(UncommonRes);
855 Node->addOperand(CNode);
856 return submitCompositeNode(Node);
857}
858
859ComplexDeinterleavingGraph::CompositeNode *
860ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
861 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
862
863 // Determine rotation
865 if ((Real->getOpcode() == Instruction::FSub &&
866 Imag->getOpcode() == Instruction::FAdd) ||
867 (Real->getOpcode() == Instruction::Sub &&
868 Imag->getOpcode() == Instruction::Add))
869 Rotation = ComplexDeinterleavingRotation::Rotation_90;
870 else if ((Real->getOpcode() == Instruction::FAdd &&
871 Imag->getOpcode() == Instruction::FSub) ||
872 (Real->getOpcode() == Instruction::Add &&
873 Imag->getOpcode() == Instruction::Sub))
874 Rotation = ComplexDeinterleavingRotation::Rotation_270;
875 else {
876 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
877 return nullptr;
878 }
879
880 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
881 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
882 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
883 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
884
885 if (!AR || !AI || !BR || !BI) {
886 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
887 return nullptr;
888 }
889
890 CompositeNode *ResA = identifyNode(AR, AI);
891 if (!ResA) {
892 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
893 return nullptr;
894 }
895 CompositeNode *ResB = identifyNode(BR, BI);
896 if (!ResB) {
897 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
898 return nullptr;
899 }
900
901 CompositeNode *Node =
902 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
903 Node->Rotation = Rotation;
904 Node->addOperand(ResA);
905 Node->addOperand(ResB);
906 return submitCompositeNode(Node);
907}
908
910 unsigned OpcA = A->getOpcode();
911 unsigned OpcB = B->getOpcode();
912
913 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
914 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
915 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
916 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
917}
918
920 auto Pattern =
922
923 return match(A, Pattern) && match(B, Pattern);
924}
925
927 switch (I->getOpcode()) {
928 case Instruction::FAdd:
929 case Instruction::FSub:
930 case Instruction::FMul:
931 case Instruction::FNeg:
932 case Instruction::Add:
933 case Instruction::Sub:
934 case Instruction::Mul:
935 return true;
936 default:
937 return false;
938 }
939}
940
941ComplexDeinterleavingGraph::CompositeNode *
942ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
943 auto *FirstReal = cast<Instruction>(Vals[0].Real);
944 unsigned FirstOpc = FirstReal->getOpcode();
945 for (auto &V : Vals) {
946 auto *Real = cast<Instruction>(V.Real);
947 auto *Imag = cast<Instruction>(V.Imag);
948 if (Real->getOpcode() != FirstOpc || Imag->getOpcode() != FirstOpc)
949 return nullptr;
950
953 return nullptr;
954
955 if (isa<FPMathOperator>(FirstReal))
956 if (Real->getFastMathFlags() != FirstReal->getFastMathFlags() ||
957 Imag->getFastMathFlags() != FirstReal->getFastMathFlags())
958 return nullptr;
959 }
960
961 ComplexValues OpVals;
962 for (auto &V : Vals) {
963 auto *R0 = cast<Instruction>(V.Real)->getOperand(0);
964 auto *I0 = cast<Instruction>(V.Imag)->getOperand(0);
965 OpVals.push_back({R0, I0});
966 }
967
968 CompositeNode *Op0 = identifyNode(OpVals);
969 CompositeNode *Op1 = nullptr;
970 if (Op0 == nullptr)
971 return nullptr;
972
973 if (FirstReal->isBinaryOp()) {
974 OpVals.clear();
975 for (auto &V : Vals) {
976 auto *R1 = cast<Instruction>(V.Real)->getOperand(1);
977 auto *I1 = cast<Instruction>(V.Imag)->getOperand(1);
978 OpVals.push_back({R1, I1});
979 }
980 Op1 = identifyNode(OpVals);
981 if (Op1 == nullptr)
982 return nullptr;
983 }
984
985 auto Node =
986 prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
987 Node->Opcode = FirstReal->getOpcode();
988 if (isa<FPMathOperator>(FirstReal))
989 Node->Flags = FirstReal->getFastMathFlags();
990
991 Node->addOperand(Op0);
992 if (FirstReal->isBinaryOp())
993 Node->addOperand(Op1);
994
995 return submitCompositeNode(Node);
996}
997
998ComplexDeinterleavingGraph::CompositeNode *
999ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
1001 ComplexDeinterleavingOperation::CDot, V->getType())) {
1002 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
1003 "operation CDot with the type "
1004 << *V->getType() << "\n");
1005 return nullptr;
1006 }
1007
1008 auto *Inst = cast<Instruction>(V);
1009 auto *RealUser = cast<Instruction>(*Inst->user_begin());
1010
1011 CompositeNode *CN =
1012 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
1013
1014 CompositeNode *ANode = nullptr;
1015
1016 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1017
1018 Value *AReal = nullptr;
1019 Value *AImag = nullptr;
1020 Value *BReal = nullptr;
1021 Value *BImag = nullptr;
1022 Value *Phi = nullptr;
1023
1024 auto UnwrapCast = [](Value *V) -> Value * {
1025 if (auto *CI = dyn_cast<CastInst>(V))
1026 return CI->getOperand(0);
1027 return V;
1028 };
1029
1030 auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
1032 m_Mul(m_Value(BReal), m_Value(AReal))),
1033 m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
1034
1035 auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
1037 m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
1038 m_Mul(m_Value(BImag), m_Value(AReal)));
1039
1040 if (match(Inst, PatternRot0)) {
1041 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1042 } else if (match(Inst, PatternRot270)) {
1043 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1044 } else {
1045 Value *A0, *A1;
1046 // The rotations 90 and 180 share the same operation pattern, so inspect the
1047 // order of the operands, identifying where the real and imaginary
1048 // components of A go, to discern between the aforementioned rotations.
1049 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
1051 m_Mul(m_Value(BReal), m_Value(A0))),
1052 m_Mul(m_Value(BImag), m_Value(A1)));
1053
1054 if (!match(Inst, PatternRot90Rot180))
1055 return nullptr;
1056
1057 A0 = UnwrapCast(A0);
1058 A1 = UnwrapCast(A1);
1059
1060 // Test if A0 is real/A1 is imag
1061 ANode = identifyNode(A0, A1);
1062 if (!ANode) {
1063 // Test if A0 is imag/A1 is real
1064 ANode = identifyNode(A1, A0);
1065 // Unable to identify operand components, thus unable to identify rotation
1066 if (!ANode)
1067 return nullptr;
1068 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1069 AReal = A1;
1070 AImag = A0;
1071 } else {
1072 AReal = A0;
1073 AImag = A1;
1074 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1075 }
1076 }
1077
1078 AReal = UnwrapCast(AReal);
1079 AImag = UnwrapCast(AImag);
1080 BReal = UnwrapCast(BReal);
1081 BImag = UnwrapCast(BImag);
1082
1083 VectorType *VTy = cast<VectorType>(V->getType());
1084 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
1085 if (AReal->getType() != ExpectedOperandTy)
1086 return nullptr;
1087 if (AImag->getType() != ExpectedOperandTy)
1088 return nullptr;
1089 if (BReal->getType() != ExpectedOperandTy)
1090 return nullptr;
1091 if (BImag->getType() != ExpectedOperandTy)
1092 return nullptr;
1093
1094 if (Phi->getType() != VTy && RealUser->getType() != VTy)
1095 return nullptr;
1096
1097 CompositeNode *Node = identifyNode(AReal, AImag);
1098
1099 // In the case that a node was identified to figure out the rotation, ensure
1100 // that trying to identify a node with AReal and AImag post-unwrap results in
1101 // the same node
1102 if (ANode && Node != ANode) {
1103 LLVM_DEBUG(
1104 dbgs()
1105 << "Identified node is different from previously identified node. "
1106 "Unable to confidently generate a complex operation node\n");
1107 return nullptr;
1108 }
1109
1110 CN->addOperand(Node);
1111 CN->addOperand(identifyNode(BReal, BImag));
1112 CN->addOperand(identifyNode(Phi, RealUser));
1113
1114 return submitCompositeNode(CN);
1115}
1116
1117ComplexDeinterleavingGraph::CompositeNode *
1118ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1119 // Partial reductions don't support non-vector types, so check these first
1120 if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
1121 return nullptr;
1122
1123 if (!R->hasUseList() || !I->hasUseList())
1124 return nullptr;
1125
1126 auto CommonUser =
1127 findCommonBetweenCollections<Value *>(R->users(), I->users());
1128 if (!CommonUser)
1129 return nullptr;
1130
1131 auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
1132 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1133 return nullptr;
1134
1135 if (CompositeNode *CN = identifyDotProduct(IInst))
1136 return CN;
1137
1138 return nullptr;
1139}
1140
1141ComplexDeinterleavingGraph::CompositeNode *
1142ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
1143 auto It = CachedResult.find(Vals);
1144 if (It != CachedResult.end()) {
1145 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1146 return It->second;
1147 }
1148
1149 if (Vals.size() == 1) {
1150 assert(Factor == 2 && "Can only handle interleave factors of 2");
1151 Value *R = Vals[0].Real;
1152 Value *I = Vals[0].Imag;
1153 if (CompositeNode *CN = identifyPartialReduction(R, I))
1154 return CN;
1155 bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
1156 if (!IsReduction && R->getType() != I->getType())
1157 return nullptr;
1158 }
1159
1160 if (CompositeNode *CN = identifySplat(Vals))
1161 return CN;
1162
1163 for (auto &V : Vals) {
1164 auto *Real = dyn_cast<Instruction>(V.Real);
1165 auto *Imag = dyn_cast<Instruction>(V.Imag);
1166 if (!Real || !Imag)
1167 return nullptr;
1168 }
1169
1170 if (CompositeNode *CN = identifyDeinterleave(Vals))
1171 return CN;
1172
1173 if (Vals.size() == 1) {
1174 assert(Factor == 2 && "Can only handle interleave factors of 2");
1175 auto *Real = dyn_cast<Instruction>(Vals[0].Real);
1176 auto *Imag = dyn_cast<Instruction>(Vals[0].Imag);
1177 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1178 return CN;
1179
1180 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1181 return CN;
1182
1183 auto *VTy = cast<VectorType>(Real->getType());
1184 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1185
1186 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
1187 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1188 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
1189 ComplexDeinterleavingOperation::CAdd, NewVTy);
1190
1191 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
1192 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1193 return CN;
1194 }
1195
1196 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
1197 if (CompositeNode *CN = identifyAdd(Real, Imag))
1198 return CN;
1199 }
1200
1201 if (HasCMulSupport && HasCAddSupport) {
1202 if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
1203 return CN;
1204 }
1205 }
1206 }
1207
1208 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1209 return CN;
1210
1211 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
1212 CachedResult[Vals] = nullptr;
1213 return nullptr;
1214}
1215
1216ComplexDeinterleavingGraph::CompositeNode *
1217ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1218 Instruction *Imag) {
1219 auto IsOperationSupported = [](unsigned Opcode) -> bool {
1220 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1221 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1222 Opcode == Instruction::Sub;
1223 };
1224
1225 if (!IsOperationSupported(Real->getOpcode()) ||
1226 !IsOperationSupported(Imag->getOpcode()))
1227 return nullptr;
1228
1229 std::optional<FastMathFlags> Flags;
1230 if (isa<FPMathOperator>(Real)) {
1231 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
1232 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1233 "not identical\n");
1234 return nullptr;
1235 }
1236
1237 Flags = Real->getFastMathFlags();
1238 if (!Flags->allowReassoc()) {
1239 LLVM_DEBUG(
1240 dbgs()
1241 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1242 return nullptr;
1243 }
1244 }
1245
1246 // Collect multiplications and addend instructions from the given instruction
1247 // while traversing it operands. Additionally, verify that all instructions
1248 // have the same fast math flags.
1249 auto Collect = [&Flags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
1250 AddendList &Addends) -> bool {
1251 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
1252 SmallPtrSet<Value *, 8> Visited;
1253 while (!Worklist.empty()) {
1254 auto [V, IsPositive] = Worklist.pop_back_val();
1255 if (!Visited.insert(V).second)
1256 continue;
1257
1259 if (!I) {
1260 Addends.emplace_back(V, IsPositive);
1261 continue;
1262 }
1263
1264 // If an instruction has more than one user, it indicates that it either
1265 // has an external user, which will be later checked by the checkNodes
1266 // function, or it is a subexpression utilized by multiple expressions. In
1267 // the latter case, we will attempt to separately identify the complex
1268 // operation from here in order to create a shared
1269 // ComplexDeinterleavingCompositeNode.
1270 if (I != Insn && I->hasNUsesOrMore(2)) {
1271 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1272 Addends.emplace_back(I, IsPositive);
1273 continue;
1274 }
1275 switch (I->getOpcode()) {
1276 case Instruction::FAdd:
1277 case Instruction::Add:
1278 Worklist.emplace_back(I->getOperand(1), IsPositive);
1279 Worklist.emplace_back(I->getOperand(0), IsPositive);
1280 break;
1281 case Instruction::FSub:
1282 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1283 Worklist.emplace_back(I->getOperand(0), IsPositive);
1284 break;
1285 case Instruction::Sub:
1286 if (isNeg(I)) {
1287 Worklist.emplace_back(getNegOperand(I), !IsPositive);
1288 } else {
1289 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1290 Worklist.emplace_back(I->getOperand(0), IsPositive);
1291 }
1292 break;
1293 case Instruction::FMul:
1294 case Instruction::Mul: {
1295 Value *A, *B;
1296 if (isNeg(I->getOperand(0))) {
1297 A = getNegOperand(I->getOperand(0));
1298 IsPositive = !IsPositive;
1299 } else {
1300 A = I->getOperand(0);
1301 }
1302
1303 if (isNeg(I->getOperand(1))) {
1304 B = getNegOperand(I->getOperand(1));
1305 IsPositive = !IsPositive;
1306 } else {
1307 B = I->getOperand(1);
1308 }
1309 Muls.push_back(Product{A, B, IsPositive});
1310 break;
1311 }
1312 case Instruction::FNeg:
1313 Worklist.emplace_back(I->getOperand(0), !IsPositive);
1314 break;
1315 default:
1316 Addends.emplace_back(I, IsPositive);
1317 continue;
1318 }
1319
1320 if (Flags && I->getFastMathFlags() != *Flags) {
1321 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1322 "inconsistent with the root instructions' flags: "
1323 << *I << "\n");
1324 return false;
1325 }
1326 }
1327 return true;
1328 };
1329
1330 SmallVector<Product> RealMuls, ImagMuls;
1331 AddendList RealAddends, ImagAddends;
1332 if (!Collect(Real, RealMuls, RealAddends) ||
1333 !Collect(Imag, ImagMuls, ImagAddends))
1334 return nullptr;
1335
1336 if (RealAddends.size() != ImagAddends.size())
1337 return nullptr;
1338
1339 CompositeNode *FinalNode = nullptr;
1340 if (!RealMuls.empty() || !ImagMuls.empty()) {
1341 // If there are multiplicands, extract positive addend and use it as an
1342 // accumulator
1343 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1344 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1345 if (!FinalNode)
1346 return nullptr;
1347 }
1348
1349 // Identify and process remaining additions
1350 if (!RealAddends.empty() || !ImagAddends.empty()) {
1351 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1352 if (!FinalNode)
1353 return nullptr;
1354 }
1355 assert(FinalNode && "FinalNode can not be nullptr here");
1356 assert(FinalNode->Vals.size() == 1);
1357 // Set the Real and Imag fields of the final node and submit it
1358 FinalNode->Vals[0].Real = Real;
1359 FinalNode->Vals[0].Imag = Imag;
1360 submitCompositeNode(FinalNode);
1361 return FinalNode;
1362}
1363
1364bool ComplexDeinterleavingGraph::collectPartialMuls(
1365 ArrayRef<Product> RealMuls, ArrayRef<Product> ImagMuls,
1366 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1367 // Helper function to extract a common operand from two products
1368 auto FindCommonInstruction = [](const Product &Real,
1369 const Product &Imag) -> Value * {
1370 if (Real.Multiplicand == Imag.Multiplicand ||
1371 Real.Multiplicand == Imag.Multiplier)
1372 return Real.Multiplicand;
1373
1374 if (Real.Multiplier == Imag.Multiplicand ||
1375 Real.Multiplier == Imag.Multiplier)
1376 return Real.Multiplier;
1377
1378 return nullptr;
1379 };
1380
1381 // Iterating over real and imaginary multiplications to find common operands
1382 // If a common operand is found, a partial multiplication candidate is created
1383 // and added to the candidates vector The function returns false if no common
1384 // operands are found for any product
1385 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1386 bool FoundCommon = false;
1387 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1388 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1389 if (!Common)
1390 continue;
1391
1392 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1393 : RealMuls[i].Multiplicand;
1394 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1395 : ImagMuls[j].Multiplicand;
1396
1397 auto Node = identifyNode(A, B);
1398 if (Node) {
1399 FoundCommon = true;
1400 PartialMulCandidates.push_back({Common, Node, i, j, false});
1401 }
1402
1403 Node = identifyNode(B, A);
1404 if (Node) {
1405 FoundCommon = true;
1406 PartialMulCandidates.push_back({Common, Node, i, j, true});
1407 }
1408 }
1409 if (!FoundCommon)
1410 return false;
1411 }
1412 return true;
1413}
1414
1415ComplexDeinterleavingGraph::CompositeNode *
1416ComplexDeinterleavingGraph::identifyMultiplications(
1417 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1418 CompositeNode *Accumulator = nullptr) {
1419 if (RealMuls.size() != ImagMuls.size())
1420 return nullptr;
1421
1423 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1424 return nullptr;
1425
1426 // Map to store common instruction to node pointers
1427 DenseMap<Value *, CompositeNode *> CommonToNode;
1428 SmallVector<bool> Processed(Info.size(), false);
1429 for (unsigned I = 0; I < Info.size(); ++I) {
1430 if (Processed[I])
1431 continue;
1432
1433 PartialMulCandidate &InfoA = Info[I];
1434 for (unsigned J = I + 1; J < Info.size(); ++J) {
1435 if (Processed[J])
1436 continue;
1437
1438 PartialMulCandidate &InfoB = Info[J];
1439 auto *InfoReal = &InfoA;
1440 auto *InfoImag = &InfoB;
1441
1442 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1443 if (!NodeFromCommon) {
1444 std::swap(InfoReal, InfoImag);
1445 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1446 }
1447 if (!NodeFromCommon)
1448 continue;
1449
1450 CommonToNode[InfoReal->Common] = NodeFromCommon;
1451 CommonToNode[InfoImag->Common] = NodeFromCommon;
1452 Processed[I] = true;
1453 Processed[J] = true;
1454 }
1455 }
1456
1457 SmallVector<bool> ProcessedReal(RealMuls.size(), false);
1458 SmallVector<bool> ProcessedImag(ImagMuls.size(), false);
1459 CompositeNode *Result = Accumulator;
1460 for (auto &PMI : Info) {
1461 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1462 continue;
1463
1464 auto It = CommonToNode.find(PMI.Common);
1465 // TODO: Process independent complex multiplications. Cases like this:
1466 // A.real() * B where both A and B are complex numbers.
1467 if (It == CommonToNode.end()) {
1468 LLVM_DEBUG({
1469 dbgs() << "Unprocessed independent partial multiplication:\n";
1470 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1471 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1472 << " multiplied by " << *Mul->Multiplicand << "\n";
1473 });
1474 return nullptr;
1475 }
1476
1477 auto &RealMul = RealMuls[PMI.RealIdx];
1478 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1479
1480 auto NodeA = It->second;
1481 auto NodeB = PMI.Node;
1482 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1483 // The following table illustrates the relationship between multiplications
1484 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1485 // can see:
1486 //
1487 // Rotation | Real | Imag |
1488 // ---------+--------+--------+
1489 // 0 | x * u | x * v |
1490 // 90 | -y * v | y * u |
1491 // 180 | -x * u | -x * v |
1492 // 270 | y * v | -y * u |
1493 //
1494 // Check if the candidate can indeed be represented by partial
1495 // multiplication
1496 // TODO: Add support for multiplication by complex one
1497 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1498 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1499 continue;
1500
1501 // Determine the rotation based on the multiplications
1503 if (IsMultiplicandReal) {
1504 // Detect 0 and 180 degrees rotation
1505 if (RealMul.IsPositive && ImagMul.IsPositive)
1507 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1509 else
1510 continue;
1511
1512 } else {
1513 // Detect 90 and 270 degrees rotation
1514 if (!RealMul.IsPositive && ImagMul.IsPositive)
1516 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1518 else
1519 continue;
1520 }
1521
1522 LLVM_DEBUG({
1523 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1524 dbgs().indent(4) << "X: " << *NodeA->Vals[0].Real << "\n";
1525 dbgs().indent(4) << "Y: " << *NodeA->Vals[0].Imag << "\n";
1526 dbgs().indent(4) << "U: " << *NodeB->Vals[0].Real << "\n";
1527 dbgs().indent(4) << "V: " << *NodeB->Vals[0].Imag << "\n";
1528 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1529 });
1530
1531 CompositeNode *NodeMul = prepareCompositeNode(
1532 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1533 NodeMul->Rotation = Rotation;
1534 NodeMul->addOperand(NodeA);
1535 NodeMul->addOperand(NodeB);
1536 if (Result)
1537 NodeMul->addOperand(Result);
1538 submitCompositeNode(NodeMul);
1539 Result = NodeMul;
1540 ProcessedReal[PMI.RealIdx] = true;
1541 ProcessedImag[PMI.ImagIdx] = true;
1542 }
1543
1544 // Ensure all products have been processed, if not return nullptr.
1545 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1546 !all_of(ProcessedImag, [](bool V) { return V; })) {
1547
1548 // Dump debug information about which partial multiplications are not
1549 // processed.
1550 LLVM_DEBUG({
1551 dbgs() << "Unprocessed products (Real):\n";
1552 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1553 if (!ProcessedReal[i])
1554 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1555 << *RealMuls[i].Multiplier << " multiplied by "
1556 << *RealMuls[i].Multiplicand << "\n";
1557 }
1558 dbgs() << "Unprocessed products (Imag):\n";
1559 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1560 if (!ProcessedImag[i])
1561 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1562 << *ImagMuls[i].Multiplier << " multiplied by "
1563 << *ImagMuls[i].Multiplicand << "\n";
1564 }
1565 });
1566 return nullptr;
1567 }
1568
1569 return Result;
1570}
1571
1572ComplexDeinterleavingGraph::CompositeNode *
1573ComplexDeinterleavingGraph::identifyAdditions(
1574 AddendList &RealAddends, AddendList &ImagAddends,
1575 std::optional<FastMathFlags> Flags, CompositeNode *Accumulator = nullptr) {
1576 if (RealAddends.size() != ImagAddends.size())
1577 return nullptr;
1578
1579 CompositeNode *Result = nullptr;
1580 // If we have accumulator use it as first addend
1581 if (Accumulator)
1583 // Otherwise find an element with both positive real and imaginary parts.
1584 else
1585 Result = extractPositiveAddend(RealAddends, ImagAddends);
1586
1587 if (!Result)
1588 return nullptr;
1589
1590 while (!RealAddends.empty()) {
1591 auto ItR = RealAddends.begin();
1592 auto [R, IsPositiveR] = *ItR;
1593
1594 bool FoundImag = false;
1595 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1596 auto [I, IsPositiveI] = *ItI;
1598 if (IsPositiveR && IsPositiveI)
1599 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1600 else if (!IsPositiveR && IsPositiveI)
1601 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1602 else if (!IsPositiveR && !IsPositiveI)
1603 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1604 else
1605 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1606
1607 CompositeNode *AddNode = nullptr;
1608 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1609 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1610 AddNode = identifyNode(R, I);
1611 } else {
1612 AddNode = identifyNode(I, R);
1613 }
1614 if (AddNode) {
1615 LLVM_DEBUG({
1616 dbgs() << "Identified addition:\n";
1617 dbgs().indent(4) << "X: " << *R << "\n";
1618 dbgs().indent(4) << "Y: " << *I << "\n";
1619 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1620 });
1621
1622 CompositeNode *TmpNode = nullptr;
1624 TmpNode = prepareCompositeNode(
1625 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1626 if (Flags) {
1627 TmpNode->Opcode = Instruction::FAdd;
1628 TmpNode->Flags = *Flags;
1629 } else {
1630 TmpNode->Opcode = Instruction::Add;
1631 }
1632 } else if (Rotation ==
1634 TmpNode = prepareCompositeNode(
1635 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1636 if (Flags) {
1637 TmpNode->Opcode = Instruction::FSub;
1638 TmpNode->Flags = *Flags;
1639 } else {
1640 TmpNode->Opcode = Instruction::Sub;
1641 }
1642 } else {
1643 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1644 nullptr, nullptr);
1645 TmpNode->Rotation = Rotation;
1646 }
1647
1648 TmpNode->addOperand(Result);
1649 TmpNode->addOperand(AddNode);
1650 submitCompositeNode(TmpNode);
1651 Result = TmpNode;
1652 RealAddends.erase(ItR);
1653 ImagAddends.erase(ItI);
1654 FoundImag = true;
1655 break;
1656 }
1657 }
1658 if (!FoundImag)
1659 return nullptr;
1660 }
1661 return Result;
1662}
1663
1664ComplexDeinterleavingGraph::CompositeNode *
1665ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1666 AddendList &ImagAddends) {
1667 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1668 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1669 auto [R, IsPositiveR] = *ItR;
1670 auto [I, IsPositiveI] = *ItI;
1671 if (IsPositiveR && IsPositiveI) {
1672 auto Result = identifyNode(R, I);
1673 if (Result) {
1674 RealAddends.erase(ItR);
1675 ImagAddends.erase(ItI);
1676 return Result;
1677 }
1678 }
1679 }
1680 }
1681 return nullptr;
1682}
1683
1684bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1685 // This potential root instruction might already have been recognized as
1686 // reduction. Because RootToNode maps both Real and Imaginary parts to
1687 // CompositeNode we should choose only one either Real or Imag instruction to
1688 // use as an anchor for generating complex instruction.
1689 auto It = RootToNode.find(RootI);
1690 if (It != RootToNode.end()) {
1691 auto RootNode = It->second;
1692 assert(RootNode->Operation ==
1693 ComplexDeinterleavingOperation::ReductionOperation ||
1694 RootNode->Operation ==
1695 ComplexDeinterleavingOperation::ReductionSingle);
1696 assert(RootNode->Vals.size() == 1 &&
1697 "Cannot handle reductions involving multiple complex values");
1698 // Find out which part, Real or Imag, comes later, and only if we come to
1699 // the latest part, add it to OrderedRoots.
1700 auto *R = cast<Instruction>(RootNode->Vals[0].Real);
1701 auto *I = RootNode->Vals[0].Imag ? cast<Instruction>(RootNode->Vals[0].Imag)
1702 : nullptr;
1703
1704 Instruction *ReplacementAnchor;
1705 if (I)
1706 ReplacementAnchor = R->comesBefore(I) ? I : R;
1707 else
1708 ReplacementAnchor = R;
1709
1710 if (ReplacementAnchor != RootI)
1711 return false;
1712 OrderedRoots.push_back(RootI);
1713 return true;
1714 }
1715
1716 auto RootNode = identifyRoot(RootI);
1717 if (!RootNode)
1718 return false;
1719
1720 LLVM_DEBUG({
1721 Function *F = RootI->getFunction();
1722 BasicBlock *B = RootI->getParent();
1723 dbgs() << "Complex deinterleaving graph for " << F->getName()
1724 << "::" << B->getName() << ".\n";
1725 dump(dbgs());
1726 dbgs() << "\n";
1727 });
1728 RootToNode[RootI] = RootNode;
1729 OrderedRoots.push_back(RootI);
1730 return true;
1731}
1732
1733bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1734 bool FoundPotentialReduction = false;
1735 if (Factor != 2)
1736 return false;
1737
1738 auto *Br = dyn_cast<CondBrInst>(B->getTerminator());
1739 if (!Br)
1740 return false;
1741
1742 // Identify simple one-block loop
1743 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1744 return false;
1745
1746 for (auto &PHI : B->phis()) {
1747 if (PHI.getNumIncomingValues() != 2)
1748 continue;
1749
1750 if (!PHI.getType()->isVectorTy())
1751 continue;
1752
1753 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1754 if (!ReductionOp)
1755 continue;
1756
1757 // Check if final instruction is reduced outside of current block
1758 Instruction *FinalReduction = nullptr;
1759 auto NumUsers = 0u;
1760 for (auto *U : ReductionOp->users()) {
1761 ++NumUsers;
1762 if (U == &PHI)
1763 continue;
1764 FinalReduction = dyn_cast<Instruction>(U);
1765 }
1766
1767 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1768 isa<PHINode>(FinalReduction))
1769 continue;
1770
1771 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1772 BackEdge = B;
1773 auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1774 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1775 Incoming = PHI.getIncomingBlock(IncomingIdx);
1776 FoundPotentialReduction = true;
1777
1778 // If the initial value of PHINode is an Instruction, consider it a leaf
1779 // value of a complex deinterleaving graph.
1780 if (auto *InitPHI =
1781 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1782 FinalInstructions.insert(InitPHI);
1783 }
1784 return FoundPotentialReduction;
1785}
1786
1787void ComplexDeinterleavingGraph::identifyReductionNodes() {
1788 assert(Factor == 2 && "Cannot handle multiple complex values");
1789
1790 SmallVector<bool> Processed(ReductionInfo.size(), false);
1791 SmallVector<Instruction *> OperationInstruction;
1792 for (auto &P : ReductionInfo)
1793 OperationInstruction.push_back(P.first);
1794
1795 // Identify a complex computation by evaluating two reduction operations that
1796 // potentially could be involved
1797 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1798 if (Processed[i])
1799 continue;
1800 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1801 if (Processed[j])
1802 continue;
1803 auto *Real = OperationInstruction[i];
1804 auto *Imag = OperationInstruction[j];
1805 if (Real->getType() != Imag->getType())
1806 continue;
1807
1808 RealPHI = ReductionInfo[Real].first;
1809 ImagPHI = ReductionInfo[Imag].first;
1810 PHIsFound = false;
1811 auto Node = identifyNode(Real, Imag);
1812 if (!Node) {
1813 std::swap(Real, Imag);
1814 std::swap(RealPHI, ImagPHI);
1815 Node = identifyNode(Real, Imag);
1816 }
1817
1818 // If a node is identified and reduction PHINode is used in the chain of
1819 // operations, mark its operation instructions as used to prevent
1820 // re-identification and attach the node to the real part
1821 if (Node && PHIsFound) {
1822 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1823 << *Real << " / " << *Imag << "\n");
1824 Processed[i] = true;
1825 Processed[j] = true;
1826 auto RootNode = prepareCompositeNode(
1827 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1828 RootNode->addOperand(Node);
1829 RootToNode[Real] = RootNode;
1830 RootToNode[Imag] = RootNode;
1831 submitCompositeNode(RootNode);
1832 break;
1833 }
1834 }
1835
1836 auto *Real = OperationInstruction[i];
1837 // We want to check that we have 2 operands, but the function attributes
1838 // being counted as operands bloats this value.
1839 if (Processed[i] || Real->getNumOperands() < 2)
1840 continue;
1841
1842 // Can only combined integer reductions at the moment.
1843 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1844 continue;
1845
1846 RealPHI = ReductionInfo[Real].first;
1847 ImagPHI = nullptr;
1848 PHIsFound = false;
1849 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1850 if (Node && PHIsFound) {
1851 LLVM_DEBUG(
1852 dbgs() << "Identified single reduction starting from instruction: "
1853 << *Real << "/" << *ReductionInfo[Real].second << "\n");
1854
1855 // Reducing to a single vector is not supported, only permit reducing down
1856 // to scalar values.
1857 // Doing this here will leave the prior node in the graph,
1858 // however with no uses the node will be unreachable by the replacement
1859 // process. That along with the usage outside the graph should prevent the
1860 // replacement process from kicking off at all for this graph.
1861 // TODO Add support for reducing to a single vector value
1862 if (ReductionInfo[Real].second->getType()->isVectorTy())
1863 continue;
1864
1865 Processed[i] = true;
1866 auto RootNode = prepareCompositeNode(
1867 ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
1868 RootNode->addOperand(Node);
1869 RootToNode[Real] = RootNode;
1870 submitCompositeNode(RootNode);
1871 }
1872 }
1873
1874 RealPHI = nullptr;
1875 ImagPHI = nullptr;
1876}
1877
1878bool ComplexDeinterleavingGraph::checkNodes() {
1879 bool FoundDeinterleaveNode = false;
1880 for (CompositeNode *N : CompositeNodes) {
1881 if (!N->areOperandsValid())
1882 return false;
1883
1884 if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1885 FoundDeinterleaveNode = true;
1886 }
1887
1888 // We need a deinterleave node in order to guarantee that we're working with
1889 // complex numbers.
1890 if (!FoundDeinterleaveNode) {
1891 LLVM_DEBUG(
1892 dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1893 "guarantee safety during graph transformation.\n");
1894 return false;
1895 }
1896
1897 // Collect all instructions from roots to leaves
1898 SmallPtrSet<Instruction *, 16> AllInstructions;
1899 SmallVector<Instruction *, 8> Worklist;
1900 for (auto &Pair : RootToNode)
1901 Worklist.push_back(Pair.first);
1902
1903 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1904 // chains
1905 while (!Worklist.empty()) {
1906 auto *I = Worklist.pop_back_val();
1907
1908 if (!AllInstructions.insert(I).second)
1909 continue;
1910
1911 for (Value *Op : I->operands()) {
1912 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1913 if (!FinalInstructions.count(I))
1914 Worklist.emplace_back(OpI);
1915 }
1916 }
1917 }
1918
1919 // Find instructions that have users outside of chain
1920 for (auto *I : AllInstructions) {
1921 // Skip root nodes
1922 if (RootToNode.count(I))
1923 continue;
1924
1925 for (User *U : I->users()) {
1926 if (AllInstructions.count(cast<Instruction>(U)))
1927 continue;
1928
1929 // Found an instruction that is not used by XCMLA/XCADD chain
1930 Worklist.emplace_back(I);
1931 break;
1932 }
1933 }
1934
1935 // If any instructions are found to be used outside, find and remove roots
1936 // that somehow connect to those instructions.
1937 SmallPtrSet<Instruction *, 16> Visited;
1938 while (!Worklist.empty()) {
1939 auto *I = Worklist.pop_back_val();
1940 if (!Visited.insert(I).second)
1941 continue;
1942
1943 // Found an impacted root node. Removing it from the nodes to be
1944 // deinterleaved
1945 if (RootToNode.count(I)) {
1946 LLVM_DEBUG(dbgs() << "Instruction " << *I
1947 << " could be deinterleaved but its chain of complex "
1948 "operations have an outside user\n");
1949 RootToNode.erase(I);
1950 }
1951
1952 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1953 continue;
1954
1955 for (User *U : I->users())
1956 Worklist.emplace_back(cast<Instruction>(U));
1957
1958 for (Value *Op : I->operands()) {
1959 if (auto *OpI = dyn_cast<Instruction>(Op))
1960 Worklist.emplace_back(OpI);
1961 }
1962 }
1963 return !RootToNode.empty();
1964}
1965
1966ComplexDeinterleavingGraph::CompositeNode *
1967ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1968 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1970 Intrinsic->getIntrinsicID())
1971 return nullptr;
1972
1973 ComplexValues Vals;
1974 for (unsigned I = 0; I < Factor; I += 2) {
1975 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(I));
1976 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(I + 1));
1977 if (!Real || !Imag)
1978 return nullptr;
1979 Vals.push_back({Real, Imag});
1980 }
1981
1982 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
1983 if (!Node1)
1984 return nullptr;
1985 return Node1;
1986 }
1987
1988 // TODO: We could also add support for fixed-width interleave factors of 4
1989 // and above, but currently for symmetric operations the interleaves and
1990 // deinterleaves are already removed by VectorCombine. If we extend this to
1991 // permit complex multiplications, reductions, etc. then we should also add
1992 // support for fixed-width here.
1993 if (Factor != 2)
1994 return nullptr;
1995
1996 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1997 if (!SVI)
1998 return nullptr;
1999
2000 // Look for a shufflevector that takes separate vectors of the real and
2001 // imaginary components and recombines them into a single vector.
2002 if (!isInterleavingMask(SVI->getShuffleMask()))
2003 return nullptr;
2004
2005 Instruction *Real;
2006 Instruction *Imag;
2007 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
2008 return nullptr;
2009
2010 return identifyNode(Real, Imag);
2011}
2012
2013ComplexDeinterleavingGraph::CompositeNode *
2014ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
2015 Instruction *II = nullptr;
2016
2017 // Must be at least one complex value.
2018 auto CheckExtract = [&](Value *V, unsigned ExpectedIdx,
2019 Instruction *ExpectedInsn) -> ExtractValueInst * {
2020 auto *EVI = dyn_cast<ExtractValueInst>(V);
2021 if (!EVI || EVI->getNumIndices() != 1 ||
2022 EVI->getIndices()[0] != ExpectedIdx ||
2023 !isa<Instruction>(EVI->getAggregateOperand()) ||
2024 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2025 return nullptr;
2026 return EVI;
2027 };
2028
2029 for (unsigned Idx = 0; Idx < Vals.size(); Idx++) {
2030 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2, II);
2031 if (RealEVI && Idx == 0)
2033 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1, II)) {
2034 II = nullptr;
2035 break;
2036 }
2037 }
2038
2039 if (auto *IntrinsicII = dyn_cast_or_null<IntrinsicInst>(II)) {
2040 if (IntrinsicII->getIntrinsicID() !=
2042 return nullptr;
2043
2044 // The remaining should match too.
2045 CompositeNode *PlaceholderNode = prepareCompositeNode(
2047 PlaceholderNode->ReplacementNode = II->getOperand(0);
2048 for (auto &V : Vals) {
2049 FinalInstructions.insert(cast<Instruction>(V.Real));
2050 FinalInstructions.insert(cast<Instruction>(V.Imag));
2051 }
2052 return submitCompositeNode(PlaceholderNode);
2053 }
2054
2055 if (Vals.size() != 1)
2056 return nullptr;
2057
2058 Value *Real = Vals[0].Real;
2059 Value *Imag = Vals[0].Imag;
2060 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
2061 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
2062 if (!RealShuffle || !ImagShuffle) {
2063 if (RealShuffle || ImagShuffle)
2064 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
2065 return nullptr;
2066 }
2067
2068 Value *RealOp1 = RealShuffle->getOperand(1);
2069 if (!isa<UndefValue>(RealOp1) && !match(RealOp1, m_Zero())) {
2070 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
2071 return nullptr;
2072 }
2073 Value *ImagOp1 = ImagShuffle->getOperand(1);
2074 if (!isa<UndefValue>(ImagOp1) && !match(ImagOp1, m_Zero())) {
2075 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
2076 return nullptr;
2077 }
2078
2079 Value *RealOp0 = RealShuffle->getOperand(0);
2080 Value *ImagOp0 = ImagShuffle->getOperand(0);
2081
2082 if (RealOp0 != ImagOp0) {
2083 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
2084 return nullptr;
2085 }
2086
2087 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2088 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2089 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
2090 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
2091 return nullptr;
2092 }
2093
2094 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2095 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
2096 return nullptr;
2097 }
2098
2099 // Type checking, the shuffle type should be a vector type of the same
2100 // scalar type, but half the size
2101 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2102 Value *Op = Shuffle->getOperand(0);
2103 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
2104 auto *OpTy = cast<FixedVectorType>(Op->getType());
2105
2106 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2107 return false;
2108 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2109 return false;
2110
2111 return true;
2112 };
2113
2114 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
2115 if (!CheckType(Shuffle))
2116 return false;
2117
2118 ArrayRef<int> Mask = Shuffle->getShuffleMask();
2119 int Last = *Mask.rbegin();
2120
2121 Value *Op = Shuffle->getOperand(0);
2122 auto *OpTy = cast<FixedVectorType>(Op->getType());
2123 int NumElements = OpTy->getNumElements();
2124
2125 // Ensure that the deinterleaving shuffle only pulls from the first
2126 // shuffle operand.
2127 return Last < NumElements;
2128 };
2129
2130 if (RealShuffle->getType() != ImagShuffle->getType()) {
2131 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
2132 return nullptr;
2133 }
2134 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2135 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
2136 return nullptr;
2137 }
2138 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2139 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
2140 return nullptr;
2141 }
2142
2143 CompositeNode *PlaceholderNode =
2145 RealShuffle, ImagShuffle);
2146 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
2147 FinalInstructions.insert(RealShuffle);
2148 FinalInstructions.insert(ImagShuffle);
2149 return submitCompositeNode(PlaceholderNode);
2150}
2151
2152ComplexDeinterleavingGraph::CompositeNode *
2153ComplexDeinterleavingGraph::identifySplat(ComplexValues &Vals) {
2154 auto IsSplat = [](Value *V) -> bool {
2155 // Fixed-width vector with constants
2157 return true;
2158
2159 if (isa<ConstantInt>(V) || isa<ConstantFP>(V))
2160 return isa<VectorType>(V->getType());
2161
2162 VectorType *VTy;
2163 ArrayRef<int> Mask;
2164 // Splats are represented differently depending on whether the repeated
2165 // value is a constant or an Instruction
2166 if (auto *Const = dyn_cast<ConstantExpr>(V)) {
2167 if (Const->getOpcode() != Instruction::ShuffleVector)
2168 return false;
2169 VTy = cast<VectorType>(Const->getType());
2170 Mask = Const->getShuffleMask();
2171 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
2172 VTy = Shuf->getType();
2173 Mask = Shuf->getShuffleMask();
2174 } else {
2175 return false;
2176 }
2177
2178 // When the data type is <1 x Type>, it's not possible to differentiate
2179 // between the ComplexDeinterleaving::Deinterleave and
2180 // ComplexDeinterleaving::Splat operations.
2181 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2182 return false;
2183
2184 return all_equal(Mask) && Mask[0] == 0;
2185 };
2186
2187 // The splats must meet the following requirements:
2188 // 1. Must either be all instructions or all values.
2189 // 2. Non-constant splats must live in the same block.
2190 if (auto *FirstValAsInstruction = dyn_cast<Instruction>(Vals[0].Real)) {
2191 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2192 for (auto &V : Vals) {
2193 if (!IsSplat(V.Real) || !IsSplat(V.Imag))
2194 return nullptr;
2195
2196 auto *Real = dyn_cast<Instruction>(V.Real);
2197 auto *Imag = dyn_cast<Instruction>(V.Imag);
2198 if (!Real || !Imag || Real->getParent() != FirstBB ||
2199 Imag->getParent() != FirstBB)
2200 return nullptr;
2201 }
2202 } else {
2203 for (auto &V : Vals) {
2204 if (!IsSplat(V.Real) || !IsSplat(V.Imag) || isa<Instruction>(V.Real) ||
2205 isa<Instruction>(V.Imag))
2206 return nullptr;
2207 }
2208 }
2209
2210 for (auto &V : Vals) {
2211 auto *Real = dyn_cast<Instruction>(V.Real);
2212 auto *Imag = dyn_cast<Instruction>(V.Imag);
2213 if (Real && Imag) {
2214 FinalInstructions.insert(Real);
2215 FinalInstructions.insert(Imag);
2216 }
2217 }
2218 CompositeNode *PlaceholderNode =
2219 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
2220 return submitCompositeNode(PlaceholderNode);
2221}
2222
2223ComplexDeinterleavingGraph::CompositeNode *
2224ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2225 Instruction *Imag) {
2226 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2227 return nullptr;
2228
2229 PHIsFound = true;
2230 CompositeNode *PlaceholderNode = prepareCompositeNode(
2231 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2232 return submitCompositeNode(PlaceholderNode);
2233}
2234
2235ComplexDeinterleavingGraph::CompositeNode *
2236ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2237 Instruction *Imag) {
2238 auto *SelectReal = dyn_cast<SelectInst>(Real);
2239 auto *SelectImag = dyn_cast<SelectInst>(Imag);
2240 if (!SelectReal || !SelectImag)
2241 return nullptr;
2242
2243 Instruction *MaskA, *MaskB;
2244 Instruction *AR, *AI, *RA, *BI;
2245 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
2246 m_Instruction(RA))) ||
2247 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
2248 m_Instruction(BI))))
2249 return nullptr;
2250
2251 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
2252 return nullptr;
2253
2254 if (!MaskA->getType()->isVectorTy())
2255 return nullptr;
2256
2257 auto NodeA = identifyNode(AR, AI);
2258 if (!NodeA)
2259 return nullptr;
2260
2261 auto NodeB = identifyNode(RA, BI);
2262 if (!NodeB)
2263 return nullptr;
2264
2265 CompositeNode *PlaceholderNode = prepareCompositeNode(
2266 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2267 PlaceholderNode->addOperand(NodeA);
2268 PlaceholderNode->addOperand(NodeB);
2269 FinalInstructions.insert(MaskA);
2270 FinalInstructions.insert(MaskB);
2271 return submitCompositeNode(PlaceholderNode);
2272}
2273
2274static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
2275 std::optional<FastMathFlags> Flags,
2276 Value *InputA, Value *InputB) {
2277 Value *I;
2278 switch (Opcode) {
2279 case Instruction::FNeg:
2280 I = B.CreateFNeg(InputA);
2281 break;
2282 case Instruction::FAdd:
2283 I = B.CreateFAdd(InputA, InputB);
2284 break;
2285 case Instruction::Add:
2286 I = B.CreateAdd(InputA, InputB);
2287 break;
2288 case Instruction::FSub:
2289 I = B.CreateFSub(InputA, InputB);
2290 break;
2291 case Instruction::Sub:
2292 I = B.CreateSub(InputA, InputB);
2293 break;
2294 case Instruction::FMul:
2295 I = B.CreateFMul(InputA, InputB);
2296 break;
2297 case Instruction::Mul:
2298 I = B.CreateMul(InputA, InputB);
2299 break;
2300 default:
2301 llvm_unreachable("Incorrect symmetric opcode");
2302 }
2303 if (Flags)
2304 cast<Instruction>(I)->setFastMathFlags(*Flags);
2305 return I;
2306}
2307
2308Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2309 CompositeNode *Node) {
2310 if (Node->ReplacementNode)
2311 return Node->ReplacementNode;
2312
2313 auto ReplaceOperandIfExist = [&](CompositeNode *Node,
2314 unsigned Idx) -> Value * {
2315 return Node->Operands.size() > Idx
2316 ? replaceNode(Builder, Node->Operands[Idx])
2317 : nullptr;
2318 };
2319
2320 Value *ReplacementNode = nullptr;
2321 switch (Node->Operation) {
2322 case ComplexDeinterleavingOperation::CDot: {
2323 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2324 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2325 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2326 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2327 "Node inputs need to be of the same type"));
2328 ReplacementNode = TL->createComplexDeinterleavingIR(
2329 Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
2330 break;
2331 }
2332 case ComplexDeinterleavingOperation::CAdd:
2333 case ComplexDeinterleavingOperation::CMulPartial:
2334 case ComplexDeinterleavingOperation::Symmetric: {
2335 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2336 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2337 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2338 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2339 "Node inputs need to be of the same type"));
2341 (Input0->getType() == Accumulator->getType() &&
2342 "Accumulator and input need to be of the same type"));
2343 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2344 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
2345 Input0, Input1);
2346 else
2347 ReplacementNode = TL->createComplexDeinterleavingIR(
2348 Builder, Node->Operation, Node->Rotation, Input0, Input1,
2349 Accumulator);
2350 break;
2351 }
2352 case ComplexDeinterleavingOperation::Deinterleave:
2353 llvm_unreachable("Deinterleave node should already have ReplacementNode");
2354 break;
2355 case ComplexDeinterleavingOperation::Splat: {
2357 for (auto &V : Node->Vals) {
2358 Ops.push_back(V.Real);
2359 Ops.push_back(V.Imag);
2360 }
2361 auto *R = dyn_cast<Instruction>(Node->Vals[0].Real);
2362 auto *I = dyn_cast<Instruction>(Node->Vals[0].Imag);
2363 if (R && I) {
2364 // Splats that are not constant are interleaved where they are located
2365 Instruction *InsertPoint = R;
2366 for (auto V : Node->Vals) {
2367 if (InsertPoint->comesBefore(cast<Instruction>(V.Real)))
2368 InsertPoint = cast<Instruction>(V.Real);
2369 if (InsertPoint->comesBefore(cast<Instruction>(V.Imag)))
2370 InsertPoint = cast<Instruction>(V.Imag);
2371 }
2372 InsertPoint = InsertPoint->getNextNode();
2373 IRBuilder<> IRB(InsertPoint);
2374 ReplacementNode = IRB.CreateVectorInterleave(Ops);
2375 } else {
2376 ReplacementNode = Builder.CreateVectorInterleave(Ops);
2377 }
2378 break;
2379 }
2380 case ComplexDeinterleavingOperation::ReductionPHI: {
2381 // If Operation is ReductionPHI, a new empty PHINode is created.
2382 // It is filled later when the ReductionOperation is processed.
2383 auto *OldPHI = cast<PHINode>(Node->Vals[0].Real);
2384 auto *VTy = cast<VectorType>(Node->Vals[0].Real->getType());
2385 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2386 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
2387 OldToNewPHI[OldPHI] = NewPHI;
2388 ReplacementNode = NewPHI;
2389 break;
2390 }
2391 case ComplexDeinterleavingOperation::ReductionSingle:
2392 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2393 processReductionSingle(ReplacementNode, Node);
2394 break;
2395 case ComplexDeinterleavingOperation::ReductionOperation:
2396 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2397 processReductionOperation(ReplacementNode, Node);
2398 break;
2399 case ComplexDeinterleavingOperation::ReductionSelect: {
2400 auto *MaskReal = cast<Instruction>(Node->Vals[0].Real)->getOperand(0);
2401 auto *MaskImag = cast<Instruction>(Node->Vals[0].Imag)->getOperand(0);
2402 auto *A = replaceNode(Builder, Node->Operands[0]);
2403 auto *B = replaceNode(Builder, Node->Operands[1]);
2404 auto *NewMask = Builder.CreateVectorInterleave({MaskReal, MaskImag});
2405 ReplacementNode = Builder.CreateSelect(NewMask, A, B);
2406 break;
2407 }
2408 }
2409
2410 assert(ReplacementNode && "Target failed to create Intrinsic call.");
2411 NumComplexTransformations += 1;
2412 Node->ReplacementNode = ReplacementNode;
2413 return ReplacementNode;
2414}
2415
2416void ComplexDeinterleavingGraph::processReductionSingle(
2417 Value *OperationReplacement, CompositeNode *Node) {
2418 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2419 auto *OldPHI = ReductionInfo[Real].first;
2420 auto *NewPHI = OldToNewPHI[OldPHI];
2421 auto *VTy = cast<VectorType>(Real->getType());
2422 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2423
2424 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2425
2426 IRBuilder<> Builder(Incoming->getTerminator());
2427
2428 Value *NewInit = nullptr;
2429 if (auto *C = dyn_cast<Constant>(Init)) {
2430 if (C->isNullValue())
2431 NewInit = Constant::getNullValue(NewVTy);
2432 }
2433
2434 if (!NewInit)
2435 NewInit =
2436 Builder.CreateVectorInterleave({Init, Constant::getNullValue(VTy)});
2437
2438 NewPHI->addIncoming(NewInit, Incoming);
2439 NewPHI->addIncoming(OperationReplacement, BackEdge);
2440
2441 auto *FinalReduction = ReductionInfo[Real].second;
2442 Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2443
2444 auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
2445 FinalReduction->replaceAllUsesWith(AddReduce);
2446}
2447
2448void ComplexDeinterleavingGraph::processReductionOperation(
2449 Value *OperationReplacement, CompositeNode *Node) {
2450 auto *Real = cast<Instruction>(Node->Vals[0].Real);
2451 auto *Imag = cast<Instruction>(Node->Vals[0].Imag);
2452 auto *OldPHIReal = ReductionInfo[Real].first;
2453 auto *OldPHIImag = ReductionInfo[Imag].first;
2454 auto *NewPHI = OldToNewPHI[OldPHIReal];
2455
2456 // We have to interleave initial origin values coming from IncomingBlock
2457 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2458 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2459
2460 IRBuilder<> Builder(Incoming->getTerminator());
2461 auto *NewInit = Builder.CreateVectorInterleave({InitReal, InitImag});
2462
2463 NewPHI->addIncoming(NewInit, Incoming);
2464 NewPHI->addIncoming(OperationReplacement, BackEdge);
2465
2466 // Deinterleave complex vector outside of loop so that it can be finally
2467 // reduced
2468 auto *FinalReductionReal = ReductionInfo[Real].second;
2469 auto *FinalReductionImag = ReductionInfo[Imag].second;
2470
2471 auto *Br = cast<CondBrInst>(BackEdge->getTerminator());
2472 BasicBlock *ExitBB = Br->getSuccessor(Br->getSuccessor(0) == BackEdge);
2473 Builder.SetInsertPoint(&*ExitBB->getFirstInsertionPt());
2474
2475 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2476 OperationReplacement->getType(),
2477 OperationReplacement);
2478
2479 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2480 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2481
2482 Builder.SetInsertPoint(FinalReductionImag);
2483 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2484 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2485}
2486
2487void ComplexDeinterleavingGraph::replaceNodes() {
2488 SmallVector<Instruction *, 16> DeadInstrRoots;
2489 for (auto *RootInstruction : OrderedRoots) {
2490 // Check if this potential root went through check process and we can
2491 // deinterleave it
2492 if (!RootToNode.count(RootInstruction))
2493 continue;
2494
2495 IRBuilder<> Builder(RootInstruction);
2496 auto RootNode = RootToNode[RootInstruction];
2497 Value *R = replaceNode(Builder, RootNode);
2498
2499 if (RootNode->Operation ==
2500 ComplexDeinterleavingOperation::ReductionOperation) {
2501 auto *RootReal = cast<Instruction>(RootNode->Vals[0].Real);
2502 auto *RootImag = cast<Instruction>(RootNode->Vals[0].Imag);
2503 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2504 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2505 DeadInstrRoots.push_back(RootReal);
2506 DeadInstrRoots.push_back(RootImag);
2507 } else if (RootNode->Operation ==
2508 ComplexDeinterleavingOperation::ReductionSingle) {
2509 auto *RootInst = cast<Instruction>(RootNode->Vals[0].Real);
2510 auto &Info = ReductionInfo[RootInst];
2511 Info.first->removeIncomingValue(BackEdge);
2512 DeadInstrRoots.push_back(Info.second);
2513 } else {
2514 assert(R && "Unable to find replacement for RootInstruction");
2515 DeadInstrRoots.push_back(RootInstruction);
2516 RootInstruction->replaceAllUsesWith(R);
2517 }
2518 }
2519
2520 for (auto *I : DeadInstrRoots)
2522}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static MCDisassembler::DecodeStatus addOperand(MCInst &Inst, const MCOperand &Opnd)
Rewrite undef for PHI
This file defines the BumpPtrAllocator interface.
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool isInstructionPotentiallySymmetric(Instruction *I)
static Value * getNegOperand(Value *V)
Returns the operand for negation operation.
static bool isNeg(Value *V)
Returns true if the operation is a negation of V, and it works for both integers and floats.
static cl::opt< bool > ComplexDeinterleavingEnabled("enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden)
static bool isInstructionPairAdd(Instruction *A, Instruction *B)
static Value * replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional< FastMathFlags > Flags, Value *InputA, Value *InputB)
static bool isInterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is interleaving.
static bool isDeinterleavingMask(ArrayRef< int > Mask)
Checks the given mask, and determines whether said mask is deinterleaving.
SmallVector< struct ComplexValue, 2 > ComplexValues
static bool isInstructionPairMul(Instruction *A, Instruction *B)
static bool runOnFunction(Function &F, bool PostInlining)
#define DEBUG_TYPE
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
This file implements a map that provides insertion order iteration.
#define T
uint64_t IntrinsicInst * II
#define P(N)
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
SI optimize exec mask operations pre RA
static LLVM_ATTRIBUTE_ALWAYS_INLINE bool CheckType(MVT::SimpleValueType VT, SDValue N, const TargetLowering *TLI, const DataLayout &DL)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:119
This file describes how to lower LLVM code to machine code.
This pass exposes codegen information to IR-level passes.
BinaryOperator * Mul
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
size_t size() const
Get the array size.
Definition ArrayRef.h:141
LLVM_ABI const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:225
iterator end()
Definition DenseMap.h:143
bool allowContract() const
Definition FMF.h:69
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
Common base class shared among various IRBuilders.
Definition IRBuilder.h:114
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition IRBuilder.h:2684
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > OverloadTypes, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="", ArrayRef< OperandBundleDef > OpBundles={})
Create a call to intrinsic ID with Args, mangled using OverloadTypes.
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition IRBuilder.h:207
LLVM_ABI Value * CreateVectorInterleave(ArrayRef< Value * > Ops, const Twine &Name="")
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI bool isIdenticalTo(const Instruction *I) const LLVM_READONLY
Return true if the specified instruction is exactly identical to the current one.
size_type size() const
Definition MapVector.h:58
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetLibraryInfo.
virtual bool isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation, Type *Ty) const
Does this target support complex deinterleaving with the given operation and type.
virtual Value * createComplexDeinterleavingIR(IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator=nullptr) const
Create the IR node for the given complex deinterleaving operation.
virtual bool isComplexDeinterleavingSupported() const
Does this target support complex deinterleaving.
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
Primary interface to the complete machine description for the target machine.
virtual const TargetSubtargetInfo * getSubtargetImpl(const Function &) const
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
virtual const TargetLowering * getTargetLowering() const
bool isVectorTy() const
True if this is an instance of VectorType.
Definition Type.h:288
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:552
An opaque object representing a hash code.
Definition Hashing.h:78
const ParentTy * getParent() const
Definition ilist_node.h:34
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BR
Control flow instructions. These all have token chains.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor)
Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
match_bind< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BinOp()
Match an arbitrary binary operation and ignore it.
auto m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
NodeAddr< NodeBase * > Node
Definition RDFGraph.h:381
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
@ Offset
Definition DWP.cpp:558
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
hash_code hash_value(const FixedPointSemantics &Val)
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition Local.cpp:535
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
LLVM_ABI FunctionPass * createComplexDeinterleavingPass(const TargetMachine *TM)
This pass implements generation of target-specific intrinsics to support handling of complex number a...
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
@ Other
Any other memory.
Definition ModRef.h:68
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1771
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1946
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
Definition STLExtras.h:2165
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
Definition Hashing.h:325
AllocatorList< T, BumpPtrAllocator > BumpPtrList
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:863
#define N
ComplexDeinterleavingPass(const TargetMachine &TM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS)
static unsigned getHashValue(const ComplexValue &Val)
An information struct used to provide DenseMap with the various necessary components for a given valu...