Skip to content

Conversation

@Abhishek-Varma
Copy link
Contributor

-- This commit is the eighth in the series of adding matchers
for linalg.conv/pool. Refer: #163724
-- In this commit all variants of Pooling ops have been added.

Signed-off-by: Abhishek Varma abhvarma@amd.com

-- This commit is the eighth in the series of adding matchers
   for linalg.*conv*/*pool*. Refer: llvm#163724
-- In this commit all variants of Pooling ops have been added.

Signed-off-by: Abhishek Varma <abhvarma@amd.com>
@llvmbot
Copy link
Member

llvmbot commented Dec 15, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Abhishek Varma (Abhishek-Varma)

Changes

-- This commit is the eighth in the series of adding matchers
for linalg.conv/pool. Refer: #163724
-- In this commit all variants of Pooling ops have been added.

Signed-off-by: Abhishek Varma <abhvarma@amd.com>


Patch is 20.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172351.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+12)
  • (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+315)
  • (modified) mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir (+157)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 397e322a64dea..0c7b998ffcab9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -316,6 +316,18 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
   CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp);
   CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp);
   CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp);
+  CONV_OP_SPECIALIZER(linalg::PoolingNchwSumOp);
+  CONV_OP_SPECIALIZER(linalg::PoolingNchwMaxOp);
+  CONV_OP_SPECIALIZER(linalg::PoolingNwcSumOp);
+  CONV_OP_SPECIALIZER(linalg::PoolingNcwSumOp);
+  CONV_OP_SPECIALIZER(linalg::PoolingNwcMaxOp);
+  CONV_OP_SPECIALIZER(linalg::PoolingNwcMaxUnsignedOp);
+  CONV_OP_SPECIALIZER(linalg::PoolingNcwMaxOp);
+  CONV_OP_SPECIALIZER(linalg::PoolingNwcMinOp);
+  CONV_OP_SPECIALIZER(linalg::PoolingNwcMinUnsignedOp);
+  CONV_OP_SPECIALIZER(linalg::PoolingNdhwcSumOp);
+  CONV_OP_SPECIALIZER(linalg::PoolingNdhwcMaxOp);
+  CONV_OP_SPECIALIZER(linalg::PoolingNdhwcMinOp);
 #undef CONV_OP_SPECIALIZER
   return failure();
 }
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 5c4a359dac4a4..055a7bdcf61b8 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -1684,6 +1684,321 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
       .matchBody();
 }
 
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNchwSumOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNchwSumOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+                       PoolingType::Sum);
+  AffineExpr N = m.dim(0);
+  AffineExpr C = m.dim(1);
+  AffineExpr H = m.dim(2);
+  AffineExpr W = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
+      .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1)
+      .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{N, C, H, W}})
+      .matchBody();
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNchwMaxOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNchwMaxOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+                       PoolingType::MaxSigned);
+  AffineExpr N = m.dim(0);
+  AffineExpr C = m.dim(1);
+  AffineExpr H = m.dim(2);
+  AffineExpr W = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
+      .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1)
+      .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{N, C, H, W}})
+      .matchBody();
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNwcSumOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNwcSumOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
+                       PoolingType::Sum);
+  AffineExpr N = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr C = m.dim(2);
+  AffineExpr w = m.dim(3);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+                  /*filterMap=*/{w},
+                  /*outputMap=*/{N, W, C}})
+      .matchBody();
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNcwSumOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNcwSumOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
+                       PoolingType::Sum);
+  AffineExpr N = m.dim(0);
+  AffineExpr C = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr w = m.dim(3);
+
+  return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
+      .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+                  /*filterMap=*/{w},
+                  /*outputMap=*/{N, C, W}})
+      .matchBody();
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNwcMaxOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNwcMaxOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
+                       PoolingType::MaxSigned);
+  AffineExpr N = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr C = m.dim(2);
+  AffineExpr w = m.dim(3);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+                  /*filterMap=*/{w},
+                  /*outputMap=*/{N, W, C}})
+      .matchBody();
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNwcMaxUnsignedOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNwcMaxUnsignedOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
+                       PoolingType::MaxUnsigned);
+  AffineExpr N = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr C = m.dim(2);
+  AffineExpr w = m.dim(3);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+                  /*filterMap=*/{w},
+                  /*outputMap=*/{N, W, C}})
+      .matchBody();
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNcwMaxOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNcwMaxOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
+                       PoolingType::MaxSigned);
+  AffineExpr N = m.dim(0);
+  AffineExpr C = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr w = m.dim(3);
+
+  return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
+      .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+                  /*filterMap=*/{w},
+                  /*outputMap=*/{N, C, W}})
+      .matchBody();
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNwcMinOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNwcMinOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
+                       PoolingType::MinSigned);
+  AffineExpr N = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr C = m.dim(2);
+  AffineExpr w = m.dim(3);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+                  /*filterMap=*/{w},
+                  /*outputMap=*/{N, W, C}})
+      .matchBody();
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNwcMinUnsignedOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNwcMinUnsignedOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
+                       PoolingType::MinUnsigned);
+  AffineExpr N = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr C = m.dim(2);
+  AffineExpr w = m.dim(3);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+                  /*filterMap=*/{w},
+                  /*outputMap=*/{N, W, C}})
+      .matchBody();
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNdhwcSumOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNdhwcSumOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
+                       PoolingType::Sum);
+  AffineExpr N = m.dim(0);
+  AffineExpr D = m.dim(1);
+  AffineExpr H = m.dim(2);
+  AffineExpr W = m.dim(3);
+  AffineExpr C = m.dim(4);
+  AffineExpr d = m.dim(5);
+  AffineExpr h = m.dim(6);
+  AffineExpr w = m.dim(7);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
+      .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+                                m.strided(W, w, 2), C},
+                  /*filterMap=*/{d, h, w},
+                  /*outputMap=*/{N, D, H, W, C}})
+      .matchBody();
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNdhwcMaxOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNdhwcMaxOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
+                       PoolingType::MaxSigned);
+  AffineExpr N = m.dim(0);
+  AffineExpr D = m.dim(1);
+  AffineExpr H = m.dim(2);
+  AffineExpr W = m.dim(3);
+  AffineExpr C = m.dim(4);
+  AffineExpr d = m.dim(5);
+  AffineExpr h = m.dim(6);
+  AffineExpr w = m.dim(7);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
+      .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+                                m.strided(W, w, 2), C},
+                  /*filterMap=*/{d, h, w},
+                  /*outputMap=*/{N, D, H, W, C}})
+      .matchBody();
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNdhwcMinOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNdhwcMinOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
+                       PoolingType::MinSigned);
+  AffineExpr N = m.dim(0);
+  AffineExpr D = m.dim(1);
+  AffineExpr H = m.dim(2);
+  AffineExpr W = m.dim(3);
+  AffineExpr C = m.dim(4);
+  AffineExpr d = m.dim(5);
+  AffineExpr h = m.dim(6);
+  AffineExpr w = m.dim(7);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+      .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
+      .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+                                m.strided(W, w, 2), C},
+                  /*filterMap=*/{d, h, w},
+                  /*outputMap=*/{N, D, H, W, C}})
+      .matchBody();
+}
+
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
                             Value source, Value pad, bool nofold,
                             ValueRange typeDynDims) {
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
index ac9a33b0528b0..1d01d2dad3105 100644
--- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -491,3 +491,160 @@ func.func @pooling_nhwc_min_unsigned_float(%input: tensor<?x?x?x?xf32>, %filter:
 //      CHECK: @pooling_nhwc_min_unsigned_float
 //      CHECK:   linalg.pooling_nhwc_min
 // CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @pooling_nchw_sum(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nchw_sum
+         {dilations = dense<2> : tensor<2xi64>, strides = dense<3> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+         outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nchw_sum
+//      CHECK:   linalg.pooling_nchw_sum
+// CHECK-SAME:      dilations = dense<2> : tensor<2xi64>, strides = dense<3> : tensor<2xi64>
+
+// -----
+
+func.func @pooling_nchw_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nchw_max
+         {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+         ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+         outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nchw_max
+//      CHECK:   linalg.pooling_nchw_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>
+
+// -----
+
+func.func @pooling_nwc_sum(%input: tensor<?x?x?xf32>, %filter: tensor<?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.pooling_nwc_sum
+         {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+         ins (%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+         outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @pooling_nwc_sum
+//      CHECK:   linalg.pooling_nwc_sum
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+
+// -----
+
+func.func @pooling_ncw_sum(%input: tensor<?x?x?xf32>, %filter: tensor<?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.pooling_ncw_sum
+         {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+         ins (%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+         outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @pooling_ncw_sum
+//      CHECK:   linalg.pooling_ncw_sum
+// CHECK-SAME:      dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>
+
+// -----
+
+func.func @pooling_nwc_max(%input: tensor<?x?x?xf32>, %filter: tensor<?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.pooling_nwc_max
+         {dilations = dense<1> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+         ins (%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+         outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @pooling_nwc_max
+//      CHECK:   linalg.pooling_nwc_max
+// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+
+// -----
+
+func.func @pooling_nwc_max_unsigned(%input: tensor<?x?x?xi8>, %filter: tensor<?xi8>, %output: tensor<?x?x?xi32>) -> tensor<?x?x?xi32> {
+  %0 = linalg.pooling_nwc_max_unsigned
+         {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+         ins (%input, %filter: tensor<?x?x?xi8>, tensor<?xi8>)
+         outs (%output: tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
+  return %0 : tensor<?x?x?xi32>
+}
+//      CHECK: @pooling_nwc_max_unsigned
+//      CHECK:   linalg.pooling_nwc_max_unsigned
+// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>
+
+// -----
+
+func.func @pooling_ncw_max(%input: tensor<?x?x?xf32>, %filter: tensor<?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.pooling_ncw_max
+         {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+         ins (%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+         outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @pooling_ncw_max
+//      CHECK:   linalg.pooling_ncw_max
+// CHECK-SAME:      dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>
+
+// -----
+
+func.func @pooling_nwc_min(%input: tensor<?x?x?xf32>, %filter: tensor<?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.pooling_nwc_min
+         {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+         ins (%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+         outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @pooling_nwc_min
+//      CHECK:   linalg.pooling_nwc_min
+// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>
+
+// -----
+
+func.func @pooling_nwc_min_unsigned(%input: tensor<?x?x?xi8>, %filter: tensor<?xi8>, %output: tensor<?x?x?xi32>) -> tensor<?x?x?xi32> {
+  %0 = linalg.pooling_nwc_min_unsigned
+         {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+         ins (%input, %filter: tensor<?x?x?xi8>, tensor<?xi8>)
+         outs (%output: tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
+  return %0 : tensor<?x?x?xi32>
+}
+//      CHECK: @pooling_nwc_min_unsigned
+//      CHECK:   linalg.pooling_nwc_min_unsigned
+// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>
+
+// -----
+
+func.func @pooling_ndhwc_sum(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.pooling_ndhwc_sum
+         {dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>}
+         ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>)
+         outs (%output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @pooling_ndhwc_sum
+//      CHECK:   linalg.pooling_ndhwc_sum
+// CHECK-SAME:      dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>
+
+// -----
+
+func.func @pooling_ndhwc_max(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.pooling_ndhwc_max
+         {dilations = dense<1> : tensor<3xi64>, strides = dense<2> : tensor<3xi64>}
+         ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>)
+         outs (%output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @pooling_ndhwc_max
+//      CHECK:   linalg.pooling_ndhwc_max
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<2> : tensor<3xi64>
+
+// -----
+
+func.func @pooling_ndhwc_min(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.pooling_ndhwc_min
+         {dilations = dense<[1, 2, 3]> : tensor<3xi64>, strides = dense<[4, 5, 6]> : tensor<3xi64>}
+         ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>)
+         outs (%output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @pooling_ndhwc_min
+//  ...
[truncated]

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need all these matchers for our work? I'm worried about tech debts. When can we start working on fixing the matchers in other upstream patterns? E.g.,

void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
Conv1DNwcWcfOp>,
DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
Conv1DNcwFcwOp>,
DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>(
patterns.getContext(), benefit);
patterns.add<
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
PoolingNwcMaxUnsignedOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
PoolingNwcMinUnsignedOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
patterns.getContext(), benefit);
}

https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

@Abhishek-Varma
Copy link
Contributor Author

Do we need all these matchers for our work? I'm worried about tech debts. When can we start working on fixing the matchers in other upstream patterns? E.g.,

void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
Conv1DNwcWcfOp>,
DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
Conv1DNcwFcwOp>,
DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>(
patterns.getContext(), benefit);
patterns.add<
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
PoolingNwcMaxUnsignedOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
PoolingNwcMinUnsignedOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
patterns.getContext(), benefit);
}

https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Hi @hanhanW - the above upstream patterns include 60% of the Pooling ops this PR aims to add. As we've already added all Conv1D/2D/3D variants, for consistency, with this PR we will have all Convolution ops' variants finally supported.

Since broadly we have 5 pooling ops' format nwc, ncw, nhwc, nchw,and ndhwc that differ only on the summarizing operator and nhwc's support is already added via #163724 - this PR essentially adds support for the other 4 formats.

When can we start working on fixing the matchers in other upstream patterns?

This PR will be the final when it comes to adding the matchers. Therefore the PR right after this will finally make use of these matchers to allow upstream patterns to work with linalg.generic forms as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants