LLVM 23.0.0git
StraightLineStrengthReduce.cpp
Go to the documentation of this file.
1//===- StraightLineStrengthReduce.cpp - -----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements straight-line strength reduction (SLSR). Unlike loop
10// strength reduction, this algorithm is designed to reduce arithmetic
11// redundancy in straight-line code instead of loops. It has proven to be
12// effective in simplifying arithmetic statements derived from an unrolled loop.
13// It can also simplify the logic of SeparateConstOffsetFromGEP.
14//
15// There are many optimizations we can perform in the domain of SLSR.
16// We look for strength reduction candidates in the following forms:
17//
18// Form Add: B + i * S
19// Form Mul: (B + i) * S
20// Form GEP: &B[i * S]
21//
22// where S is an integer variable, and i is a constant integer. If we found two
23// candidates S1 and S2 in the same form and S1 dominates S2, we may rewrite S2
24// in a simpler way with respect to S1 (index delta). For example,
25//
26// S1: X = B + i * S
27// S2: Y = B + i' * S => X + (i' - i) * S
28//
29// S1: X = (B + i) * S
30// S2: Y = (B + i') * S => X + (i' - i) * S
31//
32// S1: X = &B[i * S]
33// S2: Y = &B[i' * S] => &X[(i' - i) * S]
34//
35// Note: (i' - i) * S is folded to the extent possible.
36//
37// For Add and GEP forms, we can also rewrite a candidate in a simpler way
38// with respect to other dominating candidates if their B or S are different
39// but other parts are the same. For example,
40//
41// Base Delta:
42// S1: X = B + i * S
43// S2: Y = B' + i * S => X + (B' - B)
44//
45// S1: X = &B [i * S]
46// S2: Y = &B'[i * S] => X + (B' - B)
47//
48// Stride Delta:
49// S1: X = B + i * S
50// S2: Y = B + i * S' => X + i * (S' - S)
51//
52// S1: X = &B[i * S]
53// S2: Y = &B[i * S'] => X + i * (S' - S)
54//
55// PS: Stride delta rewrite on Mul form is usually non-profitable, and Base
56// delta rewrite sometimes is profitable, so we do not support them on Mul.
57//
58// This rewriting is in general a good idea. The code patterns we focus on
59// usually come from loop unrolling, so the delta is likely the same
60// across iterations and can be reused. When that happens, the optimized form
61// takes only one add starting from the second iteration.
62//
63// When such rewriting is possible, we call S1 a "basis" of S2. When S2 has
64// multiple bases, we choose to rewrite S2 with respect to its "immediate"
65// basis, the basis that is the closest ancestor in the dominator tree.
66//
67// TODO:
68//
69// - Floating point arithmetics when fast math is enabled.
70
72#include "llvm/ADT/APInt.h"
74#include "llvm/ADT/SetVector.h"
80#include "llvm/IR/Constants.h"
81#include "llvm/IR/DataLayout.h"
83#include "llvm/IR/Dominators.h"
85#include "llvm/IR/IRBuilder.h"
86#include "llvm/IR/Instruction.h"
88#include "llvm/IR/Module.h"
89#include "llvm/IR/Operator.h"
91#include "llvm/IR/Type.h"
92#include "llvm/IR/Value.h"
94#include "llvm/Pass.h"
100#include <cassert>
101#include <cstdint>
102#include <limits>
103#include <list>
104#include <queue>
105#include <vector>
106
107using namespace llvm;
108using namespace PatternMatch;
109
110#define DEBUG_TYPE "slsr"
111
112static const unsigned UnknownAddressSpace =
113 std::numeric_limits<unsigned>::max();
114
115DEBUG_COUNTER(StraightLineStrengthReduceCounter, "slsr-counter",
116 "Controls whether rewriteCandidate is executed.");
117
118// Only for testing.
119static cl::opt<bool>
120 EnablePoisonReuseGuard("enable-poison-reuse-guard", cl::init(true),
121 cl::desc("Enable poison-reuse guard"));
122
123namespace {
124
125class StraightLineStrengthReduceLegacyPass : public FunctionPass {
126 const DataLayout *DL = nullptr;
127
128public:
129 static char ID;
130
131 StraightLineStrengthReduceLegacyPass() : FunctionPass(ID) {
134 }
135
136 void getAnalysisUsage(AnalysisUsage &AU) const override {
137 AU.addRequired<DominatorTreeWrapperPass>();
138 AU.addRequired<ScalarEvolutionWrapperPass>();
139 AU.addRequired<TargetTransformInfoWrapperPass>();
140 // We do not modify the shape of the CFG.
141 AU.setPreservesCFG();
142 }
143
144 bool doInitialization(Module &M) override {
145 DL = &M.getDataLayout();
146 return false;
147 }
148
149 bool runOnFunction(Function &F) override;
150};
151
152class StraightLineStrengthReduce {
153public:
154 StraightLineStrengthReduce(const DataLayout *DL, DominatorTree *DT,
155 ScalarEvolution *SE, TargetTransformInfo *TTI)
156 : DL(DL), DT(DT), SE(SE), TTI(TTI) {}
157
158 // SLSR candidate. Such a candidate must be in one of the forms described in
159 // the header comments.
160 struct Candidate {
161 enum Kind {
162 Invalid, // reserved for the default constructor
163 Add, // B + i * S
164 Mul, // (B + i) * S
165 GEP, // &B[..][i * S][..]
166 };
167
168 enum DKind {
169 InvalidDelta, // reserved for the default constructor
170 IndexDelta, // Delta is a constant from Index
171 BaseDelta, // Delta is a constant or variable from Base
172 StrideDelta, // Delta is a constant or variable from Stride
173 };
174
175 Candidate() = default;
176 Candidate(Kind CT, const SCEV *B, ConstantInt *Idx, Value *S,
177 Instruction *I, const SCEV *StrideSCEV)
178 : CandidateKind(CT), Base(B), Index(Idx), Stride(S), Ins(I),
179 StrideSCEV(StrideSCEV) {}
180
181 Kind CandidateKind = Invalid;
182
183 const SCEV *Base = nullptr;
184 // TODO: Swap Index and Stride's name.
185 // Note that Index and Stride of a GEP candidate do not necessarily have the
186 // same integer type. In that case, during rewriting, Stride will be
187 // sign-extended or truncated to Index's type.
188 ConstantInt *Index = nullptr;
189
190 Value *Stride = nullptr;
191
192 // The instruction this candidate corresponds to. It helps us to rewrite a
193 // candidate with respect to its immediate basis. Note that one instruction
194 // can correspond to multiple candidates depending on how you associate the
195 // expression. For instance,
196 //
197 // (a + 1) * (b + 2)
198 //
199 // can be treated as
200 //
201 // <Base: a, Index: 1, Stride: b + 2>
202 //
203 // or
204 //
205 // <Base: b, Index: 2, Stride: a + 1>
206 Instruction *Ins = nullptr;
207
208 // Points to the immediate basis of this candidate, or nullptr if we cannot
209 // find any basis for this candidate.
210 Candidate *Basis = nullptr;
211
212 DKind DeltaKind = InvalidDelta;
213
214 // Store SCEV of Stride to compute delta from different strides
215 const SCEV *StrideSCEV = nullptr;
216
217 // Points to (Y - X) that will be used to rewrite this candidate.
218 Value *Delta = nullptr;
219
220 // List of instructions we need to drop poison generating annotations from.
221 // This is used so we can defer dropping until the candidate is evaluated.
223
224 /// Cost model: Evaluate the computational efficiency of the candidate.
225 ///
226 /// Efficiency levels (higher is better):
227 /// ZeroInst (5) - [Variable] or [Const]
228 /// OneInstOneVar (4) - [Variable + Const] or [Variable * Const]
229 /// OneInstTwoVar (3) - [Variable + Variable] or [Variable * Variable]
230 /// TwoInstOneVar (2) - [Const + Const * Variable]
231 /// TwoInstTwoVar (1) - [Variable + Const * Variable]
232 enum EfficiencyLevel : unsigned {
233 Unknown = 0,
234 TwoInstTwoVar = 1,
235 TwoInstOneVar = 2,
236 OneInstTwoVar = 3,
237 OneInstOneVar = 4,
238 ZeroInst = 5
239 };
240
241 static EfficiencyLevel
242 getComputationEfficiency(Kind CandidateKind, const ConstantInt *Index,
243 const Value *Stride, const SCEV *Base = nullptr) {
244 bool IsConstantBase = false;
245 bool IsZeroBase = false;
246 // When evaluating the efficiency of a rewrite, if the Base's SCEV is
247 // not available, conservatively assume the base is not constant.
248 if (auto *ConstBase = dyn_cast_or_null<SCEVConstant>(Base)) {
249 IsConstantBase = true;
250 IsZeroBase = ConstBase->getValue()->isZero();
251 }
252
253 bool IsConstantStride = isa<ConstantInt>(Stride);
254 bool IsZeroStride =
255 IsConstantStride && cast<ConstantInt>(Stride)->isZero();
256 // All constants
257 if (IsConstantBase && IsConstantStride)
258 return ZeroInst;
259
260 // (Base + Index) * Stride
261 if (CandidateKind == Mul) {
262 if (IsZeroStride)
263 return ZeroInst;
264 if (Index->isZero())
265 return (IsConstantStride || IsConstantBase) ? OneInstOneVar
266 : OneInstTwoVar;
267
268 if (IsConstantBase)
269 return IsZeroBase && (Index->isOne() || Index->isMinusOne())
270 ? ZeroInst
271 : OneInstOneVar;
272
273 if (IsConstantStride) {
274 auto *CI = cast<ConstantInt>(Stride);
275 return (CI->isOne() || CI->isMinusOne()) ? OneInstOneVar
276 : TwoInstOneVar;
277 }
278 return TwoInstTwoVar;
279 }
280
281 // Base + Index * Stride
282 assert(CandidateKind == Add || CandidateKind == GEP);
283 if (Index->isZero() || IsZeroStride)
284 return ZeroInst;
285
286 bool IsSimpleIndex = Index->isOne() || Index->isMinusOne();
287
288 if (IsConstantBase)
289 return IsZeroBase ? (IsSimpleIndex ? ZeroInst : OneInstOneVar)
290 : (IsSimpleIndex ? OneInstOneVar : TwoInstOneVar);
291
292 if (IsConstantStride)
293 return IsZeroStride ? ZeroInst : OneInstOneVar;
294
295 if (IsSimpleIndex)
296 return OneInstTwoVar;
297
298 return TwoInstTwoVar;
299 }
300
301 // Evaluate if the given delta is profitable to rewrite this candidate.
302 bool isProfitableRewrite(const Value &Delta, const DKind DeltaKind) const {
303 // This function cannot accurately evaluate the profit of whole expression
304 // with context. A candidate (B + I * S) cannot express whether this
305 // instruction needs to compute on its own (I * S), which may be shared
306 // with other candidates or may need instructions to compute.
307 // If the rewritten form has the same strength, still rewrite to
308 // (X + Delta) since it may expose more CSE opportunities on Delta, as
309 // unrolled loops usually have identical Delta for each unrolled body.
310 //
311 // Note, this function should only be used on Index Delta rewrite.
312 // Base and Stride delta need context info to evaluate the register
313 // pressure impact from variable delta.
314 return getComputationEfficiency(CandidateKind, Index, Stride, Base) <=
315 getRewriteEfficiency(Delta, DeltaKind);
316 }
317
318 // Evaluate the rewrite efficiency of this candidate with its Basis
319 EfficiencyLevel getRewriteEfficiency() const {
320 return Basis ? getRewriteEfficiency(*Delta, DeltaKind) : Unknown;
321 }
322
323 // Evaluate the rewrite efficiency of this candidate with a given delta
324 EfficiencyLevel getRewriteEfficiency(const Value &Delta,
325 const DKind DeltaKind) const {
326 switch (DeltaKind) {
327 case BaseDelta: // [X + Delta]
328 return getComputationEfficiency(
329 CandidateKind,
330 ConstantInt::get(cast<IntegerType>(Delta.getType()), 1), &Delta);
331 case StrideDelta: // [X + Index * Delta]
332 return getComputationEfficiency(CandidateKind, Index, &Delta);
333 case IndexDelta: // [X + Delta * Stride]
334 return getComputationEfficiency(CandidateKind,
335 cast<ConstantInt>(&Delta), Stride);
336 default:
337 return Unknown;
338 }
339 }
340
341 bool isHighEfficiency() const {
342 return getComputationEfficiency(CandidateKind, Index, Stride, Base) >=
343 OneInstOneVar;
344 }
345
346 // Verify that this candidate has valid delta components relative to the
347 // basis
348 bool hasValidDelta(const Candidate &Basis) const {
349 switch (DeltaKind) {
350 case IndexDelta:
351 // Index differs, Base and Stride must match
352 return Base == Basis.Base && StrideSCEV == Basis.StrideSCEV;
353 case StrideDelta:
354 // Stride differs, Base and Index must match
355 return Base == Basis.Base && Index == Basis.Index;
356 case BaseDelta:
357 // Base differs, Stride and Index must match
358 return StrideSCEV == Basis.StrideSCEV && Index == Basis.Index;
359 default:
360 return false;
361 }
362 }
363 };
364
365 bool runOnFunction(Function &F);
366
367private:
368 // Fetch straight-line basis for rewriting C, update C.Basis to point to it,
369 // and store the delta between C and its Basis in C.Delta.
370 void setBasisAndDeltaFor(Candidate &C);
371 // Returns whether the candidate can be folded into an addressing mode.
372 bool isFoldable(const Candidate &C, TargetTransformInfo *TTI);
373
374 // Checks whether I is in a candidate form. If so, adds all the matching forms
375 // to Candidates, and tries to find the immediate basis for each of them.
376 void allocateCandidatesAndFindBasis(Instruction *I);
377
378 // Allocate candidates and find bases for Add instructions.
379 void allocateCandidatesAndFindBasisForAdd(Instruction *I);
380
381 // Given I = LHS + RHS, factors RHS into i * S and makes (LHS + i * S) a
382 // candidate.
383 void allocateCandidatesAndFindBasisForAdd(Value *LHS, Value *RHS,
384 Instruction *I);
385 // Allocate candidates and find bases for Mul instructions.
386 void allocateCandidatesAndFindBasisForMul(Instruction *I);
387
388 // Splits LHS into Base + Index and, if succeeds, calls
389 // allocateCandidatesAndFindBasis.
390 void allocateCandidatesAndFindBasisForMul(Value *LHS, Value *RHS,
391 Instruction *I);
392
393 // Allocate candidates and find bases for GetElementPtr instructions.
394 void allocateCandidatesAndFindBasisForGEP(GetElementPtrInst *GEP);
395
396 // Adds the given form <CT, B, Idx, S> to Candidates, and finds its immediate
397 // basis.
398 void allocateCandidatesAndFindBasis(Candidate::Kind CT, const SCEV *B,
399 ConstantInt *Idx, Value *S,
400 Instruction *I);
401
402 // Rewrites candidate C with respect to Basis.
403 void rewriteCandidate(const Candidate &C);
404
405 // Emit code that computes the "bump" from Basis to C.
406 static Value *emitBump(const Candidate &Basis, const Candidate &C,
407 IRBuilder<> &Builder, const DataLayout *DL);
408
409 const DataLayout *DL = nullptr;
410 DominatorTree *DT = nullptr;
411 ScalarEvolution *SE;
412 TargetTransformInfo *TTI = nullptr;
413 std::list<Candidate> Candidates;
414
415 // Map from SCEV to instructions that represent the value,
416 // instructions are sorted in depth-first order.
417 DenseMap<const SCEV *, SmallSetVector<Instruction *, 2>> SCEVToInsts;
418
419 // Record the dependency between instructions. If C.Basis == B, we would have
420 // {B.Ins -> {C.Ins, ...}}.
421 MapVector<Instruction *, std::vector<Instruction *>> DependencyGraph;
422
423 // Map between each instruction and its possible candidates.
424 DenseMap<Instruction *, SmallVector<Candidate *, 3>> RewriteCandidates;
425
426 // All instructions that have candidates sort in topological order based on
427 // dependency graph, from roots to leaves.
428 std::vector<Instruction *> SortedCandidateInsts;
429
430 // Record all instructions that are already rewritten and will be removed
431 // later.
432 std::vector<Instruction *> DeadInstructions;
433
434 // Classify candidates against Delta kind
435 class CandidateDictTy {
436 public:
437 using CandsTy = SmallVector<Candidate *, 8>;
438 using BBToCandsTy = DenseMap<const BasicBlock *, CandsTy>;
439
440 private:
441 // Index delta Basis must have the same (Base, StrideSCEV, Inst.Type)
442 using IndexDeltaKeyTy = std::tuple<const SCEV *, const SCEV *, Type *>;
443 DenseMap<IndexDeltaKeyTy, BBToCandsTy> IndexDeltaCandidates;
444
445 // Base delta Basis must have the same (StrideSCEV, Index, Inst.Type)
446 using BaseDeltaKeyTy = std::tuple<const SCEV *, ConstantInt *, Type *>;
447 DenseMap<BaseDeltaKeyTy, BBToCandsTy> BaseDeltaCandidates;
448
449 // Stride delta Basis must have the same (Base, Index, Inst.Type)
450 using StrideDeltaKeyTy = std::tuple<const SCEV *, ConstantInt *, Type *>;
451 DenseMap<StrideDeltaKeyTy, BBToCandsTy> StrideDeltaCandidates;
452
453 public:
454 // TODO: Disable index delta on GEP after we completely move
455 // from typed GEP to PtrAdd.
456 const BBToCandsTy *getCandidatesWithDeltaKind(const Candidate &C,
457 Candidate::DKind K) const {
458 assert(K != Candidate::InvalidDelta);
459 if (K == Candidate::IndexDelta) {
460 IndexDeltaKeyTy IndexDeltaKey(C.Base, C.StrideSCEV, C.Ins->getType());
461 auto It = IndexDeltaCandidates.find(IndexDeltaKey);
462 if (It != IndexDeltaCandidates.end())
463 return &It->second;
464 } else if (K == Candidate::BaseDelta) {
465 BaseDeltaKeyTy BaseDeltaKey(C.StrideSCEV, C.Index, C.Ins->getType());
466 auto It = BaseDeltaCandidates.find(BaseDeltaKey);
467 if (It != BaseDeltaCandidates.end())
468 return &It->second;
469 } else {
470 assert(K == Candidate::StrideDelta);
471 StrideDeltaKeyTy StrideDeltaKey(C.Base, C.Index, C.Ins->getType());
472 auto It = StrideDeltaCandidates.find(StrideDeltaKey);
473 if (It != StrideDeltaCandidates.end())
474 return &It->second;
475 }
476 return nullptr;
477 }
478
479 // Pointers to C must remain valid until CandidateDict is cleared.
480 void add(Candidate &C) {
481 Type *ValueType = C.Ins->getType();
482 BasicBlock *BB = C.Ins->getParent();
483 IndexDeltaKeyTy IndexDeltaKey(C.Base, C.StrideSCEV, ValueType);
484 BaseDeltaKeyTy BaseDeltaKey(C.StrideSCEV, C.Index, ValueType);
485 StrideDeltaKeyTy StrideDeltaKey(C.Base, C.Index, ValueType);
486 IndexDeltaCandidates[IndexDeltaKey][BB].push_back(&C);
487 BaseDeltaCandidates[BaseDeltaKey][BB].push_back(&C);
488 StrideDeltaCandidates[StrideDeltaKey][BB].push_back(&C);
489 }
490 // Remove all mappings from set
491 void clear() {
492 IndexDeltaCandidates.clear();
493 BaseDeltaCandidates.clear();
494 StrideDeltaCandidates.clear();
495 }
496 } CandidateDict;
497
498 const SCEV *getAndRecordSCEV(Value *V) {
499 auto *S = SE->getSCEV(V);
502 SCEVToInsts[S].insert(cast<Instruction>(V));
503
504 return S;
505 }
506
507 bool candidatePredicate(Candidate *Basis, Candidate &C, Candidate::DKind K);
508
509 bool searchFrom(const CandidateDictTy::BBToCandsTy &BBToCands, Candidate &C,
510 Candidate::DKind K);
511
512 // Get the nearest instruction before CI that represents the value of S,
513 // return nullptr if no instruction is associated with S or S is not a
514 // reusable expression.
515 Value *getNearestValueOfSCEV(const SCEV *S, const Instruction *CI) const {
517 return nullptr;
518
519 if (auto *SU = dyn_cast<SCEVUnknown>(S))
520 return SU->getValue();
521 if (auto *SC = dyn_cast<SCEVConstant>(S))
522 return SC->getValue();
523
524 auto It = SCEVToInsts.find(S);
525 if (It == SCEVToInsts.end())
526 return nullptr;
527
528 // Instructions are sorted in depth-first order, so search for the nearest
529 // instruction by walking the list in reverse order.
530 for (Instruction *I : reverse(It->second))
531 if (DT->dominates(I, CI))
532 return I;
533
534 return nullptr;
535 }
536
537 struct DeltaInfo {
538 Candidate *Cand;
539 Candidate::DKind DeltaKind;
540 Value *Delta;
541
542 DeltaInfo()
543 : Cand(nullptr), DeltaKind(Candidate::InvalidDelta), Delta(nullptr) {}
544 DeltaInfo(Candidate *Cand, Candidate::DKind DeltaKind, Value *Delta)
545 : Cand(Cand), DeltaKind(DeltaKind), Delta(Delta) {}
546 operator bool() const { return Cand != nullptr; }
547 };
548
549 friend raw_ostream &operator<<(raw_ostream &OS, const DeltaInfo &DI);
550
551 DeltaInfo compressPath(Candidate &C, Candidate *Basis) const;
552
553 Candidate *pickRewriteCandidate(Instruction *I) const;
554 void sortCandidateInstructions();
555 Value *getDelta(const Candidate &C, const Candidate &Basis,
556 Candidate::DKind K) const;
557 static bool isSimilar(Candidate &C, Candidate &Basis, Candidate::DKind K);
558
559 // Add Basis -> C in DependencyGraph and propagate
560 // C.Stride and C.Delta's dependency to C
561 void addDependency(Candidate &C, Candidate *Basis) {
562 if (Basis)
563 DependencyGraph[Basis->Ins].emplace_back(C.Ins);
564
565 // If any candidate of Inst has a basis, then Inst will be rewritten,
566 // C must be rewritten after rewriting Inst, so we need to propagate
567 // the dependency to C
568 auto PropagateDependency = [&](Instruction *Inst) {
569 if (auto CandsIt = RewriteCandidates.find(Inst);
570 CandsIt != RewriteCandidates.end() &&
571 llvm::any_of(CandsIt->second,
572 [](Candidate *Cand) { return Cand->Basis; }))
573 DependencyGraph[Inst].emplace_back(C.Ins);
574 };
575
576 // If C has a variable delta and the delta is a candidate,
577 // propagate its dependency to C
578 if (auto *DeltaInst = dyn_cast_or_null<Instruction>(C.Delta))
579 PropagateDependency(DeltaInst);
580
581 // If the stride is a candidate, propagate its dependency to C
582 if (auto *StrideInst = dyn_cast<Instruction>(C.Stride))
583 PropagateDependency(StrideInst);
584 };
585};
586
588 const StraightLineStrengthReduce::Candidate &C) {
589 OS << "Ins: " << *C.Ins << "\n Base: " << *C.Base
590 << "\n Index: " << *C.Index << "\n Stride: " << *C.Stride
591 << "\n StrideSCEV: " << *C.StrideSCEV;
592 if (C.Basis)
593 OS << "\n Delta: " << *C.Delta << "\n Basis: \n [ " << *C.Basis << " ]";
594 return OS;
595}
596
597[[maybe_unused]] LLVM_DUMP_METHOD inline raw_ostream &
598operator<<(raw_ostream &OS, const StraightLineStrengthReduce::DeltaInfo &DI) {
599 OS << "Cand: " << *DI.Cand << "\n";
600 OS << "Delta Kind: ";
601 switch (DI.DeltaKind) {
602 case StraightLineStrengthReduce::Candidate::IndexDelta:
603 OS << "Index";
604 break;
605 case StraightLineStrengthReduce::Candidate::BaseDelta:
606 OS << "Base";
607 break;
608 case StraightLineStrengthReduce::Candidate::StrideDelta:
609 OS << "Stride";
610 break;
611 default:
612 break;
613 }
614 OS << "\nDelta: " << *DI.Delta;
615 return OS;
616}
617
618} // end anonymous namespace
619
620char StraightLineStrengthReduceLegacyPass::ID = 0;
621
622INITIALIZE_PASS_BEGIN(StraightLineStrengthReduceLegacyPass, "slsr",
623 "Straight line strength reduction", false, false)
627INITIALIZE_PASS_END(StraightLineStrengthReduceLegacyPass, "slsr",
628 "Straight line strength reduction", false, false)
629
631 return new StraightLineStrengthReduceLegacyPass();
632}
633
634// A helper function that unifies the bitwidth of A and B.
635static void unifyBitWidth(APInt &A, APInt &B) {
636 if (A.getBitWidth() < B.getBitWidth())
637 A = A.sext(B.getBitWidth());
638 else if (A.getBitWidth() > B.getBitWidth())
639 B = B.sext(A.getBitWidth());
640}
641
642Value *StraightLineStrengthReduce::getDelta(const Candidate &C,
643 const Candidate &Basis,
644 Candidate::DKind K) const {
645 if (K == Candidate::IndexDelta) {
646 APInt Idx = C.Index->getValue();
647 APInt BasisIdx = Basis.Index->getValue();
648 unifyBitWidth(Idx, BasisIdx);
649 APInt IndexDelta = Idx - BasisIdx;
650 IntegerType *DeltaType =
651 IntegerType::get(C.Ins->getContext(), IndexDelta.getBitWidth());
652 return ConstantInt::get(DeltaType, IndexDelta);
653 } else if (K == Candidate::BaseDelta || K == Candidate::StrideDelta) {
654 const SCEV *BasisPart =
655 (K == Candidate::BaseDelta) ? Basis.Base : Basis.StrideSCEV;
656 const SCEV *CandPart = (K == Candidate::BaseDelta) ? C.Base : C.StrideSCEV;
657 const SCEV *Diff = SE->getMinusSCEV(CandPart, BasisPart);
658 return getNearestValueOfSCEV(Diff, C.Ins);
659 }
660 return nullptr;
661}
662
663bool StraightLineStrengthReduce::isSimilar(Candidate &C, Candidate &Basis,
664 Candidate::DKind K) {
665 bool SameType = false;
666 switch (K) {
667 case Candidate::StrideDelta:
668 SameType = C.StrideSCEV->getType() == Basis.StrideSCEV->getType();
669 break;
670 case Candidate::BaseDelta:
671 SameType = C.Base->getType() == Basis.Base->getType();
672 break;
673 case Candidate::IndexDelta:
674 SameType = true;
675 break;
676 default:;
677 }
678 return SameType && Basis.Ins != C.Ins &&
679 Basis.CandidateKind == C.CandidateKind;
680}
681
682// Try to find a Delta that C can reuse Basis to rewrite.
683// Set C.Delta, C.Basis, and C.DeltaKind if found.
684// Return true if found a constant delta.
685// Return false if not found or the delta is not a constant.
686bool StraightLineStrengthReduce::candidatePredicate(Candidate *Basis,
687 Candidate &C,
688 Candidate::DKind K) {
689 if (!isSimilar(C, *Basis, K))
690 return false;
691
692 assert(DT->dominates(Basis->Ins, C.Ins));
693 Value *Delta = getDelta(C, *Basis, K);
694 if (!Delta)
695 return false;
696
697 // IndexDelta rewrite is not always profitable, e.g.,
698 // X = B + 8 * S
699 // Y = B + S,
700 // rewriting Y to X - 7 * S is probably a bad idea.
701 // So, we need to check if the rewrite form's computation efficiency
702 // is better than the original form.
703 if (K == Candidate::IndexDelta &&
704 !C.isProfitableRewrite(*Delta, Candidate::IndexDelta))
705 return false;
706
707 // If there is a Delta that we can reuse Basis to rewrite C, clean up
708 // previously collected poison generating instructions.
709 for (Instruction *I : Basis->DropList)
710 I->dropPoisonGeneratingAnnotations();
711
712 // Record delta if none has been found yet, or the new delta is
713 // a constant that is better than the existing delta.
714 if (!C.Delta || isa<ConstantInt>(Delta)) {
715 C.Delta = Delta;
716 C.Basis = Basis;
717 C.DeltaKind = K;
718 }
719 return isa<ConstantInt>(C.Delta);
720}
721
722// return true if find a Basis with constant delta and stop searching,
723// return false if did not find a Basis or the delta is not a constant
724// and continue searching for a Basis with constant delta
725bool StraightLineStrengthReduce::searchFrom(
726 const CandidateDictTy::BBToCandsTy &BBToCands, Candidate &C,
727 Candidate::DKind K) {
728
729 // Stride delta rewrite on Mul form is usually non-profitable, and Base
730 // delta rewrite sometimes is profitable, so we do not support them on Mul.
731 if (C.CandidateKind == Candidate::Mul && K != Candidate::IndexDelta)
732 return false;
733
734 // Search dominating candidates by walking the immediate-dominator chain
735 // from the candidate's defining block upward. Visiting blocks in this
736 // order ensures we prefer the closest dominating basis.
737 const BasicBlock *BB = C.Ins->getParent();
738 while (BB) {
739 auto It = BBToCands.find(BB);
740 if (It != BBToCands.end())
741 for (Candidate *Basis : reverse(It->second))
742 if (candidatePredicate(Basis, C, K))
743 return true;
744
745 const DomTreeNode *Node = DT->getNode(BB);
746 if (!Node)
747 break;
748 Node = Node->getIDom();
749 BB = Node ? Node->getBlock() : nullptr;
750 }
751 return false;
752}
753
754void StraightLineStrengthReduce::setBasisAndDeltaFor(Candidate &C) {
755 if (const auto *BaseDeltaCandidates =
756 CandidateDict.getCandidatesWithDeltaKind(C, Candidate::BaseDelta))
757 if (searchFrom(*BaseDeltaCandidates, C, Candidate::BaseDelta)) {
758 LLVM_DEBUG(dbgs() << "Found delta from Base: " << *C.Delta << "\n");
759 return;
760 }
761
762 if (const auto *StrideDeltaCandidates =
763 CandidateDict.getCandidatesWithDeltaKind(C, Candidate::StrideDelta))
764 if (searchFrom(*StrideDeltaCandidates, C, Candidate::StrideDelta)) {
765 LLVM_DEBUG(dbgs() << "Found delta from Stride: " << *C.Delta << "\n");
766 return;
767 }
768
769 if (const auto *IndexDeltaCandidates =
770 CandidateDict.getCandidatesWithDeltaKind(C, Candidate::IndexDelta))
771 if (searchFrom(*IndexDeltaCandidates, C, Candidate::IndexDelta)) {
772 LLVM_DEBUG(dbgs() << "Found delta from Index: " << *C.Delta << "\n");
773 return;
774 }
775
776 // If we did not find a constant delta, we might have found a variable delta
777 if (C.Delta) {
778 LLVM_DEBUG({
779 dbgs() << "Found delta from ";
780 if (C.DeltaKind == Candidate::BaseDelta)
781 dbgs() << "Base: ";
782 else
783 dbgs() << "Stride: ";
784 dbgs() << *C.Delta << "\n";
785 });
786 assert(C.DeltaKind != Candidate::InvalidDelta && C.Basis);
787 }
788}
789
790// Compress the path from `Basis` to the deepest Basis in the Basis chain
791// to avoid non-profitable data dependency and improve ILP.
792// X = A + 1
793// Y = X + 1
794// Z = Y + 1
795// ->
796// X = A + 1
797// Y = A + 2
798// Z = A + 3
799// Return the delta info for C aginst the new Basis
800auto StraightLineStrengthReduce::compressPath(Candidate &C,
801 Candidate *Basis) const
802 -> DeltaInfo {
803 if (!Basis || !Basis->Basis || C.CandidateKind == Candidate::Mul)
804 return {};
805 Candidate *Root = Basis;
806 Value *NewDelta = nullptr;
807 auto NewKind = Candidate::InvalidDelta;
808
809 while (Root->Basis) {
810 Candidate *NextRoot = Root->Basis;
811 if (C.Base == NextRoot->Base && C.StrideSCEV == NextRoot->StrideSCEV &&
812 isSimilar(C, *NextRoot, Candidate::IndexDelta)) {
813 ConstantInt *CI =
814 cast<ConstantInt>(getDelta(C, *NextRoot, Candidate::IndexDelta));
815 if (CI->isZero() || CI->isOne() || isa<SCEVConstant>(C.StrideSCEV)) {
816 Root = NextRoot;
817 NewKind = Candidate::IndexDelta;
818 NewDelta = CI;
819 continue;
820 }
821 }
822
823 const SCEV *CandPart = nullptr;
824 const SCEV *BasisPart = nullptr;
825 auto CurrKind = Candidate::InvalidDelta;
826 if (C.Base == NextRoot->Base && C.Index == NextRoot->Index) {
827 CandPart = C.StrideSCEV;
828 BasisPart = NextRoot->StrideSCEV;
829 CurrKind = Candidate::StrideDelta;
830 } else if (C.StrideSCEV == NextRoot->StrideSCEV &&
831 C.Index == NextRoot->Index) {
832 CandPart = C.Base;
833 BasisPart = NextRoot->Base;
834 CurrKind = Candidate::BaseDelta;
835 } else
836 break;
837
838 assert(CandPart && BasisPart);
839 if (!isSimilar(C, *NextRoot, CurrKind))
840 break;
841
842 if (auto DeltaVal =
843 dyn_cast<SCEVConstant>(SE->getMinusSCEV(CandPart, BasisPart))) {
844 Root = NextRoot;
845 NewDelta = DeltaVal->getValue();
846 NewKind = CurrKind;
847 } else
848 break;
849 }
850
851 if (Root != Basis) {
852 assert(NewKind != Candidate::InvalidDelta && NewDelta);
853 LLVM_DEBUG(dbgs() << "Found new Basis with " << *NewDelta
854 << " from path compression.\n");
855 return {Root, NewKind, NewDelta};
856 }
857
858 return {};
859}
860
861// Topologically sort candidate instructions based on their relationship in
862// dependency graph.
863void StraightLineStrengthReduce::sortCandidateInstructions() {
864 SortedCandidateInsts.clear();
865 // An instruction may have multiple candidates that get different Basis
866 // instructions, and each candidate can get dependencies from Basis and
867 // Stride when Stride will also be rewritten by SLSR. Hence, an instruction
868 // may have multiple dependencies. Use InDegree to ensure all dependencies
869 // processed before processing itself.
870 DenseMap<Instruction *, int> InDegree;
871 for (auto &KV : DependencyGraph) {
872 InDegree.try_emplace(KV.first, 0);
873
874 for (auto *Child : KV.second) {
875 InDegree[Child]++;
876 }
877 }
878 std::queue<Instruction *> WorkList;
879 DenseSet<Instruction *> Visited;
880
881 for (auto &KV : DependencyGraph)
882 if (InDegree[KV.first] == 0)
883 WorkList.push(KV.first);
884
885 while (!WorkList.empty()) {
886 Instruction *I = WorkList.front();
887 WorkList.pop();
888 if (!Visited.insert(I).second)
889 continue;
890
891 SortedCandidateInsts.push_back(I);
892
893 for (auto *Next : DependencyGraph[I]) {
894 auto &Degree = InDegree[Next];
895 if (--Degree == 0)
896 WorkList.push(Next);
897 }
898 }
899
900 assert(SortedCandidateInsts.size() == DependencyGraph.size() &&
901 "Dependency graph should not have cycles");
902}
903
904auto StraightLineStrengthReduce::pickRewriteCandidate(Instruction *I) const
905 -> Candidate * {
906 // Return the candidate of instruction I that has the highest profit.
907 auto It = RewriteCandidates.find(I);
908 if (It == RewriteCandidates.end())
909 return nullptr;
910
911 Candidate *BestC = nullptr;
912 auto BestEfficiency = Candidate::Unknown;
913 for (Candidate *C : reverse(It->second))
914 if (C->Basis) {
915 auto Efficiency = C->getRewriteEfficiency();
916 if (Efficiency > BestEfficiency) {
917 BestEfficiency = Efficiency;
918 BestC = C;
919 }
920 }
921
922 return BestC;
923}
924
926 const TargetTransformInfo *TTI) {
927 SmallVector<const Value *, 4> Indices(GEP->indices());
928 return TTI->getGEPCost(GEP->getSourceElementType(), GEP->getPointerOperand(),
930}
931
932// Returns whether (Base + Index * Stride) can be folded to an addressing mode.
933static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride,
935 // Index->getSExtValue() may crash if Index is wider than 64-bit.
936 return Index->getBitWidth() <= 64 &&
937 TTI->isLegalAddressingMode(Base->getType(), nullptr, 0, true,
938 Index->getSExtValue(), UnknownAddressSpace);
939}
940
941bool StraightLineStrengthReduce::isFoldable(const Candidate &C,
942 TargetTransformInfo *TTI) {
943 if (C.CandidateKind == Candidate::Add)
944 return isAddFoldable(C.Base, C.Index, C.Stride, TTI);
945 if (C.CandidateKind == Candidate::GEP)
947 return false;
948}
949
950void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
951 Candidate::Kind CT, const SCEV *B, ConstantInt *Idx, Value *S,
952 Instruction *I) {
953 // Record the SCEV of S that we may use it as a variable delta.
954 // Ensure that we rewrite C with a existing IR that reproduces delta value.
955
956 Candidate C(CT, B, Idx, S, I, getAndRecordSCEV(S));
957 // If we can fold I into an addressing mode, computing I is likely free or
958 // takes only one instruction. So, we don't need to analyze or rewrite it.
959 //
960 // Currently, this algorithm can at best optimize complex computations into
961 // a `variable +/* constant` form. However, some targets have stricter
962 // constraints on the their addressing mode.
963 // For example, a `variable + constant` can only be folded to an addressing
964 // mode if the constant falls within a certain range.
965 // So, we also check if the instruction is already high efficient enough
966 // for the strength reduction algorithm.
967 if (!isFoldable(C, TTI) && !C.isHighEfficiency()) {
968 setBasisAndDeltaFor(C);
969
970 // Compress unnecessary rewrite to improve ILP
971 if (auto Res = compressPath(C, C.Basis)) {
972 C.Basis = Res.Cand;
973 C.DeltaKind = Res.DeltaKind;
974 C.Delta = Res.Delta;
975 }
976 }
977 // Regardless of whether we find a basis for C, we need to push C to the
978 // candidate list so that it can be the basis of other candidates.
979 LLVM_DEBUG(dbgs() << "Allocated Candidate: " << C << "\n");
980 Candidates.push_back(C);
981 RewriteCandidates[C.Ins].push_back(&Candidates.back());
982 // Only add to the dict if this instruction is safe to reuse as a basis. By
983 // doing this early we avoid calling canReuseInstruction repeatedly for the
984 // same instruction. The DropList is stored on the Candidate so
985 // candidatePredicate can drop the flags when a rewrite is being done.
987 SE->canReuseInstruction(SE->getSCEV(I), I, Candidates.back().DropList)) {
988 CandidateDict.add(Candidates.back());
989 }
990}
991
992void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
993 Instruction *I) {
994 switch (I->getOpcode()) {
995 case Instruction::Add:
996 allocateCandidatesAndFindBasisForAdd(I);
997 break;
998 case Instruction::Mul:
999 allocateCandidatesAndFindBasisForMul(I);
1000 break;
1001 case Instruction::GetElementPtr:
1002 allocateCandidatesAndFindBasisForGEP(cast<GetElementPtrInst>(I));
1003 break;
1004 }
1005}
1006
1007void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
1008 Instruction *I) {
1009 // Try matching B + i * S.
1010 if (!isa<IntegerType>(I->getType()))
1011 return;
1012
1013 assert(I->getNumOperands() == 2 && "isn't I an add?");
1014 Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);
1015 allocateCandidatesAndFindBasisForAdd(LHS, RHS, I);
1016 if (LHS != RHS)
1017 allocateCandidatesAndFindBasisForAdd(RHS, LHS, I);
1018}
1019
1020void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
1021 Value *LHS, Value *RHS, Instruction *I) {
1022 Value *S = nullptr;
1023 ConstantInt *Idx = nullptr;
1024 if (match(RHS, m_Mul(m_Value(S), m_ConstantInt(Idx)))) {
1025 // I = LHS + RHS = LHS + Idx * S
1026 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I);
1027 } else if (match(RHS, m_Shl(m_Value(S), m_ConstantInt(Idx)))) {
1028 // I = LHS + RHS = LHS + (S << Idx) = LHS + S * (1 << Idx)
1029 APInt One(Idx->getBitWidth(), 1);
1030 Idx = ConstantInt::get(Idx->getContext(), One << Idx->getValue());
1031 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I);
1032 } else {
1033 // At least, I = LHS + 1 * RHS
1034 ConstantInt *One = ConstantInt::get(cast<IntegerType>(I->getType()), 1);
1035 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), One, RHS,
1036 I);
1037 }
1038}
1039
1040// Returns true if A matches B + C where C is constant.
1041static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C) {
1042 return match(A, m_c_Add(m_Value(B), m_ConstantInt(C)));
1043}
1044
1045// Returns true if A matches B | C where C is constant.
1046static bool matchesOr(Value *A, Value *&B, ConstantInt *&C) {
1047 return match(A, m_c_Or(m_Value(B), m_ConstantInt(C)));
1048}
1049
1050void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
1051 Value *LHS, Value *RHS, Instruction *I) {
1052 Value *B = nullptr;
1053 ConstantInt *Idx = nullptr;
1054 if (matchesAdd(LHS, B, Idx)) {
1055 // If LHS is in the form of "Base + Index", then I is in the form of
1056 // "(Base + Index) * RHS".
1057 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS, I);
1058 } else if (matchesOr(LHS, B, Idx) && haveNoCommonBitsSet(B, Idx, *DL)) {
1059 // If LHS is in the form of "Base | Index" and Base and Index have no common
1060 // bits set, then
1061 // Base | Index = Base + Index
1062 // and I is thus in the form of "(Base + Index) * RHS".
1063 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS, I);
1064 } else {
1065 // Otherwise, at least try the form (LHS + 0) * RHS.
1066 ConstantInt *Zero = ConstantInt::get(cast<IntegerType>(I->getType()), 0);
1067 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(LHS), Zero, RHS,
1068 I);
1069 }
1070}
1071
1072void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
1073 Instruction *I) {
1074 // Try matching (B + i) * S.
1075 // TODO: we could extend SLSR to float and vector types.
1076 if (!isa<IntegerType>(I->getType()))
1077 return;
1078
1079 assert(I->getNumOperands() == 2 && "isn't I a mul?");
1080 Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);
1081 allocateCandidatesAndFindBasisForMul(LHS, RHS, I);
1082 if (LHS != RHS) {
1083 // Symmetrically, try to split RHS to Base + Index.
1084 allocateCandidatesAndFindBasisForMul(RHS, LHS, I);
1085 }
1086}
1087
1088void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
1089 GetElementPtrInst *GEP) {
1090 // TODO: handle vector GEPs
1091 if (GEP->getType()->isVectorTy())
1092 return;
1093
1094 SmallVector<SCEVUse, 4> IndexExprs;
1095 for (Use &Idx : GEP->indices())
1096 IndexExprs.push_back(SE->getSCEV(Idx));
1097
1099 for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) {
1100 if (GTI.isStruct())
1101 continue;
1102
1103 SCEVUse OrigIndexExpr = IndexExprs[I - 1];
1104 IndexExprs[I - 1] = SE->getZero(OrigIndexExpr.getPointer()->getType());
1105
1106 // The base of this candidate is GEP's base plus the offsets of all
1107 // indices except this current one.
1108 SCEVUse BaseExpr = SE->getGEPExpr(cast<GEPOperator>(GEP), IndexExprs);
1109 Value *ArrayIdx = GEP->getOperand(I);
1110 uint64_t ElementSize = GTI.getSequentialElementStride(*DL);
1111 IntegerType *PtrIdxTy = cast<IntegerType>(DL->getIndexType(GEP->getType()));
1112 // If the element size overflows the type, truncate.
1113 ConstantInt *ElementSizeIdx =
1114 ConstantInt::getSigned(PtrIdxTy, ElementSize, /*ImplicitTrunc=*/true);
1115 if (ArrayIdx->getType()->getIntegerBitWidth() <=
1116 DL->getIndexSizeInBits(GEP->getAddressSpace())) {
1117 // Skip factoring if ArrayIdx is wider than the index size, because
1118 // ArrayIdx is implicitly truncated to the index size.
1119 allocateCandidatesAndFindBasis(Candidate::GEP, BaseExpr, ElementSizeIdx,
1120 ArrayIdx, GEP);
1121 }
1122 // When ArrayIdx is the sext of a value, we try to factor that value as
1123 // well. Handling this case is important because array indices are
1124 // typically sign-extended to the pointer index size.
1125 Value *TruncatedArrayIdx = nullptr;
1126 if (match(ArrayIdx, m_SExt(m_Value(TruncatedArrayIdx))) &&
1127 TruncatedArrayIdx->getType()->getIntegerBitWidth() <=
1128 DL->getIndexSizeInBits(GEP->getAddressSpace())) {
1129 // Skip factoring if TruncatedArrayIdx is wider than the pointer size,
1130 // because TruncatedArrayIdx is implicitly truncated to the pointer size.
1131 allocateCandidatesAndFindBasis(Candidate::GEP, BaseExpr, ElementSizeIdx,
1132 TruncatedArrayIdx, GEP);
1133 }
1134
1135 IndexExprs[I - 1] = OrigIndexExpr;
1136 }
1137}
1138
1139Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis,
1140 const Candidate &C,
1141 IRBuilder<> &Builder,
1142 const DataLayout *DL) {
1143 auto CreateMul = [&](Value *LHS, Value *RHS) {
1144 if (ConstantInt *CR = dyn_cast<ConstantInt>(RHS)) {
1145 const APInt &ConstRHS = CR->getValue();
1146 IntegerType *DeltaType =
1147 IntegerType::get(C.Ins->getContext(), ConstRHS.getBitWidth());
1148 if (ConstRHS.isPowerOf2()) {
1149 ConstantInt *Exponent =
1150 ConstantInt::get(DeltaType, ConstRHS.logBase2());
1151 return Builder.CreateShl(LHS, Exponent);
1152 }
1153 if (ConstRHS.isNegatedPowerOf2()) {
1154 ConstantInt *Exponent =
1155 ConstantInt::get(DeltaType, (-ConstRHS).logBase2());
1156 return Builder.CreateNeg(Builder.CreateShl(LHS, Exponent));
1157 }
1158 }
1159
1160 return Builder.CreateMul(LHS, RHS);
1161 };
1162
1163 Value *Delta = C.Delta;
1164 // If Delta is 0, C is a fully redundant of C.Basis,
1165 // just replace C.Ins with Basis.Ins
1166 if (ConstantInt *CI = dyn_cast<ConstantInt>(Delta);
1167 CI && CI->getValue().isZero())
1168 return nullptr;
1169
1170 if (C.DeltaKind == Candidate::IndexDelta) {
1171 APInt IndexDelta = cast<ConstantInt>(C.Delta)->getValue();
1172 // IndexDelta
1173 // X = B + i * S
1174 // Y = B + i` * S
1175 // = B + (i + IndexDelta) * S
1176 // = B + i * S + IndexDelta * S
1177 // = X + IndexDelta * S
1178 // Bump = (i' - i) * S
1179
1180 // Common case 1: if (i' - i) is 1, Bump = S.
1181 if (IndexDelta == 1)
1182 return C.Stride;
1183 // Common case 2: if (i' - i) is -1, Bump = -S.
1184 if (IndexDelta.isAllOnes())
1185 return Builder.CreateNeg(C.Stride);
1186
1187 IntegerType *DeltaType =
1188 IntegerType::get(Basis.Ins->getContext(), IndexDelta.getBitWidth());
1189 Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, DeltaType);
1190
1191 return CreateMul(ExtendedStride, C.Delta);
1192 }
1193
1194 assert(C.DeltaKind == Candidate::StrideDelta ||
1195 C.DeltaKind == Candidate::BaseDelta);
1196 assert(C.CandidateKind != Candidate::Mul);
1197 // StrideDelta
1198 // X = B + i * S
1199 // Y = B + i * S'
1200 // = B + i * (S + StrideDelta)
1201 // = B + i * S + i * StrideDelta
1202 // = X + i * StrideDelta
1203 // Bump = i * (S' - S)
1204 //
1205 // BaseDelta
1206 // X = B + i * S
1207 // Y = B' + i * S
1208 // = (B + BaseDelta) + i * S
1209 // = X + BaseDelta
1210 // Bump = (B' - B).
1211 Value *Bump = C.Delta;
1212 if (C.DeltaKind == Candidate::StrideDelta) {
1213 // If this value is consumed by a GEP, promote StrideDelta before doing
1214 // StrideDelta * Index to ensure the same semantics as the original GEP.
1215 if (C.CandidateKind == Candidate::GEP) {
1216 auto *GEP = cast<GetElementPtrInst>(C.Ins);
1217 Type *NewScalarIndexTy =
1218 DL->getIndexType(GEP->getPointerOperandType()->getScalarType());
1219 Bump = Builder.CreateSExtOrTrunc(Bump, NewScalarIndexTy);
1220 }
1221 if (!C.Index->isOne()) {
1222 Value *ExtendedIndex =
1223 Builder.CreateSExtOrTrunc(C.Index, Bump->getType());
1224 Bump = CreateMul(Bump, ExtendedIndex);
1225 }
1226 }
1227 return Bump;
1228}
1229
1230void StraightLineStrengthReduce::rewriteCandidate(const Candidate &C) {
1231 if (!DebugCounter::shouldExecute(StraightLineStrengthReduceCounter))
1232 return;
1233
1234 const Candidate &Basis = *C.Basis;
1235 assert(C.Delta && C.CandidateKind == Basis.CandidateKind &&
1236 C.hasValidDelta(Basis));
1237
1238 IRBuilder<> Builder(C.Ins);
1239 Value *Bump = emitBump(Basis, C, Builder, DL);
1240 Value *Reduced = nullptr; // equivalent to but weaker than C.Ins
1241 // If delta is 0, C is a fully redundant of Basis, and Bump is nullptr,
1242 // just replace C.Ins with Basis.Ins
1243 if (!Bump)
1244 Reduced = Basis.Ins;
1245 else {
1246 switch (C.CandidateKind) {
1247 case Candidate::Add:
1248 case Candidate::Mul: {
1249 // C = Basis + Bump
1250 Value *NegBump;
1251 if (match(Bump, m_Neg(m_Value(NegBump)))) {
1252 // If Bump is a neg instruction, emit C = Basis - (-Bump).
1253 Reduced = Builder.CreateSub(Basis.Ins, NegBump);
1254 // We only use the negative argument of Bump, and Bump itself may be
1255 // trivially dead.
1257 } else {
1258 // It's tempting to preserve nsw on Bump and/or Reduced. However, it's
1259 // usually unsound, e.g.,
1260 //
1261 // X = (-2 +nsw 1) *nsw INT_MAX
1262 // Y = (-2 +nsw 3) *nsw INT_MAX
1263 // =>
1264 // Y = X + 2 * INT_MAX
1265 //
1266 // Neither + and * in the resultant expression are nsw.
1267 Reduced = Builder.CreateAdd(Basis.Ins, Bump);
1268 }
1269 break;
1270 }
1271 case Candidate::GEP: {
1272 bool InBounds = cast<GetElementPtrInst>(C.Ins)->isInBounds();
1273 // C = (char *)Basis + Bump
1274 Reduced = Builder.CreatePtrAdd(Basis.Ins, Bump, "", InBounds);
1275 break;
1276 }
1277 default:
1278 llvm_unreachable("C.CandidateKind is invalid");
1279 };
1280 Reduced->takeName(C.Ins);
1281 }
1282 C.Ins->replaceAllUsesWith(Reduced);
1283 DeadInstructions.push_back(C.Ins);
1284}
1285
1286bool StraightLineStrengthReduceLegacyPass::runOnFunction(Function &F) {
1287 if (skipFunction(F))
1288 return false;
1289
1290 auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1291 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
1292 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
1293 return StraightLineStrengthReduce(DL, DT, SE, TTI).runOnFunction(F);
1294}
1295
1296bool StraightLineStrengthReduce::runOnFunction(Function &F) {
1297 LLVM_DEBUG(dbgs() << "SLSR on Function: " << F.getName() << "\n");
1298 // Traverse the dominator tree in the depth-first order. This order makes sure
1299 // all bases of a candidate are in Candidates when we process it.
1300 for (const auto Node : depth_first(DT))
1301 for (auto &I : *(Node->getBlock()))
1302 allocateCandidatesAndFindBasis(&I);
1303
1304 // Build the dependency graph and sort candidate instructions from dependency
1305 // roots to leaves
1306 for (auto &C : Candidates) {
1307 DependencyGraph.try_emplace(C.Ins);
1308 addDependency(C, C.Basis);
1309 }
1310 sortCandidateInstructions();
1311
1312 // Rewrite candidates in the topological order that rewrites a Candidate
1313 // always before rewriting its Basis
1314 for (Instruction *I : reverse(SortedCandidateInsts))
1315 if (Candidate *C = pickRewriteCandidate(I))
1316 rewriteCandidate(*C);
1317
1318 for (auto *DeadIns : DeadInstructions)
1319 // A dead instruction may be another dead instruction's op,
1320 // don't delete an instruction twice
1321 if (DeadIns->getParent())
1323
1324 bool Ret = !DeadInstructions.empty();
1325 DeadInstructions.clear();
1326 DependencyGraph.clear();
1327 RewriteCandidates.clear();
1328 SortedCandidateInsts.clear();
1329 // First clear all references to candidates in the list
1330 CandidateDict.clear();
1331 // Then destroy the list
1332 Candidates.clear();
1333 return Ret;
1334}
1335
1336PreservedAnalyses
1338 const DataLayout *DL = &F.getDataLayout();
1339 auto *DT = &AM.getResult<DominatorTreeAnalysis>(F);
1340 auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
1341 auto *TTI = &AM.getResult<TargetIRAnalysis>(F);
1342
1343 if (!StraightLineStrengthReduce(DL, DT, SE, TTI).runOnFunction(F))
1344 return PreservedAnalyses::all();
1345
1351 return PA;
1352}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file implements a class to represent arbitrary precision integral constant values and operations...
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
This file contains the declarations for the subclasses of Constant, which represent the different fla...
This file provides an implementation of debug counters.
#define DEBUG_COUNTER(VARNAME, COUNTERNAME, DESC)
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool runOnFunction(Function &F, bool PostInlining)
Hexagon Common GEP
Module.h This file contains the declarations for the Module class.
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
Machine Check Debug Module
static bool isGEPFoldable(GetElementPtrInst *GEP, const TargetTransformInfo *TTI)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
static BinaryOperator * CreateMul(Value *S1, Value *S2, const Twine &Name, BasicBlock::iterator InsertBefore, Value *FlagsOp)
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallVector class.
static bool matchesOr(Value *A, Value *&B, ConstantInt *&C)
static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride, TargetTransformInfo *TTI)
static void unifyBitWidth(APInt &A, APInt &B)
static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C)
static const unsigned UnknownAddressSpace
static cl::opt< bool > EnablePoisonReuseGuard("enable-poison-reuse-guard", cl::init(true), cl::desc("Enable poison-reuse guard"))
#define LLVM_DEBUG(...)
Definition Debug.h:119
This pass exposes codegen information to IR-level passes.
Value * RHS
Value * LHS
Class for arbitrary precision integers.
Definition APInt.h:78
bool isNegatedPowerOf2() const
Check if this APInt's negated value is a power of two greater than zero.
Definition APInt.h:450
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1511
unsigned logBase2() const
Definition APInt.h:1784
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
Represents analyses that only rely on functions' control flow.
Definition Analysis.h:73
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isOne() const
This is just a convenience method to make client code smaller for a common case.
Definition Constants.h:225
static ConstantInt * getSigned(IntegerType *Ty, int64_t V, bool ImplicitTrunc=false)
Return a ConstantInt with the specified value for the specified type.
Definition Constants.h:135
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
unsigned getBitWidth() const
getBitWidth - Return the scalar bitwidth of this constant.
Definition Constants.h:162
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
static bool shouldExecute(CounterInfo &Counter)
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:225
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:301
iterator end()
Definition DenseMap.h:143
Analysis pass which computes a DominatorTree.
Definition Dominators.h:274
DomTreeNodeBase< NodeT > * getNode(const NodeT *BB) const
getNode - return the (Post)DominatorTree node for the specified basic block.
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:310
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Value * CreatePtrAdd(Value *Ptr, Value *Offset, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
Definition IRBuilder.h:2101
Value * CreateNeg(Value *V, const Twine &Name="", bool HasNSW=false)
Definition IRBuilder.h:1852
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1461
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1533
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1444
Value * CreateSExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a SExt or Trunc from the integer value V to DestTy.
Definition IRBuilder.h:2163
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1478
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:350
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
Definition Analysis.h:151
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
This class represents an analyzed expression in the program.
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI const SCEV * getMinusSCEV(SCEVUse LHS, SCEVUse RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< SCEVUse > IndexExprs)
Returns an expression for a GEP.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Analysis pass providing the TargetTransformInfo.
Wrapper pass for TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
@ TCC_Free
Expected to fold away in lowering.
LLVM_ABI unsigned getIntegerBitWidth() const
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:258
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
Definition Value.cpp:399
std::pair< iterator, bool > insert(const ValueT &V)
Definition DenseSet.h:212
TypeSize getSequentialElementStride(const DataLayout &DL) const
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
bool match(Val *V, const Pattern &P)
auto m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Add, true > m_c_Add(const LHS &L, const RHS &R)
Matches a Add with LHS and RHS in either order.
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
CastInst_match< OpTy, SExtInst > m_SExt(const OpTy &Op)
Matches SExt.
BinaryOp_match< LHS, RHS, Instruction::Or, true > m_c_Or(const LHS &L, const RHS &R)
Matches an Or with LHS and RHS in either order.
auto m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
initializer< Ty > init(const Ty &Val)
NodeAddr< NodeBase * > Node
Definition RDFGraph.h:381
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
LLVM_ABI bool haveNoCommonBitsSet(const WithCache< const Value * > &LHSCache, const WithCache< const Value * > &RHSCache, const SimplifyQuery &SQ)
Return true if LHS and RHS have no common bits set.
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition Local.cpp:535
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI void initializeStraightLineStrengthReduceLegacyPassPass(PassRegistry &)
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
auto reverse(ContainerTy &&C)
Definition STLExtras.h:407
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
generic_gep_type_iterator<> gep_type_iterator
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
TargetTransformInfo TTI
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
FunctionAddr VTableAddr Next
Definition InstrProf.h:141
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
gep_type_iterator gep_type_begin(const User *GEP)
PointerUnion< const Value *, const PseudoSourceValue * > ValueType
iterator_range< df_iterator< T > > depth_first(const T &G)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI FunctionPass * createStraightLineStrengthReducePass()
SCEVUseT< const SCEV * > SCEVUse
SCEVPtrT getPointer() const