[PATCH 73/79] [Backport to 15] Fix SPIR-V global to function replacement for differin...
authorMaksim Shelegov <maksim.shelegov@intel.com>
Wed, 29 Nov 2023 11:58:58 +0000 (03:58 -0800)
committerAndreas Beckmann <anbe@debian.org>
Thu, 8 Feb 2024 21:48:18 +0000 (22:48 +0100)
In some cases, we will see IR with the following

@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

...

%0 = load <6 x i32>, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
%1 = extractelement <6 x i32> %0, i64 0
Note the global type and load type are different. Change the handling of vector loads from vector globals to reconstruct the global vector type and then bitcast to the load type.

Thanks to @jcranmer-intel for helping me find the simplest solution.

Co-authored-by: Nick Sarnie <sarnex@users.noreply.github.com>
Gbp-Pq: Name 0073-Backport-to-15-Fix-SPIR-V-global-to-function-replace.patch

lib/SPIRV/SPIRVUtil.cpp
test/transcoding/builtin_vars_different_type.ll [new file with mode: 0644]

index 9a45e775e7d3e535be53b26c2e10a4ed75f8c6d4..4c5a11b8c9459b7a52a23eeaccf2452dd0d36fc1 100644 (file)
@@ -1922,19 +1922,20 @@ bool isSPIRVBuiltinVariable(GlobalVariable *GV,
 /// are accumulated in the AccumulatedOffset parameter, which will eventually be
 /// used to figure out which index of a variable is being used.
 static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
-                                    Function *ReplacementFunc) {
+                                    Function *ReplacementFunc,
+                                    GlobalVariable *GV) {
   const DataLayout &DL = ReplacementFunc->getParent()->getDataLayout();
   SmallVector<Instruction *, 4> InstsToRemove;
   for (User *U : V->users()) {
     if (auto *Cast = dyn_cast<CastInst>(U)) {
-      replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc);
+      replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc, GV);
       InstsToRemove.push_back(Cast);
     } else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
       APInt NewOffset = AccumulatedOffset.sextOrTrunc(
           DL.getIndexSizeInBits(GEP->getPointerAddressSpace()));
       if (!GEP->accumulateConstantOffset(DL, NewOffset))
         llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
-      replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc);
+      replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc, GV);
       InstsToRemove.push_back(GEP);
     } else if (auto *Load = dyn_cast<LoadInst>(U)) {
       // Figure out which index the accumulated offset corresponds to. If we
@@ -1957,7 +1958,12 @@ static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
       } else {
         // The function has an index parameter.
         if (auto *VecTy = dyn_cast<FixedVectorType>(Load->getType())) {
-          if (!Index.isZero())
+          // Reconstruct the original global variable vector because
+          // the load type may not match.
+          // global <3 x i64>, load <6 x i32>
+          VecTy = cast<FixedVectorType>(GV->getValueType());
+          if (!Index.isZero() || DL.getTypeSizeInBits(VecTy) !=
+                                     DL.getTypeSizeInBits(Load->getType()))
             llvm_unreachable("Illegal use of a SPIR-V builtin variable");
           Replacement = UndefValue::get(VecTy);
           for (unsigned I = 0; I < VecTy->getNumElements(); I++) {
@@ -1967,6 +1973,19 @@ static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
                     Builder.CreateCall(ReplacementFunc, {Builder.getInt32(I)})),
                 Builder.getInt32(I));
           }
+          // Insert a bitcast from the reconstructed vector to the load vector
+          // type in case they are different.
+          // Input:
+          // %1 = load <6 x i32>, ptr addrspace(1) %0, align 32
+          // %2 = extractelement <6 x i32> %1, i32 0
+          // %3 = add i32 5, %2
+          // Modified:
+          // < reconstruct global vector elements 0 and 1 >
+          // %2 = insertelement <3 x i64> %0, i64 %1, i32 2
+          // %3 = bitcast <3 x i64> %2 to <6 x i32>
+          // %4 = extractelement <6 x i32> %3, i32 0
+          // %5 = add i32 5, %4
+          Replacement = Builder.CreateBitCast(Replacement, Load->getType());
         } else if (Load->getType() == ScalarTy) {
           Replacement = setAttrByCalledFunc(Builder.CreateCall(
               ReplacementFunc, {Builder.getInt32(Index.getZExtValue())}));
@@ -2020,7 +2039,7 @@ bool lowerBuiltinVariableToCall(GlobalVariable *GV,
     Func->setDoesNotAccessMemory();
   }
 
-  replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func);
+  replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func, GV);
   return true;
 }
 
diff --git a/test/transcoding/builtin_vars_different_type.ll b/test/transcoding/builtin_vars_different_type.ll
new file mode 100644 (file)
index 0000000..4dc838a
--- /dev/null
@@ -0,0 +1,27 @@
+; RUN: llvm-as %s -o %t.bc
+; RUN: llvm-spirv %t.bc -o %t.spv -spirv-ext=+SPV_INTEL_vector_compute
+; RUN: llvm-spirv -r %t.spv --spirv-target-env=SPV-IR -o %t.out.bc
+; RUN: llvm-dis %t.out.bc -o - | FileCheck %s --check-prefix=CHECK-SPV-IR
+
+target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spir-unknown-unknown"
+
+@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
+
+; Function Attrs: nounwind readnone
+define spir_kernel void @f() {
+entry:
+  %0 = load <6 x i32>, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
+  %1 = extractelement <6 x i32> %0, i64 0
+  %2 = add i32 5, %1
+  ret void
+; CHECK-SPV-IR: %[[#ID0:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 0) #1
+; CHECK-SPV-IR: %[[#ID1:]] = insertelement <3 x i64> undef, i64 %[[#ID0]], i32 0
+; CHECK-SPV-IR: %[[#ID2:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 1) #1
+; CHECK-SPV-IR: %[[#ID3:]] = insertelement <3 x i64> %[[#ID1]], i64 %[[#ID2]], i32 1
+; CHECK-SPV-IR: %[[#ID4:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 2) #1
+; CHECK-SPV-IR: %[[#ID5:]] = insertelement <3 x i64> %[[#ID3]], i64 %[[#ID4]], i32 2
+; CHECK-SPV-IR: %[[#ID6:]] = bitcast <3 x i64> %[[#ID5]] to <6 x i32>
+; CHECK-SPV-IR: %[[#ID7:]] = extractelement <6 x i32> %[[#ID6]], i32 0
+; CHECK-SPV-IR: = add i32 5, %[[#ID7]]
+}