35#ifndef LLVM_ANALYSIS_IR2VEC_H
36#define LLVM_ANALYSIS_IR2VEC_H
90 std::vector<double> Data;
94 Embedding(
const std::vector<double> &V) : Data(V) {}
96 Embedding(std::initializer_list<double> IL) : Data(IL) {}
101 size_t size()
const {
return Data.size(); }
102 bool empty()
const {
return Data.empty(); }
105 assert(Itr < Data.size() &&
"Index out of bounds");
110 assert(Itr < Data.size() &&
"Index out of bounds");
124 const std::vector<double> &
getData()
const {
return Data; }
141 double Tolerance = 1e-4)
const;
160 std::vector<std::vector<Embedding>> Sections;
165 size_t TotalSize = 0;
166 unsigned Dimension = 0;
182 size_t size()
const {
return TotalSize; }
186 return static_cast<unsigned>(Sections.size());
190 const std::vector<Embedding> &
operator[](
unsigned SectionId)
const {
191 assert(SectionId < Sections.size() &&
"Invalid section ID");
192 return Sections[SectionId];
199 bool isValid()
const {
return TotalSize > 0; }
204 unsigned SectionId = 0;
205 size_t LocalIndex = 0;
210 : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
228 VocabMap &TargetVocab,
unsigned &Dim);
267 enum class Section :
unsigned {
278 static constexpr unsigned NumICmpPredicates =
281 static constexpr unsigned NumFCmpPredicates =
314#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
315#include "llvm/IR/Instruction.def"
316#undef LAST_OTHER_INST
318 static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
326 NumICmpPredicates + NumFCmpPredicates;
347 bool isValid()
const {
return Storage.size() == NumCanonicalEntries; }
351 return Storage.getDimension();
363 return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(
TypeID));
368 unsigned Index =
static_cast<unsigned>(Kind);
370 return OperandKindNames[Index];
381 assert(Opcode >= 1 && Opcode <= MaxOpcodes &&
"Invalid opcode");
387 return MaxOpcodes +
static_cast<unsigned>(getCanonicalTypeID(
TypeID));
393 return OperandBaseOffset + Index;
397 return PredicateBaseOffset + getPredicateLocalIndex(
P);
402 assert(Opcode >= 1 && Opcode <= MaxOpcodes &&
"Invalid opcode");
403 return Storage[
static_cast<unsigned>(Section::Opcodes)][Opcode - 1];
408 unsigned LocalIndex =
static_cast<unsigned>(getCanonicalTypeID(
TypeID));
409 return Storage[
static_cast<unsigned>(Section::CanonicalTypes)][LocalIndex];
413 unsigned LocalIndex =
static_cast<unsigned>(
getOperandKind(&Arg));
415 return Storage[
static_cast<unsigned>(Section::Operands)][LocalIndex];
419 unsigned LocalIndex = getPredicateLocalIndex(
P);
420 return Storage[
static_cast<unsigned>(Section::Predicates)][LocalIndex];
428 return Storage.begin();
435 return Storage.end();
449 ModuleAnalysisManager::Invalidator &Inv)
const;
452 constexpr static unsigned NumCanonicalEntries =
456 constexpr static unsigned OperandBaseOffset =
458 constexpr static unsigned PredicateBaseOffset =
464 getPredicateFromLocalIndex(
unsigned LocalIndex);
468 "FloatTy",
"VoidTy",
"LabelTy",
"MetadataTy",
"VectorTy",
469 "TokenTy",
"IntegerTy",
"ByteTy",
"FunctionTy",
"PointerTy",
470 "StructTy",
"ArrayTy",
"UnknownTy"};
471 static_assert(std::size(CanonicalTypeNames) ==
473 "CanonicalTypeNames array size must match MaxCanonicalType");
476 static constexpr StringLiteral OperandKindNames[] = {
"Function",
"Pointer",
477 "Constant",
"Variable"};
478 static_assert(std::size(OperandKindNames) ==
480 "OperandKindNames array size must match MaxOperandKind");
484 static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
508 static_assert(TypeIDMapping.size() ==
MaxTypeIDs,
509 "TypeIDMapping must cover all Type::TypeID values");
512 static StringRef getVocabKeyForCanonicalTypeID(
CanonicalTypeID CType) {
513 unsigned Index =
static_cast<unsigned>(CType);
515 return CanonicalTypeNames[
Index];
522 return TypeIDMapping[
Index];
530 return getPredicateFromLocalIndex(Index);
533 using VocabMap = std::map<std::string, Embedding>;
536 static VocabStorage buildVocabStorage(
const VocabMap &OpcVocab,
537 const VocabMap &TypeVocab,
538 const VocabMap &ArgVocab);
578 LLVM_ABI static std::unique_ptr<Embedder>
638 std::optional<ir2vec::VocabStorage> Vocab;
640 void emitError(
Error Err);
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
This file defines the DenseMap class.
Provides ErrorOr<T> smart pointer.
This header defines various interfaces for pass management in LLVM.
This file supports working with JSON data.
ModuleAnalysisManager MAM
static cl::opt< RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode > Mode("regalloc-enable-advisor", cl::Hidden, cl::init(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default), cl::desc("Enable regalloc advisor mode"), cl::values(clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default, "default", "Default"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Release, "release", "precompiled"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Development, "development", "for training")))
LLVM Basic Block Representation.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Lightweight error class with error context and mandatory checking.
Tagged union holding either a T or a Error.
IR2VecPrinterPass(raw_ostream &OS)
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
This analysis provides the vocabulary for IR2Vec.
IR2VecVocabAnalysis(ir2vec::VocabStorage &&Vocab)
IR2VecVocabAnalysis()=default
ir2vec::Vocabulary Result
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM)
static LLVM_ABI AnalysisKey Key
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
IR2VecVocabPrinterPass(raw_ostream &OS)
This is an important class for using LLVM in a threaded context.
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
A wrapper around a string literal that serves as a proxy for constructing global tables of StringRefs...
Represent a constant reference to a string, i.e.
TypeID
Definitions of all of the base types for the Type system.
LLVM Value Representation.
static LLVM_ABI std::unique_ptr< Embedder > create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab)
Factory method to create an Embedder object.
virtual Embedding computeEmbeddings(const Instruction &I) const =0
Function to compute the embedding for a given instruction.
Embedder(const Function &F, const Vocabulary &Vocab)
Embedding getBBVector(const BasicBlock &BB) const
Computes and returns the embedding for a given basic block in the function F.
virtual ~Embedder()=default
Embedding getInstVector(const Instruction &I) const
Computes and returns the embedding for a given instruction in the function F.
const float OpcWeight
Weights for different entities (like opcode, arguments, types) in the IR instructions to generate the...
const unsigned Dimension
Dimension of the vector representation; captured from the input vocabulary.
virtual void invalidateEmbeddings()
Invalidate embeddings if cached.
Embedding getFunctionVector() const
Computes and returns the embedding for the current function.
LLVM_ABI Embedding computeEmbeddings() const
Function to compute embeddings.
void invalidateEmbeddings() override
Invalidate embeddings if cached.
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
Iterator support for section-based access.
const_iterator(const VocabStorage *Storage, unsigned SectionId, size_t LocalIndex)
LLVM_ABI bool operator!=(const const_iterator &Other) const
LLVM_ABI const_iterator & operator++()
LLVM_ABI const Embedding & operator*() const
LLVM_ABI bool operator==(const const_iterator &Other) const
Generic storage class for section-based vocabularies.
static LLVM_ABI Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
VocabStorage & operator=(VocabStorage &&)=default
const_iterator end() const
unsigned getNumSections() const
Get number of sections.
VocabStorage & operator=(const VocabStorage &)=delete
unsigned getDimension() const
Get vocabulary dimension.
size_t size() const
Get total number of entries across all sections.
VocabStorage()=default
Default constructor creates empty storage (invalid state)
const_iterator begin() const
bool isValid() const
Check if vocabulary is valid (has data)
VocabStorage(VocabStorage &&)=default
std::map< std::string, Embedding > VocabMap
const std::vector< Embedding > & operator[](unsigned SectionId) const
Section-based access: Storage[sectionId][localIndex].
VocabStorage(const VocabStorage &)=delete
Class for storing and accessing the IR2Vec vocabulary.
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const
static LLVM_ABI Expected< Vocabulary > fromFile(StringRef VocabFilePath, float OpcWeight=1.0, float TypeWeight=0.5, float ArgWeight=0.2)
Create a Vocabulary by loading embeddings from a JSON file.
const_iterator begin() const
Vocabulary(Vocabulary &&)=default
static LLVM_ABI OperandKind getOperandKind(const Value *Op)
Function to classify an operand into OperandKind.
const ir2vec::Embedding & operator[](Type::TypeID TypeID) const
static StringRef getVocabKeyForOperandKind(OperandKind Kind)
Function to get vocabulary key for a given OperandKind.
Vocabulary & operator=(const Vocabulary &)=delete
static LLVM_ABI StringRef getStringKey(unsigned Pos)
Returns the string key for a given index position in the vocabulary.
Vocabulary(VocabStorage &&Storage)
static constexpr unsigned MaxCanonicalTypeIDs
unsigned getDimension() const
static constexpr unsigned MaxOperandKinds
Vocabulary(const Vocabulary &)=delete
const_iterator cbegin() const
OperandKind
Operand kinds supported by IR2Vec Vocabulary.
static constexpr size_t getCanonicalSize()
Total number of entries (opcodes + canonicalized types + operand kinds + predicates)
const ir2vec::Embedding & operator[](CmpInst::Predicate P) const
static unsigned getIndex(Type::TypeID TypeID)
static LLVM_ABI StringRef getVocabKeyForPredicate(CmpInst::Predicate P)
Function to get vocabulary key for a given predicate.
static constexpr unsigned MaxTypeIDs
const_iterator end() const
static LLVM_ABI StringRef getVocabKeyForOpcode(unsigned Opcode)
Function to get vocabulary key for a given Opcode.
static unsigned getIndex(CmpInst::Predicate P)
static StringRef getVocabKeyForTypeID(Type::TypeID TypeID)
Function to get vocabulary key for a given TypeID.
VocabStorage::const_iterator const_iterator
Const Iterator type aliases.
const_iterator cend() const
static unsigned getIndex(unsigned Opcode)
Functions to return flat index.
Vocabulary & operator=(Vocabulary &&Other)=delete
static LLVM_ABI VocabStorage createDummyVocabForTest(unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
static constexpr unsigned MaxPredicateKinds
static unsigned getIndex(const Value &Op)
CanonicalTypeID
Canonical type IDs supported by IR2Vec Vocabulary.
const ir2vec::Embedding & operator[](unsigned Opcode) const
Accessors to get the embedding for a given entity.
const ir2vec::Embedding & operator[](const Value &Arg) const
A Value is an JSON value of unknown type.
This class implements an extremely fast bulk output stream that can only output to a stream.
DenseMap< const Instruction *, Embedding > InstEmbeddingsMap
LLVM_ABI cl::opt< float > ArgWeight
DenseMap< const BasicBlock *, Embedding > BBEmbeddingsMap
LLVM_ABI cl::opt< std::string > VocabFile
LLVM_ABI cl::opt< float > OpcWeight
LLVM_ABI cl::opt< float > TypeWeight
LLVM_ABI cl::opt< IR2VecKind > IR2VecEmbeddingKind
LLVM_ABI llvm::cl::OptionCategory IR2VecCategory
This is an optimization pass for GlobalISel generic memory operations.
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
IR2VecKind
IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
DWARFExpression::Operation Op
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Implement std::hash so that hash_code can be used in STL containers.
A CRTP mix-in that provides informational APIs needed for analysis passes.
A special type used by analysis passes to provide an address that identifies that particular analysis...
A CRTP mix-in for passes that should not be skipped.
Embedding is a datatype that wraps std::vector<double>.
const_iterator end() const
LLVM_ABI bool approximatelyEquals(const Embedding &RHS, double Tolerance=1e-4) const
Returns true if the embedding is approximately equal to the RHS embedding within the specified tolera...
const_iterator cbegin() const
std::vector< double >::iterator iterator
LLVM_ABI Embedding & operator+=(const Embedding &RHS)
Arithmetic operators.
std::vector< double >::const_iterator const_iterator
LLVM_ABI Embedding operator-(const Embedding &RHS) const
const std::vector< double > & getData() const
Embedding(size_t Size, double InitialValue)
LLVM_ABI Embedding & operator-=(const Embedding &RHS)
const_iterator cend() const
bool isZero() const
Returns true if all elements of the embedding are zero.
LLVM_ABI Embedding operator*(double Factor) const
LLVM_ABI Embedding & operator*=(double Factor)
Embedding(std::initializer_list< double > IL)
Embedding(const std::vector< double > &V)
LLVM_ABI Embedding operator+(const Embedding &RHS) const
LLVM_ABI Embedding & scaleAndAdd(const Embedding &Src, float Factor)
Adds Src Embedding scaled by Factor with the called Embedding.
Embedding(std::vector< double > &&V)
const double & operator[](size_t Itr) const
LLVM_ABI void print(raw_ostream &OS) const
const_iterator begin() const
double & operator[](size_t Itr)