return true;
}
-void setAttrByCalledFunc(CallInst *Call) {
+CallInst *setAttrByCalledFunc(CallInst *Call) {
Function *F = Call->getCalledFunction();
assert(F);
if (F->isIntrinsic()) {
- return;
+ return Call;
}
Call->setCallingConv(F->getCallingConv());
Call->setAttributes(F->getAttributes());
+ return Call;
}
bool isSPIRVBuiltinVariable(GlobalVariable *GV,
// %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
// %5 = insertelement <3 x i64> %3, i64 %4, i32 2
// %6 = extractelement <3 x i64> %5, i32 0
+
+/// Recursively look through the uses of a global variable, including casts or
+/// gep offsets, to find all loads of the variable. Gep offsets that are non-0
+/// are accumulated in the AccumulatedOffset parameter, which will eventually be
+/// used to figure out which index of a variable is being used.
+static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
+ Function *ReplacementFunc) {
+ const DataLayout &DL = ReplacementFunc->getParent()->getDataLayout();
+ SmallVector<Instruction *, 4> InstsToRemove;
+ for (User *U : V->users()) {
+ if (auto *Cast = dyn_cast<CastInst>(U)) {
+ replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc);
+ InstsToRemove.push_back(Cast);
+ } else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
+ APInt NewOffset = AccumulatedOffset.sextOrTrunc(
+ DL.getIndexSizeInBits(GEP->getPointerAddressSpace()));
+ if (!GEP->accumulateConstantOffset(DL, NewOffset))
+ llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
+ replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc);
+ InstsToRemove.push_back(GEP);
+ } else if (auto *Load = dyn_cast<LoadInst>(U)) {
+ // Figure out which index the accumulated offset corresponds to. If we
+ // have a weird offset (e.g., trying to load byte 7), bail out.
+ Type *ScalarTy = ReplacementFunc->getReturnType();
+ APInt Index;
+ uint64_t Remainder;
+ APInt::udivrem(AccumulatedOffset, ScalarTy->getScalarSizeInBits() / 8,
+ Index, Remainder);
+ if (Remainder != 0)
+ llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
+
+ IRBuilder<> Builder(Load);
+ Value *Replacement;
+ if (ReplacementFunc->getFunctionType()->getNumParams() == 0) {
+ if (Load->getType() != ScalarTy)
+ llvm_unreachable("Illegal use of a SPIR-V builtin variable");
+ Replacement =
+ setAttrByCalledFunc(Builder.CreateCall(ReplacementFunc, {}));
+ } else {
+ // The function has an index parameter.
+ if (auto *VecTy = dyn_cast<FixedVectorType>(Load->getType())) {
+ if (!Index.isZero())
+ llvm_unreachable("Illegal use of a SPIR-V builtin variable");
+ Replacement = UndefValue::get(VecTy);
+ for (unsigned I = 0; I < VecTy->getNumElements(); I++) {
+ Replacement = Builder.CreateInsertElement(
+ Replacement,
+ setAttrByCalledFunc(
+ Builder.CreateCall(ReplacementFunc, {Builder.getInt32(I)})),
+ Builder.getInt32(I));
+ }
+ } else if (Load->getType() == ScalarTy) {
+ Replacement = setAttrByCalledFunc(Builder.CreateCall(
+ ReplacementFunc, {Builder.getInt32(Index.getZExtValue())}));
+ } else {
+ llvm_unreachable("Illegal load type of a SPIR-V builtin variable");
+ }
+ }
+ Load->replaceAllUsesWith(Replacement);
+ InstsToRemove.push_back(Load);
+ } else {
+ llvm_unreachable("Illegal use of a SPIR-V builtin variable");
+ }
+ }
+
+ for (Instruction *I : InstsToRemove)
+ I->eraseFromParent();
+}
+
bool lowerBuiltinVariableToCall(GlobalVariable *GV,
SPIRVBuiltinVariableKind Kind) {
// There might be dead constant users of GV (for example, SPIRVLowerConstExpr
Func->setDoesNotAccessMemory();
}
- // Collect instructions in these containers to remove them later.
- std::vector<Instruction *> Loads;
- std::vector<Instruction *> Casts;
- std::vector<Instruction *> GEPs;
-
- auto Replace = [&](std::vector<Value *> Arg, Instruction *I) {
- auto *Call = CallInst::Create(Func, Arg, "", I);
- Call->takeName(I);
- setAttrByCalledFunc(Call);
- SPIRVDBG(dbgs() << "[lowerBuiltinVariableToCall] " << *I << " -> " << *Call
- << '\n';)
- I->replaceAllUsesWith(Call);
- };
-
- // If HasIndexArg is true, we create 3 built-in calls and insertelement to
- // get 3-element vector filled with ids and replace uses of Load instruction
- // with this vector.
- // If HasIndexArg is false, the result of the Load instruction is the value
- // which should be replaced with the Func.
- // Returns true if Load was replaced, false otherwise.
- auto ReplaceIfLoad = [&](User *I) {
- auto *LD = dyn_cast<LoadInst>(I);
- if (!LD)
- return false;
- std::vector<Value *> Vectors;
- Loads.push_back(LD);
- if (HasIndexArg) {
- auto *VecTy = cast<FixedVectorType>(GVTy);
- Value *EmptyVec = UndefValue::get(VecTy);
- Vectors.push_back(EmptyVec);
- const DebugLoc &DLoc = LD->getDebugLoc();
- for (unsigned I = 0; I < VecTy->getNumElements(); ++I) {
- auto *Idx = ConstantInt::get(Type::getInt32Ty(C), I);
- auto *Call = CallInst::Create(Func, {Idx}, "", LD);
- if (DLoc)
- Call->setDebugLoc(DLoc);
- setAttrByCalledFunc(Call);
- auto *Insert = InsertElementInst::Create(Vectors.back(), Call, Idx);
- if (DLoc)
- Insert->setDebugLoc(DLoc);
- Insert->insertAfter(Call);
- Vectors.push_back(Insert);
- }
-
- Value *Ptr = LD->getPointerOperand();
-
- if (isa<FixedVectorType>(LD->getType())) {
- LD->replaceAllUsesWith(Vectors.back());
- } else {
- auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
- assert(GEP && "Unexpected pattern!");
- assert(GEP->getNumIndices() == 2 && "Unexpected pattern!");
- Value *Idx = GEP->getOperand(2);
- Value *Vec = Vectors.back();
- auto *NewExtract = ExtractElementInst::Create(Vec, Idx);
- NewExtract->insertAfter(cast<Instruction>(Vec));
- LD->replaceAllUsesWith(NewExtract);
- }
-
- } else {
- Replace({}, LD);
- }
-
- return true;
- };
-
- // Go over the GV users, find Load and ExtractElement instructions and
- // replace them with the corresponding function call.
- for (auto *UI : GV->users()) {
- // There might or might not be an addrspacecast instruction.
- if (auto *ASCast = dyn_cast<AddrSpaceCastInst>(UI)) {
- Casts.push_back(ASCast);
- for (auto *CastUser : ASCast->users()) {
- if (ReplaceIfLoad(CastUser))
- continue;
- if (auto *GEP = dyn_cast<GetElementPtrInst>(CastUser)) {
- GEPs.push_back(GEP);
- for (auto *GEPUser : GEP->users()) {
- if (!ReplaceIfLoad(GEPUser))
- llvm_unreachable("Unexpected pattern!");
- }
- } else {
- llvm_unreachable("Unexpected pattern!");
- }
- }
- } else if (auto *GEP = dyn_cast<GetElementPtrInst>(UI)) {
- GEPs.push_back(GEP);
- for (auto *GEPUser : GEP->users()) {
- if (!ReplaceIfLoad(GEPUser))
- llvm_unreachable("Unexpected pattern!");
- }
- } else if (!ReplaceIfLoad(UI)) {
- llvm_unreachable("Unexpected pattern!");
- }
- }
-
- auto Erase = [](std::vector<Instruction *> &ToErase) {
- for (Instruction *I : ToErase) {
- assert(I->hasNUses(0));
- I->eraseFromParent();
- }
- };
- // Order of erasing is important.
- Erase(Loads);
- Erase(GEPs);
- Erase(Casts);
-
+ replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func);
return true;
}
define spir_func void @foo() {
entry:
%GroupID = alloca [3 x i64], align 8
- %0 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)*
- %1 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %0, i64 0, i64 0
+ %0 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
+ %1 = getelementptr <3 x i64>, ptr addrspace(4) %0, i64 0, i64 0
; CHECK-LLVM: %[[GLocalSize0:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
-; CHECK-LLVM: %[[Ins0:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize0]], i32 0
-; CHECK-LLVM: %[[GLocalSize1:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
-; CHECK-LLVM: %[[Ins1:[0-9]+]] = insertelement <3 x i64> %[[Ins0]], i64 %[[GLocalSize1]], i32 1
+ %2 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
+ %3 = getelementptr <3 x i64>, ptr addrspace(4) %2, i64 0, i64 2
+ %4 = load i64, ptr addrspace(4) %1, align 32
+ %5 = load i64, ptr addrspace(4) %3, align 8
; CHECK-LLVM: %[[GLocalSize2:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
-; CHECK-LLVM: %[[Ins2:[0-9]+]] = insertelement <3 x i64> %[[Ins1]], i64 %[[GLocalSize2]], i32 2
-; CHECK-LLVM: %[[Extract:[0-9]+]] = extractelement <3 x i64> %[[Ins2]], i64 0
- %2 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)*
- %3 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %2, i64 0, i64 2
- %4 = load i64, i64 addrspace(4)* %1, align 32
- %5 = load i64, i64 addrspace(4)* %3, align 8
-; CHECK-LLVM: %[[GLocalSize01:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
-; CHECK-LLVM: %[[Ins01:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize01]], i32 0
-; CHECK-LLVM: %[[GLocalSize11:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
-; CHECK-LLVM: %[[Ins11:[0-9]+]] = insertelement <3 x i64> %[[Ins01]], i64 %[[GLocalSize11]], i32 1
-; CHECK-LLVM: %[[GLocalSize21:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
-; CHECK-LLVM: %[[Ins21:[0-9]+]] = insertelement <3 x i64> %[[Ins11]], i64 %[[GLocalSize21]], i32 2
-; CHECK-LLVM: %[[Extract1:[0-9]+]] = extractelement <3 x i64> %[[Ins21]], i64 2
-; CHECK-LLVM: mul i64 %[[Extract]], %[[Extract1]]
+; CHECK-LLVM: mul i64 %[[GLocalSize0]], %[[GLocalSize2]]
%mul = mul i64 %4, %5
ret void
}
+; Function Attrs: alwaysinline convergent nounwind mustprogress
+define spir_func void @foo_i8gep() {
+entry:
+ %GroupID = alloca [3 x i64], align 8
+ %0 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
+ %1 = getelementptr i8, ptr addrspace(4) %0, i64 0
+; CHECK-LLVM: %[[GLocalSize0:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
+ %2 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
+ %3 = getelementptr i8, ptr addrspace(4) %2, i64 16
+ %4 = load i64, ptr addrspace(4) %1, align 32
+ %5 = load i64, ptr addrspace(4) %3, align 8
+; CHECK-LLVM: %[[GLocalSize2:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
+; CHECK-LLVM: mul i64 %[[GLocalSize0]], %[[GLocalSize2]]
+ %mul = mul i64 %4, %5
+ ret void
+}