LLVM 23.0.0git
SPIRVISelLowering.cpp
Go to the documentation of this file.
1//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SPIRVTargetLowering class.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVISelLowering.h"
14#include "SPIRV.h"
15#include "SPIRVInstrInfo.h"
17#include "SPIRVRegisterInfo.h"
18#include "SPIRVSubtarget.h"
23#include "llvm/IR/IntrinsicsSPIRV.h"
24
25#define DEBUG_TYPE "spirv-lower"
26
27using namespace llvm;
28
30 const SPIRVSubtarget &ST)
31 : TargetLowering(TM, ST), STI(ST) {
32 // Even with SPV_ALTERA_arbitrary_precision_integers enabled, atomic sizes are
33 // limited by atomicrmw xchg operation, which only supports operand up to 64
34 // bits wide, as defined in SPIR-V legalizer. Currently, spirv-val doesn't
35 // consider 128-bit OpTypeInt as valid either.
38}
39
40// Returns true of the types logically match, as defined in
41// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical.
42static bool typesLogicallyMatch(const SPIRVTypeInst Ty1,
43 const SPIRVTypeInst Ty2,
45 if (Ty1->getOpcode() != Ty2->getOpcode())
46 return false;
47
48 if (Ty1->getNumOperands() != Ty2->getNumOperands())
49 return false;
50
51 if (Ty1->getOpcode() == SPIRV::OpTypeArray) {
52 // Array must have the same size.
53 if (Ty1->getOperand(2).getReg() != Ty2->getOperand(2).getReg())
54 return false;
55
56 SPIRVTypeInst ElemType1 =
58 SPIRVTypeInst ElemType2 =
60 return ElemType1 == ElemType2 ||
61 typesLogicallyMatch(ElemType1, ElemType2, GR);
62 }
63
64 if (Ty1->getOpcode() == SPIRV::OpTypeStruct) {
65 for (unsigned I = 1; I < Ty1->getNumOperands(); I++) {
66 SPIRVTypeInst ElemType1 =
68 SPIRVTypeInst ElemType2 =
70 if (ElemType1 != ElemType2 &&
71 !typesLogicallyMatch(ElemType1, ElemType2, GR))
72 return false;
73 }
74 return true;
75 }
76 return false;
77}
78
80 LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
81 // This code avoids CallLowering fail inside getVectorTypeBreakdown
82 // on v3i1 arguments. Maybe we need to return 1 for all types.
83 // TODO: remove it once this case is supported by the default implementation.
84 if (VT.isVector() && VT.getVectorNumElements() == 3 &&
85 (VT.getVectorElementType() == MVT::i1 ||
86 VT.getVectorElementType() == MVT::i8))
87 return 1;
88 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
89 return 1;
90 return getNumRegisters(Context, VT);
91}
92
95 EVT VT) const {
96 // This code avoids CallLowering fail inside getVectorTypeBreakdown
97 // on v3i1 arguments. Maybe we need to return i32 for all types.
98 // TODO: remove it once this case is supported by the default implementation.
99 if (VT.isVector() && VT.getVectorNumElements() == 3) {
100 if (VT.getVectorElementType() == MVT::i1)
101 return MVT::v4i1;
102 else if (VT.getVectorElementType() == MVT::i8)
103 return MVT::v4i8;
104 }
105 return getRegisterType(Context, VT);
106}
107
110 MachineFunction &MF, unsigned Intrinsic) const {
111 IntrinsicInfo Info;
112 unsigned AlignIdx = 3;
113 switch (Intrinsic) {
114 case Intrinsic::spv_load:
115 AlignIdx = 2;
116 [[fallthrough]];
117 case Intrinsic::spv_store: {
118 if (I.getNumOperands() >= AlignIdx + 1) {
119 auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
120 Info.align = Align(AlignOp->getZExtValue());
121 }
122 Info.flags = static_cast<MachineMemOperand::Flags>(
123 cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
124 Info.memVT = MVT::i64;
125 // TODO: take into account opaque pointers (don't use getElementType).
126 // MVT::getVT(PtrTy->getElementType());
127 Infos.push_back(Info);
128 return;
129 }
130 default:
131 break;
132 }
133}
134
137 // SPIR-V represents inline assembly via OpAsmINTEL where constraints are
138 // passed through as literals defined by client API. Return C_RegisterClass
139 // for any constraint since SPIR-V does not distinguish between register,
140 // immediate, or memory operands at this level.
141 return C_RegisterClass;
142}
143
144std::pair<unsigned, const TargetRegisterClass *>
146 StringRef Constraint,
147 MVT VT) const {
148 const TargetRegisterClass *RC = nullptr;
149 if (Constraint.starts_with("{"))
150 return std::make_pair(0u, RC);
151
152 if (VT.isFloatingPoint())
153 RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;
154 else if (VT.isInteger())
155 RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;
156 else
157 RC = &SPIRV::iIDRegClass;
158
159 return std::make_pair(0u, RC);
160}
161
163 const MachineInstr *Inst = MRI->getVRegDef(OpReg);
164 return Inst && Inst->getOpcode() == SPIRV::OpFunctionParameter
165 ? Inst->getOperand(1).getReg()
166 : OpReg;
167}
168
171 Register OpReg, unsigned OpIdx,
172 SPIRVTypeInst NewPtrType) {
173 MachineIRBuilder MIB(I);
174 Register NewReg = createVirtualRegister(NewPtrType, &GR, MRI, MIB.getMF());
175 MIB.buildInstr(SPIRV::OpBitcast)
176 .addDef(NewReg)
177 .addUse(GR.getSPIRVTypeID(NewPtrType))
178 .addUse(OpReg)
180 *STI.getRegBankInfo());
181 I.getOperand(OpIdx).setReg(NewReg);
182}
183
185 SPIRVTypeInst OpType, bool ReuseType,
186 SPIRVTypeInst ResType,
187 const Type *ResTy) {
188 SPIRV::StorageClass::StorageClass SC =
189 static_cast<SPIRV::StorageClass::StorageClass>(
190 OpType->getOperand(1).getImm());
191 MachineIRBuilder MIB(I);
192 SPIRVTypeInst NewBaseType =
193 ReuseType ? ResType
195 ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
196 return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
197}
198
199// Insert a bitcast before the instruction to keep SPIR-V code valid
200// when there is a type mismatch between results and operand types.
201static void validatePtrTypes(const SPIRVSubtarget &STI,
203 MachineInstr &I, unsigned OpIdx,
204 SPIRVTypeInst ResType,
205 const Type *ResTy = nullptr) {
206 // Get operand type
207 MachineFunction *MF = I.getParent()->getParent();
208 Register OpReg = I.getOperand(OpIdx).getReg();
209 Register OpTypeReg = getTypeReg(MRI, OpReg);
210 const MachineInstr *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
211 if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
212 return;
213 // Get operand's pointee type
214 Register ElemTypeReg = OpType->getOperand(2).getReg();
215 SPIRVTypeInst ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
216 if (!ElemType)
217 return;
218 // Check if we need a bitcast to make a statement valid
219 bool IsSameMF = MF == ResType->getParent()->getParent();
220 bool IsEqualTypes = IsSameMF ? ElemType == ResType
221 : GR.getTypeForSPIRVType(ElemType) == ResTy;
222 if (IsEqualTypes)
223 return;
224 // There is a type mismatch between results and operand types
225 // and we insert a bitcast before the instruction to keep SPIR-V code valid
226 SPIRVTypeInst NewPtrType =
227 createNewPtrType(GR, I, OpType, IsSameMF, ResType, ResTy);
228 if (!GR.isBitcastCompatible(NewPtrType, OpType))
230 "insert validation bitcast: incompatible result and operand types");
231 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
232}
233
234// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
235// that doesn't point to OpTypeEvent.
239 MachineInstr &I) {
240 constexpr unsigned OpIdx = 2;
241 MachineFunction *MF = I.getParent()->getParent();
242 Register OpReg = I.getOperand(OpIdx).getReg();
243 Register OpTypeReg = getTypeReg(MRI, OpReg);
244 SPIRVTypeInst OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
245 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
246 return;
247 SPIRVTypeInst ElemType =
248 GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
249 if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
250 return;
251 // Insert a bitcast before the instruction to keep SPIR-V code valid.
252 LLVMContext &Context = MF->getFunction().getContext();
253 SPIRVTypeInst NewPtrType =
254 createNewPtrType(GR, I, OpType, false, nullptr,
255 TargetExtType::get(Context, "spirv.Event"));
256 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
257}
258
262 Register PtrReg = I.getOperand(0).getReg();
263 MachineFunction *MF = I.getParent()->getParent();
264 Register PtrTypeReg = getTypeReg(MRI, PtrReg);
265 SPIRVTypeInst PtrType = GR.getSPIRVTypeForVReg(PtrTypeReg, MF);
266 SPIRVTypeInst PonteeElemType = PtrType ? GR.getPointeeType(PtrType) : nullptr;
267 if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid ||
268 (PonteeElemType->getOpcode() == SPIRV::OpTypeInt &&
269 PonteeElemType->getOperand(1).getImm() == 8))
270 return;
271 // To keep the code valid a bitcast must be inserted
272 SPIRV::StorageClass::StorageClass SC =
273 static_cast<SPIRV::StorageClass::StorageClass>(
274 PtrType->getOperand(1).getImm());
275 MachineIRBuilder MIB(I);
276 LLVMContext &Context = MF->getFunction().getContext();
277 SPIRVTypeInst NewPtrType =
279 doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
280}
281
285 MachineInstr &I, unsigned OpIdx) {
286 MachineFunction *MF = I.getParent()->getParent();
287 Register OpReg = I.getOperand(OpIdx).getReg();
288 Register OpTypeReg = getTypeReg(MRI, OpReg);
289 SPIRVTypeInst OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
290 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
291 return;
292 SPIRVTypeInst ElemType =
293 GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
294 if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
295 ElemType->getNumOperands() != 2)
296 return;
297 // It's a structure-wrapper around another type with a single member field.
298 SPIRVTypeInst MemberType =
299 GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg());
300 if (!MemberType)
301 return;
302 unsigned MemberTypeOp = MemberType->getOpcode();
303 if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
304 MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
305 return;
306 // It's a structure-wrapper around a valid type. Insert a bitcast before the
307 // instruction to keep SPIR-V code valid.
308 SPIRV::StorageClass::StorageClass SC =
309 static_cast<SPIRV::StorageClass::StorageClass>(
310 OpType->getOperand(1).getImm());
311 MachineIRBuilder MIB(I);
312 SPIRVTypeInst NewPtrType =
313 GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC);
314 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
315}
316
317// Insert a bitcast before the function call instruction to keep SPIR-V code
318// valid when there is a type mismatch between actual and expected types of an
319// argument:
320// %formal = OpFunctionParameter %formal_type
321// ...
322// %res = OpFunctionCall %ty %fun %actual ...
323// implies that %actual is of %formal_type, and in case of opaque pointers.
324// We may need to insert a bitcast to ensure this.
326 MachineRegisterInfo *DefMRI,
327 MachineRegisterInfo *CallMRI,
328 SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
329 MachineInstr *FunDef) {
330 if (FunDef->getOpcode() != SPIRV::OpFunction)
331 return;
332 unsigned OpIdx = 3;
333 for (FunDef = FunDef->getNextNode();
334 FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
335 OpIdx < FunCall.getNumOperands();
336 FunDef = FunDef->getNextNode(), OpIdx++) {
337 SPIRVTypeInst DefPtrType =
338 DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
339 SPIRVTypeInst DefElemType =
340 DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
341 ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
342 DefPtrType->getParent()->getParent())
343 : nullptr;
344 if (DefElemType) {
345 const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
346 // validatePtrTypes() works in the context if the call site
347 // When we process historical records about forward calls
348 // we need to switch context to the (forward) call site and
349 // then restore it back to the current machine function.
350 MachineFunction *CurMF =
351 GR.setCurrentFunc(*FunCall.getParent()->getParent());
352 validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
353 DefElemTy);
354 GR.setCurrentFunc(*CurMF);
355 }
356 }
357}
358
359// Ensure there is no mismatch between actual and expected arg types: calls
360// with a processed definition. Return Function pointer if it's a forward
361// call (ahead of definition), and nullptr otherwise.
363 MachineRegisterInfo *CallMRI,
365 MachineInstr &FunCall) {
366 const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
367 const Function *F = dyn_cast<Function>(GV);
368 MachineInstr *FunDef =
369 const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
370 if (!FunDef)
371 return F;
372 MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
373 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
374 return nullptr;
375}
376
377// Ensure there is no mismatch between actual and expected arg types: calls
378// ahead of a processed definition.
381 MachineInstr &FunDef) {
382 const Function *F = GR.getFunctionByDefinition(&FunDef);
384 for (MachineInstr *FunCall : *FwdCalls) {
385 MachineRegisterInfo *CallMRI =
386 &FunCall->getParent()->getParent()->getRegInfo();
387 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef);
388 }
389}
390
391// Validation of an access chain.
394 SPIRVTypeInst BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg());
395 if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
396 SPIRVTypeInst BaseElemType =
397 GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg());
398 validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType);
399 }
400}
401
402// TODO: the logic of inserting additional bitcast's is to be moved
403// to pre-IRTranslation passes eventually
405 // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
406 // We'd like to avoid the needless second processing pass.
407 if (ProcessedMF.find(&MF) != ProcessedMF.end())
408 return;
409
410 MachineRegisterInfo *MRI = &MF.getRegInfo();
411 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
412 GR.setCurrentFunc(MF);
413 for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
415 for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
416 MBBI != MBBE;) {
417 MachineInstr &MI = *MBBI++;
418 switch (MI.getOpcode()) {
419 case SPIRV::OpAtomicLoad:
420 case SPIRV::OpAtomicExchange:
421 case SPIRV::OpAtomicCompareExchange:
422 case SPIRV::OpAtomicCompareExchangeWeak:
423 case SPIRV::OpAtomicIIncrement:
424 case SPIRV::OpAtomicIDecrement:
425 case SPIRV::OpAtomicIAdd:
426 case SPIRV::OpAtomicISub:
427 case SPIRV::OpAtomicSMin:
428 case SPIRV::OpAtomicUMin:
429 case SPIRV::OpAtomicSMax:
430 case SPIRV::OpAtomicUMax:
431 case SPIRV::OpAtomicAnd:
432 case SPIRV::OpAtomicOr:
433 case SPIRV::OpAtomicXor:
434 // for the above listed instructions
435 // OpAtomicXXX <ResType>, ptr %Op, ...
436 // implies that %Op is a pointer to <ResType>
437 case SPIRV::OpLoad:
438 // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
440 break;
441
442 validatePtrTypes(STI, MRI, GR, MI, 2,
443 GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
444 break;
445 case SPIRV::OpAtomicStore:
446 // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
447 // implies that %Op points to the <Obj>'s type
448 validatePtrTypes(STI, MRI, GR, MI, 0,
449 GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg()));
450 break;
451 case SPIRV::OpStore:
452 // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
453 validatePtrTypes(STI, MRI, GR, MI, 0,
454 GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
455 break;
456 case SPIRV::OpPtrCastToGeneric:
457 case SPIRV::OpGenericCastToPtr:
458 case SPIRV::OpGenericCastToPtrExplicit:
459 validateAccessChain(STI, MRI, GR, MI);
460 break;
461 case SPIRV::OpPtrAccessChain:
462 case SPIRV::OpInBoundsPtrAccessChain:
463 if (MI.getNumOperands() == 4)
464 validateAccessChain(STI, MRI, GR, MI);
465 break;
466
467 case SPIRV::OpFunctionCall:
468 // ensure there is no mismatch between actual and expected arg types:
469 // calls with a processed definition
470 if (MI.getNumOperands() > 3)
471 if (const Function *F = validateFunCall(STI, MRI, GR, MI))
472 GR.addForwardCall(F, &MI);
473 break;
474 case SPIRV::OpFunction:
475 // ensure there is no mismatch between actual and expected arg types:
476 // calls ahead of a processed definition
477 validateForwardCalls(STI, MRI, GR, MI);
478 break;
479
480 // ensure that LLVM IR add/sub instructions result in logical SPIR-V
481 // instructions when applied to bool type
482 case SPIRV::OpIAddS:
483 case SPIRV::OpIAddV:
484 case SPIRV::OpISubS:
485 case SPIRV::OpISubV:
486 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
487 SPIRV::OpTypeBool))
488 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
489 break;
490
491 // ensure that LLVM IR bitwise instructions result in logical SPIR-V
492 // instructions when applied to bool type
493 case SPIRV::OpBitwiseOrS:
494 case SPIRV::OpBitwiseOrV:
495 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
496 SPIRV::OpTypeBool))
497 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));
498 break;
499 case SPIRV::OpBitwiseAndS:
500 case SPIRV::OpBitwiseAndV:
501 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
502 SPIRV::OpTypeBool))
503 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));
504 break;
505 case SPIRV::OpBitwiseXorS:
506 case SPIRV::OpBitwiseXorV:
507 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
508 SPIRV::OpTypeBool))
509 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
510 break;
511 case SPIRV::OpLifetimeStart:
512 case SPIRV::OpLifetimeStop:
513 if (MI.getOperand(1).getImm() > 0)
514 validateLifetimeStart(STI, MRI, GR, MI);
515 break;
516 case SPIRV::OpGroupAsyncCopy:
517 validatePtrUnwrapStructField(STI, MRI, GR, MI, 3);
518 validatePtrUnwrapStructField(STI, MRI, GR, MI, 4);
519 break;
520 case SPIRV::OpGroupWaitEvents:
521 // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
522 validateGroupWaitEventsPtr(STI, MRI, GR, MI);
523 break;
524 case SPIRV::OpConstantI: {
525 SPIRVTypeInst Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
526 if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
527 MI.getOperand(2).getImm() == 0) {
528 // Validate the null constant of a target extension type
529 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
530 for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
531 MI.removeOperand(i);
532 }
533 } break;
534 case SPIRV::OpExtInst: {
535 // prefetch
536 if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() ||
537 MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
538 continue;
539 switch (MI.getOperand(3).getImm()) {
540 case SPIRV::OpenCLExtInst::frexp:
541 case SPIRV::OpenCLExtInst::lgamma_r:
542 case SPIRV::OpenCLExtInst::remquo: {
543 // The last operand must be of a pointer to i32 or vector of i32
544 // values.
545 MachineIRBuilder MIB(MI);
546 SPIRVTypeInst Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
547 SPIRVTypeInst RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
548 assert(RetType && "Expected return type");
549 validatePtrTypes(STI, MRI, GR, MI, MI.getNumOperands() - 1,
550 RetType->getOpcode() != SPIRV::OpTypeVector
551 ? Int32Type
553 Int32Type, RetType->getOperand(2).getImm(),
554 MIB, false));
555 } break;
556 case SPIRV::OpenCLExtInst::fract:
557 case SPIRV::OpenCLExtInst::modf:
558 case SPIRV::OpenCLExtInst::sincos:
559 // The last operand must be of a pointer to the base type represented
560 // by the previous operand.
561 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
562 "Expected v-reg");
564 STI, MRI, GR, MI, MI.getNumOperands() - 1,
566 MI.getOperand(MI.getNumOperands() - 2).getReg()));
567 break;
568 case SPIRV::OpenCLExtInst::prefetch:
569 // Expected `ptr` type is a pointer to float, integer or vector, but
570 // the pontee value can be wrapped into a struct.
571 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
572 "Expected v-reg");
573 validatePtrUnwrapStructField(STI, MRI, GR, MI,
574 MI.getNumOperands() - 2);
575 break;
576 }
577 } break;
578 }
579 }
580 }
581 ProcessedMF.insert(&MF);
583}
584
585// Modifies either operand PtrOpIdx or OpIdx so that the pointee type of
586// PtrOpIdx matches the type for operand OpIdx. Returns true if they already
587// match or if the instruction was modified to make them match.
589 MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const {
590 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
591 SPIRVTypeInst PtrType = GR.getResultType(I.getOperand(PtrOpIdx).getReg());
592 SPIRVTypeInst PointeeType = GR.getPointeeType(PtrType);
593 SPIRVTypeInst OpType = GR.getResultType(I.getOperand(OpIdx).getReg());
594
595 if (PointeeType == OpType)
596 return true;
597
598 if (typesLogicallyMatch(PointeeType, OpType, GR)) {
599 // Apply OpCopyLogical to OpIdx.
600 if (I.getOperand(OpIdx).isDef() &&
601 insertLogicalCopyOnResult(I, PointeeType)) {
602 return true;
603 }
604
605 llvm_unreachable("Unable to add OpCopyLogical yet.");
606 return false;
607 }
608
609 return false;
610}
611
613 MachineInstr &I, SPIRVTypeInst NewResultType) const {
614 MachineRegisterInfo *MRI = &I.getMF()->getRegInfo();
615 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
616
617 Register NewResultReg =
618 createVirtualRegister(NewResultType, &GR, MRI, *I.getMF());
619 Register NewTypeReg = GR.getSPIRVTypeID(NewResultType);
620
621 assert(llvm::size(I.defs()) == 1 && "Expected only one def");
622 MachineOperand &OldResult = *I.defs().begin();
623 Register OldResultReg = OldResult.getReg();
624 MachineOperand &OldType = *I.uses().begin();
625 Register OldTypeReg = OldType.getReg();
626
627 OldResult.setReg(NewResultReg);
628 OldType.setReg(NewTypeReg);
629
630 MachineIRBuilder MIB(*I.getNextNode());
631 MIB.buildInstr(SPIRV::OpCopyLogical)
632 .addDef(OldResultReg)
633 .addUse(OldTypeReg)
634 .addUse(NewResultReg)
635 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
636 *STI.getRegBankInfo());
637 return true;
638}
639
655
658 // TODO: Pointer operand should be cast to integer in atomicrmw xchg, since
659 // SPIR-V only supports atomic exchange for integer and floating-point types.
661}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock & MBB
MachineBasicBlock MachineBasicBlock::iterator MBBI
IRTranslator LLVM IR MI
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
Register const TargetRegisterInfo * TRI
MachineInstr unsigned OpIdx
static bool typesLogicallyMatch(const SPIRVTypeInst Ty1, const SPIRVTypeInst Ty2, SPIRVGlobalRegistry &GR)
static void validateLifetimeStart(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
static void validatePtrTypes(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx, SPIRVTypeInst ResType, const Type *ResTy=nullptr)
static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx)
Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg)
void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
void validateFunCallMachineDef(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall, MachineInstr *FunDef)
void validateForwardCalls(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunDef)
const Function * validateFunCall(const SPIRVSubtarget &STI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall)
static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, Register OpReg, unsigned OpIdx, SPIRVTypeInst NewPtrType)
static SPIRVTypeInst createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I, SPIRVTypeInst OpType, bool ReuseType, SPIRVTypeInst ResType, const Type *ResTy)
This file describes how to lower LLVM code to machine code.
an instruction that atomically reads a memory location, combines it with another value,...
@ FAdd
*p = old + v
@ FSub
*p = old - v
@ UIncWrap
Increment one up to a maximum value.
@ FMin
*p = minnum(old, v) minnum matches the behavior of llvm.minnum.
@ FMax
*p = maxnum(old, v) maxnum matches the behavior of llvm.maxnum.
@ UDecWrap
Decrement one until a minimum value or zero.
BinOp getOperation() const
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition Function.cpp:358
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
Machine Value Type.
bool isVector() const
Return true if this is a vector value type.
bool isInteger() const
Return true if this is an integer or a vector integer type.
bool isFloatingPoint() const
Return true if this is a FP or a vector FP type.
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineInstrBundleIterator< MachineInstr > iterator
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
BasicBlockListType::iterator iterator
void insert(iterator MBBI, MachineBasicBlock *MBB)
Helper class to build MachineInstr.
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
void constrainAllUses(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const RegisterBankInfo &RBI) const
const MachineInstrBuilder & addUse(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
const MachineBasicBlock * getParent() const
unsigned getNumOperands() const
Retuns the total number of operands.
const MachineOperand & getOperand(unsigned i) const
Flags
Flags values. These may be or'd together.
MachineOperand class - Representation of each machine instruction operand.
const GlobalValue * getGlobal() const
int64_t getImm() const
LLVM_ABI void setReg(Register Reg)
Change the register this operand corresponds to.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
LLVM_ABI MachineInstr * getVRegDef(Register Reg) const
getVRegDef - Return the machine instr that defines the specified virtual register or null if none is ...
Wrapper class representing virtual and physical registers.
Definition Register.h:20
void addForwardCall(const Function *F, MachineInstr *MI)
SPIRVTypeInst getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineIRBuilder &MIRBuilder)
SPIRVTypeInst getOrCreateSPIRVVectorType(SPIRVTypeInst BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder, bool EmitIR)
SPIRVTypeInst getResultType(Register VReg, MachineFunction *MF=nullptr)
const Type * getTypeForSPIRVType(SPIRVTypeInst Ty) const
bool isBitcastCompatible(SPIRVTypeInst Type1, SPIRVTypeInst Type2) const
const MachineInstr * getFunctionDefinition(const Function *F)
SPIRVTypeInst getOrCreateSPIRVPointerType(const Type *BaseType, MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SC)
Register getSPIRVTypeID(SPIRVTypeInst SpirvType) const
SPIRVTypeInst getPointeeType(SPIRVTypeInst PtrType)
SmallPtrSet< MachineInstr *, 8 > * getForwardCalls(const Function *F)
SPIRVTypeInst getOrCreateSPIRVType(const Type *Type, MachineInstr &I, SPIRV::AccessQualifier::AccessQualifier AQ, bool EmitIR)
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const
MachineFunction * setCurrentFunc(MachineFunction &MF)
SPIRVTypeInst getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
const Function * getFunctionByDefinition(const MachineInstr *MI)
const SPIRVInstrInfo * getInstrInfo() const override
const SPIRVRegisterInfo * getRegisterInfo() const override
const RegisterBankInfo * getRegBankInfo() const override
AtomicExpansionKind shouldCastAtomicRMWIInIR(AtomicRMWInst *RMWI) const override
Returns how the given atomic atomicrmw should be cast by the IR-level AtomicExpand pass.
bool enforcePtrTypeCompatibility(MachineInstr &I, unsigned PtrOpIdx, unsigned OpIdx) const
unsigned getNumRegisters(LLVMContext &Context, EVT VT, std::optional< MVT > RegisterVT=std::nullopt) const override
Return the number of registers that this ValueType will eventually require.
unsigned getNumRegistersForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain targets require unusual breakdowns of certain types.
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain combinations of ABIs, Targets and features require that types are legal for some operations a...
AtomicExpansionKind shouldExpandAtomicRMWInIR(const AtomicRMWInst *RMW) const override
Returns how the IR-level AtomicExpand pass should expand the given AtomicRMW, if at all.
void finalizeLowering(MachineFunction &MF) const override
Execute target specific actions to finalize target lowering.
void getTgtMemIntrinsic(SmallVectorImpl< IntrinsicInfo > &Infos, const CallBase &I, MachineFunction &MF, unsigned Intrinsic) const override
Given an intrinsic, checks if on the target the intrinsic will need to map to a MemIntrinsicNode (tou...
bool insertLogicalCopyOnResult(MachineInstr &I, SPIRVTypeInst NewResultType) const
SPIRVTargetLowering(const TargetMachine &TM, const SPIRVSubtarget &ST)
std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const override
Given a physical register constraint (e.g.
ConstraintType getConstraintType(StringRef Constraint) const override
Given a constraint, return the type of constraint it is for this target.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition StringRef.h:258
static LLVM_ABI TargetExtType * get(LLVMContext &Context, StringRef Name, ArrayRef< Type * > Types={}, ArrayRef< unsigned > Ints={})
Return a target extension type having the specified name and optional type and integer parameters.
Definition Type.cpp:907
virtual void finalizeLowering(MachineFunction &MF) const
Execute target specific actions to finalize target lowering.
virtual AtomicExpansionKind shouldExpandAtomicRMWInIR(const AtomicRMWInst *RMW) const
Returns how the IR-level AtomicExpand pass should expand the given AtomicRMW, if at all.
void setMaxAtomicSizeInBitsSupported(unsigned SizeInBits)
Set the maximum atomic operation size supported by the backend.
void setMinCmpXchgSizeInBits(unsigned SizeInBits)
Sets the minimum cmpxchg or ll/sc size supported by the backend.
AtomicExpansionKind
Enum that specifies what an atomic load/AtomicRMWInst is expanded to, if at all.
MVT getRegisterType(MVT VT) const
Return the type of registers that this ValueType will eventually require.
TargetLowering(const TargetLowering &)=delete
Primary interface to the complete machine description for the target machine.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:294
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
This namespace contains an enum with a value for every intrinsic/builtin function known by LLVM.
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1669
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
Register createVirtualRegister(SPIRVTypeInst SpvType, SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI, const MachineFunction &MF)
MachineInstr * getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:163
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
Extended Value Type.
Definition ValueTypes.h:35
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
Definition ValueTypes.h:381
bool isVector() const
Return true if this is a vector value type.
Definition ValueTypes.h:176
EVT getVectorElementType() const
Given a vector type, return the type of each element.
Definition ValueTypes.h:336
unsigned getVectorNumElements() const
Given a vector type, return the number of elements it contains.
Definition ValueTypes.h:344
bool isInteger() const
Return true if this is an integer or a vector integer type.
Definition ValueTypes.h:160