[PATCH 49/79] [Backport to 15] Split JointMatrixMadINTEL instruction into 4 (#1833)
authorDmitry Sidorov <dmitry.sidorov@intel.com>
Wed, 15 Feb 2023 18:49:45 +0000 (19:49 +0100)
committerAndreas Beckmann <anbe@debian.org>
Thu, 14 Mar 2024 19:01:08 +0000 (20:01 +0100)
JointMatrixMadINTEL will stand for signed/signed Matrix type
JointMatrixSUMadINTEL will stand for signed/signed Matrix type
JointMatrixUSMadINTEL will stand for unsigned/signed Matrix type
JointMatrixUUMadINTEL will stand for unsigned/unsigned Matrix type

Spec update:
intel/llvm#8175

Signed-off-by: Dmitry Sidorov dmitry.sidorov@intel.com
Gbp-Pq: Name 0049-Backport-to-15-Split-JointMatrixMadINTEL-instruction.patch

lib/SPIRV/SPIRVWriter.cpp
lib/SPIRV/libSPIRV/SPIRVInstruction.h
lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
lib/SPIRV/libSPIRV/spirv_internal.hpp
test/transcoding/SPV_INTEL_joint_matrix/joint_matrix.ll

index c923642b95d22dbc36f0879bbb0ddd811be07340..cf1482e2691e5c094768db261a7067019836456d 100644 (file)
@@ -606,14 +606,9 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
   return TranslatedTy;
 }
 
-// Representation in LLVM IR before the translator is a pointer array wrapped
-// in a structure:
-// %struct.__spirv_JointMatrixINTEL = type { [R x [C x [L x [S x type]]]]* }
-// where R = Rows, C = Columnts, L = Layout + 1, S = Scope + 1
-// this '+1' for the Layout and Scope is required because both of them can
-// be '0', but array size can not be '0'.
-// The result should look like SPIR-V friendly LLVM IR:
-// %spirv.JointMatrixINTEL._char_2_2_0_3
+// Representation in LLVM IR before the translator is a pointer to an opaque
+// structure:
+// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%scope%_%use%
 // Here we check the structure name yet again. Another option would be to
 // check SPIR-V friendly function calls (by their name) and obtain return
 // or their parameter types, assuming, that the appropriate types are Matrix
index 8ee9aaf09d596c3617f3347a7b391ecfbfaf3c3a..8127717e7fa82156cb3f39a7c6b317c34b80b5fb 100644 (file)
@@ -3330,6 +3330,9 @@ class SPIRVJointMatrixINTELInst : public SPIRVJointMatrixINTELInstBase {
 _SPIRV_OP(JointMatrixLoad, true, 6, true)
 _SPIRV_OP(JointMatrixStore, false, 5, true)
 _SPIRV_OP(JointMatrixMad, true, 7)
+_SPIRV_OP(JointMatrixSUMad, true, 7)
+_SPIRV_OP(JointMatrixUSMad, true, 7)
+_SPIRV_OP(JointMatrixUUMad, true, 7)
 _SPIRV_OP(JointMatrixWorkItemLength, true, 4)
 #undef _SPIRV_OP
 
index ea888d8aad0680f5ef75c4884ddead130a1162dd..2682f86937f35c344d01726f730f19d92def7c6a 100644 (file)
@@ -9,6 +9,9 @@ _SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
 _SPIRV_OP_INTERNAL(JointMatrixLoadINTEL, internal::OpJointMatrixLoadINTEL)
 _SPIRV_OP_INTERNAL(JointMatrixStoreINTEL, internal::OpJointMatrixStoreINTEL)
 _SPIRV_OP_INTERNAL(JointMatrixMadINTEL, internal::OpJointMatrixMadINTEL)
+_SPIRV_OP_INTERNAL(JointMatrixSUMadINTEL, internal::OpJointMatrixSUMadINTEL)
+_SPIRV_OP_INTERNAL(JointMatrixUSMadINTEL, internal::OpJointMatrixUSMadINTEL)
+_SPIRV_OP_INTERNAL(JointMatrixUUMadINTEL, internal::OpJointMatrixUUMadINTEL)
 _SPIRV_OP_INTERNAL(JointMatrixWorkItemLengthINTEL,
                    internal::OpJointMatrixWorkItemLengthINTEL)
 _SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
index e03a3e25ac647c0dd6b5f15a546d33ba84faa7b9..c4524196bebcb586d48acfd766a1039769ec4bce 100644 (file)
@@ -65,6 +65,9 @@ enum InternalOp {
   IOpJointMatrixLoadINTEL = 6120,
   IOpJointMatrixStoreINTEL = 6121,
   IOpJointMatrixMadINTEL = 6122,
+  IOpJointMatrixSUMadINTEL = 6128,
+  IOpJointMatrixUSMadINTEL = 6129,
+  IOpJointMatrixUUMadINTEL = 6130,
   IOpArithmeticFenceINTEL = 6145,
   IOpJointMatrixWorkItemLengthINTEL = 6410,
   IOpComplexFMulINTEL = 6415,
@@ -138,6 +141,9 @@ _SPIRV_OP(Op, TypeJointMatrixINTEL)
 _SPIRV_OP(Op, JointMatrixLoadINTEL)
 _SPIRV_OP(Op, JointMatrixStoreINTEL)
 _SPIRV_OP(Op, JointMatrixMadINTEL)
+_SPIRV_OP(Op, JointMatrixSUMadINTEL)
+_SPIRV_OP(Op, JointMatrixUSMadINTEL)
+_SPIRV_OP(Op, JointMatrixUUMadINTEL)
 _SPIRV_OP(Op, JointMatrixWorkItemLengthINTEL)
 _SPIRV_OP(Capability, HWThreadQueryINTEL)
 _SPIRV_OP(BuiltIn, SubDeviceIDINTEL)
index 89c3e21729df15f1740f9cb397073c562f6af504..9088d7394179be88126ea0c5f7838992a82cfba5 100644 (file)
 ; CHECK-SPIRV: JointMatrixLoadINTEL [[#ATy]] [[#A:]] [[#Aptr:]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]]
 ; CHECK-SPIRV: JointMatrixLoadINTEL [[#BTy]] [[#B:]] [[#Bptr:]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]]
 ; CHECK-SPIRV: JointMatrixMadINTEL [[#CTy]] [[#CMad]] [[#A]] [[#B]] [[#C]] [[#Three]]
+; CHECK-SPIRV: JointMatrixSUMadINTEL [[#CTy]] [[#UnusedMad1:]] [[#A]] [[#B]] [[#C]] [[#Three]]
+; CHECK-SPIRV: JointMatrixUSMadINTEL [[#CTy]] [[#UnusedMad2:]] [[#A]] [[#B]] [[#C]] [[#Three]]
+; CHECK-SPIRV: JointMatrixUUMadINTEL [[#CTy]] [[#UnusedMad3:]] [[#A]] [[#B]] [[#C]] [[#Three]]
+
 ; CHECK-SPIRV: JointMatrixStoreINTEL [[#Cptr:]] [[#C]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]]
 ; CHECK-SPIRV: CompositeConstruct [[#CTy]] [[#Cnew:]] [[#FortyTwo]]
 ; CHECK-SPIRV: Store [[#PtrToZero:]] [[#Zero]]
 ; CHECK-LLVM: [[C:%.*]] = phi %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [ [[CLoaded]], %entry ], [ [[CMad:%.*]], %for.body.i ]
 ; CHECK-LLVM: [[A:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* @_Z79__spirv_JointMatrixLoadINTEL_RPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS4cliii(i8 addrspace(4)* [[APtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
 ; CHECK-LLVM: [[B:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* @_Z77__spirv_JointMatrixLoadINTEL_RPU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS4cliii(i8 addrspace(4)* [[BPtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
-; CHECK-LLVM: [[CMad:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z27__spirv_JointMatrixMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
+; CHECK-LLVM: [[CMad1:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z27__spirv_JointMatrixMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
+; CHECK-LLVM: [[CMad2:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z29__spirv_JointMatrixSUMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
+; CHECK-LLVM: [[CMad3:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z29__spirv_JointMatrixUSMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
+; CHECK-LLVM: [[CMad4:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z29__spirv_JointMatrixUUMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
+
 ; CHECK-LLVM: call spir_func void @_Z29__spirv_JointMatrixStoreINTELPU3AS4sPU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3liii(i16 addrspace(4)* [[CPtr]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i64 [[Stride]], i32 0, i32 3, i32 0)
 ; CHECK-LLVM: call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z26__spirv_CompositeConstructi(i32 42)
 ; CHECK-LLVM: store i32 0, i32 addrspace(4)* [[StoredZero:%.*]], align 4
@@ -115,6 +123,9 @@ for.body.i:                                       ; preds = %for.cond.i
   %add.ptr17.i = addrspacecast i8 addrspace(1)* %add.ptr17.i56 to i8 addrspace(4)*
   %call18.i = tail call spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* @_Z28__spirv_JointMatrixLoadINTELIaLm16ELm2ELN5__spv12MatrixLayoutE3ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT2_EXT3_EEEPS5_mS1_S3_i(i8 addrspace(4)* %add.ptr17.i, i64 %_arg_1, i32 0, i32 3, i32 0) #3
   %call19.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z27__spirv_JointMatrixMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3
+  %call20.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixSUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3
+  %call21.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUSMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3
+  %call22.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3
   %add.i = add nuw nsw i32 %k.0.i, 16
   br label %for.cond.i, !llvm.loop !19
 
@@ -142,6 +153,15 @@ declare dso_local spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*
 ; Function Attrs: convergent
 declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z27__spirv_JointMatrixMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1
 
+; Function Attrs: convergent
+declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixSUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1
+
+; Function Attrs: convergent
+declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUSMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1
+
+; Function Attrs: convergent
+declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1
+
 ; Function Attrs: convergent
 declare dso_local spir_func void @_Z29__spirv_JointMatrixStoreINTELIsLm2ELm2ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEEmS1_S3_i(i16 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i64, i32, i32, i32) local_unnamed_addr #1