From d3fd142f9eab041215b84e935a6fcc25b2f7fd94 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 9 Jan 2023 11:35:46 -0500 Subject: [PATCH] [PATCH] [mlir][spirv] Account for type conversion failures in scf-to-spirv Fixes: https://github.com/llvm/llvm-project/issues/59136 Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D141292 Gbp-Pq: Name CVE-2023-29934.patch --- mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp | 5 +++++ mlir/test/Conversion/SCFToSPIRV/if.mlir | 14 ++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index 3a67428cc5..5dceaa794d 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; @@ -286,6 +287,10 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, SmallVector returnTypes; for (auto result : ifOp.getResults()) { auto convertedType = typeConverter.convertType(result.getType()); + if (!convertedType) + return failure(); + + returnTypes.push_back(convertedType); } replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, diff --git a/mlir/test/Conversion/SCFToSPIRV/if.mlir b/mlir/test/Conversion/SCFToSPIRV/if.mlir index 224549b539..6ca883686d 100644 --- a/mlir/test/Conversion/SCFToSPIRV/if.mlir +++ b/mlir/test/Conversion/SCFToSPIRV/if.mlir @@ -153,4 +153,18 @@ func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32> return } +// Memrefs without a spirv storage class are not supported. The conversion +// should preserve the `scf.if` and not crash. +func.func @unsupported_yield_type(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %c : i1) { +// CHECK-LABEL: @unsupported_yield_type +// CHECK-NEXT: scf.if +// CHECK: spirv.Return + %r = scf.if %c -> (memref<8xi32>) { + scf.yield %arg0 : memref<8xi32> + } else { + scf.yield %arg1 : memref<8xi32> + } + return +} + } // end module -- 2.30.2