LLVM 23.0.0git
AMDGPULowerKernelAttributes.cpp
Go to the documentation of this file.
1//===-- AMDGPULowerKernelAttributes.cpp------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file This pass does attempts to make use of reqd_work_group_size metadata
10/// to eliminate loads from the dispatch packet and to constant fold OpenCL
11/// get_local_size-like functions.
12//
13//===----------------------------------------------------------------------===//
14
15#include "AMDGPU.h"
19#include "llvm/CodeGen/Passes.h"
20#include "llvm/IR/Constants.h"
21#include "llvm/IR/Function.h"
22#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/IntrinsicsAMDGPU.h"
26#include "llvm/IR/MDBuilder.h"
28#include "llvm/Pass.h"
29
30#define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
31
32using namespace llvm;
33
34namespace {
35
36// Field offsets in hsa_kernel_dispatch_packet_t.
37enum DispatchPackedOffsets {
38 WORKGROUP_SIZE_X = 4,
39 WORKGROUP_SIZE_Y = 6,
40 WORKGROUP_SIZE_Z = 8,
41
42 GRID_SIZE_X = 12,
43 GRID_SIZE_Y = 16,
44 GRID_SIZE_Z = 20
45};
46
47// Field offsets to implicit kernel argument pointer.
48enum ImplicitArgOffsets {
49 HIDDEN_BLOCK_COUNT_X = 0,
50 HIDDEN_BLOCK_COUNT_Y = 4,
51 HIDDEN_BLOCK_COUNT_Z = 8,
52
53 HIDDEN_GROUP_SIZE_X = 12,
54 HIDDEN_GROUP_SIZE_Y = 14,
55 HIDDEN_GROUP_SIZE_Z = 16,
56
57 HIDDEN_REMAINDER_X = 18,
58 HIDDEN_REMAINDER_Y = 20,
59 HIDDEN_REMAINDER_Z = 22,
60};
61
62class AMDGPULowerKernelAttributes : public ModulePass {
63public:
64 static char ID;
65
66 AMDGPULowerKernelAttributes() : ModulePass(ID) {}
67
68 bool runOnModule(Module &M) override;
69
70 StringRef getPassName() const override { return "AMDGPU Kernel Attributes"; }
71
72 void getAnalysisUsage(AnalysisUsage &AU) const override {
73 AU.setPreservesAll();
74 }
75};
76
77Function *getBasePtrIntrinsic(Module &M, bool IsV5OrAbove) {
78 auto IntrinsicId = IsV5OrAbove ? Intrinsic::amdgcn_implicitarg_ptr
79 : Intrinsic::amdgcn_dispatch_ptr;
80 return Intrinsic::getDeclarationIfExists(&M, IntrinsicId);
81}
82
83} // end anonymous namespace
84
86 uint32_t MaxNumGroups) {
87 if (MaxNumGroups == 0 || MaxNumGroups == std::numeric_limits<uint32_t>::max())
88 return false;
89
90 if (!Load->getType()->isIntegerTy(32))
91 return false;
92
93 // TODO: If there is existing range metadata, preserve it if it is stricter.
94 if (Load->hasMetadata(LLVMContext::MD_range))
95 return false;
96
97 MDBuilder MDB(Load->getContext());
98 MDNode *Range = MDB.createRange(APInt(32, 1), APInt(32, MaxNumGroups + 1));
99 Load->setMetadata(LLVMContext::MD_range, Range);
100 return true;
101}
102
103static bool annotateGroupSizeLoadWithRangeMD(LoadInst *Load, bool IsRemainder) {
104 if (!Load->getType()->isIntegerTy(16))
105 return false;
106
107 // TODO: If there is existing range metadata, preserve it if it is stricter.
108 if (Load->hasMetadata(LLVMContext::MD_range))
109 return false;
110
111 MDBuilder MDB(Load->getContext());
112 MDNode *Range = MDB.createRange(
113 APInt(16, !IsRemainder),
114 APInt(16, AMDGPU::IsaInfo::getMaxFlatWorkGroupSize() + 1 - IsRemainder));
115 Load->setMetadata(LLVMContext::MD_range, Range);
116 return true;
117}
118
119static bool processUse(CallInst *CI, bool IsV5OrAbove) {
120 Function *F = CI->getFunction();
121
122 auto *MD = F->getMetadata("reqd_work_group_size");
123 const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
124
125 const bool HasUniformWorkGroupSize =
126 F->hasFnAttribute("uniform-work-group-size");
127
128 SmallVector<unsigned> MaxNumWorkgroups =
129 AMDGPU::getIntegerVecAttribute(*F, "amdgpu-max-num-workgroups",
130 /*Size=*/3, /*DefaultVal=*/0);
131
132 Value *BlockCounts[3] = {nullptr, nullptr, nullptr};
133 Value *GroupSizes[3] = {nullptr, nullptr, nullptr};
134 Value *Remainders[3] = {nullptr, nullptr, nullptr};
135 Value *GridSizes[3] = {nullptr, nullptr, nullptr};
136
137 const DataLayout &DL = F->getDataLayout();
138 bool MadeChange = false;
139
140 // We expect to see several GEP users, casted to the appropriate type and
141 // loaded.
142 for (User *U : CI->users()) {
143 if (!U->hasOneUse())
144 continue;
145
146 int64_t Offset = 0;
147 auto *Load = dyn_cast<LoadInst>(U); // Load from ImplicitArgPtr/DispatchPtr?
148 auto *BCI = dyn_cast<BitCastInst>(U);
149 if (!Load && !BCI) {
151 continue;
152 Load = dyn_cast<LoadInst>(*U->user_begin()); // Load from GEP?
153 BCI = dyn_cast<BitCastInst>(*U->user_begin());
154 }
155
156 if (BCI) {
157 if (!BCI->hasOneUse())
158 continue;
159 Load = dyn_cast<LoadInst>(*BCI->user_begin()); // Load from BCI?
160 }
161
162 if (!Load || !Load->isSimple())
163 continue;
164
165 unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
166
167 // TODO: Handle merged loads.
168 if (IsV5OrAbove) { // Base is ImplicitArgPtr.
169 switch (Offset) {
170 case HIDDEN_BLOCK_COUNT_X:
171 if (LoadSize == 4) {
172 BlockCounts[0] = Load;
173 MadeChange |=
174 annotateGridSizeLoadWithRangeMD(Load, MaxNumWorkgroups[0]);
175 }
176 break;
177 case HIDDEN_BLOCK_COUNT_Y:
178 if (LoadSize == 4) {
179 BlockCounts[1] = Load;
180 MadeChange |=
181 annotateGridSizeLoadWithRangeMD(Load, MaxNumWorkgroups[1]);
182 }
183 break;
184 case HIDDEN_BLOCK_COUNT_Z:
185 if (LoadSize == 4) {
186 BlockCounts[2] = Load;
187 MadeChange |=
188 annotateGridSizeLoadWithRangeMD(Load, MaxNumWorkgroups[2]);
189 }
190 break;
191 case HIDDEN_GROUP_SIZE_X:
192 if (LoadSize == 2) {
193 GroupSizes[0] = Load;
194 MadeChange |= annotateGroupSizeLoadWithRangeMD(Load, false);
195 }
196 break;
197 case HIDDEN_GROUP_SIZE_Y:
198 if (LoadSize == 2) {
199 GroupSizes[1] = Load;
200 MadeChange |= annotateGroupSizeLoadWithRangeMD(Load, false);
201 }
202 break;
203 case HIDDEN_GROUP_SIZE_Z:
204 if (LoadSize == 2) {
205 GroupSizes[2] = Load;
206 MadeChange |= annotateGroupSizeLoadWithRangeMD(Load, false);
207 }
208 break;
209 case HIDDEN_REMAINDER_X:
210 if (LoadSize == 2) {
211 Remainders[0] = Load;
212 MadeChange |= annotateGroupSizeLoadWithRangeMD(Load, true);
213 }
214 break;
215 case HIDDEN_REMAINDER_Y:
216 if (LoadSize == 2) {
217 Remainders[1] = Load;
218 MadeChange |= annotateGroupSizeLoadWithRangeMD(Load, true);
219 }
220 break;
221 case HIDDEN_REMAINDER_Z:
222 if (LoadSize == 2) {
223 Remainders[2] = Load;
224 MadeChange |= annotateGroupSizeLoadWithRangeMD(Load, true);
225 }
226 break;
227 default:
228 break;
229 }
230 } else { // Base is DispatchPtr.
231 switch (Offset) {
232 case WORKGROUP_SIZE_X:
233 if (LoadSize == 2)
234 GroupSizes[0] = Load;
235 break;
236 case WORKGROUP_SIZE_Y:
237 if (LoadSize == 2)
238 GroupSizes[1] = Load;
239 break;
240 case WORKGROUP_SIZE_Z:
241 if (LoadSize == 2)
242 GroupSizes[2] = Load;
243 break;
244 case GRID_SIZE_X:
245 if (LoadSize == 4)
246 GridSizes[0] = Load;
247 break;
248 case GRID_SIZE_Y:
249 if (LoadSize == 4)
250 GridSizes[1] = Load;
251 break;
252 case GRID_SIZE_Z:
253 if (LoadSize == 4)
254 GridSizes[2] = Load;
255 break;
256 default:
257 break;
258 }
259 }
260 }
261
262 if (IsV5OrAbove && HasUniformWorkGroupSize) {
263 // Under v5 __ockl_get_local_size returns the value computed by the
264 // expression:
265 //
266 // workgroup_id < hidden_block_count ? hidden_group_size :
267 // hidden_remainder
268 //
269 // For functions with the attribute uniform-work-group-size=true. we can
270 // evaluate workgroup_id < hidden_block_count as true, and thus
271 // hidden_group_size is returned for __ockl_get_local_size.
272 for (int I = 0; I < 3; ++I) {
273 Value *BlockCount = BlockCounts[I];
274 if (!BlockCount)
275 continue;
276
277 using namespace llvm::PatternMatch;
278 auto GroupIDIntrin =
282
283 for (User *ICmp : BlockCount->users()) {
284 if (match(ICmp, m_SpecificICmp(ICmpInst::ICMP_ULT, GroupIDIntrin,
285 m_Specific(BlockCount)))) {
286 ICmp->replaceAllUsesWith(llvm::ConstantInt::getTrue(ICmp->getType()));
287 MadeChange = true;
288 }
289 }
290 }
291
292 // All remainders should be 0 with uniform work group size.
293 for (Value *Remainder : Remainders) {
294 if (!Remainder)
295 continue;
296 Remainder->replaceAllUsesWith(
297 Constant::getNullValue(Remainder->getType()));
298 MadeChange = true;
299 }
300 } else if (HasUniformWorkGroupSize) { // Pre-V5.
301 // Pattern match the code used to handle partial workgroup dispatches in the
302 // library implementation of get_local_size, so the entire function can be
303 // constant folded with a known group size.
304 //
305 // uint r = grid_size - group_id * group_size;
306 // get_local_size = (r < group_size) ? r : group_size;
307 //
308 // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
309 // the grid_size is required to be a multiple of group_size). In this case:
310 //
311 // grid_size - (group_id * group_size) < group_size
312 // ->
313 // grid_size < group_size + (group_id * group_size)
314 //
315 // (grid_size / group_size) < 1 + group_id
316 //
317 // grid_size / group_size is at least 1, so we can conclude the select
318 // condition is false (except for group_id == 0, where the select result is
319 // the same).
320 for (int I = 0; I < 3; ++I) {
321 Value *GroupSize = GroupSizes[I];
322 Value *GridSize = GridSizes[I];
323 if (!GroupSize || !GridSize)
324 continue;
325
326 using namespace llvm::PatternMatch;
327 auto GroupIDIntrin =
331
332 for (User *U : GroupSize->users()) {
333 auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
334 if (!ZextGroupSize)
335 continue;
336
337 for (User *UMin : ZextGroupSize->users()) {
338 if (match(UMin, m_UMin(m_Sub(m_Specific(GridSize),
339 m_Mul(GroupIDIntrin,
340 m_Specific(ZextGroupSize))),
341 m_Specific(ZextGroupSize)))) {
342 if (HasReqdWorkGroupSize) {
343 ConstantInt *KnownSize =
344 mdconst::extract<ConstantInt>(MD->getOperand(I));
345 UMin->replaceAllUsesWith(ConstantFoldIntegerCast(
346 KnownSize, UMin->getType(), false, DL));
347 } else {
348 UMin->replaceAllUsesWith(ZextGroupSize);
349 }
350
351 MadeChange = true;
352 }
353 }
354 }
355 }
356 }
357
358 // Upgrade the old method of calculating the block size using the grid size.
359 // We pattern match any case where the implicit argument group size is the
360 // divisor to a dispatch packet grid size read of the same dimension.
361 if (IsV5OrAbove) {
362 for (int I = 0; I < 3; I++) {
363 Value *GroupSize = GroupSizes[I];
364 if (!GroupSize || !GroupSize->getType()->isIntegerTy(16))
365 continue;
366
367 for (User *U : GroupSize->users()) {
369 if (isa<ZExtInst>(Inst) && !Inst->use_empty())
370 Inst = cast<Instruction>(*Inst->user_begin());
371
372 using namespace llvm::PatternMatch;
373 if (!match(
374 Inst,
377 m_SpecificInt(GRID_SIZE_X + I * sizeof(uint32_t))))),
378 m_Value())))
379 continue;
380
381 IRBuilder<> Builder(Inst);
382
383 Value *GEP = Builder.CreateInBoundsGEP(
384 Builder.getInt8Ty(), CI,
385 {ConstantInt::get(Type::getInt64Ty(CI->getContext()),
386 HIDDEN_BLOCK_COUNT_X + I * sizeof(uint32_t))});
387 Instruction *BlockCount = Builder.CreateLoad(Builder.getInt32Ty(), GEP);
388 BlockCount->setMetadata(LLVMContext::MD_invariant_load,
389 MDNode::get(CI->getContext(), {}));
390 BlockCount->setMetadata(LLVMContext::MD_noundef,
391 MDNode::get(CI->getContext(), {}));
392
393 Value *BlockCountExt = Builder.CreateZExt(BlockCount, Inst->getType());
394 Inst->replaceAllUsesWith(BlockCountExt);
395 Inst->eraseFromParent();
396 MadeChange = true;
397 }
398 }
399 }
400
401 // If reqd_work_group_size is set, we can replace work group size with it.
402 if (!HasReqdWorkGroupSize)
403 return MadeChange;
404
405 for (int I = 0; I < 3; I++) {
406 Value *GroupSize = GroupSizes[I];
407 if (!GroupSize)
408 continue;
409
410 ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
411 GroupSize->replaceAllUsesWith(
412 ConstantFoldIntegerCast(KnownSize, GroupSize->getType(), false, DL));
413 MadeChange = true;
414 }
415
416 return MadeChange;
417}
418
419// TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
420// TargetPassConfig for subtarget.
421bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
422 bool MadeChange = false;
423 bool IsV5OrAbove =
425 Function *BasePtr = getBasePtrIntrinsic(M, IsV5OrAbove);
426
427 if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
428 return false;
429
430 SmallPtrSet<Instruction *, 4> HandledUses;
431 for (auto *U : BasePtr->users()) {
432 CallInst *CI = cast<CallInst>(U);
433 if (HandledUses.insert(CI).second) {
434 if (processUse(CI, IsV5OrAbove))
435 MadeChange = true;
436 }
437 }
438
439 return MadeChange;
440}
441
442INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
443 "AMDGPU Kernel Attributes", false, false)
444INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
445 "AMDGPU Kernel Attributes", false, false)
446
447char AMDGPULowerKernelAttributes::ID = 0;
448
450 return new AMDGPULowerKernelAttributes();
451}
452
455 bool IsV5OrAbove =
457 Function *BasePtr = getBasePtrIntrinsic(*F.getParent(), IsV5OrAbove);
458
459 if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
460 return PreservedAnalyses::all();
461
462 bool Changed = false;
463 for (Instruction &I : instructions(F)) {
464 if (CallInst *CI = dyn_cast<CallInst>(&I)) {
465 if (CI->getCalledFunction() == BasePtr)
466 Changed |= processUse(CI, IsV5OrAbove);
467 }
468 }
469
472}
static bool annotateGridSizeLoadWithRangeMD(LoadInst *Load, uint32_t MaxNumGroups)
static bool annotateGroupSizeLoadWithRangeMD(LoadInst *Load, bool IsRemainder)
static bool processUse(CallInst *CI, bool IsV5OrAbove)
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
This file contains the declarations for the subclasses of Constant, which represent the different fla...
#define DEBUG_TYPE
Hexagon Common GEP
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
Class for arbitrary precision integers.
Definition APInt.h:78
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
Represents analyses that only rely on functions' control flow.
Definition Analysis.h:73
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
This class represents a function call, abstracting a target machine's calling convention.
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
This is the shared class of boolean and integer constants.
Definition Constants.h:87
static LLVM_ABI ConstantInt * getTrue(LLVMContext &Context)
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2788
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI void setMetadata(unsigned KindID, MDNode *Node)
Set the metadata of the specified kind to the specified node.
An instruction for reading from memory.
LLVM_ABI MDNode * createRange(const APInt &Lo, const APInt &Hi)
Return metadata describing the range [Lo, Hi).
Definition MDBuilder.cpp:96
Metadata node.
Definition Metadata.h:1080
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata * > MDs)
Definition Metadata.h:1572
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition Pass.h:255
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition Analysis.h:115
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
Definition Analysis.h:151
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
user_iterator user_begin()
Definition Value.h:403
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:553
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:259
iterator_range< user_iterator > users()
Definition Value.h:427
bool use_empty() const
Definition Value.h:347
Changed
constexpr unsigned getMaxFlatWorkGroupSize()
unsigned getAMDHSACodeObjectVersion(const Module &M)
SmallVector< unsigned > getIntegerVecAttribute(const Function &F, StringRef Name, unsigned Size, unsigned DefaultVal)
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
match_combine_or< CastInst_match< OpTy, ZExtInst >, OpTy > m_ZExtOrSelf(const OpTy &Op)
bool match(Val *V, const Pattern &P)
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
auto m_GEP(const OperandTypes &...Ops)
Matches GetElementPtrInst.
SpecificCmpClass_match< LHS, RHS, ICmpInst > m_SpecificICmp(CmpPredicate MatchPred, const LHS &L, const RHS &R)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
BinaryOp_match< LHS, RHS, Instruction::UDiv > m_UDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Sub > m_Sub(const LHS &L, const RHS &R)
MaxMin_match< ICmpInst, LHS, RHS, umin_pred_ty > m_UMin(const LHS &L, const RHS &R)
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > extract(Y &&MD)
Extract a Value from Metadata.
Definition Metadata.h:668
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
@ Offset
Definition DWP.cpp:532
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
ModulePass * createAMDGPULowerKernelAttributesPass()
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI Constant * ConstantFoldIntegerCast(Constant *C, Type *DestTy, bool IsSigned, const DataLayout &DL)
Constant fold a zext, sext or trunc, depending on IsSigned and whether the DestTy is wider or narrowe...
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)