[PATCH 66/79] [Backport to 15] [OpaquePointers] Adjust builtin variable tracking...
authorMateusz Chudyk <mateuszchudyk@gmail.com>
Mon, 3 Jul 2023 16:40:05 +0000 (18:40 +0200)
committerAndreas Beckmann <anbe@debian.org>
Thu, 14 Mar 2024 19:01:08 +0000 (20:01 +0100)
The existing logic for the replacement of builtin variables with calls to
functions relies on relatively brittle tracking that is broken when opaque
pointers is turned on, and will be even more thoroughly broken if/when typed
geps are replaced with i8 geps or ptradd. This patch replaces that logic with a
less brittle variant that is able to handle any sequence of bitcast, gep, or
addrspacecast instructions between the global variable and the ultimate load
instruction.

It still will error out if the variable is used in too insane of a fashion (say,
trying to load an i32 out of the i64, or a misaligned vector type).

Co-authored-by: Joshua Cranmer <joshua.cranmer@intel.com>
Gbp-Pq: Name 0066-Backport-to-15-OpaquePointers-Adjust-builtin-variabl.patch

lib/SPIRV/SPIRVInternal.h
lib/SPIRV/SPIRVUtil.cpp
test/builtin-vars-gep.ll
test/transcoding/builtin_vars_gep.ll

index 9eaaf6c0fa02a1ed7a00f90e720070fb3a0b949a..30241c89b9f8f02aafa88291cb00546cd67bfc0f 100644 (file)
@@ -1056,7 +1056,7 @@ std::string decodeSPIRVTypeName(StringRef Name,
                                 SmallVectorImpl<std::string> &Strs);
 
 // Copy attributes from function to call site.
-void setAttrByCalledFunc(CallInst *Call);
+CallInst *setAttrByCalledFunc(CallInst *Call);
 bool isSPIRVBuiltinVariable(GlobalVariable *GV, SPIRVBuiltinVariableKind *Kind);
 // Transform builtin variable from GlobalVariable to builtin call.
 // e.g.
index 1ac4eb62e51b18a97f9b97891ab69ddb15786975..9a45e775e7d3e535be53b26c2e10a4ed75f8c6d4 100644 (file)
@@ -1858,14 +1858,15 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
   return true;
 }
 
-void setAttrByCalledFunc(CallInst *Call) {
+CallInst *setAttrByCalledFunc(CallInst *Call) {
   Function *F = Call->getCalledFunction();
   assert(F);
   if (F->isIntrinsic()) {
-    return;
+    return Call;
   }
   Call->setCallingConv(F->getCallingConv());
   Call->setAttributes(F->getAttributes());
+  return Call;
 }
 
 bool isSPIRVBuiltinVariable(GlobalVariable *GV,
@@ -1915,6 +1916,75 @@ bool isSPIRVBuiltinVariable(GlobalVariable *GV,
 // %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
 // %5 = insertelement <3 x i64> %3, i64 %4, i32 2
 // %6 = extractelement <3 x i64> %5, i32 0
+
+/// Recursively look through the uses of a global variable, including casts or
+/// gep offsets, to find all loads of the variable. Gep offsets that are non-0
+/// are accumulated in the AccumulatedOffset parameter, which will eventually be
+/// used to figure out which index of a variable is being used.
+static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
+                                    Function *ReplacementFunc) {
+  const DataLayout &DL = ReplacementFunc->getParent()->getDataLayout();
+  SmallVector<Instruction *, 4> InstsToRemove;
+  for (User *U : V->users()) {
+    if (auto *Cast = dyn_cast<CastInst>(U)) {
+      replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc);
+      InstsToRemove.push_back(Cast);
+    } else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
+      APInt NewOffset = AccumulatedOffset.sextOrTrunc(
+          DL.getIndexSizeInBits(GEP->getPointerAddressSpace()));
+      if (!GEP->accumulateConstantOffset(DL, NewOffset))
+        llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
+      replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc);
+      InstsToRemove.push_back(GEP);
+    } else if (auto *Load = dyn_cast<LoadInst>(U)) {
+      // Figure out which index the accumulated offset corresponds to. If we
+      // have a weird offset (e.g., trying to load byte 7), bail out.
+      Type *ScalarTy = ReplacementFunc->getReturnType();
+      APInt Index;
+      uint64_t Remainder;
+      APInt::udivrem(AccumulatedOffset, ScalarTy->getScalarSizeInBits() / 8,
+                     Index, Remainder);
+      if (Remainder != 0)
+        llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
+
+      IRBuilder<> Builder(Load);
+      Value *Replacement;
+      if (ReplacementFunc->getFunctionType()->getNumParams() == 0) {
+        if (Load->getType() != ScalarTy)
+          llvm_unreachable("Illegal use of a SPIR-V builtin variable");
+        Replacement =
+            setAttrByCalledFunc(Builder.CreateCall(ReplacementFunc, {}));
+      } else {
+        // The function has an index parameter.
+        if (auto *VecTy = dyn_cast<FixedVectorType>(Load->getType())) {
+          if (!Index.isZero())
+            llvm_unreachable("Illegal use of a SPIR-V builtin variable");
+          Replacement = UndefValue::get(VecTy);
+          for (unsigned I = 0; I < VecTy->getNumElements(); I++) {
+            Replacement = Builder.CreateInsertElement(
+                Replacement,
+                setAttrByCalledFunc(
+                    Builder.CreateCall(ReplacementFunc, {Builder.getInt32(I)})),
+                Builder.getInt32(I));
+          }
+        } else if (Load->getType() == ScalarTy) {
+          Replacement = setAttrByCalledFunc(Builder.CreateCall(
+              ReplacementFunc, {Builder.getInt32(Index.getZExtValue())}));
+        } else {
+          llvm_unreachable("Illegal load type of a SPIR-V builtin variable");
+        }
+      }
+      Load->replaceAllUsesWith(Replacement);
+      InstsToRemove.push_back(Load);
+    } else {
+      llvm_unreachable("Illegal use of a SPIR-V builtin variable");
+    }
+  }
+
+  for (Instruction *I : InstsToRemove)
+    I->eraseFromParent();
+}
+
 bool lowerBuiltinVariableToCall(GlobalVariable *GV,
                                 SPIRVBuiltinVariableKind Kind) {
   // There might be dead constant users of GV (for example, SPIRVLowerConstExpr
@@ -1950,113 +2020,7 @@ bool lowerBuiltinVariableToCall(GlobalVariable *GV,
     Func->setDoesNotAccessMemory();
   }
 
-  // Collect instructions in these containers to remove them later.
-  std::vector<Instruction *> Loads;
-  std::vector<Instruction *> Casts;
-  std::vector<Instruction *> GEPs;
-
-  auto Replace = [&](std::vector<Value *> Arg, Instruction *I) {
-    auto *Call = CallInst::Create(Func, Arg, "", I);
-    Call->takeName(I);
-    setAttrByCalledFunc(Call);
-    SPIRVDBG(dbgs() << "[lowerBuiltinVariableToCall] " << *I << " -> " << *Call
-                    << '\n';)
-    I->replaceAllUsesWith(Call);
-  };
-
-  // If HasIndexArg is true, we create 3 built-in calls and insertelement to
-  // get 3-element vector filled with ids and replace uses of Load instruction
-  // with this vector.
-  // If HasIndexArg is false, the result of the Load instruction is the value
-  // which should be replaced with the Func.
-  // Returns true if Load was replaced, false otherwise.
-  auto ReplaceIfLoad = [&](User *I) {
-    auto *LD = dyn_cast<LoadInst>(I);
-    if (!LD)
-      return false;
-    std::vector<Value *> Vectors;
-    Loads.push_back(LD);
-    if (HasIndexArg) {
-      auto *VecTy = cast<FixedVectorType>(GVTy);
-      Value *EmptyVec = UndefValue::get(VecTy);
-      Vectors.push_back(EmptyVec);
-      const DebugLoc &DLoc = LD->getDebugLoc();
-      for (unsigned I = 0; I < VecTy->getNumElements(); ++I) {
-        auto *Idx = ConstantInt::get(Type::getInt32Ty(C), I);
-        auto *Call = CallInst::Create(Func, {Idx}, "", LD);
-        if (DLoc)
-          Call->setDebugLoc(DLoc);
-        setAttrByCalledFunc(Call);
-        auto *Insert = InsertElementInst::Create(Vectors.back(), Call, Idx);
-        if (DLoc)
-          Insert->setDebugLoc(DLoc);
-        Insert->insertAfter(Call);
-        Vectors.push_back(Insert);
-      }
-
-      Value *Ptr = LD->getPointerOperand();
-
-      if (isa<FixedVectorType>(LD->getType())) {
-        LD->replaceAllUsesWith(Vectors.back());
-      } else {
-        auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
-        assert(GEP && "Unexpected pattern!");
-        assert(GEP->getNumIndices() == 2 && "Unexpected pattern!");
-        Value *Idx = GEP->getOperand(2);
-        Value *Vec = Vectors.back();
-        auto *NewExtract = ExtractElementInst::Create(Vec, Idx);
-        NewExtract->insertAfter(cast<Instruction>(Vec));
-        LD->replaceAllUsesWith(NewExtract);
-      }
-
-    } else {
-      Replace({}, LD);
-    }
-
-    return true;
-  };
-
-  // Go over the GV users, find Load and ExtractElement instructions and
-  // replace them with the corresponding function call.
-  for (auto *UI : GV->users()) {
-    // There might or might not be an addrspacecast instruction.
-    if (auto *ASCast = dyn_cast<AddrSpaceCastInst>(UI)) {
-      Casts.push_back(ASCast);
-      for (auto *CastUser : ASCast->users()) {
-        if (ReplaceIfLoad(CastUser))
-          continue;
-        if (auto *GEP = dyn_cast<GetElementPtrInst>(CastUser)) {
-          GEPs.push_back(GEP);
-          for (auto *GEPUser : GEP->users()) {
-            if (!ReplaceIfLoad(GEPUser))
-              llvm_unreachable("Unexpected pattern!");
-          }
-        } else {
-          llvm_unreachable("Unexpected pattern!");
-        }
-      }
-    } else if (auto *GEP = dyn_cast<GetElementPtrInst>(UI)) {
-      GEPs.push_back(GEP);
-      for (auto *GEPUser : GEP->users()) {
-        if (!ReplaceIfLoad(GEPUser))
-          llvm_unreachable("Unexpected pattern!");
-      }
-    } else if (!ReplaceIfLoad(UI)) {
-      llvm_unreachable("Unexpected pattern!");
-    }
-  }
-
-  auto Erase = [](std::vector<Instruction *> &ToErase) {
-    for (Instruction *I : ToErase) {
-      assert(I->hasNUses(0));
-      I->eraseFromParent();
-    }
-  };
-  // Order of erasing is important.
-  Erase(Loads);
-  Erase(GEPs);
-  Erase(Casts);
-
+  replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func);
   return true;
 }
 
index c9ee2a105a9ee33734e2f8269c7f776e4a7792c4..0dc307471197bca424b0817f1a28597b43be115c 100644 (file)
@@ -14,28 +14,32 @@ target triple = "spir64"
 define spir_func void @foo() {
 entry:
   %GroupID = alloca [3 x i64], align 8
-  %0 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)*
-  %1 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %0, i64 0, i64 0
+  %0 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
+  %1 = getelementptr <3 x i64>, ptr addrspace(4) %0, i64 0, i64 0
 ; CHECK-LLVM: %[[GLocalSize0:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
-; CHECK-LLVM: %[[Ins0:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize0]], i32 0
-; CHECK-LLVM: %[[GLocalSize1:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
-; CHECK-LLVM: %[[Ins1:[0-9]+]] = insertelement <3 x i64> %[[Ins0]], i64 %[[GLocalSize1]], i32 1
+  %2 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
+  %3 = getelementptr <3 x i64>, ptr addrspace(4) %2, i64 0, i64 2
+  %4 = load i64, ptr addrspace(4) %1, align 32
+  %5 = load i64, ptr addrspace(4) %3, align 8
 ; CHECK-LLVM: %[[GLocalSize2:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
-; CHECK-LLVM: %[[Ins2:[0-9]+]] = insertelement <3 x i64> %[[Ins1]], i64 %[[GLocalSize2]], i32 2
-; CHECK-LLVM: %[[Extract:[0-9]+]] = extractelement <3 x i64> %[[Ins2]], i64 0
-  %2 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)*
-  %3 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %2, i64 0, i64 2
-  %4 = load i64, i64 addrspace(4)* %1, align 32
-  %5 = load i64, i64 addrspace(4)* %3, align 8
-; CHECK-LLVM: %[[GLocalSize01:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
-; CHECK-LLVM: %[[Ins01:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize01]], i32 0
-; CHECK-LLVM: %[[GLocalSize11:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
-; CHECK-LLVM: %[[Ins11:[0-9]+]] = insertelement <3 x i64> %[[Ins01]], i64 %[[GLocalSize11]], i32 1
-; CHECK-LLVM: %[[GLocalSize21:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
-; CHECK-LLVM: %[[Ins21:[0-9]+]] = insertelement <3 x i64> %[[Ins11]], i64 %[[GLocalSize21]], i32 2
-; CHECK-LLVM: %[[Extract1:[0-9]+]] = extractelement <3 x i64> %[[Ins21]], i64 2
-; CHECK-LLVM:  mul i64 %[[Extract]], %[[Extract1]]
+; CHECK-LLVM:  mul i64 %[[GLocalSize0]], %[[GLocalSize2]]
   %mul = mul i64 %4, %5
   ret void
 }
 
+; Function Attrs: alwaysinline convergent nounwind mustprogress
+define spir_func void @foo_i8gep() {
+entry:
+  %GroupID = alloca [3 x i64], align 8
+  %0 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
+  %1 = getelementptr i8, ptr addrspace(4) %0, i64 0
+; CHECK-LLVM: %[[GLocalSize0:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
+  %2 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
+  %3 = getelementptr i8, ptr addrspace(4) %2, i64 16
+  %4 = load i64, ptr addrspace(4) %1, align 32
+  %5 = load i64, ptr addrspace(4) %3, align 8
+; CHECK-LLVM: %[[GLocalSize2:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
+; CHECK-LLVM:  mul i64 %[[GLocalSize0]], %[[GLocalSize2]]
+  %mul = mul i64 %4, %5
+  ret void
+}
index 133fcf561770a7a0250f66460024642978735f45..2bf192410b3cc304e20ff5fd644878babe9a8138 100644 (file)
@@ -23,20 +23,8 @@ define spir_kernel void @f() {
 entry:
   %0 = load i64, i64 addrspace(1)* getelementptr (<3 x i64>, <3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId, i64 0, i64 0), align 32
   ; CHECK-OCL-IR: %[[#ID1:]] = call spir_func i64 @_Z13get_global_idj(i32 0)
-  ; CHECK-OCL-IR: %[[#VEC1:]] = insertelement <3 x i64> undef, i64 %[[#ID1]], i32 0
-  ; CHECK-OCL-IR: %[[#ID2:]] = call spir_func i64 @_Z13get_global_idj(i32 1)
-  ; CHECK-OCL-IR: %[[#VEC2:]] = insertelement <3 x i64> %[[#VEC1]], i64 %[[#ID2]], i32 1
-  ; CHECK-OCL-IR: %[[#ID3:]] = call spir_func i64 @_Z13get_global_idj(i32 2)
-  ; CHECK-OCL-IR: %[[#VEC3:]] = insertelement <3 x i64> %[[#VEC2]], i64 %[[#ID3]], i32 2
-  ; CHECK-OCL-IR: %[[#]] = extractelement <3 x i64> %[[#VEC3]], i64 0
 
   ; CHECK-SPV-IR: %[[#ID1:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 0)
-  ; CHECK-SPV-IR: %[[#VEC1:]] = insertelement <3 x i64> undef, i64 %[[#ID1]], i32 0
-  ; CHECK-SPV-IR: %[[#ID2:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 1)
-  ; CHECK-SPV-IR: %[[#VEC2:]] = insertelement <3 x i64> %[[#VEC1]], i64 %[[#ID2]], i32 1
-  ; CHECK-SPV-IR: %[[#ID3:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 2)
-  ; CHECK-SPV-IR: %[[#VEC3:]] = insertelement <3 x i64> %[[#VEC2]], i64 %[[#ID3]], i32 2
-  ; CHECK-SPV-IR: %[[#]] = extractelement <3 x i64> %[[#VEC3]], i64 0
 
   ret void
 }