LLVM 23.0.0git
NVPTXISelLowering.cpp
Go to the documentation of this file.
1//===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
16#include "NVPTX.h"
17#include "NVPTXISelDAGToDAG.h"
19#include "NVPTXSubtarget.h"
20#include "NVPTXTargetMachine.h"
22#include "NVPTXUtilities.h"
23#include "NVVMProperties.h"
24#include "llvm/ADT/APFloat.h"
25#include "llvm/ADT/APInt.h"
26#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/StringRef.h"
41#include "llvm/IR/Argument.h"
42#include "llvm/IR/Attributes.h"
43#include "llvm/IR/Constants.h"
44#include "llvm/IR/DataLayout.h"
47#include "llvm/IR/FPEnv.h"
48#include "llvm/IR/Function.h"
49#include "llvm/IR/GlobalValue.h"
50#include "llvm/IR/IRBuilder.h"
51#include "llvm/IR/Instruction.h"
53#include "llvm/IR/IntrinsicsNVPTX.h"
54#include "llvm/IR/Module.h"
55#include "llvm/IR/Type.h"
56#include "llvm/IR/Value.h"
68#include <algorithm>
69#include <cassert>
70#include <cmath>
71#include <cstdint>
72#include <iterator>
73#include <optional>
74#include <string>
75#include <tuple>
76#include <utility>
77#include <vector>
78
79#define DEBUG_TYPE "nvptx-lower"
80
81using namespace llvm;
82
84 "nvptx-sched4reg",
85 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
86
88 "nvptx-fma-level", cl::Hidden,
89 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
90 " 1: do it 2: do it aggressively"),
91 cl::init(2));
92
94 "nvptx-prec-divf32", cl::Hidden,
96 "NVPTX Specific: Override the precision of the lowering for f32 fdiv"),
98 clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0", "Use div.approx"),
99 clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"),
101 "Use IEEE Compliant F32 div.rnd if available (default)"),
103 "Use IEEE Compliant F32 div.rnd if available, no FTZ")),
105
107 "nvptx-prec-sqrtf32", cl::Hidden,
108 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
109 cl::init(true));
110
111// PTX atom.add.f32 has fixed FTZ behavior that may not match the function's
112// (see shouldExpandAtomicRMWInIR), so by default we fall back to a CAS loop
113// when they disagree. This flag is an escape hatch to use atom.add anyway,
114// trading correct denormal handling for the speed of the native instruction.
116 "nvptx-allow-ftz-atomics", cl::Hidden,
117 cl::desc("NVPTX Specific: Lower atomicrmw fadd to atom.add even when its "
118 "FTZ behavior does not match the function's denormal mode."),
119 cl::init(false));
120
121/// Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it
122/// does NOT use lg2.approx for log2, so this is disabled by default.
124 "nvptx-approx-log2f32",
125 cl::desc("NVPTX Specific: whether to use lg2.approx for log2"),
126 cl::init(false));
127
130 const SDNode &N) const {
131 // If nvptx-prec-div32=N is used on the command-line, always honor it
132 if (UsePrecDivF32.getNumOccurrences() > 0)
133 return UsePrecDivF32;
134
135 const SDNodeFlags Flags = N.getFlags();
136 if (Flags.hasApproximateFuncs())
138
140}
141
143 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
144 if (UsePrecSqrtF32.getNumOccurrences() > 0)
145 return UsePrecSqrtF32;
146
147 if (N) {
148 const SDNodeFlags Flags = N->getFlags();
149 if (Flags.hasApproximateFuncs())
150 return false;
151 }
152
153 return true;
154}
155
160
161static bool IsPTXVectorType(MVT VT) {
162 switch (VT.SimpleTy) {
163 default:
164 return false;
165 case MVT::v2i1:
166 case MVT::v4i1:
167 case MVT::v2i8:
168 case MVT::v4i8:
169 case MVT::v8i8: // <2 x i8x4>
170 case MVT::v16i8: // <4 x i8x4>
171 case MVT::v2i16:
172 case MVT::v4i16:
173 case MVT::v8i16: // <4 x i16x2>
174 case MVT::v2i32:
175 case MVT::v4i32:
176 case MVT::v2i64:
177 case MVT::v2f16:
178 case MVT::v4f16:
179 case MVT::v8f16: // <4 x f16x2>
180 case MVT::v2bf16:
181 case MVT::v4bf16:
182 case MVT::v8bf16: // <4 x bf16x2>
183 case MVT::v2f32:
184 case MVT::v4f32:
185 case MVT::v2f64:
186 case MVT::v4i64:
187 case MVT::v4f64:
188 case MVT::v8i32:
189 case MVT::v8f32:
190 case MVT::v16f16: // <8 x f16x2>
191 case MVT::v16bf16: // <8 x bf16x2>
192 case MVT::v16i16: // <8 x i16x2>
193 case MVT::v32i8: // <8 x i8x4>
194 return true;
195 }
196}
197
198// When legalizing vector loads/stores, this function is called, which does two
199// things:
200// 1. Determines Whether the vector is something we want to custom lower,
201// std::nullopt is returned if we do not want to custom lower it.
202// 2. If we do want to handle it, returns two parameters:
203// - unsigned int NumElts - The number of elements in the final vector
204// - EVT EltVT - The type of the elements in the final vector
205static std::optional<std::pair<unsigned int, MVT>>
207 unsigned AddressSpace) {
208 const bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AddressSpace);
209
210 if (CanLowerTo256Bit && VectorEVT.isScalarInteger() &&
211 VectorEVT.getSizeInBits() == 256)
212 return {{4, MVT::i64}};
213
214 if (!VectorEVT.isSimple())
215 return std::nullopt;
216 const MVT VectorVT = VectorEVT.getSimpleVT();
217
218 if (!VectorVT.isVector()) {
219 if (VectorVT == MVT::i128 || VectorVT == MVT::f128)
220 return {{2, MVT::i64}};
221 return std::nullopt;
222 }
223
224 const MVT EltVT = VectorVT.getVectorElementType();
225 const unsigned NumElts = VectorVT.getVectorNumElements();
226
227 // The size of the PTX virtual register that holds a packed type.
228 unsigned PackRegSize;
229
230 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
231 // legal. We can (and should) split that into 2 stores of <2 x double> here
232 // but I'm leaving that as a TODO for now.
233 switch (VectorVT.SimpleTy) {
234 default:
235 return std::nullopt;
236
237 case MVT::v4i64:
238 case MVT::v4f64:
239 // This is a "native" vector type iff the address space is global and the
240 // target supports 256-bit loads/stores
241 if (!CanLowerTo256Bit)
242 return std::nullopt;
243 [[fallthrough]];
244 case MVT::v2i8:
245 case MVT::v2i64:
246 case MVT::v2f64:
247 // This is a "native" vector type
248 return std::pair(NumElts, EltVT);
249
250 case MVT::v16f16: // <8 x f16x2>
251 case MVT::v16bf16: // <8 x bf16x2>
252 case MVT::v16i16: // <8 x i16x2>
253 case MVT::v32i8: // <8 x i8x4>
254 // This can be upsized into a "native" vector type iff the address space is
255 // global and the target supports 256-bit loads/stores.
256 if (!CanLowerTo256Bit)
257 return std::nullopt;
258 [[fallthrough]];
259 case MVT::v2i16: // <1 x i16x2>
260 case MVT::v2f16: // <1 x f16x2>
261 case MVT::v2bf16: // <1 x bf16x2>
262 case MVT::v4i8: // <1 x i8x4>
263 case MVT::v4i16: // <2 x i16x2>
264 case MVT::v4f16: // <2 x f16x2>
265 case MVT::v4bf16: // <2 x bf16x2>
266 case MVT::v8i8: // <2 x i8x4>
267 case MVT::v8f16: // <4 x f16x2>
268 case MVT::v8bf16: // <4 x bf16x2>
269 case MVT::v8i16: // <4 x i16x2>
270 case MVT::v16i8: // <4 x i8x4>
271 PackRegSize = 32;
272 break;
273
274 case MVT::v8f32: // <4 x f32x2>
275 case MVT::v8i32: // <4 x i32x2>
276 // This is a "native" vector type iff the address space is global and the
277 // target supports 256-bit loads/stores
278 if (!CanLowerTo256Bit)
279 return std::nullopt;
280 [[fallthrough]];
281 case MVT::v2f32: // <1 x f32x2>
282 case MVT::v4f32: // <2 x f32x2>
283 case MVT::v2i32: // <1 x i32x2>
284 case MVT::v4i32: // <2 x i32x2>
285 if (!STI.hasF32x2Instructions())
286 return std::pair(NumElts, EltVT);
287 PackRegSize = 64;
288 break;
289 }
290
291 // If we reach here, then we can pack 2 or more elements into a single 32-bit
292 // or 64-bit PTX register and treat the vector as a new vector containing
293 // packed elements.
294
295 // Number of elements to pack in one word.
296 const unsigned NPerReg = PackRegSize / EltVT.getSizeInBits();
297
298 return std::pair(NumElts / NPerReg, MVT::getVectorVT(EltVT, NPerReg));
299}
300
301/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
302/// legal-ish MVTs that compose it. Unlike ComputeValueVTs, this will legalize
303/// the types as required by the calling convention (with special handling for
304/// i8s).
305/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
306/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
307/// LowerCall, and LowerReturn.
308static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
309 LLVMContext &Ctx, CallingConv::ID CallConv,
310 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
312 uint64_t StartingOffset = 0) {
313 SmallVector<EVT, 16> TempVTs;
314 SmallVector<uint64_t, 16> TempOffsets;
315 ComputeValueVTs(TLI, DL, Ty, TempVTs, /*MemVTs=*/nullptr, &TempOffsets,
316 StartingOffset);
317
318 for (const auto [VT, Off] : zip(TempVTs, TempOffsets)) {
319 MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, VT);
320 unsigned NumRegs = TLI.getNumRegistersForCallingConv(Ctx, CallConv, VT);
321
322 // Since we actually can load/store b8, we need to ensure that we'll use
323 // the original sized type for any i8s or i8 vectors.
324 if (VT.getScalarType() == MVT::i8) {
325 if (RegisterVT == MVT::i16)
326 RegisterVT = MVT::i8;
327 else if (RegisterVT == MVT::v2i16)
328 RegisterVT = MVT::v2i8;
329 else
330 assert(RegisterVT == MVT::v4i8 &&
331 "Expected v4i8, v2i16, or i16 for i8 RegisterVT");
332 }
333
334 // TODO: This is horribly incorrect for cases where the vector elements are
335 // not a multiple of bytes (ex i1) and legal or i8. However, this problem
336 // has existed for as long as NVPTX has and no one has complained, so we'll
337 // leave it for now.
338 for (unsigned I : seq(NumRegs)) {
339 ValueVTs.push_back(RegisterVT);
340 Offsets.push_back(Off + I * RegisterVT.getStoreSize());
341 }
342 }
343}
344
345// We return an EVT that can hold N VTs
346// If the VT is a vector, the resulting EVT is a flat vector with the same
347// element type as VT's element type.
348static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
349 if (N == 1)
350 return VT;
351
352 return VT.isVector() ? EVT::getVectorVT(C, VT.getScalarType(),
353 VT.getVectorNumElements() * N)
354 : EVT::getVectorVT(C, VT, N);
355}
356
358 const SDLoc &dl, SelectionDAG &DAG) {
359 if (V.getValueType() == VT) {
360 assert(I == 0 && "Index must be 0 for scalar value");
361 return V;
362 }
363
364 if (!VT.isVector())
365 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, V,
366 DAG.getVectorIdxConstant(I, dl));
367
368 return DAG.getNode(
369 ISD::EXTRACT_SUBVECTOR, dl, VT, V,
371}
372
373template <typename T>
374static inline SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl,
375 SelectionDAG &DAG, T GetElement) {
376 if (N == 1)
377 return GetElement(0);
378
380 for (const unsigned I : llvm::seq(N)) {
381 SDValue Val = GetElement(I);
382 if (Val.getValueType().isVector())
383 DAG.ExtractVectorElements(Val, Values);
384 else
385 Values.push_back(Val);
386 }
387
388 EVT VT = EVT::getVectorVT(*DAG.getContext(), Values[0].getValueType(),
389 Values.size());
390 return DAG.getBuildVector(VT, dl, Values);
391}
392
393/// PromoteScalarIntegerPTX
394/// Used to make sure the arguments/returns are suitable for passing
395/// and promote them to a larger size if they're not.
396///
397/// The promoted type is placed in \p PromoteVT if the function returns true.
399 if (VT.isScalarInteger()) {
400 switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
401 default:
403 "Promotion is not suitable for scalars of size larger than 64-bits");
404 case 1:
405 return MVT::i1;
406 case 2:
407 case 4:
408 case 8:
409 return MVT::i8;
410 case 16:
411 return MVT::i16;
412 case 32:
413 return MVT::i32;
414 case 64:
415 return MVT::i64;
416 }
417 }
418 return VT;
419}
420
421// Check whether we can merge loads/stores of some of the pieces of a
422// flattened function parameter or return value into a single vector
423// load/store.
424//
425// The flattened parameter is represented as a list of EVTs and
426// offsets, and the whole structure is aligned to ParamAlignment. This
427// function determines whether we can load/store pieces of the
428// parameter starting at index Idx using a single vectorized op of
429// size AccessSize. If so, it returns the number of param pieces
430// covered by the vector op. Otherwise, it returns 1.
431template <typename T>
433 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
434 const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
435
436 // Can't vectorize if param alignment is not sufficient.
437 if (ParamAlignment < AccessSize)
438 return 1;
439 // Can't vectorize if offset is not aligned.
440 if (Offsets[Idx] & (AccessSize - 1))
441 return 1;
442
443 EVT EltVT = ValueVTs[Idx];
444 unsigned EltSize = EltVT.getStoreSize();
445
446 // Element is too large to vectorize.
447 if (EltSize >= AccessSize)
448 return 1;
449
450 unsigned NumElts = AccessSize / EltSize;
451 // Can't vectorize if AccessBytes if not a multiple of EltSize.
452 if (AccessSize != EltSize * NumElts)
453 return 1;
454
455 // We don't have enough elements to vectorize.
456 if (Idx + NumElts > ValueVTs.size())
457 return 1;
458
459 // PTX ISA can only deal with 2- and 4-element vector ops.
460 if (NumElts != 4 && NumElts != 2)
461 return 1;
462
463 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
464 // Types do not match.
465 if (ValueVTs[j] != EltVT)
466 return 1;
467
468 // Elements are not contiguous.
469 if (Offsets[j] - Offsets[j - 1] != EltSize)
470 return 1;
471 }
472 // OK. We can vectorize ValueVTs[i..i+NumElts)
473 return NumElts;
474}
475
476// Computes whether and how we can vectorize the loads/stores of a
477// flattened function parameter or return value.
478//
479// The flattened parameter is represented as the list of ValueVTs and
480// Offsets, and is aligned to ParamAlignment bytes. We return a vector
481// of the same size as ValueVTs indicating how each piece should be
482// loaded/stored (i.e. as a scalar, or as part of a vector
483// load/store).
484template <typename T>
487 const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
488 bool IsVAArg = false) {
489 // Set vector size to match ValueVTs and mark all elements as
490 // scalars by default.
491
492 if (IsVAArg)
493 return SmallVector<unsigned>(ValueVTs.size(), 1);
494
495 SmallVector<unsigned, 16> VectorInfo;
496
497 const auto GetNumElts = [&](unsigned I) -> unsigned {
498 for (const unsigned AccessSize : {16, 8, 4, 2}) {
499 const unsigned NumElts = canMergeParamLoadStoresStartingAt(
500 I, AccessSize, ValueVTs, Offsets, ParamAlignment);
501 assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
502 "Unexpected vectorization size");
503 if (NumElts != 1)
504 return NumElts;
505 }
506 return 1;
507 };
508
509 // Check what we can vectorize using 128/64/32-bit accesses.
510 for (unsigned I = 0, E = ValueVTs.size(); I != E;) {
511 const unsigned NumElts = GetNumElts(I);
512 VectorInfo.push_back(NumElts);
513 I += NumElts;
514 }
515 assert(std::accumulate(VectorInfo.begin(), VectorInfo.end(), 0u) ==
516 ValueVTs.size());
517 return VectorInfo;
518}
519
520// NVPTXTargetLowering Constructor.
522 const NVPTXSubtarget &STI)
523 : TargetLowering(TM, STI), nvTM(&TM), STI(STI), GlobalUniqueCallSite(0) {
524 // always lower memset, memcpy, and memmove intrinsics to load/store
525 // instructions, rather
526 // then generating calls to memset, mempcy or memmove.
530
533
534 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
535 // condition branches.
536 setJumpIsExpensive(true);
537
538 // Wide divides are _very_ slow. Try to reduce the width of the divide if
539 // possible.
540 addBypassSlowDiv(64, 32);
541
542 // By default, use the Source scheduling
543 if (sched4reg)
545 else
547
548 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
549 LegalizeAction NoF16Action) {
550 bool IsOpSupported = STI.allowFP16Math();
551 switch (Op) {
552 // Several FP16 instructions are available on sm_80 only.
553 case ISD::FMINNUM:
554 case ISD::FMAXNUM:
557 case ISD::FMAXIMUM:
558 case ISD::FMINIMUM:
559 case ISD::FMAXIMUMNUM:
560 case ISD::FMINIMUMNUM:
561 IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
562 break;
563 case ISD::FEXP2:
564 IsOpSupported &= STI.getSmVersion() >= 75 && STI.getPTXVersion() >= 70;
565 break;
566 }
567 setOperationAction(Op, VT, IsOpSupported ? Action : NoF16Action);
568 };
569
570 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
571 LegalizeAction NoBF16Action) {
572 bool IsOpSupported = STI.hasNativeBF16Support(Op);
574 Op, VT, IsOpSupported ? Action : NoBF16Action);
575 };
576
577 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
578 LegalizeAction NoI16x2Action) {
579 bool IsOpSupported = false;
580 // instructions are available on sm_90 only
581 switch (Op) {
582 case ISD::ADD:
583 case ISD::SMAX:
584 case ISD::SMIN:
585 case ISD::UMIN:
586 case ISD::UMAX:
587 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
588 break;
589 }
590 setOperationAction(Op, VT, IsOpSupported ? Action : NoI16x2Action);
591 };
592
593 addRegisterClass(MVT::i1, &NVPTX::B1RegClass);
594 addRegisterClass(MVT::i16, &NVPTX::B16RegClass);
595 addRegisterClass(MVT::v2i16, &NVPTX::B32RegClass);
596 addRegisterClass(MVT::v4i8, &NVPTX::B32RegClass);
597 addRegisterClass(MVT::i32, &NVPTX::B32RegClass);
598 addRegisterClass(MVT::i64, &NVPTX::B64RegClass);
599 addRegisterClass(MVT::f32, &NVPTX::B32RegClass);
600 addRegisterClass(MVT::f64, &NVPTX::B64RegClass);
601 addRegisterClass(MVT::f16, &NVPTX::B16RegClass);
602 addRegisterClass(MVT::v2f16, &NVPTX::B32RegClass);
603 addRegisterClass(MVT::bf16, &NVPTX::B16RegClass);
604 addRegisterClass(MVT::v2bf16, &NVPTX::B32RegClass);
605
606 if (STI.hasF32x2Instructions()) {
607 addRegisterClass(MVT::v2f32, &NVPTX::B64RegClass);
608 addRegisterClass(MVT::v2i32, &NVPTX::B64RegClass);
609 }
610
611 // Conversion to/from FP16/FP16x2 is always legal.
616
618 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
620
621 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
622 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
623
624 // Conversion to/from BFP16/BFP16x2 is always legal.
629
630 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
631 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
632 if (getOperationAction(ISD::SETCC, MVT::bf16) == Promote)
633 AddPromotedToType(ISD::SETCC, MVT::bf16, MVT::f32);
634
635 // Conversion to/from i16/i16x2 is always legal.
640
645
646 // No support for these operations with v2f32/v2i32
647 setOperationAction(ISD::INSERT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32}, Expand);
648 setOperationAction(ISD::VECTOR_SHUFFLE, {MVT::v2f32, MVT::v2i32}, Expand);
649
652 MVT::v2i32, Expand);
653
654 // Need custom lowering in case the index is dynamic.
655 if (STI.hasF32x2Instructions())
656 setOperationAction(ISD::EXTRACT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32},
657 Custom);
658
659 // Custom conversions to/from v2i8.
661
662 // Only logical ops can be done on v4i8/v2i32 directly, others must be done
663 // elementwise.
680 {MVT::v4i8, MVT::v2i32}, Expand);
681
682 // Operations not directly supported by NVPTX.
683 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
684 MVT::v2f32, MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16,
685 MVT::v4i8, MVT::i32, MVT::v2i32, MVT::i64}) {
688 }
689
690 // We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
691 setOperationAction(ISD::VSELECT, {MVT::v2f32, MVT::v2i32}, Expand);
692
693 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
694 // For others we will expand to a SHL/SRA pair.
700 setOperationAction(ISD::SIGN_EXTEND_INREG, {MVT::v2i16, MVT::v2i32}, Expand);
701
708
711
713 {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
714 Expand);
715
716 if (STI.hasHWROT32()) {
719 Custom);
720 }
721
722 setOperationAction(ISD::BR_JT, MVT::Other, STI.hasBrx() ? Legal : Expand);
724
725 // We want to legalize constant related memmove and memcopy
726 // intrinsics.
728
729 // FP extload/truncstore is not legal in PTX. We need to expand all these.
730 for (auto FloatVTs :
732 for (MVT ValVT : FloatVTs) {
733 for (MVT MemVT : FloatVTs) {
734 setLoadExtAction(ISD::EXTLOAD, ValVT, MemVT, Expand);
735 setTruncStoreAction(ValVT, MemVT, Expand);
736 }
737 }
738 }
739
740 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
741 // how they'll be lowered in ISel anyway, and by doing this a little earlier
742 // we allow for more DAG combine opportunities.
743 for (auto IntVTs :
745 for (MVT ValVT : IntVTs)
746 for (MVT MemVT : IntVTs)
747 if (isTypeLegal(ValVT))
748 setLoadExtAction(ISD::EXTLOAD, ValVT, MemVT, Custom);
749
750 // PTX does not support load / store predicate registers
752 for (MVT VT : MVT::integer_valuetypes()) {
754 Promote);
755 setTruncStoreAction(VT, MVT::i1, Expand);
756 }
757
758 // Disable generations of extload/truncstore for v2i32/v2i16/v2i8. The generic
759 // expansion for these nodes when they are unaligned is incorrect if the
760 // type is a vector.
761 //
762 // TODO: Fix the generic expansion for these nodes found in
763 // TargetLowering::expandUnalignedLoad/Store.
765 MVT::v2i8, Expand);
767 {MVT::v2i8, MVT::v2i16}, Expand);
768 setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand);
769 setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
770 setTruncStoreAction(MVT::v2i32, MVT::v2i8, Expand);
771
772 // Register custom handling for illegal type loads/stores. We'll try to custom
773 // lower almost all illegal types and logic in the lowering will discard cases
774 // we can't handle.
775 setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::i256, MVT::f128},
776 Custom);
778 if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
780 Custom);
781
782 // Custom legalization for LDU intrinsics.
783 // TODO: The logic to lower these is not very robust and we should rewrite it.
784 // Perhaps LDU should not be represented as an intrinsic at all.
787 if (IsPTXVectorType(VT))
789
793 MVT::i1, Expand);
794
795 // This is legal in NVPTX
800
801 setOperationAction(ISD::DYNAMIC_STACKALLOC, {MVT::i32, MVT::i64}, Custom);
803
804 // TRAP can be lowered to PTX trap
805 setOperationAction(ISD::TRAP, MVT::Other, Legal);
806 // DEBUGTRAP can be lowered to PTX brkpt
808
809 // Support varargs.
814
816 {MVT::i16, MVT::i32, MVT::i64}, Legal);
817 // PTX abs.s is undefined for INT_MIN, so ISD::ABS (which requires
818 // abs(INT_MIN) == INT_MIN) must be expanded. ABS_MIN_POISON matches
819 // PTX abs semantics since INT_MIN input is poison/undefined.
820 setOperationAction(ISD::ABS, {MVT::i16, MVT::i32, MVT::i64}, Expand);
821 setOperationAction(ISD::ABS_MIN_POISON, {MVT::i16, MVT::i32, MVT::i64},
822 Legal);
823
825 Promote);
828
829 setI16x2OperationAction(ISD::ABS_MIN_POISON, MVT::v2i16, Legal, Custom);
830 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
831 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
832 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
833 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
834 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
835 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
836
837 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
838 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
839 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
840 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
841 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
842 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
843
844 // Other arithmetic and logic ops are unsupported.
848 {MVT::v2i16, MVT::v2i32}, Expand);
849
850 // v2i32 is not supported for any arithmetic operations
855 MVT::v2i32, Expand);
856
861 if (STI.getPTXVersion() >= 43) {
866 }
867
869 setOperationAction(ISD::CTTZ, {MVT::v2i16, MVT::v2i32}, Expand);
872
873 // PTX does not directly support SELP of i1, so promote to i32 first
875
876 // PTX cannot multiply two i64s in a single instruction.
879
880 // We have some custom DAG combine patterns for these nodes
882 ISD::AND,
884 ISD::FADD,
891 ISD::MUL,
893 ISD::SHL,
894 ISD::SREM,
895 ISD::UREM,
899 ISD::LOAD,
904
905 // If the vector operands require register coalescing, scalarize instead
906 if (STI.hasF32x2Instructions())
908
909 // setcc for f16x2 and bf16x2 needs special handling to prevent
910 // legalizer's attempt to scalarize it due to v2i1 not being legal.
911 if (STI.allowFP16Math() || STI.hasBF16Math())
913
914 // Vector reduction operations. These may be turned into shuffle or tree
915 // reductions depending on what instructions are available for each type.
917 MVT EltVT = VT.getVectorElementType();
918 if (EltVT == MVT::f32 || EltVT == MVT::f64) {
921 VT, Custom);
922 }
923 }
924
925 // Promote fp16 arithmetic if fp16 hardware isn't available or the
926 // user passed --nvptx-no-fp16-math. The flag is useful because,
927 // although sm_53+ GPUs have some sort of FP16 support in
928 // hardware, only sm_53 and sm_60 have full implementation. Others
929 // only have token amount of hardware and are likely to run faster
930 // by using fp32 units instead.
931 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
932 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
933 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
934 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
935 // bf16 must be promoted to f32.
936 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
937 if (getOperationAction(Op, MVT::bf16) == Promote)
938 AddPromotedToType(Op, MVT::bf16, MVT::f32);
939 setOperationAction(Op, MVT::v2f32,
940 STI.hasF32x2Instructions() ? Legal : Expand);
941 }
942
943 // On SM80, we select add/mul/sub as fma to avoid promotion to float
944 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
945 for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
946 if (!STI.hasNativeBF16Support(Op) && STI.hasNativeBF16Support(ISD::FMA)) {
948 }
949 }
950 }
951
952 // f16/f16x2 neg was introduced in PTX 60, SM_53.
953 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
954 STI.getPTXVersion() >= 60 &&
955 STI.allowFP16Math();
956 for (const auto &VT : {MVT::f16, MVT::v2f16})
958 IsFP16FP16x2NegAvailable ? Legal : Expand);
959
960 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
961 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
962 setOperationAction(ISD::FNEG, MVT::v2f32, Expand);
963 // (would be) Library functions.
964
965 // These map to conversion instructions for scalar FP types.
966 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
968 setOperationAction(Op, MVT::f16, Legal);
969 setOperationAction(Op, MVT::f32, Legal);
970 setOperationAction(Op, MVT::f64, Legal);
971 setOperationAction(Op, MVT::v2f16, Expand);
972 setOperationAction(Op, MVT::v2bf16, Expand);
973 setOperationAction(Op, MVT::v2f32, Expand);
974 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
975 if (getOperationAction(Op, MVT::bf16) == Promote)
976 AddPromotedToType(Op, MVT::bf16, MVT::f32);
977 }
978
979 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
981 }
982 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
983 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
986 }
987 }
988
989 // Expand v2f32 = fp_extend
991 // Expand v2[b]f16 = fp_round v2f32
992 setOperationAction(ISD::FP_ROUND, {MVT::v2bf16, MVT::v2f16}, Expand);
993
994 // sm_80 only has conversions between f32 and bf16. Custom lower all other
995 // bf16 conversions.
996 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
997 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
1000 VT, Custom);
1001 }
1004 MVT::bf16, Custom);
1005 }
1006
1010 setOperationAction(ISD::FROUND, MVT::v2bf16, Expand);
1014 AddPromotedToType(ISD::FROUND, MVT::bf16, MVT::f32);
1015
1016 // 'Expand' implements FCOPYSIGN without calling an external library.
1023
1024 // These map to corresponding instructions for f32/f64. f16 must be
1025 // promoted to f32. v2f16 is expanded to f16, which is then promoted
1026 // to f32.
1027 for (const auto &Op :
1029 setOperationAction(Op, MVT::f16, Promote);
1030 setOperationAction(Op, MVT::f32, Legal);
1031 // only div/rem/sqrt are legal for f64
1032 if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) {
1033 setOperationAction(Op, MVT::f64, Legal);
1034 }
1035 setOperationAction(Op, {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Expand);
1036 setOperationAction(Op, MVT::bf16, Promote);
1037 AddPromotedToType(Op, MVT::bf16, MVT::f32);
1038 }
1039 setOperationAction(ISD::FREM, {MVT::f32, MVT::f64}, Custom);
1040
1041 setOperationAction(ISD::FABS, {MVT::f32, MVT::f64}, Legal);
1042 setOperationAction(ISD::FABS, MVT::v2f32, Expand);
1043 if (STI.getPTXVersion() >= 65) {
1044 setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
1045 setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
1046 } else {
1048 setOperationAction(ISD::FABS, MVT::v2f16, Expand);
1049 }
1050 setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
1051 setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
1052 if (getOperationAction(ISD::FABS, MVT::bf16) == Promote)
1053 AddPromotedToType(ISD::FABS, MVT::bf16, MVT::f32);
1054
1055 for (const auto &Op :
1057 setOperationAction(Op, MVT::f32, Legal);
1058 setOperationAction(Op, MVT::f64, Legal);
1059 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
1060 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1061 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1062 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
1063 if (getOperationAction(Op, MVT::bf16) == Promote)
1064 AddPromotedToType(Op, MVT::bf16, MVT::f32);
1065 setOperationAction(Op, MVT::v2f32, Expand);
1066 }
1067 bool SupportsF32MinMaxNaN =
1068 STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
1069 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
1070 setOperationAction(Op, MVT::f32, SupportsF32MinMaxNaN ? Legal : Expand);
1071 setFP16OperationAction(Op, MVT::f16, Legal, Expand);
1072 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1073 setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
1074 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1075 setOperationAction(Op, MVT::v2f32, Expand);
1076 }
1077
1078 // Custom lowering for inline asm with 128-bit operands
1081
1082 // FEXP2 support:
1083 // - f32
1084 // - f16/f16x2 (sm_70+, PTX 7.0+)
1085 // - bf16/bf16x2 (sm_90+, PTX 7.8+)
1086 // When f16/bf16 types aren't supported, they are promoted/expanded to f32.
1088 setOperationAction(ISD::FEXP2, MVT::v2f32, Expand);
1089 setFP16OperationAction(ISD::FEXP2, MVT::f16, Legal, Promote);
1090 setFP16OperationAction(ISD::FEXP2, MVT::v2f16, Legal, Expand);
1091 setBF16OperationAction(ISD::FEXP2, MVT::bf16, Legal, Promote);
1092 setBF16OperationAction(ISD::FEXP2, MVT::v2bf16, Legal, Expand);
1093
1094 // FLOG2 supports f32 only
1095 // f16/bf16 types aren't supported, but they are promoted/expanded to f32.
1096 if (UseApproxLog2F32) {
1098 setOperationPromotedToType(ISD::FLOG2, MVT::f16, MVT::f32);
1099 setOperationPromotedToType(ISD::FLOG2, MVT::bf16, MVT::f32);
1100 setOperationAction(ISD::FLOG2, {MVT::v2f16, MVT::v2bf16, MVT::v2f32},
1101 Expand);
1102 }
1103
1104 setOperationAction(ISD::ADDRSPACECAST, {MVT::i32, MVT::i64}, Custom);
1105
1106 setOperationAction(ISD::ATOMIC_LOAD_SUB, {MVT::i32, MVT::i64}, Expand);
1107
1108 // atom.b128 is legal in PTX but since we don't represent i128 as a legal
1109 // type, we need to custom lower it.
1111 Custom);
1112
1113 // Now deduce the information based on the above mentioned
1114 // actions
1115 computeRegisterProperties(STI.getRegisterInfo());
1116
1117 // PTX support for 16-bit CAS is emulated. Only use 32+
1118 setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits());
1119 setMaxAtomicSizeInBitsSupported(STI.hasAtomSwap128() ? 128 : 64);
1121
1122 // Custom lowering for tcgen05.ld vector operands
1124 {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1125 MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::v2f32,
1126 MVT::v4f32, MVT::v8f32, MVT::v16f32, MVT::v32f32,
1127 MVT::v64f32, MVT::v128f32},
1128 Custom);
1129
1130 // Custom lowering for tcgen05.st vector operands
1132 {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1133 MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::Other},
1134 Custom);
1135
1136 // Enable custom lowering for the following:
1137 // * MVT::i128 - clusterlaunchcontrol
1138 // * MVT::i32 - prmt
1139 // * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics
1140 // * MVT::Other - internal.addrspace.wrap
1142 {MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Custom);
1143
1144 // Custom lowering for bswap
1145 setOperationAction(ISD::BSWAP, {MVT::i16, MVT::i32, MVT::i64, MVT::v2i16},
1146 Custom);
1147}
1148
1151 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1152 VT.getScalarType() == MVT::i1)
1153 return TypeSplitVector;
1155}
1156
1158 int Enabled, int &ExtraSteps,
1159 bool &UseOneConst,
1160 bool Reciprocal) const {
1163 return SDValue();
1164
1165 if (ExtraSteps == ReciprocalEstimate::Unspecified)
1166 ExtraSteps = 0;
1167
1168 SDLoc DL(Operand);
1169 EVT VT = Operand.getValueType();
1170 bool Ftz = useF32FTZ(DAG.getMachineFunction());
1171
1172 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1173 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
1174 DAG.getConstant(IID, DL, MVT::i32), Operand);
1175 };
1176
1177 // The sqrt and rsqrt refinement processes assume we always start out with an
1178 // approximation of the rsqrt. Therefore, if we're going to do any refinement
1179 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
1180 // any refinement, we must return a regular sqrt.
1181 if (Reciprocal || ExtraSteps > 0) {
1182 if (VT == MVT::f32)
1183 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1184 : Intrinsic::nvvm_rsqrt_approx_f);
1185 else if (VT == MVT::f64)
1186 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1187 else
1188 return SDValue();
1189 } else {
1190 if (VT == MVT::f32)
1191 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1192 : Intrinsic::nvvm_sqrt_approx_f);
1193 else {
1194 // There's no sqrt.approx.f64 instruction, so we emit
1195 // reciprocal(rsqrt(x)). This is faster than
1196 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
1197 // x * rsqrt(x).)
1198 return DAG.getNode(
1200 DAG.getConstant(Intrinsic::nvvm_rcp_approx_ftz_d, DL, MVT::i32),
1201 MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1202 }
1203 }
1204}
1205
1206static Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
1207 const DataLayout &DL);
1208
1210 const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
1212 std::optional<unsigned> FirstVAArg, const CallBase &CB,
1213 unsigned UniqueCallSite) const {
1214 auto PtrVT = getPointerTy(DL);
1215
1216 std::string Prototype;
1217 raw_string_ostream O(Prototype);
1218 O << "prototype_" << UniqueCallSite << " : .callprototype ";
1219
1220 if (RetTy->isVoidTy()) {
1221 O << "()";
1222 } else {
1223 O << "(";
1224 if (shouldPassAsArray(RetTy)) {
1225 const Align RetAlign = getArgumentAlignment(&CB, RetTy, 0, DL);
1226 O << ".param .align " << RetAlign.value() << " .b8 _["
1227 << DL.getTypeAllocSize(RetTy) << "]";
1228 } else if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy()) {
1229 unsigned size = 0;
1230 if (auto *ITy = dyn_cast<IntegerType>(RetTy)) {
1231 size = ITy->getBitWidth();
1232 } else {
1233 assert(RetTy->isFloatingPointTy() &&
1234 "Floating point type expected here");
1235 size = RetTy->getPrimitiveSizeInBits();
1236 }
1237 // PTX ABI requires all scalar return values to be at least 32
1238 // bits in size. fp16 normally uses .b16 as its storage type in
1239 // PTX, so its size must be adjusted here, too.
1241
1242 O << ".param .b" << size << " _";
1243 } else if (isa<PointerType>(RetTy)) {
1244 O << ".param .b" << PtrVT.getSizeInBits() << " _";
1245 } else {
1246 llvm_unreachable("Unknown return type");
1247 }
1248 O << ") ";
1249 }
1250 O << "_ (";
1251
1252 bool first = true;
1253
1254 const unsigned NumArgs = FirstVAArg.value_or(Args.size());
1255 auto AllOuts = ArrayRef(Outs);
1256 for (const unsigned I : llvm::seq(NumArgs)) {
1257 const auto ArgOuts =
1258 AllOuts.take_while([I](auto O) { return O.OrigArgIndex == I; });
1259 AllOuts = AllOuts.drop_front(ArgOuts.size());
1260
1261 Type *Ty = Args[I].Ty;
1262 if (!first) {
1263 O << ", ";
1264 }
1265 first = false;
1266
1267 if (ArgOuts[0].Flags.isByVal()) {
1268 // Indirect calls need strict ABI alignment so we disable optimizations by
1269 // not providing a function to optimize.
1270 Type *ETy = Args[I].IndirectType;
1271 Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1272 Align ParamByValAlign =
1273 getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
1274
1275 O << ".param .align " << ParamByValAlign.value() << " .b8 _["
1276 << ArgOuts[0].Flags.getByValSize() << "]";
1277 } else {
1278 if (shouldPassAsArray(Ty)) {
1279 Align ParamAlign =
1280 getArgumentAlignment(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
1281 O << ".param .align " << ParamAlign.value() << " .b8 _["
1282 << DL.getTypeAllocSize(Ty) << "]";
1283 continue;
1284 }
1285 // i8 types in IR will be i16 types in SDAG
1286 assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
1287 (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
1288 "type mismatch between callee prototype and arguments");
1289 // scalar type
1290 unsigned sz = 0;
1291 if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
1292 sz = promoteScalarArgumentSize(ITy->getBitWidth());
1293 } else if (isa<PointerType>(Ty)) {
1294 sz = PtrVT.getSizeInBits();
1295 } else {
1296 sz = Ty->getPrimitiveSizeInBits();
1297 }
1298 O << ".param .b" << sz << " _";
1299 }
1300 }
1301
1302 if (FirstVAArg)
1303 O << (first ? "" : ",") << " .param .align "
1304 << STI.getMaxRequiredAlignment() << " .b8 _[]";
1305 O << ")";
1306 if (shouldEmitPTXNoReturn(&CB, *nvTM))
1307 O << " .noreturn";
1308 O << ";";
1309
1310 return Prototype;
1311}
1312
1313static Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
1314 const DataLayout &DL) {
1315 if (!CB) {
1316 // CallSite is zero, fallback to ABI type alignment
1317 return DL.getABITypeAlign(Ty);
1318 }
1319
1320 const Function *DirectCallee = CB->getCalledFunction();
1321
1322 if (!DirectCallee) {
1323 // We don't have a direct function symbol, but that may be because of
1324 // constant cast instructions in the call.
1325
1326 // With bitcast'd call targets, the instruction will be the call
1327 if (const auto *CI = dyn_cast<CallInst>(CB)) {
1328 // Check if we have call alignment metadata
1329 if (MaybeAlign StackAlign = getAlign(*CI, Idx))
1330 return StackAlign.value();
1331 }
1332 DirectCallee = getMaybeBitcastedCallee(CB);
1333 }
1334
1335 // Check for function alignment information if we found that the
1336 // ultimate target is a Function
1337 if (DirectCallee)
1338 return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
1339
1340 // Call is indirect, fall back to the ABI type alignment
1341 return DL.getABITypeAlign(Ty);
1342}
1343
1345 const DataLayout &DL,
1346 const TargetLowering &TL) {
1347 if (Ptr->getOpcode() == ISD::FrameIndex) {
1348 auto Ty = TL.getPointerTy(DL, ADDRESS_SPACE_LOCAL);
1349 Ptr = DAG.getAddrSpaceCast(SDLoc(), Ty, Ptr, ADDRESS_SPACE_GENERIC,
1351
1353 }
1354
1355 // Peel of an addrspacecast to generic and load directly from the specific
1356 // address space.
1357 if (Ptr->getOpcode() == ISD::ADDRSPACECAST) {
1358 const auto *ASC = cast<AddrSpaceCastSDNode>(Ptr);
1359 if (ASC->getDestAddressSpace() == ADDRESS_SPACE_GENERIC) {
1360 Ptr = ASC->getOperand(0);
1361 return MachinePointerInfo(ASC->getSrcAddressSpace());
1362 }
1363 }
1364
1365 return MachinePointerInfo();
1366}
1367
1369 if (Flags.isSExt())
1370 return ISD::SIGN_EXTEND;
1371 if (Flags.isZExt())
1372 return ISD::ZERO_EXTEND;
1373 return ISD::ANY_EXTEND;
1374}
1375
1377 ISD::ArgFlagsTy Flags, SelectionDAG &DAG,
1378 SDLoc dl) {
1379 const EVT ActualVT = V.getValueType();
1380 assert((ActualVT == ExpectedVT ||
1381 (ExpectedVT.isInteger() && ActualVT.isInteger())) &&
1382 "Non-integer argument type size mismatch");
1383 if (ExpectedVT.bitsGT(ActualVT))
1384 return DAG.getNode(getExtOpcode(Flags), dl, ExpectedVT, V);
1385 if (ExpectedVT.bitsLT(ActualVT))
1386 return DAG.getNode(ISD::TRUNCATE, dl, ExpectedVT, V);
1387
1388 return V;
1389}
1390
1392 SmallVectorImpl<SDValue> &InVals) const {
1393
1394 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1396 "Support for variadic functions (unsized array parameter) introduced "
1397 "in PTX ISA version 6.0 and requires target sm_30.");
1398
1399 SelectionDAG &DAG = CLI.DAG;
1400 SDLoc dl = CLI.DL;
1401 const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
1402 SDValue Callee = CLI.Callee;
1403 ArgListTy &Args = CLI.getArgs();
1404 Type *RetTy = CLI.RetTy;
1405 const CallBase *CB = CLI.CB;
1406 const DataLayout &DL = DAG.getDataLayout();
1407 LLVMContext &Ctx = *DAG.getContext();
1408
1409 const auto GetI32 = [&](const unsigned I) {
1410 return DAG.getConstant(I, dl, MVT::i32);
1411 };
1412
1413 const unsigned UniqueCallSite = GlobalUniqueCallSite++;
1414 const SDValue CallChain = CLI.Chain;
1415 const SDValue StartChain =
1416 DAG.getCALLSEQ_START(CallChain, UniqueCallSite, 0, dl);
1417 SDValue DeclareGlue = StartChain.getValue(1);
1418
1419 SmallVector<SDValue, 16> CallPrereqs{StartChain};
1420
1421 const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
1422 // PTX ABI requires integral types to be at least 32 bits in size. FP16 is
1423 // loaded/stored using i16, so it's handled here as well.
1424 const unsigned SizeBits = promoteScalarArgumentSize(Size * 8);
1425 SDValue Declare =
1426 DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
1427 {StartChain, Symbol, GetI32(SizeBits), DeclareGlue});
1428 CallPrereqs.push_back(Declare);
1429 DeclareGlue = Declare.getValue(1);
1430 return Declare;
1431 };
1432
1433 const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align,
1434 unsigned Size) {
1435 SDValue Declare = DAG.getNode(
1436 NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
1437 {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
1438 CallPrereqs.push_back(Declare);
1439 DeclareGlue = Declare.getValue(1);
1440 return Declare;
1441 };
1442
1443 // Variadic arguments.
1444 //
1445 // Normally, for each argument, we declare a param scalar or a param
1446 // byte array in the .param space, and store the argument value to that
1447 // param scalar or array starting at offset 0.
1448 //
1449 // In the case of the first variadic argument, we declare a vararg byte array
1450 // with size 0. The exact size of this array isn't known at this point, so
1451 // it'll be patched later. All the variadic arguments will be stored to this
1452 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1453 // initially set to 0, so it can be used for non-variadic arguments (which use
1454 // 0 offset) to simplify the code.
1455 //
1456 // After all vararg is processed, 'VAOffset' holds the size of the
1457 // vararg byte array.
1458 assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) &&
1459 "Non-VarArg function with extra arguments");
1460
1461 const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
1462 unsigned VAOffset = 0; // current offset in the param array
1463
1464 const SDValue VADeclareParam =
1465 CLI.Args.size() > FirstVAArg
1466 ? MakeDeclareArrayParam(getCallParamSymbol(DAG, FirstVAArg, MVT::i32),
1467 Align(STI.getMaxRequiredAlignment()), 0)
1468 : SDValue();
1469
1470 // Args.size() and Outs.size() need not match.
1471 // Outs.size() will be larger
1472 // * if there is an aggregate argument with multiple fields (each field
1473 // showing up separately in Outs)
1474 // * if there is a vector argument with more than typical vector-length
1475 // elements (generally if more than 4) where each vector element is
1476 // individually present in Outs.
1477 // So a different index should be used for indexing into Outs/OutVals.
1478 // See similar issue in LowerFormalArguments.
1479 auto AllOuts = ArrayRef(CLI.Outs);
1480 auto AllOutVals = ArrayRef(CLI.OutVals);
1481 assert(AllOuts.size() == AllOutVals.size() &&
1482 "Outs and OutVals must be the same size");
1483 // Declare the .params or .reg need to pass values
1484 // to the function
1485 for (const auto E : llvm::enumerate(Args)) {
1486 const auto ArgI = E.index();
1487 const auto Arg = E.value();
1488 const auto ArgOuts =
1489 AllOuts.take_while([&](auto O) { return O.OrigArgIndex == ArgI; });
1490 const auto ArgOutVals = AllOutVals.take_front(ArgOuts.size());
1491 AllOuts = AllOuts.drop_front(ArgOuts.size());
1492 AllOutVals = AllOutVals.drop_front(ArgOuts.size());
1493
1494 const bool IsVAArg = (ArgI >= FirstVAArg);
1495 const bool IsByVal = Arg.IsByVal;
1496
1497 const SDValue ParamSymbol =
1498 getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
1499
1500 assert((!IsByVal || Arg.IndirectType) &&
1501 "byval arg must have indirect type");
1502 Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
1503
1504 const Align ArgAlign = [&]() {
1505 if (IsByVal) {
1506 // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1507 // so we don't need to worry whether it's naturally aligned or not.
1508 // See TargetLowering::LowerCallTo().
1509 const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1511 InitialAlign, DL);
1512 }
1513 return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
1514 }();
1515
1516 const unsigned TySize = DL.getTypeAllocSize(ETy);
1517 assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
1518 "type size mismatch");
1519
1520 const SDValue ArgDeclare = [&]() {
1521 if (IsVAArg)
1522 return VADeclareParam;
1523
1524 if (IsByVal || shouldPassAsArray(Arg.Ty))
1525 return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
1526
1527 assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
1528 assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
1529 "Only int and float types are supported as non-array arguments");
1530
1531 return MakeDeclareScalarParam(ParamSymbol, TySize);
1532 }();
1533
1534 if (IsByVal) {
1535 assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
1536 SDValue SrcPtr = ArgOutVals[0];
1537 const auto PointerInfo = refinePtrAS(SrcPtr, DAG, DL, *this);
1538 const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1539
1540 if (IsVAArg)
1541 VAOffset = alignTo(VAOffset, ArgAlign);
1542
1543 SmallVector<EVT, 4> ValueVTs, MemVTs;
1545 ComputeValueVTs(*this, DL, ETy, ValueVTs, &MemVTs, &Offsets);
1546
1547 unsigned J = 0;
1548 const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign, IsVAArg);
1549 for (const unsigned NumElts : VI) {
1550 EVT LoadVT = getVectorizedVT(MemVTs[J], NumElts, Ctx);
1551 Align SrcAlign = commonAlignment(BaseSrcAlign, Offsets[J]);
1552 SDValue SrcAddr = DAG.getObjectPtrOffset(dl, SrcPtr, Offsets[J]);
1553 SDValue SrcLoad =
1554 DAG.getLoad(LoadVT, dl, CallChain, SrcAddr, PointerInfo, SrcAlign);
1555
1556 TypeSize ParamOffset = Offsets[J].getWithIncrement(VAOffset);
1557 Align ParamAlign = commonAlignment(ArgAlign, ParamOffset);
1558 SDValue ParamAddr =
1559 DAG.getObjectPtrOffset(dl, ParamSymbol, ParamOffset);
1560 SDValue StoreParam = DAG.getStore(
1561 ArgDeclare, dl, SrcLoad, ParamAddr,
1563 CallPrereqs.push_back(StoreParam);
1564
1565 J += NumElts;
1566 }
1567 if (IsVAArg)
1568 VAOffset += TySize;
1569 } else {
1572 ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, Arg.Ty, VTs, Offsets,
1573 VAOffset);
1574 assert(VTs.size() == Offsets.size() && "Size mismatch");
1575 assert(VTs.size() == ArgOuts.size() && "Size mismatch");
1576
1577 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1578 // than 32-bits are sign extended or zero extended, depending on
1579 // whether they are signed or unsigned types. This case applies
1580 // only to scalar parameters and not to aggregate values.
1581 const bool ExtendIntegerParam =
1582 Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
1583
1584 const auto GetStoredValue = [&](const unsigned I) {
1585 SDValue StVal = ArgOutVals[I];
1587 StVal.getValueType() &&
1588 "OutVal type should always be legal");
1589
1590 const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
1591 const EVT StoreVT =
1592 ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1593
1594 return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
1595 };
1596
1597 unsigned J = 0;
1598 const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
1599 for (const unsigned NumElts : VI) {
1600 const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
1601
1602 unsigned Offset;
1603 if (IsVAArg) {
1604 // TODO: We may need to support vector types that can be passed
1605 // as scalars in variadic arguments.
1606 assert(NumElts == 1 &&
1607 "Vectorization should be disabled for vaargs.");
1608
1609 // Align each part of the variadic argument to their type.
1610 VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
1611 Offset = VAOffset;
1612
1613 const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1614 VAOffset += DL.getTypeAllocSize(TheStoreType.getTypeForEVT(Ctx));
1615 } else {
1616 assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
1617 Offset = Offsets[J];
1618 }
1619
1620 SDValue Ptr =
1621 DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
1622
1623 const MaybeAlign CurrentAlign = ExtendIntegerParam
1624 ? MaybeAlign(std::nullopt)
1625 : commonAlignment(ArgAlign, Offset);
1626
1627 SDValue Val =
1628 getBuildVectorizedValue(NumElts, dl, DAG, [&](unsigned K) {
1629 return GetStoredValue(J + K);
1630 });
1631
1632 SDValue StoreParam = DAG.getStore(
1633 ArgDeclare, dl, Val, Ptr,
1635 CallPrereqs.push_back(StoreParam);
1636
1637 J += NumElts;
1638 }
1639 }
1640 }
1641
1642 // Handle Result
1643 if (!Ins.empty()) {
1644 const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
1645 const unsigned ResultSize = DL.getTypeAllocSize(RetTy);
1646 if (shouldPassAsArray(RetTy)) {
1647 const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
1648 MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize);
1649 } else {
1650 MakeDeclareScalarParam(RetSymbol, ResultSize);
1651 }
1652 }
1653
1654 // Set the size of the vararg param byte array if the callee is a variadic
1655 // function and the variadic part is not empty.
1656 if (VADeclareParam) {
1657 SDValue DeclareParamOps[] = {VADeclareParam.getOperand(0),
1658 VADeclareParam.getOperand(1),
1659 VADeclareParam.getOperand(2), GetI32(VAOffset),
1660 VADeclareParam.getOperand(4)};
1661 DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(),
1662 VADeclareParam->getVTList(), DeclareParamOps);
1663 }
1664
1665 const auto *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
1666 const auto *CalleeF = Func ? dyn_cast<Function>(Func->getGlobal()) : nullptr;
1667
1668 // If the type of the callsite does not match that of the function, convert
1669 // the callsite to an indirect call.
1670 const bool ConvertToIndirectCall =
1671 CalleeF && CB->getFunctionType() != CalleeF->getFunctionType();
1672
1673 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1674 // between them we must rely on the call site value which is valid for
1675 // indirect calls but is always null for libcalls.
1676 const bool IsIndirectCall = (!Func && CB) || ConvertToIndirectCall;
1677
1678 if (isa<ExternalSymbolSDNode>(Callee)) {
1679 Function* CalleeFunc = nullptr;
1680
1681 // Try to find the callee in the current module.
1682 Callee = DAG.getSymbolFunctionGlobalAddress(Callee, &CalleeFunc);
1683 assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1684
1685 // Set the "libcall callee" attribute to indicate that the function
1686 // must always have a declaration.
1687 CalleeFunc->addFnAttr("nvptx-libcall-callee", "true");
1688 }
1689
1690 if (IsIndirectCall) {
1691 // This is indirect function call case : PTX requires a prototype of the
1692 // form
1693 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1694 // to be emitted, and the label has to used as the last arg of call
1695 // instruction.
1696 // The prototype is embedded in a string and put as the operand for a
1697 // CallPrototype SDNode which will print out to the value of the string.
1698 const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1699 std::string Proto =
1700 getPrototype(DL, RetTy, Args, CLI.Outs,
1701 HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB,
1702 UniqueCallSite);
1703 const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
1704 const SDValue PrototypeDeclare = DAG.getNode(
1705 NVPTXISD::CallPrototype, dl, MVT::Other,
1706 {StartChain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32)});
1707 CallPrereqs.push_back(PrototypeDeclare);
1708 }
1709
1710 const bool IsUnknownIntrinsic =
1711 CalleeF && CalleeF->isIntrinsic() &&
1712 CalleeF->getIntrinsicID() == Intrinsic::not_intrinsic;
1713 if (IsUnknownIntrinsic) {
1716 "call to unknown intrinsic '" + CalleeF->getName() +
1717 "' cannot be lowered by the NVPTX backend",
1718 dl.getDebugLoc()));
1719 }
1720
1721 const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
1722 const unsigned NumArgs =
1723 std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
1724 /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
1725 /// NumParams, Callee, Proto)
1726 const SDValue CallToken = DAG.getTokenFactor(dl, CallPrereqs);
1727 const SDValue Call = DAG.getNode(
1728 NVPTXISD::CALL, dl, MVT::Other,
1729 {CallToken, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
1730 GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee, GetI32(Proto)});
1731
1732 SmallVector<SDValue, 16> LoadChains{Call};
1733 SmallVector<SDValue, 16> ProxyRegOps;
1734 if (!Ins.empty()) {
1737 ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, RetTy, VTs, Offsets);
1738 assert(VTs.size() == Ins.size() && "Bad value decomposition");
1739
1740 const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
1741 const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
1742
1743 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
1744 // 32-bits are sign extended or zero extended, depending on whether
1745 // they are signed or unsigned types.
1746 const bool ExtendIntegerRetVal =
1747 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
1748
1749 unsigned I = 0;
1750 const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
1751 for (const unsigned NumElts : VI) {
1752 const MaybeAlign CurrentAlign =
1753 ExtendIntegerRetVal ? MaybeAlign(std::nullopt)
1754 : commonAlignment(RetAlign, Offsets[I]);
1755
1756 const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
1757 const EVT LoadVT =
1758 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1759 const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
1760 SDValue Ptr =
1761 DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
1762
1763 SDValue R = DAG.getLoad(
1764 VecVT, dl, Call, Ptr,
1766
1767 LoadChains.push_back(R.getValue(1));
1768 for (const unsigned J : llvm::seq(NumElts))
1769 ProxyRegOps.push_back(getExtractVectorizedValue(R, J, LoadVT, dl, DAG));
1770 I += NumElts;
1771 }
1772 }
1773
1774 const SDValue EndToken = DAG.getTokenFactor(dl, LoadChains);
1775 const SDValue CallEnd = DAG.getCALLSEQ_END(EndToken, UniqueCallSite,
1776 UniqueCallSite + 1, SDValue(), dl);
1777
1778 // Append ProxyReg instructions to the chain to make sure that `callseq_end`
1779 // will not get lost. Otherwise, during libcalls expansion, the nodes can become
1780 // dangling.
1781 for (const auto [I, Reg] : llvm::enumerate(ProxyRegOps)) {
1782 SDValue Proxy =
1783 DAG.getNode(NVPTXISD::ProxyReg, dl, Reg.getValueType(), {CallEnd, Reg});
1784 SDValue Ret = correctParamType(Proxy, Ins[I].VT, Ins[I].Flags, DAG, dl);
1785 InVals.push_back(Ret);
1786 }
1787
1788 // set IsTailCall to false for now, until we figure out how to express
1789 // tail call optimization in PTX
1790 CLI.IsTailCall = false;
1791 return CallEnd;
1792}
1793
1795 SelectionDAG &DAG) const {
1796
1797 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1798 const Function &Fn = DAG.getMachineFunction().getFunction();
1799
1801 Fn,
1802 "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
1803 "requires target sm_52.",
1804 SDLoc(Op).getDebugLoc()));
1805 auto Ops = {DAG.getConstant(0, SDLoc(), Op.getValueType()),
1806 Op.getOperand(0)};
1807 return DAG.getMergeValues(Ops, SDLoc());
1808 }
1809
1810 SDLoc DL(Op.getNode());
1811 SDValue Chain = Op.getOperand(0);
1812 SDValue Size = Op.getOperand(1);
1813 uint64_t Align = Op.getConstantOperandVal(2);
1814
1815 // The alignment on a ISD::DYNAMIC_STACKALLOC node may be 0 to indicate that
1816 // the default stack alignment should be used.
1817 if (Align == 0)
1819
1820 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
1821 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
1822
1823 SDValue Alloc =
1824 DAG.getNode(NVPTXISD::DYNAMIC_STACKALLOC, DL, {LocalVT, MVT::Other},
1825 {Chain, DAG.getZExtOrTrunc(Size, DL, LocalVT),
1826 DAG.getTargetConstant(Align, DL, MVT::i32)});
1827
1828 SDValue ASC = DAG.getAddrSpaceCast(
1830
1831 return DAG.getMergeValues({ASC, SDValue(Alloc.getNode(), 1)}, DL);
1832}
1833
1835 SelectionDAG &DAG) const {
1836 SDLoc DL(Op.getNode());
1837 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1838 const Function &Fn = DAG.getMachineFunction().getFunction();
1839
1841 Fn,
1842 "Support for stackrestore requires PTX ISA version >= 7.3 and target "
1843 ">= sm_52.",
1844 DL.getDebugLoc()));
1845 return Op.getOperand(0);
1846 }
1847
1848 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
1849 SDValue Chain = Op.getOperand(0);
1850 SDValue Ptr = Op.getOperand(1);
1851 SDValue ASC = DAG.getAddrSpaceCast(DL, LocalVT, Ptr, ADDRESS_SPACE_GENERIC,
1853 return DAG.getNode(NVPTXISD::STACKRESTORE, DL, MVT::Other, {Chain, ASC});
1854}
1855
1857 SelectionDAG &DAG) const {
1858 SDLoc DL(Op.getNode());
1859 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1860 const Function &Fn = DAG.getMachineFunction().getFunction();
1861
1863 Fn,
1864 "Support for stacksave requires PTX ISA version >= 7.3 and target >= "
1865 "sm_52.",
1866 DL.getDebugLoc()));
1867 auto Ops = {DAG.getConstant(0, DL, Op.getValueType()), Op.getOperand(0)};
1868 return DAG.getMergeValues(Ops, DL);
1869 }
1870
1871 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
1872 SDValue Chain = Op.getOperand(0);
1873 SDValue SS =
1874 DAG.getNode(NVPTXISD::STACKSAVE, DL, {LocalVT, MVT::Other}, Chain);
1875 SDValue ASC = DAG.getAddrSpaceCast(
1876 DL, Op.getValueType(), SS, ADDRESS_SPACE_LOCAL, ADDRESS_SPACE_GENERIC);
1877 return DAG.getMergeValues({ASC, SDValue(SS.getNode(), 1)}, DL);
1878}
1879
1880// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1881// (see LegalizeDAG.cpp). This is slow and uses local memory.
1882// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1883SDValue
1884NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1885 SDNode *Node = Op.getNode();
1886 SDLoc dl(Node);
1888 unsigned NumOperands = Node->getNumOperands();
1889 for (unsigned i = 0; i < NumOperands; ++i) {
1890 SDValue SubOp = Node->getOperand(i);
1891 EVT VVT = SubOp.getNode()->getValueType(0);
1892 EVT EltVT = VVT.getVectorElementType();
1893 unsigned NumSubElem = VVT.getVectorNumElements();
1894 for (unsigned j = 0; j < NumSubElem; ++j) {
1895 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
1896 DAG.getIntPtrConstant(j, dl)));
1897 }
1898 }
1899 return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
1900}
1901
1903 SelectionDAG &DAG,
1904 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1905 assert(A.getValueType() == MVT::i32 && B.getValueType() == MVT::i32 &&
1906 Selector.getValueType() == MVT::i32 && "PRMT must have i32 operands");
1907 return DAG.getNode(NVPTXISD::PRMT, DL, MVT::i32,
1908 {A, B, Selector, DAG.getConstant(Mode, DL, MVT::i32)});
1909}
1910
1912 SelectionDAG &DAG,
1913 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1914 return getPRMT(A, B, DAG.getConstant(Selector, DL, MVT::i32), DL, DAG, Mode);
1915}
1916
1917/// Reduces the elements using the scalar operations provided. The operations
1918/// are sorted descending in number of inputs they take. The flags on the
1919/// original reduction operation will be propagated to each scalar operation.
1920/// Nearby elements are grouped in tree reduction, unlike the shuffle reduction
1921/// used in ExpandReductions and SelectionDAG.
1923 const SmallVector<SDValue> &Elements, EVT EltTy,
1924 ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
1925 const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
1926 // Build the reduction tree at each level, starting with all the elements.
1927 SmallVector<SDValue> Level = Elements;
1928
1929 unsigned OpIdx = 0;
1930 while (Level.size() > 1) {
1931 // Try to reduce this level using the current operator.
1932 const auto [Op, NumInputs] = Ops[OpIdx];
1933
1934 // Build the next level by partially reducing all elements.
1935 SmallVector<SDValue> ReducedLevel;
1936 unsigned I = 0, E = Level.size();
1937 for (; I + NumInputs <= E; I += NumInputs) {
1938 // Reduce elements in groups of [NumInputs], as much as possible.
1939 ReducedLevel.push_back(DAG.getNode(
1940 Op, DL, EltTy, ArrayRef<SDValue>(Level).slice(I, NumInputs), Flags));
1941 }
1942
1943 if (I < E) {
1944 // Handle leftover elements.
1945
1946 if (ReducedLevel.empty()) {
1947 // We didn't reduce anything at this level. We need to pick a smaller
1948 // operator.
1949 ++OpIdx;
1950 assert(OpIdx < Ops.size() && "no smaller operators for reduction");
1951 continue;
1952 }
1953
1954 // We reduced some things but there's still more left, meaning the
1955 // operator's number of inputs doesn't evenly divide this level size. Move
1956 // these elements to the next level.
1957 for (; I < E; ++I)
1958 ReducedLevel.push_back(Level[I]);
1959 }
1960
1961 // Process the next level.
1962 Level = ReducedLevel;
1963 }
1964
1965 return *Level.begin();
1966}
1967
1968// Get scalar reduction opcode
1969static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode) {
1970 switch (ReductionOpcode) {
1972 return ISD::FMAXNUM;
1974 return ISD::FMINNUM;
1976 return ISD::FMAXIMUM;
1978 return ISD::FMINIMUM;
1979 default:
1980 llvm_unreachable("unhandled reduction opcode");
1981 }
1982}
1983
1984/// Get 3-input scalar reduction opcode
1985static std::optional<unsigned>
1986getScalar3OpcodeForReduction(unsigned ReductionOpcode) {
1987 switch (ReductionOpcode) {
1989 return NVPTXISD::FMAXNUM3;
1991 return NVPTXISD::FMINNUM3;
1993 return NVPTXISD::FMAXIMUM3;
1995 return NVPTXISD::FMINIMUM3;
1996 default:
1997 return std::nullopt;
1998 }
1999}
2000
2001/// Lower reductions to either a sequence of operations or a tree if
2002/// reassociations are allowed. This method will use larger operations like
2003/// max3/min3 when the target supports them.
2004SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2005 SelectionDAG &DAG) const {
2006 SDLoc DL(Op);
2007 const SDNodeFlags Flags = Op->getFlags();
2008 SDValue Vector = Op.getOperand(0);
2009
2010 const unsigned Opcode = Op->getOpcode();
2011 const EVT EltTy = Vector.getValueType().getVectorElementType();
2012
2013 // Whether we can use 3-input min/max when expanding the reduction.
2014 const bool CanUseMinMax3 =
2015 EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2016 STI.getPTXVersion() >= 88 &&
2017 (Opcode == ISD::VECREDUCE_FMAX || Opcode == ISD::VECREDUCE_FMIN ||
2018 Opcode == ISD::VECREDUCE_FMAXIMUM || Opcode == ISD::VECREDUCE_FMINIMUM);
2019
2020 // A list of SDNode opcodes with equivalent semantics, sorted descending by
2021 // number of inputs they take.
2022 SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2023
2024 if (auto Opcode3Elem = getScalar3OpcodeForReduction(Opcode);
2025 CanUseMinMax3 && Opcode3Elem)
2026 ScalarOps.push_back({*Opcode3Elem, 3});
2027 ScalarOps.push_back({getScalarOpcodeForReduction(Opcode), 2});
2028
2030 DAG.ExtractVectorElements(Vector, Elements);
2031
2032 return buildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2033}
2034
2035SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2036 // Handle bitcasting from v2i8 without hitting the default promotion
2037 // strategy which goes through stack memory.
2038 EVT FromVT = Op->getOperand(0)->getValueType(0);
2039 if (FromVT != MVT::v2i8) {
2040 return Op;
2041 }
2042
2043 // Pack vector elements into i16 and bitcast to final type
2044 SDLoc DL(Op);
2045 SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2046 Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
2047 SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2048 Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
2049 SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
2050 SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
2051 SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
2052 SDValue AsInt = DAG.getNode(
2053 ISD::OR, DL, MVT::i16,
2054 {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
2055 EVT ToVT = Op->getValueType(0);
2056 return DAG.getBitcast(ToVT, AsInt);
2057}
2058
2059// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2060// would get lowered as two constant loads and vector-packing move.
2061// Instead we want just a constant move:
2062// mov.b32 %r2, 0x40003C00
2063SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2064 SelectionDAG &DAG) const {
2065 EVT VT = Op->getValueType(0);
2066 if (!(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector()))
2067 return Op;
2068 SDLoc DL(Op);
2069
2070 if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
2071 return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
2072 isa<ConstantFPSDNode>(Operand);
2073 })) {
2074 if (VT != MVT::v4i8)
2075 return Op;
2076 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2077 // to optimize calculation of constant parts.
2078 auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2079 uint64_t SelectionValue) -> SDValue {
2080 SDValue L = Left;
2081 SDValue R = Right;
2082 if (Cast) {
2083 L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);
2084 R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);
2085 }
2086 return getPRMT(L, R, SelectionValue, DL, DAG);
2087 };
2088 auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);
2089 auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);
2090 auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2091 return DAG.getBitcast(VT, PRMT3210);
2092 }
2093
2094 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2095 auto GetOperand = [](SDValue Op, int N) -> APInt {
2096 const SDValue &Operand = Op->getOperand(N);
2097 EVT VT = Op->getValueType(0);
2098 if (Operand->isUndef())
2099 return APInt(32, 0);
2100 APInt Value;
2101 if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2102 Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
2103 else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2104 Value = Operand->getAsAPIntVal();
2105 else
2106 llvm_unreachable("Unsupported type");
2107 // i8 values are carried around as i16, so we need to zero out upper bits,
2108 // so they do not get in the way of combining individual byte values
2109 if (VT == MVT::v4i8)
2110 Value = Value.trunc(8);
2111 return Value.zext(32);
2112 };
2113
2114 // Construct a 32-bit constant by shifting into place smaller values
2115 // (elements of the vector type VT).
2116 // For example, if VT has 2 elements, then N == 2:
2117 // ShiftAmount = 32 / N = 16
2118 // Value |= Op0 (b16) << 0
2119 // Value |= Op1 (b16) << 16
2120 // If N == 4:
2121 // ShiftAmount = 32 / N = 8
2122 // Value |= Op0 (b8) << 0
2123 // Value |= Op1 (b8) << 8
2124 // Value |= Op2 (b8) << 16
2125 // Value |= Op3 (b8) << 24
2126 // ...etc
2127 APInt Value(32, 0);
2128 const unsigned NumElements = VT.getVectorNumElements();
2129 assert(32 % NumElements == 0 && "must evenly divide bit length");
2130 const unsigned ShiftAmount = 32 / NumElements;
2131 for (unsigned ElementNo : seq(NumElements))
2132 Value |= GetOperand(Op, ElementNo).shl(ElementNo * ShiftAmount);
2133 SDValue Const = DAG.getConstant(Value, DL, MVT::i32);
2134 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), Const);
2135}
2136
2137SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2138 SelectionDAG &DAG) const {
2139 SDValue Index = Op->getOperand(1);
2140 SDValue Vector = Op->getOperand(0);
2141 SDLoc DL(Op);
2142 EVT VectorVT = Vector.getValueType();
2143
2144 if (VectorVT == MVT::v4i8) {
2145 SDValue Selector = DAG.getNode(ISD::OR, DL, MVT::i32,
2146 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2147 DAG.getConstant(0x7770, DL, MVT::i32));
2148 SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, Vector),
2149 DAG.getConstant(0, DL, MVT::i32), Selector, DL, DAG);
2150 SDValue Ext = DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));
2151 SDNodeFlags Flags;
2152 Flags.setNoSignedWrap(Ext.getScalarValueSizeInBits() > 8);
2153 Flags.setNoUnsignedWrap(Ext.getScalarValueSizeInBits() >= 8);
2154 Ext->setFlags(Flags);
2155 return Ext;
2156 }
2157
2158 // Constant index will be matched by tablegen.
2159 if (isa<ConstantSDNode>(Index.getNode()))
2160 return Op;
2161
2162 // Extract individual elements and select one of them.
2163 assert(NVPTX::isPackedVectorTy(VectorVT) &&
2164 VectorVT.getVectorNumElements() == 2 && "Unexpected vector type.");
2165 EVT EltVT = VectorVT.getVectorElementType();
2166
2167 SDLoc dl(Op.getNode());
2168 SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2169 DAG.getIntPtrConstant(0, dl));
2170 SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2171 DAG.getIntPtrConstant(1, dl));
2172 return DAG.getSelectCC(dl, Index, DAG.getIntPtrConstant(0, dl), E0, E1,
2174}
2175
2176SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2177 SelectionDAG &DAG) const {
2178 SDValue Vector = Op->getOperand(0);
2179 EVT VectorVT = Vector.getValueType();
2180
2181 if (VectorVT != MVT::v4i8)
2182 return Op;
2183 SDLoc DL(Op);
2184 SDValue Value = Op->getOperand(1);
2185 if (Value->isUndef())
2186 return Vector;
2187
2188 SDValue Index = Op->getOperand(2);
2189
2190 SDValue BFI =
2191 DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2192 {DAG.getZExtOrTrunc(Value, DL, MVT::i32), Vector,
2193 DAG.getNode(ISD::MUL, DL, MVT::i32,
2194 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2195 DAG.getConstant(8, DL, MVT::i32)),
2196 DAG.getConstant(8, DL, MVT::i32)});
2197 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), BFI);
2198}
2199
2200SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2201 SelectionDAG &DAG) const {
2202 SDValue V1 = Op.getOperand(0);
2203 EVT VectorVT = V1.getValueType();
2204 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2205 return Op;
2206
2207 // Lower shuffle to PRMT instruction.
2208 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
2209 SDValue V2 = Op.getOperand(1);
2210 uint32_t Selector = 0;
2211 for (auto I : llvm::enumerate(SVN->getMask())) {
2212 if (I.value() != -1) // -1 is a placeholder for undef.
2213 Selector |= (I.value() << (I.index() * 4));
2214 }
2215
2216 SDLoc DL(Op);
2217 SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, V1),
2218 DAG.getBitcast(MVT::i32, V2), Selector, DL, DAG);
2219 return DAG.getBitcast(Op.getValueType(), PRMT);
2220}
2221/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2222/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2223/// amount, or
2224/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2225/// amount.
2226SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2227 SelectionDAG &DAG) const {
2228 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2229 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2230
2231 EVT VT = Op.getValueType();
2232 unsigned VTBits = VT.getSizeInBits();
2233 SDLoc dl(Op);
2234 SDValue ShOpLo = Op.getOperand(0);
2235 SDValue ShOpHi = Op.getOperand(1);
2236 SDValue ShAmt = Op.getOperand(2);
2237 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2238
2239 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2240 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2241 // {dHi, dLo} = {aHi, aLo} >> Amt
2242 // dHi = aHi >> Amt
2243 // dLo = shf.r.clamp aLo, aHi, Amt
2244
2245 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2246 SDValue Lo =
2247 DAG.getNode(NVPTXISD::FSHR_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
2248
2249 SDValue Ops[2] = { Lo, Hi };
2250 return DAG.getMergeValues(Ops, dl);
2251 }
2252 else {
2253 // {dHi, dLo} = {aHi, aLo} >> Amt
2254 // - if (Amt>=size) then
2255 // dLo = aHi >> (Amt-size)
2256 // dHi = aHi >> Amt (this is either all 0 or all 1)
2257 // else
2258 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2259 // dHi = aHi >> Amt
2260
2261 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2262 DAG.getConstant(VTBits, dl, MVT::i32),
2263 ShAmt);
2264 SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
2265 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2266 DAG.getConstant(VTBits, dl, MVT::i32));
2267 SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
2268 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2269 SDValue TrueVal = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
2270
2271 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2272 DAG.getConstant(VTBits, dl, MVT::i32),
2273 ISD::SETGE);
2274 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2275 SDValue Lo = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2276
2277 SDValue Ops[2] = { Lo, Hi };
2278 return DAG.getMergeValues(Ops, dl);
2279 }
2280}
2281
2282/// LowerShiftLeftParts - Lower SHL_PARTS, which
2283/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2284/// amount, or
2285/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2286/// amount.
2287SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2288 SelectionDAG &DAG) const {
2289 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2290 assert(Op.getOpcode() == ISD::SHL_PARTS);
2291
2292 EVT VT = Op.getValueType();
2293 unsigned VTBits = VT.getSizeInBits();
2294 SDLoc dl(Op);
2295 SDValue ShOpLo = Op.getOperand(0);
2296 SDValue ShOpHi = Op.getOperand(1);
2297 SDValue ShAmt = Op.getOperand(2);
2298
2299 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2300 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2301 // {dHi, dLo} = {aHi, aLo} << Amt
2302 // dHi = shf.l.clamp aLo, aHi, Amt
2303 // dLo = aLo << Amt
2304
2305 SDValue Hi =
2306 DAG.getNode(NVPTXISD::FSHL_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
2307 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2308
2309 SDValue Ops[2] = { Lo, Hi };
2310 return DAG.getMergeValues(Ops, dl);
2311 }
2312 else {
2313 // {dHi, dLo} = {aHi, aLo} << Amt
2314 // - if (Amt>=size) then
2315 // dLo = aLo << Amt (all 0)
2316 // dLo = aLo << (Amt-size)
2317 // else
2318 // dLo = aLo << Amt
2319 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
2320
2321 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2322 DAG.getConstant(VTBits, dl, MVT::i32),
2323 ShAmt);
2324 SDValue Tmp1 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
2325 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2326 DAG.getConstant(VTBits, dl, MVT::i32));
2327 SDValue Tmp2 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
2328 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2329 SDValue TrueVal = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
2330
2331 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2332 DAG.getConstant(VTBits, dl, MVT::i32),
2333 ISD::SETGE);
2334 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2335 SDValue Hi = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2336
2337 SDValue Ops[2] = { Lo, Hi };
2338 return DAG.getMergeValues(Ops, dl);
2339 }
2340}
2341
2342/// If the types match, convert the generic copysign to the NVPTXISD version,
2343/// otherwise bail ensuring that mismatched cases are properly expaned.
2344SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
2345 SelectionDAG &DAG) const {
2346 EVT VT = Op.getValueType();
2347 SDLoc DL(Op);
2348
2349 SDValue In1 = Op.getOperand(0);
2350 SDValue In2 = Op.getOperand(1);
2351 EVT SrcVT = In2.getValueType();
2352
2353 if (!SrcVT.bitsEq(VT))
2354 return SDValue();
2355
2356 return DAG.getNode(NVPTXISD::FCOPYSIGN, DL, VT, In1, In2);
2357}
2358
2359SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2360 EVT VT = Op.getValueType();
2361
2362 if (VT == MVT::f32)
2363 return LowerFROUND32(Op, DAG);
2364
2365 if (VT == MVT::f64)
2366 return LowerFROUND64(Op, DAG);
2367
2368 llvm_unreachable("unhandled type");
2369}
2370
2371// This is the the rounding method used in CUDA libdevice in C like code:
2372// float roundf(float A)
2373// {
2374// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2375// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2376// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2377// }
2378SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2379 SelectionDAG &DAG) const {
2380 SDLoc SL(Op);
2381 SDValue A = Op.getOperand(0);
2382 EVT VT = Op.getValueType();
2383
2384 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2385
2386 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2387 SDValue Bitcast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A);
2388 const unsigned SignBitMask = 0x80000000;
2389 SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast,
2390 DAG.getConstant(SignBitMask, SL, MVT::i32));
2391 const unsigned PointFiveInBits = 0x3F000000;
2392 SDValue PointFiveWithSignRaw =
2393 DAG.getNode(ISD::OR, SL, MVT::i32, Sign,
2394 DAG.getConstant(PointFiveInBits, SL, MVT::i32));
2395 SDValue PointFiveWithSign =
2396 DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw);
2397 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign);
2398 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2399
2400 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2401 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2402 SDValue IsLarge =
2403 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT),
2404 ISD::SETOGT);
2405 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2406
2407 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2408 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2409 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2410 SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A);
2411 return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA);
2412}
2413
2414// The implementation of round(double) is similar to that of round(float) in
2415// that they both separate the value range into three regions and use a method
2416// specific to the region to round the values. However, round(double) first
2417// calculates the round of the absolute value and then adds the sign back while
2418// round(float) directly rounds the value with sign.
2419SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2420 SelectionDAG &DAG) const {
2421 SDLoc SL(Op);
2422 SDValue A = Op.getOperand(0);
2423 EVT VT = Op.getValueType();
2424
2425 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2426
2427 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2428 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA,
2429 DAG.getConstantFP(0.5, SL, VT));
2430 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2431
2432 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2433 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2434 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2435 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2436 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall,
2437 DAG.getConstantFP(0, SL, VT),
2438 RoundedA);
2439
2440 // Add sign to rounded_A
2441 RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A);
2442 DAG.getNode(ISD::FTRUNC, SL, VT, A);
2443
2444 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2445 SDValue IsLarge =
2446 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT),
2447 ISD::SETOGT);
2448 return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2449}
2450
2452 EVT VT = N->getValueType(0);
2453 EVT NVT = MVT::f32;
2454 if (VT.isVector()) {
2455 NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
2456 }
2457 SDLoc DL(N);
2458 SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
2459 SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
2460 SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
2461 return DAG.getFPExtendOrRound(Res, DL, VT);
2462}
2463
2464SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
2465 SelectionDAG &DAG) const {
2466 if (useF32FTZ(DAG.getMachineFunction())) {
2467 return PromoteBinOpToF32(Op.getNode(), DAG);
2468 }
2469 return Op;
2470}
2471
2472SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2473 SelectionDAG &DAG) const {
2474 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2475
2476 if (Op.getValueType() == MVT::bf16) {
2477 SDLoc Loc(Op);
2478 return DAG.getNode(
2479 ISD::FP_ROUND, Loc, MVT::bf16,
2480 DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
2481 DAG.getIntPtrConstant(0, Loc, /*isTarget=*/true));
2482 }
2483
2484 // Everything else is considered legal.
2485 return Op;
2486}
2487
2488SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2489 SelectionDAG &DAG) const {
2490 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2491
2492 if (Op.getOperand(0).getValueType() == MVT::bf16) {
2493 SDLoc Loc(Op);
2494 return DAG.getNode(
2495 Op.getOpcode(), Loc, Op.getValueType(),
2496 DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0)));
2497 }
2498
2499 // Everything else is considered legal.
2500 return Op;
2501}
2502
2503SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2504 SelectionDAG &DAG) const {
2505 EVT NarrowVT = Op.getValueType();
2506 SDValue Wide = Op.getOperand(0);
2507 EVT WideVT = Wide.getValueType();
2508 if (NarrowVT.getScalarType() == MVT::bf16) {
2509 const TargetLowering *TLI = STI.getTargetLowering();
2510 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2511 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2512 }
2513 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2514 // This combination was the first to support f32 -> bf16.
2515 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2516 if (WideVT.getScalarType() == MVT::f32) {
2517 return Op;
2518 }
2519 if (WideVT.getScalarType() == MVT::f64) {
2520 SDLoc Loc(Op);
2521 // Round-inexact-to-odd f64 to f32, then do the final rounding using
2522 // the hardware f32 -> bf16 instruction.
2524 WideVT.changeElementType(*DAG.getContext(), MVT::f32), Wide, Loc,
2525 DAG);
2526 return DAG.getFPExtendOrRound(rod, Loc, NarrowVT);
2527 }
2528 }
2529 return TLI->expandFP_ROUND(Op.getNode(), DAG);
2530 }
2531 }
2532
2533 // Everything else is considered legal.
2534 return Op;
2535}
2536
2537SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2538 SelectionDAG &DAG) const {
2539 SDValue Narrow = Op.getOperand(0);
2540 EVT NarrowVT = Narrow.getValueType();
2541 EVT WideVT = Op.getValueType();
2542 if (NarrowVT.getScalarType() == MVT::bf16) {
2543 if (WideVT.getScalarType() == MVT::f32 &&
2544 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2545 SDLoc Loc(Op);
2546 return DAG.getNode(ISD::BF16_TO_FP, Loc, WideVT, Narrow);
2547 }
2548 if (WideVT.getScalarType() == MVT::f64 &&
2549 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2550 EVT F32 = NarrowVT.changeElementType(*DAG.getContext(), MVT::f32);
2551 SDLoc Loc(Op);
2552 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2553 Op = DAG.getNode(ISD::FP_EXTEND, Loc, F32, Narrow);
2554 } else {
2555 Op = DAG.getNode(ISD::BF16_TO_FP, Loc, F32, Narrow);
2556 }
2557 return DAG.getNode(ISD::FP_EXTEND, Loc, WideVT, Op);
2558 }
2559 }
2560
2561 // Everything else is considered legal.
2562 return Op;
2563}
2564
2566 SDLoc DL(Op);
2567 if (Op.getValueType() != MVT::v2i16)
2568 return Op;
2569 EVT EltVT = Op.getValueType().getVectorElementType();
2570 SmallVector<SDValue> VecElements;
2571 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2572 SmallVector<SDValue> ScalarArgs;
2573 llvm::transform(Op->ops(), std::back_inserter(ScalarArgs),
2574 [&](const SDUse &O) {
2575 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT,
2576 O.get(), DAG.getIntPtrConstant(I, DL));
2577 });
2578 VecElements.push_back(DAG.getNode(Op.getOpcode(), DL, EltVT, ScalarArgs));
2579 }
2580 SDValue V =
2581 DAG.getNode(ISD::BUILD_VECTOR, DL, Op.getValueType(), VecElements);
2582 return V;
2583}
2584
2586 bool hasOffset = false) {
2587 // skip lowering if the vector operand is already legalized
2588 if (!Op->getOperand(hasOffset ? 4 : 3).getValueType().isVector())
2589 return Op;
2590
2591 SDNode *N = Op.getNode();
2592 SDLoc DL(N);
2594
2595 // split the vector argument
2596 for (size_t I = 0; I < N->getNumOperands(); I++) {
2597 SDValue Val = N->getOperand(I);
2598 EVT ValVT = Val.getValueType();
2599 if (ValVT.isVector()) {
2600 EVT EltVT = ValVT.getVectorElementType();
2601 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2602 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
2603 DAG.getIntPtrConstant(J, DL)));
2604 } else
2605 Ops.push_back(Val);
2606 }
2607
2609 SDValue Tcgen05StNode =
2610 DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL, N->getVTList(), Ops,
2611 MemSD->getMemoryVT(), MemSD->getMemOperand());
2612
2613 return Tcgen05StNode;
2614}
2615
2617 SDLoc DL(Op);
2618 SDValue Src = Op.getOperand(0);
2619 EVT VT = Op.getValueType();
2620
2621 switch (VT.getSimpleVT().SimpleTy) {
2622 case MVT::i16: {
2623 SDValue Extended = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Src);
2624 SDValue Swapped =
2625 getPRMT(Extended, DAG.getConstant(0, DL, MVT::i32), 0x7701, DL, DAG);
2626 return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Swapped);
2627 }
2628 case MVT::i32: {
2629 return getPRMT(Src, DAG.getConstant(0, DL, MVT::i32), 0x0123, DL, DAG);
2630 }
2631 case MVT::v2i16: {
2632 SDValue Converted = DAG.getBitcast(MVT::i32, Src);
2633 SDValue Swapped =
2634 getPRMT(Converted, DAG.getConstant(0, DL, MVT::i32), 0x2301, DL, DAG);
2635 return DAG.getNode(ISD::BITCAST, DL, MVT::v2i16, Swapped);
2636 }
2637 case MVT::i64: {
2638 SDValue UnpackSrc =
2639 DAG.getNode(NVPTXISD::UNPACK_VECTOR, DL, {MVT::i32, MVT::i32}, Src);
2640 SDValue SwappedLow =
2641 getPRMT(UnpackSrc.getValue(0), DAG.getConstant(0, DL, MVT::i32), 0x0123,
2642 DL, DAG);
2643 SDValue SwappedHigh =
2644 getPRMT(UnpackSrc.getValue(1), DAG.getConstant(0, DL, MVT::i32), 0x0123,
2645 DL, DAG);
2646 return DAG.getNode(NVPTXISD::BUILD_VECTOR, DL, MVT::i64,
2647 {SwappedHigh, SwappedLow});
2648 }
2649 default:
2650 llvm_unreachable("unsupported type for bswap");
2651 }
2652}
2653
2654static unsigned getTcgen05MMADisableOutputLane(unsigned IID) {
2655 switch (IID) {
2656 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2657 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG1;
2658 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2659 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG2;
2660 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2661 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2662 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2663 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2664 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2665 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2666 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2667 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2668 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2669 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2670 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2671 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2672 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2673 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2674 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2675 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2676 case Intrinsic::
2677 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2678 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2679 case Intrinsic::
2680 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2681 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2682 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2683 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG1;
2684 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2685 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG2;
2686 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2687 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2688 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2689 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2690 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2691 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2692 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2693 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2694 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2695 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2696 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2697 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2698 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2699 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2700 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2701 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2702 case Intrinsic::
2703 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2704 return NVPTXISD::
2705 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2706 case Intrinsic::
2707 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2708 return NVPTXISD::
2709 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2710 };
2711 llvm_unreachable("unhandled tcgen05.mma.disable_output_lane intrinsic");
2712}
2713
2715 SDNode *N = Op.getNode();
2716 SDLoc DL(N);
2717 unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
2718
2720 // split the vector argument
2721 for (size_t I = 0; I < N->getNumOperands(); I++) {
2722 if (I == 1)
2723 continue; // skip IID
2724 SDValue Val = N->getOperand(I);
2725 EVT ValVT = Val.getValueType();
2726 if (ValVT.isVector()) {
2727 EVT EltVT = ValVT.getVectorElementType();
2728 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2729 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
2730 DAG.getIntPtrConstant(J, DL)));
2731 } else
2732 Ops.push_back(Val);
2733 }
2734
2736 SDValue Tcgen05MMANode = DAG.getMemIntrinsicNode(
2737 getTcgen05MMADisableOutputLane(IID), DL, N->getVTList(), Ops,
2738 MemSD->getMemoryVT(), MemSD->getMemOperand());
2739
2740 return Tcgen05MMANode;
2741}
2742
2743// Lower vector return type of tcgen05.ld intrinsics
2744static std::optional<std::pair<SDValue, SDValue>>
2745lowerTcgen05Ld(SDNode *N, SelectionDAG &DAG, bool HasOffset = false) {
2746 SDLoc DL(N);
2747 EVT ResVT = N->getValueType(0);
2748 if (!ResVT.isVector())
2749 return {}; // already legalized.
2750
2751 const unsigned NumElts = ResVT.getVectorNumElements();
2752
2753 // Create the return type of the instructions
2754 SmallVector<EVT, 5> ListVTs;
2755 for (unsigned i = 0; i < NumElts; ++i)
2756 ListVTs.push_back(MVT::i32);
2757
2758 ListVTs.push_back(N->getValueType(1)); // Chain
2759
2760 SDVTList ResVTs = DAG.getVTList(ListVTs);
2761
2762 SmallVector<SDValue, 8> Ops{N->getOperand(0), N->getOperand(1),
2763 N->getOperand(2)};
2764
2765 if (HasOffset) {
2766 Ops.push_back(N->getOperand(3)); // offset
2767 Ops.push_back(N->getOperand(4)); // Pack flag
2768 } else
2769 Ops.push_back(N->getOperand(3)); // Pack flag
2770
2772 SDValue NewNode =
2774 MemSD->getMemoryVT(), MemSD->getMemOperand());
2775
2776 // split the vector result
2777 SmallVector<SDValue, 4> ScalarRes;
2778 for (unsigned i = 0; i < NumElts; ++i) {
2779 SDValue Res = NewNode.getValue(i);
2780 ScalarRes.push_back(Res);
2781 }
2782
2783 SDValue Chain = NewNode.getValue(NumElts);
2784 SDValue BuildVector = DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
2785 return {{BuildVector, Chain}};
2786}
2787
2789 unsigned Val) {
2790 SDNode *N = Op.getNode();
2791 SDLoc DL(N);
2792
2793 const Function &Fn = DAG.getMachineFunction().getFunction();
2794
2795 unsigned AS = 0;
2796 if (auto *MemN = dyn_cast<MemIntrinsicSDNode>(N))
2797 AS = MemN->getAddressSpace();
2798 Type *PtrTy = PointerType::get(*DAG.getContext(), AS);
2800
2802 Fn,
2803 "Intrinsic " +
2804 Intrinsic::getName(N->getConstantOperandVal(1), {PtrTy}, M) +
2805 " with value " + Twine(Val) +
2806 " is not supported on the given target.",
2807 DL.getDebugLoc()));
2808 return Op.getOperand(0);
2809}
2810
2812 SDNode *N = Op.getNode();
2813 SDLoc DL(N);
2814
2815 // immediate argument representing elemtype
2816 unsigned Val = N->getConstantOperandVal(3);
2817
2819 Val))
2820 return reportInvalidTensormapReplaceUsage(Op, DAG, Val);
2821
2822 return Op;
2823}
2824
2826 SDNode *N = Op.getNode();
2827 SDLoc DL(N);
2828
2829 // immediate argument representing swizzle mode
2830 unsigned Val = N->getConstantOperandVal(3);
2831
2833 Val))
2834 return reportInvalidTensormapReplaceUsage(Op, DAG, Val);
2835
2836 return Op;
2837}
2838
2840 SDNode *N = Op.getNode();
2841 SDValue Intrin = N->getOperand(1);
2842
2843 // Get the intrinsic ID
2844 unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
2845 switch (IntrinNo) {
2846 default:
2847 break;
2848 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
2849 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
2850 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
2851 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
2852 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
2853 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
2854 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
2855 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
2856 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
2857 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
2858 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
2859 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
2860 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
2861 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
2862 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
2863 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
2864 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
2865 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
2866 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
2867 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
2868 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
2869 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
2870 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
2871 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
2872 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
2873 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
2874 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
2875 return lowerTcgen05St(Op, DAG);
2876 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2:
2877 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4:
2878 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8:
2879 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16:
2880 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32:
2881 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64:
2882 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128:
2883 return lowerTcgen05St(Op, DAG, /* hasOffset */ true);
2884 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2885 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2886 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2887 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2888 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2889 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2890 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2891 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2892 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2893 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2894 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2895 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2896 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2897 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2898 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2899 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2900 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2901 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2902 case Intrinsic::
2903 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2904 case Intrinsic::
2905 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2906 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2907 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2908 case Intrinsic::
2909 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2910 case Intrinsic::
2911 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2913 case Intrinsic::nvvm_tensormap_replace_elemtype:
2914 return lowerTensormapReplaceElemtype(Op, DAG);
2915 case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
2917 }
2918 return Op;
2919}
2920
2922 SelectionDAG &DAG) {
2923
2924 SDNode *N = Op.getNode();
2925 if (N->getOperand(1).getValueType() != MVT::i128) {
2926 // return, if the operand is already lowered
2927 return SDValue();
2928 }
2929
2930 unsigned IID =
2931 cast<ConstantSDNode>(N->getOperand(0).getNode())->getZExtValue();
2932 auto Opcode = [&]() {
2933 switch (IID) {
2934 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
2935 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED;
2936 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
2937 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X;
2938 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
2939 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y;
2940 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
2941 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z;
2942 default:
2943 llvm_unreachable("unsupported/unhandled intrinsic");
2944 }
2945 }();
2946
2947 SDLoc DL(N);
2948 SDValue TryCancelResponse = N->getOperand(1);
2949 SDValue Cast = DAG.getNode(ISD::BITCAST, DL, MVT::v2i64, TryCancelResponse);
2950 SDValue TryCancelResponse0 =
2951 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
2952 DAG.getIntPtrConstant(0, DL));
2953 SDValue TryCancelResponse1 =
2954 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
2955 DAG.getIntPtrConstant(1, DL));
2956
2957 return DAG.getNode(Opcode, DL, N->getVTList(),
2958 {TryCancelResponse0, TryCancelResponse1});
2959}
2960
2962 SDNode *N = Op.getNode();
2963 SDLoc DL(N);
2964 SDValue F32Vec = N->getOperand(1);
2965 SDValue RBits = N->getOperand(2);
2966
2967 unsigned IntrinsicID = N->getConstantOperandVal(0);
2968
2969 // Extract the 4 float elements from the vector
2971 for (unsigned i = 0; i < 4; ++i)
2972 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
2973 DAG.getIntPtrConstant(i, DL)));
2974
2976
2977 auto [OpCode, RetTy, CvtModeFlag] =
2978 [&]() -> std::tuple<unsigned, MVT::SimpleValueType, uint32_t> {
2979 switch (IntrinsicID) {
2980 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
2981 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8,
2982 CvtMode::RS | CvtMode::RELU_FLAG};
2983 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
2984 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2985 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
2986 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8,
2987 CvtMode::RS | CvtMode::RELU_FLAG};
2988 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
2989 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2990 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
2991 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8,
2992 CvtMode::RS | CvtMode::RELU_FLAG};
2993 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
2994 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2995 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
2996 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8,
2997 CvtMode::RS | CvtMode::RELU_FLAG};
2998 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
2999 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
3000 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
3001 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16,
3002 CvtMode::RS | CvtMode::RELU_FLAG};
3003 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
3004 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16, CvtMode::RS};
3005 default:
3006 llvm_unreachable("unsupported/unhandled intrinsic");
3007 }
3008 }();
3009
3010 Ops.push_back(RBits);
3011 Ops.push_back(DAG.getConstant(CvtModeFlag, DL, MVT::i32));
3012
3013 return DAG.getNode(OpCode, DL, RetTy, Ops);
3014}
3015
3017 const unsigned Mode = [&]() {
3018 switch (Op->getConstantOperandVal(0)) {
3019 case Intrinsic::nvvm_prmt:
3021 case Intrinsic::nvvm_prmt_b4e:
3023 case Intrinsic::nvvm_prmt_ecl:
3025 case Intrinsic::nvvm_prmt_ecr:
3027 case Intrinsic::nvvm_prmt_f4e:
3029 case Intrinsic::nvvm_prmt_rc16:
3031 case Intrinsic::nvvm_prmt_rc8:
3033 default:
3034 llvm_unreachable("unsupported/unhandled intrinsic");
3035 }
3036 }();
3037 SDLoc DL(Op);
3038 SDValue A = Op->getOperand(1);
3039 SDValue B = Op.getNumOperands() == 4 ? Op.getOperand(2)
3040 : DAG.getConstant(0, DL, MVT::i32);
3041 SDValue Selector = (Op->op_end() - 1)->get();
3042 return getPRMT(A, B, Selector, DL, DAG, Mode);
3043}
3044
3045#define TCGEN05_LD_RED_INTR(SHAPE, NUM, TYPE) \
3046 Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_x##NUM##_##TYPE
3047
3048#define TCGEN05_LD_RED_INST(SHAPE, NUM, TYPE) \
3049 NVPTXISD::TCGEN05_LD_RED_##SHAPE##_X##NUM##_##TYPE
3050
3051static unsigned getTcgen05LdRedID(Intrinsic::ID IID) {
3052 switch (IID) {
3053 case TCGEN05_LD_RED_INTR(32x32b, 2, f32):
3054 return TCGEN05_LD_RED_INST(32x32b, 2, F32);
3055 case TCGEN05_LD_RED_INTR(32x32b, 4, f32):
3056 return TCGEN05_LD_RED_INST(32x32b, 4, F32);
3057 case TCGEN05_LD_RED_INTR(32x32b, 8, f32):
3058 return TCGEN05_LD_RED_INST(32x32b, 8, F32);
3059 case TCGEN05_LD_RED_INTR(32x32b, 16, f32):
3060 return TCGEN05_LD_RED_INST(32x32b, 16, F32);
3061 case TCGEN05_LD_RED_INTR(32x32b, 32, f32):
3062 return TCGEN05_LD_RED_INST(32x32b, 32, F32);
3063 case TCGEN05_LD_RED_INTR(32x32b, 64, f32):
3064 return TCGEN05_LD_RED_INST(32x32b, 64, F32);
3065 case TCGEN05_LD_RED_INTR(32x32b, 128, f32):
3066 return TCGEN05_LD_RED_INST(32x32b, 128, F32);
3067 case TCGEN05_LD_RED_INTR(16x32bx2, 2, f32):
3068 return TCGEN05_LD_RED_INST(16x32bx2, 2, F32);
3069 case TCGEN05_LD_RED_INTR(16x32bx2, 4, f32):
3070 return TCGEN05_LD_RED_INST(16x32bx2, 4, F32);
3071 case TCGEN05_LD_RED_INTR(16x32bx2, 8, f32):
3072 return TCGEN05_LD_RED_INST(16x32bx2, 8, F32);
3073 case TCGEN05_LD_RED_INTR(16x32bx2, 16, f32):
3074 return TCGEN05_LD_RED_INST(16x32bx2, 16, F32);
3075 case TCGEN05_LD_RED_INTR(16x32bx2, 32, f32):
3076 return TCGEN05_LD_RED_INST(16x32bx2, 32, F32);
3077 case TCGEN05_LD_RED_INTR(16x32bx2, 64, f32):
3078 return TCGEN05_LD_RED_INST(16x32bx2, 64, F32);
3079 case TCGEN05_LD_RED_INTR(16x32bx2, 128, f32):
3080 return TCGEN05_LD_RED_INST(16x32bx2, 128, F32);
3081 case TCGEN05_LD_RED_INTR(32x32b, 2, i32):
3082 return TCGEN05_LD_RED_INST(32x32b, 2, I32);
3083 case TCGEN05_LD_RED_INTR(32x32b, 4, i32):
3084 return TCGEN05_LD_RED_INST(32x32b, 4, I32);
3085 case TCGEN05_LD_RED_INTR(32x32b, 8, i32):
3086 return TCGEN05_LD_RED_INST(32x32b, 8, I32);
3087 case TCGEN05_LD_RED_INTR(32x32b, 16, i32):
3088 return TCGEN05_LD_RED_INST(32x32b, 16, I32);
3089 case TCGEN05_LD_RED_INTR(32x32b, 32, i32):
3090 return TCGEN05_LD_RED_INST(32x32b, 32, I32);
3091 case TCGEN05_LD_RED_INTR(32x32b, 64, i32):
3092 return TCGEN05_LD_RED_INST(32x32b, 64, I32);
3093 case TCGEN05_LD_RED_INTR(32x32b, 128, i32):
3094 return TCGEN05_LD_RED_INST(32x32b, 128, I32);
3095 case TCGEN05_LD_RED_INTR(16x32bx2, 2, i32):
3096 return TCGEN05_LD_RED_INST(16x32bx2, 2, I32);
3097 case TCGEN05_LD_RED_INTR(16x32bx2, 4, i32):
3098 return TCGEN05_LD_RED_INST(16x32bx2, 4, I32);
3099 case TCGEN05_LD_RED_INTR(16x32bx2, 8, i32):
3100 return TCGEN05_LD_RED_INST(16x32bx2, 8, I32);
3101 case TCGEN05_LD_RED_INTR(16x32bx2, 16, i32):
3102 return TCGEN05_LD_RED_INST(16x32bx2, 16, I32);
3103 case TCGEN05_LD_RED_INTR(16x32bx2, 32, i32):
3104 return TCGEN05_LD_RED_INST(16x32bx2, 32, I32);
3105 case TCGEN05_LD_RED_INTR(16x32bx2, 64, i32):
3106 return TCGEN05_LD_RED_INST(16x32bx2, 64, I32);
3107 case TCGEN05_LD_RED_INTR(16x32bx2, 128, i32):
3108 return TCGEN05_LD_RED_INST(16x32bx2, 128, I32);
3109 default:
3110 llvm_unreachable("Invalid tcgen05.ld.red intrinsic ID");
3111 }
3112}
3113
3114// Lower vector return type of tcgen05.ld intrinsics
3115static std::optional<std::tuple<SDValue, SDValue, SDValue>>
3117 SDLoc DL(N);
3118 EVT ResVT = N->getValueType(0);
3119 if (!ResVT.isVector())
3120 return {}; // already legalized.
3121
3122 const unsigned NumElts = ResVT.getVectorNumElements();
3123
3124 // Create the return type of the instructions
3125 // +1 represents the reduction value
3126 SmallVector<EVT, 132> ListVTs{
3127 NumElts + 1,
3128 ResVT.getVectorElementType().isFloatingPoint() ? MVT::f32 : MVT::i32};
3129
3130 ListVTs.push_back(MVT::Other); // Chain
3131
3132 SDVTList ResVTs = DAG.getVTList(ListVTs);
3133
3134 // Prepare the Operands
3135 SmallVector<SDValue, 8> Ops{N->getOperand(0)}; // Chain
3136
3137 // skip IID at index 1
3138 for (unsigned i = 2; i < N->getNumOperands(); i++)
3139 Ops.push_back(N->getOperand(i));
3140
3141 unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
3143 SDValue NewNode =
3144 DAG.getMemIntrinsicNode(getTcgen05LdRedID(IID), DL, ResVTs, Ops,
3145 MemSD->getMemoryVT(), MemSD->getMemOperand());
3146
3147 // Split vector result
3148 SmallVector<SDValue, 132> ScalarRes;
3149 for (unsigned i = 0; i < NumElts; ++i) {
3150 SDValue Res = NewNode.getValue(i);
3151 ScalarRes.push_back(Res);
3152 }
3153
3154 SDValue BuildVector = DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
3155 SDValue RedResult = NewNode.getValue(NumElts);
3156 SDValue Chain = NewNode.getValue(NumElts + 1);
3157 return {{BuildVector, RedResult, Chain}};
3158}
3159
3161 switch (Op->getConstantOperandVal(1)) {
3162 default:
3163 return Op;
3164
3165 // These tcgen05 intrinsics return a v2i32, which is legal, so we have to
3166 // lower them through LowerOperation() instead of ReplaceNodeResults().
3167 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
3168 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
3169 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
3170 if (auto Res = lowerTcgen05Ld(Op.getNode(), DAG))
3171 return DAG.getMergeValues({Res->first, Res->second}, SDLoc(Op));
3172 return SDValue();
3173
3174 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
3175 if (auto Res = lowerTcgen05Ld(Op.getNode(), DAG, /*HasOffset=*/true))
3176 return DAG.getMergeValues({Res->first, Res->second}, SDLoc(Op));
3177 return SDValue();
3178
3179 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_f32:
3180 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_i32:
3181 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_f32:
3182 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_i32:
3183 if (auto Res = lowerTcgen05LdRed(Op.getNode(), DAG))
3184 return DAG.getMergeValues(
3185 {std::get<0>(*Res), std::get<1>(*Res), std::get<2>(*Res)}, SDLoc(Op));
3186 return SDValue();
3187 }
3188}
3189
3191 switch (Op->getConstantOperandVal(0)) {
3192 default:
3193 return Op;
3194 case Intrinsic::nvvm_prmt:
3195 case Intrinsic::nvvm_prmt_b4e:
3196 case Intrinsic::nvvm_prmt_ecl:
3197 case Intrinsic::nvvm_prmt_ecr:
3198 case Intrinsic::nvvm_prmt_f4e:
3199 case Intrinsic::nvvm_prmt_rc16:
3200 case Intrinsic::nvvm_prmt_rc8:
3201 return lowerPrmtIntrinsic(Op, DAG);
3202 case Intrinsic::nvvm_internal_addrspace_wrap:
3203 return Op.getOperand(1);
3204 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
3205 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
3206 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
3207 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
3209 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
3210 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
3211 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
3212 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
3213 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
3214 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
3215 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
3216 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
3217 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
3218 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
3219 return lowerCvtRSIntrinsics(Op, DAG);
3220 }
3221}
3222
3223// In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
3224// Lower these into a node returning the correct type which is zero-extended
3225// back to the correct size.
3227 SDValue V = Op->getOperand(0);
3228 assert(V.getValueType() == MVT::i64 &&
3229 "Unexpected CTLZ/CTPOP type to legalize");
3230
3231 SDLoc DL(Op);
3232 SDValue CT = DAG.getNode(Op->getOpcode(), DL, MVT::i32, V);
3233 return DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, CT, SDNodeFlags::NonNeg);
3234}
3235
3237 unsigned Opcode, SelectionDAG &DAG) {
3238 assert(A.getValueType() == MVT::i64 && B.getValueType() == MVT::i64);
3239
3240 const auto *AmtConst = dyn_cast<ConstantSDNode>(ShiftAmount);
3241 if (!AmtConst)
3242 return SDValue();
3243 const auto Amt = AmtConst->getZExtValue() & 63;
3244
3245 SDValue UnpackA =
3246 DAG.getNode(NVPTXISD::UNPACK_VECTOR, DL, {MVT::i32, MVT::i32}, A);
3247 SDValue UnpackB =
3248 DAG.getNode(NVPTXISD::UNPACK_VECTOR, DL, {MVT::i32, MVT::i32}, B);
3249
3250 // Arch is Little endiain: 0 = low bits, 1 = high bits
3251 SDValue ALo = UnpackA.getValue(0);
3252 SDValue AHi = UnpackA.getValue(1);
3253 SDValue BLo = UnpackB.getValue(0);
3254 SDValue BHi = UnpackB.getValue(1);
3255
3256 // The bitfeild consists of { AHi : ALo : BHi : BLo }
3257 //
3258 // * FSHL, Amt < 32 - The window will contain { AHi : ALo : BHi }
3259 // * FSHL, Amt >= 32 - The window will contain { ALo : BHi : BLo }
3260 // * FSHR, Amt < 32 - The window will contain { ALo : BHi : BLo }
3261 // * FSHR, Amt >= 32 - The window will contain { AHi : ALo : BHi }
3262 //
3263 // Note that Amt = 0 and Amt = 32 are special cases where 32-bit funnel shifts
3264 // are not needed at all. Amt = 0 is a no-op producing either A or B depending
3265 // on the direction. Amt = 32 can be implemented by a packing and unpacking
3266 // move to select and arrange the 32bit values. For simplicity, these cases
3267 // are not handled here explicitly and instead we rely on DAGCombiner to
3268 // remove the no-op funnel shifts we insert.
3269 auto [High, Mid, Low] = ((Opcode == ISD::FSHL) == (Amt < 32))
3270 ? std::make_tuple(AHi, ALo, BHi)
3271 : std::make_tuple(ALo, BHi, BLo);
3272
3273 SDValue NewAmt = DAG.getConstant(Amt & 31, DL, MVT::i32);
3274 SDValue RHi = DAG.getNode(Opcode, DL, MVT::i32, {High, Mid, NewAmt});
3275 SDValue RLo = DAG.getNode(Opcode, DL, MVT::i32, {Mid, Low, NewAmt});
3276
3277 return DAG.getNode(NVPTXISD::BUILD_VECTOR, DL, MVT::i64, {RLo, RHi});
3278}
3279
3281 return expandFSH64(Op->getOperand(0), Op->getOperand(1), Op->getOperand(2),
3282 SDLoc(Op), Op->getOpcode(), DAG);
3283}
3284
3286 unsigned Opcode = Op->getOpcode() == ISD::ROTL ? ISD::FSHL : ISD::FSHR;
3287 return expandFSH64(Op->getOperand(0), Op->getOperand(0), Op->getOperand(1),
3288 SDLoc(Op), Opcode, DAG);
3289}
3290
3292 // Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
3293 // i.e. "poor man's fmod()". When y is infinite, x is returned. This matches
3294 // the semantics of LLVM's frem.
3295 SDLoc DL(Op);
3296 SDValue X = Op->getOperand(0);
3297 SDValue Y = Op->getOperand(1);
3298 EVT Ty = Op.getValueType();
3299 SDNodeFlags Flags = Op->getFlags();
3300
3301 SDValue Div = DAG.getNode(ISD::FDIV, DL, Ty, X, Y, Flags);
3302 SDValue Trunc = DAG.getNode(ISD::FTRUNC, DL, Ty, Div, Flags);
3303 SDValue Mul = DAG.getNode(ISD::FMUL, DL, Ty, Trunc, Y,
3305 SDValue Sub = DAG.getNode(ISD::FSUB, DL, Ty, X, Mul,
3307
3308 if (Flags.hasNoInfs())
3309 return Sub;
3310
3311 // If Y is infinite, return X
3312 SDValue AbsY = DAG.getNode(ISD::FABS, DL, Ty, Y);
3313 SDValue Inf =
3314 DAG.getConstantFP(APFloat::getInf(Ty.getFltSemantics()), DL, Ty);
3315 SDValue IsInf = DAG.getSetCC(DL, MVT::i1, AbsY, Inf, ISD::SETEQ);
3316 return DAG.getSelect(DL, Ty, IsInf, X, Sub);
3317}
3318
3320 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
3321
3322 SDValue Cond = Op->getOperand(0);
3323 SDValue TrueVal = Op->getOperand(1);
3324 SDValue FalseVal = Op->getOperand(2);
3325 SDLoc DL(Op);
3326
3327 // If both operands are truncated, we push the select through the truncates.
3328 if (TrueVal.getOpcode() == ISD::TRUNCATE &&
3329 FalseVal.getOpcode() == ISD::TRUNCATE) {
3330 TrueVal = TrueVal.getOperand(0);
3331 FalseVal = FalseVal.getOperand(0);
3332
3333 EVT VT = TrueVal.getSimpleValueType().bitsLE(FalseVal.getSimpleValueType())
3334 ? TrueVal.getValueType()
3335 : FalseVal.getValueType();
3336 TrueVal = DAG.getAnyExtOrTrunc(TrueVal, DL, VT);
3337 FalseVal = DAG.getAnyExtOrTrunc(FalseVal, DL, VT);
3338 SDValue Select = DAG.getSelect(DL, VT, Cond, TrueVal, FalseVal);
3339 return DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Select);
3340 }
3341
3342 // Otherwise, expand the select into a series of logical operations. These
3343 // often can be folded into other operations either by us or ptxas.
3344 TrueVal = DAG.getFreeze(TrueVal);
3345 FalseVal = DAG.getFreeze(FalseVal);
3346 SDValue And1 = DAG.getNode(ISD::AND, DL, MVT::i1, Cond, TrueVal);
3347 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
3348 SDValue And2 = DAG.getNode(ISD::AND, DL, MVT::i1, NotCond, FalseVal);
3349 SDValue Or = DAG.getNode(ISD::OR, DL, MVT::i1, And1, And2);
3350 return Or;
3351}
3352
3354 SDNode *N = Op.getNode();
3355
3356 SDValue Chain = N->getOperand(0);
3357 SDValue Val = N->getOperand(1);
3358 SDValue BasePtr = N->getOperand(2);
3359 SDValue Offset = N->getOperand(3);
3360 SDValue Mask = N->getOperand(4);
3361
3362 SDLoc DL(N);
3363 EVT ValVT = Val.getValueType();
3364 MemSDNode *MemSD = cast<MemSDNode>(N);
3365 assert(ValVT.isVector() && "Masked vector store must have vector type");
3366 assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
3367 "Unexpected alignment for masked store");
3368
3369 unsigned Opcode = 0;
3370 switch (ValVT.getSimpleVT().SimpleTy) {
3371 default:
3372 llvm_unreachable("Unexpected masked vector store type");
3373 case MVT::v4i64:
3374 case MVT::v4f64: {
3375 Opcode = NVPTXISD::StoreV4;
3376 break;
3377 }
3378 case MVT::v8i32:
3379 case MVT::v8f32: {
3380 Opcode = NVPTXISD::StoreV8;
3381 break;
3382 }
3383 }
3384
3386
3387 // Construct the new SDNode. First operand is the chain.
3388 Ops.push_back(Chain);
3389
3390 // The next N operands are the values to store. Encode the mask into the
3391 // values using the sentinel register 0 to represent a masked-off element.
3392 assert(Mask.getValueType().isVector() &&
3393 Mask.getValueType().getVectorElementType() == MVT::i1 &&
3394 "Mask must be a vector of i1");
3395 assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
3396 "Mask expected to be a BUILD_VECTOR");
3397 assert(Mask.getValueType().getVectorNumElements() ==
3398 ValVT.getVectorNumElements() &&
3399 "Mask size must be the same as the vector size");
3400 for (auto [I, Op] : enumerate(Mask->ops())) {
3401 // Mask elements must be constants.
3402 if (Op.getNode()->getAsZExtVal() == 0) {
3403 // Append a sentinel register 0 to the Ops vector to represent a masked
3404 // off element, this will be handled in tablegen
3406 ValVT.getVectorElementType()));
3407 } else {
3408 // Extract the element from the vector to store
3409 SDValue ExtVal =
3411 Val, DAG.getIntPtrConstant(I, DL));
3412 Ops.push_back(ExtVal);
3413 }
3414 }
3415
3416 // Next, the pointer operand.
3417 Ops.push_back(BasePtr);
3418
3419 // Finally, the offset operand. We expect this to always be undef, and it will
3420 // be ignored in lowering, but to mirror the handling of the other vector
3421 // store instructions we include it in the new SDNode.
3422 assert(Offset.getOpcode() == ISD::UNDEF &&
3423 "Offset operand expected to be undef");
3424 Ops.push_back(Offset);
3425
3426 SDValue NewSt =
3427 DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
3428 MemSD->getMemoryVT(), MemSD->getMemOperand());
3429
3430 return NewSt;
3431}
3432
3433SDValue
3435 switch (Op.getOpcode()) {
3436 case ISD::RETURNADDR:
3437 return SDValue();
3438 case ISD::FRAMEADDR:
3439 return SDValue();
3440 case ISD::ADDRSPACECAST:
3441 return LowerADDRSPACECAST(Op, DAG);
3443 return lowerIntrinsicWChain(Op, DAG);
3445 return lowerIntrinsicWOChain(Op, DAG);
3447 return lowerIntrinsicVoid(Op, DAG);
3448 case ISD::BUILD_VECTOR:
3449 return LowerBUILD_VECTOR(Op, DAG);
3450 case ISD::BITCAST:
3451 return LowerBITCAST(Op, DAG);
3453 return Op;
3455 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
3457 return LowerINSERT_VECTOR_ELT(Op, DAG);
3459 return LowerVECTOR_SHUFFLE(Op, DAG);
3461 return LowerCONCAT_VECTORS(Op, DAG);
3466 return LowerVECREDUCE(Op, DAG);
3467 case ISD::STORE:
3468 return LowerSTORE(Op, DAG);
3469 case ISD::MSTORE: {
3470 assert(STI.has256BitVectorLoadStore(
3471 cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
3472 "Masked store vector not supported on subtarget.");
3473 return lowerMSTORE(Op, DAG);
3474 }
3475 case ISD::LOAD:
3476 return LowerLOAD(Op, DAG);
3477 case ISD::MLOAD:
3478 return LowerMLOAD(Op, DAG);
3479 case ISD::SHL_PARTS:
3480 return LowerShiftLeftParts(Op, DAG);
3481 case ISD::SRA_PARTS:
3482 case ISD::SRL_PARTS:
3483 return LowerShiftRightParts(Op, DAG);
3484 case ISD::SELECT:
3485 return lowerSELECT(Op, DAG);
3486 case ISD::FROUND:
3487 return LowerFROUND(Op, DAG);
3488 case ISD::FCOPYSIGN:
3489 return LowerFCOPYSIGN(Op, DAG);
3490 case ISD::SINT_TO_FP:
3491 case ISD::UINT_TO_FP:
3492 return LowerINT_TO_FP(Op, DAG);
3493 case ISD::FP_TO_SINT:
3494 case ISD::FP_TO_UINT:
3495 // fptosi/fptoui to i1 truncate toward zero, so the only defined results
3496 // are {0,-1} (signed) and {0,1} (unsigned); every other input results in
3497 // poison. Thus we can simply lower to `x <= -1.0` or `x >= 1.0`.
3498 if (Op.getValueType() == MVT::i1) {
3499 SDLoc DL(Op);
3500 SDValue X = Op.getOperand(0);
3501 bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT;
3502 return DAG.getSetCC(
3503 DL, MVT::i1, X,
3504 DAG.getConstantFP(IsSigned ? -1.0 : 1.0, DL, X.getValueType()),
3505 IsSigned ? ISD::SETOLE : ISD::SETOGE);
3506 }
3507 return LowerFP_TO_INT(Op, DAG);
3508 case ISD::FP_ROUND:
3509 return LowerFP_ROUND(Op, DAG);
3510 case ISD::FP_EXTEND:
3511 return LowerFP_EXTEND(Op, DAG);
3512 case ISD::VAARG:
3513 return LowerVAARG(Op, DAG);
3514 case ISD::VASTART:
3515 return LowerVASTART(Op, DAG);
3516 case ISD::FSHL:
3517 case ISD::FSHR:
3518 return lowerFSH(Op, DAG);
3519 case ISD::ROTL:
3520 case ISD::ROTR:
3521 return lowerROT(Op, DAG);
3522 case ISD::ABS:
3524 case ISD::SMIN:
3525 case ISD::SMAX:
3526 case ISD::UMIN:
3527 case ISD::UMAX:
3528 case ISD::ADD:
3529 case ISD::SUB:
3530 case ISD::MUL:
3531 case ISD::SHL:
3532 case ISD::SREM:
3533 case ISD::UREM:
3534 return LowerVectorArith(Op, DAG);
3536 return LowerDYNAMIC_STACKALLOC(Op, DAG);
3537 case ISD::STACKRESTORE:
3538 return LowerSTACKRESTORE(Op, DAG);
3539 case ISD::STACKSAVE:
3540 return LowerSTACKSAVE(Op, DAG);
3541 case ISD::CopyToReg:
3542 return LowerCopyToReg_128(Op, DAG);
3543 case ISD::FADD:
3544 case ISD::FSUB:
3545 case ISD::FMUL:
3546 // Used only for bf16 on SM80, where we select fma for non-ftz operation
3547 return PromoteBinOpIfF32FTZ(Op, DAG);
3548 case ISD::CTPOP:
3549 case ISD::CTLZ:
3550 return lowerCTLZCTPOP(Op, DAG);
3551 case ISD::FREM:
3552 return lowerFREM(Op, DAG);
3553 case ISD::BSWAP:
3554 return lowerBSWAP(Op, DAG);
3555 default:
3556 llvm_unreachable("Custom lowering not defined for operation");
3557 }
3558}
3559
3560// This will prevent AsmPrinter from trying to print the jump tables itself.
3564
3565SDValue NVPTXTargetLowering::LowerADDRSPACECAST(SDValue Op,
3566 SelectionDAG &DAG) const {
3568 unsigned SrcAS = N->getSrcAddressSpace();
3569 unsigned DestAS = N->getDestAddressSpace();
3570 if (SrcAS != llvm::ADDRESS_SPACE_GENERIC &&
3571 DestAS != llvm::ADDRESS_SPACE_GENERIC) {
3572 // Shared and SharedCluster can be converted to each other through generic
3573 // space
3574 if ((SrcAS == llvm::ADDRESS_SPACE_SHARED &&
3577 DestAS == llvm::ADDRESS_SPACE_SHARED)) {
3578 SDLoc DL(Op.getNode());
3579 const MVT GenerictVT =
3581 SDValue GenericConversion = DAG.getAddrSpaceCast(
3582 DL, GenerictVT, Op.getOperand(0), SrcAS, ADDRESS_SPACE_GENERIC);
3583 SDValue SharedClusterConversion =
3584 DAG.getAddrSpaceCast(DL, Op.getValueType(), GenericConversion,
3585 ADDRESS_SPACE_GENERIC, DestAS);
3586 return SharedClusterConversion;
3587 }
3588
3589 return DAG.getUNDEF(Op.getValueType());
3590 }
3591
3592 return Op;
3593}
3594
3595// This function is almost a copy of SelectionDAG::expandVAArg().
3596// The only diff is that this one produces loads from local address space.
3597SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
3598 const TargetLowering *TLI = STI.getTargetLowering();
3599 SDLoc DL(Op);
3600
3601 SDNode *Node = Op.getNode();
3602 const Value *V = cast<SrcValueSDNode>(Node->getOperand(2))->getValue();
3603 EVT VT = Node->getValueType(0);
3604 auto *Ty = VT.getTypeForEVT(*DAG.getContext());
3605 SDValue Tmp1 = Node->getOperand(0);
3606 SDValue Tmp2 = Node->getOperand(1);
3607 const MaybeAlign MA(Node->getConstantOperandVal(3));
3608
3609 SDValue VAListLoad = DAG.getLoad(TLI->getPointerTy(DAG.getDataLayout()), DL,
3610 Tmp1, Tmp2, MachinePointerInfo(V));
3611 SDValue VAList = VAListLoad;
3612
3613 if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
3614 VAList = DAG.getNode(
3615 ISD::ADD, DL, VAList.getValueType(), VAList,
3616 DAG.getConstant(MA->value() - 1, DL, VAList.getValueType()));
3617
3618 VAList = DAG.getNode(ISD::AND, DL, VAList.getValueType(), VAList,
3619 DAG.getSignedConstant(-(int64_t)MA->value(), DL,
3620 VAList.getValueType()));
3621 }
3622
3623 // Increment the pointer, VAList, to the next vaarg
3624 Tmp1 = DAG.getNode(ISD::ADD, DL, VAList.getValueType(), VAList,
3626 DL, VAList.getValueType()));
3627
3628 // Store the incremented VAList to the legalized pointer
3629 Tmp1 = DAG.getStore(VAListLoad.getValue(1), DL, Tmp1, Tmp2,
3630 MachinePointerInfo(V));
3631
3632 const Value *SrcV = Constant::getNullValue(
3634
3635 // Load the actual argument out of the pointer VAList
3636 return DAG.getLoad(VT, DL, Tmp1, VAList, MachinePointerInfo(SrcV));
3637}
3638
3639SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
3640 const TargetLowering *TLI = STI.getTargetLowering();
3641 SDLoc DL(Op);
3642 EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());
3643
3644 // Store the address of unsized array <function>_vararg[] in the ap object.
3645 SDValue VAReg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
3646
3647 const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
3648 return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
3649 MachinePointerInfo(SV));
3650}
3651
3652static std::pair<MemSDNode *, uint32_t>
3654 const NVPTXSubtarget &STI) {
3655 SDValue Chain = N->getOperand(0);
3656 SDValue BasePtr = N->getOperand(1);
3657 SDValue Mask = N->getOperand(3);
3658 [[maybe_unused]] SDValue Passthru = N->getOperand(4);
3659
3660 SDLoc DL(N);
3661 EVT ResVT = N->getValueType(0);
3662 assert(ResVT.isVector() && "Masked vector load must have vector type");
3663 // While we only expect poison passthru vectors as an input to the backend,
3664 // when the legalization framework splits a poison vector in half, it creates
3665 // two undef vectors, so we can technically expect those too.
3666 assert((Passthru.getOpcode() == ISD::POISON ||
3667 Passthru.getOpcode() == ISD::UNDEF) &&
3668 "Passthru operand expected to be poison or undef");
3669
3670 // Extract the mask and convert it to a uint32_t representing the used bytes
3671 // of the entire vector load
3672 uint32_t UsedBytesMask = 0;
3673 uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
3674 assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
3675 uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
3676 uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
3677
3678 for (SDValue Op : reverse(Mask->ops())) {
3679 // We technically only want to do this shift for every
3680 // iteration *but* the first, but in the first iteration UsedBytesMask is 0,
3681 // so this shift is a no-op.
3682 UsedBytesMask <<= ElementSizeInBytes;
3683
3684 // Mask elements must be constants.
3685 if (Op->getAsZExtVal() != 0)
3686 UsedBytesMask |= ElementMask;
3687 }
3688
3689 assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
3690 "Unexpected masked load with elements masked all on or all off");
3691
3692 // Create a new load sd node to be handled normally by ReplaceLoadVector.
3693 MemSDNode *NewLD = cast<MemSDNode>(
3694 DAG.getLoad(ResVT, DL, Chain, BasePtr, N->getMemOperand()).getNode());
3695
3696 // If our subtarget does not support the used bytes mask pragma, "drop" the
3697 // mask by setting it to UINT32_MAX
3698 if (!STI.hasUsedBytesMaskPragma())
3699 UsedBytesMask = UINT32_MAX;
3700
3701 return {NewLD, UsedBytesMask};
3702}
3703
3704/// replaceLoadVector - Convert vector loads into multi-output scalar loads.
3705static std::optional<std::pair<SDValue, SDValue>>
3708 const EVT ResVT = LD->getValueType(0);
3709 const EVT MemVT = LD->getMemoryVT();
3710
3711 // If we're doing sign/zero extension as part of the load, avoid lowering to
3712 // a LoadV node. TODO: consider relaxing this restriction.
3713 if (ResVT != MemVT)
3714 return std::nullopt;
3715
3716 const auto NumEltsAndEltVT =
3717 getVectorLoweringShape(ResVT, STI, LD->getAddressSpace());
3718 if (!NumEltsAndEltVT)
3719 return std::nullopt;
3720 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3721
3722 Align Alignment = LD->getAlign();
3723 const auto &TD = DAG.getDataLayout();
3724 Align PrefAlign = TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DAG.getContext()));
3725 if (Alignment < PrefAlign) {
3726 // This load is not sufficiently aligned, so bail out and let this vector
3727 // load be scalarized. Note that we may still be able to emit smaller
3728 // vector loads. For example, if we are loading a <4 x float> with an
3729 // alignment of 8, this check will fail but the legalizer will try again
3730 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3731 return std::nullopt;
3732 }
3733
3734 // If we have a masked load, convert it to a normal load now
3735 std::optional<uint32_t> UsedBytesMask = std::nullopt;
3736 if (LD->getOpcode() == ISD::MLOAD)
3737 std::tie(LD, UsedBytesMask) =
3739
3740 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
3741 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
3742 // loaded type to i16 and propagate the "real" type as the memory type.
3743 const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
3744
3745 unsigned Opcode;
3746 switch (NumElts) {
3747 default:
3748 return std::nullopt;
3749 case 2:
3750 Opcode = NVPTXISD::LoadV2;
3751 break;
3752 case 4:
3753 Opcode = NVPTXISD::LoadV4;
3754 break;
3755 case 8:
3756 Opcode = NVPTXISD::LoadV8;
3757 break;
3758 }
3759 auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT);
3760 ListVTs.push_back(MVT::Other);
3761 SDVTList LdResVTs = DAG.getVTList(ListVTs);
3762
3763 SDLoc DL(LD);
3764
3765 // Copy regular operands
3766 SmallVector<SDValue, 8> OtherOps(LD->ops());
3767
3768 OtherOps.push_back(
3769 DAG.getConstant(UsedBytesMask.value_or(UINT32_MAX), DL, MVT::i32));
3770
3771 // The select routine does not have access to the LoadSDNode instance, so
3772 // pass along the extension information
3773 OtherOps.push_back(
3774 DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
3775
3776 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT,
3777 LD->getMemOperand());
3778
3779 SmallVector<SDValue> ScalarRes;
3780 if (EltVT.isVector()) {
3782 assert(NumElts * EltVT.getVectorNumElements() ==
3783 ResVT.getVectorNumElements());
3784 // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
3785 // into individual elements.
3786 for (const unsigned I : llvm::seq(NumElts)) {
3787 SDValue SubVector = NewLD.getValue(I);
3788 DAG.ExtractVectorElements(SubVector, ScalarRes);
3789 }
3790 } else {
3791 for (const unsigned I : llvm::seq(NumElts)) {
3792 SDValue Res = NewLD.getValue(I);
3793 if (LoadEltVT != EltVT)
3794 Res = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res);
3795 ScalarRes.push_back(Res);
3796 }
3797 }
3798
3799 SDValue LoadChain = NewLD.getValue(NumElts);
3800
3801 const MVT BuildVecVT =
3802 MVT::getVectorVT(EltVT.getScalarType(), ScalarRes.size());
3803 SDValue BuildVec = DAG.getBuildVector(BuildVecVT, DL, ScalarRes);
3804 SDValue LoadValue = DAG.getBitcast(ResVT, BuildVec);
3805
3806 return {{LoadValue, LoadChain}};
3807}
3808
3811 const NVPTXSubtarget &STI) {
3812 if (auto Res = replaceLoadVector(N, DAG, STI))
3813 Results.append({Res->first, Res->second});
3814}
3815
3817 const NVPTXSubtarget &STI) {
3818 if (auto Res = replaceLoadVector(N, DAG, STI))
3819 return DAG.getMergeValues({Res->first, Res->second}, SDLoc(N));
3820 return SDValue();
3821}
3822
3823// v = ld i1* addr
3824// =>
3825// v1 = ld i8* addr (-> i16)
3826// v = trunc i16 to i1
3828 SDLoc dl(LD);
3829 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
3830 assert(LD->getValueType(0) == MVT::i1 && "Custom lowering for i1 load only");
3831 SDValue newLD = DAG.getExtLoad(ISD::ZEXTLOAD, dl, MVT::i16, LD->getChain(),
3832 LD->getBasePtr(), LD->getPointerInfo(),
3833 MVT::i8, LD->getAlign(),
3834 LD->getMemOperand()->getFlags());
3835 SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
3836 // The legalizer (the caller) is expecting two values from the legalized
3837 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
3838 // in LegalizeDAG.cpp which also uses MergeValues.
3839 return DAG.getMergeValues({result, LD->getChain()}, dl);
3840}
3841
3842SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3843 LoadSDNode *LD = cast<LoadSDNode>(Op);
3844
3845 if (Op.getValueType() == MVT::i1)
3846 return lowerLOADi1(LD, DAG);
3847
3848 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
3849 // how they'll be lowered in ISel anyway, and by doing this a little earlier
3850 // we allow for more DAG combine opportunities.
3851 if (LD->getExtensionType() == ISD::EXTLOAD) {
3852 assert(LD->getValueType(0).isInteger() && LD->getMemoryVT().isInteger() &&
3853 "Unexpected fpext-load");
3854 return DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Op), Op.getValueType(),
3855 LD->getChain(), LD->getBasePtr(), LD->getMemoryVT(),
3856 LD->getMemOperand());
3857 }
3858
3859 llvm_unreachable("Unexpected custom lowering for load");
3860}
3861
3862SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
3863 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3864 // masked loads of these types and have to handle them here.
3865 // v2f32 also needs to be handled here if the subtarget has f32x2
3866 // instructions, making it legal.
3867 //
3868 // Note: misaligned masked loads should never reach this point
3869 // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
3870 // will validate alignment. Therefore, we do not need to special case handle
3871 // them here.
3872 EVT VT = Op.getValueType();
3873 if (NVPTX::isPackedVectorTy(VT)) {
3875 cast<MemSDNode>(Op.getNode()), DAG, STI);
3876 MemSDNode *LD = std::get<0>(Result);
3877 uint32_t UsedBytesMask = std::get<1>(Result);
3878
3879 SDLoc DL(LD);
3880
3881 // Copy regular operands
3882 SmallVector<SDValue, 8> OtherOps(LD->ops());
3883
3884 OtherOps.push_back(DAG.getConstant(UsedBytesMask, DL, MVT::i32));
3885
3886 // We currently are not lowering extending loads, but pass the extension
3887 // type anyway as later handling expects it.
3888 OtherOps.push_back(
3889 DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
3890 SDValue NewLD =
3891 DAG.getMemIntrinsicNode(NVPTXISD::MLoad, DL, LD->getVTList(), OtherOps,
3892 LD->getMemoryVT(), LD->getMemOperand());
3893 return NewLD;
3894 }
3895 return SDValue();
3896}
3897
3899 const NVPTXSubtarget &STI) {
3900 MemSDNode *N = cast<MemSDNode>(Op.getNode());
3901 SDValue Val = N->getOperand(1);
3902 SDLoc DL(N);
3903 const EVT ValVT = Val.getValueType();
3904 const EVT MemVT = N->getMemoryVT();
3905
3906 // If we're truncating as part of the store, avoid lowering to a StoreV node.
3907 // TODO: consider relaxing this restriction.
3908 if (ValVT != MemVT)
3909 return SDValue();
3910
3911 const auto NumEltsAndEltVT =
3912 getVectorLoweringShape(ValVT, STI, N->getAddressSpace());
3913 if (!NumEltsAndEltVT)
3914 return SDValue();
3915 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3916
3917 const DataLayout &TD = DAG.getDataLayout();
3918
3919 Align Alignment = N->getAlign();
3920 Align PrefAlign = TD.getPrefTypeAlign(ValVT.getTypeForEVT(*DAG.getContext()));
3921 if (Alignment < PrefAlign) {
3922 // This store is not sufficiently aligned, so bail out and let this vector
3923 // store be scalarized. Note that we may still be able to emit smaller
3924 // vector stores. For example, if we are storing a <4 x float> with an
3925 // alignment of 8, this check will fail but the legalizer will try again
3926 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3927 return SDValue();
3928 }
3929
3930 unsigned Opcode;
3931 switch (NumElts) {
3932 default:
3933 return SDValue();
3934 case 2:
3935 Opcode = NVPTXISD::StoreV2;
3936 break;
3937 case 4:
3938 Opcode = NVPTXISD::StoreV4;
3939 break;
3940 case 8:
3941 Opcode = NVPTXISD::StoreV8;
3942 break;
3943 }
3944
3946
3947 // First is the chain
3948 Ops.push_back(N->getOperand(0));
3949
3950 // Then the split values
3951 if (EltVT.isVector()) {
3953 assert(NumElts * EltVT.getVectorNumElements() ==
3954 ValVT.getVectorNumElements());
3955 // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
3956 // stored as b32s
3957 const unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
3958 for (const unsigned I : llvm::seq(NumElts)) {
3959 SmallVector<SDValue, 4> SubVectorElts;
3960 DAG.ExtractVectorElements(Val, SubVectorElts, I * NumEltsPerSubVector,
3961 NumEltsPerSubVector);
3962 Ops.push_back(DAG.getBuildVector(EltVT, DL, SubVectorElts));
3963 }
3964 } else {
3965 SDValue V = DAG.getBitcast(MVT::getVectorVT(EltVT, NumElts), Val);
3966 for (const unsigned I : llvm::seq(NumElts)) {
3967 SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, V,
3968 DAG.getIntPtrConstant(I, DL));
3969
3970 // Since StoreV2 is a target node, we cannot rely on DAG type
3971 // legalization. Therefore, we must ensure the type is legal. For i1 and
3972 // i8, we set the stored type to i16 and propagate the "real" type as the
3973 // memory type.
3974 if (EltVT.getSizeInBits() < 16)
3975 ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
3976 Ops.push_back(ExtVal);
3977 }
3978 }
3979
3980 // Then any remaining arguments
3981 Ops.append(N->op_begin() + 2, N->op_end());
3982
3983 SDValue NewSt =
3984 DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
3985 N->getMemoryVT(), N->getMemOperand());
3986
3987 // return DCI.CombineTo(N, NewSt, true);
3988 return NewSt;
3989}
3990
3991SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3992 StoreSDNode *Store = cast<StoreSDNode>(Op);
3993 EVT VT = Store->getMemoryVT();
3994
3995 if (VT == MVT::i1)
3996 return LowerSTOREi1(Op, DAG);
3997
3998 // Lower store of any other vector type, including v2f32 as we want to break
3999 // it apart since this is not a widely-supported type.
4000 return lowerSTOREVector(Op, DAG, STI);
4001}
4002
4003// st i1 v, addr
4004// =>
4005// v1 = zxt v to i16
4006// st.u8 i16, addr
4007SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
4008 SDNode *Node = Op.getNode();
4009 SDLoc dl(Node);
4010 StoreSDNode *ST = cast<StoreSDNode>(Node);
4011 SDValue Tmp1 = ST->getChain();
4012 SDValue Tmp2 = ST->getBasePtr();
4013 SDValue Tmp3 = ST->getValue();
4014 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
4015 Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
4016 SDValue Result =
4017 DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(), MVT::i8,
4018 ST->getAlign(), ST->getMemOperand()->getFlags());
4019 return Result;
4020}
4021
4022SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
4023 SelectionDAG &DAG) const {
4024 // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
4025 // operand so that it can pass the legalization.
4026
4027 assert(Op.getOperand(1).getValueType() == MVT::i128 &&
4028 "Custom lowering for 128-bit CopyToReg only");
4029
4030 SDNode *Node = Op.getNode();
4031 SDLoc DL(Node);
4032
4033 SDValue Cast = DAG.getBitcast(MVT::v2i64, Op->getOperand(2));
4034 SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
4035 DAG.getIntPtrConstant(0, DL));
4036 SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
4037 DAG.getIntPtrConstant(1, DL));
4038
4040 SmallVector<EVT, 3> ResultsType(Node->values());
4041
4042 NewOps[0] = Op->getOperand(0); // Chain
4043 NewOps[1] = Op->getOperand(1); // Dst Reg
4044 NewOps[2] = Lo; // Lower 64-bit
4045 NewOps[3] = Hi; // Higher 64-bit
4046 if (Op.getNumOperands() == 4)
4047 NewOps[4] = Op->getOperand(3); // Glue if exists
4048
4049 return DAG.getNode(ISD::CopyToReg, DL, ResultsType, NewOps);
4050}
4051
4052unsigned NVPTXTargetLowering::getNumRegisters(
4053 LLVMContext &Context, EVT VT,
4054 std::optional<MVT> RegisterVT = std::nullopt) const {
4055 if (VT == MVT::i128 && RegisterVT == MVT::i128)
4056 return 1;
4057 return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
4058}
4059
4060bool NVPTXTargetLowering::splitValueIntoRegisterParts(
4061 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
4062 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
4063 if (Val.getValueType() == MVT::i128 && NumParts == 1) {
4064 Parts[0] = Val;
4065 return true;
4066 }
4067 return false;
4068}
4069
4070// This creates target external symbol for a function parameter.
4071// Name of the symbol is composed from its index and the function name.
4072// Negative index corresponds to special parameter (unsized array) used for
4073// passing variable arguments.
4074SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int I,
4075 EVT T) const {
4076 StringRef SavedStr = nvTM->getStrPool().save(
4078 return DAG.getExternalSymbol(SavedStr.data(), T);
4079}
4080
4081SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, int I,
4082 EVT T) const {
4083 const StringRef SavedStr = nvTM->getStrPool().save("param" + Twine(I));
4084 return DAG.getExternalSymbol(SavedStr.data(), T);
4085}
4086
4088 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
4089 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
4090 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
4091 const DataLayout &DL = DAG.getDataLayout();
4092 LLVMContext &Ctx = *DAG.getContext();
4093 auto PtrVT = getPointerTy(DAG.getDataLayout());
4094
4095 const Function &F = DAG.getMachineFunction().getFunction();
4096 const bool IsKernel = isKernelFunction(F);
4097
4098 SDValue Root = DAG.getRoot();
4099 SmallVector<SDValue, 16> OutChains;
4100
4101 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
4102 // Ins.size() will be larger
4103 // * if there is an aggregate argument with multiple fields (each field
4104 // showing up separately in Ins)
4105 // * if there is a vector argument with more than typical vector-length
4106 // elements (generally if more than 4) where each vector element is
4107 // individually present in Ins.
4108 // So a different index should be used for indexing into Ins.
4109 // See similar issue in LowerCall.
4110
4111 auto AllIns = ArrayRef(Ins);
4112 for (const auto &Arg : F.args()) {
4113 const auto ArgIns = AllIns.take_while(
4114 [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
4115 AllIns = AllIns.drop_front(ArgIns.size());
4116
4117 Type *Ty = Arg.getType();
4118
4119 if (ArgIns.empty())
4120 report_fatal_error("Empty parameter types are not supported");
4121
4122 if (Arg.use_empty()) {
4123 // argument is dead
4124 for (const auto &In : ArgIns) {
4125 assert(!In.Used && "Arg.use_empty() is true but Arg is used?");
4126 InVals.push_back(DAG.getUNDEF(In.VT));
4127 }
4128 continue;
4129 }
4130
4131 SDValue ArgSymbol = getParamSymbol(DAG, Arg.getArgNo(), PtrVT);
4132
4133 // In the following cases, assign a node order of "i+1"
4134 // to newly created nodes. The SDNodes for params have to
4135 // appear in the same order as their order of appearance
4136 // in the original function. "i+1" holds that order.
4137 if (Arg.hasByValAttr()) {
4138 // Param has ByVal attribute
4139 // Return MoveParam(param symbol).
4140 // Ideally, the param symbol can be returned directly,
4141 // but when SDNode builder decides to use it in a CopyToReg(),
4142 // machine instruction fails because TargetExternalSymbol
4143 // (not lowered) is target dependent, and CopyToReg assumes
4144 // the source is lowered.
4145 assert(ArgIns.size() == 1 && "ByVal argument must be a pointer");
4146 const auto &ByvalIn = ArgIns[0];
4147 assert(getValueType(DL, Ty) == ByvalIn.VT &&
4148 "Ins type did not match function type");
4149 assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
4150
4151 SDValue P;
4152 if (IsKernel) {
4153 assert(isParamGridConstant(Arg) && "ByVal argument must be lowered to "
4154 "grid_constant by NVPTXLowerArgs");
4155 P = ArgSymbol;
4156 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4157 } else {
4158 P = DAG.getNode(NVPTXISD::MoveParam, dl, ByvalIn.VT, ArgSymbol);
4159 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4160 P = DAG.getAddrSpaceCast(dl, ByvalIn.VT, P, ADDRESS_SPACE_LOCAL,
4162 }
4163 InVals.push_back(P);
4164 } else {
4167 ComputePTXValueVTs(*this, DL, Ctx, CallConv, Ty, VTs, Offsets);
4168 assert(VTs.size() == ArgIns.size() && "Size mismatch");
4169 assert(VTs.size() == Offsets.size() && "Size mismatch");
4170
4171 const Align ArgAlign = getFunctionArgumentAlignment(
4172 &F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
4173
4174 unsigned I = 0;
4175 const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
4176 for (const unsigned NumElts : VI) {
4177 // i1 is loaded/stored as i8
4178 const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
4179 const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
4180
4181 SDValue VecAddr = DAG.getObjectPtrOffset(
4182 dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
4183
4184 const Align PartAlign = commonAlignment(ArgAlign, Offsets[I]);
4185 const unsigned AS = IsKernel ? NVPTX::AddressSpace::EntryParam
4187 SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
4188 MachinePointerInfo(AS), PartAlign,
4191 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4192 for (const unsigned J : llvm::seq(NumElts)) {
4193 SDValue Elt = getExtractVectorizedValue(P, J, LoadVT, dl, DAG);
4194
4195 Elt = correctParamType(Elt, ArgIns[I + J].VT, ArgIns[I + J].Flags,
4196 DAG, dl);
4197 InVals.push_back(Elt);
4198 }
4199 I += NumElts;
4200 }
4201 }
4202 }
4203
4204 if (!OutChains.empty())
4205 DAG.setRoot(DAG.getTokenFactor(dl, OutChains));
4206
4207 return Chain;
4208}
4209
4210SDValue
4212 bool isVarArg,
4214 const SmallVectorImpl<SDValue> &OutVals,
4215 const SDLoc &dl, SelectionDAG &DAG) const {
4216 const Function &F = DAG.getMachineFunction().getFunction();
4217 Type *RetTy = F.getReturnType();
4218
4219 if (RetTy->isVoidTy()) {
4220 assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
4221 return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);
4222 }
4223
4224 const DataLayout &DL = DAG.getDataLayout();
4225 LLVMContext &Ctx = *DAG.getContext();
4226
4227 const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
4228 const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
4229
4230 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
4231 // 32-bits are sign extended or zero extended, depending on whether
4232 // they are signed or unsigned types.
4233 const bool ExtendIntegerRetVal =
4234 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
4235
4238 ComputePTXValueVTs(*this, DL, Ctx, CallConv, RetTy, VTs, Offsets);
4239 assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
4240
4241 const auto GetRetVal = [&](unsigned I) -> SDValue {
4242 SDValue RetVal = OutVals[I];
4244 RetVal.getValueType() &&
4245 "OutVal type should always be legal");
4246
4247 const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
4248 const EVT StoreVT =
4249 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
4250 return correctParamType(RetVal, StoreVT, Outs[I].Flags, DAG, dl);
4251 };
4252
4253 unsigned I = 0;
4254 const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
4255 for (const unsigned NumElts : VI) {
4256 const MaybeAlign CurrentAlign = ExtendIntegerRetVal
4257 ? MaybeAlign(std::nullopt)
4258 : commonAlignment(RetAlign, Offsets[I]);
4259
4261 NumElts, dl, DAG, [&](unsigned K) { return GetRetVal(I + K); });
4262
4263 SDValue Ptr =
4264 DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
4265
4266 Chain = DAG.getStore(Chain, dl, Val, Ptr,
4268 CurrentAlign);
4269
4270 I += NumElts;
4271 }
4272
4273 return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);
4274}
4275
4277 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
4278 SelectionDAG &DAG) const {
4279 if (Constraint.size() > 1)
4280 return;
4282}
4283
4284// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
4285// TgtMemIntrinsic
4286// because we need the information that is only available in the "Value" type
4287// of destination
4288// pointer. In particular, the address space information.
4291 MachineFunction &MF, unsigned Intrinsic) const {
4292 IntrinsicInfo Info;
4293 switch (Intrinsic) {
4294 default:
4295 return;
4296 case Intrinsic::nvvm_match_all_sync_i32p:
4297 case Intrinsic::nvvm_match_all_sync_i64p:
4298 Info.opc = ISD::INTRINSIC_W_CHAIN;
4299 // memVT is bogus. These intrinsics have IntrInaccessibleMemOnly attribute
4300 // in order to model data exchange with other threads, but perform no real
4301 // memory accesses.
4302 Info.memVT = MVT::i1;
4303
4304 // Our result depends on both our and other thread's arguments.
4306 Infos.push_back(Info);
4307 return;
4308 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
4309 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
4310 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
4311 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
4312 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
4313 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
4314 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
4315 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
4316 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
4317 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
4318 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
4319 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
4320 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
4321 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
4322 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
4323 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
4324 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
4325 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
4326 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
4327 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
4328 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
4329 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
4330 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
4331 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
4332 Info.opc = ISD::INTRINSIC_W_CHAIN;
4333 Info.memVT = MVT::v8f16;
4334 Info.ptrVal = I.getArgOperand(0);
4335 Info.offset = 0;
4336 Info.flags = MachineMemOperand::MOLoad;
4337 Info.align = Align(16);
4338 Infos.push_back(Info);
4339 return;
4340 }
4341 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
4342 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
4343 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
4344 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
4345 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
4346 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
4347 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
4348 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
4349 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
4350 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
4351 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
4352 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
4353 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
4354 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
4355 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
4356 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
4357 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
4358 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
4359 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
4360 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
4361 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
4362 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
4363 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
4364 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
4365 Info.opc = ISD::INTRINSIC_W_CHAIN;
4366 Info.memVT = MVT::v2i32;
4367 Info.ptrVal = I.getArgOperand(0);
4368 Info.offset = 0;
4369 Info.flags = MachineMemOperand::MOLoad;
4370 Info.align = Align(8);
4371 Infos.push_back(Info);
4372 return;
4373 }
4374
4375 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
4376 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
4377 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
4378 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
4379 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
4380 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
4381 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
4382 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
4383 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
4384 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
4385 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
4386 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
4387 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
4388 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
4389 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
4390 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
4391
4392 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
4393 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
4394 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
4395 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
4396 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
4397 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
4398 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
4399 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
4400 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
4401 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
4402 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
4403 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
4404 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
4405 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
4406 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
4407 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
4408 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
4409 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
4410 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
4411 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
4412 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
4413 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
4414 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
4415 Info.opc = ISD::INTRINSIC_W_CHAIN;
4416 Info.memVT = MVT::v4i32;
4417 Info.ptrVal = I.getArgOperand(0);
4418 Info.offset = 0;
4419 Info.flags = MachineMemOperand::MOLoad;
4420 Info.align = Align(16);
4421 Infos.push_back(Info);
4422 return;
4423 }
4424
4425 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
4426 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
4427 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
4428 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
4429 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
4430 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
4431 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
4432 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
4433
4434 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
4435 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
4436 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
4437 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
4438 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
4439 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
4440 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
4441 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
4442 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
4443 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
4444 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
4445 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
4446 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
4447 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
4448 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
4449 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
4450 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
4451 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
4452 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
4453 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
4454 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
4455 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
4456 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
4457 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
4458 Info.opc = ISD::INTRINSIC_W_CHAIN;
4459 Info.memVT = MVT::i32;
4460 Info.ptrVal = I.getArgOperand(0);
4461 Info.offset = 0;
4462 Info.flags = MachineMemOperand::MOLoad;
4463 Info.align = Align(4);
4464 Infos.push_back(Info);
4465 return;
4466 }
4467
4468 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
4469 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
4470 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
4471 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
4472 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
4473 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
4474 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
4475 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
4476 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
4477 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
4478 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
4479 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
4480 Info.opc = ISD::INTRINSIC_W_CHAIN;
4481 Info.memVT = MVT::v4f16;
4482 Info.ptrVal = I.getArgOperand(0);
4483 Info.offset = 0;
4484 Info.flags = MachineMemOperand::MOLoad;
4485 Info.align = Align(16);
4486 Infos.push_back(Info);
4487 return;
4488 }
4489
4490 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
4491 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
4492 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
4493 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
4494 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
4495 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
4496 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
4497 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
4498 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
4499 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
4500 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
4501 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
4502 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
4503 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
4504 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
4505 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
4506 Info.opc = ISD::INTRINSIC_W_CHAIN;
4507 Info.memVT = MVT::v8f32;
4508 Info.ptrVal = I.getArgOperand(0);
4509 Info.offset = 0;
4510 Info.flags = MachineMemOperand::MOLoad;
4511 Info.align = Align(16);
4512 Infos.push_back(Info);
4513 return;
4514 }
4515
4516 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
4517 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
4518 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
4519 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
4520
4521 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
4522 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
4523 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
4524 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
4525
4526 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
4527 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
4528 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
4529 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
4530 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
4531 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
4532 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
4533 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
4534 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
4535 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
4536 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
4537 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
4538 Info.opc = ISD::INTRINSIC_W_CHAIN;
4539 Info.memVT = MVT::v8i32;
4540 Info.ptrVal = I.getArgOperand(0);
4541 Info.offset = 0;
4542 Info.flags = MachineMemOperand::MOLoad;
4543 Info.align = Align(16);
4544 Infos.push_back(Info);
4545 return;
4546 }
4547
4548 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
4549 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
4550 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
4551 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
4552 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
4553 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
4554 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
4555 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
4556 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
4557 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
4558 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
4559 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
4560 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
4561 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
4562 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
4563 Info.opc = ISD::INTRINSIC_W_CHAIN;
4564 Info.memVT = MVT::v2i32;
4565 Info.ptrVal = I.getArgOperand(0);
4566 Info.offset = 0;
4567 Info.flags = MachineMemOperand::MOLoad;
4568 Info.align = Align(8);
4569 Infos.push_back(Info);
4570 return;
4571 }
4572
4573 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
4574 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
4575 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
4576 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
4577
4578 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
4579 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
4580 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
4581 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
4582 Info.opc = ISD::INTRINSIC_W_CHAIN;
4583 Info.memVT = MVT::f64;
4584 Info.ptrVal = I.getArgOperand(0);
4585 Info.offset = 0;
4586 Info.flags = MachineMemOperand::MOLoad;
4587 Info.align = Align(8);
4588 Infos.push_back(Info);
4589 return;
4590 }
4591
4592 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
4593 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
4594 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
4595 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
4596 Info.opc = ISD::INTRINSIC_W_CHAIN;
4597 Info.memVT = MVT::v2f64;
4598 Info.ptrVal = I.getArgOperand(0);
4599 Info.offset = 0;
4600 Info.flags = MachineMemOperand::MOLoad;
4601 Info.align = Align(16);
4602 Infos.push_back(Info);
4603 return;
4604 }
4605
4606 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
4607 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
4608 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
4609 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
4610 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
4611 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
4612 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
4613 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
4614 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
4615 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
4616 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
4617 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
4618 Info.opc = ISD::INTRINSIC_VOID;
4619 Info.memVT = MVT::v4f16;
4620 Info.ptrVal = I.getArgOperand(0);
4621 Info.offset = 0;
4622 Info.flags = MachineMemOperand::MOStore;
4623 Info.align = Align(16);
4624 Infos.push_back(Info);
4625 return;
4626 }
4627
4628 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
4629 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
4630 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
4631 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
4632 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
4633 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
4634 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
4635 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
4636 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
4637 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
4638 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
4639 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
4640 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
4641 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
4642 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
4643 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
4644 Info.opc = ISD::INTRINSIC_VOID;
4645 Info.memVT = MVT::v8f32;
4646 Info.ptrVal = I.getArgOperand(0);
4647 Info.offset = 0;
4648 Info.flags = MachineMemOperand::MOStore;
4649 Info.align = Align(16);
4650 Infos.push_back(Info);
4651 return;
4652 }
4653
4654 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
4655 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
4656 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
4657 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
4658 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
4659 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
4660 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
4661 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
4662 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
4663 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
4664 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
4665 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
4666 Info.opc = ISD::INTRINSIC_VOID;
4667 Info.memVT = MVT::v8i32;
4668 Info.ptrVal = I.getArgOperand(0);
4669 Info.offset = 0;
4670 Info.flags = MachineMemOperand::MOStore;
4671 Info.align = Align(16);
4672 Infos.push_back(Info);
4673 return;
4674 }
4675
4676 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
4677 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
4678 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
4679 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
4680 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
4681 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
4682 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
4683 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
4684 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
4685 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
4686 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
4687 Info.opc = ISD::INTRINSIC_VOID;
4688 Info.memVT = MVT::v2i32;
4689 Info.ptrVal = I.getArgOperand(0);
4690 Info.offset = 0;
4691 Info.flags = MachineMemOperand::MOStore;
4692 Info.align = Align(8);
4693 Infos.push_back(Info);
4694 return;
4695 }
4696
4697 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
4698 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
4699 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
4700 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
4701 Info.opc = ISD::INTRINSIC_VOID;
4702 Info.memVT = MVT::v2f64;
4703 Info.ptrVal = I.getArgOperand(0);
4704 Info.offset = 0;
4705 Info.flags = MachineMemOperand::MOStore;
4706 Info.align = Align(16);
4707 Infos.push_back(Info);
4708 return;
4709 }
4710
4711 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
4712 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
4713 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
4714 Info.opc = ISD::INTRINSIC_VOID;
4715 Info.memVT = MVT::i32;
4716 Info.ptrVal = I.getArgOperand(0);
4717 Info.offset = 0;
4718 Info.flags = MachineMemOperand::MOStore;
4719 Info.align = Align(4);
4720 Infos.push_back(Info);
4721 return;
4722 }
4723
4724 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
4725 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
4726 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
4727 Info.opc = ISD::INTRINSIC_VOID;
4728 Info.memVT = MVT::v4i32;
4729 Info.ptrVal = I.getArgOperand(0);
4730 Info.offset = 0;
4731 Info.flags = MachineMemOperand::MOStore;
4732 Info.align = Align(16);
4733 Infos.push_back(Info);
4734 return;
4735 }
4736
4737 case Intrinsic::nvvm_prefetch_tensormap: {
4738 auto &DL = I.getDataLayout();
4739 Info.opc = ISD::INTRINSIC_VOID;
4740 Info.memVT = getPointerTy(DL);
4741 Info.ptrVal = I.getArgOperand(0);
4742 Info.offset = 0;
4743 Info.flags =
4745 Info.align.reset();
4746 Infos.push_back(Info);
4747 return;
4748 }
4749
4750 case Intrinsic::nvvm_tensormap_replace_global_address:
4751 case Intrinsic::nvvm_tensormap_replace_global_stride: {
4752 Info.opc = ISD::INTRINSIC_VOID;
4753 Info.memVT = MVT::i64;
4754 Info.ptrVal = I.getArgOperand(0);
4755 Info.offset = 0;
4756 Info.flags = MachineMemOperand::MOStore;
4757 Info.align.reset();
4758 Infos.push_back(Info);
4759 return;
4760 }
4761
4762 case Intrinsic::nvvm_tensormap_replace_rank:
4763 case Intrinsic::nvvm_tensormap_replace_box_dim:
4764 case Intrinsic::nvvm_tensormap_replace_global_dim:
4765 case Intrinsic::nvvm_tensormap_replace_element_stride:
4766 case Intrinsic::nvvm_tensormap_replace_elemtype:
4767 case Intrinsic::nvvm_tensormap_replace_interleave_layout:
4768 case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
4769 case Intrinsic::nvvm_tensormap_replace_swizzle_atomicity:
4770 case Intrinsic::nvvm_tensormap_replace_fill_mode: {
4771 Info.opc = ISD::INTRINSIC_VOID;
4772 Info.memVT = MVT::i32;
4773 Info.ptrVal = I.getArgOperand(0);
4774 Info.offset = 0;
4775 Info.flags = MachineMemOperand::MOStore;
4776 Info.align.reset();
4777 Infos.push_back(Info);
4778 return;
4779 }
4780
4781 case Intrinsic::nvvm_ldu_global_i:
4782 case Intrinsic::nvvm_ldu_global_f:
4783 case Intrinsic::nvvm_ldu_global_p: {
4784 Info.opc = ISD::INTRINSIC_W_CHAIN;
4785 Info.memVT = getValueType(I.getDataLayout(), I.getType());
4786 Info.ptrVal = I.getArgOperand(0);
4787 Info.offset = 0;
4788 Info.flags = MachineMemOperand::MOLoad;
4789 Info.align = cast<ConstantInt>(I.getArgOperand(1))->getMaybeAlignValue();
4790
4791 Infos.push_back(Info);
4792 return;
4793 }
4794 case Intrinsic::nvvm_tex_1d_v4f32_s32:
4795 case Intrinsic::nvvm_tex_1d_v4f32_f32:
4796 case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
4797 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
4798 case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
4799 case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
4800 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
4801 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
4802 case Intrinsic::nvvm_tex_2d_v4f32_s32:
4803 case Intrinsic::nvvm_tex_2d_v4f32_f32:
4804 case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
4805 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
4806 case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
4807 case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
4808 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
4809 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
4810 case Intrinsic::nvvm_tex_3d_v4f32_s32:
4811 case Intrinsic::nvvm_tex_3d_v4f32_f32:
4812 case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
4813 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
4814 case Intrinsic::nvvm_tex_cube_v4f32_f32:
4815 case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
4816 case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
4817 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
4818 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
4819 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
4820 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
4821 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
4822 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
4823 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
4824 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
4825 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
4826 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
4827 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
4828 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
4829 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
4830 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
4831 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
4832 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
4833 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
4834 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
4835 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
4836 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
4837 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
4838 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
4839 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
4840 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
4841 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
4842 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
4843 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
4844 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
4845 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
4846 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
4847 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
4848 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
4849 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
4850 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
4851 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
4852 Info.opc = ISD::INTRINSIC_W_CHAIN;
4853 Info.memVT = MVT::v4f32;
4854 Info.ptrVal = nullptr;
4855 Info.offset = 0;
4856 Info.flags = MachineMemOperand::MOLoad;
4857 Info.align = Align(16);
4858 Infos.push_back(Info);
4859 return;
4860
4861 case Intrinsic::nvvm_tex_1d_v4s32_s32:
4862 case Intrinsic::nvvm_tex_1d_v4s32_f32:
4863 case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
4864 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
4865 case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
4866 case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
4867 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
4868 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
4869 case Intrinsic::nvvm_tex_2d_v4s32_s32:
4870 case Intrinsic::nvvm_tex_2d_v4s32_f32:
4871 case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
4872 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
4873 case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
4874 case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
4875 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
4876 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
4877 case Intrinsic::nvvm_tex_3d_v4s32_s32:
4878 case Intrinsic::nvvm_tex_3d_v4s32_f32:
4879 case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
4880 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
4881 case Intrinsic::nvvm_tex_cube_v4s32_f32:
4882 case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
4883 case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
4884 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
4885 case Intrinsic::nvvm_tex_cube_v4u32_f32:
4886 case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
4887 case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
4888 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
4889 case Intrinsic::nvvm_tex_1d_v4u32_s32:
4890 case Intrinsic::nvvm_tex_1d_v4u32_f32:
4891 case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
4892 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
4893 case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
4894 case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
4895 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
4896 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
4897 case Intrinsic::nvvm_tex_2d_v4u32_s32:
4898 case Intrinsic::nvvm_tex_2d_v4u32_f32:
4899 case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
4900 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
4901 case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
4902 case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
4903 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
4904 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
4905 case Intrinsic::nvvm_tex_3d_v4u32_s32:
4906 case Intrinsic::nvvm_tex_3d_v4u32_f32:
4907 case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
4908 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
4909 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
4910 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
4911 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
4912 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
4913 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
4914 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
4915 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
4916 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
4917 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
4918 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
4919 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
4920 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
4921 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
4922 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
4923 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
4924 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
4925 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
4926 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
4927 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
4928 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
4929 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
4930 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
4931 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
4932 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
4933 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
4934 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
4935 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
4936 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
4937 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
4938 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
4939 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
4940 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
4941 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
4942 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
4943 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
4944 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
4945 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
4946 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
4947 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
4948 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
4949 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
4950 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
4951 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
4952 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
4953 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
4954 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
4955 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
4956 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
4957 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
4958 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
4959 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
4960 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
4961 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
4962 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
4963 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
4964 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
4965 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
4966 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
4967 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
4968 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
4969 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
4970 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
4971 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
4972 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
4973 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
4974 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
4975 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
4976 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
4977 Info.opc = ISD::INTRINSIC_W_CHAIN;
4978 Info.memVT = MVT::v4i32;
4979 Info.ptrVal = nullptr;
4980 Info.offset = 0;
4981 Info.flags = MachineMemOperand::MOLoad;
4982 Info.align = Align(16);
4983 Infos.push_back(Info);
4984 return;
4985
4986 case Intrinsic::nvvm_suld_1d_i8_clamp:
4987 case Intrinsic::nvvm_suld_1d_v2i8_clamp:
4988 case Intrinsic::nvvm_suld_1d_v4i8_clamp:
4989 case Intrinsic::nvvm_suld_1d_array_i8_clamp:
4990 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
4991 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
4992 case Intrinsic::nvvm_suld_2d_i8_clamp:
4993 case Intrinsic::nvvm_suld_2d_v2i8_clamp:
4994 case Intrinsic::nvvm_suld_2d_v4i8_clamp:
4995 case Intrinsic::nvvm_suld_2d_array_i8_clamp:
4996 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
4997 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
4998 case Intrinsic::nvvm_suld_3d_i8_clamp:
4999 case Intrinsic::nvvm_suld_3d_v2i8_clamp:
5000 case Intrinsic::nvvm_suld_3d_v4i8_clamp:
5001 case Intrinsic::nvvm_suld_1d_i8_trap:
5002 case Intrinsic::nvvm_suld_1d_v2i8_trap:
5003 case Intrinsic::nvvm_suld_1d_v4i8_trap:
5004 case Intrinsic::nvvm_suld_1d_array_i8_trap:
5005 case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
5006 case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
5007 case Intrinsic::nvvm_suld_2d_i8_trap:
5008 case Intrinsic::nvvm_suld_2d_v2i8_trap:
5009 case Intrinsic::nvvm_suld_2d_v4i8_trap:
5010 case Intrinsic::nvvm_suld_2d_array_i8_trap:
5011 case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
5012 case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
5013 case Intrinsic::nvvm_suld_3d_i8_trap:
5014 case Intrinsic::nvvm_suld_3d_v2i8_trap:
5015 case Intrinsic::nvvm_suld_3d_v4i8_trap:
5016 case Intrinsic::nvvm_suld_1d_i8_zero:
5017 case Intrinsic::nvvm_suld_1d_v2i8_zero:
5018 case Intrinsic::nvvm_suld_1d_v4i8_zero:
5019 case Intrinsic::nvvm_suld_1d_array_i8_zero:
5020 case Intrinsic::nvvm_suld_1d_array_v2i8_zero:
5021 case Intrinsic::nvvm_suld_1d_array_v4i8_zero:
5022 case Intrinsic::nvvm_suld_2d_i8_zero:
5023 case Intrinsic::nvvm_suld_2d_v2i8_zero:
5024 case Intrinsic::nvvm_suld_2d_v4i8_zero:
5025 case Intrinsic::nvvm_suld_2d_array_i8_zero:
5026 case Intrinsic::nvvm_suld_2d_array_v2i8_zero:
5027 case Intrinsic::nvvm_suld_2d_array_v4i8_zero:
5028 case Intrinsic::nvvm_suld_3d_i8_zero:
5029 case Intrinsic::nvvm_suld_3d_v2i8_zero:
5030 case Intrinsic::nvvm_suld_3d_v4i8_zero:
5031 Info.opc = ISD::INTRINSIC_W_CHAIN;
5032 Info.memVT = MVT::i8;
5033 Info.ptrVal = nullptr;
5034 Info.offset = 0;
5035 Info.flags = MachineMemOperand::MOLoad;
5036 Info.align = Align(16);
5037 Infos.push_back(Info);
5038 return;
5039
5040 case Intrinsic::nvvm_suld_1d_i16_clamp:
5041 case Intrinsic::nvvm_suld_1d_v2i16_clamp:
5042 case Intrinsic::nvvm_suld_1d_v4i16_clamp:
5043 case Intrinsic::nvvm_suld_1d_array_i16_clamp:
5044 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
5045 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
5046 case Intrinsic::nvvm_suld_2d_i16_clamp:
5047 case Intrinsic::nvvm_suld_2d_v2i16_clamp:
5048 case Intrinsic::nvvm_suld_2d_v4i16_clamp:
5049 case Intrinsic::nvvm_suld_2d_array_i16_clamp:
5050 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
5051 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
5052 case Intrinsic::nvvm_suld_3d_i16_clamp:
5053 case Intrinsic::nvvm_suld_3d_v2i16_clamp:
5054 case Intrinsic::nvvm_suld_3d_v4i16_clamp:
5055 case Intrinsic::nvvm_suld_1d_i16_trap:
5056 case Intrinsic::nvvm_suld_1d_v2i16_trap:
5057 case Intrinsic::nvvm_suld_1d_v4i16_trap:
5058 case Intrinsic::nvvm_suld_1d_array_i16_trap:
5059 case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
5060 case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
5061 case Intrinsic::nvvm_suld_2d_i16_trap:
5062 case Intrinsic::nvvm_suld_2d_v2i16_trap:
5063 case Intrinsic::nvvm_suld_2d_v4i16_trap:
5064 case Intrinsic::nvvm_suld_2d_array_i16_trap:
5065 case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
5066 case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
5067 case Intrinsic::nvvm_suld_3d_i16_trap:
5068 case Intrinsic::nvvm_suld_3d_v2i16_trap:
5069 case Intrinsic::nvvm_suld_3d_v4i16_trap:
5070 case Intrinsic::nvvm_suld_1d_i16_zero:
5071 case Intrinsic::nvvm_suld_1d_v2i16_zero:
5072 case Intrinsic::nvvm_suld_1d_v4i16_zero:
5073 case Intrinsic::nvvm_suld_1d_array_i16_zero:
5074 case Intrinsic::nvvm_suld_1d_array_v2i16_zero:
5075 case Intrinsic::nvvm_suld_1d_array_v4i16_zero:
5076 case Intrinsic::nvvm_suld_2d_i16_zero:
5077 case Intrinsic::nvvm_suld_2d_v2i16_zero:
5078 case Intrinsic::nvvm_suld_2d_v4i16_zero:
5079 case Intrinsic::nvvm_suld_2d_array_i16_zero:
5080 case Intrinsic::nvvm_suld_2d_array_v2i16_zero:
5081 case Intrinsic::nvvm_suld_2d_array_v4i16_zero:
5082 case Intrinsic::nvvm_suld_3d_i16_zero:
5083 case Intrinsic::nvvm_suld_3d_v2i16_zero:
5084 case Intrinsic::nvvm_suld_3d_v4i16_zero:
5085 Info.opc = ISD::INTRINSIC_W_CHAIN;
5086 Info.memVT = MVT::i16;
5087 Info.ptrVal = nullptr;
5088 Info.offset = 0;
5089 Info.flags = MachineMemOperand::MOLoad;
5090 Info.align = Align(16);
5091 Infos.push_back(Info);
5092 return;
5093
5094 case Intrinsic::nvvm_suld_1d_i32_clamp:
5095 case Intrinsic::nvvm_suld_1d_v2i32_clamp:
5096 case Intrinsic::nvvm_suld_1d_v4i32_clamp:
5097 case Intrinsic::nvvm_suld_1d_array_i32_clamp:
5098 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
5099 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
5100 case Intrinsic::nvvm_suld_2d_i32_clamp:
5101 case Intrinsic::nvvm_suld_2d_v2i32_clamp:
5102 case Intrinsic::nvvm_suld_2d_v4i32_clamp:
5103 case Intrinsic::nvvm_suld_2d_array_i32_clamp:
5104 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
5105 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
5106 case Intrinsic::nvvm_suld_3d_i32_clamp:
5107 case Intrinsic::nvvm_suld_3d_v2i32_clamp:
5108 case Intrinsic::nvvm_suld_3d_v4i32_clamp:
5109 case Intrinsic::nvvm_suld_1d_i32_trap:
5110 case Intrinsic::nvvm_suld_1d_v2i32_trap:
5111 case Intrinsic::nvvm_suld_1d_v4i32_trap:
5112 case Intrinsic::nvvm_suld_1d_array_i32_trap:
5113 case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
5114 case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
5115 case Intrinsic::nvvm_suld_2d_i32_trap:
5116 case Intrinsic::nvvm_suld_2d_v2i32_trap:
5117 case Intrinsic::nvvm_suld_2d_v4i32_trap:
5118 case Intrinsic::nvvm_suld_2d_array_i32_trap:
5119 case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
5120 case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
5121 case Intrinsic::nvvm_suld_3d_i32_trap:
5122 case Intrinsic::nvvm_suld_3d_v2i32_trap:
5123 case Intrinsic::nvvm_suld_3d_v4i32_trap:
5124 case Intrinsic::nvvm_suld_1d_i32_zero:
5125 case Intrinsic::nvvm_suld_1d_v2i32_zero:
5126 case Intrinsic::nvvm_suld_1d_v4i32_zero:
5127 case Intrinsic::nvvm_suld_1d_array_i32_zero:
5128 case Intrinsic::nvvm_suld_1d_array_v2i32_zero:
5129 case Intrinsic::nvvm_suld_1d_array_v4i32_zero:
5130 case Intrinsic::nvvm_suld_2d_i32_zero:
5131 case Intrinsic::nvvm_suld_2d_v2i32_zero:
5132 case Intrinsic::nvvm_suld_2d_v4i32_zero:
5133 case Intrinsic::nvvm_suld_2d_array_i32_zero:
5134 case Intrinsic::nvvm_suld_2d_array_v2i32_zero:
5135 case Intrinsic::nvvm_suld_2d_array_v4i32_zero:
5136 case Intrinsic::nvvm_suld_3d_i32_zero:
5137 case Intrinsic::nvvm_suld_3d_v2i32_zero:
5138 case Intrinsic::nvvm_suld_3d_v4i32_zero:
5139 Info.opc = ISD::INTRINSIC_W_CHAIN;
5140 Info.memVT = MVT::i32;
5141 Info.ptrVal = nullptr;
5142 Info.offset = 0;
5143 Info.flags = MachineMemOperand::MOLoad;
5144 Info.align = Align(16);
5145 Infos.push_back(Info);
5146 return;
5147
5148 case Intrinsic::nvvm_suld_1d_i64_clamp:
5149 case Intrinsic::nvvm_suld_1d_v2i64_clamp:
5150 case Intrinsic::nvvm_suld_1d_array_i64_clamp:
5151 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
5152 case Intrinsic::nvvm_suld_2d_i64_clamp:
5153 case Intrinsic::nvvm_suld_2d_v2i64_clamp:
5154 case Intrinsic::nvvm_suld_2d_array_i64_clamp:
5155 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
5156 case Intrinsic::nvvm_suld_3d_i64_clamp:
5157 case Intrinsic::nvvm_suld_3d_v2i64_clamp:
5158 case Intrinsic::nvvm_suld_1d_i64_trap:
5159 case Intrinsic::nvvm_suld_1d_v2i64_trap:
5160 case Intrinsic::nvvm_suld_1d_array_i64_trap:
5161 case Intrinsic::nvvm_suld_1d_array_v2i64_trap:
5162 case Intrinsic::nvvm_suld_2d_i64_trap:
5163 case Intrinsic::nvvm_suld_2d_v2i64_trap:
5164 case Intrinsic::nvvm_suld_2d_array_i64_trap:
5165 case Intrinsic::nvvm_suld_2d_array_v2i64_trap:
5166 case Intrinsic::nvvm_suld_3d_i64_trap:
5167 case Intrinsic::nvvm_suld_3d_v2i64_trap:
5168 case Intrinsic::nvvm_suld_1d_i64_zero:
5169 case Intrinsic::nvvm_suld_1d_v2i64_zero:
5170 case Intrinsic::nvvm_suld_1d_array_i64_zero:
5171 case Intrinsic::nvvm_suld_1d_array_v2i64_zero:
5172 case Intrinsic::nvvm_suld_2d_i64_zero:
5173 case Intrinsic::nvvm_suld_2d_v2i64_zero:
5174 case Intrinsic::nvvm_suld_2d_array_i64_zero:
5175 case Intrinsic::nvvm_suld_2d_array_v2i64_zero:
5176 case Intrinsic::nvvm_suld_3d_i64_zero:
5177 case Intrinsic::nvvm_suld_3d_v2i64_zero:
5178 Info.opc = ISD::INTRINSIC_W_CHAIN;
5179 Info.memVT = MVT::i64;
5180 Info.ptrVal = nullptr;
5181 Info.offset = 0;
5182 Info.flags = MachineMemOperand::MOLoad;
5183 Info.align = Align(16);
5184 Infos.push_back(Info);
5185 return;
5186
5187 case Intrinsic::nvvm_tcgen05_ld_16x64b_x1:
5188 case Intrinsic::nvvm_tcgen05_ld_32x32b_x1:
5189 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x1: {
5190 Info.opc = ISD::INTRINSIC_W_CHAIN;
5191 Info.memVT = MVT::v1i32;
5192 Info.ptrVal = I.getArgOperand(0);
5193 Info.offset = 0;
5194 Info.flags = MachineMemOperand::MOLoad;
5195 Info.align.reset();
5196 Infos.push_back(Info);
5197 return;
5198 }
5199
5200 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
5201 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
5202 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
5203 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
5204 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_i32:
5205 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_i32: {
5206 Info.opc = ISD::INTRINSIC_W_CHAIN;
5207 Info.memVT = MVT::v2i32;
5208 Info.ptrVal = I.getArgOperand(0);
5209 Info.offset = 0;
5210 Info.flags = MachineMemOperand::MOLoad;
5211 Info.align.reset();
5212 Infos.push_back(Info);
5213 return;
5214 }
5215
5216 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_f32:
5217 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_f32: {
5218 Info.opc = ISD::INTRINSIC_W_CHAIN;
5219 Info.memVT = MVT::v2f32;
5220 Info.ptrVal = I.getArgOperand(0);
5221 Info.offset = 0;
5222 Info.flags = MachineMemOperand::MOLoad;
5223 Info.align.reset();
5224 Infos.push_back(Info);
5225 return;
5226 }
5227
5228 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
5229 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
5230 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
5231 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
5232 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
5233 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_i32:
5234 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_i32: {
5235 Info.opc = ISD::INTRINSIC_W_CHAIN;
5236 Info.memVT = MVT::v4i32;
5237 Info.ptrVal = I.getArgOperand(0);
5238 Info.offset = 0;
5239 Info.flags = MachineMemOperand::MOLoad;
5240 Info.align.reset();
5241 Infos.push_back(Info);
5242 return;
5243 }
5244
5245 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_f32:
5246 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_f32: {
5247 Info.opc = ISD::INTRINSIC_W_CHAIN;
5248 Info.memVT = MVT::v4f32;
5249 Info.ptrVal = I.getArgOperand(0);
5250 Info.offset = 0;
5251 Info.flags = MachineMemOperand::MOLoad;
5252 Info.align.reset();
5253 Infos.push_back(Info);
5254 return;
5255 }
5256
5257 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
5258 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
5259 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
5260 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
5261 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
5262 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_i32:
5263 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_i32: {
5264 Info.opc = ISD::INTRINSIC_W_CHAIN;
5265 Info.memVT = MVT::v8i32;
5266 Info.ptrVal = I.getArgOperand(0);
5267 Info.offset = 0;
5268 Info.flags = MachineMemOperand::MOLoad;
5269 Info.align.reset();
5270 Infos.push_back(Info);
5271 return;
5272 }
5273
5274 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_f32:
5275 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_f32: {
5276 Info.opc = ISD::INTRINSIC_W_CHAIN;
5277 Info.memVT = MVT::v8f32;
5278 Info.ptrVal = I.getArgOperand(0);
5279 Info.offset = 0;
5280 Info.flags = MachineMemOperand::MOLoad;
5281 Info.align.reset();
5282 Infos.push_back(Info);
5283 return;
5284 }
5285
5286 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
5287 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
5288 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
5289 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
5290 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
5291 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_i32:
5292 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_i32: {
5293 Info.opc = ISD::INTRINSIC_W_CHAIN;
5294 Info.memVT = MVT::v16i32;
5295 Info.ptrVal = I.getArgOperand(0);
5296 Info.offset = 0;
5297 Info.flags = MachineMemOperand::MOLoad;
5298 Info.align.reset();
5299 Infos.push_back(Info);
5300 return;
5301 }
5302
5303 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_f32:
5304 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_f32: {
5305 Info.opc = ISD::INTRINSIC_W_CHAIN;
5306 Info.memVT = MVT::v16f32;
5307 Info.ptrVal = I.getArgOperand(0);
5308 Info.offset = 0;
5309 Info.flags = MachineMemOperand::MOLoad;
5310 Info.align.reset();
5311 Infos.push_back(Info);
5312 return;
5313 }
5314
5315 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
5316 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
5317 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
5318 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
5319 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
5320 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_i32:
5321 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_i32: {
5322 Info.opc = ISD::INTRINSIC_W_CHAIN;
5323 Info.memVT = MVT::v32i32;
5324 Info.ptrVal = I.getArgOperand(0);
5325 Info.offset = 0;
5326 Info.flags = MachineMemOperand::MOLoad;
5327 Info.align.reset();
5328 Infos.push_back(Info);
5329 return;
5330 }
5331
5332 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_f32:
5333 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_f32: {
5334 Info.opc = ISD::INTRINSIC_W_CHAIN;
5335 Info.memVT = MVT::v32f32;
5336 Info.ptrVal = I.getArgOperand(0);
5337 Info.offset = 0;
5338 Info.flags = MachineMemOperand::MOLoad;
5339 Info.align.reset();
5340 Infos.push_back(Info);
5341 return;
5342 }
5343
5344 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
5345 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
5346 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
5347 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
5348 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
5349 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_i32:
5350 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_i32: {
5351 Info.opc = ISD::INTRINSIC_W_CHAIN;
5352 Info.memVT = MVT::v64i32;
5353 Info.ptrVal = I.getArgOperand(0);
5354 Info.offset = 0;
5355 Info.flags = MachineMemOperand::MOLoad;
5356 Info.align.reset();
5357 Infos.push_back(Info);
5358 return;
5359 }
5360
5361 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_f32:
5362 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_f32: {
5363 Info.opc = ISD::INTRINSIC_W_CHAIN;
5364 Info.memVT = MVT::v64f32;
5365 Info.ptrVal = I.getArgOperand(0);
5366 Info.offset = 0;
5367 Info.flags = MachineMemOperand::MOLoad;
5368 Info.align.reset();
5369 Infos.push_back(Info);
5370 return;
5371 }
5372
5373 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
5374 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
5375 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
5376 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
5377 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
5378 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_i32:
5379 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_i32: {
5380 Info.opc = ISD::INTRINSIC_W_CHAIN;
5381 Info.memVT = MVT::v128i32;
5382 Info.ptrVal = I.getArgOperand(0);
5383 Info.offset = 0;
5384 Info.flags = MachineMemOperand::MOLoad;
5385 Info.align.reset();
5386 Infos.push_back(Info);
5387 return;
5388 }
5389
5390 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_f32:
5391 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_f32: {
5392 Info.opc = ISD::INTRINSIC_W_CHAIN;
5393 Info.memVT = MVT::v128f32;
5394 Info.ptrVal = I.getArgOperand(0);
5395 Info.offset = 0;
5396 Info.flags = MachineMemOperand::MOLoad;
5397 Info.align.reset();
5398 Infos.push_back(Info);
5399 return;
5400 }
5401
5402 case Intrinsic::nvvm_tcgen05_st_16x64b_x1:
5403 case Intrinsic::nvvm_tcgen05_st_32x32b_x1:
5404 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x1: {
5405 Info.opc = ISD::INTRINSIC_VOID;
5406 Info.memVT = MVT::i32;
5407 Info.ptrVal = I.getArgOperand(0);
5408 Info.offset = 0;
5409 Info.flags = MachineMemOperand::MOStore;
5410 Info.align.reset();
5411 Infos.push_back(Info);
5412 return;
5413 }
5414
5415 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
5416 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
5417 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
5418 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2: {
5419 Info.opc = ISD::INTRINSIC_VOID;
5420 Info.memVT = MVT::v2i32;
5421 Info.ptrVal = I.getArgOperand(0);
5422 Info.offset = 0;
5423 Info.flags = MachineMemOperand::MOStore;
5424 Info.align.reset();
5425 Infos.push_back(Info);
5426 return;
5427 }
5428
5429 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
5430 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
5431 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
5432 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
5433 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4: {
5434 Info.opc = ISD::INTRINSIC_VOID;
5435 Info.memVT = MVT::v4i32;
5436 Info.ptrVal = I.getArgOperand(0);
5437 Info.offset = 0;
5438 Info.flags = MachineMemOperand::MOStore;
5439 Info.align.reset();
5440 Infos.push_back(Info);
5441 return;
5442 }
5443
5444 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
5445 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
5446 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
5447 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
5448 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8: {
5449 Info.opc = ISD::INTRINSIC_VOID;
5450 Info.memVT = MVT::v8i32;
5451 Info.ptrVal = I.getArgOperand(0);
5452 Info.offset = 0;
5453 Info.flags = MachineMemOperand::MOStore;
5454 Info.align.reset();
5455 Infos.push_back(Info);
5456 return;
5457 }
5458
5459 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
5460 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
5461 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
5462 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
5463 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16: {
5464 Info.opc = ISD::INTRINSIC_VOID;
5465 Info.memVT = MVT::v16i32;
5466 Info.ptrVal = I.getArgOperand(0);
5467 Info.offset = 0;
5468 Info.flags = MachineMemOperand::MOStore;
5469 Info.align.reset();
5470 Infos.push_back(Info);
5471 return;
5472 }
5473
5474 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
5475 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
5476 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
5477 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
5478 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32: {
5479 Info.opc = ISD::INTRINSIC_VOID;
5480 Info.memVT = MVT::v32i32;
5481 Info.ptrVal = I.getArgOperand(0);
5482 Info.offset = 0;
5483 Info.flags = MachineMemOperand::MOStore;
5484 Info.align.reset();
5485 Infos.push_back(Info);
5486 return;
5487 }
5488
5489 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
5490 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
5491 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
5492 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
5493 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64: {
5494 Info.opc = ISD::INTRINSIC_VOID;
5495 Info.memVT = MVT::v64i32;
5496 Info.ptrVal = I.getArgOperand(0);
5497 Info.offset = 0;
5498 Info.flags = MachineMemOperand::MOStore;
5499 Info.align.reset();
5500 Infos.push_back(Info);
5501 return;
5502 }
5503
5504 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
5505 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
5506 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
5507 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
5508 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128: {
5509 Info.opc = ISD::INTRINSIC_VOID;
5510 Info.memVT = MVT::v128i32;
5511 Info.ptrVal = I.getArgOperand(0);
5512 Info.offset = 0;
5513 Info.flags = MachineMemOperand::MOStore;
5514 Info.align.reset();
5515 Infos.push_back(Info);
5516 return;
5517 }
5518 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
5519 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
5520 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
5521 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
5522 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
5523 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
5524 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
5525 case Intrinsic::
5526 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
5527 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
5528 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
5529 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
5530 case Intrinsic::
5531 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift: {
5532 // We are reading and writing back to TMem
5533 Info.opc = ISD::INTRINSIC_VOID;
5534 Info.memVT = MVT::v4i32;
5535 Info.ptrVal = I.getArgOperand(0);
5536 Info.offset = 0;
5538 Info.align = Align(16);
5539 Infos.push_back(Info);
5540 return;
5541 }
5542
5543 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
5544 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
5545 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
5546 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
5547 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
5548 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
5549 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
5550 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
5551 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
5552 case Intrinsic::
5553 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
5554 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
5555 case Intrinsic::
5556 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift: {
5557 // We are reading and writing back to TMem
5558 Info.opc = ISD::INTRINSIC_VOID;
5559 Info.memVT = MVT::v8i32;
5560 Info.ptrVal = I.getArgOperand(0);
5561 Info.offset = 0;
5563 Info.align = Align(16);
5564 Infos.push_back(Info);
5565 return;
5566 }
5567 }
5568}
5569
5570// Helper for getting a function parameter name. Name is composed from
5571// its index and the function name. Negative index corresponds to special
5572// parameter (unsized array) used for passing variable arguments.
5574 int Idx) const {
5575 std::string ParamName;
5576 raw_string_ostream ParamStr(ParamName);
5577
5578 ParamStr << getTargetMachine().getSymbol(F)->getName();
5579 if (Idx < 0)
5580 ParamStr << "_vararg";
5581 else
5582 ParamStr << "_param_" << Idx;
5583
5584 return ParamName;
5585}
5586
5587/// isLegalAddressingMode - Return true if the addressing mode represented
5588/// by AM is legal for this target, for a load/store of the specified type.
5589/// Used to guide target specific optimizations, like loop strength reduction
5590/// (LoopStrengthReduce.cpp) and memory optimization for address mode
5591/// (CodeGenPrepare.cpp)
5593 const AddrMode &AM, Type *Ty,
5594 unsigned AS, Instruction *I) const {
5595 // AddrMode - This represents an addressing mode of:
5596 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
5597 //
5598 // The legal address modes are
5599 // - [avar]
5600 // - [areg]
5601 // - [areg+immoff]
5602 // - [immAddr]
5603
5604 // immoff must fit in a signed 32-bit int
5605 if (!APInt(64, AM.BaseOffs).isSignedIntN(32))
5606 return false;
5607
5608 if (AM.BaseGV)
5609 return !AM.BaseOffs && !AM.HasBaseReg && !AM.Scale;
5610
5611 switch (AM.Scale) {
5612 case 0: // "r", "r+i" or "i" is allowed
5613 break;
5614 case 1:
5615 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
5616 return false;
5617 // Otherwise we have r+i.
5618 break;
5619 default:
5620 // No scale > 1 is allowed
5621 return false;
5622 }
5623 return true;
5624}
5625
5626//===----------------------------------------------------------------------===//
5627// NVPTX Inline Assembly Support
5628//===----------------------------------------------------------------------===//
5629
5630/// getConstraintType - Given a constraint letter, return the type of
5631/// constraint it is for this target.
5634 if (Constraint.size() == 1) {
5635 switch (Constraint[0]) {
5636 default:
5637 break;
5638 case 'b':
5639 case 'r':
5640 case 'h':
5641 case 'c':
5642 case 'l':
5643 case 'f':
5644 case 'd':
5645 case 'q':
5646 case '0':
5647 case 'N':
5648 return C_RegisterClass;
5649 }
5650 }
5651 return TargetLowering::getConstraintType(Constraint);
5652}
5653
5654std::pair<unsigned, const TargetRegisterClass *>
5656 StringRef Constraint,
5657 MVT VT) const {
5658 if (Constraint.size() == 1) {
5659 switch (Constraint[0]) {
5660 case 'b':
5661 return std::make_pair(0U, &NVPTX::B1RegClass);
5662 case 'c':
5663 case 'h':
5664 return std::make_pair(0U, &NVPTX::B16RegClass);
5665 case 'r':
5666 case 'f':
5667 return std::make_pair(0U, &NVPTX::B32RegClass);
5668 case 'l':
5669 case 'N':
5670 case 'd':
5671 return std::make_pair(0U, &NVPTX::B64RegClass);
5672 case 'q': {
5673 if (STI.getSmVersion() < 70)
5674 report_fatal_error("Inline asm with 128 bit operands is only "
5675 "supported for sm_70 and higher!");
5676 return std::make_pair(0U, &NVPTX::B128RegClass);
5677 }
5678 }
5679 }
5680 return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
5681}
5682
5683//===----------------------------------------------------------------------===//
5684// NVPTX DAG Combining
5685//===----------------------------------------------------------------------===//
5686
5688 CodeGenOptLevel OptLevel) const {
5689 // Always honor command-line argument
5690 if (FMAContractLevelOpt.getNumOccurrences() > 0)
5691 return FMAContractLevelOpt > 0;
5692
5693 // Do not contract if we're not optimizing the code.
5694 if (OptLevel == CodeGenOptLevel::None)
5695 return false;
5696
5697 // Honor TargetOptions flags that explicitly say fusion is okay.
5699 return true;
5700
5701 return false;
5702}
5703
5704static bool isConstZero(const SDValue &Operand) {
5705 const auto *Const = dyn_cast<ConstantSDNode>(Operand);
5706 return Const && Const->getZExtValue() == 0;
5707}
5708
5709/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
5710/// operands N0 and N1. This is a helper for PerformADDCombine that is
5711/// called with the default operands, and if that fails, with commuted
5712/// operands.
5713static SDValue
5716 EVT VT = N0.getValueType();
5717
5718 // Since integer multiply-add costs the same as integer multiply
5719 // but is more costly than integer add, do the fusion only when
5720 // the mul is only used in the add.
5721 // TODO: this may not be true for later architectures, consider relaxing this
5722 if (!N0.getNode()->hasOneUse())
5723 return SDValue();
5724
5725 // fold (add (select cond, 0, (mul a, b)), c)
5726 // -> (select cond, c, (add (mul a, b), c))
5727 //
5728 if (N0.getOpcode() == ISD::SELECT) {
5729 unsigned ZeroOpNum;
5730 if (isConstZero(N0->getOperand(1)))
5731 ZeroOpNum = 1;
5732 else if (isConstZero(N0->getOperand(2)))
5733 ZeroOpNum = 2;
5734 else
5735 return SDValue();
5736
5737 SDValue M = N0->getOperand((ZeroOpNum == 1) ? 2 : 1);
5738 if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
5739 return SDValue();
5740
5741 SDLoc DL(N);
5742 SDValue Mul =
5743 DCI.DAG.getNode(ISD::MUL, DL, VT, M->getOperand(0), M->getOperand(1));
5744 SDValue MAD = DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, N1);
5745 return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
5746 ((ZeroOpNum == 1) ? N1 : MAD),
5747 ((ZeroOpNum == 1) ? MAD : N1));
5748 }
5749
5750 return SDValue();
5751}
5752
5753SDValue NVPTXTargetLowering::performFADDCombineWithOperands(
5755 CodeGenOptLevel OptLevel) const {
5756 EVT VT = N0.getValueType();
5757 if (N0.getOpcode() == ISD::FMUL) {
5758 if (!(allowFMA(DCI.DAG.getMachineFunction(), OptLevel) ||
5759 (N->getFlags().hasAllowContract() &&
5760 N0->getFlags().hasAllowContract())))
5761 return SDValue();
5762
5763 // For floating point:
5764 // Do the fusion only when the mul has less than 5 uses and all
5765 // are add.
5766 // The heuristic is that if a use is not an add, then that use
5767 // cannot be fused into fma, therefore mul is still needed anyway.
5768 // If there are more than 4 uses, even if they are all add, fusing
5769 // them will increase register pressue.
5770 //
5771 int numUses = 0;
5772 int nonAddCount = 0;
5773 for (const SDNode *User : N0.getNode()->users()) {
5774 numUses++;
5775 if (User->getOpcode() != ISD::FADD)
5776 ++nonAddCount;
5777 if (numUses >= 5)
5778 return SDValue();
5779 }
5780 if (nonAddCount) {
5781 int orderNo = N->getIROrder();
5782 int orderNo2 = N0.getNode()->getIROrder();
5783 // simple heuristics here for considering potential register
5784 // pressure, the logics here is that the differnce are used
5785 // to measure the distance between def and use, the longer distance
5786 // more likely cause register pressure.
5787 if (orderNo - orderNo2 < 500)
5788 return SDValue();
5789
5790 // Now, check if at least one of the FMUL's operands is live beyond the
5791 // node N, which guarantees that the FMA will not increase register
5792 // pressure at node N.
5793 bool opIsLive = false;
5794 const SDNode *left = N0.getOperand(0).getNode();
5795 const SDNode *right = N0.getOperand(1).getNode();
5796
5797 if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
5798 opIsLive = true;
5799
5800 if (!opIsLive)
5801 for (const SDNode *User : left->users()) {
5802 int orderNo3 = User->getIROrder();
5803 if (orderNo3 > orderNo) {
5804 opIsLive = true;
5805 break;
5806 }
5807 }
5808
5809 if (!opIsLive)
5810 for (const SDNode *User : right->users()) {
5811 int orderNo3 = User->getIROrder();
5812 if (orderNo3 > orderNo) {
5813 opIsLive = true;
5814 break;
5815 }
5816 }
5817
5818 if (!opIsLive)
5819 return SDValue();
5820 }
5821
5822 return DCI.DAG.getNode(ISD::FMA, SDLoc(N), VT, N0.getOperand(0),
5823 N0.getOperand(1), N1);
5824 }
5825
5826 return SDValue();
5827}
5828
5829/// Fold unpacking movs into a load by increasing the number of return values.
5830///
5831/// ex:
5832/// L: v2f16,ch = load <p>
5833/// a: f16 = extractelt L:0, 0
5834/// b: f16 = extractelt L:0, 1
5835/// use(a, b)
5836///
5837/// ...is turned into...
5838///
5839/// L: f16,f16,ch = LoadV2 <p>
5840/// use(L:0, L:1)
5841static SDValue
5843 // Don't run this optimization before the legalizer
5844 if (!DCI.isAfterLegalizeDAG())
5845 return SDValue();
5846
5847 EVT ElementVT = N->getValueType(0);
5848 // Avoid non-packed types and v4i8
5849 if (!NVPTX::isPackedVectorTy(ElementVT) || ElementVT == MVT::v4i8)
5850 return SDValue();
5851
5852 // Check whether all outputs are either used by an extractelt or are
5853 // glue/chain nodes
5854 if (!all_of(N->uses(), [&](SDUse &U) {
5855 // Skip glue, chain nodes
5856 if (U.getValueType() == MVT::Glue || U.getValueType() == MVT::Other)
5857 return true;
5858 if (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
5859 if (N->getOpcode() != ISD::LOAD)
5860 return true;
5861 // Since this is an ISD::LOAD, check all extractelts are used. If
5862 // any are not used, we don't want to defeat another optimization that
5863 // will narrow the load.
5864 //
5865 // For example:
5866 //
5867 // L: v2f16,ch = load <p>
5868 // e0: f16 = extractelt L:0, 0
5869 // e1: f16 = extractelt L:0, 1 <-- unused
5870 // store e0
5871 //
5872 // Can be optimized by DAGCombiner to:
5873 //
5874 // L: f16,ch = load <p>
5875 // store L:0
5876 return !U.getUser()->use_empty();
5877 }
5878
5879 // Otherwise, this use prevents us from splitting a value.
5880 return false;
5881 }))
5882 return SDValue();
5883
5884 auto *LD = cast<MemSDNode>(N);
5885 SDLoc DL(LD);
5886
5887 // the new opcode after we double the number of operands
5888 unsigned Opcode;
5889 SmallVector<SDValue> Operands(LD->ops());
5890 unsigned OldNumOutputs; // non-glue, non-chain outputs
5891 switch (LD->getOpcode()) {
5892 case ISD::LOAD:
5893 OldNumOutputs = 1;
5894 // Any packed type is legal, so the legalizer will not have lowered
5895 // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
5896 // here.
5897 Opcode = NVPTXISD::LoadV2;
5898 // append a "full" used bytes mask operand right before the extension type
5899 // operand, signifying that all bytes are used.
5900 Operands.push_back(DCI.DAG.getConstant(UINT32_MAX, DL, MVT::i32));
5901 Operands.push_back(DCI.DAG.getIntPtrConstant(
5902 cast<LoadSDNode>(LD)->getExtensionType(), DL));
5903 break;
5904 case NVPTXISD::LoadV2:
5905 OldNumOutputs = 2;
5906 Opcode = NVPTXISD::LoadV4;
5907 break;
5908 case NVPTXISD::LoadV4:
5909 // V8 is only supported for f32/i32. Don't forget, we're not changing the
5910 // load size here. This is already a 256-bit load.
5911 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
5912 return SDValue();
5913 OldNumOutputs = 4;
5914 Opcode = NVPTXISD::LoadV8;
5915 break;
5916 case NVPTXISD::LoadV8:
5917 // PTX doesn't support the next doubling of outputs
5918 return SDValue();
5919 }
5920
5921 // the non-glue, non-chain outputs in the new load
5922 const unsigned NewNumOutputs = OldNumOutputs * 2;
5923 SmallVector<EVT> NewVTs(NewNumOutputs, ElementVT.getVectorElementType());
5924 // add remaining chain and glue values
5925 NewVTs.append(LD->value_begin() + OldNumOutputs, LD->value_end());
5926
5927 // Create the new load
5928 SDValue NewLoad = DCI.DAG.getMemIntrinsicNode(
5929 Opcode, DL, DCI.DAG.getVTList(NewVTs), Operands, LD->getMemoryVT(),
5930 LD->getMemOperand());
5931
5932 // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
5933 // the outputs the same. These nodes will be optimized away in later
5934 // DAGCombiner iterations.
5936 for (unsigned I : seq(OldNumOutputs))
5937 Results.push_back(DCI.DAG.getBuildVector(
5938 ElementVT, DL, {NewLoad.getValue(I * 2), NewLoad.getValue(I * 2 + 1)}));
5939 // Add remaining chain and glue nodes
5940 for (unsigned I : seq(NewLoad->getNumValues() - NewNumOutputs))
5941 Results.push_back(NewLoad.getValue(NewNumOutputs + I));
5942
5943 return DCI.DAG.getMergeValues(Results, DL);
5944}
5945
5946/// Fold packing movs into a store.
5947///
5948/// ex:
5949/// v1: v2f16 = BUILD_VECTOR a:f16, b:f16
5950/// v2: v2f16 = BUILD_VECTOR c:f16, d:f16
5951/// StoreV2 v1, v2
5952///
5953/// ...is turned into...
5954///
5955/// StoreV4 a, b, c, d
5958 unsigned Front, unsigned Back) {
5959 // We want to run this as late as possible since other optimizations may
5960 // eliminate the BUILD_VECTORs.
5961 if (!DCI.isAfterLegalizeDAG())
5962 return SDValue();
5963
5964 // Get the type of the operands being stored.
5965 EVT ElementVT = N->getOperand(Front).getValueType();
5966
5967 // Avoid non-packed types and v4i8
5968 if (!NVPTX::isPackedVectorTy(ElementVT) || ElementVT == MVT::v4i8)
5969 return SDValue();
5970
5971 auto *ST = cast<MemSDNode>(N);
5972
5973 // The new opcode after we double the number of operands.
5974 unsigned Opcode;
5975 switch (N->getOpcode()) {
5976 case ISD::STORE:
5977 // Any packed type is legal, so the legalizer will not have lowered
5978 // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
5979 // it here.
5980 Opcode = NVPTXISD::StoreV2;
5981 break;
5982 case NVPTXISD::StoreV2:
5983 Opcode = NVPTXISD::StoreV4;
5984 break;
5985 case NVPTXISD::StoreV4:
5986 // V8 is only supported for f32/i32. Don't forget, we're not changing the
5987 // store size here. This is already a 256-bit store.
5988 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
5989 return SDValue();
5990 Opcode = NVPTXISD::StoreV8;
5991 break;
5992 case NVPTXISD::StoreV8:
5993 // PTX doesn't support the next doubling of operands
5994 return SDValue();
5995 default:
5996 llvm_unreachable("Unhandled store opcode");
5997 }
5998
5999 // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
6000 // their elements.
6001 SmallVector<SDValue, 4> Operands(N->ops().take_front(Front));
6002 for (SDValue BV : N->ops().drop_front(Front).drop_back(Back)) {
6003 if (BV.getOpcode() != ISD::BUILD_VECTOR)
6004 return SDValue();
6005
6006 // If the operand has multiple uses, this optimization can increase register
6007 // pressure.
6008 if (!BV.hasOneUse())
6009 return SDValue();
6010
6011 // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
6012 // any signs they may be folded by some other pattern or rule.
6013 for (SDValue Op : BV->ops()) {
6014 // Peek through bitcasts
6015 if (Op.getOpcode() == ISD::BITCAST)
6016 Op = Op.getOperand(0);
6017
6018 // This may be folded into a PRMT.
6019 if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
6020 Op->getOperand(0).getValueType() == MVT::i32)
6021 return SDValue();
6022
6023 // This may be folded into cvt.bf16x2
6024 if (Op.getOpcode() == ISD::FP_ROUND)
6025 return SDValue();
6026 }
6027 Operands.append({BV.getOperand(0), BV.getOperand(1)});
6028 }
6029 Operands.append(N->op_end() - Back, N->op_end());
6030
6031 // Now we replace the store
6032 return DCI.DAG.getMemIntrinsicNode(Opcode, SDLoc(N), N->getVTList(), Operands,
6033 ST->getMemoryVT(), ST->getMemOperand());
6034}
6035
6037 const NVPTXSubtarget &STI) {
6038
6039 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::STORE) {
6040 // Here is our chance to custom lower a store with a non-simple type.
6041 // Unfortunately, we can't do this in the legalizer because there is no
6042 // way to setOperationAction for an non-simple type.
6044 if (!ST->getValue().getValueType().isSimple())
6045 return lowerSTOREVector(SDValue(ST, 0), DCI.DAG, STI);
6046 }
6047
6048 return combinePackingMovIntoStore(N, DCI, 1, 2);
6049}
6050
6052 const NVPTXSubtarget &STI) {
6053 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::LOAD) {
6054 // Here is our chance to custom lower a load with a non-simple type.
6055 // Unfortunately, we can't do this in the legalizer because there is no
6056 // way to setOperationAction for an non-simple type.
6057 if (!N->getValueType(0).isSimple())
6058 return lowerLoadVector(N, DCI.DAG, STI);
6059 }
6060
6061 return combineUnpackingMovIntoLoad(N, DCI);
6062}
6063
6064/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
6065///
6068 CodeGenOptLevel OptLevel) {
6069 if (OptLevel == CodeGenOptLevel::None)
6070 return SDValue();
6071
6072 SDValue N0 = N->getOperand(0);
6073 SDValue N1 = N->getOperand(1);
6074
6075 // Skip non-integer, non-scalar case
6076 EVT VT = N0.getValueType();
6077 if (VT.isVector() || VT != MVT::i32)
6078 return SDValue();
6079
6080 // First try with the default operand order.
6081 if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
6082 return Result;
6083
6084 // If that didn't work, try again with the operands commuted.
6085 return PerformADDCombineWithOperands(N, N1, N0, DCI);
6086}
6087
6088/// Check if a v2f32 BUILD_VECTOR provably packs values from non-adjacent
6089/// register pairs (non-coalescable).
6090static bool isNonCoalescableBuildVector(const SDValue &BV) {
6091 if (BV.getOpcode() != ISD::BUILD_VECTOR || BV.getValueType() != MVT::v2f32)
6092 return false;
6093
6094 SDValue Elt0 = BV.getOperand(0);
6095 SDValue Elt1 = BV.getOperand(1);
6096
6097 bool IsExt0 = Elt0.getOpcode() == ISD::EXTRACT_VECTOR_ELT;
6098 bool IsExt1 = Elt1.getOpcode() == ISD::EXTRACT_VECTOR_ELT;
6099
6100 // If neither element is an EXTRACT_VECTOR_ELT they are free-standing
6101 // scalars and the register allocator can still place them side-by-side.
6102 if (!IsExt0 && !IsExt1)
6103 return false;
6104
6105 // If exactly one element is an EXTRACT_VECTOR_ELT, the other is a scalar
6106 // that cannot generally occupy the adjacent register slot.
6107 if (IsExt0 != IsExt1)
6108 return true;
6109
6110 // At this point both sources are extracting from vectors. If they are from
6111 // different vectors, then the BUILD_VECTOR is non-coalescable.
6112 SDValue Src0 = Elt0.getOperand(0);
6113 SDValue Src1 = Elt1.getOperand(0);
6114 if (Src0 != Src1)
6115 return true;
6116
6117 auto *Idx0 = dyn_cast<ConstantSDNode>(Elt0.getOperand(1));
6118 auto *Idx1 = dyn_cast<ConstantSDNode>(Elt1.getOperand(1));
6119 // If both indices are dynamic they will be lowered to
6120 // loads and the vector will be spilled to local memory. The register
6121 // allocator can easily place the results in adjacent registers.
6122 if (!Idx0 && !Idx1)
6123 return false;
6124
6125 // If one index is dynamic and the other is constant, the value from the
6126 // constant load will result in an additional register to pair with the result
6127 // from the dynamic load. We consider this non-coalescable.
6128 if ((Idx0 && !Idx1) || (!Idx0 && Idx1))
6129 return true;
6130
6131 // Both are constant, adjacent pairs are coalescable
6132 return std::abs(Idx0->getSExtValue() - Idx1->getSExtValue()) != 1;
6133}
6134
6135/// Return true if FMUL v2f32 node \p N may be scalarized to fold each lane's
6136/// product into a scalar FMA.
6137bool NVPTXTargetLowering::mayFoldFMULIntoFMA(SDNode *N, MachineFunction &MF,
6138 CodeGenOptLevel OptLevel) const {
6139 if (N->getOpcode() != ISD::FMUL || N->getValueType(0) != MVT::v2f32)
6140 return false;
6141 const bool GlobalFMA = allowFMA(MF, OptLevel);
6142 if (!N->getFlags().hasAllowContract() && !GlobalFMA)
6143 return false;
6144
6145 const SDNode *FirstFAdd = nullptr;
6146 unsigned NumScalarFAdd = 0;
6147
6148 // Both lanes must feed unique FADDs
6149 for (SDNode *EE : N->users()) {
6150 if (NumScalarFAdd == 2)
6151 return false;
6152
6153 if (EE->getOpcode() != ISD::EXTRACT_VECTOR_ELT || !EE->hasOneUse() ||
6154 !isa<ConstantSDNode>(EE->getOperand(1)))
6155 return false;
6156
6157 const SDNode *const FAdd = *EE->users().begin();
6158 if (FAdd->getOpcode() != ISD::FADD ||
6159 (!GlobalFMA && !FAdd->getFlags().hasAllowContract()))
6160 return false;
6161
6162 if (!FirstFAdd)
6163 FirstFAdd = FAdd;
6164 else if (FAdd == FirstFAdd)
6165 return false;
6166
6167 NumScalarFAdd++;
6168 }
6169
6170 return NumScalarFAdd == 2;
6171}
6172
6173/// Scalarize a v2f32 arithmetic node (FADD, FMUL, FSUB, FMA) when at least
6174/// one operand is a BUILD_VECTOR that repacks values from non-adjacent register
6175/// pairs. Without this combine the BUILD_VECTOR forces allocation of a
6176/// temporary 64-bit register, increasing register pressure.
6177///
6178/// Example - before:
6179/// t0: v2f32,v2f32,ch = LoadV2 ...
6180/// t1: f32 = extract_vector_elt t0, 0
6181/// t2: f32 = extract_vector_elt t0:1, 0
6182/// t3: v2f32 = BUILD_VECTOR t1, t2 ;; non-coalescable repack
6183/// t4: v2f32 = fma t_a, t3, t_c
6184///
6185/// After:
6186/// t0: v2f32,v2f32,ch = LoadV2 ...
6187/// t1: f32 = extract_vector_elt t0, 0
6188/// t2: f32 = extract_vector_elt t0:1, 0
6189/// a0: f32 = extract_vector_elt t_a, 0
6190/// a1: f32 = extract_vector_elt t_a, 1
6191/// c0: f32 = extract_vector_elt t_c, 0
6192/// c1: f32 = extract_vector_elt t_c, 1
6193/// r0: f32 = fma a0, t1, c0
6194/// r1: f32 = fma a1, t2, c1
6195/// t4: v2f32 = BUILD_VECTOR r0, r1
6196///
6197/// Also scalarizes an FMUL when all output lanes feed into scalar FADDs
6198/// to enable scalar FMA combining.
6199SDValue NVPTXTargetLowering::performScalarizeV2F32Op(
6201 CodeGenOptLevel OptLevel) const {
6202 EVT VT = N->getValueType(0);
6203 if (VT != MVT::v2f32)
6204 return SDValue();
6205
6206 if (none_of(N->ops(), isNonCoalescableBuildVector) &&
6207 !mayFoldFMULIntoFMA(N, DCI.DAG.getMachineFunction(), OptLevel))
6208 return SDValue();
6209
6210 SelectionDAG &DAG = DCI.DAG;
6211 SDLoc DL(N);
6212 EVT EltVT = VT.getVectorElementType();
6213 unsigned Opc = N->getOpcode();
6214
6215 // For each operand, get the scalar element at the given index: if the operand
6216 // is a BUILD_VECTOR, grab the element directly; otherwise, emit an
6217 // EXTRACT_VECTOR_ELT.
6218 auto GetElement = [&](SDValue Op, unsigned Index) -> SDValue {
6219 if (Op.getOpcode() == ISD::BUILD_VECTOR)
6220 return Op.getOperand(Index);
6221 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Op,
6222 DAG.getVectorIdxConstant(Index, DL));
6223 };
6224
6225 // Build scalar operand lists for element 0 and element 1.
6226 SmallVector<SDValue, 3> Ops0, Ops1;
6227 for (const SDValue &Op : N->ops()) {
6228 Ops0.push_back(GetElement(Op, 0));
6229 Ops1.push_back(GetElement(Op, 1));
6230 }
6231
6232 SDValue Res0 = DAG.getNode(Opc, DL, EltVT, Ops0, N->getFlags());
6233 SDValue Res1 = DAG.getNode(Opc, DL, EltVT, Ops1, N->getFlags());
6234
6235 return DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Res0, Res1);
6236}
6237
6238/// Target-specific dag combine xforms for ISD::FADD.
6239SDValue
6240NVPTXTargetLowering::performFADDCombine(SDNode *N,
6242 CodeGenOptLevel OptLevel) const {
6243 if (SDValue Result = performScalarizeV2F32Op(N, DCI, OptLevel))
6244 return Result;
6245
6246 SDValue N0 = N->getOperand(0);
6247 SDValue N1 = N->getOperand(1);
6248
6249 EVT VT = N0.getValueType();
6250 if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
6251 return SDValue();
6252
6253 // First try with the default operand order.
6254 if (SDValue Result = performFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
6255 return Result;
6256
6257 // If that didn't work, try again with the operands commuted.
6258 return performFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
6259}
6260
6261/// Get 3-input version of a 2-input min/max opcode
6262static unsigned getMinMax3Opcode(unsigned MinMax2Opcode) {
6263 switch (MinMax2Opcode) {
6264 case ISD::FMAXNUM:
6265 case ISD::FMAXIMUMNUM:
6266 return NVPTXISD::FMAXNUM3;
6267 case ISD::FMINNUM:
6268 case ISD::FMINIMUMNUM:
6269 return NVPTXISD::FMINNUM3;
6270 case ISD::FMAXIMUM:
6271 return NVPTXISD::FMAXIMUM3;
6272 case ISD::FMINIMUM:
6273 return NVPTXISD::FMINIMUM3;
6274 default:
6275 llvm_unreachable("Invalid 2-input min/max opcode");
6276 }
6277}
6278
6279/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into
6280/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics.
6283 unsigned PTXVersion, unsigned SmVersion) {
6284
6285 // 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s
6286 EVT VT = N->getValueType(0);
6287 if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100)
6288 return SDValue();
6289
6290 SDValue Op0 = N->getOperand(0);
6291 SDValue Op1 = N->getOperand(1);
6292 unsigned MinMaxOp2 = N->getOpcode();
6293 unsigned MinMaxOp3 = getMinMax3Opcode(MinMaxOp2);
6294
6295 if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
6296 // (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
6297 SDValue A = Op0.getOperand(0);
6298 SDValue B = Op0.getOperand(1);
6299 SDValue C = Op1;
6300 return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
6301 } else if (Op1.getOpcode() == MinMaxOp2 && Op1.hasOneUse()) {
6302 // (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
6303 SDValue A = Op0;
6304 SDValue B = Op1.getOperand(0);
6305 SDValue C = Op1.getOperand(1);
6306 return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
6307 }
6308 return SDValue();
6309}
6310
6313 CodeGenOptLevel OptLevel) {
6314 assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM);
6315
6316 // Don't do anything at less than -O2.
6317 if (OptLevel < CodeGenOptLevel::Default)
6318 return SDValue();
6319
6320 SelectionDAG &DAG = DCI.DAG;
6321 SDLoc DL(N);
6322 EVT VT = N->getValueType(0);
6323 bool IsSigned = N->getOpcode() == ISD::SREM;
6324 unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV;
6325
6326 const SDValue &Num = N->getOperand(0);
6327 const SDValue &Den = N->getOperand(1);
6328
6329 for (const SDNode *U : Num->users()) {
6330 if (U->getOpcode() == DivOpc && U->getOperand(0) == Num &&
6331 U->getOperand(1) == Den) {
6332 // Num % Den -> Num - (Num / Den) * Den
6333 return DAG.getNode(ISD::SUB, DL, VT, Num,
6334 DAG.getNode(ISD::MUL, DL, VT,
6335 DAG.getNode(DivOpc, DL, VT, Num, Den),
6336 Den));
6337 }
6338 }
6339 return SDValue();
6340}
6341
6342// sext (mul.iN nsw x, y) => mul.wide.sN x, y
6343// zext (mul.iN nuw x, y) => mul.wide.uN x, y
6344// sext (shl.iN nsw x, const) => mul.wide.sN x, (1 << const)
6345// zext (shl.iN nuw x, const) => mul.wide.uN x, (1 << const)
6348 CodeGenOptLevel OptLevel) {
6349 assert(N->getOpcode() == ISD::SIGN_EXTEND ||
6350 N->getOpcode() == ISD::ZERO_EXTEND);
6351
6352 if (OptLevel == CodeGenOptLevel::None)
6353 return SDValue();
6354
6355 SDValue Op = N->getOperand(0);
6356 if (!Op.hasOneUse())
6357 return SDValue();
6358
6359 EVT ToVT = N->getValueType(0);
6360 EVT FromVT = Op.getValueType();
6361 if (!((ToVT == MVT::i32 && FromVT == MVT::i16) ||
6362 (ToVT == MVT::i64 && FromVT == MVT::i32)))
6363 return SDValue();
6364
6365 bool IsSigned = N->getOpcode() == ISD::SIGN_EXTEND;
6366 if ((IsSigned && !Op->getFlags().hasNoSignedWrap()) ||
6367 (!IsSigned && !Op->getFlags().hasNoUnsignedWrap()))
6368 return SDValue();
6369
6370 SDLoc DL(N);
6371 SDValue LHS = Op.getOperand(0);
6372 SDValue RHS = Op.getOperand(1);
6373 unsigned MulWideOpcode =
6374 IsSigned ? NVPTXISD::MUL_WIDE_SIGNED : NVPTXISD::MUL_WIDE_UNSIGNED;
6375 if (Op.getOpcode() == ISD::MUL) {
6376 return DCI.DAG.getNode(MulWideOpcode, DL, ToVT, LHS, RHS);
6377 } else if (Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(RHS)) {
6378 const auto ShiftAmt = Op.getConstantOperandVal(1);
6379 const auto MulVal = APInt(FromVT.getSizeInBits(), 1) << ShiftAmt;
6380
6381 // Note that the sext (shl nsw ...) case doesn't work if 1 << const
6382 // overflows to a negative value! The only valid input values in this
6383 // case are 0 and -1 (all other values yield poison because of the nsw),
6384 // and mul.wide.sN would give us the wrong sign for -1. We could use
6385 // mul.wide.uN, but since this is a weird case anyway, we might as well not
6386 // apply this transformation at all.
6387 if (IsSigned && MulVal.isNegative())
6388 return SDValue();
6389
6390 RHS = DCI.DAG.getConstant(MulVal, DL, FromVT);
6391 return DCI.DAG.getNode(MulWideOpcode, DL, ToVT, LHS, RHS);
6392 }
6393
6394 return SDValue();
6395}
6396
6402
6403/// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
6404/// that can be demoted to \p OptSize bits without loss of information. The
6405/// signedness of the operand, if determinable, is placed in \p S.
6407 unsigned OptSize,
6408 OperandSignedness &S) {
6409 S = Unknown;
6410
6411 if (Op.getOpcode() == ISD::SIGN_EXTEND ||
6412 Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
6413 EVT OrigVT = Op.getOperand(0).getValueType();
6414 if (OrigVT.getFixedSizeInBits() <= OptSize) {
6415 S = Signed;
6416 return true;
6417 }
6418 } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
6419 EVT OrigVT = Op.getOperand(0).getValueType();
6420 if (OrigVT.getFixedSizeInBits() <= OptSize) {
6421 S = Unsigned;
6422 return true;
6423 }
6424 }
6425
6426 return false;
6427}
6428
6429/// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
6430/// be demoted to \p OptSize bits without loss of information. If the operands
6431/// contain a constant, it should appear as the RHS operand. The signedness of
6432/// the operands is placed in \p IsSigned.
6434 unsigned OptSize,
6435 bool &IsSigned) {
6436 OperandSignedness LHSSign;
6437
6438 // The LHS operand must be a demotable op
6439 if (!IsMulWideOperandDemotable(LHS, OptSize, LHSSign))
6440 return false;
6441
6442 // We should have been able to determine the signedness from the LHS
6443 if (LHSSign == Unknown)
6444 return false;
6445
6446 IsSigned = (LHSSign == Signed);
6447
6448 // The RHS can be a demotable op or a constant
6450 const APInt &Val = CI->getAPIntValue();
6451 if (LHSSign == Unsigned) {
6452 return Val.isIntN(OptSize);
6453 } else {
6454 return Val.isSignedIntN(OptSize);
6455 }
6456 } else {
6457 OperandSignedness RHSSign;
6458 if (!IsMulWideOperandDemotable(RHS, OptSize, RHSSign))
6459 return false;
6460
6461 return LHSSign == RHSSign;
6462 }
6463}
6464
6465/// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
6466/// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
6467/// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
6468/// amount.
6471 EVT MulType = N->getValueType(0);
6472 if (MulType != MVT::i32 && MulType != MVT::i64) {
6473 return SDValue();
6474 }
6475
6476 SDLoc DL(N);
6477 unsigned OptSize = MulType.getSizeInBits() >> 1;
6478 SDValue LHS = N->getOperand(0);
6479 SDValue RHS = N->getOperand(1);
6480
6481 // Canonicalize the multiply so the constant (if any) is on the right
6482 if (N->getOpcode() == ISD::MUL) {
6483 if (isa<ConstantSDNode>(LHS)) {
6484 std::swap(LHS, RHS);
6485 }
6486 }
6487
6488 // If we have a SHL, determine the actual multiply amount
6489 if (N->getOpcode() == ISD::SHL) {
6491 if (!ShlRHS) {
6492 return SDValue();
6493 }
6494
6495 APInt ShiftAmt = ShlRHS->getAPIntValue();
6496 unsigned BitWidth = MulType.getSizeInBits();
6497 if (ShiftAmt.sge(0) && ShiftAmt.slt(BitWidth)) {
6498 APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
6499 RHS = DCI.DAG.getConstant(MulVal, DL, MulType);
6500 } else {
6501 return SDValue();
6502 }
6503 }
6504
6505 bool Signed;
6506 // Verify that our operands are demotable
6507 if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, Signed)) {
6508 return SDValue();
6509 }
6510
6511 EVT DemotedVT;
6512 if (MulType == MVT::i32) {
6513 DemotedVT = MVT::i16;
6514 } else {
6515 DemotedVT = MVT::i32;
6516 }
6517
6518 // Truncate the operands to the correct size. Note that these are just for
6519 // type consistency and will (likely) be eliminated in later phases.
6520 SDValue TruncLHS =
6521 DCI.DAG.getNode(ISD::TRUNCATE, DL, DemotedVT, LHS);
6522 SDValue TruncRHS =
6523 DCI.DAG.getNode(ISD::TRUNCATE, DL, DemotedVT, RHS);
6524
6525 unsigned Opc;
6526 if (Signed) {
6527 Opc = NVPTXISD::MUL_WIDE_SIGNED;
6528 } else {
6529 Opc = NVPTXISD::MUL_WIDE_UNSIGNED;
6530 }
6531
6532 return DCI.DAG.getNode(Opc, DL, MulType, TruncLHS, TruncRHS);
6533}
6534
6535static bool isConstOne(const SDValue &Operand) {
6536 const auto *Const = dyn_cast<ConstantSDNode>(Operand);
6537 return Const && Const->getZExtValue() == 1;
6538}
6539
6541 if (Add->getOpcode() != ISD::ADD)
6542 return SDValue();
6543
6544 if (isConstOne(Add->getOperand(0)))
6545 return Add->getOperand(1);
6546
6547 if (isConstOne(Add->getOperand(1)))
6548 return Add->getOperand(0);
6549
6550 return SDValue();
6551}
6552
6555
6557 SDValue Mul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
6558 return DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, X);
6559 }
6560
6561 return SDValue();
6562}
6563
6565 SDLoc DL,
6567 if (Select->getOpcode() != ISD::SELECT)
6568 return SDValue();
6569
6570 SDValue Cond = Select->getOperand(0);
6571
6572 unsigned ConstOpNo;
6573 if (isConstOne(Select->getOperand(1)))
6574 ConstOpNo = 1;
6575 else if (isConstOne(Select->getOperand(2)))
6576 ConstOpNo = 2;
6577 else
6578 return SDValue();
6579
6580 SDValue Y = Select->getOperand((ConstOpNo == 1) ? 2 : 1);
6581
6582 // Do not combine if the resulting sequence is not obviously profitable.
6584 return SDValue();
6585
6586 SDValue NewMul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
6587
6588 return DCI.DAG.getNode(ISD::SELECT, DL, VT, Cond,
6589 (ConstOpNo == 1) ? X : NewMul,
6590 (ConstOpNo == 1) ? NewMul : X);
6591}
6592
6593static SDValue
6596
6597 EVT VT = N0.getValueType();
6598 if (VT.isVector())
6599 return SDValue();
6600
6601 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
6602 return SDValue();
6603
6604 SDLoc DL(N);
6605
6606 // (mul x, (add y, 1)) -> (add (mul x, y), x)
6607 if (SDValue Res = combineMADConstOne(N0, N1, VT, DL, DCI))
6608 return Res;
6609 if (SDValue Res = combineMADConstOne(N1, N0, VT, DL, DCI))
6610 return Res;
6611
6612 // (mul x, (select y, 1)) -> (select (mul x, y), x)
6613 if (SDValue Res = combineMulSelectConstOne(N0, N1, VT, DL, DCI))
6614 return Res;
6615 if (SDValue Res = combineMulSelectConstOne(N1, N0, VT, DL, DCI))
6616 return Res;
6617
6618 return SDValue();
6619}
6620
6621/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
6624 CodeGenOptLevel OptLevel) {
6625 if (OptLevel == CodeGenOptLevel::None)
6626 return SDValue();
6627
6628 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6629 return Ret;
6630
6631 SDValue N0 = N->getOperand(0);
6632 SDValue N1 = N->getOperand(1);
6633 return PerformMULCombineWithOperands(N, N0, N1, DCI);
6634}
6635
6636/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
6639 CodeGenOptLevel OptLevel) {
6640 if (OptLevel > CodeGenOptLevel::None) {
6641 // Try mul.wide combining at OptLevel > 0
6642 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6643 return Ret;
6644 }
6645
6646 return SDValue();
6647}
6648
6651 unsigned int SmVersion) {
6652 EVT CCType = N->getValueType(0);
6653 SDValue A = N->getOperand(0);
6654 SDValue B = N->getOperand(1);
6655
6656 EVT AType = A.getValueType();
6657 if (!(CCType == MVT::v2i1 && (AType == MVT::v2f16 || AType == MVT::v2bf16)))
6658 return SDValue();
6659
6660 if (A.getValueType() == MVT::v2bf16 && SmVersion < 90)
6661 return SDValue();
6662
6663 SDLoc DL(N);
6664 // setp.f16x2 returns two scalar predicates, which we need to
6665 // convert back to v2i1. The returned result will be scalarized by
6666 // the legalizer, but the comparison will remain a single vector
6667 // instruction.
6668 SDValue CCNode = DCI.DAG.getNode(
6669 A.getValueType() == MVT::v2f16 ? NVPTXISD::SETP_F16X2
6671 DL, DCI.DAG.getVTList(MVT::i1, MVT::i1), {A, B, N->getOperand(2)});
6672 return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, CCType, CCNode.getValue(0),
6673 CCNode.getValue(1));
6674}
6675
6678 SDValue Vector = peekThroughFreeze(N->getOperand(0));
6679 SDLoc DL(N);
6680 EVT VectorVT = Vector.getValueType();
6681 if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
6682 IsPTXVectorType(VectorVT.getSimpleVT()))
6683 return SDValue(); // Native vector loads already combine nicely w/
6684 // extract_vector_elt.
6685 // Don't mess with singletons or packed types (v2*32, v2*16, v4i8 and v8i8),
6686 // we already handle them OK.
6687 if (VectorVT.getVectorNumElements() == 1 ||
6688 NVPTX::isPackedVectorTy(VectorVT) || VectorVT == MVT::v8i8)
6689 return SDValue();
6690
6691 // Don't mess with undef values as sra may be simplified to 0, not undef.
6692 if (Vector->isUndef() || ISD::allOperandsUndef(Vector.getNode()))
6693 return SDValue();
6694
6695 uint64_t VectorBits = VectorVT.getSizeInBits();
6696 // We only handle the types we can extract in-register.
6697 if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64))
6698 return SDValue();
6699
6700 ConstantSDNode *Index = dyn_cast<ConstantSDNode>(N->getOperand(1));
6701 // Index == 0 is handled by generic DAG combiner.
6702 if (!Index || Index->getZExtValue() == 0)
6703 return SDValue();
6704
6705 MVT IVT = MVT::getIntegerVT(VectorBits);
6706 EVT EltVT = VectorVT.getVectorElementType();
6707 EVT EltIVT = EltVT.changeTypeToInteger();
6708 uint64_t EltBits = EltVT.getScalarSizeInBits();
6709
6710 SDValue Result = DCI.DAG.getNode(
6711 ISD::TRUNCATE, DL, EltIVT,
6712 DCI.DAG.getNode(
6713 ISD::SRA, DL, IVT, DCI.DAG.getNode(ISD::BITCAST, DL, IVT, Vector),
6714 DCI.DAG.getConstant(Index->getZExtValue() * EltBits, DL, IVT)));
6715
6716 // If element has non-integer type, bitcast it back to the expected type.
6717 if (EltVT != EltIVT)
6718 Result = DCI.DAG.getNode(ISD::BITCAST, DL, EltVT, Result);
6719 // Past legalizer, we may need to extent i8 -> i16 to match the register type.
6720 if (EltVT != N->getValueType(0))
6721 Result = DCI.DAG.getNode(ISD::ANY_EXTEND, DL, N->getValueType(0), Result);
6722
6723 return Result;
6724}
6725
6726/// Transform patterns like:
6727/// (select (ugt shift_amt, BitWidth-1), 0, (srl/shl x, shift_amt))
6728/// (select (ult shift_amt, BitWidth), (srl/shl x, shift_amt), 0)
6729/// Into:
6730/// (NVPTXISD::SRL_CLAMP x, shift_amt) or (NVPTXISD::SHL_CLAMP x, shift_amt)
6731///
6732/// These patterns arise from code like `s >= 32 ? 0 : x >> s`. In LLVM,
6733/// over-shifting a value results in poison, but PTX shr/shl instructions clamp
6734/// the shift amount to BitWidth, making the guard redundant.
6735///
6736/// Note: We only handle SRL and SHL, not SRA, because arithmetic right shifts
6737/// can produce 0 or -1 when shift >= BitWidth.
6738/// Note: We don't handle uge or ule. These don't appear because of
6739/// canonicalization.
6742 if (!DCI.isAfterLegalizeDAG())
6743 return SDValue();
6744
6745 using namespace SDPatternMatch;
6746 unsigned BitWidth = N->getValueType(0).getSizeInBits();
6747 SDValue ShiftAmt, ShiftOp;
6748
6749 // Match logical shifts where the shift amount in the guard matches the shift
6750 // amount in the operation.
6751 auto LogicalShift =
6752 m_AllOf(m_Value(ShiftOp),
6753 m_AnyOf(m_Srl(m_Value(), m_TruncOrSelf(m_Deferred(ShiftAmt))),
6754 m_Shl(m_Value(), m_TruncOrSelf(m_Deferred(ShiftAmt)))));
6755
6756 // shift_amt > BitWidth-1 ? 0 : shift_op
6757 bool MatchedUGT =
6758 sd_match(N, m_Select(m_SetCC(m_Value(ShiftAmt),
6760 m_SpecificCondCode(ISD::SETUGT)),
6761 m_Zero(), LogicalShift));
6762 // shift_amt < BitWidth ? shift_op : 0
6763 bool MatchedULT =
6764 !MatchedUGT &&
6765 sd_match(N, m_Select(m_SetCC(m_Value(ShiftAmt),
6767 m_SpecificCondCode(ISD::SETULT)),
6768 LogicalShift, m_Zero()));
6769
6770 if (!MatchedUGT && !MatchedULT)
6771 return SDValue();
6772
6773 // In LLVM IR, the shift amount and the value-to-be-shifted are the same
6774 // type, whereas in PTX the shift amount is always i32. Therefore when
6775 // shifting types larger than i32, we can only do this transformation if we
6776 // know that the upper bits of the shift amount are known zero.
6777 SDValue ClampAmt = ShiftOp.getOperand(1);
6778 unsigned ClampAmtBits = ClampAmt.getValueSizeInBits();
6779 if (ShiftAmt.getValueSizeInBits() > ClampAmtBits &&
6780 DCI.DAG.computeKnownBits(ShiftAmt).countMaxActiveBits() > ClampAmtBits)
6781 return SDValue();
6782
6783 // Return a clamp shift operation, which has the same semantics as PTX shift.
6784 unsigned ClampOpc = ShiftOp.getOpcode() == ISD::SRL ? NVPTXISD::SRL_CLAMP
6785 : NVPTXISD::SHL_CLAMP;
6786 return DCI.DAG.getNode(ClampOpc, SDLoc(N), ShiftOp.getValueType(),
6787 ShiftOp.getOperand(0), ClampAmt);
6788}
6789
6792 SDValue VA = N->getOperand(1);
6793 EVT VectorVT = VA.getValueType();
6794 if (VectorVT != MVT::v4i8)
6795 return SDValue();
6796
6797 // We need to split vselect into individual per-element operations Because we
6798 // use BFE/BFI instruction for byte extraction/insertion, we do end up with
6799 // 32-bit values, so we may as well do comparison as i32 to avoid conversions
6800 // to/from i16 normally used for i8 values.
6802 SDLoc DL(N);
6803 SDValue VCond = N->getOperand(0);
6804 SDValue VB = N->getOperand(2);
6805 for (int I = 0; I < 4; ++I) {
6806 SDValue C = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i1, VCond,
6807 DCI.DAG.getConstant(I, DL, MVT::i32));
6808 SDValue EA = DCI.DAG.getAnyExtOrTrunc(
6809 DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, VA,
6810 DCI.DAG.getConstant(I, DL, MVT::i32)),
6811 DL, MVT::i32);
6812 SDValue EB = DCI.DAG.getAnyExtOrTrunc(
6813 DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, VB,
6814 DCI.DAG.getConstant(I, DL, MVT::i32)),
6815 DL, MVT::i32);
6816 E.push_back(DCI.DAG.getAnyExtOrTrunc(
6817 DCI.DAG.getNode(ISD::SELECT, DL, MVT::i32, C, EA, EB), DL, MVT::i8));
6818 }
6819 return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v4i8, E);
6820}
6821
6822static SDValue
6824 auto VT = N->getValueType(0);
6825 if (!DCI.isAfterLegalizeDAG() ||
6826 // only process v2*16 types
6827 !(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector() &&
6828 VT.getVectorNumElements() == 2))
6829 return SDValue();
6830
6831 auto Op0 = N->getOperand(0);
6832 auto Op1 = N->getOperand(1);
6833
6834 // Start out by assuming we want to take the lower 2 bytes of each i32
6835 // operand.
6836 uint64_t Op0Bytes = 0x10;
6837 uint64_t Op1Bytes = 0x54;
6838
6839 std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
6840 {&Op1, &Op1Bytes}};
6841
6842 // Check that each operand is an i16, truncated from an i32 operand. We'll
6843 // select individual bytes from those original operands. Optionally, fold in a
6844 // shift right of that original operand.
6845 for (auto &[Op, OpBytes] : OpData) {
6846 // Eat up any bitcast
6847 if (Op->getOpcode() == ISD::BITCAST)
6848 *Op = Op->getOperand(0);
6849
6850 if (!(Op->getValueType() == MVT::i16 && Op->getOpcode() == ISD::TRUNCATE &&
6851 Op->getOperand(0).getValueType() == MVT::i32))
6852 return SDValue();
6853
6854 // If the truncate has multiple uses, this optimization can increase
6855 // register pressure
6856 if (!Op->hasOneUse())
6857 return SDValue();
6858
6859 *Op = Op->getOperand(0);
6860
6861 // Optionally, fold in a shift-right of the original operand and let permute
6862 // pick the two higher bytes of the original value directly.
6863 if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Op->getOperand(1))) {
6864 if (cast<ConstantSDNode>(Op->getOperand(1))->getZExtValue() == 16) {
6865 // Shift the PRMT byte selector to pick upper bytes from each respective
6866 // value, instead of the lower ones: 0x10 -> 0x32, 0x54 -> 0x76
6867 assert((*OpBytes == 0x10 || *OpBytes == 0x54) &&
6868 "PRMT selector values out of range");
6869 *OpBytes += 0x22;
6870 *Op = Op->getOperand(0);
6871 }
6872 }
6873 }
6874
6875 SDLoc DL(N);
6876 auto &DAG = DCI.DAG;
6877
6878 auto PRMT =
6879 getPRMT(DAG.getBitcast(MVT::i32, Op0), DAG.getBitcast(MVT::i32, Op1),
6880 (Op1Bytes << 8) | Op0Bytes, DL, DAG);
6881 return DAG.getBitcast(VT, PRMT);
6882}
6883
6886 auto *ASCN1 = cast<AddrSpaceCastSDNode>(N);
6887
6888 if (auto *ASCN2 = dyn_cast<AddrSpaceCastSDNode>(ASCN1->getOperand(0))) {
6889 assert(ASCN2->getDestAddressSpace() == ASCN1->getSrcAddressSpace());
6890
6891 // Fold asc[B -> A](asc[A -> B](x)) -> x
6892 if (ASCN1->getDestAddressSpace() == ASCN2->getSrcAddressSpace())
6893 return ASCN2->getOperand(0);
6894 }
6895
6896 return SDValue();
6897}
6898
6899// Given a constant selector value and a prmt mode, return the selector value
6900// normalized to the generic prmt mode. See the PTX ISA documentation for more
6901// details:
6902// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
6903static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
6904 assert(Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6905
6907 return Selector;
6908
6909 const unsigned V = Selector.trunc(2).getZExtValue();
6910
6911 const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
6912 unsigned S3) {
6913 return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));
6914 };
6915
6916 switch (Mode) {
6918 return GetSelector(V, V + 1, V + 2, V + 3);
6920 return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);
6922 return GetSelector(V, V, V, V);
6924 return GetSelector(V, std::max(V, 1U), std::max(V, 2U), 3U);
6926 return GetSelector(0, std::min(V, 1U), std::min(V, 2U), V);
6928 unsigned V1 = (V & 1) << 1;
6929 return GetSelector(V1, V1 + 1, V1, V1 + 1);
6930 }
6931 default:
6932 llvm_unreachable("Invalid PRMT mode");
6933 }
6934}
6935
6936static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
6937 assert(A.getBitWidth() == 32 && B.getBitWidth() == 32 &&
6938 Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6939 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6940 APInt BitField = B.concat(A);
6941 APInt SelectorVal = getPRMTSelector(Selector, Mode);
6942 APInt Result(32, 0);
6943 for (unsigned I : llvm::seq(4U)) {
6944 APInt Sel = SelectorVal.extractBits(4, I * 4);
6945 unsigned Idx = Sel.getLoBits(3).getZExtValue();
6946 unsigned Sign = Sel.getHiBits(1).getZExtValue();
6947 APInt Byte = BitField.extractBits(8, Idx * 8);
6948 if (Sign)
6949 Byte = Byte.ashr(8);
6950 Result.insertBits(Byte, I * 8);
6951 }
6952 return Result;
6953}
6954
6956 CodeGenOptLevel OptLevel) {
6957 if (OptLevel == CodeGenOptLevel::None)
6958 return SDValue();
6959
6960 // Constant fold PRMT
6961 if (isa<ConstantSDNode>(N->getOperand(0)) &&
6962 isa<ConstantSDNode>(N->getOperand(1)) &&
6963 isa<ConstantSDNode>(N->getOperand(2)))
6964 return DCI.DAG.getConstant(computePRMT(N->getConstantOperandAPInt(0),
6965 N->getConstantOperandAPInt(1),
6966 N->getConstantOperandAPInt(2),
6967 N->getConstantOperandVal(3)),
6968 SDLoc(N), N->getValueType(0));
6969 return SDValue();
6970}
6971
6972// During call lowering we wrap the return values in a ProxyReg node which
6973// depend on the chain value produced by the completed call. This ensures that
6974// the full call is emitted in cases where libcalls are used to legalize
6975// operations. To improve the functioning of other DAG combines we pull all
6976// operations we can through one of these nodes, ensuring that the ProxyReg
6977// directly wraps a load. That is:
6978//
6979// (ProxyReg (zext (load retval0))) => (zext (ProxyReg (load retval0)))
6980//
6983 switch (R.getOpcode()) {
6984 case ISD::TRUNCATE:
6985 case ISD::ANY_EXTEND:
6986 case ISD::SIGN_EXTEND:
6987 case ISD::ZERO_EXTEND:
6988 case ISD::BITCAST: {
6989 if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI))
6990 return DCI.DAG.getNode(R.getOpcode(), SDLoc(R), R.getValueType(), V);
6991 return SDValue();
6992 }
6993 case ISD::SHL:
6994 case ISD::SRL:
6995 case ISD::SRA:
6996 case ISD::OR: {
6997 if (SDValue A = sinkProxyReg(R.getOperand(0), Chain, DCI))
6998 if (SDValue B = sinkProxyReg(R.getOperand(1), Chain, DCI))
6999 return DCI.DAG.getNode(R.getOpcode(), SDLoc(R), R.getValueType(), A, B);
7000 return SDValue();
7001 }
7002 case ISD::Constant:
7003 return R;
7004 case ISD::LOAD:
7005 case NVPTXISD::LoadV2:
7006 case NVPTXISD::LoadV4: {
7007 return DCI.DAG.getNode(NVPTXISD::ProxyReg, SDLoc(R), R.getValueType(),
7008 {Chain, R});
7009 }
7010 case ISD::BUILD_VECTOR: {
7011 if (DCI.isBeforeLegalize())
7012 return SDValue();
7013
7015 for (auto &Op : R->ops()) {
7016 SDValue V = sinkProxyReg(Op, Chain, DCI);
7017 if (!V)
7018 return SDValue();
7019 Ops.push_back(V);
7020 }
7021 return DCI.DAG.getNode(ISD::BUILD_VECTOR, SDLoc(R), R.getValueType(), Ops);
7022 }
7024 if (DCI.isBeforeLegalize())
7025 return SDValue();
7026
7027 if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI))
7029 R.getValueType(), V, R.getOperand(1));
7030 return SDValue();
7031 }
7032 default:
7033 return SDValue();
7034 }
7035}
7036
7037static unsigned getF16SubOpc(Intrinsic::ID AddIntrinsicID) {
7038 switch (AddIntrinsicID) {
7039 default:
7040 break;
7041 case Intrinsic::nvvm_add_rn_sat_f16:
7042 case Intrinsic::nvvm_add_rn_sat_v2f16:
7043 return NVPTXISD::SUB_RN_SAT;
7044 case Intrinsic::nvvm_add_rn_ftz_sat_f16:
7045 case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
7046 return NVPTXISD::SUB_RN_FTZ_SAT;
7047 }
7048 llvm_unreachable("Invalid F16 add intrinsic");
7049}
7050
7052 Intrinsic::ID AddIntrinsicID) {
7053 SDValue Op1 = N->getOperand(1);
7054 SDValue Op2 = N->getOperand(2);
7055
7056 SDValue SubOp1, SubOp2;
7057
7058 if (Op1.getOpcode() == ISD::FNEG) {
7059 SubOp1 = Op2;
7060 SubOp2 = Op1.getOperand(0);
7061 } else if (Op2.getOpcode() == ISD::FNEG) {
7062 SubOp1 = Op1;
7063 SubOp2 = Op2.getOperand(0);
7064 } else {
7065 return SDValue();
7066 }
7067
7068 SDLoc DL(N);
7069 return DAG.getNode(getF16SubOpc(AddIntrinsicID), DL, N->getValueType(0),
7070 SubOp1, SubOp2);
7071}
7072
7075 const NVPTXSubtarget &STI) {
7076 unsigned IID = N->getConstantOperandVal(0);
7077
7078 switch (IID) {
7079 default:
7080 break;
7081 case Intrinsic::nvvm_add_rn_sat_f16:
7082 case Intrinsic::nvvm_add_rn_ftz_sat_f16:
7083 case Intrinsic::nvvm_add_rn_sat_v2f16:
7084 case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
7085 return combineF16AddWithNeg(N, DCI.DAG, IID);
7086 }
7087 return SDValue();
7088}
7089
7092
7093 SDValue Chain = N->getOperand(0);
7094 SDValue Reg = N->getOperand(1);
7095
7096 // If the ProxyReg is not wrapping a load, try to pull the operations through
7097 // the ProxyReg.
7098 if (Reg.getOpcode() != ISD::LOAD) {
7099 if (SDValue V = sinkProxyReg(Reg, Chain, DCI))
7100 return V;
7101 }
7102
7103 return SDValue();
7104}
7105
7106SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
7107 DAGCombinerInfo &DCI) const {
7109 switch (N->getOpcode()) {
7110 default:
7111 break;
7112 case ISD::ADD:
7113 return PerformADDCombine(N, DCI, OptLevel);
7114 case ISD::ADDRSPACECAST:
7115 return combineADDRSPACECAST(N, DCI);
7116 case ISD::SIGN_EXTEND:
7117 case ISD::ZERO_EXTEND:
7118 return combineSZExtToMulWide(N, DCI, OptLevel);
7119 case ISD::BUILD_VECTOR:
7120 return PerformBUILD_VECTORCombine(N, DCI);
7122 return PerformEXTRACTCombine(N, DCI);
7123 case ISD::FADD:
7124 return performFADDCombine(N, DCI, OptLevel);
7125 case ISD::FMA:
7126 case ISD::FMUL:
7127 case ISD::FSUB:
7128 return performScalarizeV2F32Op(N, DCI, OptLevel);
7129 case ISD::FMAXNUM:
7130 case ISD::FMINNUM:
7131 case ISD::FMAXIMUM:
7132 case ISD::FMINIMUM:
7133 case ISD::FMAXIMUMNUM:
7134 case ISD::FMINIMUMNUM:
7135 return PerformFMinMaxCombine(N, DCI, STI.getPTXVersion(),
7136 STI.getSmVersion());
7137 case ISD::LOAD:
7138 case NVPTXISD::LoadV2:
7139 case NVPTXISD::LoadV4:
7140 return combineLOAD(N, DCI, STI);
7141 case ISD::MUL:
7142 return PerformMULCombine(N, DCI, OptLevel);
7143 case NVPTXISD::PRMT:
7144 return combinePRMT(N, DCI, OptLevel);
7145 case NVPTXISD::ProxyReg:
7146 return combineProxyReg(N, DCI);
7147 case ISD::SETCC:
7148 return PerformSETCCCombine(N, DCI, STI.getSmVersion());
7149 case ISD::SHL:
7150 return PerformSHLCombine(N, DCI, OptLevel);
7151 case ISD::SREM:
7152 case ISD::UREM:
7153 return PerformREMCombine(N, DCI, OptLevel);
7154 case ISD::STORE:
7155 case NVPTXISD::StoreV2:
7156 case NVPTXISD::StoreV4:
7157 return combineSTORE(N, DCI, STI);
7158 case ISD::SELECT:
7159 return PerformSELECTShiftCombine(N, DCI);
7160 case ISD::VSELECT:
7161 return PerformVSELECTCombine(N, DCI);
7163 return combineIntrinsicWOChain(N, DCI, STI);
7164 }
7165 return SDValue();
7166}
7167
7170 // Handle bitcasting to v2i8 without hitting the default promotion
7171 // strategy which goes through stack memory.
7172 SDValue Op(Node, 0);
7173 EVT ToVT = Op->getValueType(0);
7174 if (ToVT != MVT::v2i8) {
7175 return;
7176 }
7177
7178 // Bitcast to i16 and unpack elements into a vector
7179 SDLoc DL(Node);
7180 SDValue AsInt = DAG.getBitcast(MVT::i16, Op->getOperand(0));
7181 SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
7182 SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
7183 SDValue Vec1 =
7184 DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
7185 DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8}));
7186 Results.push_back(
7187 DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
7188}
7189
7192 SDValue Chain = N->getOperand(0);
7193 SDValue Intrin = N->getOperand(1);
7194 SDLoc DL(N);
7195
7196 // Get the intrinsic ID
7197 unsigned IntrinNo = Intrin.getNode()->getAsZExtVal();
7198 switch (IntrinNo) {
7199 default:
7200 return;
7201 case Intrinsic::nvvm_ldu_global_i:
7202 case Intrinsic::nvvm_ldu_global_f:
7203 case Intrinsic::nvvm_ldu_global_p: {
7204 EVT ResVT = N->getValueType(0);
7205
7206 if (ResVT.isVector()) {
7207 // Vector LDG/LDU
7208
7209 unsigned NumElts = ResVT.getVectorNumElements();
7210 EVT EltVT = ResVT.getVectorElementType();
7211
7212 // Since LDU/LDG are target nodes, we cannot rely on DAG type
7213 // legalization.
7214 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
7215 // loaded type to i16 and propagate the "real" type as the memory type.
7216 bool NeedTrunc = false;
7217 if (EltVT.getSizeInBits() < 16) {
7218 EltVT = MVT::i16;
7219 NeedTrunc = true;
7220 }
7221
7222 unsigned Opcode = 0;
7223 SDVTList LdResVTs;
7224
7225 switch (NumElts) {
7226 default:
7227 return;
7228 case 2:
7229 Opcode = NVPTXISD::LDUV2;
7230 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
7231 break;
7232 case 4: {
7233 Opcode = NVPTXISD::LDUV4;
7234 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
7235 LdResVTs = DAG.getVTList(ListVTs);
7236 break;
7237 }
7238 }
7239
7240 SmallVector<SDValue, 8> OtherOps;
7241
7242 // Copy regular operands
7243
7244 OtherOps.push_back(Chain); // Chain
7245 // Skip operand 1 (intrinsic ID)
7246 // Others
7247 OtherOps.append(N->op_begin() + 2, N->op_end());
7248
7250
7251 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps,
7252 MemSD->getMemoryVT(),
7253 MemSD->getMemOperand());
7254
7255 SmallVector<SDValue, 4> ScalarRes;
7256
7257 for (unsigned i = 0; i < NumElts; ++i) {
7258 SDValue Res = NewLD.getValue(i);
7259 if (NeedTrunc)
7260 Res =
7261 DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
7262 ScalarRes.push_back(Res);
7263 }
7264
7265 SDValue LoadChain = NewLD.getValue(NumElts);
7266
7267 SDValue BuildVec =
7268 DAG.getBuildVector(ResVT, DL, ScalarRes);
7269
7270 Results.push_back(BuildVec);
7271 Results.push_back(LoadChain);
7272 } else {
7273 // i8 LDG/LDU
7274 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
7275 "Custom handling of non-i8 ldu/ldg?");
7276
7277 // Just copy all operands as-is
7279
7280 // Force output to i16
7281 SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
7282
7284
7285 // We make sure the memory type is i8, which will be used during isel
7286 // to select the proper instruction.
7287 SDValue NewLD =
7289 MVT::i8, MemSD->getMemOperand());
7290
7291 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
7292 NewLD.getValue(0)));
7293 Results.push_back(NewLD.getValue(1));
7294 }
7295 return;
7296 }
7297
7298 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
7299 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
7300 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
7301 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
7302 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
7303 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
7304 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
7305 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
7306 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
7307 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
7308 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
7309 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
7310 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
7311 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
7312 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
7313 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
7314 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
7315 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
7316 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
7317 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
7318 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
7319 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
7320 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
7321 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
7322 if (auto Res = lowerTcgen05Ld(N, DAG)) {
7323 Results.push_back(Res->first);
7324 Results.push_back(Res->second);
7325 }
7326 return;
7327
7328 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
7329 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
7330 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
7331 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
7332 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
7333 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
7334 if (auto Res = lowerTcgen05Ld(N, DAG, /*HasOffset=*/true)) {
7335 Results.push_back(Res->first);
7336 Results.push_back(Res->second);
7337 }
7338 return;
7339
7340 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_i32:
7341 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_f32:
7342 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_i32:
7343 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_f32:
7344 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_i32:
7345 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_f32:
7346 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_i32:
7347 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_f32:
7348 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_i32:
7349 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_f32:
7350 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_i32:
7351 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_f32:
7352 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_i32:
7353 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_f32:
7354 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_i32:
7355 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_f32:
7356 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_i32:
7357 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_f32:
7358 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_i32:
7359 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_f32:
7360 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_i32:
7361 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_f32:
7362 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_i32:
7363 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_f32:
7364 if (auto Res = lowerTcgen05LdRed(N, DAG)) {
7365 Results.push_back(std::get<0>(*Res));
7366 Results.push_back(std::get<1>(*Res));
7367 Results.push_back(std::get<2>(*Res));
7368 }
7369 return;
7370 }
7371}
7372
7375 // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
7376 // result so that it can pass the legalization
7377 SDLoc DL(N);
7378 SDValue Chain = N->getOperand(0);
7379 SDValue Reg = N->getOperand(1);
7380 SDValue Glue = N->getOperand(2);
7381
7382 assert(Reg.getValueType() == MVT::i128 &&
7383 "Custom lowering for CopyFromReg with 128-bit reg only");
7384 SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(1),
7385 N->getValueType(2)};
7386 SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
7387
7388 SDValue NewValue = DAG.getNode(ISD::CopyFromReg, DL, ResultsType, NewOps);
7389 SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128,
7390 {NewValue.getValue(0), NewValue.getValue(1)});
7391
7392 Results.push_back(Pair);
7393 Results.push_back(NewValue.getValue(2));
7394 Results.push_back(NewValue.getValue(3));
7395}
7396
7398 const TargetLowering &TLI,
7400 SDValue Chain = N->getOperand(0);
7401 SDValue Reg = N->getOperand(1);
7402
7403 MVT VT = TLI.getRegisterType(*DAG.getContext(), Reg.getValueType());
7404
7405 SDValue NewReg = DAG.getAnyExtOrTrunc(Reg, SDLoc(N), VT);
7406 SDValue NewProxy =
7407 DAG.getNode(NVPTXISD::ProxyReg, SDLoc(N), VT, {Chain, NewReg});
7408 SDValue Res = DAG.getAnyExtOrTrunc(NewProxy, SDLoc(N), N->getValueType(0));
7409
7410 Results.push_back(Res);
7411}
7412
7414 const NVPTXSubtarget &STI,
7416 assert(N->getValueType(0) == MVT::i128 &&
7417 "Custom lowering for atomic128 only supports i128");
7418
7420 SDLoc dl(N);
7421
7422 if (!STI.hasAtomSwap128()) {
7425 "Support for b128 atomics introduced in PTX ISA version 8.3 and "
7426 "requires target sm_90.",
7427 dl.getDebugLoc()));
7428
7429 Results.push_back(DAG.getUNDEF(MVT::i128));
7430 Results.push_back(AN->getOperand(0)); // Chain
7431 return;
7432 }
7433
7435 Ops.push_back(AN->getOperand(0)); // Chain
7436 Ops.push_back(AN->getOperand(1)); // Ptr
7437 for (const auto &Op : AN->ops().drop_front(2)) {
7438 // Low part
7439 Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i64, Op,
7440 DAG.getIntPtrConstant(0, dl)));
7441 // High part
7442 Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i64, Op,
7443 DAG.getIntPtrConstant(1, dl)));
7444 }
7445 unsigned Opcode = N->getOpcode() == ISD::ATOMIC_SWAP
7448 SDVTList Tys = DAG.getVTList(MVT::i64, MVT::i64, MVT::Other);
7449 SDValue Result = DAG.getMemIntrinsicNode(Opcode, dl, Tys, Ops, MVT::i128,
7450 AN->getMemOperand());
7451 Results.push_back(DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i128,
7452 {Result.getValue(0), Result.getValue(1)}));
7453 Results.push_back(Result.getValue(2));
7454}
7455
7456void NVPTXTargetLowering::ReplaceNodeResults(
7458 switch (N->getOpcode()) {
7459 default:
7460 report_fatal_error("Unhandled custom legalization");
7461 case ISD::BITCAST:
7462 ReplaceBITCAST(N, DAG, Results);
7463 return;
7464 case ISD::LOAD:
7465 case ISD::MLOAD:
7466 replaceLoadVector(N, DAG, Results, STI);
7467 return;
7470 return;
7471 case ISD::CopyFromReg:
7473 return;
7474 case NVPTXISD::ProxyReg:
7475 replaceProxyReg(N, DAG, *this, Results);
7476 return;
7478 case ISD::ATOMIC_SWAP:
7479 replaceAtomicSwap128(N, DAG, STI, Results);
7480 return;
7481 }
7482}
7483
7486 Type *Ty = AI->getValOperand()->getType();
7487
7488 // Try to lower LLVM atomicrmw fadd to PTX atomic.add. This is complicated
7489 // by the weird FTZ behavior PTX atom.add has:
7490 // - atom.add.f32 on global memory flushes denormals
7491 // - atom.add.f32 on shared memory does not flush denormals
7492 // - atom.add.f16 and atomic.add.bf16 never flush denormals
7493 //
7494 // We lower to atom.add only if the function's FTZ behavior matches that of
7495 // atom.add; otherwise, we lower to a CAS loop. But we always allow
7496 // atomic.add.bf16; even though it never flushes denormals, we never flush
7497 // bf16 denormals when doing regular arithmetic, even when FTZ is enabled.
7498 if (AI->isFloatingPointOperation() &&
7500 const bool FTZ =
7503
7504 // AllowFTZAtomics forces atom.add regardless of the FTZ mismatch.
7505 if (Ty->isFloatTy()) {
7507 switch (AI->getPointerAddressSpace()) {
7509 UseNative |= FTZ;
7510 break;
7513 UseNative |= !FTZ;
7514 break;
7515 }
7516 if (UseNative)
7518 }
7519
7520 if (Ty->isHalfTy() && (!FTZ || AllowFTZAtomics) &&
7521 STI.getSmVersion() >= 70 && STI.getPTXVersion() >= 63)
7523
7524 if (Ty->isBFloatTy() && STI.getSmVersion() >= 90 &&
7525 STI.getPTXVersion() >= 78)
7527
7528 if (Ty->isDoubleTy() && STI.hasAtomAddF64())
7530 }
7531
7532 // PTX's only atomic fp op is `add`; all other ops expand to a CAS loop.
7533 if (AI->isFloatingPointOperation())
7535
7536 assert(Ty->isIntegerTy() && "Ty should be integer at this point");
7537 const unsigned BitWidth = cast<IntegerType>(Ty)->getBitWidth();
7538
7539 switch (AI->getOperation()) {
7540 default:
7543 if (BitWidth == 128)
7545 [[fallthrough]];
7549 switch (BitWidth) {
7550 case 8:
7551 case 16:
7553 case 32:
7555 case 64:
7556 if (STI.hasAtomBitwise64())
7559 case 128:
7561 default:
7562 llvm_unreachable("unsupported width encountered");
7563 }
7570 switch (BitWidth) {
7571 case 8:
7572 case 16:
7574 case 32:
7576 case 64:
7577 if (STI.hasAtomMinMax64())
7580 case 128:
7582 default:
7583 llvm_unreachable("unsupported width encountered");
7584 }
7587 switch (BitWidth) {
7588 case 32:
7590 case 8:
7591 case 16:
7592 case 64:
7593 case 128:
7595 default:
7596 llvm_unreachable("unsupported width encountered");
7597 }
7598 }
7599
7601}
7602
7604 const Instruction *I) const {
7605 // This function returns true iff the operation is emulated using a CAS-loop,
7606 // or if it has the memory order seq_cst (which is not natively supported in
7607 // the PTX `atom` instruction).
7608 //
7609 // atomicrmw and cmpxchg instructions not efficiently supported by PTX
7610 // are lowered to CAS emulation loops that preserve their memory order,
7611 // syncscope, and volatile semantics. For PTX, it is more efficient to use
7612 // atom.cas.relaxed.sco instructions within the loop, and fences before and
7613 // after the loop to restore order.
7614 //
7615 // Atomic instructions efficiently supported by PTX are lowered to
7616 // `atom.<op>.<sem>.<scope` instruction with their corresponding memory order
7617 // and scope. Since PTX does not support seq_cst, we emulate it by lowering to
7618 // a fence.sc followed by an atom according to the PTX atomics ABI
7619 // https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/atomic-abi.html
7620 if (auto *CI = dyn_cast<AtomicCmpXchgInst>(I))
7621 return (cast<IntegerType>(CI->getCompareOperand()->getType())
7622 ->getBitWidth() < STI.getMinCmpXchgSizeInBits()) ||
7623 CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent;
7624 if (auto *RI = dyn_cast<AtomicRMWInst>(I))
7626 RI->getOrdering() == AtomicOrdering::SequentiallyConsistent;
7627 return false;
7628}
7629
7631 const Instruction *I) const {
7632 // If the operation is emulated by a CAS-loop, we lower the instruction to
7633 // atom.<op>.relaxed, since AtomicExpandPass will insert fences for enforcing
7634 // the correct memory ordering around the CAS loop.
7635 //
7636 // When the operation is not emulated, but the memory order is seq_cst,
7637 // we must lower to "fence.sc.<scope>; atom.<op>.acquire.<scope>;" to conform
7638 // to the PTX atomics ABI.
7639 // https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/atomic-abi.html
7640 // For such cases, emitLeadingFence() will separately insert the leading
7641 // "fence.sc.<scope>;". Here, we only set the memory order to acquire.
7642 //
7643 // Otherwise, the operation is not emulated, and the memory order is not
7644 // seq_cst. In this case, the LLVM memory order is natively supported by the
7645 // PTX `atom` instruction, and we just lower to the corresponding
7646 // `atom.<op>.relaxed|acquire|release|acq_rel". For such cases, this function
7647 // will NOT be called.
7648 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7649 // I before its memory order was modified.
7650 if (auto *CI = dyn_cast<AtomicCmpXchgInst>(I);
7651 CI && CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent &&
7652 cast<IntegerType>(CI->getCompareOperand()->getType())->getBitWidth() >=
7653 STI.getMinCmpXchgSizeInBits())
7655 else if (auto *RI = dyn_cast<AtomicRMWInst>(I);
7656 RI && RI->getOrdering() == AtomicOrdering::SequentiallyConsistent &&
7659
7661}
7662
7664 Instruction *Inst,
7665 AtomicOrdering Ord) const {
7666 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7667 // `Inst` before its memory order was modified. We cannot enforce this with an
7668 // assert, because AtomicExpandPass will have modified the memory order
7669 // between the initial call to shouldInsertFencesForAtomic() and the call to
7670 // this function.
7671 if (!isa<AtomicCmpXchgInst>(Inst) && !isa<AtomicRMWInst>(Inst))
7672 return TargetLoweringBase::emitLeadingFence(Builder, Inst, Ord);
7673
7674 // Specialize for cmpxchg and atomicrmw
7675 auto SSID = getAtomicSyncScopeID(Inst);
7676 assert(SSID.has_value() && "Expected an atomic operation");
7677
7678 if (isReleaseOrStronger(Ord))
7679 return Builder.CreateFence(Ord == AtomicOrdering::SequentiallyConsistent
7682 SSID.value());
7683
7684 return nullptr;
7685}
7686
7688 Instruction *Inst,
7689 AtomicOrdering Ord) const {
7690 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7691 // `Inst` before its memory order was modified. See `emitLeadingFence` for why
7692 // this cannot be enforced with an assert. Specialize for cmpxchg and
7693 // atomicrmw
7694 auto *CI = dyn_cast<AtomicCmpXchgInst>(Inst);
7695 auto *RI = dyn_cast<AtomicRMWInst>(Inst);
7696 if (!CI && !RI)
7697 return TargetLoweringBase::emitTrailingFence(Builder, Inst, Ord);
7698
7699 auto SSID = getAtomicSyncScopeID(Inst);
7700 assert(SSID.has_value() && "Expected an atomic operation");
7701
7702 bool IsEmulated =
7703 CI ? cast<IntegerType>(CI->getCompareOperand()->getType())
7704 ->getBitWidth() < STI.getMinCmpXchgSizeInBits()
7706
7707 if (isAcquireOrStronger(Ord) && IsEmulated)
7708 return Builder.CreateFence(AtomicOrdering::Acquire, SSID.value());
7709
7710 return nullptr;
7711}
7712
7713// Rather than default to SINT when both UINT and SINT are custom, we only
7714// change the opcode when UINT is not legal and SINT is. UINT is preferred when
7715// both are custom since unsigned CVT instructions can lead to slightly better
7716// SASS code with fewer instructions.
7718 EVT ToVT) const {
7719 if (isOperationLegal(Op, ToVT))
7720 return Op;
7721 switch (Op) {
7722 case ISD::FP_TO_UINT:
7724 return ISD::FP_TO_SINT;
7725 break;
7729 break;
7730 case ISD::VP_FP_TO_UINT:
7731 if (isOperationLegal(ISD::VP_FP_TO_SINT, ToVT))
7732 return ISD::VP_FP_TO_SINT;
7733 break;
7734 default:
7735 break;
7736 }
7737 return Op;
7738}
7739
7740// Pin NVPTXTargetObjectFile's vtables to this file.
7742
7747
7749 const SelectionDAG &DAG, unsigned Depth) {
7750 SDValue A = Op.getOperand(0);
7751 SDValue B = Op.getOperand(1);
7752 ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand(2));
7753 unsigned Mode = Op.getConstantOperandVal(3);
7754
7755 if (!Selector)
7756 return;
7757
7758 KnownBits AKnown = DAG.computeKnownBits(A, Depth);
7759 KnownBits BKnown = DAG.computeKnownBits(B, Depth);
7760
7761 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
7762 assert(AKnown.getBitWidth() == 32 && BKnown.getBitWidth() == 32 &&
7763 "PRMT must have i32 operands");
7764 assert(Known.getBitWidth() == 32 && "PRMT must have i32 result");
7765 KnownBits BitField = BKnown.concat(AKnown);
7766
7767 APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);
7768 for (unsigned I : llvm::seq(4)) {
7769 APInt Sel = SelectorVal.extractBits(4, I * 4);
7770 unsigned Idx = Sel.getLoBits(3).getZExtValue();
7771 unsigned Sign = Sel.getHiBits(1).getZExtValue();
7772 KnownBits Byte = BitField.extractBits(8, Idx * 8);
7773 if (Sign)
7774 Byte = KnownBits::ashr(Byte, KnownBits::makeConstant(APInt(8, 7)));
7775 Known.insertBits(Byte, I * 8);
7776 }
7777}
7778
7779static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known) {
7781
7782 // We can't do anything without knowing the sign bit.
7783 auto ExtType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
7784 if (ExtType == ISD::SEXTLOAD)
7785 return;
7786
7787 // ExtLoading to vector types is weird and may not work well with known bits.
7788 auto DestVT = LD->getValueType(0);
7789 if (DestVT.isVector())
7790 return;
7791
7792 assert(Known.getBitWidth() == DestVT.getSizeInBits());
7793 auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(LD);
7794 Known.Zero.setHighBits(Known.getBitWidth() - ElementBitWidth);
7795}
7796
7798 const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
7799 const SelectionDAG &DAG, unsigned Depth) const {
7800 Known.resetAll();
7801
7802 switch (Op.getOpcode()) {
7803 case NVPTXISD::PRMT:
7804 computeKnownBitsForPRMT(Op, Known, DAG, Depth);
7805 break;
7806 case NVPTXISD::LoadV2:
7807 case NVPTXISD::LoadV4:
7808 case NVPTXISD::LoadV8:
7810 break;
7811 default:
7812 break;
7813 }
7814}
7815
7816static std::pair<APInt, APInt> getPRMTDemandedBits(const APInt &SelectorVal,
7817 const APInt &DemandedBits) {
7818 APInt DemandedLHS = APInt(32, 0);
7819 APInt DemandedRHS = APInt(32, 0);
7820
7821 for (unsigned I : llvm::seq(4)) {
7822 if (DemandedBits.extractBits(8, I * 8).isZero())
7823 continue;
7824
7825 APInt Sel = SelectorVal.extractBits(4, I * 4);
7826 unsigned Idx = Sel.getLoBits(3).getZExtValue();
7827 unsigned Sign = Sel.getHiBits(1).getZExtValue();
7828
7829 APInt &Src = Idx < 4 ? DemandedLHS : DemandedRHS;
7830 unsigned ByteStart = (Idx % 4) * 8;
7831 if (Sign)
7832 Src.setBit(ByteStart + 7);
7833 else
7834 Src.setBits(ByteStart, ByteStart + 8);
7835 }
7836
7837 return {DemandedLHS, DemandedRHS};
7838}
7839
7840// Replace undef with 0 as this is easier for other optimizations such as
7841// known bits.
7843 if (!Op)
7844 return SDValue();
7845 if (Op.isUndef())
7846 return DAG.getConstant(0, SDLoc(), MVT::i32);
7847 return Op;
7848}
7849
7851 const APInt &DemandedBits,
7852 SelectionDAG &DAG,
7853 const TargetLowering &TLI,
7854 unsigned Depth) {
7855 assert(PRMT.getOpcode() == NVPTXISD::PRMT);
7856 SDValue Op0 = PRMT.getOperand(0);
7857 SDValue Op1 = PRMT.getOperand(1);
7858 auto *SelectorConst = dyn_cast<ConstantSDNode>(PRMT.getOperand(2));
7859 if (!SelectorConst)
7860 return SDValue();
7861
7862 unsigned Mode = PRMT.getConstantOperandVal(3);
7863 const APInt Selector = getPRMTSelector(SelectorConst->getAPIntValue(), Mode);
7864
7865 // Try to simplify the PRMT to one of the inputs if the used bytes are all
7866 // from the same input in the correct order.
7867 const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
7868 const unsigned SelBits = (4 - LeadingBytes) * 4;
7869 if (Selector.getLoBits(SelBits) == APInt(32, 0x3210).getLoBits(SelBits))
7870 return Op0;
7871 if (Selector.getLoBits(SelBits) == APInt(32, 0x7654).getLoBits(SelBits))
7872 return Op1;
7873
7874 auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits(Selector, DemandedBits);
7875
7876 // Attempt to avoid multi-use ops if we don't need anything from them.
7877 SDValue DemandedOp0 =
7878 TLI.SimplifyMultipleUseDemandedBits(Op0, DemandedLHS, DAG, Depth + 1);
7879 SDValue DemandedOp1 =
7880 TLI.SimplifyMultipleUseDemandedBits(Op1, DemandedRHS, DAG, Depth + 1);
7881
7882 DemandedOp0 = canonicalizePRMTInput(DemandedOp0, DAG);
7883 DemandedOp1 = canonicalizePRMTInput(DemandedOp1, DAG);
7884 if ((DemandedOp0 && DemandedOp0 != Op0) ||
7885 (DemandedOp1 && DemandedOp1 != Op1)) {
7886 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
7887 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
7888 return getPRMT(Op0, Op1, Selector.getZExtValue(), SDLoc(PRMT), DAG);
7889 }
7890
7891 return SDValue();
7892}
7893
7895 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
7896 KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
7897 Known.resetAll();
7898
7899 switch (Op.getOpcode()) {
7900 case NVPTXISD::PRMT:
7902 *this, Depth)) {
7903 TLO.CombineTo(Op, Result);
7904 return true;
7905 }
7906 break;
7907 default:
7908 break;
7909 }
7910
7911 computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth);
7912 return false;
7913}
return SDValue()
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
constexpr LLT F32
static cl::list< std::string > UseNative("amdgpu-use-native", cl::desc("Comma separated list of functions to replace with native, or all"), cl::CommaSeparated, cl::ValueOptional, cl::Hidden)
AMDGPU Register Bank Select
This file declares a class to represent arbitrary precision floating point values and provide a varie...
This file implements a class to represent arbitrary precision integral constant values and operations...
static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
PerformADDCombineWithOperands - Try DAG combinations for an ADD with operands N0 and N1.
static SDValue PerformADDCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
static SDValue PerformVSELECTCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
static SDValue PerformMULCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
static SDValue PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget)
PerformBUILD_VECTORCombine - Target-specific dag combine xforms for ISD::BUILD_VECTOR.
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Function Alias Analysis Results
Atomic ordering constants.
This file contains the simple types necessary to represent the attributes associated with functions a...
#define X(NUM, ENUM, NAME)
Definition ELF.h:853
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
This file contains the declarations for the subclasses of Constant, which represent the different fla...
This file contains the declarations of entities that describe floating point environment and related ...
static bool IsIndirectCall(const MachineInstr *MI)
Module.h This file contains the declarations for the Module class.
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
static DebugLoc getDebugLoc(MachineBasicBlock::instr_iterator FirstMI, MachineBasicBlock::instr_iterator LastMI)
Return the first DebugLoc that has line number information, given a range of instructions.
Register Reg
Register const TargetRegisterInfo * TRI
#define T
NVPTX address space definition.
static SDValue reportInvalidTensormapReplaceUsage(SDValue Op, SelectionDAG &DAG, unsigned Val)
static SDValue combineADDRSPACECAST(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
static cl::opt< bool > sched4reg("nvptx-sched4reg", cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false))
static SDValue lowerTcgen05St(SDValue Op, SelectionDAG &DAG, bool hasOffset=false)
static SDValue PerformEXTRACTCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
static cl::opt< NVPTX::DivPrecisionLevel > UsePrecDivF32("nvptx-prec-divf32", cl::Hidden, cl::desc("NVPTX Specific: Override the precision of the lowering for f32 fdiv"), cl::values(clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0", "Use div.approx"), clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"), clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2", "Use IEEE Compliant F32 div.rnd if available (default)"), clEnumValN(NVPTX::DivPrecisionLevel::IEEE754_NoFTZ, "3", "Use IEEE Compliant F32 div.rnd if available, no FTZ")), cl::init(NVPTX::DivPrecisionLevel::IEEE754))
static bool isConstOne(const SDValue &Operand)
static cl::opt< unsigned > FMAContractLevelOpt("nvptx-fma-level", cl::Hidden, cl::desc("NVPTX Specific: FMA contraction (0: don't do it" " 1: do it 2: do it aggressively"), cl::init(2))
static bool IsPTXVectorType(MVT VT)
static SDValue PerformSELECTShiftCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
Transform patterns like: (select (ugt shift_amt, BitWidth-1), 0, (srl/shl x, shift_amt)) (select (ult...
static SDValue lowerLOADi1(LoadSDNode *LD, SelectionDAG &DAG)
static SDValue lowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG)
static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG, const DataLayout &DL, const TargetLowering &TL)
static SDValue lowerROT(SDValue Op, SelectionDAG &DAG)
static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL, LLVMContext &Ctx, CallingConv::ID CallConv, Type *Ty, SmallVectorImpl< EVT > &ValueVTs, SmallVectorImpl< uint64_t > &Offsets, uint64_t StartingOffset=0)
ComputePTXValueVTs - For the given Type Ty, returns the set of primitive legal-ish MVTs that compose ...
static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
static void replaceAtomicSwap128(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI, SmallVectorImpl< SDValue > &Results)
static unsigned getMinMax3Opcode(unsigned MinMax2Opcode)
Get 3-input version of a 2-input min/max opcode.
static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG, const NVPTXSubtarget &STI)
static SDValue lowerLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI)
static void replaceProxyReg(SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, SmallVectorImpl< SDValue > &Results)
static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
#define TCGEN05_LD_RED_INST(SHAPE, NUM, TYPE)
static SDValue lowerCTLZCTPOP(SDValue Op, SelectionDAG &DAG)
static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL, TargetLowering::DAGCombinerInfo &DCI)
static unsigned getTcgen05LdRedID(Intrinsic::ID IID)
static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
static SDValue combinePackingMovIntoStore(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, unsigned Front, unsigned Back)
Fold packing movs into a store.
static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG, SmallVectorImpl< SDValue > &Results)
static SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl, SelectionDAG &DAG, T GetElement)
static Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx, const DataLayout &DL)
static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT, const SDLoc &dl, SelectionDAG &DAG)
static SDValue combineSZExtToMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
static unsigned canMergeParamLoadStoresStartingAt(unsigned Idx, uint32_t AccessSize, const SmallVectorImpl< EVT > &ValueVTs, const SmallVectorImpl< T > &Offsets, Align ParamAlignment)
static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C)
static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG)
static SDValue PerformFMinMaxCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, unsigned PTXVersion, unsigned SmVersion)
PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into (fmaxnum3 a, b, c).
static std::optional< unsigned > getScalar3OpcodeForReduction(unsigned ReductionOpcode)
Get 3-input scalar reduction opcode.
static SDValue lowerIntrinsicWChain(SDValue Op, SelectionDAG &DAG)
static bool isNonCoalescableBuildVector(const SDValue &BV)
Check if a v2f32 BUILD_VECTOR provably packs values from non-adjacent register pairs (non-coalescable...
static bool isConstZero(const SDValue &Operand)
static unsigned getF16SubOpc(Intrinsic::ID AddIntrinsicID)
static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG)
static SDValue LowerTcgen05MMADisableOutputLane(SDValue Op, SelectionDAG &DAG)
static bool IsMulWideOperandDemotable(SDValue Op, unsigned OptSize, OperandSignedness &S)
IsMulWideOperandDemotable - Checks if the provided DAG node is an operand that can be demoted to OptS...
static unsigned getTcgen05MMADisableOutputLane(unsigned IID)
static std::pair< APInt, APInt > getPRMTDemandedBits(const APInt &SelectorVal, const APInt &DemandedBits)
static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode)
static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode)
static SDValue PerformREMCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
static SDValue lowerBSWAP(SDValue Op, SelectionDAG &DAG)
static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG)
static SDValue PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI)
static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known, const SelectionDAG &DAG, unsigned Depth)
static SDValue combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
Fold unpacking movs into a load by increasing the number of return values.
#define TCGEN05_LD_RED_INTR(SHAPE, NUM, TYPE)
static SDValue lowerTensormapReplaceElemtype(SDValue Op, SelectionDAG &DAG)
static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op, SelectionDAG &DAG)
static std::optional< std::pair< SDValue, SDValue > > lowerTcgen05Ld(SDNode *N, SelectionDAG &DAG, bool HasOffset=false)
static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG)
static std::optional< std::pair< SDValue, SDValue > > replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI)
replaceLoadVector - Convert vector loads into multi-output scalar loads.
static SDValue expandFSH64(SDValue A, SDValue B, SDValue ShiftAmount, SDLoc DL, unsigned Opcode, SelectionDAG &DAG)
static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS, unsigned OptSize, bool &IsSigned)
AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can be demoted to OptSize bits...
static std::pair< MemSDNode *, uint32_t > convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI)
static SDValue TryMULWIDECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply of M/2 bits that produces...
static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG)
static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT, SDLoc DL, TargetLowering::DAGCombinerInfo &DCI)
static SDValue buildTreeReduction(const SmallVector< SDValue > &Elements, EVT EltTy, ArrayRef< std::pair< unsigned, unsigned > > Ops, const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG)
Reduces the elements using the scalar operations provided.
static SDValue combineProxyReg(SDNode *N, TargetLowering::DAGCombinerInfo &DCI)
static SmallVector< unsigned, 16 > VectorizePTXValueVTs(const SmallVectorImpl< EVT > &ValueVTs, const SmallVectorImpl< T > &Offsets, Align ParamAlignment, bool IsVAArg=false)
static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL, SelectionDAG &DAG, unsigned Mode=NVPTX::PTXPrmtMode::NONE)
static SDValue matchMADConstOnePattern(SDValue Add)
static SDValue correctParamType(SDValue V, EVT ExpectedVT, ISD::ArgFlagsTy Flags, SelectionDAG &DAG, SDLoc dl)
static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags)
static cl::opt< bool > UsePrecSqrtF32("nvptx-prec-sqrtf32", cl::Hidden, cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."), cl::init(true))
static cl::opt< bool > AllowFTZAtomics("nvptx-allow-ftz-atomics", cl::Hidden, cl::desc("NVPTX Specific: Lower atomicrmw fadd to atom.add even when its " "FTZ behavior does not match the function's denormal mode."), cl::init(false))
static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known)
static APInt getPRMTSelector(const APInt &Selector, unsigned Mode)
static EVT promoteScalarIntegerPTX(const EVT VT)
PromoteScalarIntegerPTX Used to make sure the arguments/returns are suitable for passing and promote ...
static std::optional< std::tuple< SDValue, SDValue, SDValue > > lowerTcgen05LdRed(SDNode *N, SelectionDAG &DAG)
static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT, const APInt &DemandedBits, SelectionDAG &DAG, const TargetLowering &TLI, unsigned Depth)
static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG)
static SDValue canonicalizePRMTInput(SDValue Op, SelectionDAG &DAG)
static SDValue sinkProxyReg(SDValue R, SDValue Chain, TargetLowering::DAGCombinerInfo &DCI)
static SDValue lowerFSH(SDValue Op, SelectionDAG &DAG)
static SDValue lowerTensormapReplaceSwizzleMode(SDValue Op, SelectionDAG &DAG)
static SDValue combineIntrinsicWOChain(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const NVPTXSubtarget &STI)
static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG)
static SDValue PerformSETCCCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, unsigned int SmVersion)
static std::optional< std::pair< unsigned int, MVT > > getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI, unsigned AddressSpace)
static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG, Intrinsic::ID AddIntrinsicID)
static cl::opt< bool > UseApproxLog2F32("nvptx-approx-log2f32", cl::desc("NVPTX Specific: whether to use lg2.approx for log2"), cl::init(false))
Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it does NOT use lg2....
static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG)
static SDValue combineLOAD(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const NVPTXSubtarget &STI)
static SDValue combineSTORE(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const NVPTXSubtarget &STI)
static SDValue PerformSHLCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel)
PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
MachineInstr unsigned OpIdx
uint64_t High
#define P(N)
const SmallVectorImpl< MachineOperand > & Cond
static cl::opt< RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode > Mode("regalloc-enable-advisor", cl::Hidden, cl::init(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default), cl::desc("Enable regalloc advisor mode"), cl::values(clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default, "default", "Default"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Release, "release", "precompiled"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Development, "development", "for training")))
Contains matchers for matching SelectionDAG nodes and values.
This file contains some templates that are useful if you are working with the STL at all.
This file defines the SmallVector class.
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
This file describes how to lower LLVM code to machine code.
Value * RHS
Value * LHS
BinaryOperator * Mul
static const fltSemantics & IEEEsingle()
Definition APFloat.h:296
static APFloat getInf(const fltSemantics &Sem, bool Negative=false)
Factory for Positive and Negative Infinity.
Definition APFloat.h:1157
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt getLoBits(unsigned numBits) const
Compute an APInt containing numBits lowbits from this APInt.
Definition APInt.cpp:645
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1563
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1414
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:640
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:968
void setBit(unsigned BitPosition)
Set the given bit to 1 whose position is given as "bitPosition".
Definition APInt.h:1353
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1511
bool isSignedIntN(unsigned N) const
Check if this APInt has an N-bits signed integer value.
Definition APInt.h:436
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1137
LLVM_ABI APInt extractBits(unsigned numBits, unsigned bitPosition) const
Return an APInt with the extracted bits [bitPosition,bitPosition+numBits).
Definition APInt.cpp:483
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
bool sge(const APInt &RHS) const
Signed greater or equal comparison.
Definition APInt.h:1244
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
ArrayRef< T > slice(size_t N, size_t M) const
slice(n, m) - Chop off the first N elements of the array, and keep M elements in the array.
Definition ArrayRef.h:185
an instruction that atomically reads a memory location, combines it with another value,...
@ Add
*p = old + v
@ FAdd
*p = old + v
@ Min
*p = old <signed v ? old : v
@ Sub
*p = old - v
@ And
*p = old & v
@ Xor
*p = old ^ v
@ UIncWrap
Increment one up to a maximum value.
@ Max
*p = old >signed v ? old : v
@ UMin
*p = old <unsigned v ? old : v
@ UMax
*p = old >unsigned v ? old : v
@ UDecWrap
Decrement one until a minimum value or zero.
bool isFloatingPointOperation() const
BinOp getOperation() const
unsigned getPointerAddressSpace() const
Returns the address space of the pointer operand.
This is an SDNode representing atomic operations.
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
FunctionType * getFunctionType() const
const APInt & getAPIntValue() const
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI TypeSize getTypeAllocSize(Type *Ty) const
Returns the offset in bytes between successive objects of the specified type, including alignment pad...
LLVM_ABI Align getPrefTypeAlign(Type *Ty) const
Returns the preferred stack/global alignment for the specified type.
Diagnostic information for unsupported feature in backend.
void addFnAttr(Attribute::AttrKind Kind)
Add function attributes to this function.
Definition Function.cpp:638
DenormalMode getDenormalMode(const fltSemantics &FPType) const
Returns the denormal handling type for the default rounding mode of the function.
Definition Function.cpp:804
Module * getParent()
Get the module that this global value is contained inside of...
Common base class shared among various IRBuilders.
Definition IRBuilder.h:114
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
LLVM_ABI void diagnose(const DiagnosticInfo &DI)
Report a message to the currently installed diagnostic handler.
This class is used to represent ISD::LOAD nodes.
MCSection * getDataSection() const
static constexpr unsigned NoRegister
Definition MCRegister.h:60
Instances of this class represent a uniqued identifier for a section in the current translation unit.
Definition MCSection.h:573
StringRef getName() const
getName - Get the symbol name.
Definition MCSymbol.h:188
Machine Value Type.
static auto integer_fixedlen_vector_valuetypes()
SimpleValueType SimpleTy
unsigned getVectorNumElements() const
bool isVector() const
Return true if this is a vector value type.
bool isScalableVector() const
Return true if this is a vector value type where the runtime length is machine dependent.
static auto integer_valuetypes()
TypeSize getSizeInBits() const
Returns the size of the specified MVT in bits.
static auto fixedlen_vector_valuetypes()
TypeSize getStoreSize() const
Return the number of bytes overwritten by a store of the specified value type.
static MVT getVectorVT(MVT VT, unsigned NumElements)
MVT getVectorElementType() const
static MVT getIntegerVT(unsigned BitWidth)
static auto fp_valuetypes()
MVT getScalarType() const
If this is a vector, return the element type, otherwise return this.
static auto fp_fixedlen_vector_valuetypes()
DenormalMode getDenormalMode(const fltSemantics &FPType) const
Returns the denormal handling type for the default rounding mode of the function.
Function & getFunction()
Return the LLVM function that this machine code represents.
const TargetMachine & getTarget() const
getTarget - Return the target machine this machine code is compiled with
@ EK_Inline
EK_Inline - Jump table entries are emitted inline at their point of use.
@ MODereferenceable
The memory access is dereferenceable (i.e., doesn't trap).
@ MOLoad
The memory access reads data.
@ MOInvariant
The memory access always returns the same value (or traps).
@ MOStore
The memory access writes data.
This SDNode is used for target intrinsics that touch memory and need an associated MachineMemOperand.
This is an abstract virtual class for memory operations.
Align getAlign() const
MachineMemOperand * getMemOperand() const
Return the unique MachineMemOperand object describing the memory reference performed by operation.
EVT getMemoryVT() const
Return the type of the in-memory value.
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
static unsigned getFromTypeWidthForLoad(const MemSDNode *Mem)
bool hasTensormapReplaceSwizzleModeSupport(unsigned value) const
bool hasUsedBytesMaskPragma() const
bool hasTensormapReplaceElemtypeSupport(unsigned value) const
bool hasAtomSwap128() const
bool hasF32x2Instructions() const
bool has256BitVectorLoadStore(unsigned AS) const
AtomicOrdering atomicOperationOrderAfterFenceSplit(const Instruction *I) const override
ConstraintType getConstraintType(StringRef Constraint) const override
getConstraintType - Given a constraint letter, return the type of constraint it is for this target.
SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override
This callback is invoked for operations that are unsupported by the target, which are registered to u...
const NVPTXTargetMachine * nvTM
bool SimplifyDemandedBitsForTargetNode(SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts, KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth=0) const override
Attempt to simplify any target nodes based on the demanded bits/elts, returning true on success.
AtomicExpansionKind shouldExpandAtomicRMWInIR(const AtomicRMWInst *AI) const override
Returns how the IR-level AtomicExpand pass should expand the given AtomicRMW, if at all.
NVPTXTargetLowering(const NVPTXTargetMachine &TM, const NVPTXSubtarget &STI)
std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &, const SmallVectorImpl< ISD::OutputArg > &, std::optional< unsigned > FirstVAArg, const CallBase &CB, unsigned UniqueCallSite) const
unsigned getPreferredFPToIntOpcode(unsigned Op, EVT FromVT, EVT ToVT) const override
bool useF32FTZ(const MachineFunction &MF) const
SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const
SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps, bool &UseOneConst, bool Reciprocal) const override
Hooks for building estimates in place of slower divisions and square roots.
SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl< ISD::OutputArg > &Outs, const SmallVectorImpl< SDValue > &OutVals, const SDLoc &dl, SelectionDAG &DAG) const override
This hook must be implemented to lower outgoing return values, described by the Outs array,...
SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl< ISD::InputArg > &Ins, const SDLoc &dl, SelectionDAG &DAG, SmallVectorImpl< SDValue > &InVals) const override
This hook must be implemented to lower the incoming (formal) arguments, described by the Ins array,...
void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint, std::vector< SDValue > &Ops, SelectionDAG &DAG) const override
Lower the specified operand into the Ops vector.
SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const
Instruction * emitTrailingFence(IRBuilderBase &Builder, Instruction *Inst, AtomicOrdering Ord) const override
std::string getParamName(const Function *F, int Idx) const
TargetLoweringBase::LegalizeTypeAction getPreferredVectorAction(MVT VT) const override
Return the preferred vector type legalization action.
NVPTX::DivPrecisionLevel getDivF32Level(const MachineFunction &MF, const SDNode &N) const
bool shouldInsertFencesForAtomic(const Instruction *) const override
Whether AtomicExpandPass should automatically insert fences and reduce ordering for this atomic.
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const
EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Ctx, EVT VT) const override
Return the ValueType of the result of SETCC operations.
std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const override
Given a physical register constraint (e.g.
bool isLegalAddressingMode(const DataLayout &DL, const AddrMode &AM, Type *Ty, unsigned AS, Instruction *I=nullptr) const override
isLegalAddressingMode - Return true if the addressing mode represented by AM is legal for this target...
Instruction * emitLeadingFence(IRBuilderBase &Builder, Instruction *Inst, AtomicOrdering Ord) const override
Inserts in the IR a target-specific intrinsic specifying a fence.
void getTgtMemIntrinsic(SmallVectorImpl< IntrinsicInfo > &Infos, const CallBase &I, MachineFunction &MF, unsigned Intrinsic) const override
Given an intrinsic, checks if on the target the intrinsic will need to map to a MemIntrinsicNode (tou...
bool allowFMA(MachineFunction &MF, CodeGenOptLevel OptLevel) const
bool usePrecSqrtF32(const SDNode *N=nullptr) const
unsigned getJumpTableEncoding() const override
Return the entry encoding for a jump table in the current function.
SDValue LowerCall(CallLoweringInfo &CLI, SmallVectorImpl< SDValue > &InVals) const override
This hook must be implemented to lower calls into the specified DAG.
void computeKnownBitsForTargetNode(const SDValue Op, KnownBits &Known, const APInt &DemandedElts, const SelectionDAG &DAG, unsigned Depth=0) const override
Determine which of the bits specified in Mask are known to be either zero or one and return them in t...
MCSection * SelectSectionForGlobal(const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const override
static LLVM_ABI PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space.
Wrapper class for IR location info (IR ordering and DebugLoc) to be passed into SDNode creation funct...
const DebugLoc & getDebugLoc() const
Represents one node in the SelectionDAG.
ArrayRef< SDUse > ops() const
const APInt & getAsAPIntVal() const
Helper method returns the APInt value of a ConstantSDNode.
unsigned getOpcode() const
Return the SelectionDAG opcode value for this node.
bool hasOneUse() const
Return true if there is exactly one use of this node.
unsigned getIROrder() const
Return the node ordering.
SDNodeFlags getFlags() const
uint64_t getAsZExtVal() const
Helper method returns the zero-extended integer value of a ConstantSDNode.
unsigned getNumValues() const
Return the number of values defined/returned by this operator.
SDVTList getVTList() const
const SDValue & getOperand(unsigned Num) const
bool isUndef() const
Returns true if the node type is UNDEF or POISON.
iterator_range< user_iterator > users()
void setFlags(SDNodeFlags NewFlags)
Represents a use of a SDNode.
Unlike LLVM values, Selection DAG nodes may return multiple values as the result of a computation.
SDNode * getNode() const
get the SDNode which holds the desired result
bool hasOneUse() const
Return true if there is exactly one node using value ResNo of Node.
SDValue getValue(unsigned R) const
EVT getValueType() const
Return the ValueType of the referenced return value.
TypeSize getValueSizeInBits() const
Returns the size of the value in bits.
const SDValue & getOperand(unsigned i) const
uint64_t getScalarValueSizeInBits() const
uint64_t getConstantOperandVal(unsigned i) const
unsigned getOpcode() const
SectionKind - This is a simple POD value that classifies the properties of a section.
Definition SectionKind.h:22
This is used to represent a portion of an LLVM function in a low-level Data Dependence DAG representa...
LLVM_ABI SDValue getExtLoad(ISD::LoadExtType ExtType, const SDLoc &dl, EVT VT, SDValue Chain, SDValue Ptr, MachinePointerInfo PtrInfo, EVT MemVT, MaybeAlign Alignment=MaybeAlign(), MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes())
const SDValue & getRoot() const
Return the root tag of the SelectionDAG.
LLVM_ABI SDValue getAddrSpaceCast(const SDLoc &dl, EVT VT, SDValue Ptr, unsigned SrcAS, unsigned DestAS)
Return an AddrSpaceCastSDNode.
const TargetSubtargetInfo & getSubtarget() const
LLVM_ABI SDValue getMergeValues(ArrayRef< SDValue > Ops, const SDLoc &dl)
Create a MERGE_VALUES node from the given operands.
LLVM_ABI SDVTList getVTList(EVT VT)
Return an SDVTList that represents the list of values specified.
LLVM_ABI void ExtractVectorElements(SDValue Op, SmallVectorImpl< SDValue > &Args, unsigned Start=0, unsigned Count=0, EVT EltVT=EVT())
Append the extracted elements from Start to Count out of the vector Op in Args.
LLVM_ABI SDValue getFreeze(SDValue V)
Return a freeze using the SDLoc of the value operand.
LLVM_ABI SDValue getSymbolFunctionGlobalAddress(SDValue Op, Function **TargetFunction=nullptr)
Return a GlobalAddress of the function from the current module with name matching the given ExternalS...
LLVM_ABI SDValue getConstantFP(double Val, const SDLoc &DL, EVT VT, bool isTarget=false)
Create a ConstantFPSDNode wrapping a constant value.
LLVM_ABI SDValue getRegister(Register Reg, EVT VT)
LLVM_ABI SDValue getLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, MachinePointerInfo PtrInfo, MaybeAlign Alignment=MaybeAlign(), MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes(), const MDNode *Ranges=nullptr)
Loads are not normal binary operators: their result type is not determined by their operands,...
LLVM_ABI SDValue getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl, SDVTList VTList, ArrayRef< SDValue > Ops, EVT MemVT, MachinePointerInfo PtrInfo, Align Alignment, MachineMemOperand::Flags Flags=MachineMemOperand::MOLoad|MachineMemOperand::MOStore, LocationSize Size=LocationSize::precise(0), const AAMDNodes &AAInfo=AAMDNodes())
Creates a MemIntrinsicNode that may produce a result and takes a list of operands.
SDValue getSetCC(const SDLoc &DL, EVT VT, SDValue LHS, SDValue RHS, ISD::CondCode Cond, SDValue Chain=SDValue(), bool IsSignaling=false, SDNodeFlags Flags={})
Helper function to make it easier to build SetCC's if you just have an ISD::CondCode instead of an SD...
LLVM_ABI Align getEVTAlign(EVT MemoryVT) const
Compute the default alignment value for the given type.
LLVM_ABI SDValue getNOT(const SDLoc &DL, SDValue Val, EVT VT)
Create a bitwise NOT operation as (XOR Val, -1).
LLVM_ABI SDNode * MorphNodeTo(SDNode *N, unsigned Opc, SDVTList VTs, ArrayRef< SDValue > Ops)
This mutates the specified node to have the specified return type, opcode, and operands.
SDValue getUNDEF(EVT VT)
Return an UNDEF node. UNDEF does not have a useful SDLoc.
SDValue getCALLSEQ_END(SDValue Chain, SDValue Op1, SDValue Op2, SDValue InGlue, const SDLoc &DL)
Return a new CALLSEQ_END node, which always must have a glue result (to ensure it's not CSE'd).
SDValue getBuildVector(EVT VT, const SDLoc &DL, ArrayRef< SDValue > Ops)
Return an ISD::BUILD_VECTOR node.
LLVM_ABI SDValue getBitcast(EVT VT, SDValue V)
Return a bitcast using the SDLoc of the value operand, and casting to the provided type.
SDValue getSelect(const SDLoc &DL, EVT VT, SDValue Cond, SDValue LHS, SDValue RHS, SDNodeFlags Flags=SDNodeFlags())
Helper function to make it easier to build Select's if you just have operands and don't want to check...
const DataLayout & getDataLayout() const
LLVM_ABI SDValue getTokenFactor(const SDLoc &DL, SmallVectorImpl< SDValue > &Vals)
Creates a new TokenFactor containing Vals.
LLVM_ABI SDValue getConstant(uint64_t Val, const SDLoc &DL, EVT VT, bool isTarget=false, bool isOpaque=false)
Create a ConstantSDNode wrapping a constant value.
LLVM_ABI SDValue getTruncStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, MachinePointerInfo PtrInfo, EVT SVT, Align Alignment, MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes())
LLVM_ABI SDValue getStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, MachinePointerInfo PtrInfo, Align Alignment, MachineMemOperand::Flags MMOFlags=MachineMemOperand::MONone, const AAMDNodes &AAInfo=AAMDNodes())
Helper function to build ISD::STORE nodes.
LLVM_ABI SDValue getSignedConstant(int64_t Val, const SDLoc &DL, EVT VT, bool isTarget=false, bool isOpaque=false)
SDValue getCALLSEQ_START(SDValue Chain, uint64_t InSize, uint64_t OutSize, const SDLoc &DL)
Return a new CALLSEQ_START node, that starts new call frame, in which InSize bytes are set up inside ...
SDValue getSelectCC(const SDLoc &DL, SDValue LHS, SDValue RHS, SDValue True, SDValue False, ISD::CondCode Cond, SDNodeFlags Flags=SDNodeFlags())
Helper function to make it easier to build SelectCC's if you just have an ISD::CondCode instead of an...
LLVM_ABI SDValue getExternalSymbol(const char *Sym, EVT VT)
LLVM_ABI SDValue getAnyExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT)
Convert Op, which must be of integer type, to the integer type VT, by either any-extending or truncat...
LLVM_ABI SDValue getIntPtrConstant(uint64_t Val, const SDLoc &DL, bool isTarget=false)
LLVM_ABI SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, ArrayRef< SDUse > Ops)
Gets or creates the specified node.
LLVM_ABI SDValue getFPExtendOrRound(SDValue Op, const SDLoc &DL, EVT VT)
Convert Op, which must be of float type, to the float type VT, by either extending or rounding (by tr...
SDValue getTargetConstant(uint64_t Val, const SDLoc &DL, EVT VT, bool isOpaque=false)
LLVM_ABI SDValue getVectorIdxConstant(uint64_t Val, const SDLoc &DL, bool isTarget=false)
MachineFunction & getMachineFunction() const
LLVM_ABI KnownBits computeKnownBits(SDValue Op, unsigned Depth=0) const
Determine which bits of Op are known to be either zero or one and return them in Known.
LLVM_ABI SDValue getZExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT)
Convert Op, which must be of integer type, to the integer type VT, by either zero-extending or trunca...
SDValue getObjectPtrOffset(const SDLoc &SL, SDValue Ptr, TypeSize Offset)
Create an add instruction with appropriate flags when used for addressing some offset of an object.
LLVMContext * getContext() const
const SDValue & setRoot(SDValue N)
Set the current root tag of the SelectionDAG.
LLVM_ABI SDValue getTargetExternalSymbol(const char *Sym, EVT VT, unsigned TargetFlags=0)
ArrayRef< int > getMask() const
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
This class is used to represent ISD::STORE nodes.
Represent a constant reference to a string, i.e.
Definition StringRef.h:56
constexpr size_t size() const
Get the string size.
Definition StringRef.h:144
constexpr const char * data() const
Get a pointer to the start of the string (which may not be null terminated).
Definition StringRef.h:138
Align getStackAlign() const
getStackAlignment - This method returns the number of bytes to which the stack pointer must be aligne...
void setBooleanVectorContents(BooleanContent Ty)
Specify how the target extends the result of a vector boolean value from a vector of i1 to a wider ty...
void setOperationAction(unsigned Op, MVT VT, LegalizeAction Action)
Indicate that the specified operation does not work with the specified type and indicate what to do a...
void setMaxDivRemBitWidthSupported(unsigned SizeInBits)
Set the size in bits of the maximum div/rem the backend supports.
EVT getValueType(const DataLayout &DL, Type *Ty, bool AllowUnknown=false) const
Return the EVT corresponding to this LLVM type.
unsigned MaxStoresPerMemcpyOptSize
Likewise for functions with the OptSize attribute.
const TargetMachine & getTargetMachine() const
virtual unsigned getNumRegistersForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const
Certain targets require unusual breakdowns of certain types.
virtual MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const
Certain combinations of ABIs, Targets and features require that types are legal for some operations a...
void setOperationPromotedToType(unsigned Opc, MVT OrigVT, MVT DestVT)
Convenience method to set an operation to Promote and specify the type in a single call.
LegalizeTypeAction
This enum indicates whether a types are legal for a target, and if not, what action should be used to...
void addBypassSlowDiv(unsigned int SlowBitWidth, unsigned int FastBitWidth)
Tells the code generator which bitwidths to bypass.
virtual unsigned getNumRegisters(LLVMContext &Context, EVT VT, std::optional< MVT > RegisterVT=std::nullopt) const
Return the number of registers that this ValueType will eventually require.
void setMaxAtomicSizeInBitsSupported(unsigned SizeInBits)
Set the maximum atomic operation size supported by the backend.
virtual TargetLoweringBase::LegalizeTypeAction getPreferredVectorAction(MVT VT) const
Return the preferred vector type legalization action.
unsigned MaxStoresPerMemsetOptSize
Likewise for functions with the OptSize attribute.
void setBooleanContents(BooleanContent Ty)
Specify how the target extends the result of integer and floating point boolean values from i1 to a w...
unsigned MaxStoresPerMemmove
Specify maximum number of store instructions per memmove call.
void computeRegisterProperties(const TargetRegisterInfo *TRI)
Once all of the register classes are added, this allows us to compute derived properties we expose.
unsigned MaxStoresPerMemmoveOptSize
Likewise for functions with the OptSize attribute.
void addRegisterClass(MVT VT, const TargetRegisterClass *RC)
Add the specified register class as an available regclass for the specified value type.
bool isTypeLegal(EVT VT) const
Return true if the target has native support for the specified value type.
virtual MVT getPointerTy(const DataLayout &DL, uint32_t AS=0) const
Return the pointer type for the given address space, defaults to the pointer type from the data layou...
bool isOperationLegal(unsigned Op, EVT VT) const
Return true if the specified operation is legal on this target.
unsigned MaxStoresPerMemset
Specify maximum number of store instructions per memset call.
void setTruncStoreAction(MVT ValVT, MVT MemVT, LegalizeAction Action)
Indicate that the specified truncating store does not work with the specified type and indicate what ...
void setMinCmpXchgSizeInBits(unsigned SizeInBits)
Sets the minimum cmpxchg or ll/sc size supported by the backend.
void AddPromotedToType(unsigned Opc, MVT OrigVT, MVT DestVT)
If Opc/OrigVT is specified as being promoted, the promotion code defaults to trying a larger integer/...
AtomicExpansionKind
Enum that specifies what an atomic load/AtomicRMWInst is expanded to, if at all.
void setCondCodeAction(ArrayRef< ISD::CondCode > CCs, MVT VT, LegalizeAction Action)
Indicate that the specified condition code is or isn't supported on the target and indicate what to d...
void setTargetDAGCombine(ArrayRef< ISD::NodeType > NTs)
Targets should invoke this method for each target independent node that they want to provide a custom...
Align getMinStackArgumentAlignment() const
Return the minimum stack alignment of an argument.
void setLoadExtAction(unsigned ExtType, MVT ValVT, MVT MemVT, LegalizeAction Action)
Indicate that the specified load with extension does not work with the specified type and indicate wh...
std::vector< ArgListEntry > ArgListTy
virtual Instruction * emitTrailingFence(IRBuilderBase &Builder, Instruction *Inst, AtomicOrdering Ord) const
virtual Instruction * emitLeadingFence(IRBuilderBase &Builder, Instruction *Inst, AtomicOrdering Ord) const
Inserts in the IR a target-specific intrinsic specifying a fence.
unsigned MaxStoresPerMemcpy
Specify maximum number of store instructions per memcpy call.
void setSchedulingPreference(Sched::Preference Pref)
Specify the target scheduling preference.
MVT getRegisterType(MVT VT) const
Return the type of registers that this ValueType will eventually require.
void setJumpIsExpensive(bool isExpensive=true)
Tells the code generator not to expand logic operations on comparison predicates into separate sequen...
LegalizeAction getOperationAction(unsigned Op, EVT VT) const
Return how this operation should be treated: either it is legal, needs to be promoted to a larger siz...
This class defines information used to lower LLVM code to legal SelectionDAG operators that the targe...
SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts, SelectionDAG &DAG, unsigned Depth=0) const
More limited version of SimplifyDemandedBits that can be used to "lookthrough" ops that don't contrib...
virtual ConstraintType getConstraintType(StringRef Constraint) const
Given a constraint, return the type of constraint it is for this target.
virtual std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const
Given a physical register constraint (e.g.
TargetLowering(const TargetLowering &)=delete
SDValue expandRoundInexactToOdd(EVT ResultVT, SDValue Op, const SDLoc &DL, SelectionDAG &DAG) const
Truncate Op to ResultVT.
SDValue expandFP_ROUND(SDNode *Node, SelectionDAG &DAG) const
Expand round(fp) to fp conversion.
virtual void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint, std::vector< SDValue > &Ops, SelectionDAG &DAG) const
Lower the specified operand into the Ops vector.
Primary interface to the complete machine description for the target machine.
CodeGenOptLevel getOptLevel() const
Returns the optimization level: None, Less, Default, or Aggressive.
TargetOptions Options
MCSymbol * getSymbol(const GlobalValue *GV) const
FPOpFusion::FPOpFusionMode AllowFPOpFusion
AllowFPOpFusion - This flag is set by the -fp-contract=xxx option.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
virtual const TargetFrameLowering * getFrameLowering() const
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition Twine.h:82
static constexpr TypeSize getFixed(ScalarTy ExactSize)
Definition TypeSize.h:343
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
bool isFloatingPointTy() const
Return true if this is one of the floating-point types.
Definition Type.h:186
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
bool isVoidTy() const
Return true if this is 'void'.
Definition Type.h:141
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
A raw_ostream that writes to an std::string.
CallInst * Call
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
LLVM_ABI APInt pow(const APInt &X, int64_t N)
Compute X^N for N>=0.
Definition APInt.cpp:3207
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
NodeType
ISD::NodeType enum - This enum defines the target-independent operators for a SelectionDAG.
Definition ISDOpcodes.h:41
@ SETCC
SetCC operator - This evaluates to a true value iff the condition is true.
Definition ISDOpcodes.h:823
@ STACKRESTORE
STACKRESTORE has two operands, an input chain and a pointer to restore to it returns an output chain.
@ STACKSAVE
STACKSAVE - STACKSAVE has one operand, an input chain.
@ POISON
POISON - A poison node.
Definition ISDOpcodes.h:236
@ MLOAD
Masked load and store - consecutive vector load and store operations with additional mask operand tha...
@ SMUL_LOHI
SMUL_LOHI/UMUL_LOHI - Multiply two integers of type iN, producing a signed/unsigned value of type i[2...
Definition ISDOpcodes.h:275
@ BSWAP
Byte Swap and Counting operators.
Definition ISDOpcodes.h:783
@ VAEND
VAEND, VASTART - VAEND and VASTART have three operands: an input chain, pointer, and a SRCVALUE.
@ ADDC
Carry-setting nodes for multiple precision addition and subtraction.
Definition ISDOpcodes.h:294
@ ADD
Simple integer binary arithmetic operators.
Definition ISDOpcodes.h:264
@ LOAD
LOAD and STORE have token chains as their first operand, then the same operands as an LLVM load/store...
@ ANY_EXTEND
ANY_EXTEND - Used for integer types. The high bits are undefined.
Definition ISDOpcodes.h:857
@ FMA
FMA - Perform a * b + c with no intermediate rounding step.
Definition ISDOpcodes.h:518
@ INTRINSIC_VOID
OUTCHAIN = INTRINSIC_VOID(INCHAIN, INTRINSICID, arg1, arg2, ...) This node represents a target intrin...
Definition ISDOpcodes.h:220
@ SINT_TO_FP
[SU]INT_TO_FP - These operators convert integers (whose interpreted sign depends on the first letter)...
Definition ISDOpcodes.h:884
@ CONCAT_VECTORS
CONCAT_VECTORS(VECTOR0, VECTOR1, ...) - Given a number of values of vector type with the same length ...
Definition ISDOpcodes.h:584
@ VECREDUCE_FMAX
FMIN/FMAX nodes can have flags, for NaN/NoNaN variants.
@ FADD
Simple binary floating point operators.
Definition ISDOpcodes.h:417
@ VECREDUCE_FMAXIMUM
FMINIMUM/FMAXIMUM nodes propatate NaNs and signed zeroes using the llvm.minimum and llvm....
@ ABS
ABS - Determine the unsigned absolute value of a signed integer value of the same bitwidth.
Definition ISDOpcodes.h:747
@ SDIVREM
SDIVREM/UDIVREM - Divide two integers and produce both a quotient and remainder result.
Definition ISDOpcodes.h:280
@ BITCAST
BITCAST - This operator converts between integer, vector and FP values, as if the value was stored to...
Definition ISDOpcodes.h:997
@ BUILD_PAIR
BUILD_PAIR - This is the opposite of EXTRACT_ELEMENT in some ways.
Definition ISDOpcodes.h:254
@ CTLZ_ZERO_POISON
Definition ISDOpcodes.h:792
@ SIGN_EXTEND
Conversion operators.
Definition ISDOpcodes.h:848
@ READSTEADYCOUNTER
READSTEADYCOUNTER - This corresponds to the readfixedcounter intrinsic.
@ FNEG
Perform various unary floating-point operations inspired by libm.
@ BR_CC
BR_CC - Conditional branch.
@ SSUBO
Same for subtraction.
Definition ISDOpcodes.h:352
@ BRIND
BRIND - Indirect branch.
@ BR_JT
BR_JT - Jumptable branch.
@ SSUBSAT
RESULT = [US]SUBSAT(LHS, RHS) - Perform saturation subtraction on 2 integers with the same bit width ...
Definition ISDOpcodes.h:374
@ SELECT
Select(COND, TRUEVAL, FALSEVAL).
Definition ISDOpcodes.h:800
@ UNDEF
UNDEF - An undefined node.
Definition ISDOpcodes.h:233
@ EXTRACT_ELEMENT
EXTRACT_ELEMENT - This is used to get the lower or upper (determined by a Constant,...
Definition ISDOpcodes.h:247
@ VACOPY
VACOPY - VACOPY has 5 operands: an input chain, a destination pointer, a source pointer,...
@ CopyFromReg
CopyFromReg - This node indicates that the input value is a virtual or physical register that is defi...
Definition ISDOpcodes.h:230
@ SADDO
RESULT, BOOL = [SU]ADDO(LHS, RHS) - Overflow-aware nodes for addition.
Definition ISDOpcodes.h:348
@ MULHU
MULHU/MULHS - Multiply high - Multiply two integers of type iN, producing an unsigned/signed value of...
Definition ISDOpcodes.h:704
@ SHL
Shift and rotation operations.
Definition ISDOpcodes.h:769
@ VECTOR_SHUFFLE
VECTOR_SHUFFLE(VEC1, VEC2) - Returns a vector, of the same type as VEC1/VEC2.
Definition ISDOpcodes.h:649
@ EXTRACT_SUBVECTOR
EXTRACT_SUBVECTOR(VECTOR, IDX) - Returns a subvector from VECTOR.
Definition ISDOpcodes.h:614
@ FMINNUM_IEEE
FMINNUM_IEEE/FMAXNUM_IEEE - Perform floating-point minimumNumber or maximumNumber on two values,...
@ EXTRACT_VECTOR_ELT
EXTRACT_VECTOR_ELT(VECTOR, IDX) - Returns a single element from VECTOR identified by the (potentially...
Definition ISDOpcodes.h:576
@ CopyToReg
CopyToReg - This node has three operands: a chain, a register number to set to this value,...
Definition ISDOpcodes.h:224
@ ZERO_EXTEND
ZERO_EXTEND - Used for integer types, zeroing the new bits.
Definition ISDOpcodes.h:854
@ DEBUGTRAP
DEBUGTRAP - Trap intended to get the attention of a debugger.
@ SELECT_CC
Select with condition operator - This selects between a true value and a false value (ops #2 and #3) ...
Definition ISDOpcodes.h:815
@ ATOMIC_CMP_SWAP
Val, OUTCHAIN = ATOMIC_CMP_SWAP(INCHAIN, ptr, cmp, swap) For double-word atomic operations: ValLo,...
@ FMINNUM
FMINNUM/FMAXNUM - Perform floating-point minimum maximum on two values, following IEEE-754 definition...
@ SSHLSAT
RESULT = [US]SHLSAT(LHS, RHS) - Perform saturation left shift.
Definition ISDOpcodes.h:386
@ SMULO
Same for multiplication.
Definition ISDOpcodes.h:356
@ DYNAMIC_STACKALLOC
DYNAMIC_STACKALLOC - Allocate some number of bytes on the stack aligned to a specified boundary.
@ SIGN_EXTEND_INREG
SIGN_EXTEND_INREG - This operator atomically performs a SHL/SRA pair to sign extend a small value in ...
Definition ISDOpcodes.h:892
@ SMIN
[US]{MIN/MAX} - Binary minimum or maximum of signed or unsigned integers.
Definition ISDOpcodes.h:727
@ FP_EXTEND
X = FP_EXTEND(Y) - Extend a smaller FP type into a larger FP type.
Definition ISDOpcodes.h:982
@ VSELECT
Select with a vector condition (op #0) and two vector operands (ops #1 and #2), returning a vector re...
Definition ISDOpcodes.h:809
@ UADDO_CARRY
Carry-using nodes for multiple precision addition and subtraction.
Definition ISDOpcodes.h:328
@ BF16_TO_FP
BF16_TO_FP, FP_TO_BF16 - These operators are used to perform promotions and truncation for bfloat16.
@ FRAMEADDR
FRAMEADDR, RETURNADDR - These nodes represent llvm.frameaddress and llvm.returnaddress on the DAG.
Definition ISDOpcodes.h:110
@ STRICT_FP_TO_UINT
Definition ISDOpcodes.h:478
@ STRICT_FP_TO_SINT
STRICT_FP_TO_[US]INT - Convert a floating point value to a signed or unsigned integer.
Definition ISDOpcodes.h:477
@ FMINIMUM
FMINIMUM/FMAXIMUM - NaN-propagating minimum/maximum that also treat -0.0 as less than 0....
@ FP_TO_SINT
FP_TO_[US]INT - Convert a floating point value to a signed or unsigned integer.
Definition ISDOpcodes.h:930
@ READCYCLECOUNTER
READCYCLECOUNTER - This corresponds to the readcyclecounter intrinsic.
@ AND
Bitwise operators - logical and, logical or, logical xor.
Definition ISDOpcodes.h:739
@ TRAP
TRAP - Trapping instruction.
@ INTRINSIC_WO_CHAIN
RESULT = INTRINSIC_WO_CHAIN(INTRINSICID, arg1, arg2, ...) This node represents a target intrinsic fun...
Definition ISDOpcodes.h:205
@ ADDE
Carry-using nodes for multiple precision addition and subtraction.
Definition ISDOpcodes.h:304
@ INSERT_VECTOR_ELT
INSERT_VECTOR_ELT(VECTOR, VAL, IDX) - Returns VECTOR with the element at IDX replaced with VAL.
Definition ISDOpcodes.h:565
@ ATOMIC_SWAP
Val, OUTCHAIN = ATOMIC_SWAP(INCHAIN, ptr, amt) Val, OUTCHAIN = ATOMIC_LOAD_[OpName](INCHAIN,...
@ FP_ROUND
X = FP_ROUND(Y, TRUNC) - Rounding 'Y' from a larger floating point type down to the precision of the ...
Definition ISDOpcodes.h:963
@ ADDRSPACECAST
ADDRSPACECAST - This operator converts between pointers of different address spaces.
@ VECREDUCE_FMINIMUM
@ TRUNCATE
TRUNCATE - Completely drop the high bits.
Definition ISDOpcodes.h:860
@ VAARG
VAARG - VAARG has four operands: an input chain, a pointer, a SRCVALUE, and the alignment.
@ SHL_PARTS
SHL_PARTS/SRA_PARTS/SRL_PARTS - These operators are used for expanded integer shift operations.
Definition ISDOpcodes.h:837
@ FCOPYSIGN
FCOPYSIGN(X, Y) - Return the value of X with the sign of Y.
Definition ISDOpcodes.h:534
@ SADDSAT
RESULT = [US]ADDSAT(LHS, RHS) - Perform saturation addition on 2 integers with the same bit width (W)...
Definition ISDOpcodes.h:365
@ FMINIMUMNUM
FMINIMUMNUM/FMAXIMUMNUM - minimumnum/maximumnum that is same with FMINNUM_IEEE and FMAXNUM_IEEE besid...
@ SADDO_CARRY
Carry-using overflow-aware nodes for multiple precision addition and subtraction.
Definition ISDOpcodes.h:338
@ INTRINSIC_W_CHAIN
RESULT,OUTCHAIN = INTRINSIC_W_CHAIN(INCHAIN, INTRINSICID, arg1, ...) This node represents a target in...
Definition ISDOpcodes.h:213
@ ABS_MIN_POISON
ABS with a poison result for INT_MIN.
Definition ISDOpcodes.h:751
@ BUILD_VECTOR
BUILD_VECTOR(ELT0, ELT1, ELT2, ELT3,...) - Return a fixed-width vector with the specified,...
Definition ISDOpcodes.h:556
LLVM_ABI bool allOperandsUndef(const SDNode *N)
Return true if the node has at least one operand and all operands of the specified node are ISD::UNDE...
This namespace contains an enum with a value for every intrinsic/builtin function known by LLVM.
LLVM_ABI StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
@ Bitcast
Perform the operation on a different, but equivalently sized type.
@ ATOMIC_CMP_SWAP_B128
These nodes are used to lower atomic instructions with i128 type.
@ DeviceParam
Definition NVPTX.h:215
@ EntryParam
Definition NVPTX.h:209
bool isPackedVectorTy(EVT VT)
DivPrecisionLevel
Definition NVPTX.h:278
match_combine_or< CastInst_match< OpTy, TruncInst >, OpTy > m_TruncOrSelf(const OpTy &Op)
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
match_deferred< Value > m_Deferred(Value *const &V)
Like m_Specific(), but works if the specific value to match is determined as part of the same match()...
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
@ User
could "use" a pointer
NodeAddr< NodeBase * > Node
Definition RDFGraph.h:381
This is an optimization pass for GlobalISel generic memory operations.
@ Low
Lower the current thread's priority such that it does not affect foreground tasks significantly.
Definition Threading.h:280
@ Offset
Definition DWP.cpp:558
detail::zippy< detail::zip_shortest, T, U, Args... > zip(T &&t, U &&u, Args &&...args)
zip iterator for two or more iteratable types.
Definition STLExtras.h:830
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM)
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
MaybeAlign getAlign(const CallInst &I, unsigned Index)
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1668
SDValue peekThroughFreeze(SDValue V)
Return the non-frozen source operand of V if it exists.
LLVM_ABI void ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL, Type *Ty, SmallVectorImpl< EVT > &ValueVTs, SmallVectorImpl< EVT > *MemVTs=nullptr, SmallVectorImpl< TypeSize > *Offsets=nullptr, TypeSize StartingOffset=TypeSize::getZero())
ComputeValueVTs - Given an LLVM IR type, compute a sequence of EVTs that represent all the individual...
Definition Analysis.cpp:119
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition STLExtras.h:2553
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
uint64_t PowerOf2Ceil(uint64_t A)
Returns the power of two which is greater than or equal to the given value.
Definition MathExtras.h:385
bool isReleaseOrStronger(AtomicOrdering AO)
OutputIt transform(R &&Range, OutputIt d_first, UnaryFunction F)
Wrapper function around std::transform to apply a function to a range and store the result elsewhere.
Definition STLExtras.h:2025
auto reverse(ContainerTy &&C)
Definition STLExtras.h:407
std::optional< SyncScope::ID > getAtomicSyncScopeID(const Instruction *I)
A helper function that returns an atomic operation's sync scope; returns std::nullopt if it is not an...
unsigned promoteScalarArgumentSize(unsigned size)
bool none_of(R &&Range, UnaryPredicate P)
Provide wrappers to std::none_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1752
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:163
bool shouldPassAsArray(Type *Ty)
constexpr uint64_t alignTo(uint64_t Size, Align A)
Returns a multiple of A needed to store Size bytes.
Definition Alignment.h:144
CodeGenOptLevel
Code generation optimization level.
Definition CodeGen.h:82
@ Default
-O2, -Os, -Oz
Definition CodeGen.h:85
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
AtomicOrdering
Atomic ordering for LLVM's memory model.
Align getFunctionByValParamAlign(const Function *F, Type *ArgTy, Align InitialAlign, const DataLayout &DL)
@ Sub
Subtraction of integers.
@ Add
Sum of integers.
@ FAdd
Sum of floats.
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
bool isParamGridConstant(const Argument &Arg)
bool isAcquireOrStronger(AtomicOrdering AO)
constexpr unsigned BitWidth
bool isKernelFunction(const Function &F)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
Function * getMaybeBitcastedCallee(const CallBase *CB)
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition Alignment.h:201
Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy, const DataLayout &DL)
Since function arguments are passed via .param space, we may want to increase their alignment in a wa...
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:863
#define N
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
constexpr uint64_t value() const
This is a hole in the type system and should not be abused.
Definition Alignment.h:77
@ PreserveSign
The sign of a flushed-to-zero number is preserved in the sign of 0.
DenormalModeKind Output
Denormal flushing mode for floating point instruction results in the default floating point environme...
Extended Value Type.
Definition ValueTypes.h:35
TypeSize getStoreSize() const
Return the number of bytes overwritten by a store of the specified value type.
Definition ValueTypes.h:418
bool isSimple() const
Test if the given EVT is simple (as opposed to being extended).
Definition ValueTypes.h:145
static EVT getVectorVT(LLVMContext &Context, EVT VT, unsigned NumElements, bool IsScalable=false)
Returns the EVT that represents a vector NumElements in length, where each element is of type VT.
Definition ValueTypes.h:70
EVT changeTypeToInteger() const
Return the type converted to an equivalently sized integer or vector with integer element type.
Definition ValueTypes.h:129
bool bitsGT(EVT VT) const
Return true if this has more bits than VT.
Definition ValueTypes.h:307
bool bitsLT(EVT VT) const
Return true if this has less bits than VT.
Definition ValueTypes.h:323
bool isFloatingPoint() const
Return true if this is a FP or a vector FP type.
Definition ValueTypes.h:155
ElementCount getVectorElementCount() const
Definition ValueTypes.h:373
bool is32BitVector() const
Return true if this is a 32-bit vector type.
Definition ValueTypes.h:220
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
Definition ValueTypes.h:396
uint64_t getScalarSizeInBits() const
Definition ValueTypes.h:408
MVT getSimpleVT() const
Return the SimpleValueType held in the specified simple EVT.
Definition ValueTypes.h:339
uint64_t getFixedSizeInBits() const
Return the size of the specified fixed width value type in bits.
Definition ValueTypes.h:404
bool isVector() const
Return true if this is a vector value type.
Definition ValueTypes.h:176
EVT getScalarType() const
If this is a vector type, return the element type, otherwise return this.
Definition ValueTypes.h:346
bool bitsEq(EVT VT) const
Return true if this has the same number of bits as VT.
Definition ValueTypes.h:279
LLVM_ABI Type * getTypeForEVT(LLVMContext &Context) const
This method returns an LLVM type corresponding to the specified EVT.
EVT getVectorElementType() const
Given a vector type, return the type of each element.
Definition ValueTypes.h:351
EVT changeElementType(LLVMContext &Context, EVT EltVT) const
Return a VT for a type whose attributes match ourselves with the exception of the element type that i...
Definition ValueTypes.h:121
bool isScalarInteger() const
Return true if this is an integer, but not a vector.
Definition ValueTypes.h:165
unsigned getVectorNumElements() const
Given a vector type, return the number of elements it contains.
Definition ValueTypes.h:359
bool isInteger() const
Return true if this is an integer or a vector integer type.
Definition ValueTypes.h:160
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:315
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
KnownBits concat(const KnownBits &Lo) const
Concatenate the bits from Lo onto the bottom of *this.
Definition KnownBits.h:247
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
void resetAll()
Resets the known state of all bits.
Definition KnownBits.h:72
unsigned countMaxActiveBits() const
Returns the maximum number of bits needed to represent all possible unsigned values with these known ...
Definition KnownBits.h:310
void insertBits(const KnownBits &SubBits, unsigned BitPosition)
Insert the bits from a smaller known bits starting at bitPosition.
Definition KnownBits.h:233
This class contains a discriminated union of information about pointers in memory operands,...
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition Alignment.h:106
These are IR-level optimization flags that may be propagated to SDNodes.
bool hasAllowContract() const
This represents a list of ValueType's that has been intern'd by a SelectionDAG.
This represents an addressing mode of: BaseGV + BaseOffs + BaseReg + Scale*ScaleReg + ScalableOffset*...
This structure contains all information that is necessary for lowering calls.
SmallVector< ISD::InputArg, 32 > Ins
SmallVector< ISD::OutputArg, 32 > Outs
Type * RetTy
Same as OrigRetTy, or partially legalized for soft float libcalls.
A convenience struct that encapsulates a DAG, and two SDValues for returning information from TargetL...