Skip to content

Commit fb8bbd4

Browse files
authored
[mlir][Python] use canonical Python isinstance instead of Type.isinstance (#172892)
We've been able to do `isinstance(x, Type)` for a quite a while now (since bfb1ba7) so remove `Type.isinstance` and the the special-casing (`_is_integer_type`, `_is_floating_point_type`, `_is_index_type`) in some places (and therefore support various `fp8`, `fp6`, `fp4` types).
1 parent e826168 commit fb8bbd4

File tree

21 files changed

+224
-216
lines changed

21 files changed

+224
-216
lines changed

mlir/include/mlir-c/Dialect/PDL.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLType(MlirType type);
3030

3131
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLAttributeType(MlirType type);
3232

33+
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLAttributeTypeGetTypeID(void);
34+
3335
MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx);
3436

3537
//===---------------------------------------------------------------------===//
@@ -38,6 +40,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx);
3840

3941
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLOperationType(MlirType type);
4042

43+
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLOperationTypeGetTypeID(void);
44+
4145
MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx);
4246

4347
//===---------------------------------------------------------------------===//
@@ -46,6 +50,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx);
4650

4751
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLRangeType(MlirType type);
4852

53+
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLRangeTypeGetTypeID(void);
54+
4955
MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType);
5056

5157
MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type);
@@ -56,6 +62,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type);
5662

5763
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLTypeType(MlirType type);
5864

65+
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLTypeTypeGetTypeID(void);
66+
5967
MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx);
6068

6169
//===---------------------------------------------------------------------===//
@@ -64,6 +72,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx);
6472

6573
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLValueType(MlirType type);
6674

75+
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLValueTypeGetTypeID(void);
76+
6777
MLIR_CAPI_EXPORTED MlirType mlirPDLValueTypeGet(MlirContext ctx);
6878

6979
#ifdef __cplusplus

mlir/include/mlir-c/Dialect/Quant.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ mlirQuantizedTypeCastExpressedToStorageType(MlirType type, MlirType candidate);
103103
/// Returns `true` if the given type is an AnyQuantizedType.
104104
MLIR_CAPI_EXPORTED bool mlirTypeIsAAnyQuantizedType(MlirType type);
105105

106+
MLIR_CAPI_EXPORTED MlirTypeID mlirAnyQuantizedTypeGetTypeID(void);
107+
106108
/// Creates an instance of AnyQuantizedType with the given parameters in the
107109
/// same context as `storageType` and returns it. The instance is owned by the
108110
/// context.
@@ -119,6 +121,8 @@ MLIR_CAPI_EXPORTED MlirType mlirAnyQuantizedTypeGet(unsigned flags,
119121
/// Returns `true` if the given type is a UniformQuantizedType.
120122
MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedType(MlirType type);
121123

124+
MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedTypeGetTypeID(void);
125+
122126
/// Creates an instance of UniformQuantizedType with the given parameters in the
123127
/// same context as `storageType` and returns it. The instance is owned by the
124128
/// context.
@@ -142,6 +146,8 @@ MLIR_CAPI_EXPORTED bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type);
142146
/// Returns `true` if the given type is a UniformQuantizedPerAxisType.
143147
MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type);
144148

149+
MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedPerAxisTypeGetTypeID(void);
150+
145151
/// Creates an instance of UniformQuantizedPerAxisType with the given parameters
146152
/// in the same context as `storageType` and returns it. `scales` and
147153
/// `zeroPoints` point to `nDims` number of elements. The instance is owned
@@ -180,6 +186,8 @@ mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type);
180186
MLIR_CAPI_EXPORTED bool
181187
mlirTypeIsAUniformQuantizedSubChannelType(MlirType type);
182188

189+
MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedSubChannelTypeGetTypeID(void);
190+
183191
/// Creates a UniformQuantizedSubChannelType with the given parameters.
184192
///
185193
/// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be
@@ -220,6 +228,8 @@ mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type);
220228
/// Returns `true` if the given type is a CalibratedQuantizedType.
221229
MLIR_CAPI_EXPORTED bool mlirTypeIsACalibratedQuantizedType(MlirType type);
222230

231+
MLIR_CAPI_EXPORTED MlirTypeID mlirCalibratedQuantizedTypeGetTypeID(void);
232+
223233
/// Creates an instance of CalibratedQuantizedType with the given parameters
224234
/// in the same context as `expressedType` and returns it. The instance is owned
225235
/// by the context.

mlir/include/mlir/Bindings/Python/IRCore.h

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -957,12 +957,6 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
957957
auto cls = ClassTy(m, DerivedTy::pyClassName);
958958
cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(),
959959
nanobind::arg("cast_from_type"));
960-
cls.def_static(
961-
"isinstance",
962-
[](PyType &otherType) -> bool {
963-
return DerivedTy::isaFunction(otherType);
964-
},
965-
nanobind::arg("other"));
966960
cls.def_prop_ro_static(
967961
"static_typeid",
968962
[](nanobind::object & /*class*/) {
@@ -1094,12 +1088,6 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
10941088
}
10951089
cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(),
10961090
nanobind::arg("cast_from_attr"));
1097-
cls.def_static(
1098-
"isinstance",
1099-
[](PyAttribute &otherAttr) -> bool {
1100-
return DerivedTy::isaFunction(otherAttr);
1101-
},
1102-
nanobind::arg("other"));
11031091
cls.def_prop_ro(
11041092
"type",
11051093
[](PyAttribute &attr) -> nanobind::typed<nanobind::object, PyType> {
@@ -1555,12 +1543,6 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
15551543
.c_str()));
15561544
cls.def(nanobind::init<PyValue &>(), nanobind::keep_alive<0, 1>(),
15571545
nanobind::arg("value"));
1558-
cls.def_static(
1559-
"isinstance",
1560-
[](PyValue &otherValue) -> bool {
1561-
return DerivedTy::isaFunction(otherValue);
1562-
},
1563-
nanobind::arg("other_value"));
15641546
cls.def(
15651547
MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
15661548
[](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {

mlir/lib/Bindings/Python/DialectPDL.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ struct PDLType : PyConcreteType<PDLType> {
3939

4040
struct AttributeType : PyConcreteType<AttributeType> {
4141
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLAttributeType;
42+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
43+
mlirPDLAttributeTypeGetTypeID;
4244
static constexpr const char *pyClassName = "AttributeType";
4345
using Base::Base;
4446

@@ -60,6 +62,8 @@ struct AttributeType : PyConcreteType<AttributeType> {
6062

6163
struct OperationType : PyConcreteType<OperationType> {
6264
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLOperationType;
65+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
66+
mlirPDLOperationTypeGetTypeID;
6367
static constexpr const char *pyClassName = "OperationType";
6468
using Base::Base;
6569

@@ -81,6 +85,8 @@ struct OperationType : PyConcreteType<OperationType> {
8185

8286
struct RangeType : PyConcreteType<RangeType> {
8387
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLRangeType;
88+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
89+
mlirPDLRangeTypeGetTypeID;
8490
static constexpr const char *pyClassName = "RangeType";
8591
using Base::Base;
8692

@@ -109,6 +115,8 @@ struct RangeType : PyConcreteType<RangeType> {
109115

110116
struct TypeType : PyConcreteType<TypeType> {
111117
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLTypeType;
118+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
119+
mlirPDLTypeTypeGetTypeID;
112120
static constexpr const char *pyClassName = "TypeType";
113121
using Base::Base;
114122

@@ -130,6 +138,8 @@ struct TypeType : PyConcreteType<TypeType> {
130138

131139
struct ValueType : PyConcreteType<ValueType> {
132140
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLValueType;
141+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
142+
mlirPDLValueTypeGetTypeID;
133143
static constexpr const char *pyClassName = "ValueType";
134144
using Base::Base;
135145

mlir/lib/Bindings/Python/DialectQuant.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
192192

193193
struct AnyQuantizedType : PyConcreteType<AnyQuantizedType, QuantizedType> {
194194
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAnyQuantizedType;
195+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
196+
mlirAnyQuantizedTypeGetTypeID;
195197
static constexpr const char *pyClassName = "AnyQuantizedType";
196198
using Base::Base;
197199

@@ -221,6 +223,8 @@ struct AnyQuantizedType : PyConcreteType<AnyQuantizedType, QuantizedType> {
221223
struct UniformQuantizedType
222224
: PyConcreteType<UniformQuantizedType, QuantizedType> {
223225
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUniformQuantizedType;
226+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
227+
mlirUniformQuantizedTypeGetTypeID;
224228
static constexpr const char *pyClassName = "UniformQuantizedType";
225229
using Base::Base;
226230

@@ -273,6 +277,8 @@ struct UniformQuantizedPerAxisType
273277
: PyConcreteType<UniformQuantizedPerAxisType, QuantizedType> {
274278
static constexpr IsAFunctionTy isaFunction =
275279
mlirTypeIsAUniformQuantizedPerAxisType;
280+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
281+
mlirUniformQuantizedPerAxisTypeGetTypeID;
276282
static constexpr const char *pyClassName = "UniformQuantizedPerAxisType";
277283
using Base::Base;
278284

@@ -357,6 +363,8 @@ struct UniformQuantizedSubChannelType
357363
: PyConcreteType<UniformQuantizedSubChannelType, QuantizedType> {
358364
static constexpr IsAFunctionTy isaFunction =
359365
mlirTypeIsAUniformQuantizedSubChannelType;
366+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
367+
mlirUniformQuantizedSubChannelTypeGetTypeID;
360368
static constexpr const char *pyClassName = "UniformQuantizedSubChannelType";
361369
using Base::Base;
362370

@@ -448,6 +456,8 @@ struct CalibratedQuantizedType
448456
: PyConcreteType<CalibratedQuantizedType, QuantizedType> {
449457
static constexpr IsAFunctionTy isaFunction =
450458
mlirTypeIsACalibratedQuantizedType;
459+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
460+
mlirCalibratedQuantizedTypeGetTypeID;
451461
static constexpr const char *pyClassName = "CalibratedQuantizedType";
452462
using Base::Base;
453463

mlir/lib/Bindings/Python/IRAffine.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,6 @@ class PyConcreteAffineExpr : public BaseTy {
118118
static void bind(nb::module_ &m) {
119119
auto cls = ClassTy(m, DerivedTy::pyClassName);
120120
cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr"));
121-
cls.def_static(
122-
"isinstance",
123-
[](PyAffineExpr &otherAffineExpr) -> bool {
124-
return DerivedTy::isaFunction(otherAffineExpr);
125-
},
126-
nb::arg("other"));
127121
DerivedTy::bindDerived(cls);
128122
}
129123

mlir/lib/CAPI/Dialect/PDL.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ bool mlirTypeIsAPDLAttributeType(MlirType type) {
3232
return isa<pdl::AttributeType>(unwrap(type));
3333
}
3434

35+
MlirTypeID mlirPDLAttributeTypeGetTypeID(void) {
36+
return wrap(pdl::AttributeType::getTypeID());
37+
}
38+
3539
MlirType mlirPDLAttributeTypeGet(MlirContext ctx) {
3640
return wrap(pdl::AttributeType::get(unwrap(ctx)));
3741
}
@@ -44,6 +48,10 @@ bool mlirTypeIsAPDLOperationType(MlirType type) {
4448
return isa<pdl::OperationType>(unwrap(type));
4549
}
4650

51+
MlirTypeID mlirPDLOperationTypeGetTypeID(void) {
52+
return wrap(pdl::OperationType::getTypeID());
53+
}
54+
4755
MlirType mlirPDLOperationTypeGet(MlirContext ctx) {
4856
return wrap(pdl::OperationType::get(unwrap(ctx)));
4957
}
@@ -56,6 +64,10 @@ bool mlirTypeIsAPDLRangeType(MlirType type) {
5664
return isa<pdl::RangeType>(unwrap(type));
5765
}
5866

67+
MlirTypeID mlirPDLRangeTypeGetTypeID(void) {
68+
return wrap(pdl::RangeType::getTypeID());
69+
}
70+
5971
MlirType mlirPDLRangeTypeGet(MlirType elementType) {
6072
return wrap(pdl::RangeType::get(unwrap(elementType)));
6173
}
@@ -72,6 +84,10 @@ bool mlirTypeIsAPDLTypeType(MlirType type) {
7284
return isa<pdl::TypeType>(unwrap(type));
7385
}
7486

87+
MlirTypeID mlirPDLTypeTypeGetTypeID(void) {
88+
return wrap(pdl::TypeType::getTypeID());
89+
}
90+
7591
MlirType mlirPDLTypeTypeGet(MlirContext ctx) {
7692
return wrap(pdl::TypeType::get(unwrap(ctx)));
7793
}
@@ -84,6 +100,10 @@ bool mlirTypeIsAPDLValueType(MlirType type) {
84100
return isa<pdl::ValueType>(unwrap(type));
85101
}
86102

103+
MlirTypeID mlirPDLValueTypeGetTypeID(void) {
104+
return wrap(pdl::ValueType::getTypeID());
105+
}
106+
87107
MlirType mlirPDLValueTypeGet(MlirContext ctx) {
88108
return wrap(pdl::ValueType::get(unwrap(ctx)));
89109
}

mlir/lib/CAPI/Dialect/Quant.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ bool mlirTypeIsAAnyQuantizedType(MlirType type) {
113113
return isa<quant::AnyQuantizedType>(unwrap(type));
114114
}
115115

116+
MlirTypeID mlirAnyQuantizedTypeGetTypeID(void) {
117+
return wrap(quant::AnyQuantizedType::getTypeID());
118+
}
119+
116120
MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
117121
MlirType expressedType, int64_t storageTypeMin,
118122
int64_t storageTypeMax) {
@@ -129,6 +133,10 @@ bool mlirTypeIsAUniformQuantizedType(MlirType type) {
129133
return isa<quant::UniformQuantizedType>(unwrap(type));
130134
}
131135

136+
MlirTypeID mlirUniformQuantizedTypeGetTypeID(void) {
137+
return wrap(quant::UniformQuantizedType::getTypeID());
138+
}
139+
132140
MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
133141
MlirType expressedType, double scale,
134142
int64_t zeroPoint, int64_t storageTypeMin,
@@ -158,6 +166,10 @@ bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) {
158166
return isa<quant::UniformQuantizedPerAxisType>(unwrap(type));
159167
}
160168

169+
MlirTypeID mlirUniformQuantizedPerAxisTypeGetTypeID(void) {
170+
return wrap(quant::UniformQuantizedPerAxisType::getTypeID());
171+
}
172+
161173
MlirType mlirUniformQuantizedPerAxisTypeGet(
162174
unsigned flags, MlirType storageType, MlirType expressedType,
163175
intptr_t nDims, double *scales, int64_t *zeroPoints,
@@ -203,6 +215,10 @@ bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type) {
203215
return isa<quant::UniformQuantizedSubChannelType>(unwrap(type));
204216
}
205217

218+
MlirTypeID mlirUniformQuantizedSubChannelTypeGetTypeID(void) {
219+
return wrap(quant::UniformQuantizedSubChannelType::getTypeID());
220+
}
221+
206222
MlirType mlirUniformQuantizedSubChannelTypeGet(
207223
unsigned flags, MlirType storageType, MlirType expressedType,
208224
MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, intptr_t nDims,
@@ -258,6 +274,10 @@ bool mlirTypeIsACalibratedQuantizedType(MlirType type) {
258274
return isa<quant::CalibratedQuantizedType>(unwrap(type));
259275
}
260276

277+
MlirTypeID mlirCalibratedQuantizedTypeGetTypeID(void) {
278+
return wrap(quant::CalibratedQuantizedType::getTypeID());
279+
}
280+
261281
MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
262282
double max) {
263283
return wrap(

mlir/python/mlir/dialects/arith.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,6 @@
2121
raise RuntimeError("Error loading imports from extension module") from e
2222

2323

24-
def _isa(obj: Any, cls: type):
25-
try:
26-
cls(obj)
27-
except ValueError:
28-
return False
29-
return True
30-
31-
32-
def _is_any_of(obj: Any, classes: List[type]):
33-
return any(_isa(obj, cls) for cls in classes)
34-
35-
36-
def _is_integer_like_type(type: Type):
37-
return _is_any_of(type, [IntegerType, IndexType])
38-
39-
40-
def _is_float_type(type: Type):
41-
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
42-
43-
4424
@_ods_cext.register_operation(_Dialect, replace=True)
4525
class ConstantOp(ConstantOp):
4626
"""Specialization for the constant op class."""
@@ -96,9 +76,9 @@ def value(self):
9676

9777
@property
9878
def literal_value(self) -> Union[int, float]:
99-
if _is_integer_like_type(self.type):
79+
if isinstance(self.type, (IntegerType, IndexType)):
10080
return IntegerAttr(self.value).value
101-
elif _is_float_type(self.type):
81+
elif isinstance(self.type, FloatType):
10282
return FloatAttr(self.value).value
10383
else:
10484
raise ValueError("only integer and float constants have literal values")

0 commit comments

Comments
 (0)