LLVM 20.0.0git
BottomUpVec.cpp
Go to the documentation of this file.
1//===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
19
20namespace llvm {
21
22static cl::opt<unsigned>
23 OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden,
24 cl::desc("Override the vector register size in bits, "
25 "which is otherwise found by querying TTI."));
26static cl::opt<bool>
27 AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden,
28 cl::desc("Allow non-power-of-2 vectorization."));
29
30namespace sandboxir {
31
33 : FunctionPass("bottom-up-vec"),
34 RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {}
35
37 unsigned OpIdx) {
39 for (Value *BndlV : Bndl) {
40 auto *BndlI = cast<Instruction>(BndlV);
41 Operands.push_back(BndlI->getOperand(OpIdx));
42 }
43 return Operands;
44}
45
48 // TODO: Use the VecUtils function for getting the bottom instr once it lands.
49 auto *BotI = cast<Instruction>(
50 *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) {
51 return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2));
52 }));
53 // If Bndl contains Arguments or Constants, use the beginning of the BB.
54 return std::next(BotI->getIterator());
55}
56
57Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
59 auto CreateVectorInstr = [](ArrayRef<Value *> Bndl,
61 assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
62 "Expect Instructions!");
63 auto &Ctx = Bndl[0]->getContext();
64
65 Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
66 auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
67
69
70 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
71 switch (Opcode) {
72 case Instruction::Opcode::ZExt:
73 case Instruction::Opcode::SExt:
74 case Instruction::Opcode::FPToUI:
75 case Instruction::Opcode::FPToSI:
76 case Instruction::Opcode::FPExt:
77 case Instruction::Opcode::PtrToInt:
78 case Instruction::Opcode::IntToPtr:
79 case Instruction::Opcode::SIToFP:
80 case Instruction::Opcode::UIToFP:
81 case Instruction::Opcode::Trunc:
82 case Instruction::Opcode::FPTrunc:
83 case Instruction::Opcode::BitCast: {
84 assert(Operands.size() == 1u && "Casts are unary!");
85 return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx,
86 "VCast");
87 }
88 case Instruction::Opcode::FCmp:
89 case Instruction::Opcode::ICmp: {
90 auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
92 [Pred](auto *SBV) {
93 return cast<CmpInst>(SBV)->getPredicate() == Pred;
94 }) &&
95 "Expected same predicate across bundle.");
96 return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
97 "VCmp");
98 }
99 case Instruction::Opcode::Select: {
100 return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
101 Ctx, "Vec");
102 }
103 case Instruction::Opcode::FNeg: {
104 auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
105 auto OpC = UOp0->getOpcode();
107 WhereIt, Ctx, "Vec");
108 }
109 case Instruction::Opcode::Add:
110 case Instruction::Opcode::FAdd:
111 case Instruction::Opcode::Sub:
112 case Instruction::Opcode::FSub:
113 case Instruction::Opcode::Mul:
114 case Instruction::Opcode::FMul:
115 case Instruction::Opcode::UDiv:
116 case Instruction::Opcode::SDiv:
117 case Instruction::Opcode::FDiv:
118 case Instruction::Opcode::URem:
119 case Instruction::Opcode::SRem:
120 case Instruction::Opcode::FRem:
121 case Instruction::Opcode::Shl:
122 case Instruction::Opcode::LShr:
123 case Instruction::Opcode::AShr:
124 case Instruction::Opcode::And:
125 case Instruction::Opcode::Or:
126 case Instruction::Opcode::Xor: {
127 auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
128 auto *LHS = Operands[0];
129 auto *RHS = Operands[1];
131 BinOp0->getOpcode(), LHS, RHS, BinOp0, WhereIt, Ctx, "Vec");
132 }
133 case Instruction::Opcode::Load: {
134 auto *Ld0 = cast<LoadInst>(Bndl[0]);
135 Value *Ptr = Ld0->getPointerOperand();
136 return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx,
137 "VecL");
138 }
139 case Instruction::Opcode::Store: {
140 auto Align = cast<StoreInst>(Bndl[0])->getAlign();
141 Value *Val = Operands[0];
142 Value *Ptr = Operands[1];
143 return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
144 }
145 case Instruction::Opcode::Br:
146 case Instruction::Opcode::Ret:
147 case Instruction::Opcode::PHI:
148 case Instruction::Opcode::AddrSpaceCast:
149 case Instruction::Opcode::Call:
150 case Instruction::Opcode::GetElementPtr:
151 llvm_unreachable("Unimplemented");
152 break;
153 default:
154 llvm_unreachable("Unimplemented");
155 break;
156 }
157 llvm_unreachable("Missing switch case!");
158 // TODO: Propagate debug info.
159 };
160
161 auto *VecI = CreateVectorInstr(Bndl, Operands);
162 if (VecI != nullptr) {
163 Change = true;
164 IMaps->registerVector(Bndl, VecI);
165 }
166 return VecI;
167}
168
169void BottomUpVec::tryEraseDeadInstrs() {
170 // Visiting the dead instructions bottom-to-top.
171 SmallVector<Instruction *> SortedDeadInstrCandidates(
172 DeadInstrCandidates.begin(), DeadInstrCandidates.end());
173 sort(SortedDeadInstrCandidates,
174 [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); });
175 for (Instruction *I : reverse(SortedDeadInstrCandidates)) {
176 if (I->hasNUses(0))
177 I->eraseFromParent();
178 }
179 DeadInstrCandidates.clear();
180}
181
182Value *BottomUpVec::createShuffle(Value *VecOp, const ShuffleMask &Mask) {
184 return ShuffleVectorInst::create(VecOp, VecOp, Mask, WhereIt,
185 VecOp->getContext(), "VShuf");
186}
187
188Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) {
190
191 Type *ScalarTy = VecUtils::getCommonScalarType(ToPack);
192 unsigned Lanes = VecUtils::getNumLanes(ToPack);
193 Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes);
194
195 // Create a series of pack instructions.
196 Value *LastInsert = PoisonValue::get(VecTy);
197
198 Context &Ctx = ToPack[0]->getContext();
199
200 unsigned InsertIdx = 0;
201 for (Value *Elm : ToPack) {
202 // An element can be either scalar or vector. We need to generate different
203 // IR for each case.
204 if (Elm->getType()->isVectorTy()) {
205 unsigned NumElms =
206 cast<FixedVectorType>(Elm->getType())->getNumElements();
207 for (auto ExtrLane : seq<int>(0, NumElms)) {
208 // We generate extract-insert pairs, for each lane in `Elm`.
209 Constant *ExtrLaneC =
211 // This may return a Constant if Elm is a Constant.
212 auto *ExtrI =
213 ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack");
214 if (!isa<Constant>(ExtrI))
215 WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator());
216 Constant *InsertLaneC =
217 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
218 // This may also return a Constant if ExtrI is a Constant.
219 auto *InsertI = InsertElementInst::create(
220 LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack");
221 if (!isa<Constant>(InsertI)) {
222 LastInsert = InsertI;
223 WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator());
224 }
225 }
226 } else {
227 Constant *InsertLaneC =
228 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
229 // This may be folded into a Constant if LastInsert is a Constant. In
230 // that case we only collect the last constant.
231 LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC,
232 WhereIt, Ctx, "Pack");
233 if (auto *NewI = dyn_cast<Instruction>(LastInsert))
234 WhereIt = std::next(NewI->getIterator());
235 }
236 }
237 return LastInsert;
238}
239
240void BottomUpVec::collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl) {
241 for (Value *V : Bndl)
242 DeadInstrCandidates.insert(cast<Instruction>(V));
243 // Also collect the GEPs of vectorized loads and stores.
244 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
245 switch (Opcode) {
246 case Instruction::Opcode::Load: {
247 for (Value *V : drop_begin(Bndl))
248 if (auto *Ptr =
249 dyn_cast<Instruction>(cast<LoadInst>(V)->getPointerOperand()))
250 DeadInstrCandidates.insert(Ptr);
251 break;
252 }
253 case Instruction::Opcode::Store: {
254 for (Value *V : drop_begin(Bndl))
255 if (auto *Ptr =
256 dyn_cast<Instruction>(cast<StoreInst>(V)->getPointerOperand()))
257 DeadInstrCandidates.insert(Ptr);
258 break;
259 }
260 default:
261 break;
262 }
263}
264
265Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
266 Value *NewVec = nullptr;
267 const auto &LegalityRes = Legality->canVectorize(Bndl);
268 switch (LegalityRes.getSubclassID()) {
270 auto *I = cast<Instruction>(Bndl[0]);
271 SmallVector<Value *, 2> VecOperands;
272 switch (I->getOpcode()) {
273 case Instruction::Opcode::Load:
274 // Don't recurse towards the pointer operand.
275 VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand());
276 break;
277 case Instruction::Opcode::Store: {
278 // Don't recurse towards the pointer operand.
279 auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Depth + 1);
280 VecOperands.push_back(VecOp);
281 VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
282 break;
283 }
284 default:
285 // Visit all operands.
286 for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
287 auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Depth + 1);
288 VecOperands.push_back(VecOp);
289 }
290 break;
291 }
292 NewVec = createVectorInstr(Bndl, VecOperands);
293
294 // Collect any potentially dead scalar instructions, including the original
295 // scalars and pointer operands of loads/stores.
296 if (NewVec != nullptr)
297 collectPotentiallyDeadInstrs(Bndl);
298 break;
299 }
301 NewVec = cast<DiamondReuse>(LegalityRes).getVector();
302 break;
303 }
305 auto *VecOp = cast<DiamondReuseWithShuffle>(LegalityRes).getVector();
306 const ShuffleMask &Mask =
307 cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
308 NewVec = createShuffle(VecOp, Mask);
309 break;
310 }
312 // If we can't vectorize the seeds then just return.
313 if (Depth == 0)
314 return nullptr;
315 NewVec = createPack(Bndl);
316 break;
317 }
318 }
319 return NewVec;
320}
321
322bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
323 DeadInstrCandidates.clear();
324 Legality->clear();
325 vectorizeRec(Bndl, /*Depth=*/0);
326 tryEraseDeadInstrs();
327 return Change;
328}
329
331 IMaps = std::make_unique<InstrMaps>(F.getContext());
332 Legality = std::make_unique<LegalityAnalysis>(
333 A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
334 F.getContext(), *IMaps);
335 Change = false;
336 const auto &DL = F.getParent()->getDataLayout();
337 unsigned VecRegBits =
340 : A.getTTI()
342 .getFixedValue();
343
344 // TODO: Start from innermost BBs first
345 for (auto &BB : F) {
346 SeedCollector SC(&BB, A.getScalarEvolution());
347 for (SeedBundle &Seeds : SC.getStoreSeeds()) {
348 unsigned ElmBits =
350 Seeds[Seeds.getFirstUnusedElementIdx()])),
351 DL);
352
353 auto DivideBy2 = [](unsigned Num) {
354 auto Floor = VecUtils::getFloorPowerOf2(Num);
355 if (Floor == Num)
356 return Floor / 2;
357 return Floor;
358 };
359 // Try to create the largest vector supported by the target. If it fails
360 // reduce the vector size by half.
361 for (unsigned SliceElms = std::min(VecRegBits / ElmBits,
362 Seeds.getNumUnusedBits() / ElmBits);
363 SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) {
364 if (Seeds.allUsed())
365 break;
366 // Keep trying offsets after FirstUnusedElementIdx, until we vectorize
367 // the slice. This could be quite expensive, so we enforce a limit.
368 for (unsigned Offset = Seeds.getFirstUnusedElementIdx(),
369 OE = Seeds.size();
370 Offset + 1 < OE; Offset += 1) {
371 // Seeds are getting used as we vectorize, so skip them.
372 if (Seeds.isUsed(Offset))
373 continue;
374 if (Seeds.allUsed())
375 break;
376
377 auto SeedSlice =
378 Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2);
379 if (SeedSlice.empty())
380 continue;
381
382 assert(SeedSlice.size() >= 2 && "Should have been rejected!");
383
384 // TODO: If vectorization succeeds, run the RegionPassManager on the
385 // resulting region.
386
387 // TODO: Refactor to remove the unnecessary copy to SeedSliceVals.
388 SmallVector<Value *> SeedSliceVals(SeedSlice.begin(),
389 SeedSlice.end());
390 Change |= tryVectorize(SeedSliceVals);
391 }
392 }
393 }
394 }
395 return Change;
396}
397
398} // namespace sandboxir
399} // namespace llvm
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallVector class.
This pass exposes codegen information to IR-level passes.
Value * RHS
Value * LHS
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
iterator end() const
Definition: ArrayRef.h:157
iterator begin() const
Definition: ArrayRef.h:156
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:177
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
LLVM Value Representation.
Definition: Value.h:74
static Value * createWithCopiedFlags(Instruction::Opcode Op, Value *LHS, Value *RHS, Value *CopyFrom, InsertPosition Pos, Context &Ctx, const Twine &Name="")
BottomUpVec(StringRef Pipeline)
Definition: BottomUpVec.cpp:32
bool runOnFunction(Function &F, const Analyses &A) final
\Returns true if it modifies F.
static Value * create(Type *DestTy, Opcode Op, Value *Operand, InsertPosition Pos, Context &Ctx, const Twine &Name="")
static Value * create(Predicate Pred, Value *S1, Value *S2, InsertPosition Pos, Context &Ctx, const Twine &Name="")
static ConstantInt * getSigned(IntegerType *Ty, int64_t V)
Return a ConstantInt with the specified value for the specified type.
Definition: Constant.cpp:57
static Value * create(Value *Vec, Value *Idx, InsertPosition Pos, Context &Ctx, const Twine &Name="")
A pass that runs on a sandbox::Function.
Definition: Pass.h:75
static Value * create(Value *Vec, Value *NewElt, Value *Idx, InsertPosition Pos, Context &Ctx, const Twine &Name="")
static LoadInst * create(Type *Ty, Value *Ptr, MaybeAlign Align, InsertPosition Pos, bool IsVolatile, Context &Ctx, const Twine &Name="")
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constant.cpp:238
A set of candidate Instructions for vectorizing together.
Definition: SeedCollector.h:27
static Value * create(Value *Cond, Value *True, Value *False, InsertPosition Pos, Context &Ctx, const Twine &Name="")
static Value * create(Value *V1, Value *V2, Value *Mask, InsertPosition Pos, Context &Ctx, const Twine &Name="")
static StoreInst * create(Value *V, Value *Ptr, MaybeAlign Align, InsertPosition Pos, bool IsVolatile, Context &Ctx)
static Type * getInt32Ty(Context &Ctx)
static Value * createWithCopiedFlags(Instruction::Opcode Op, Value *OpV, Value *CopyFrom, InsertPosition Pos, Context &Ctx, const Twine &Name="")
static unsigned getNumBits(Type *Ty, const DataLayout &DL)
\Returns the number of bits of Ty.
Definition: Utils.h:64
static Type * getExpectedType(const Value *V)
\Returns the expected type of Value V.
Definition: Utils.h:30
A SandboxIR Value has users. This is the base class.
Definition: Value.h:63
static Type * getCommonScalarType(ArrayRef< Value * > Bndl)
Similar to tryGetCommonScalarType() but will assert that there is a common type.
Definition: VecUtils.h:129
static unsigned getNumLanes(Type *Ty)
\Returns the number of vector lanes of Ty or 1 if not a vector.
Definition: VecUtils.h:72
static Type * getWideType(Type *ElemTy, unsigned NumElts)
\Returns <NumElts x ElemTy>.
Definition: VecUtils.h:95
static Type * getElementType(Type *Ty)
Returns Ty if scalar or its element type if vector.
Definition: VecUtils.h:32
static unsigned getFloorPowerOf2(unsigned Num)
\Returns the first integer power of 2 that is <= Num.
Definition: VecUtils.cpp:13
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:125
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
Type
MessagePack types as defined in the standard, with the exception of Integer being divided into a sign...
Definition: MsgPackReader.h:53
@ DiamondReuse
‍Vectorize by combining scalars to a vector.
@ DiamondReuseWithShuffle
‍Don't generate new code, reuse existing vector.
@ Widen
‍Collect scalar values.
static BasicBlock::iterator getInsertPointAfterInstrs(ArrayRef< Value * > Instrs)
Definition: BottomUpVec.cpp:47
static SmallVector< Value *, 4 > getOperand(ArrayRef< Value * > Bndl, unsigned OpIdx)
Definition: BottomUpVec.cpp:36
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:480
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1739
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:420
void sort(IteratorTy Start, IteratorTy End)
Definition: STLExtras.h:1664
static cl::opt< unsigned > OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden, cl::desc("Override the vector register size in bits, " "which is otherwise found by querying TTI."))
static cl::opt< bool > AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden, cl::desc("Allow non-power-of-2 vectorization."))