diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 397e322a64dea..0c7b998ffcab9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -316,6 +316,18 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp); CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp); CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp); + CONV_OP_SPECIALIZER(linalg::PoolingNchwSumOp); + CONV_OP_SPECIALIZER(linalg::PoolingNchwMaxOp); + CONV_OP_SPECIALIZER(linalg::PoolingNwcSumOp); + CONV_OP_SPECIALIZER(linalg::PoolingNcwSumOp); + CONV_OP_SPECIALIZER(linalg::PoolingNwcMaxOp); + CONV_OP_SPECIALIZER(linalg::PoolingNwcMaxUnsignedOp); + CONV_OP_SPECIALIZER(linalg::PoolingNcwMaxOp); + CONV_OP_SPECIALIZER(linalg::PoolingNwcMinOp); + CONV_OP_SPECIALIZER(linalg::PoolingNwcMinUnsignedOp); + CONV_OP_SPECIALIZER(linalg::PoolingNdhwcSumOp); + CONV_OP_SPECIALIZER(linalg::PoolingNdhwcMaxOp); + CONV_OP_SPECIALIZER(linalg::PoolingNdhwcMinOp); #undef CONV_OP_SPECIALIZER return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 5c4a359dac4a4..055a7bdcf61b8 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -1684,6 +1684,321 @@ bool isaConvolutionOpOfType( .matchBody(); } +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::Sum); + AffineExpr N = m.dim(0); + AffineExpr C = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0) + .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, C, H, W}}) + .matchBody(); +} + +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MaxSigned); + AffineExpr N = m.dim(0); + AffineExpr C = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0) + .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, C, H, W}}) + .matchBody(); +} + +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides, + PoolingType::Sum); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w}, + /*outputMap=*/{N, W, C}}) + .matchBody(); +} + +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides, + PoolingType::Sum); + AffineExpr N = m.dim(0); + AffineExpr C = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)}, + /*filterMap=*/{w}, + /*outputMap=*/{N, C, W}}) + .matchBody(); +} + +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides, + PoolingType::MaxSigned); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w}, + /*outputMap=*/{N, W, C}}) + .matchBody(); +} + +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides, + PoolingType::MaxUnsigned); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w}, + /*outputMap=*/{N, W, C}}) + .matchBody(); +} + +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides, + PoolingType::MaxSigned); + AffineExpr N = m.dim(0); + AffineExpr C = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)}, + /*filterMap=*/{w}, + /*outputMap=*/{N, C, W}}) + .matchBody(); +} + +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides, + PoolingType::MinSigned); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w}, + /*outputMap=*/{N, W, C}}) + .matchBody(); +} + +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides, + PoolingType::MinUnsigned); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w}, + /*outputMap=*/{N, W, C}}) + .matchBody(); +} + +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides, + PoolingType::Sum); + AffineExpr N = m.dim(0); + AffineExpr D = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr C = m.dim(4); + AffineExpr d = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2) + .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2), C}, + /*filterMap=*/{d, h, w}, + /*outputMap=*/{N, D, H, W, C}}) + .matchBody(); +} + +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides, + PoolingType::MaxSigned); + AffineExpr N = m.dim(0); + AffineExpr D = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr C = m.dim(4); + AffineExpr d = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2) + .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2), C}, + /*filterMap=*/{d, h, w}, + /*outputMap=*/{N, D, H, W, C}}) + .matchBody(); +} + +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides, + PoolingType::MinSigned); + AffineExpr N = m.dim(0); + AffineExpr D = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr C = m.dim(4); + AffineExpr d = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2) + .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2), C}, + /*filterMap=*/{d, h, w}, + /*outputMap=*/{N, D, H, W, C}}) + .matchBody(); +} + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir index ac9a33b0528b0..1d01d2dad3105 100644 --- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -491,3 +491,160 @@ func.func @pooling_nhwc_min_unsigned_float(%input: tensor, %filter: // CHECK: @pooling_nhwc_min_unsigned_float // CHECK: linalg.pooling_nhwc_min // CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nchw_sum(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nchw_sum + {dilations = dense<2> : tensor<2xi64>, strides = dense<3> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nchw_sum +// CHECK: linalg.pooling_nchw_sum +// CHECK-SAME: dilations = dense<2> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> + +// ----- + +func.func @pooling_nchw_max(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nchw_max + {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nchw_max +// CHECK: linalg.pooling_nchw_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64> + +// ----- + +func.func @pooling_nwc_sum(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nwc_sum + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nwc_sum +// CHECK: linalg.pooling_nwc_sum +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + +func.func @pooling_ncw_sum(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_ncw_sum + {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ncw_sum +// CHECK: linalg.pooling_ncw_sum +// CHECK-SAME: dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64> + +// ----- + +func.func @pooling_nwc_max(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nwc_max + {dilations = dense<1> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nwc_max +// CHECK: linalg.pooling_nwc_max +// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + +func.func @pooling_nwc_max_unsigned(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nwc_max_unsigned + {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nwc_max_unsigned +// CHECK: linalg.pooling_nwc_max_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64> + +// ----- + +func.func @pooling_ncw_max(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_ncw_max + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ncw_max +// CHECK: linalg.pooling_ncw_max +// CHECK-SAME: dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64> + +// ----- + +func.func @pooling_nwc_min(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nwc_min + {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nwc_min +// CHECK: linalg.pooling_nwc_min +// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64> + +// ----- + +func.func @pooling_nwc_min_unsigned(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nwc_min_unsigned + {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nwc_min_unsigned +// CHECK: linalg.pooling_nwc_min_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64> + +// ----- + +func.func @pooling_ndhwc_sum(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_ndhwc_sum + {dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ndhwc_sum +// CHECK: linalg.pooling_ndhwc_sum +// CHECK-SAME: dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64> + +// ----- + +func.func @pooling_ndhwc_max(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_ndhwc_max + {dilations = dense<1> : tensor<3xi64>, strides = dense<2> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ndhwc_max +// CHECK: linalg.pooling_ndhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<2> : tensor<3xi64> + +// ----- + +func.func @pooling_ndhwc_min(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_ndhwc_min + {dilations = dense<[1, 2, 3]> : tensor<3xi64>, strides = dense<[4, 5, 6]> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ndhwc_min +// CHECK: linalg.pooling_ndhwc_min +// CHECK-SAME: dilations = dense<[1, 2, 3]> : tensor<3xi64>, strides = +// dense<[4, 5, 6]> : tensor<3xi64>