From b546576867a7d9f4c6abc2a43e34296f52242ccd Mon Sep 17 00:00:00 2001 From: Mateusz Chudyk Date: Mon, 3 Jul 2023 18:40:05 +0200 Subject: [PATCH] [PATCH 66/79] [Backport to 15] [OpaquePointers] Adjust builtin variable tracking to support i8 geps (#2061) The existing logic for the replacement of builtin variables with calls to functions relies on relatively brittle tracking that is broken when opaque pointers is turned on, and will be even more thoroughly broken if/when typed geps are replaced with i8 geps or ptradd. This patch replaces that logic with a less brittle variant that is able to handle any sequence of bitcast, gep, or addrspacecast instructions between the global variable and the ultimate load instruction. It still will error out if the variable is used in too insane of a fashion (say, trying to load an i32 out of the i64, or a misaligned vector type). Co-authored-by: Joshua Cranmer Gbp-Pq: Name 0066-Backport-to-15-OpaquePointers-Adjust-builtin-variabl.patch --- lib/SPIRV/SPIRVInternal.h | 2 +- lib/SPIRV/SPIRVUtil.cpp | 182 +++++++++++---------------- test/builtin-vars-gep.ll | 42 ++++--- test/transcoding/builtin_vars_gep.ll | 12 -- 4 files changed, 97 insertions(+), 141 deletions(-) diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h index 9eaaf6c..30241c8 100644 --- a/lib/SPIRV/SPIRVInternal.h +++ b/lib/SPIRV/SPIRVInternal.h @@ -1056,7 +1056,7 @@ std::string decodeSPIRVTypeName(StringRef Name, SmallVectorImpl &Strs); // Copy attributes from function to call site. -void setAttrByCalledFunc(CallInst *Call); +CallInst *setAttrByCalledFunc(CallInst *Call); bool isSPIRVBuiltinVariable(GlobalVariable *GV, SPIRVBuiltinVariableKind *Kind); // Transform builtin variable from GlobalVariable to builtin call. // e.g. diff --git a/lib/SPIRV/SPIRVUtil.cpp b/lib/SPIRV/SPIRVUtil.cpp index 1ac4eb6..9a45e77 100644 --- a/lib/SPIRV/SPIRVUtil.cpp +++ b/lib/SPIRV/SPIRVUtil.cpp @@ -1858,14 +1858,15 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) { return true; } -void setAttrByCalledFunc(CallInst *Call) { +CallInst *setAttrByCalledFunc(CallInst *Call) { Function *F = Call->getCalledFunction(); assert(F); if (F->isIntrinsic()) { - return; + return Call; } Call->setCallingConv(F->getCallingConv()); Call->setAttributes(F->getAttributes()); + return Call; } bool isSPIRVBuiltinVariable(GlobalVariable *GV, @@ -1915,6 +1916,75 @@ bool isSPIRVBuiltinVariable(GlobalVariable *GV, // %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1 // %5 = insertelement <3 x i64> %3, i64 %4, i32 2 // %6 = extractelement <3 x i64> %5, i32 0 + +/// Recursively look through the uses of a global variable, including casts or +/// gep offsets, to find all loads of the variable. Gep offsets that are non-0 +/// are accumulated in the AccumulatedOffset parameter, which will eventually be +/// used to figure out which index of a variable is being used. +static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset, + Function *ReplacementFunc) { + const DataLayout &DL = ReplacementFunc->getParent()->getDataLayout(); + SmallVector InstsToRemove; + for (User *U : V->users()) { + if (auto *Cast = dyn_cast(U)) { + replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc); + InstsToRemove.push_back(Cast); + } else if (auto *GEP = dyn_cast(U)) { + APInt NewOffset = AccumulatedOffset.sextOrTrunc( + DL.getIndexSizeInBits(GEP->getPointerAddressSpace())); + if (!GEP->accumulateConstantOffset(DL, NewOffset)) + llvm_unreachable("Illegal GEP of a SPIR-V builtin variable"); + replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc); + InstsToRemove.push_back(GEP); + } else if (auto *Load = dyn_cast(U)) { + // Figure out which index the accumulated offset corresponds to. If we + // have a weird offset (e.g., trying to load byte 7), bail out. + Type *ScalarTy = ReplacementFunc->getReturnType(); + APInt Index; + uint64_t Remainder; + APInt::udivrem(AccumulatedOffset, ScalarTy->getScalarSizeInBits() / 8, + Index, Remainder); + if (Remainder != 0) + llvm_unreachable("Illegal GEP of a SPIR-V builtin variable"); + + IRBuilder<> Builder(Load); + Value *Replacement; + if (ReplacementFunc->getFunctionType()->getNumParams() == 0) { + if (Load->getType() != ScalarTy) + llvm_unreachable("Illegal use of a SPIR-V builtin variable"); + Replacement = + setAttrByCalledFunc(Builder.CreateCall(ReplacementFunc, {})); + } else { + // The function has an index parameter. + if (auto *VecTy = dyn_cast(Load->getType())) { + if (!Index.isZero()) + llvm_unreachable("Illegal use of a SPIR-V builtin variable"); + Replacement = UndefValue::get(VecTy); + for (unsigned I = 0; I < VecTy->getNumElements(); I++) { + Replacement = Builder.CreateInsertElement( + Replacement, + setAttrByCalledFunc( + Builder.CreateCall(ReplacementFunc, {Builder.getInt32(I)})), + Builder.getInt32(I)); + } + } else if (Load->getType() == ScalarTy) { + Replacement = setAttrByCalledFunc(Builder.CreateCall( + ReplacementFunc, {Builder.getInt32(Index.getZExtValue())})); + } else { + llvm_unreachable("Illegal load type of a SPIR-V builtin variable"); + } + } + Load->replaceAllUsesWith(Replacement); + InstsToRemove.push_back(Load); + } else { + llvm_unreachable("Illegal use of a SPIR-V builtin variable"); + } + } + + for (Instruction *I : InstsToRemove) + I->eraseFromParent(); +} + bool lowerBuiltinVariableToCall(GlobalVariable *GV, SPIRVBuiltinVariableKind Kind) { // There might be dead constant users of GV (for example, SPIRVLowerConstExpr @@ -1950,113 +2020,7 @@ bool lowerBuiltinVariableToCall(GlobalVariable *GV, Func->setDoesNotAccessMemory(); } - // Collect instructions in these containers to remove them later. - std::vector Loads; - std::vector Casts; - std::vector GEPs; - - auto Replace = [&](std::vector Arg, Instruction *I) { - auto *Call = CallInst::Create(Func, Arg, "", I); - Call->takeName(I); - setAttrByCalledFunc(Call); - SPIRVDBG(dbgs() << "[lowerBuiltinVariableToCall] " << *I << " -> " << *Call - << '\n';) - I->replaceAllUsesWith(Call); - }; - - // If HasIndexArg is true, we create 3 built-in calls and insertelement to - // get 3-element vector filled with ids and replace uses of Load instruction - // with this vector. - // If HasIndexArg is false, the result of the Load instruction is the value - // which should be replaced with the Func. - // Returns true if Load was replaced, false otherwise. - auto ReplaceIfLoad = [&](User *I) { - auto *LD = dyn_cast(I); - if (!LD) - return false; - std::vector Vectors; - Loads.push_back(LD); - if (HasIndexArg) { - auto *VecTy = cast(GVTy); - Value *EmptyVec = UndefValue::get(VecTy); - Vectors.push_back(EmptyVec); - const DebugLoc &DLoc = LD->getDebugLoc(); - for (unsigned I = 0; I < VecTy->getNumElements(); ++I) { - auto *Idx = ConstantInt::get(Type::getInt32Ty(C), I); - auto *Call = CallInst::Create(Func, {Idx}, "", LD); - if (DLoc) - Call->setDebugLoc(DLoc); - setAttrByCalledFunc(Call); - auto *Insert = InsertElementInst::Create(Vectors.back(), Call, Idx); - if (DLoc) - Insert->setDebugLoc(DLoc); - Insert->insertAfter(Call); - Vectors.push_back(Insert); - } - - Value *Ptr = LD->getPointerOperand(); - - if (isa(LD->getType())) { - LD->replaceAllUsesWith(Vectors.back()); - } else { - auto *GEP = dyn_cast(Ptr); - assert(GEP && "Unexpected pattern!"); - assert(GEP->getNumIndices() == 2 && "Unexpected pattern!"); - Value *Idx = GEP->getOperand(2); - Value *Vec = Vectors.back(); - auto *NewExtract = ExtractElementInst::Create(Vec, Idx); - NewExtract->insertAfter(cast(Vec)); - LD->replaceAllUsesWith(NewExtract); - } - - } else { - Replace({}, LD); - } - - return true; - }; - - // Go over the GV users, find Load and ExtractElement instructions and - // replace them with the corresponding function call. - for (auto *UI : GV->users()) { - // There might or might not be an addrspacecast instruction. - if (auto *ASCast = dyn_cast(UI)) { - Casts.push_back(ASCast); - for (auto *CastUser : ASCast->users()) { - if (ReplaceIfLoad(CastUser)) - continue; - if (auto *GEP = dyn_cast(CastUser)) { - GEPs.push_back(GEP); - for (auto *GEPUser : GEP->users()) { - if (!ReplaceIfLoad(GEPUser)) - llvm_unreachable("Unexpected pattern!"); - } - } else { - llvm_unreachable("Unexpected pattern!"); - } - } - } else if (auto *GEP = dyn_cast(UI)) { - GEPs.push_back(GEP); - for (auto *GEPUser : GEP->users()) { - if (!ReplaceIfLoad(GEPUser)) - llvm_unreachable("Unexpected pattern!"); - } - } else if (!ReplaceIfLoad(UI)) { - llvm_unreachable("Unexpected pattern!"); - } - } - - auto Erase = [](std::vector &ToErase) { - for (Instruction *I : ToErase) { - assert(I->hasNUses(0)); - I->eraseFromParent(); - } - }; - // Order of erasing is important. - Erase(Loads); - Erase(GEPs); - Erase(Casts); - + replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func); return true; } diff --git a/test/builtin-vars-gep.ll b/test/builtin-vars-gep.ll index c9ee2a1..0dc3074 100644 --- a/test/builtin-vars-gep.ll +++ b/test/builtin-vars-gep.ll @@ -14,28 +14,32 @@ target triple = "spir64" define spir_func void @foo() { entry: %GroupID = alloca [3 x i64], align 8 - %0 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)* - %1 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %0, i64 0, i64 0 + %0 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4) + %1 = getelementptr <3 x i64>, ptr addrspace(4) %0, i64 0, i64 0 ; CHECK-LLVM: %[[GLocalSize0:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1 -; CHECK-LLVM: %[[Ins0:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize0]], i32 0 -; CHECK-LLVM: %[[GLocalSize1:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1 -; CHECK-LLVM: %[[Ins1:[0-9]+]] = insertelement <3 x i64> %[[Ins0]], i64 %[[GLocalSize1]], i32 1 + %2 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4) + %3 = getelementptr <3 x i64>, ptr addrspace(4) %2, i64 0, i64 2 + %4 = load i64, ptr addrspace(4) %1, align 32 + %5 = load i64, ptr addrspace(4) %3, align 8 ; CHECK-LLVM: %[[GLocalSize2:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1 -; CHECK-LLVM: %[[Ins2:[0-9]+]] = insertelement <3 x i64> %[[Ins1]], i64 %[[GLocalSize2]], i32 2 -; CHECK-LLVM: %[[Extract:[0-9]+]] = extractelement <3 x i64> %[[Ins2]], i64 0 - %2 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)* - %3 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %2, i64 0, i64 2 - %4 = load i64, i64 addrspace(4)* %1, align 32 - %5 = load i64, i64 addrspace(4)* %3, align 8 -; CHECK-LLVM: %[[GLocalSize01:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1 -; CHECK-LLVM: %[[Ins01:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize01]], i32 0 -; CHECK-LLVM: %[[GLocalSize11:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1 -; CHECK-LLVM: %[[Ins11:[0-9]+]] = insertelement <3 x i64> %[[Ins01]], i64 %[[GLocalSize11]], i32 1 -; CHECK-LLVM: %[[GLocalSize21:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1 -; CHECK-LLVM: %[[Ins21:[0-9]+]] = insertelement <3 x i64> %[[Ins11]], i64 %[[GLocalSize21]], i32 2 -; CHECK-LLVM: %[[Extract1:[0-9]+]] = extractelement <3 x i64> %[[Ins21]], i64 2 -; CHECK-LLVM: mul i64 %[[Extract]], %[[Extract1]] +; CHECK-LLVM: mul i64 %[[GLocalSize0]], %[[GLocalSize2]] %mul = mul i64 %4, %5 ret void } +; Function Attrs: alwaysinline convergent nounwind mustprogress +define spir_func void @foo_i8gep() { +entry: + %GroupID = alloca [3 x i64], align 8 + %0 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4) + %1 = getelementptr i8, ptr addrspace(4) %0, i64 0 +; CHECK-LLVM: %[[GLocalSize0:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1 + %2 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4) + %3 = getelementptr i8, ptr addrspace(4) %2, i64 16 + %4 = load i64, ptr addrspace(4) %1, align 32 + %5 = load i64, ptr addrspace(4) %3, align 8 +; CHECK-LLVM: %[[GLocalSize2:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1 +; CHECK-LLVM: mul i64 %[[GLocalSize0]], %[[GLocalSize2]] + %mul = mul i64 %4, %5 + ret void +} diff --git a/test/transcoding/builtin_vars_gep.ll b/test/transcoding/builtin_vars_gep.ll index 133fcf5..2bf1924 100644 --- a/test/transcoding/builtin_vars_gep.ll +++ b/test/transcoding/builtin_vars_gep.ll @@ -23,20 +23,8 @@ define spir_kernel void @f() { entry: %0 = load i64, i64 addrspace(1)* getelementptr (<3 x i64>, <3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId, i64 0, i64 0), align 32 ; CHECK-OCL-IR: %[[#ID1:]] = call spir_func i64 @_Z13get_global_idj(i32 0) - ; CHECK-OCL-IR: %[[#VEC1:]] = insertelement <3 x i64> undef, i64 %[[#ID1]], i32 0 - ; CHECK-OCL-IR: %[[#ID2:]] = call spir_func i64 @_Z13get_global_idj(i32 1) - ; CHECK-OCL-IR: %[[#VEC2:]] = insertelement <3 x i64> %[[#VEC1]], i64 %[[#ID2]], i32 1 - ; CHECK-OCL-IR: %[[#ID3:]] = call spir_func i64 @_Z13get_global_idj(i32 2) - ; CHECK-OCL-IR: %[[#VEC3:]] = insertelement <3 x i64> %[[#VEC2]], i64 %[[#ID3]], i32 2 - ; CHECK-OCL-IR: %[[#]] = extractelement <3 x i64> %[[#VEC3]], i64 0 ; CHECK-SPV-IR: %[[#ID1:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 0) - ; CHECK-SPV-IR: %[[#VEC1:]] = insertelement <3 x i64> undef, i64 %[[#ID1]], i32 0 - ; CHECK-SPV-IR: %[[#ID2:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 1) - ; CHECK-SPV-IR: %[[#VEC2:]] = insertelement <3 x i64> %[[#VEC1]], i64 %[[#ID2]], i32 1 - ; CHECK-SPV-IR: %[[#ID3:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 2) - ; CHECK-SPV-IR: %[[#VEC3:]] = insertelement <3 x i64> %[[#VEC2]], i64 %[[#ID3]], i32 2 - ; CHECK-SPV-IR: %[[#]] = extractelement <3 x i64> %[[#VEC3]], i64 0 ret void } -- 2.30.2