[PATCH 01/79] Integer dot product 4x8 packed translation (#1654)
authorJakub Czarnecki <jakub.czarnecki@intel.com>
Mon, 17 Oct 2022 15:24:28 +0000 (17:24 +0200)
committerAndreas Beckmann <anbe@debian.org>
Thu, 14 Mar 2024 19:01:08 +0000 (20:01 +0100)
Changed the integer dot translation to use the correct function names
(i.e. dot_4x8packed or dot_acc_sat_4x8packed) to translate them into
proper OpCodes. Additionally removed unused variables from visitCallDot

Gbp-Pq: Name 0001-Integer-dot-product-4x8-packed-translation-1654.patch

lib/SPIRV/OCLToSPIRV.cpp
lib/SPIRV/OCLUtil.h
test/transcoding/SPV_KHR_integer_dot_product_OCLtoSPIRV_int.ll

index 95e52cc8387f6494fe5af40233fe9361bd4b626d..7db6d37f9a886c79c68cfc6e2fcd97378d8e6a61 100644 (file)
@@ -328,7 +328,9 @@ void OCLToSPIRVBase::visitCallInst(CallInst &CI) {
     return;
   }
   if (DemangledName == kOCLBuiltinName::Dot ||
-      DemangledName == kOCLBuiltinName::DotAccSat) {
+      DemangledName == kOCLBuiltinName::DotAccSat ||
+      DemangledName.startswith(kOCLBuiltinName::Dot4x8PackedPrefix) ||
+      DemangledName.startswith(kOCLBuiltinName::DotAccSat4x8PackedPrefix)) {
     if (CI.getOperand(0)->getType()->isVectorTy()) {
       auto *VT = (VectorType *)(CI.getOperand(0)->getType());
       if (!isa<llvm::IntegerType>(VT->getElementType())) {
@@ -1323,19 +1325,11 @@ void OCLToSPIRVBase::visitCallDot(CallInst *CI, StringRef MangledName,
   // translation for dot function calls,
   // to differentiate between integer dot products
 
-  SmallVector<Value *, 3> Args;
-  Args.push_back(CI->getOperand(0));
-  Args.push_back(CI->getOperand(1));
   bool IsFirstSigned, IsSecondSigned;
   bool IsDot = DemangledName == kOCLBuiltinName::Dot;
-  std::string FunName = (IsDot) ? "DotKHR" : "DotAccSatKHR";
-  if (CI->arg_size() > 2) {
-    Args.push_back(CI->getOperand(2));
-  }
-  if (CI->arg_size() > 3) {
-    Args.push_back(CI->getOperand(3));
-  }
-  if (CI->getOperand(0)->getType()->isVectorTy()) {
+  bool IsAccSat = DemangledName.contains(kOCLBuiltinName::DotAccSat);
+  bool IsPacked = CI->getOperand(0)->getType()->isIntegerTy();
+  if (!IsPacked) {
     if (IsDot) {
       // dot(char4, char4) _Z3dotDv4_cS_
       // dot(char4, uchar4) _Z3dotDv4_cDv4_h
@@ -1376,21 +1370,28 @@ void OCLToSPIRVBase::visitCallDot(CallInst *CI, StringRef MangledName,
     }
   } else {
     // for packed format
-    // dot(int, int, int) _Z3dotiii
-    // dot(int, uint, int) _Z3dotiji
-    // dot(uint, int, int) _Z3dotjii
-    // dot(uint, uint, int) _Z3dotjji
+    // dot_4x8packed_ss_int(uint, uint) _Z20dot_4x8packed_ss_intjj
+    // dot_4x8packed_su_int(uint, uint) _Z20dot_4x8packed_su_intjj
+    // dot_4x8packed_us_int(uint, uint) _Z20dot_4x8packed_us_intjj
+    // dot_4x8packed_uu_uint(uint, uint) _Z21dot_4x8packed_uu_uintjj
     // or
-    // dot_acc_sat(int, int, int, int) _Z11dot_acc_satiiii
-    // dot_acc_sat(int, uint, int, int) _Z11dot_acc_satijii
-    // dot_acc_sat(uint, int, int, int) _Z11dot_acc_satjiii
-    // dot_acc_sat(uint, uint, int, int) _Z11dot_acc_satjjii
-    assert(MangledName.startswith("_Z3dot") ||
-           MangledName.startswith("_Z11dot_acc_sat"));
-    IsFirstSigned = (IsDot) ? (MangledName[MangledName.size() - 3] == 'i')
-                            : (MangledName[MangledName.size() - 4] == 'i');
-    IsSecondSigned = (IsDot) ? (MangledName[MangledName.size() - 2] == 'i')
-                             : (MangledName[MangledName.size() - 3] == 'i');
+    // dot_acc_sat_4x8packed_ss_int(uint, uint, int)
+    // _Z28dot_acc_sat_4x8packed_ss_intjji
+    // dot_acc_sat_4x8packed_su_int(uint, uint, int)
+    // _Z28dot_acc_sat_4x8packed_su_intjji
+    // dot_acc_sat_4x8packed_us_int(uint, uint, int)
+    // _Z28dot_acc_sat_4x8packed_us_intjji
+    // dot_acc_sat_4x8packed_uu_uint(uint, uint, uint)
+    // _Z29dot_acc_sat_4x8packed_uu_uintjjj
+    assert(MangledName.startswith("_Z20dot_4x8packed") ||
+           MangledName.startswith("_Z21dot_4x8packed") ||
+           MangledName.startswith("_Z28dot_acc_sat_4x8packed") ||
+           MangledName.startswith("_Z29dot_acc_sat_4x8packed"));
+    size_t SignIndex = IsAccSat
+                           ? strlen(kOCLBuiltinName::DotAccSat4x8PackedPrefix)
+                           : strlen(kOCLBuiltinName::Dot4x8PackedPrefix);
+    IsFirstSigned = DemangledName[SignIndex] == 's';
+    IsSecondSigned = DemangledName[SignIndex + 1] == 's';
   }
   AttributeList Attrs = CI->getCalledFunction()->getAttributes();
   mutateCallInstSPIRV(
@@ -1403,7 +1404,7 @@ void OCLToSPIRVBase::visitCallDot(CallInst *CI, StringRef MangledName,
           std::swap(Args[0], Args[1]);
         }
         Op OC;
-        if (IsDot) {
+        if (!IsAccSat) {
           OC = (IsFirstSigned != IsSecondSigned
                     ? OpSUDot
                     : ((IsFirstSigned) ? OpSDot : OpUDot));
@@ -1412,6 +1413,14 @@ void OCLToSPIRVBase::visitCallDot(CallInst *CI, StringRef MangledName,
                     ? OpSUDotAccSat
                     : ((IsFirstSigned) ? OpSDotAccSat : OpUDotAccSat));
         }
+        if (IsPacked) {
+          // As per SPIRV specification the dot OpCodes
+          // which use scalar integers to represent
+          // packed vectors need additional argument
+          // specified - the Packed Vector Format
+          Args.push_back(
+              getInt32(M, PackedVectorFormatPackedVectorFormat4x8BitKHR));
+        }
         return getSPIRVFuncName(OC);
       },
       &Attrs);
index 6497ecb6ed1f1bdb0cd94af02d81c01f39073330..1c0d8605432d61a9bb3895c03895d3f6ada0db91 100644 (file)
@@ -234,6 +234,8 @@ const static char Clamp[] = "clamp";
 const static char ConvertPrefix[] = "convert_";
 const static char Dot[] = "dot";
 const static char DotAccSat[] = "dot_acc_sat";
+const static char Dot4x8PackedPrefix[] = "dot_4x8packed_";
+const static char DotAccSat4x8PackedPrefix[] = "dot_acc_sat_4x8packed_";
 const static char EnqueueKernel[] = "enqueue_kernel";
 const static char FixedSqrtINTEL[] = "intel_arbitrary_fixed_sqrt";
 const static char FixedRecipINTEL[] = "intel_arbitrary_fixed_recip";
index 7a7b5cab2daa3ba9f9a7b2723335672d4faecfa8..bb4fb700146b7eeb13ec37a27ce03fc8577c1d41 100644 (file)
@@ -31,40 +31,40 @@ target triple = "spir"
 ; Function Attrs: convergent norecurse nounwind
 define spir_kernel void @test1(i32 %ia, i32 %ua, i32 %ib, i32 %ub, i32 %ires, i32 %ures) local_unnamed_addr #0 !kernel_arg_addr_space !3 !kernel_arg_access_qual !4 !kernel_arg_type !5 !kernel_arg_base_type !5 !kernel_arg_type_qual !6 {
 entry:
-  %call = tail call spir_func i32 @_Z3dotiii(i32 %ia, i32 %ib, i32 0) #2
-  %call1 = tail call spir_func i32 @_Z3dotiji(i32 %ia, i32 %ub, i32 0) #2
-  %call2 = tail call spir_func i32 @_Z3dotjii(i32 %ua, i32 %ib, i32 0) #2
-  %call3 = tail call spir_func i32 @_Z3dotjji(i32 %ua, i32 %ub, i32 0) #2
-  %call4 = tail call spir_func i32 @_Z11dot_acc_satiiii(i32 %ia, i32 %ib, i32 %ires, i32 0) #2
-  %call5 = tail call spir_func i32 @_Z11dot_acc_satijii(i32 %ia, i32 %ub, i32 %ires, i32 0) #2
-  %call6 = tail call spir_func i32 @_Z11dot_acc_satjiii(i32 %ua, i32 %ib, i32 %ires, i32 0) #2
-  %call7 = tail call spir_func i32 @_Z11dot_acc_satjjji(i32 %ua, i32 %ub, i32 %ures, i32 0) #2
+  %call = tail call spir_func i32 @_Z20dot_4x8packed_ss_intjj(i32 %ia, i32 %ib) #2
+  %call1 = tail call spir_func i32 @_Z20dot_4x8packed_su_intjj(i32 %ia, i32 %ub) #2
+  %call2 = tail call spir_func i32 @_Z20dot_4x8packed_us_intjj(i32 %ua, i32 %ib) #2
+  %call3 = tail call spir_func i32 @_Z21dot_4x8packed_uu_uintjj(i32 %ua, i32 %ub) #2
+  %call4 = tail call spir_func i32 @_Z28dot_acc_sat_4x8packed_ss_intjji(i32 %ia, i32 %ib, i32 %ires) #2
+  %call5 = tail call spir_func i32 @_Z28dot_acc_sat_4x8packed_su_intjji(i32 %ia, i32 %ub, i32 %ires) #2
+  %call6 = tail call spir_func i32 @_Z28dot_acc_sat_4x8packed_us_intjji(i32 %ua, i32 %ib, i32 %ires) #2
+  %call7 = tail call spir_func i32 @_Z29dot_acc_sat_4x8packed_uu_uintjjj(i32 %ua, i32 %ub, i32 %ures) #2
   ret void
 }
 
 ; Function Attrs: convergent
-declare spir_func i32 @_Z3dotiii(i32, i32, i32) local_unnamed_addr #1
+declare spir_func i32 @_Z20dot_4x8packed_ss_intjj(i32, i32) local_unnamed_addr #1
 
 ; Function Attrs: convergent
-declare spir_func i32 @_Z3dotiji(i32, i32, i32) local_unnamed_addr #1
+declare spir_func i32 @_Z20dot_4x8packed_su_intjj(i32, i32) local_unnamed_addr #1
 
 ; Function Attrs: convergent
-declare spir_func i32 @_Z3dotjii(i32, i32, i32) local_unnamed_addr #1
+declare spir_func i32 @_Z20dot_4x8packed_us_intjj(i32, i32) local_unnamed_addr #1
 
 ; Function Attrs: convergent
-declare spir_func i32 @_Z3dotjji(i32, i32, i32) local_unnamed_addr #1
+declare spir_func i32 @_Z21dot_4x8packed_uu_uintjj(i32, i32) local_unnamed_addr #1
 
 ; Function Attrs: convergent
-declare spir_func i32 @_Z11dot_acc_satiiii(i32, i32, i32, i32) local_unnamed_addr #1
+declare spir_func i32 @_Z28dot_acc_sat_4x8packed_ss_intjji(i32, i32, i32) local_unnamed_addr #1
 
 ; Function Attrs: convergent
-declare spir_func i32 @_Z11dot_acc_satijii(i32, i32, i32, i32) local_unnamed_addr #1
+declare spir_func i32 @_Z28dot_acc_sat_4x8packed_su_intjji(i32, i32, i32) local_unnamed_addr #1
 
 ; Function Attrs: convergent
-declare spir_func i32 @_Z11dot_acc_satjiii(i32, i32, i32, i32) local_unnamed_addr #1
+declare spir_func i32 @_Z28dot_acc_sat_4x8packed_us_intjji(i32, i32, i32) local_unnamed_addr #1
 
 ; Function Attrs: convergent
-declare spir_func i32 @_Z11dot_acc_satjjji(i32, i32, i32, i32) local_unnamed_addr #1
+declare spir_func i32 @_Z29dot_acc_sat_4x8packed_uu_uintjjj(i32, i32, i32) local_unnamed_addr #1
 
 attributes #0 = { convergent norecurse nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" "unsafe-fp-math"="false" "use-soft-float"="false" }
 attributes #1 = { convergent "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }