Skip to content

Commit c5182c1

Browse files
authored
[MLIR][Interfaces] Allow non-builtin scalar types in IndexingMapOpInterface (#188774)
The scalar type check in `IndexingMapOpInterface::verifyImpl` and its helper `verifyIndexingMapOperandType` used `isIntOrIndexOrFloat() || isa<ComplexType>()`, which only accepted builtin scalar types. This rejected valid custom-dialect scalar types such as pointer types (`\!ptr.ptr<...>`) or other non-shaped dialect types. The `isScalar` method in `DestinationStyleOpInterface` already treats any non-MemRef/non-Tensor type as scalar. Align `IndexingMapOpInterface` with this definition by treating any non-ShapedType as a rank-0 scalar, regardless of whether it is a builtin type. Fixes #183606 Assisted-by: Claude Code
1 parent 331c1c0 commit c5182c1

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

mlir/lib/Interfaces/IndexingMapOpInterface.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ namespace mlir {
1616

1717
static LogicalResult verifyIndexingMapOperandType(Operation *op, Type t,
1818
unsigned operandNumber) {
19-
// Scalars are allowed (treated as rank-0). verifyImpl checks the rank.
20-
if (t.isIntOrIndexOrFloat() || isa<ComplexType>(t))
19+
// Non-shaped types are treated as scalars (rank-0). This includes builtin
20+
// types (integer, float, complex) as well as custom dialect types.
21+
if (!isa<ShapedType>(t))
2122
return success();
2223

2324
// Vectors are allowed.
@@ -67,8 +68,9 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
6768
if (indexingMap.getNumSymbols() != 0)
6869
return this->emitOpError("unexpected symbols in indexing_map #")
6970
<< opOperand.getOperandNumber();
70-
// Handle scalars.
71-
if (ty.isIntOrIndexOrFloat() || isa<ComplexType>(ty)) {
71+
// Handle scalars (non-shaped types: integer, float, complex, custom types,
72+
// etc.).
73+
if (!isa<ShapedType>(ty)) {
7274
int64_t rank = 0;
7375
if (indexingMap.getNumResults() != rank)
7476
return this->emitOpError("expected operand #")
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: mlir-opt %s | FileCheck %s
2+
// This test verifies that linalg.fill and linalg.generic accept non-builtin
3+
// scalar types (e.g., custom dialect types) as operands.
4+
5+
// CHECK-LABEL: @fill_non_builtin_scalar_type
6+
func.func @fill_non_builtin_scalar_type(%src: !test.test_type, %dst: tensor<4x!test.test_type>) -> tensor<4x!test.test_type> {
7+
// CHECK: linalg.fill
8+
%result = linalg.fill ins(%src : !test.test_type) outs(%dst : tensor<4x!test.test_type>) -> tensor<4x!test.test_type>
9+
return %result : tensor<4x!test.test_type>
10+
}

0 commit comments

Comments
 (0)