LLVM 23.0.0git
ScalarizeMaskedMemIntrin.cpp
Go to the documentation of this file.
1//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2// intrinsics
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass replaces masked memory intrinsics - when unsupported by the target
11// - with a chain of basic blocks, that deal with the elements one-by-one if the
12// appropriate mask bit is set.
13//
14//===----------------------------------------------------------------------===//
15
17#include "llvm/ADT/Twine.h"
21#include "llvm/IR/BasicBlock.h"
22#include "llvm/IR/Constant.h"
23#include "llvm/IR/Constants.h"
25#include "llvm/IR/Dominators.h"
26#include "llvm/IR/Function.h"
27#include "llvm/IR/IRBuilder.h"
28#include "llvm/IR/Instruction.h"
32#include "llvm/IR/Type.h"
33#include "llvm/IR/Value.h"
35#include "llvm/Pass.h"
39#include <cassert>
40#include <optional>
41
42using namespace llvm;
43
44#define DEBUG_TYPE "scalarize-masked-mem-intrin"
45
46namespace {
47
48class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
49public:
50 static char ID; // Pass identification, replacement for typeid
51
52 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
55 }
56
57 bool runOnFunction(Function &F) override;
58
59 StringRef getPassName() const override {
60 return "Scalarize Masked Memory Intrinsics";
61 }
62
63 void getAnalysisUsage(AnalysisUsage &AU) const override {
66 }
67};
68
69} // end anonymous namespace
70
71static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
72 const TargetTransformInfo &TTI, const DataLayout &DL,
73 bool HasBranchDivergence, DomTreeUpdater *DTU);
74static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
76 const DataLayout &DL, bool HasBranchDivergence,
77 DomTreeUpdater *DTU);
78
79char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
80
81INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
82 "Scalarize unsupported masked memory intrinsics", false,
83 false)
86INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
87 "Scalarize unsupported masked memory intrinsics", false,
88 false)
89
91 return new ScalarizeMaskedMemIntrinLegacyPass();
92}
93
94static bool isConstantIntVector(Value *Mask) {
96 if (!C)
97 return false;
98
99 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
100 for (unsigned i = 0; i != NumElts; ++i) {
101 Constant *CElt = C->getAggregateElement(i);
102 if (!CElt || !isa<ConstantInt>(CElt))
103 return false;
104 }
105
106 return true;
107}
108
109static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
110 unsigned Idx) {
111 return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
112}
113
114// Translate a masked load intrinsic like
115// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr,
116// <16 x i1> %mask, <16 x i32> %passthru)
117// to a chain of basic blocks, with loading element one-by-one if
118// the appropriate mask bit is set
119//
120// %1 = bitcast i8* %addr to i32*
121// %2 = extractelement <16 x i1> %mask, i32 0
122// br i1 %2, label %cond.load, label %else
123//
124// cond.load: ; preds = %0
125// %3 = getelementptr i32* %1, i32 0
126// %4 = load i32* %3
127// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
128// br label %else
129//
130// else: ; preds = %0, %cond.load
131// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ]
132// %6 = extractelement <16 x i1> %mask, i32 1
133// br i1 %6, label %cond.load1, label %else2
134//
135// cond.load1: ; preds = %else
136// %7 = getelementptr i32* %1, i32 1
137// %8 = load i32* %7
138// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
139// br label %else2
140//
141// else2: ; preds = %else, %cond.load1
142// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
143// %10 = extractelement <16 x i1> %mask, i32 2
144// br i1 %10, label %cond.load4, label %else5
145//
146static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence,
147 CallInst *CI, DomTreeUpdater *DTU,
148 bool &ModifiedDT) {
149 Value *Ptr = CI->getArgOperand(0);
150 Value *Mask = CI->getArgOperand(1);
151 Value *Src0 = CI->getArgOperand(2);
152
153 const Align AlignVal = CI->getParamAlign(0).valueOrOne();
154 VectorType *VecType = cast<FixedVectorType>(CI->getType());
155
156 Type *EltTy = VecType->getElementType();
157
158 IRBuilder<> Builder(CI->getContext());
159 Instruction *InsertPt = CI;
160 BasicBlock *IfBlock = CI->getParent();
161
162 Builder.SetInsertPoint(InsertPt);
163 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
164
165 // Short-cut if the mask is all-true.
166 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
167 LoadInst *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
168 NewI->copyMetadata(*CI);
169 NewI->takeName(CI);
170 CI->replaceAllUsesWith(NewI);
171 CI->eraseFromParent();
172 return;
173 }
174
175 // Adjust alignment for the scalar instruction.
176 const Align AdjustedAlignVal =
177 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
178 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
179
180 // The result vector
181 Value *VResult = Src0;
182
183 if (isConstantIntVector(Mask)) {
184 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
185 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
186 continue;
187 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
188 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
189 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
190 }
191 CI->replaceAllUsesWith(VResult);
192 CI->eraseFromParent();
193 return;
194 }
195
196 // Optimize the case where the "masked load" is a predicated load - that is,
197 // where the mask is the splat of a non-constant scalar boolean. In that case,
198 // use that splated value as the guard on a conditional vector load.
199 if (isSplatValue(Mask, /*Index=*/0)) {
200 Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
201 Mask->getName() + ".first");
202 Instruction *ThenTerm =
203 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
204 /*BranchWeights=*/nullptr, DTU);
205
206 BasicBlock *CondBlock = ThenTerm->getParent();
207 CondBlock->setName("cond.load");
208 Builder.SetInsertPoint(CondBlock->getTerminator());
209 LoadInst *Load = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal,
210 CI->getName() + ".cond.load");
211 Load->copyMetadata(*CI);
212
213 BasicBlock *PostLoad = ThenTerm->getSuccessor(0);
214 Builder.SetInsertPoint(PostLoad, PostLoad->begin());
215 PHINode *Phi = Builder.CreatePHI(VecType, /*NumReservedValues=*/2);
216 Phi->addIncoming(Load, CondBlock);
217 Phi->addIncoming(Src0, IfBlock);
218 Phi->takeName(CI);
219
220 CI->replaceAllUsesWith(Phi);
221 CI->eraseFromParent();
222 ModifiedDT = true;
223 return;
224 }
225 // If the mask is not v1i1, use scalar bit test operations. This generates
226 // better results on X86 at least. However, don't do this on GPUs and other
227 // machines with divergence, as there each i1 needs a vector register.
228 Value *SclrMask = nullptr;
229 if (VectorWidth != 1 && !HasBranchDivergence) {
230 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
231 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
232 }
233
234 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
235 // Fill the "else" block, created in the previous iteration
236 //
237 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
238 // %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
239 // %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
240 //
241 // On GPUs, use
242 // %cond = extrectelement %mask, Idx
243 // instead
245 if (SclrMask != nullptr) {
246 Value *Mask = Builder.getInt(APInt::getOneBitSet(
247 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
248 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
249 Builder.getIntN(VectorWidth, 0));
250 } else {
251 Predicate = Builder.CreateExtractElement(Mask, Idx);
252 }
253
254 // Create "cond" block
255 //
256 // %EltAddr = getelementptr i32* %1, i32 0
257 // %Elt = load i32* %EltAddr
258 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
259 //
260 Instruction *ThenTerm =
261 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
262 /*BranchWeights=*/nullptr, DTU);
263
264 BasicBlock *CondBlock = ThenTerm->getParent();
265 CondBlock->setName("cond.load");
266
267 Builder.SetInsertPoint(CondBlock->getTerminator());
268 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
269 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
270 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
271
272 // Create "else" block, fill it in the next iteration
273 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
274 NewIfBlock->setName("else");
275 BasicBlock *PrevIfBlock = IfBlock;
276 IfBlock = NewIfBlock;
277
278 // Create the phi to join the new and previous value.
279 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
280 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
281 Phi->addIncoming(NewVResult, CondBlock);
282 Phi->addIncoming(VResult, PrevIfBlock);
283 VResult = Phi;
284 }
285
286 CI->replaceAllUsesWith(VResult);
287 CI->eraseFromParent();
288
289 ModifiedDT = true;
290}
291
292// Translate a masked store intrinsic, like
293// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr,
294// <16 x i1> %mask)
295// to a chain of basic blocks, that stores element one-by-one if
296// the appropriate mask bit is set
297//
298// %1 = bitcast i8* %addr to i32*
299// %2 = extractelement <16 x i1> %mask, i32 0
300// br i1 %2, label %cond.store, label %else
301//
302// cond.store: ; preds = %0
303// %3 = extractelement <16 x i32> %val, i32 0
304// %4 = getelementptr i32* %1, i32 0
305// store i32 %3, i32* %4
306// br label %else
307//
308// else: ; preds = %0, %cond.store
309// %5 = extractelement <16 x i1> %mask, i32 1
310// br i1 %5, label %cond.store1, label %else2
311//
312// cond.store1: ; preds = %else
313// %6 = extractelement <16 x i32> %val, i32 1
314// %7 = getelementptr i32* %1, i32 1
315// store i32 %6, i32* %7
316// br label %else2
317// . . .
318static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence,
319 CallInst *CI, DomTreeUpdater *DTU,
320 bool &ModifiedDT) {
321 Value *Src = CI->getArgOperand(0);
322 Value *Ptr = CI->getArgOperand(1);
323 Value *Mask = CI->getArgOperand(2);
324
325 const Align AlignVal = CI->getParamAlign(1).valueOrOne();
326 auto *VecType = cast<VectorType>(Src->getType());
327
328 Type *EltTy = VecType->getElementType();
329
330 IRBuilder<> Builder(CI->getContext());
331 Instruction *InsertPt = CI;
332 Builder.SetInsertPoint(InsertPt);
333 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
334
335 // Short-cut if the mask is all-true.
336 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
337 StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
338 Store->takeName(CI);
339 Store->copyMetadata(*CI);
340 CI->eraseFromParent();
341 return;
342 }
343
344 // Adjust alignment for the scalar instruction.
345 const Align AdjustedAlignVal =
346 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
347 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
348
349 if (isConstantIntVector(Mask)) {
350 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
351 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
352 continue;
353 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
354 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
355 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
356 }
357 CI->eraseFromParent();
358 return;
359 }
360
361 // Optimize the case where the "masked store" is a predicated store - that is,
362 // when the mask is the splat of a non-constant scalar boolean. In that case,
363 // optimize to a conditional store.
364 if (isSplatValue(Mask, /*Index=*/0)) {
365 Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
366 Mask->getName() + ".first");
367 Instruction *ThenTerm =
368 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
369 /*BranchWeights=*/nullptr, DTU);
370 BasicBlock *CondBlock = ThenTerm->getParent();
371 CondBlock->setName("cond.store");
372 Builder.SetInsertPoint(CondBlock->getTerminator());
373
374 StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
375 Store->takeName(CI);
376 Store->copyMetadata(*CI);
377
378 CI->eraseFromParent();
379 ModifiedDT = true;
380 return;
381 }
382
383 // If the mask is not v1i1, use scalar bit test operations. This generates
384 // better results on X86 at least. However, don't do this on GPUs or other
385 // machines with branch divergence, as there each i1 takes up a register.
386 Value *SclrMask = nullptr;
387 if (VectorWidth != 1 && !HasBranchDivergence) {
388 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
389 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
390 }
391
392 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
393 // Fill the "else" block, created in the previous iteration
394 //
395 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
396 // %cond = icmp ne i16 %mask_1, 0
397 // br i1 %mask_1, label %cond.store, label %else
398 //
399 // On GPUs, use
400 // %cond = extrectelement %mask, Idx
401 // instead
403 if (SclrMask != nullptr) {
404 Value *Mask = Builder.getInt(APInt::getOneBitSet(
405 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
406 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
407 Builder.getIntN(VectorWidth, 0));
408 } else {
409 Predicate = Builder.CreateExtractElement(Mask, Idx);
410 }
411
412 // Create "cond" block
413 //
414 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
415 // %EltAddr = getelementptr i32* %1, i32 0
416 // %store i32 %OneElt, i32* %EltAddr
417 //
418 Instruction *ThenTerm =
419 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
420 /*BranchWeights=*/nullptr, DTU);
421
422 BasicBlock *CondBlock = ThenTerm->getParent();
423 CondBlock->setName("cond.store");
424
425 Builder.SetInsertPoint(CondBlock->getTerminator());
426 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
427 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
428 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
429
430 // Create "else" block, fill it in the next iteration
431 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
432 NewIfBlock->setName("else");
433
434 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
435 }
436 CI->eraseFromParent();
437
438 ModifiedDT = true;
439}
440
441// Translate a masked gather intrinsic like
442// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
443// <16 x i1> %Mask, <16 x i32> %Src)
444// to a chain of basic blocks, with loading element one-by-one if
445// the appropriate mask bit is set
446//
447// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
448// %Mask0 = extractelement <16 x i1> %Mask, i32 0
449// br i1 %Mask0, label %cond.load, label %else
450//
451// cond.load:
452// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
453// %Load0 = load i32, i32* %Ptr0, align 4
454// %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0
455// br label %else
456//
457// else:
458// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0]
459// %Mask1 = extractelement <16 x i1> %Mask, i32 1
460// br i1 %Mask1, label %cond.load1, label %else2
461//
462// cond.load1:
463// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
464// %Load1 = load i32, i32* %Ptr1, align 4
465// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
466// br label %else2
467// . . .
468// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
469// ret <16 x i32> %Result
471 bool HasBranchDivergence, CallInst *CI,
472 DomTreeUpdater *DTU, bool &ModifiedDT) {
473 Value *Ptrs = CI->getArgOperand(0);
474 Value *Mask = CI->getArgOperand(1);
475 Value *Src0 = CI->getArgOperand(2);
476
477 auto *VecType = cast<FixedVectorType>(CI->getType());
478 Type *EltTy = VecType->getElementType();
479
480 IRBuilder<> Builder(CI->getContext());
481 Instruction *InsertPt = CI;
482 BasicBlock *IfBlock = CI->getParent();
483 Builder.SetInsertPoint(InsertPt);
484 Align AlignVal = CI->getParamAlign(0).valueOrOne();
485
486 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
487
488 // The result vector
489 Value *VResult = Src0;
490 unsigned VectorWidth = VecType->getNumElements();
491
492 // Shorten the way if the mask is a vector of constants.
493 if (isConstantIntVector(Mask)) {
494 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
495 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
496 continue;
497 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
498 LoadInst *Load =
499 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
500 VResult =
501 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
502 }
503 CI->replaceAllUsesWith(VResult);
504 CI->eraseFromParent();
505 return;
506 }
507
508 // If the mask is not v1i1, use scalar bit test operations. This generates
509 // better results on X86 at least. However, don't do this on GPUs or other
510 // machines with branch divergence, as there, each i1 takes up a register.
511 Value *SclrMask = nullptr;
512 if (VectorWidth != 1 && !HasBranchDivergence) {
513 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
514 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
515 }
516
517 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
518 // Fill the "else" block, created in the previous iteration
519 //
520 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
521 // %cond = icmp ne i16 %mask_1, 0
522 // br i1 %Mask1, label %cond.load, label %else
523 //
524 // On GPUs, use
525 // %cond = extrectelement %mask, Idx
526 // instead
527
529 if (SclrMask != nullptr) {
530 Value *Mask = Builder.getInt(APInt::getOneBitSet(
531 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
532 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
533 Builder.getIntN(VectorWidth, 0));
534 } else {
535 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
536 }
537
538 // Create "cond" block
539 //
540 // %EltAddr = getelementptr i32* %1, i32 0
541 // %Elt = load i32* %EltAddr
542 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
543 //
544 // We mark the branch weights as explicitly unknown given they would only
545 // be derivable from the mask which we do not have VP information for.
546 Instruction *ThenTerm =
547 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
549 *CI->getFunction(), DEBUG_TYPE),
550 DTU);
551
552 BasicBlock *CondBlock = ThenTerm->getParent();
553 CondBlock->setName("cond.load");
554
555 Builder.SetInsertPoint(CondBlock->getTerminator());
556 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
557 LoadInst *Load =
558 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
559 Value *NewVResult =
560 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
561
562 // Create "else" block, fill it in the next iteration
563 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
564 NewIfBlock->setName("else");
565 BasicBlock *PrevIfBlock = IfBlock;
566 IfBlock = NewIfBlock;
567
568 // Create the phi to join the new and previous value.
569 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
570 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
571 Phi->addIncoming(NewVResult, CondBlock);
572 Phi->addIncoming(VResult, PrevIfBlock);
573 VResult = Phi;
574 }
575
576 CI->replaceAllUsesWith(VResult);
577 CI->eraseFromParent();
578
579 ModifiedDT = true;
580}
581
582// Translate a masked scatter intrinsic, like
583// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
584// <16 x i1> %Mask)
585// to a chain of basic blocks, that stores element one-by-one if
586// the appropriate mask bit is set.
587//
588// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
589// %Mask0 = extractelement <16 x i1> %Mask, i32 0
590// br i1 %Mask0, label %cond.store, label %else
591//
592// cond.store:
593// %Elt0 = extractelement <16 x i32> %Src, i32 0
594// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
595// store i32 %Elt0, i32* %Ptr0, align 4
596// br label %else
597//
598// else:
599// %Mask1 = extractelement <16 x i1> %Mask, i32 1
600// br i1 %Mask1, label %cond.store1, label %else2
601//
602// cond.store1:
603// %Elt1 = extractelement <16 x i32> %Src, i32 1
604// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
605// store i32 %Elt1, i32* %Ptr1, align 4
606// br label %else2
607// . . .
609 bool HasBranchDivergence, CallInst *CI,
610 DomTreeUpdater *DTU, bool &ModifiedDT) {
611 Value *Src = CI->getArgOperand(0);
612 Value *Ptrs = CI->getArgOperand(1);
613 Value *Mask = CI->getArgOperand(2);
614
615 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
616
617 assert(
618 isa<VectorType>(Ptrs->getType()) &&
619 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
620 "Vector of pointers is expected in masked scatter intrinsic");
621
622 IRBuilder<> Builder(CI->getContext());
623 Instruction *InsertPt = CI;
624 Builder.SetInsertPoint(InsertPt);
625 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
626
627 Align AlignVal = CI->getParamAlign(1).valueOrOne();
628 unsigned VectorWidth = SrcFVTy->getNumElements();
629
630 // Shorten the way if the mask is a vector of constants.
631 if (isConstantIntVector(Mask)) {
632 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
633 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
634 continue;
635 Value *OneElt =
636 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
637 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
638 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
639 }
640 CI->eraseFromParent();
641 return;
642 }
643
644 // If the mask is not v1i1, use scalar bit test operations. This generates
645 // better results on X86 at least.
646 Value *SclrMask = nullptr;
647 if (VectorWidth != 1 && !HasBranchDivergence) {
648 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
649 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
650 }
651
652 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
653 // Fill the "else" block, created in the previous iteration
654 //
655 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
656 // %cond = icmp ne i16 %mask_1, 0
657 // br i1 %Mask1, label %cond.store, label %else
658 //
659 // On GPUs, use
660 // %cond = extrectelement %mask, Idx
661 // instead
663 if (SclrMask != nullptr) {
664 Value *Mask = Builder.getInt(APInt::getOneBitSet(
665 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
666 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
667 Builder.getIntN(VectorWidth, 0));
668 } else {
669 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
670 }
671
672 // Create "cond" block
673 //
674 // %Elt1 = extractelement <16 x i32> %Src, i32 1
675 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
676 // %store i32 %Elt1, i32* %Ptr1
677 //
678 // We mark the branch weights as explicitly unknown given they would only
679 // be derivable from the mask which we do not have VP information for.
680 Instruction *ThenTerm =
681 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
683 *CI->getFunction(), DEBUG_TYPE),
684 DTU);
685
686 BasicBlock *CondBlock = ThenTerm->getParent();
687 CondBlock->setName("cond.store");
688
689 Builder.SetInsertPoint(CondBlock->getTerminator());
690 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
691 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
692 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
693
694 // Create "else" block, fill it in the next iteration
695 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
696 NewIfBlock->setName("else");
697
698 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
699 }
700 CI->eraseFromParent();
701
702 ModifiedDT = true;
703}
704
706 bool HasBranchDivergence, CallInst *CI,
707 DomTreeUpdater *DTU, bool &ModifiedDT) {
708 Value *Ptr = CI->getArgOperand(0);
709 Value *Mask = CI->getArgOperand(1);
710 Value *PassThru = CI->getArgOperand(2);
711 Align Alignment = CI->getParamAlign(0).valueOrOne();
712
713 auto *VecType = cast<FixedVectorType>(CI->getType());
714
715 Type *EltTy = VecType->getElementType();
716
717 IRBuilder<> Builder(CI->getContext());
718 Instruction *InsertPt = CI;
719 BasicBlock *IfBlock = CI->getParent();
720
721 Builder.SetInsertPoint(InsertPt);
722 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
723
724 unsigned VectorWidth = VecType->getNumElements();
725
726 // The result vector
727 Value *VResult = PassThru;
728
729 // Adjust alignment for the scalar instruction.
730 const Align AdjustedAlignment =
731 commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8);
732
733 // Shorten the way if the mask is a vector of constants.
734 // Create a build_vector pattern, with loads/poisons as necessary and then
735 // shuffle blend with the pass through value.
736 if (isConstantIntVector(Mask)) {
737 unsigned MemIndex = 0;
738 VResult = PoisonValue::get(VecType);
739 SmallVector<int, 16> ShuffleMask(VectorWidth, PoisonMaskElem);
740 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
741 Value *InsertElt;
742 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
743 InsertElt = PoisonValue::get(EltTy);
744 ShuffleMask[Idx] = Idx + VectorWidth;
745 } else {
746 Value *NewPtr =
747 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
748 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, AdjustedAlignment,
749 "Load" + Twine(Idx));
750 ShuffleMask[Idx] = Idx;
751 ++MemIndex;
752 }
753 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
754 "Res" + Twine(Idx));
755 }
756 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
757 CI->replaceAllUsesWith(VResult);
758 CI->eraseFromParent();
759 return;
760 }
761
762 // If the mask is not v1i1, use scalar bit test operations. This generates
763 // better results on X86 at least. However, don't do this on GPUs or other
764 // machines with branch divergence, as there, each i1 takes up a register.
765 Value *SclrMask = nullptr;
766 if (VectorWidth != 1 && !HasBranchDivergence) {
767 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
768 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
769 }
770
771 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
772 // Fill the "else" block, created in the previous iteration
773 //
774 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
775 // %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
776 // label %cond.load, label %else
777 //
778 // On GPUs, use
779 // %cond = extrectelement %mask, Idx
780 // instead
781
783 if (SclrMask != nullptr) {
784 Value *Mask = Builder.getInt(APInt::getOneBitSet(
785 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
786 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
787 Builder.getIntN(VectorWidth, 0));
788 } else {
789 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
790 }
791
792 // Create "cond" block
793 //
794 // %EltAddr = getelementptr i32* %1, i32 0
795 // %Elt = load i32* %EltAddr
796 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
797 //
798 Instruction *ThenTerm =
799 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
800 /*BranchWeights=*/nullptr, DTU);
801
802 BasicBlock *CondBlock = ThenTerm->getParent();
803 CondBlock->setName("cond.load");
804
805 Builder.SetInsertPoint(CondBlock->getTerminator());
806 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, AdjustedAlignment);
807 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
808
809 // Move the pointer if there are more blocks to come.
810 Value *NewPtr;
811 if ((Idx + 1) != VectorWidth)
812 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
813
814 // Create "else" block, fill it in the next iteration
815 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
816 NewIfBlock->setName("else");
817 BasicBlock *PrevIfBlock = IfBlock;
818 IfBlock = NewIfBlock;
819
820 // Create the phi to join the new and previous value.
821 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
822 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
823 ResultPhi->addIncoming(NewVResult, CondBlock);
824 ResultPhi->addIncoming(VResult, PrevIfBlock);
825 VResult = ResultPhi;
826
827 // Add a PHI for the pointer if this isn't the last iteration.
828 if ((Idx + 1) != VectorWidth) {
829 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
830 PtrPhi->addIncoming(NewPtr, CondBlock);
831 PtrPhi->addIncoming(Ptr, PrevIfBlock);
832 Ptr = PtrPhi;
833 }
834 }
835
836 CI->replaceAllUsesWith(VResult);
837 CI->eraseFromParent();
838
839 ModifiedDT = true;
840}
841
843 bool HasBranchDivergence, CallInst *CI,
844 DomTreeUpdater *DTU,
845 bool &ModifiedDT) {
846 Value *Src = CI->getArgOperand(0);
847 Value *Ptr = CI->getArgOperand(1);
848 Value *Mask = CI->getArgOperand(2);
849 Align Alignment = CI->getParamAlign(1).valueOrOne();
850
851 auto *VecType = cast<FixedVectorType>(Src->getType());
852
853 IRBuilder<> Builder(CI->getContext());
854 Instruction *InsertPt = CI;
855 BasicBlock *IfBlock = CI->getParent();
856
857 Builder.SetInsertPoint(InsertPt);
858 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
859
860 Type *EltTy = VecType->getElementType();
861
862 // Adjust alignment for the scalar instruction.
863 const Align AdjustedAlignment =
864 commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8);
865
866 unsigned VectorWidth = VecType->getNumElements();
867
868 // Shorten the way if the mask is a vector of constants.
869 if (isConstantIntVector(Mask)) {
870 unsigned MemIndex = 0;
871 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
872 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
873 continue;
874 Value *OneElt =
875 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
876 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
877 Builder.CreateAlignedStore(OneElt, NewPtr, AdjustedAlignment);
878 ++MemIndex;
879 }
880 CI->eraseFromParent();
881 return;
882 }
883
884 // If the mask is not v1i1, use scalar bit test operations. This generates
885 // better results on X86 at least. However, don't do this on GPUs or other
886 // machines with branch divergence, as there, each i1 takes up a register.
887 Value *SclrMask = nullptr;
888 if (VectorWidth != 1 && !HasBranchDivergence) {
889 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
890 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
891 }
892
893 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
894 // Fill the "else" block, created in the previous iteration
895 //
896 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
897 // br i1 %mask_1, label %cond.store, label %else
898 //
899 // On GPUs, use
900 // %cond = extrectelement %mask, Idx
901 // instead
903 if (SclrMask != nullptr) {
904 Value *Mask = Builder.getInt(APInt::getOneBitSet(
905 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
906 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
907 Builder.getIntN(VectorWidth, 0));
908 } else {
909 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
910 }
911
912 // Create "cond" block
913 //
914 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
915 // %EltAddr = getelementptr i32* %1, i32 0
916 // %store i32 %OneElt, i32* %EltAddr
917 //
918 Instruction *ThenTerm =
919 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
920 /*BranchWeights=*/nullptr, DTU);
921
922 BasicBlock *CondBlock = ThenTerm->getParent();
923 CondBlock->setName("cond.store");
924
925 Builder.SetInsertPoint(CondBlock->getTerminator());
926 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
927 Builder.CreateAlignedStore(OneElt, Ptr, AdjustedAlignment);
928
929 // Move the pointer if there are more blocks to come.
930 Value *NewPtr;
931 if ((Idx + 1) != VectorWidth)
932 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
933
934 // Create "else" block, fill it in the next iteration
935 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
936 NewIfBlock->setName("else");
937 BasicBlock *PrevIfBlock = IfBlock;
938 IfBlock = NewIfBlock;
939
940 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
941
942 // Add a PHI for the pointer if this isn't the last iteration.
943 if ((Idx + 1) != VectorWidth) {
944 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
945 PtrPhi->addIncoming(NewPtr, CondBlock);
946 PtrPhi->addIncoming(Ptr, PrevIfBlock);
947 Ptr = PtrPhi;
948 }
949 }
950 CI->eraseFromParent();
951
952 ModifiedDT = true;
953}
954
956 DomTreeUpdater *DTU,
957 bool &ModifiedDT) {
958 // If we extend histogram to return a result someday (like the updated vector)
959 // then we'll need to support it here.
960 assert(CI->getType()->isVoidTy() && "Histogram with non-void return.");
961 Value *Ptrs = CI->getArgOperand(0);
962 Value *Inc = CI->getArgOperand(1);
963 Value *Mask = CI->getArgOperand(2);
964
965 auto *AddrType = cast<FixedVectorType>(Ptrs->getType());
966 Type *EltTy = Inc->getType();
967
968 IRBuilder<> Builder(CI->getContext());
969 Instruction *InsertPt = CI;
970 Builder.SetInsertPoint(InsertPt);
971
972 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
973
974 // FIXME: Do we need to add an alignment parameter to the intrinsic?
975 unsigned VectorWidth = AddrType->getNumElements();
976 auto CreateHistogramUpdateValue = [&](IntrinsicInst *CI, Value *Load,
977 Value *Inc) -> Value * {
978 Value *UpdateOp;
979 switch (CI->getIntrinsicID()) {
980 case Intrinsic::experimental_vector_histogram_add:
981 UpdateOp = Builder.CreateAdd(Load, Inc);
982 break;
983 case Intrinsic::experimental_vector_histogram_uadd_sat:
984 UpdateOp =
985 Builder.CreateIntrinsic(Intrinsic::uadd_sat, {EltTy}, {Load, Inc});
986 break;
987 case Intrinsic::experimental_vector_histogram_umin:
988 UpdateOp = Builder.CreateIntrinsic(Intrinsic::umin, {EltTy}, {Load, Inc});
989 break;
990 case Intrinsic::experimental_vector_histogram_umax:
991 UpdateOp = Builder.CreateIntrinsic(Intrinsic::umax, {EltTy}, {Load, Inc});
992 break;
993
994 default:
995 llvm_unreachable("Unexpected histogram intrinsic");
996 }
997 return UpdateOp;
998 };
999
1000 // Shorten the way if the mask is a vector of constants.
1001 if (isConstantIntVector(Mask)) {
1002 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
1003 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
1004 continue;
1005 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
1006 LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx));
1007 Value *Update =
1008 CreateHistogramUpdateValue(cast<IntrinsicInst>(CI), Load, Inc);
1009 Builder.CreateStore(Update, Ptr);
1010 }
1011 CI->eraseFromParent();
1012 return;
1013 }
1014
1015 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
1016 Value *Predicate =
1017 Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
1018
1019 Instruction *ThenTerm =
1020 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
1021 /*BranchWeights=*/nullptr, DTU);
1022
1023 BasicBlock *CondBlock = ThenTerm->getParent();
1024 CondBlock->setName("cond.histogram.update");
1025
1026 Builder.SetInsertPoint(CondBlock->getTerminator());
1027 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
1028 LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx));
1029 Value *UpdateOp =
1030 CreateHistogramUpdateValue(cast<IntrinsicInst>(CI), Load, Inc);
1031 Builder.CreateStore(UpdateOp, Ptr);
1032
1033 // Create "else" block, fill it in the next iteration
1034 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
1035 NewIfBlock->setName("else");
1036 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
1037 }
1038
1039 CI->eraseFromParent();
1040 ModifiedDT = true;
1041}
1042
1044 DominatorTree *DT) {
1045 std::optional<DomTreeUpdater> DTU;
1046 if (DT)
1047 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1048
1049 bool EverMadeChange = false;
1050 bool MadeChange = true;
1051 auto &DL = F.getDataLayout();
1052 bool HasBranchDivergence = TTI.hasBranchDivergence(&F);
1053 while (MadeChange) {
1054 MadeChange = false;
1056 bool ModifiedDTOnIteration = false;
1057 MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
1058 HasBranchDivergence, DTU ? &*DTU : nullptr);
1059
1060 // Restart BB iteration if the dominator tree of the Function was changed
1061 if (ModifiedDTOnIteration)
1062 break;
1063 }
1064
1065 EverMadeChange |= MadeChange;
1066 }
1067 return EverMadeChange;
1068}
1069
1070bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
1071 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1072 DominatorTree *DT = nullptr;
1073 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
1074 DT = &DTWP->getDomTree();
1075 return runImpl(F, TTI, DT);
1076}
1077
1078PreservedAnalyses
1089
1090static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
1091 const TargetTransformInfo &TTI, const DataLayout &DL,
1092 bool HasBranchDivergence, DomTreeUpdater *DTU) {
1093 bool MadeChange = false;
1094
1095 BasicBlock::iterator CurInstIterator = BB.begin();
1096 while (CurInstIterator != BB.end()) {
1097 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
1098 MadeChange |=
1099 optimizeCallInst(CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU);
1100 if (ModifiedDT)
1101 return true;
1102 }
1103
1104 return MadeChange;
1105}
1106
1107static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
1108 const TargetTransformInfo &TTI,
1109 const DataLayout &DL, bool HasBranchDivergence,
1110 DomTreeUpdater *DTU) {
1112 if (II) {
1113 // The scalarization code below does not work for scalable vectors.
1114 if (isa<ScalableVectorType>(II->getType()) ||
1115 any_of(II->args(),
1116 [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
1117 return false;
1118 switch (II->getIntrinsicID()) {
1119 default:
1120 break;
1121 case Intrinsic::experimental_vector_histogram_add:
1122 case Intrinsic::experimental_vector_histogram_uadd_sat:
1123 case Intrinsic::experimental_vector_histogram_umin:
1124 case Intrinsic::experimental_vector_histogram_umax:
1125 if (TTI.isLegalMaskedVectorHistogram(CI->getArgOperand(0)->getType(),
1126 CI->getArgOperand(1)->getType()))
1127 return false;
1128 scalarizeMaskedVectorHistogram(DL, CI, DTU, ModifiedDT);
1129 return true;
1130 case Intrinsic::masked_load:
1131 // Scalarize unsupported vector masked load
1132 if (TTI.isLegalMaskedLoad(
1133 CI->getType(), CI->getParamAlign(0).valueOrOne(),
1135 ->getAddressSpace(),
1139 return false;
1140 scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1141 return true;
1142 case Intrinsic::masked_store:
1143 if (TTI.isLegalMaskedStore(
1144 CI->getArgOperand(0)->getType(),
1145 CI->getParamAlign(1).valueOrOne(),
1147 ->getAddressSpace(),
1151 return false;
1152 scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1153 return true;
1154 case Intrinsic::masked_gather: {
1155 Align Alignment = CI->getParamAlign(0).valueOrOne();
1156 Type *LoadTy = CI->getType();
1157 if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
1158 !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
1159 return false;
1160 scalarizeMaskedGather(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1161 return true;
1162 }
1163 case Intrinsic::masked_scatter: {
1164 Align Alignment = CI->getParamAlign(1).valueOrOne();
1165 Type *StoreTy = CI->getArgOperand(0)->getType();
1166 if (TTI.isLegalMaskedScatter(StoreTy, Alignment) &&
1167 !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
1168 Alignment))
1169 return false;
1170 scalarizeMaskedScatter(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1171 return true;
1172 }
1173 case Intrinsic::masked_expandload:
1174 if (TTI.isLegalMaskedExpandLoad(
1175 CI->getType(),
1176 CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
1177 return false;
1178 scalarizeMaskedExpandLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1179 return true;
1180 case Intrinsic::masked_compressstore:
1181 if (TTI.isLegalMaskedCompressStore(
1182 CI->getArgOperand(0)->getType(),
1183 CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
1184 return false;
1185 scalarizeMaskedCompressStore(DL, HasBranchDivergence, CI, DTU,
1186 ModifiedDT);
1187 return true;
1188 }
1189 }
1190
1191 return false;
1192}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file contains the declarations for the subclasses of Constant, which represent the different fla...
static bool runOnFunction(Function &F, bool PostInlining)
static bool runImpl(Function &F, const TargetLowering &TLI, const LibcallLoweringInfo &Libcalls, AssumptionCache *AC)
#define DEBUG_TYPE
#define F(x, y, z)
Definition MD5.cpp:54
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
This file contains the declarations for profiling metadata utility functions.
static void scalarizeMaskedExpandLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU)
static void scalarizeMaskedScatter(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, unsigned Idx)
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU)
static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedCompressStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedGather(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool runImpl(Function &F, const TargetTransformInfo &TTI, DominatorTree *DT)
static bool isConstantIntVector(Value *Mask)
static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
This pass exposes codegen information to IR-level passes.
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator end()
Definition BasicBlock.h:483
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:470
InstListType::iterator iterator
Instruction iterators...
Definition BasicBlock.h:170
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
LLVM_ABI Intrinsic::ID getIntrinsicID() const
Returns the intrinsic ID of the intrinsic called or Intrinsic::not_intrinsic if the called function i...
AttributeList getAttributes() const
Return the attributes for this call.
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
Analysis pass which computes a DominatorTree.
Definition Dominators.h:283
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:321
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:164
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2788
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI BasicBlock * getSuccessor(unsigned Idx) const LLVM_READONLY
Return the specified successor. This instruction must be a terminator.
LLVM_ABI void copyMetadata(const Instruction &SrcInst, ArrayRef< unsigned > WL=ArrayRef< unsigned >())
Copy metadata from SrcInst to this instruction.
A wrapper class for inspecting calls to intrinsic functions.
An instruction for reading from memory.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
Analysis pass providing the TargetTransformInfo.
Wrapper pass for TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition Twine.h:82
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:300
bool isVoidTy() const
Return true if this is 'void'.
Definition Type.h:139
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI void setName(const Twine &Name)
Change the name of the value.
Definition Value.cpp:397
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:553
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:259
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
Definition Value.cpp:403
const ParentTy * getParent() const
Definition ilist_node.h:34
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:634
LLVM_ABI FunctionPass * createScalarizeMaskedMemIntrinLegacyPass()
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1746
LLVM_ABI bool isSplatValue(const Value *V, int Index=-1, unsigned Depth=0)
Return true if each element of the vector value V is poisoned or equal to every other non-poisoned el...
LLVM_ABI void initializeScalarizeMaskedMemIntrinLegacyPassPass(PassRegistry &)
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ABI MDNode * getExplicitlyUnknownBranchWeightsIfProfiled(Function &F, StringRef PassName)
Returns a metadata node containing unknown branch weights if the function has an entry count,...
constexpr int PoisonMaskElem
TargetTransformInfo TTI
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition Alignment.h:201
LLVM_ABI Instruction * SplitBlockAndInsertIfThen(Value *Cond, BasicBlock::iterator SplitBefore, bool Unreachable, MDNode *BranchWeights=nullptr, DomTreeUpdater *DTU=nullptr, LoopInfo *LI=nullptr, BasicBlock *ThenBlock=nullptr)
Split the containing block at the specified instruction - everything before SplitBefore stays in the ...
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
Definition Alignment.h:130
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)