From: Jakub Kuderski Date: Mon, 9 Jan 2023 16:35:46 +0000 (-0500) Subject: [PATCH] [mlir][spirv] Account for type conversion failures in scf-to-spirv X-Git-Tag: archive/raspbian/1%14.0.6-16+rpi1^2~2 X-Git-Url: https://dgit.raspbian.org/?a=commitdiff_plain;h=d3fd142f9eab041215b84e935a6fcc25b2f7fd94;p=llvm-toolchain-14.git [PATCH] [mlir][spirv] Account for type conversion failures in scf-to-spirv Fixes: https://github.com/llvm/llvm-project/issues/59136 Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D141292 Gbp-Pq: Name CVE-2023-29934.patch --- diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index 3a67428cc5..5dceaa794d 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; @@ -286,6 +287,10 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, SmallVector returnTypes; for (auto result : ifOp.getResults()) { auto convertedType = typeConverter.convertType(result.getType()); + if (!convertedType) + return failure(); + + returnTypes.push_back(convertedType); } replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, diff --git a/mlir/test/Conversion/SCFToSPIRV/if.mlir b/mlir/test/Conversion/SCFToSPIRV/if.mlir index 224549b539..6ca883686d 100644 --- a/mlir/test/Conversion/SCFToSPIRV/if.mlir +++ b/mlir/test/Conversion/SCFToSPIRV/if.mlir @@ -153,4 +153,18 @@ func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32> return } +// Memrefs without a spirv storage class are not supported. The conversion +// should preserve the `scf.if` and not crash. +func.func @unsupported_yield_type(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %c : i1) { +// CHECK-LABEL: @unsupported_yield_type +// CHECK-NEXT: scf.if +// CHECK: spirv.Return + %r = scf.if %c -> (memref<8xi32>) { + scf.yield %arg0 : memref<8xi32> + } else { + scf.yield %arg1 : memref<8xi32> + } + return +} + } // end module