[PATCH 06/79] [Backport to 15] Add SPV_INTEL_masked_gather_scatter extension (#1580...
authorStanley Gambarin <stanley.gambarin@intel.com>
Tue, 8 Nov 2022 16:04:36 +0000 (08:04 -0800)
committerAndreas Beckmann <anbe@debian.org>
Thu, 14 Mar 2024 19:01:08 +0000 (20:01 +0100)
This extension allows TypeVector to have a Physical Pointer Type
Component Type and introduces gather/scatter instructions.
It will be useful for explicitly vectorized kernels.

Spec: https://github.com/intel/llvm/pull/6613

Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com
Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com
Co-authored-by: Dmitry Sidorov <dmitry.sidorov@intel.com>
Gbp-Pq: Name 0006-Backport-to-15-Add-SPV_INTEL_masked_gather_scatter-e.patch

13 files changed:
include/LLVMSPIRVExtensions.inc
lib/SPIRV/SPIRVReader.cpp
lib/SPIRV/SPIRVRegularizeLLVM.cpp
lib/SPIRV/SPIRVWriter.cpp
lib/SPIRV/libSPIRV/SPIRVInstruction.h
lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
lib/SPIRV/libSPIRV/SPIRVType.cpp
lib/SPIRV/libSPIRV/SPIRVType.h
lib/SPIRV/libSPIRV/spirv_internal.hpp
test/transcoding/SPV_INTEL_function_pointers/vector_elem.ll
test/transcoding/SPV_INTEL_masked_gather_scatter/intel-basic-vector-pointers.ll [new file with mode: 0644]
test/transcoding/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll [new file with mode: 0644]

index 3d0eae8ce0e3fdb9284b4076d69ea01cf47db129..4f970d2a370128f0aeffd1eb95b11e4ab54fb551 100644 (file)
@@ -54,3 +54,4 @@ EXT(SPV_INTEL_global_variable_decorations)
 EXT(SPV_INTEL_non_constant_addrspace_printf)
 EXT(SPV_INTEL_complex_float_mul_div)
 EXT(SPV_INTEL_split_barrier)
+EXT(SPV_INTEL_masked_gather_scatter)
index cbfc74998dd04ba710389e90e236874bea202d71..500bebb9893ec1e6b7daef831c58629678dc6d02 100644 (file)
@@ -2177,7 +2177,12 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
   case OpInBoundsPtrAccessChain: {
     auto AC = static_cast<SPIRVAccessChainBase *>(BV);
     auto Base = transValue(AC->getBase(), F, BB);
-    Type *BaseTy = transType(AC->getBase()->getType()->getPointerElementType());
+    SPIRVType *BaseSPVTy = AC->getBase()->getType();
+    Type *BaseTy =
+        BaseSPVTy->isTypeVector()
+            ? transType(
+                  BaseSPVTy->getVectorComponentType()->getPointerElementType())
+            : transType(BaseSPVTy->getPointerElementType());
     auto Index = transValue(AC->getIndices(), F, BB);
     if (!AC->hasPtrIndex())
       Index.insert(Index.begin(), getInt32(M, 0));
@@ -2590,6 +2595,29 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
     return mapValue(
         BV, Builder.CreateIntrinsic(Intrinsic::arithmetic_fence, RetTy, Val));
   }
+  case internal::OpMaskedGatherINTEL: {
+    IRBuilder<> Builder(BB);
+    auto *Inst = static_cast<SPIRVMaskedGatherINTELInst *>(BV);
+    Type *RetTy = transType(Inst->getType());
+    Value *PtrVector = transValue(Inst->getOperand(0), F, BB);
+    uint32_t Alignment = Inst->getOpWord(1);
+    Value *Mask = transValue(Inst->getOperand(2), F, BB);
+    Value *FillEmpty = transValue(Inst->getOperand(3), F, BB);
+    return mapValue(BV, Builder.CreateMaskedGather(RetTy, PtrVector,
+                                                   Align(Alignment), Mask,
+                                                   FillEmpty));
+  }
+
+  case internal::OpMaskedScatterINTEL: {
+    IRBuilder<> Builder(BB);
+    auto *Inst = static_cast<SPIRVMaskedScatterINTELInst *>(BV);
+    Value *InputVector = transValue(Inst->getOperand(0), F, BB);
+    Value *PtrVector = transValue(Inst->getOperand(1), F, BB);
+    uint32_t Alignment = Inst->getOpWord(2);
+    Value *Mask = transValue(Inst->getOperand(3), F, BB);
+    return mapValue(BV, Builder.CreateMaskedScatter(InputVector, PtrVector,
+                                                    Align(Alignment), Mask));
+  }
 
   default: {
     auto OC = BV->getOpCode();
@@ -3197,7 +3225,11 @@ std::string getSPIRVFuncSuffix(SPIRVInstruction *BI) {
   }
   if (BI->getOpCode() == OpGenericCastToPtrExplicit) {
     Suffix += kSPIRVPostfix::Divider;
-    auto GenericCastToPtrInst = BI->getType()->getPointerStorageClass();
+    auto *Ty = BI->getType();
+    auto GenericCastToPtrInst =
+        Ty->isTypeVectorPointer()
+            ? Ty->getVectorComponentType()->getPointerStorageClass()
+            : Ty->getPointerStorageClass();
     switch (GenericCastToPtrInst) {
     case StorageClassCrossWorkgroup:
       Suffix += std::string(kSPIRVPostfix::ToGlobal);
index d0e3ae1ec289f9483e428abc6ef9d429fe95c683..e1e2f70763b6f2644cfde67d24503ae3e6d7a5a6 100644 (file)
@@ -539,11 +539,19 @@ bool SPIRVRegularizeLLVMBase::regularize() {
         // Add an additional bitcast in case address space cast also changes
         // pointer element type.
         if (auto *ASCast = dyn_cast<AddrSpaceCastInst>(&II)) {
-          PointerType *DestTy = cast<PointerType>(ASCast->getDestTy());
-          PointerType *SrcTy = cast<PointerType>(ASCast->getSrcTy());
-          if (!DestTy->hasSameElementTypeAs(SrcTy)) {
-            PointerType *InterTy = PointerType::getWithSamePointeeType(
-                DestTy, SrcTy->getPointerAddressSpace());
+          Type *DestTy = ASCast->getDestTy();
+          Type *SrcTy = ASCast->getSrcTy();
+          if (!II.getContext().supportsTypedPointers())
+            continue;
+          if (DestTy->getScalarType()->getNonOpaquePointerElementType() !=
+              SrcTy->getScalarType()->getNonOpaquePointerElementType()) {
+            Type *InterTy = PointerType::getWithSamePointeeType(
+                cast<PointerType>(DestTy->getScalarType()),
+                cast<PointerType>(SrcTy->getScalarType())
+                    ->getPointerAddressSpace());
+            if (DestTy->isVectorTy())
+              InterTy = VectorType::get(
+                  InterTy, cast<VectorType>(DestTy)->getElementCount());
             BitCastInst *NewBCast = new BitCastInst(
                 ASCast->getPointerOperand(), InterTy, /*NameStr=*/"", ASCast);
             AddrSpaceCastInst *NewASCast =
index fd890a9334e0e4c96f9aa69c70e61bf2a2b9c285..3593375a2fcecc2681bcaaf3d8705c8f18bf4abe 100644 (file)
@@ -342,9 +342,28 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
     return transPointerType(ET, AddrSpc);
   }
 
-  if (auto *VecTy = dyn_cast<FixedVectorType>(T))
+  if (auto *VecTy = dyn_cast<FixedVectorType>(T)) {
+    if (VecTy->getElementType()->isPointerTy()) {
+      // SPV_INTEL_masked_gather_scatter extension changes 2.16.1. Universal
+      // Validation Rules:
+      // Vector types must be parameterized only with numerical types,
+      // [Physical Pointer Type] types or the [OpTypeBool] type.
+      // Without it vector of pointers is not allowed in SPIR-V.
+      if (!BM->isAllowedToUseExtension(
+              ExtensionID::SPV_INTEL_masked_gather_scatter)) {
+        BM->getErrorLog().checkError(
+            false, SPIRVEC_RequiresExtension,
+            "SPV_INTEL_masked_gather_scatter\n"
+            "NOTE: LLVM module contains vector of pointers, translation "
+            "of which requires this extension");
+        return nullptr;
+      }
+      BM->addExtension(ExtensionID::SPV_INTEL_masked_gather_scatter);
+      BM->addCapability(internal::CapabilityMaskedGatherScatterINTEL);
+    }
     return mapType(T, BM->addVectorType(transType(VecTy->getElementType()),
                                         VecTy->getNumElements()));
+  }
 
   if (T->isArrayTy()) {
     // SPIR-V 1.3 s3.32.6: Length is the number of elements in the array.
@@ -3824,6 +3843,47 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
     }
     return Op;
   }
+  case Intrinsic::masked_gather: {
+    if (!BM->isAllowedToUseExtension(
+            ExtensionID::SPV_INTEL_masked_gather_scatter)) {
+      BM->getErrorLog().checkError(
+          BM->isUnknownIntrinsicAllowed(II), SPIRVEC_InvalidFunctionCall, II,
+          "Translation of llvm.masked.gather intrinsic requires "
+          "SPV_INTEL_masked_gather_scatter extension or "
+          "-spirv-allow-unknown-intrinsics option.");
+      return nullptr;
+    }
+    SPIRVType *Ty = transType(II->getType());
+    auto *PtrVector = transValue(II->getArgOperand(0), BB);
+    uint32_t Alignment =
+        cast<ConstantInt>(II->getArgOperand(1))->getZExtValue();
+    auto *Mask = transValue(II->getArgOperand(2), BB);
+    auto *FillEmpty = transValue(II->getArgOperand(3), BB);
+    std::vector<SPIRVWord> Ops = {PtrVector->getId(), Alignment, Mask->getId(),
+                                  FillEmpty->getId()};
+    return BM->addInstTemplate(internal::OpMaskedGatherINTEL, Ops, BB, Ty);
+  }
+  case Intrinsic::masked_scatter: {
+    if (!BM->isAllowedToUseExtension(
+            ExtensionID::SPV_INTEL_masked_gather_scatter)) {
+      BM->getErrorLog().checkError(
+          BM->isUnknownIntrinsicAllowed(II), SPIRVEC_InvalidFunctionCall, II,
+          "Translation of llvm.masked.scatter intrinsic requires "
+          "SPV_INTEL_masked_gather_scatter extension or "
+          "-spirv-allow-unknown-intrinsics option.");
+      return nullptr;
+    }
+    auto *InputVector = transValue(II->getArgOperand(0), BB);
+    auto *PtrVector = transValue(II->getArgOperand(1), BB);
+    uint32_t Alignment =
+        cast<ConstantInt>(II->getArgOperand(2))->getZExtValue();
+    auto *Mask = transValue(II->getArgOperand(3), BB);
+    std::vector<SPIRVWord> Ops = {InputVector->getId(), PtrVector->getId(),
+                                  Alignment, Mask->getId()};
+    return BM->addInstTemplate(internal::OpMaskedScatterINTEL, Ops, BB,
+                               nullptr);
+  }
+
   default:
     if (BM->isUnknownIntrinsicAllowed(II))
       return BM->addCallInst(
index 5c83b1ecba577e20aa0bc569ca4b4265169e2d9f..2ffd4c3aa65260f1d12ebc3aeb2ef3e3c9c30319 100644 (file)
@@ -3413,6 +3413,154 @@ class SPIRVComplexFloatInst
 _SPIRV_OP(ComplexFMulINTEL)
 _SPIRV_OP(ComplexFDivINTEL)
 #undef _SPIRV_OP
+
+class SPIRVMaskedGatherScatterINTELInstBase : public SPIRVInstTemplateBase {
+protected:
+  SPIRVCapVec getRequiredCapability() const override {
+    return getVec(internal::CapabilityMaskedGatherScatterINTEL);
+  }
+  llvm::Optional<ExtensionID> getRequiredExtension() const override {
+    return ExtensionID::SPV_INTEL_masked_gather_scatter;
+  }
+};
+
+class SPIRVMaskedGatherINTELInst
+    : public SPIRVMaskedGatherScatterINTELInstBase {
+  void validate() const override {
+    SPIRVInstruction::validate();
+    SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
+    std::string InstName = "MaskedGatherINTEL";
+
+    SPIRVType *ResTy = this->getType();
+    SPVErrLog.checkError(ResTy->isTypeVector(), SPIRVEC_InvalidInstruction,
+                         InstName + "\nResult must be a vector type\n");
+    SPIRVWord ResCompCount = ResTy->getVectorComponentCount();
+    SPIRVType *ResCompTy = ResTy->getVectorComponentType();
+
+    SPIRVValue *PtrVec =
+        const_cast<SPIRVMaskedGatherINTELInst *>(this)->getOperand(0);
+    SPIRVType *PtrVecTy = PtrVec->getType();
+    SPVErrLog.checkError(
+        PtrVecTy->isTypeVectorPointer(), SPIRVEC_InvalidInstruction,
+        InstName + "\nPtrVector must be a vector of pointers type\n");
+    SPIRVWord PtrVecCompCount = PtrVecTy->getVectorComponentCount();
+    SPIRVType *PtrVecCompTy = PtrVecTy->getVectorComponentType();
+    SPIRVType *PtrElemTy = PtrVecCompTy->getPointerElementType();
+
+    SPVErrLog.checkError(
+        this->isOperandLiteral(1), SPIRVEC_InvalidInstruction,
+        InstName + "\nAlignment must be a constant expression integer\n");
+    const uint32_t Align =
+        static_cast<SPIRVConstant *>(
+            const_cast<SPIRVMaskedGatherINTELInst *>(this)->getOperand(2))
+            ->getZExtIntValue();
+    SPVErrLog.checkError(
+        ((Align & (Align - 1)) == 0), SPIRVEC_InvalidInstruction,
+        InstName + "\nAlignment must be 0 or power-of-two integer\n");
+
+    SPIRVValue *Mask =
+        const_cast<SPIRVMaskedGatherINTELInst *>(this)->getOperand(2);
+    SPIRVType *MaskTy = Mask->getType();
+    SPVErrLog.checkError(MaskTy->isTypeVector(), SPIRVEC_InvalidInstruction,
+                         InstName + "\nMask must be a vector type\n");
+    SPIRVType *MaskCompTy = MaskTy->getVectorComponentType();
+    SPVErrLog.checkError(MaskCompTy->isTypeBool(), SPIRVEC_InvalidInstruction,
+                         InstName + "\nMask must be a boolean vector type\n");
+    SPIRVWord MaskCompCount = MaskTy->getVectorComponentCount();
+
+    SPIRVValue *FillEmpty =
+        const_cast<SPIRVMaskedGatherINTELInst *>(this)->getOperand(3);
+    SPIRVType *FillEmptyTy = FillEmpty->getType();
+    SPVErrLog.checkError(FillEmptyTy->isTypeVector(),
+                         SPIRVEC_InvalidInstruction,
+                         InstName + "\nFillEmpty must be a vector type\n");
+    SPIRVWord FillEmptyCompCount = FillEmptyTy->getVectorComponentCount();
+    SPIRVType *FillEmptyCompTy = FillEmptyTy->getVectorComponentType();
+
+    SPVErrLog.checkError(
+        ResCompCount == PtrVecCompCount &&
+            PtrVecCompCount == FillEmptyCompCount &&
+            FillEmptyCompCount == MaskCompCount,
+        SPIRVEC_InvalidInstruction,
+        InstName + "\nResult, PtrVector, Mask and FillEmpty vectors must have "
+                   "the same size\n");
+
+    SPVErrLog.checkError(
+        ResCompTy == PtrElemTy && PtrElemTy == FillEmptyCompTy,
+        SPIRVEC_InvalidInstruction,
+        InstName + "\nComponent Type of Result and FillEmpty vector must be "
+                   "same as base type of PtrVector the same base type\n");
+  }
+};
+
+class SPIRVMaskedScatterINTELInst
+    : public SPIRVMaskedGatherScatterINTELInstBase {
+  void validate() const override {
+    SPIRVInstruction::validate();
+    SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
+    std::string InstName = "MaskedScatterINTEL";
+
+    SPIRVValue *InputVec =
+        const_cast<SPIRVMaskedScatterINTELInst *>(this)->getOperand(0);
+    SPIRVType *InputVecTy = InputVec->getType();
+    SPVErrLog.checkError(
+        InputVecTy->isTypeVector(), SPIRVEC_InvalidInstruction,
+        InstName + "\nInputVector must be a vector of pointers type\n");
+    SPIRVWord InputVecCompCount = InputVecTy->getVectorComponentCount();
+    SPIRVType *InputVecCompTy = InputVecTy->getVectorComponentType();
+
+    SPIRVValue *PtrVec =
+        const_cast<SPIRVMaskedScatterINTELInst *>(this)->getOperand(1);
+    SPIRVType *PtrVecTy = PtrVec->getType();
+    SPVErrLog.checkError(
+        PtrVecTy->isTypeVectorPointer(), SPIRVEC_InvalidInstruction,
+        InstName + "\nPtrVector must be a vector of pointers type\n");
+    SPIRVWord PtrVecCompCount = PtrVecTy->getVectorComponentCount();
+    SPIRVType *PtrVecCompTy = PtrVecTy->getVectorComponentType();
+    SPIRVType *PtrElemTy = PtrVecCompTy->getPointerElementType();
+
+    SPVErrLog.checkError(
+        this->isOperandLiteral(2), SPIRVEC_InvalidInstruction,
+        InstName + "\nAlignment must be a constant expression integer\n");
+    const uint32_t Align =
+        static_cast<SPIRVConstant *>(
+            const_cast<SPIRVMaskedScatterINTELInst *>(this)->getOperand(2))
+            ->getZExtIntValue();
+    SPVErrLog.checkError(
+        ((Align & (Align - 1)) == 0), SPIRVEC_InvalidInstruction,
+        InstName + "\nAlignment must be 0 or power-of-two integer\n");
+
+    SPIRVValue *Mask =
+        const_cast<SPIRVMaskedScatterINTELInst *>(this)->getOperand(2);
+    SPIRVType *MaskTy = Mask->getType();
+    SPVErrLog.checkError(MaskTy->isTypeVector(), SPIRVEC_InvalidInstruction,
+                         InstName + "\nMask must be a vector type\n");
+    SPIRVType *MaskCompTy = MaskTy->getVectorComponentType();
+    SPVErrLog.checkError(MaskCompTy->isTypeBool(), SPIRVEC_InvalidInstruction,
+                         InstName + "\nMask must be a boolean vector type\n");
+    SPIRVWord MaskCompCount = MaskTy->getVectorComponentCount();
+
+    SPVErrLog.checkError(
+        InputVecCompCount == PtrVecCompCount &&
+            PtrVecCompCount == MaskCompCount,
+        SPIRVEC_InvalidInstruction,
+        InstName + "\nInputVector, PtrVector and Mask vectors must have "
+                   "the same size\n");
+
+    SPVErrLog.checkError(
+        InputVecCompTy == PtrElemTy, SPIRVEC_InvalidInstruction,
+        InstName + "\nComponent Type of InputVector must be "
+                   "same as base type of PtrVector the same base type\n");
+  }
+};
+
+#define _SPIRV_OP(x, ...)                                                      \
+  typedef SPIRVInstTemplate<SPIRVMaskedGatherScatterINTELInstBase,             \
+                            internal::Op##x##INTEL, __VA_ARGS__>               \
+      SPIRV##x##INTEL;
+_SPIRV_OP(MaskedGather, true, 7)
+_SPIRV_OP(MaskedScatter, false, 5)
+#undef _SPIRV_OP
 } // namespace SPIRV
 
 #endif // SPIRV_LIBSPIRV_SPIRVINSTRUCTION_H
index 88071dd7fd2f7f08d24288c49217f27dd349c2a0..fd8c5790e4b55081f0d7314035d7693103090884 100644 (file)
@@ -617,6 +617,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
   add(internal::CapabilityNonConstantAddrspacePrintfINTEL,
       "NonConstantAddrspacePrintfINTEL");
   add(internal::CapabilityComplexFloatMulDivINTEL, "ComplexFloatMulDivINTEL");
+  add(internal::CapabilityMaskedGatherScatterINTEL, "MaskedGatherScatterINTEL");
 }
 SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
 
index 22d3aabe8d6e63dc80ebda0c18f171919f5c2b85..0ed0d855d5e61f01761c179df5a1eef3759f44fa 100644 (file)
@@ -13,3 +13,5 @@ _SPIRV_OP_INTERNAL(JointMatrixWorkItemLengthINTEL,
                    internal::OpJointMatrixWorkItemLengthINTEL)
 _SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
 _SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL)
+_SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL)
+_SPIRV_OP_INTERNAL(MaskedScatterINTEL, internal::OpMaskedScatterINTEL)
index add2d4597fa4a4bfec1423e2ff3a982fe1448e49..0314b54d1ac3dd7b852e83288b1a6648d0137978 100644 (file)
@@ -218,6 +218,10 @@ bool SPIRVType::isTypeVectorOrScalarBool() const {
   return isTypeBool() || isTypeVectorBool();
 }
 
+bool SPIRVType::isTypeVectorPointer() const {
+  return isTypeVector() && getVectorComponentType()->isTypePointer();
+}
+
 bool SPIRVType::isTypeSubgroupAvcINTEL() const {
   return isSubgroupAvcINTELTypeOpCode(OpCode);
 }
index 312aeb2fe27dbe87fd1245fba2827f4e22648b4c..af1a86d9a7b3edaec86bead9272e7f328c3d9089 100644 (file)
@@ -102,6 +102,7 @@ public:
   bool isTypeVectorOrScalarInt() const;
   bool isTypeVectorOrScalarFloat() const;
   bool isTypeVectorOrScalarBool() const;
+  bool isTypeVectorPointer() const;
   bool isTypeSubgroupAvcINTEL() const;
   bool isTypeSubgroupAvcMceINTEL() const;
 };
index dc41a04f3f6575e48f75a50e123fda9c32c5709b..9073eccd74b2073fac73b39d342db66ab0280994 100644 (file)
@@ -46,6 +46,8 @@ enum InternalOp {
   IOpJointMatrixWorkItemLengthINTEL = 6410,
   IOpComplexFMulINTEL = 6415,
   IOpComplexFDivINTEL = 6416,
+  IOpMaskedGatherINTEL = 6428,
+  IOpMaskedScatterINTEL = 6429,
   IOpPrev = OpMax - 2,
   IOpForward
 };
@@ -76,7 +78,8 @@ enum InternalCapability {
   ICapFPArithmeticFenceINTEL = 6144,
   ICapGlobalVariableDecorationsINTEL = 6146,
   ICapabilityNonConstantAddrspacePrintfINTEL = 6411,
-  ICapabilityComplexFloatMulDivINTEL = 6414
+  ICapabilityComplexFloatMulDivINTEL = 6414,
+  ICapabilityMaskedGatherScatterINTEL = 6427
 };
 
 enum InternalFunctionControlMask { IFunctionControlOptNoneINTELMask = 0x10000 };
@@ -121,6 +124,10 @@ _SPIRV_OP(Capability, NonConstantAddrspacePrintfINTEL)
 _SPIRV_OP(Capability, ComplexFloatMulDivINTEL)
 _SPIRV_OP(Op, ComplexFMulINTEL)
 _SPIRV_OP(Op, ComplexFDivINTEL)
+
+_SPIRV_OP(Capability, MaskedGatherScatterINTEL)
+_SPIRV_OP(Op, MaskedGatherINTEL)
+_SPIRV_OP(Op, MaskedScatterINTEL)
 #undef _SPIRV_OP
 
 constexpr Op OpForward = static_cast<Op>(IOpForward);
index ab832f9d5d3da58cca4a3083853c80749c34a322..421a3466676c95e8d0f677093aa0ab2f932a4a93 100644 (file)
@@ -1,4 +1,4 @@
-; RUN: llvm-as < %s | llvm-spirv -spirv-text --spirv-ext=+SPV_INTEL_function_pointers | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: llvm-as < %s | llvm-spirv -spirv-text --spirv-ext=+SPV_INTEL_function_pointers,+SPV_INTEL_masked_gather_scatter | FileCheck %s --check-prefix=CHECK-SPIRV
 
 ; CHECK-SPIRV-DAG: 6 Name [[F1:[0-9+]]] "_Z2f1u2CMvb32_j"
 ; CHECK-SPIRV-DAG: 6 Name [[F2:[0-9+]]] "_Z2f2u2CMvb32_j"
diff --git a/test/transcoding/SPV_INTEL_masked_gather_scatter/intel-basic-vector-pointers.ll b/test/transcoding/SPV_INTEL_masked_gather_scatter/intel-basic-vector-pointers.ll
new file mode 100644 (file)
index 0000000..0d2d93c
--- /dev/null
@@ -0,0 +1,70 @@
+; RUN: llvm-as %s -o %t.bc
+; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_masked_gather_scatter -o %t.spv
+; RUN: llvm-spirv %t.spv --to-text -o %t.spt
+; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
+
+; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
+; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
+
+; RUN: not llvm-spirv %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension:
+; CHECK-ERROR-NEXT: SPV_INTEL_masked_gather_scatter
+; CHECK-ERROR-NEXT: NOTE: LLVM module contains vector of pointers, translation of which requires this extension
+
+
+; CHECK-SPIRV-DAG: Capability MaskedGatherScatterINTEL
+; CHECK-SPIRV-DAG: Extension "SPV_INTEL_masked_gather_scatter"
+
+; CHECK-SPIRV-DAG: TypeInt [[#INTTYPE1:]] 32 0
+; CHECK-SPIRV-DAG: TypeInt [[#INTTYPE2:]] 8 0
+; CHECK-SPIRV-DAG: TypePointer [[#PTRTYPE1:]] 5 [[#INTTYPE1]]
+; CHECK-SPIRV-DAG: TypeVector [[#VECTYPE1:]] [[#PTRTYPE1]] 4
+; CHECK-SPIRV-DAG: TypePointer [[#PTRTYPE2:]] 8 [[#INTTYPE2]]
+; CHECK-SPIRV-DAG: TypeVector [[#VECTYPE2:]] [[#PTRTYPE2]] 4
+; CHECK-SPIRV-DAG: TypePointer [[#PTRTOVECTYPE:]] 7 [[#VECTYPE2]]
+; CHECK-SPIRV-DAG: TypePointer [[#PTRTYPE3:]] 8 [[#INTTYPE1]]
+; CHECK-SPIRV-DAG: TypeVector [[#VECTYPE3:]] [[#PTRTYPE3]] 4
+
+; CHECK-SPIRV: Variable [[#PTRTOVECTYPE]]
+; CHECK-SPIRV: Variable [[#PTRTOVECTYPE]]
+; CHECK-SPIRV: Load [[#VECTYPE2]]
+; CHECK-SPIRV: Store
+; CHECK-SPIRV: Bitcast [[#VECTYPE3]]
+; CHECK-SPIRV: GenericCastToPtr [[#VECTYPE1]]
+; CHECK-SPIRV: FunctionCall [[#VECTYPE1]]
+; CHECK-SPIRV: InBoundsPtrAccessChain [[#PTRTYPE1]]
+
+; CHECK-LLVM: alloca <4 x i8 addrspace(4)*>
+; CHECK-LLVM-NEXT: alloca <4 x i8 addrspace(4)*>
+; CHECK-LLVM-NEXT: load <4 x i8 addrspace(4)*>, <4 x i8 addrspace(4)*>*
+; CHECK-LLVM-NEXT: store <4 x i8 addrspace(4)*> %[[#]], <4 x i8 addrspace(4)*>*
+; CHECK-LLVM-NEXT: bitcast <4 x i8 addrspace(4)*> %[[#]] to <4 x i32 addrspace(4)*>
+; CHECK-LLVM-NEXT: addrspacecast <4 x i32 addrspace(4)*> %{{.*}} to <4 x i32 addrspace(1)*>
+; CHECK-LLVM-NEXT: call spir_func <4 x i32 addrspace(1)*> @boo(<4 x i32 addrspace(1)*>
+; CHECK-LLVM-NEXT: getelementptr inbounds i32, <4 x i32 addrspace(1)*> %{{.*}}, i32 1
+
+target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spir"
+
+; Function Attrs: nounwind readnone
+define spir_kernel void @foo() {
+entry:
+  %arg1 = alloca <4 x i8 addrspace(4)*>
+  %arg2 = alloca <4 x i8 addrspace(4)*>
+  %0 = load <4 x i8 addrspace(4)*>, <4 x i8 addrspace(4)*>* %arg1
+  store <4 x i8 addrspace(4)*> %0, <4 x i8 addrspace(4)*>* %arg2
+  %tmp1 = bitcast <4 x i8 addrspace(4)*> %0 to <4 x i32 addrspace(4)*>
+  %tmp2 = addrspacecast <4 x i32 addrspace(4)*> %tmp1 to  <4 x i32 addrspace(1)*>
+  %tmp3 = call <4 x i32 addrspace(1)*> @boo(<4 x i32 addrspace(1)*> %tmp2)
+  %tmp4 = getelementptr inbounds i32, <4 x i32 addrspace(1)*> %tmp3, i32 1
+  %tmp5 = addrspacecast <4 x i32 addrspace(4)*> %tmp1 to <4 x i8 addrspace(1)*>
+  ret void
+}
+
+declare <4 x i32 addrspace(1)*> @boo(<4 x i32 addrspace(1)*> %a)
+
+!llvm.module.flags = !{!0}
+!opencl.spir.version = !{!1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 1, i32 2}
diff --git a/test/transcoding/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll b/test/transcoding/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll
new file mode 100644 (file)
index 0000000..78a9fab
--- /dev/null
@@ -0,0 +1,67 @@
+; RUN: llvm-as %s -o %t.bc
+; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_masked_gather_scatter -o %t.spv
+; RUN: llvm-spirv %t.spv --to-text -o %t.spt
+; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
+
+; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
+; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
+
+; RUN: not llvm-spirv %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension:
+; CHECK-ERROR-NEXT: SPV_INTEL_masked_gather_scatter
+; CHECK-ERROR-NEXT: NOTE: LLVM module contains vector of pointers, translation of which requires this extension
+
+; CHECK-SPIRV-DAG: Capability MaskedGatherScatterINTEL
+; CHECK-SPIRV-DAG: Extension "SPV_INTEL_masked_gather_scatter"
+
+; CHECK-SPIRV-DAG: TypeInt [[#TYPEINT:]] 32 0
+; CHECK-SPIRV-DAG: TypePointer [[#TYPEPTRINT:]] [[#]] [[#TYPEINT]]
+; CHECK-SPIRV-DAG: TypeVector [[#TYPEVECPTR:]] [[#TYPEPTRINT]] 4
+; CHECK-SPIRV-DAG: TypeVector [[#TYPEVECINT:]] [[#TYPEINT]] 4
+
+; CHECK-SPIRV-DAG: Constant [[#TYPEINT]] [[#CONST4:]] 4
+; CHECK-SPIRV-DAG: Constant [[#TYPEINT]] [[#CONST0:]] 0
+; CHECK-SPIRV-DAG: Constant [[#TYPEINT]] [[#CONST1:]] 1
+; CHECK-SPIRV-DAG: ConstantTrue [[#]] [[#TRUE:]]
+; CHECK-SPIRV-DAG: ConstantFalse [[#]] [[#FALSE:]]
+; CHECK-SPIRV-DAG: ConstantComposite [[#]] [[#MASK1:]] [[#TRUE]] [[#FALSE]] [[#TRUE]] [[#TRUE]]
+; CHECK-SPIRV-DAG: ConstantComposite [[#]] [[#FILL:]] [[#CONST4]] [[#CONST0]] [[#CONST1]] [[#CONST0]]
+; CHECK-SPIRV-DAG: ConstantComposite [[#]] [[#MASK2:]] [[#TRUE]] [[#TRUE]] [[#TRUE]] [[#TRUE]]
+
+; CHECK-SPIRV: Load [[#TYPEVECPTR]] [[#VECGATHER:]]
+; CHECK-SPIRV: Load [[#TYPEVECPTR]] [[#VECSCATTER:]]
+; CHECK-SPIRV: MaskedGatherINTEL [[#TYPEVECINT]] [[#GATHER:]] [[#VECGATHER]] 4 [[#MASK1]] 23
+; CHECK-SPIRV: MaskedScatterINTEL [[#GATHER]] [[#VECSCATTER]] 4 [[#MASK2]]
+
+; CHECK-LLVM: %[[#VECGATHER:]] = load <4 x i32 addrspace(4)*>, <4 x i32 addrspace(4)*>*
+; CHECK-LLVM: %[[#VECSCATTER:]] = load <4 x i32 addrspace(4)*>, <4 x i32 addrspace(4)*>*
+; CHECK-LLVM: %[[GATHER:[a-z0-9]+]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p4i32(<4 x i32 addrspace(4)*> %[[#VECGATHER]], i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 true>, <4 x i32> <i32 4, i32 0, i32 1, i32 0>)
+; CHECK-LLVM: call void @llvm.masked.scatter.v4i32.v4p4i32(<4 x i32> %[[GATHER]], <4 x i32 addrspace(4)*> %[[#VECSCATTER]], i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
+
+; CHECK-LLVM-DAG: declare <4 x i32> @llvm.masked.gather.v4i32.v4p4i32(<4 x i32 addrspace(4)*>, i32 immarg, <4 x i1>, <4 x i32>)
+; CHECK-LLVM-DAG: declare void @llvm.masked.scatter.v4i32.v4p4i32(<4 x i32>, <4 x i32 addrspace(4)*>, i32 immarg, <4 x i1>)
+
+target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spir"
+
+; Function Attrs: nounwind readnone
+define spir_kernel void @foo() {
+entry:
+  %arg0 = alloca <4 x i32 addrspace(4)*>
+  %arg1 = alloca <4 x i32 addrspace(4)*>
+  %0 = load <4 x i32 addrspace(4)*>, <4 x i32 addrspace(4)*>* %arg0
+  %1 = load <4 x i32 addrspace(4)*>, <4 x i32 addrspace(4)*>* %arg1
+  %res = call <4 x i32> @llvm.masked.gather.v4i32.v4p4i32(<4 x i32 addrspace(4)*> %0, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 true>, <4 x i32> <i32 4, i32 0, i32 1, i32 0>)
+  call void @llvm.masked.scatter.v4i32.v4p4i32(<4 x i32> %res, <4 x i32 addrspace(4)*> %1, i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
+  ret void
+}
+
+declare <4 x i32> @llvm.masked.gather.v4i32.v4p4i32(<4 x i32 addrspace(4)*>, i32, <4 x i1>, <4 x i32>)
+
+declare void @llvm.masked.scatter.v4i32.v4p4i32(<4 x i32>, <4 x i32 addrspace(4)*>, i32, <4 x i1>)
+
+!llvm.module.flags = !{!0}
+!opencl.spir.version = !{!1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 1, i32 2}