From 61f1097c8897bd1c9556dda18175826834bb4094 Mon Sep 17 00:00:00 2001 From: Stanley Gambarin Date: Fri, 11 Nov 2022 02:31:34 -0800 Subject: [PATCH] [PATCH 08/79] [Backport to 15] Implement SPV_INTEL_tensor_float32_conversion extension (#1656) (#1700) MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit This extension adds conversion instruction from float to tensor float (TF32) data format. TF32 uses 1 bit for a sign, 8 bits for an exponent and 10 bits for a fraction. This extension doesn’t introduce TF32 type in SPIR-V, instead instruction below uses 32-bit float type to represent TF32 value. Spec: https://github.com/intel/llvm/pull/6990 Co-authored-by: Dmitry Sidorov Gbp-Pq: Name 0008-Backport-to-15-Implement-SPV_INTEL_tensor_float32_co.patch --- include/LLVMSPIRVExtensions.inc | 3 +- lib/SPIRV/libSPIRV/SPIRVInstruction.h | 58 +++++++++++++++++++ lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h | 2 + lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h | 3 +- lib/SPIRV/libSPIRV/spirv_internal.hpp | 5 ++ .../convert_tensor_float32.ll | 50 ++++++++++++++++ 6 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 test/transcoding/SPV_INTEL_tensor_float32_conversion/convert_tensor_float32.ll diff --git a/include/LLVMSPIRVExtensions.inc b/include/LLVMSPIRVExtensions.inc index 4f970d2..edad4ab 100644 --- a/include/LLVMSPIRVExtensions.inc +++ b/include/LLVMSPIRVExtensions.inc @@ -54,4 +54,5 @@ EXT(SPV_INTEL_global_variable_decorations) EXT(SPV_INTEL_non_constant_addrspace_printf) EXT(SPV_INTEL_complex_float_mul_div) EXT(SPV_INTEL_split_barrier) -EXT(SPV_INTEL_masked_gather_scatter) +EXT(SPV_INTEL_tensor_float32_conversion) +EXT(SPV_INTEL_masked_gather_scatter) \ No newline at end of file diff --git a/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/lib/SPIRV/libSPIRV/SPIRVInstruction.h index 2ffd4c3..cd91461 100644 --- a/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -3414,6 +3414,64 @@ _SPIRV_OP(ComplexFMulINTEL) _SPIRV_OP(ComplexFDivINTEL) #undef _SPIRV_OP +template +class SPIRVTensorFloat32ConversionINTELInstBase : public SPIRVUnaryInst { +protected: + SPIRVCapVec getRequiredCapability() const override { + return getVec(internal::CapabilityTensorFloat32ConversionINTEL); + } + + llvm::Optional getRequiredExtension() const override { + return ExtensionID::SPV_INTEL_tensor_float32_conversion; + } + + void validate() const override { + SPIRVUnaryInst::validate(); + + SPIRVType *ResCompTy = this->getType(); + SPIRVWord ResCompCount = 1; + if (ResCompTy->isTypeVector()) { + ResCompCount = ResCompTy->getVectorComponentCount(); + ResCompTy = ResCompTy->getVectorComponentType(); + } + + // validate is a const method, whilst getOperand is non-const method + // because it may call a method of class Module that may modify LiteralMap + // of Module field. That modification is not impacting validate method for + // these instructions, so const_cast is safe here. + using SPVTF32ConvTy = SPIRVTensorFloat32ConversionINTELInstBase; + SPIRVValue *Input = const_cast(this)->getOperand(0); + + SPIRVType *InCompTy = Input->getType(); + SPIRVWord InCompCount = 1; + if (InCompTy->isTypeVector()) { + InCompCount = InCompTy->getVectorComponentCount(); + InCompTy = InCompTy->getVectorComponentType(); + } + + auto InstName = OpCodeNameMap::map(OC); + SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog(); + + SPVErrLog.checkError( + ResCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction, + InstName + "\nResult value must be a scalar or vector of floating-point" + " 32-bit type\n"); + SPVErrLog.checkError(InCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction, + InstName + + "\nInput value must be a scalar or vector of " + "floating-point 32-bit type\n"); + SPVErrLog.checkError( + ResCompCount == InCompCount, SPIRVEC_InvalidInstruction, + InstName + "\nInput type must have the same number of components as " + "result type\n"); + } +}; + +#define _SPIRV_OP(x) \ + typedef SPIRVTensorFloat32ConversionINTELInstBase SPIRV##x; +_SPIRV_OP(ConvertFToTF32INTEL) +#undef _SPIRV_OP + class SPIRVMaskedGatherScatterINTELInstBase : public SPIRVInstTemplateBase { protected: SPIRVCapVec getRequiredCapability() const override { diff --git a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h index fd8c579..d73fdee 100644 --- a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h +++ b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h @@ -617,6 +617,8 @@ template <> inline void SPIRVMap::init() { add(internal::CapabilityNonConstantAddrspacePrintfINTEL, "NonConstantAddrspacePrintfINTEL"); add(internal::CapabilityComplexFloatMulDivINTEL, "ComplexFloatMulDivINTEL"); + add(internal::CapabilityTensorFloat32ConversionINTEL, + "TensorFloat32ConversionINTEL"); add(internal::CapabilityMaskedGatherScatterINTEL, "MaskedGatherScatterINTEL"); } SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap) diff --git a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h index 0ed0d85..b5e4cef 100644 --- a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h +++ b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h @@ -13,5 +13,6 @@ _SPIRV_OP_INTERNAL(JointMatrixWorkItemLengthINTEL, internal::OpJointMatrixWorkItemLengthINTEL) _SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL) _SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL) +_SPIRV_OP_INTERNAL(ConvertFToTF32INTEL, internal::ConvertFToTF32INTEL) _SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL) -_SPIRV_OP_INTERNAL(MaskedScatterINTEL, internal::OpMaskedScatterINTEL) +_SPIRV_OP_INTERNAL(MaskedScatterINTEL, internal::OpMaskedScatterINTEL) \ No newline at end of file diff --git a/lib/SPIRV/libSPIRV/spirv_internal.hpp b/lib/SPIRV/libSPIRV/spirv_internal.hpp index 9073ecc..3938ac4 100644 --- a/lib/SPIRV/libSPIRV/spirv_internal.hpp +++ b/lib/SPIRV/libSPIRV/spirv_internal.hpp @@ -46,6 +46,7 @@ enum InternalOp { IOpJointMatrixWorkItemLengthINTEL = 6410, IOpComplexFMulINTEL = 6415, IOpComplexFDivINTEL = 6416, + IOpConvertFToTF32INTEL = 6426, IOpMaskedGatherINTEL = 6428, IOpMaskedScatterINTEL = 6429, IOpPrev = OpMax - 2, @@ -79,6 +80,7 @@ enum InternalCapability { ICapGlobalVariableDecorationsINTEL = 6146, ICapabilityNonConstantAddrspacePrintfINTEL = 6411, ICapabilityComplexFloatMulDivINTEL = 6414, + ICapabilityTensorFloat32ConversionINTEL = 6425, ICapabilityMaskedGatherScatterINTEL = 6427 }; @@ -125,6 +127,9 @@ _SPIRV_OP(Capability, ComplexFloatMulDivINTEL) _SPIRV_OP(Op, ComplexFMulINTEL) _SPIRV_OP(Op, ComplexFDivINTEL) +_SPIRV_OP(Capability, TensorFloat32ConversionINTEL) +_SPIRV_OP(Op, ConvertFToTF32INTEL) + _SPIRV_OP(Capability, MaskedGatherScatterINTEL) _SPIRV_OP(Op, MaskedGatherINTEL) _SPIRV_OP(Op, MaskedScatterINTEL) diff --git a/test/transcoding/SPV_INTEL_tensor_float32_conversion/convert_tensor_float32.ll b/test/transcoding/SPV_INTEL_tensor_float32_conversion/convert_tensor_float32.ll new file mode 100644 index 0000000..1f02706 --- /dev/null +++ b/test/transcoding/SPV_INTEL_tensor_float32_conversion/convert_tensor_float32.ll @@ -0,0 +1,50 @@ +; RUN: llvm-as %s -o %t.bc +; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_tensor_float32_conversion +; RUN: llvm-spirv %t.spv -o %t.spt --to-text +; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV +; RUN: llvm-spirv %t.spv -o %t.rev.bc -r -emit-opaque-pointers --spirv-target-env=SPV-IR +; RUN: llvm-dis %t.rev.bc -o %t.rev.ll +; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM + +; RUN: not llvm-spirv %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension: +; CHECK-ERROR-NEXT: SPV_INTEL_tensor_float32_conversion + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spir64-unknown-unknown" + +; CHECK-SPIRV: Capability TensorFloat32ConversionINTEL +; CHECK-SPIRV: Extension "SPV_INTEL_tensor_float32_conversion" +; CHECK-SPIRV: TypeFloat [[#FP32Ty:]] 32 +; CHECK-SPIRV: TypeVector [[#FP32v8Ty:]] [[#FP32Ty]] 8 +; CHECK-SPIRV: Constant [[#FP32Ty]] [[#CONST:]] 1065353216 + +; CHECK-SPIRV: FunctionParameter [[#FP32Ty]] [[FP32ValId:.*]] +; CHECK-SPIRV: FunctionParameter [[#FP32v8Ty]] [[FP32v8ValId:.*]] + +; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32Ty]] [[#]] [[FP32ValId]] +; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32v8Ty]] [[#]] [[FP32v8ValId]] +; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32Ty]] [[#]] [[#CONST]] + +; CHECK-LLVM: call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float +; CHECK-LLVM: call spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float> +; CHECK-LLVM: call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float 1.000000e+00) + +define spir_func void @_Z2opffv8(float %a, <8 x float> %in) { + %1 = tail call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float %a) + %2 = tail call spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float> %in) + %3 = tail call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float 1.000000e+00) + ret void +} + +declare spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float) + +declare spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float>) + +!opencl.spir.version = !{!0} +!spirv.Source = !{!1} +!llvm.ident = !{!2} + +!0 = !{i32 1, i32 2} +!1 = !{i32 4, i32 100000} +!2 = !{!"clang version 16.0.0"} -- 2.30.2