[PATCH 08/79] [Backport to 15] Implement SPV_INTEL_tensor_float32_conversion extensio...
authorStanley Gambarin <stanley.gambarin@intel.com>
Fri, 11 Nov 2022 10:31:34 +0000 (02:31 -0800)
committerAndreas Beckmann <anbe@debian.org>
Thu, 8 Feb 2024 21:48:18 +0000 (22:48 +0100)
This extension adds conversion instruction from float to tensor float (TF32)
data format. TF32 uses 1 bit for a sign, 8 bits for an exponent and 10 bits
for a fraction. This extension doesn’t introduce TF32 type in SPIR-V, instead
instruction below uses 32-bit float type to represent TF32 value.

Spec: https://github.com/intel/llvm/pull/6990

Co-authored-by: Dmitry Sidorov <dmitry.sidorov@intel.com>
Gbp-Pq: Name 0008-Backport-to-15-Implement-SPV_INTEL_tensor_float32_co.patch

include/LLVMSPIRVExtensions.inc
lib/SPIRV/libSPIRV/SPIRVInstruction.h
lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
lib/SPIRV/libSPIRV/spirv_internal.hpp
test/transcoding/SPV_INTEL_tensor_float32_conversion/convert_tensor_float32.ll [new file with mode: 0644]

index 4f970d2a370128f0aeffd1eb95b11e4ab54fb551..edad4ab40a3a394d4fb483fc12f04f741298c98b 100644 (file)
@@ -54,4 +54,5 @@ EXT(SPV_INTEL_global_variable_decorations)
 EXT(SPV_INTEL_non_constant_addrspace_printf)
 EXT(SPV_INTEL_complex_float_mul_div)
 EXT(SPV_INTEL_split_barrier)
-EXT(SPV_INTEL_masked_gather_scatter)
+EXT(SPV_INTEL_tensor_float32_conversion)
+EXT(SPV_INTEL_masked_gather_scatter)
\ No newline at end of file
index 2ffd4c3aa65260f1d12ebc3aeb2ef3e3c9c30319..cd91461c28a0c39e9591169adc7b6c723cf64aee 100644 (file)
@@ -3414,6 +3414,64 @@ _SPIRV_OP(ComplexFMulINTEL)
 _SPIRV_OP(ComplexFDivINTEL)
 #undef _SPIRV_OP
 
+template <Op OC>
+class SPIRVTensorFloat32ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
+protected:
+  SPIRVCapVec getRequiredCapability() const override {
+    return getVec(internal::CapabilityTensorFloat32ConversionINTEL);
+  }
+
+  llvm::Optional<ExtensionID> getRequiredExtension() const override {
+    return ExtensionID::SPV_INTEL_tensor_float32_conversion;
+  }
+
+  void validate() const override {
+    SPIRVUnaryInst<OC>::validate();
+
+    SPIRVType *ResCompTy = this->getType();
+    SPIRVWord ResCompCount = 1;
+    if (ResCompTy->isTypeVector()) {
+      ResCompCount = ResCompTy->getVectorComponentCount();
+      ResCompTy = ResCompTy->getVectorComponentType();
+    }
+
+    // validate is a const method, whilst getOperand is non-const method
+    // because it may call a method of class Module that may modify LiteralMap
+    // of Module field. That modification is not impacting validate method for
+    // these instructions, so const_cast is safe here.
+    using SPVTF32ConvTy = SPIRVTensorFloat32ConversionINTELInstBase<OC>;
+    SPIRVValue *Input = const_cast<SPVTF32ConvTy *>(this)->getOperand(0);
+
+    SPIRVType *InCompTy = Input->getType();
+    SPIRVWord InCompCount = 1;
+    if (InCompTy->isTypeVector()) {
+      InCompCount = InCompTy->getVectorComponentCount();
+      InCompTy = InCompTy->getVectorComponentType();
+    }
+
+    auto InstName = OpCodeNameMap::map(OC);
+    SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
+
+    SPVErrLog.checkError(
+        ResCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction,
+        InstName + "\nResult value must be a scalar or vector of floating-point"
+                   " 32-bit type\n");
+    SPVErrLog.checkError(InCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction,
+                         InstName +
+                             "\nInput value must be a scalar or vector of "
+                             "floating-point 32-bit type\n");
+    SPVErrLog.checkError(
+        ResCompCount == InCompCount, SPIRVEC_InvalidInstruction,
+        InstName + "\nInput type must have the same number of components as "
+                   "result type\n");
+  }
+};
+
+#define _SPIRV_OP(x)                                                           \
+  typedef SPIRVTensorFloat32ConversionINTELInstBase<internal::Op##x> SPIRV##x;
+_SPIRV_OP(ConvertFToTF32INTEL)
+#undef _SPIRV_OP
+
 class SPIRVMaskedGatherScatterINTELInstBase : public SPIRVInstTemplateBase {
 protected:
   SPIRVCapVec getRequiredCapability() const override {
index fd8c5790e4b55081f0d7314035d7693103090884..d73fdee7b68636ad2b7d467b48293a6df71e78fb 100644 (file)
@@ -617,6 +617,8 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
   add(internal::CapabilityNonConstantAddrspacePrintfINTEL,
       "NonConstantAddrspacePrintfINTEL");
   add(internal::CapabilityComplexFloatMulDivINTEL, "ComplexFloatMulDivINTEL");
+  add(internal::CapabilityTensorFloat32ConversionINTEL,
+      "TensorFloat32ConversionINTEL");
   add(internal::CapabilityMaskedGatherScatterINTEL, "MaskedGatherScatterINTEL");
 }
 SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
index 0ed0d855d5e61f01761c179df5a1eef3759f44fa..b5e4cefdf0ff066329a2d219059dba0170b10a6e 100644 (file)
@@ -13,5 +13,6 @@ _SPIRV_OP_INTERNAL(JointMatrixWorkItemLengthINTEL,
                    internal::OpJointMatrixWorkItemLengthINTEL)
 _SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
 _SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL)
+_SPIRV_OP_INTERNAL(ConvertFToTF32INTEL, internal::ConvertFToTF32INTEL)
 _SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL)
-_SPIRV_OP_INTERNAL(MaskedScatterINTEL, internal::OpMaskedScatterINTEL)
+_SPIRV_OP_INTERNAL(MaskedScatterINTEL, internal::OpMaskedScatterINTEL)
\ No newline at end of file
index 9073eccd74b2073fac73b39d342db66ab0280994..3938ac4cc2eebc2ff09e872acdf636a96b730939 100644 (file)
@@ -46,6 +46,7 @@ enum InternalOp {
   IOpJointMatrixWorkItemLengthINTEL = 6410,
   IOpComplexFMulINTEL = 6415,
   IOpComplexFDivINTEL = 6416,
+  IOpConvertFToTF32INTEL = 6426,
   IOpMaskedGatherINTEL = 6428,
   IOpMaskedScatterINTEL = 6429,
   IOpPrev = OpMax - 2,
@@ -79,6 +80,7 @@ enum InternalCapability {
   ICapGlobalVariableDecorationsINTEL = 6146,
   ICapabilityNonConstantAddrspacePrintfINTEL = 6411,
   ICapabilityComplexFloatMulDivINTEL = 6414,
+  ICapabilityTensorFloat32ConversionINTEL = 6425,
   ICapabilityMaskedGatherScatterINTEL = 6427
 };
 
@@ -125,6 +127,9 @@ _SPIRV_OP(Capability, ComplexFloatMulDivINTEL)
 _SPIRV_OP(Op, ComplexFMulINTEL)
 _SPIRV_OP(Op, ComplexFDivINTEL)
 
+_SPIRV_OP(Capability, TensorFloat32ConversionINTEL)
+_SPIRV_OP(Op, ConvertFToTF32INTEL)
+
 _SPIRV_OP(Capability, MaskedGatherScatterINTEL)
 _SPIRV_OP(Op, MaskedGatherINTEL)
 _SPIRV_OP(Op, MaskedScatterINTEL)
diff --git a/test/transcoding/SPV_INTEL_tensor_float32_conversion/convert_tensor_float32.ll b/test/transcoding/SPV_INTEL_tensor_float32_conversion/convert_tensor_float32.ll
new file mode 100644 (file)
index 0000000..1f02706
--- /dev/null
@@ -0,0 +1,50 @@
+; RUN: llvm-as %s -o %t.bc
+; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_tensor_float32_conversion
+; RUN: llvm-spirv %t.spv -o %t.spt --to-text
+; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
+; RUN: llvm-spirv %t.spv -o %t.rev.bc -r -emit-opaque-pointers --spirv-target-env=SPV-IR
+; RUN: llvm-dis %t.rev.bc -o %t.rev.ll
+; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
+
+; RUN: not llvm-spirv %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension:
+; CHECK-ERROR-NEXT: SPV_INTEL_tensor_float32_conversion
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+; CHECK-SPIRV: Capability TensorFloat32ConversionINTEL
+; CHECK-SPIRV: Extension "SPV_INTEL_tensor_float32_conversion"
+; CHECK-SPIRV: TypeFloat [[#FP32Ty:]] 32
+; CHECK-SPIRV: TypeVector [[#FP32v8Ty:]] [[#FP32Ty]] 8
+; CHECK-SPIRV: Constant [[#FP32Ty]] [[#CONST:]] 1065353216
+
+; CHECK-SPIRV: FunctionParameter [[#FP32Ty]] [[FP32ValId:.*]]
+; CHECK-SPIRV: FunctionParameter [[#FP32v8Ty]] [[FP32v8ValId:.*]]
+
+; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32Ty]] [[#]] [[FP32ValId]]
+; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32v8Ty]] [[#]] [[FP32v8ValId]]
+; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32Ty]] [[#]] [[#CONST]]
+
+; CHECK-LLVM: call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float
+; CHECK-LLVM: call spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float>
+; CHECK-LLVM: call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float 1.000000e+00)
+
+define spir_func void @_Z2opffv8(float %a, <8 x float> %in) {
+  %1 = tail call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float %a)
+  %2 = tail call spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float> %in)
+  %3 = tail call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float 1.000000e+00)
+  ret void
+}
+
+declare spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float)
+
+declare spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float>)
+
+!opencl.spir.version = !{!0}
+!spirv.Source = !{!1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, i32 2}
+!1 = !{i32 4, i32 100000}
+!2 = !{!"clang version 16.0.0"}