LLVM 23.0.0git
SplitModuleByCategory.cpp
Go to the documentation of this file.
1//===-------- SplitModuleByCategory.cpp - split a module by categories ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// See comments in the header.
9//===----------------------------------------------------------------------===//
10
12#include "llvm/ADT/SetVector.h"
15#include "llvm/IR/Function.h"
18#include "llvm/IR/Module.h"
19#include "llvm/Support/Debug.h"
21
22#include <map>
23#include <utility>
24
25using namespace llvm;
26
27#define DEBUG_TYPE "split-module-by-category"
28
29namespace {
30
31// A vector that contains a group of function with the same category.
32using EntryPointSet = SetVector<const Function *>;
33
34/// Represents a group of functions with one category.
35struct EntryPointGroup {
36 int ID;
37 EntryPointSet Functions;
38
39 EntryPointGroup() = default;
40
41 EntryPointGroup(int ID, EntryPointSet &&Functions = EntryPointSet())
42 : ID(ID), Functions(std::move(Functions)) {}
43
44 void clear() { Functions.clear(); }
45
46#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
47 LLVM_DUMP_METHOD void dump() const {
48 constexpr size_t INDENT = 4;
49 dbgs().indent(INDENT) << "ENTRY POINTS"
50 << " " << ID << " {\n";
51 for (const Function *F : Functions)
52 dbgs().indent(INDENT) << " " << F->getName() << "\n";
53
54 dbgs().indent(INDENT) << "}\n";
55 }
56#endif
57};
58
59/// Annotates an llvm::Module with information necessary to perform and track
60/// the result of code (llvm::Module instances) splitting:
61/// - entry points group from the module.
62class ModuleDesc {
63 std::unique_ptr<Module> M;
64 EntryPointGroup EntryPoints;
65
66public:
67 ModuleDesc(std::unique_ptr<Module> M,
68 EntryPointGroup &&EntryPoints = EntryPointGroup())
69 : M(std::move(M)), EntryPoints(std::move(EntryPoints)) {
70 assert(this->M && "Module should be non-null");
71 }
72
73 Module &getModule() { return *M; }
74 const Module &getModule() const { return *M; }
75
76 std::unique_ptr<Module> releaseModule() {
77 EntryPoints.clear();
78 return std::move(M);
79 }
80
81#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
82 LLVM_DUMP_METHOD void dump() const {
83 dbgs() << "ModuleDesc[" << M->getName() << "] {\n";
84 EntryPoints.dump();
85 dbgs() << "}\n";
86 }
87#endif
88};
89
90// Represents "dependency" or "use" graph of global objects (functions and
91// global variables) in a module. It is used during code split to
92// understand which global variables and functions (other than entry points)
93// should be included into a split module.
94//
95// Nodes of the graph represent LLVM's GlobalObjects, edges "A" -> "B" represent
96// the fact that if "A" is included into a module, then "B" should be included
97// as well.
98//
99// Examples of dependencies which are represented in this graph:
100// - Function FA calls function FB
101// - Function FA uses global variable GA
102// - Global variable GA references (initialized with) function FB
103// - Function FA stores address of a function FB somewhere
104//
105// The following cases are treated as dependencies between global objects:
106// 1. Global object A is used by a global object B in any way (store,
107// bitcast, phi node, call, etc.): "A" -> "B" edge will be added to the
108// graph;
109// 2. function A performs an indirect call of a function with signature S and
110// there is a function B with signature S. "A" -> "B" edge will be added to
111// the graph;
112class DependencyGraph {
113public:
114 using GlobalSet = SmallPtrSet<const GlobalValue *, 16>;
115
116 DependencyGraph(const Module &M) {
117 // Group functions by their signature to handle case (2) described above
119 FuncTypeToFuncsMap;
120 for (const Function &F : M.functions()) {
121 // Kernels can't be called (either directly or indirectly).
122 if (F.hasKernelCallingConv())
123 continue;
124
125 FuncTypeToFuncsMap[F.getFunctionType()].insert(&F);
126 }
127
128 for (const Function &F : M.functions()) {
129 // case (1), see comment above the class definition
130 for (const Value *U : F.users())
131 addUserToGraphRecursively(cast<const User>(U), &F);
132
133 // case (2), see comment above the class definition
134 for (const Instruction &I : instructions(F)) {
135 const CallBase *CB = dyn_cast<CallBase>(&I);
136 if (!CB || !CB->isIndirectCall()) // Direct calls were handled above
137 continue;
138
139 const FunctionType *Signature = CB->getFunctionType();
140 GlobalSet &PotentialCallees = FuncTypeToFuncsMap[Signature];
141 Graph[&F].insert(PotentialCallees.begin(), PotentialCallees.end());
142 }
143 }
144
145 // And every global variable (but their handling is a bit simpler)
146 for (const GlobalVariable &GV : M.globals())
147 for (const Value *U : GV.users())
148 addUserToGraphRecursively(cast<const User>(U), &GV);
149 }
150
152 dependencies(const GlobalValue *Val) const {
153 auto It = Graph.find(Val);
154 return (It == Graph.end())
155 ? make_range(EmptySet.begin(), EmptySet.end())
156 : make_range(It->second.begin(), It->second.end());
157 }
158
159private:
160 void addUserToGraphRecursively(const User *Root, const GlobalValue *V) {
162 WorkList.push_back(Root);
163
164 while (!WorkList.empty()) {
165 const User *U = WorkList.pop_back_val();
166 if (const auto *I = dyn_cast<const Instruction>(U)) {
167 const Function *UFunc = I->getFunction();
168 Graph[UFunc].insert(V);
169 } else if (isa<const Constant>(U)) {
170 if (const auto *GV = dyn_cast<const GlobalVariable>(U))
171 Graph[GV].insert(V);
172 // This could be a global variable or some constant expression (like
173 // bitcast or gep). We trace users of this constant further to reach
174 // global objects they are used by and add them to the graph.
175 for (const User *UU : U->users())
176 WorkList.push_back(UU);
177 } else {
178 llvm_unreachable("Unhandled type of function user");
179 }
180 }
181 }
182
185};
186
187void collectFunctionsAndGlobalVariablesToExtract(
189 const EntryPointGroup &ModuleEntryPoints, const DependencyGraph &DG) {
190 // We start with module entry points
191 for (const Function *F : ModuleEntryPoints.Functions)
192 GVs.insert(F);
193
194 // Non-discardable global variables are also include into the initial set
195 for (const GlobalVariable &GV : M.globals())
196 if (!GV.isDiscardableIfUnused())
197 GVs.insert(&GV);
198
199 // GVs has SetVector type. This type inserts a value only if it is not yet
200 // present there. So, recursion is not expected here.
201 size_t Idx = 0;
202 while (Idx < GVs.size()) {
203 const GlobalValue *Obj = GVs[Idx++];
204
205 for (const GlobalValue *Dep : DG.dependencies(Obj)) {
206 if (const auto *Func = dyn_cast<const Function>(Dep)) {
207 if (!Func->isDeclaration())
208 GVs.insert(Func);
209 } else {
210 GVs.insert(Dep); // Global variables are added unconditionally
211 }
212 }
213 }
214}
215
216ModuleDesc extractSubModule(const Module &M,
218 EntryPointGroup &&ModuleEntryPoints) {
220 // Clone definitions only for needed globals. Others will be added as
221 // declarations and removed later.
222 std::unique_ptr<Module> SubM = CloneModule(
223 M, VMap, [&](const GlobalValue *GV) { return GVs.contains(GV); });
224 // Replace entry points with cloned ones.
225 EntryPointSet NewEPs;
226 const EntryPointSet &EPs = ModuleEntryPoints.Functions;
228 EPs, [&](const Function *F) { NewEPs.insert(cast<Function>(VMap[F])); });
229 ModuleEntryPoints.Functions = std::move(NewEPs);
230 return ModuleDesc{std::move(SubM), std::move(ModuleEntryPoints)};
231}
232
233// The function produces a copy of input LLVM IR module M with only those
234// functions and globals that can be called from entry points that are specified
235// in ModuleEntryPoints vector, in addition to the entry point functions.
236ModuleDesc extractCallGraph(const Module &M,
237 EntryPointGroup &&ModuleEntryPoints,
238 const DependencyGraph &DG) {
240 collectFunctionsAndGlobalVariablesToExtract(GVs, M, ModuleEntryPoints, DG);
241
242 ModuleDesc SplitM = extractSubModule(M, GVs, std::move(ModuleEntryPoints));
243 LLVM_DEBUG(SplitM.dump());
244 return SplitM;
245}
246
247using EntryPointGroupVec = SmallVector<EntryPointGroup>;
248
249/// Module Splitter.
250/// It gets a module and a collection of entry points groups.
251/// Each group specifies subset entry points from input module that should be
252/// included in a split module.
253class ModuleSplitter {
254private:
255 std::unique_ptr<Module> M;
256 EntryPointGroupVec Groups;
257 DependencyGraph DG;
258
259private:
260 EntryPointGroup drawEntryPointGroup() {
261 assert(Groups.size() > 0 && "Reached end of entry point groups list.");
262 EntryPointGroup Group = std::move(Groups.back());
263 Groups.pop_back();
264 return Group;
265 }
266
267public:
268 ModuleSplitter(std::unique_ptr<Module> Module, EntryPointGroupVec &&GroupVec)
269 : M(std::move(Module)), Groups(std::move(GroupVec)), DG(*M) {
270 assert(!Groups.empty() && "Entry points groups collection is empty!");
271 }
272
273 /// Gets next subsequence of entry points in an input module and provides
274 /// split submodule containing these entry points and their dependencies.
275 ModuleDesc getNextSplit() {
276 return extractCallGraph(*M, drawEntryPointGroup(), DG);
277 }
278
279 /// Check that there are still submodules to split.
280 bool hasMoreSplits() const { return Groups.size() > 0; }
281};
282
283EntryPointGroupVec selectEntryPointGroups(
284 const Module &M, function_ref<std::optional<int>(const Function &F)> EPC) {
285 // std::map is used here to ensure stable ordering of entry point groups,
286 // which is based on their contents, this greatly helps LIT tests
287 // Note: EPC is allowed to return big identifiers. Therefore, we use
288 // std::map + SmallVector approach here.
289 std::map<int, EntryPointSet> EntryPointsMap;
290
291 for (const auto &F : M.functions())
292 if (std::optional<int> Category = EPC(F); Category)
293 EntryPointsMap[*Category].insert(&F);
294
295 EntryPointGroupVec Groups;
296 Groups.reserve(EntryPointsMap.size());
297 for (auto &[Key, EntryPoints] : EntryPointsMap)
298 Groups.emplace_back(Key, std::move(EntryPoints));
299
300 return Groups;
301}
302
303} // namespace
304
306 std::unique_ptr<Module> M,
307 function_ref<std::optional<int>(const Function &F)> EntryPointCategorizer,
308 function_ref<void(std::unique_ptr<Module> Part)> Callback) {
309 EntryPointGroupVec Groups = selectEntryPointGroups(*M, EntryPointCategorizer);
310 ModuleSplitter Splitter(std::move(M), std::move(Groups));
311 while (Splitter.hasMoreSplits()) {
312 ModuleDesc MD = Splitter.getNextSplit();
313 Callback(MD.releaseModule());
314 }
315}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Expand Atomic instructions
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
static SmallVector< const DIVariable *, 2 > dependencies(DbgVariable *Var)
Return all DIVariables that appear in count: expressions.
static ThreadSafeModule extractSubModule(ThreadSafeModule &TSM, StringRef Suffix, GVPredicate ShouldExtract)
Module.h This file contains the declarations for the Module class.
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallPtrSet class.
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static const X86InstrFMA3Group Groups[]
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
LLVM_ABI bool isIndirectCall() const
Return true if the callsite is an indirect call.
FunctionType * getFunctionType() const
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
A vector that has set insertion semantics.
Definition SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition SetVector.h:103
bool contains(const_arg_type key) const
Check if the SetVector contains the given key.
Definition SetVector.h:252
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition SetVector.h:151
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
LLVM Value Representation.
Definition Value.h:75
iterator_range< user_iterator > users()
Definition Value.h:427
An efficient, type-erasing, non-owning reference to a callable.
A range adaptor for a pair of iterators.
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
UnaryFunction for_each(R &&Range, UnaryFunction F)
Provide wrappers to std::for_each which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1732
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
LLVM_ABI void splitModuleTransitiveFromEntryPoints(std::unique_ptr< Module > M, function_ref< std::optional< int >(const Function &F)> EntryPointCategorizer, function_ref< void(std::unique_ptr< Module > Part)> Callback)
Splits the given module M into parts.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
ValueMap< const Value *, WeakTrackingVH > ValueToValueMapTy
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
LLVM_ABI std::unique_ptr< Module > CloneModule(const Module &M)
Return an exact copy of the specified module.