Skip to content

Commit 0264d42

Browse files
authored
[mlir][CAPI][python] bind CallSiteLoc, FileLineColRange, FusedLoc, NameLoc (#129351)
This PR extends the python bindings for CallSiteLoc, FileLineColRange, FusedLoc, NameLoc with field accessors. It also adds the missing `value.location` accessor. I also did some "spring cleaning" here (`cast` -> `dyn_cast`) after running into some of my own illegal casts.
1 parent e0442bd commit 0264d42

File tree

5 files changed

+370
-55
lines changed

5 files changed

+370
-55
lines changed

mlir/include/mlir-c/IR.h

+80
Original file line numberDiff line numberDiff line change
@@ -261,22 +261,96 @@ MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(
261261
MlirContext context, MlirStringRef filename, unsigned start_line,
262262
unsigned start_col, unsigned end_line, unsigned end_col);
263263

264+
/// Getter for filename of FileLineColRange.
265+
MLIR_CAPI_EXPORTED MlirIdentifier
266+
mlirLocationFileLineColRangeGetFilename(MlirLocation location);
267+
268+
/// Getter for start_line of FileLineColRange.
269+
MLIR_CAPI_EXPORTED int
270+
mlirLocationFileLineColRangeGetStartLine(MlirLocation location);
271+
272+
/// Getter for start_column of FileLineColRange.
273+
MLIR_CAPI_EXPORTED int
274+
mlirLocationFileLineColRangeGetStartColumn(MlirLocation location);
275+
276+
/// Getter for end_line of FileLineColRange.
277+
MLIR_CAPI_EXPORTED int
278+
mlirLocationFileLineColRangeGetEndLine(MlirLocation location);
279+
280+
/// Getter for end_column of FileLineColRange.
281+
MLIR_CAPI_EXPORTED int
282+
mlirLocationFileLineColRangeGetEndColumn(MlirLocation location);
283+
284+
/// TypeID Getter for FileLineColRange.
285+
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFileLineColRangeGetTypeID(void);
286+
287+
/// Checks whether the given location is an FileLineColRange.
288+
MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location);
289+
264290
/// Creates a call site location with a callee and a caller.
265291
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee,
266292
MlirLocation caller);
267293

294+
/// Getter for callee of CallSite.
295+
MLIR_CAPI_EXPORTED MlirLocation
296+
mlirLocationCallSiteGetCallee(MlirLocation location);
297+
298+
/// Getter for caller of CallSite.
299+
MLIR_CAPI_EXPORTED MlirLocation
300+
mlirLocationCallSiteGetCaller(MlirLocation location);
301+
302+
/// TypeID Getter for CallSite.
303+
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationCallSiteGetTypeID(void);
304+
305+
/// Checks whether the given location is an CallSite.
306+
MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location);
307+
268308
/// Creates a fused location with an array of locations and metadata.
269309
MLIR_CAPI_EXPORTED MlirLocation
270310
mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
271311
MlirLocation const *locations, MlirAttribute metadata);
272312

313+
/// Getter for number of locations fused together.
314+
MLIR_CAPI_EXPORTED unsigned
315+
mlirLocationFusedGetNumLocations(MlirLocation location);
316+
317+
/// Getter for locations of Fused. Requires pre-allocated memory of
318+
/// #fusedLocations X sizeof(MlirLocation).
319+
MLIR_CAPI_EXPORTED void
320+
mlirLocationFusedGetLocations(MlirLocation location,
321+
MlirLocation *locationsCPtr);
322+
323+
/// Getter for metadata of Fused.
324+
MLIR_CAPI_EXPORTED MlirAttribute
325+
mlirLocationFusedGetMetadata(MlirLocation location);
326+
327+
/// TypeID Getter for Fused.
328+
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFusedGetTypeID(void);
329+
330+
/// Checks whether the given location is an Fused.
331+
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location);
332+
273333
/// Creates a name location owned by the given context. Providing null location
274334
/// for childLoc is allowed and if childLoc is null location, then the behavior
275335
/// is the same as having unknown child location.
276336
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context,
277337
MlirStringRef name,
278338
MlirLocation childLoc);
279339

340+
/// Getter for name of Name.
341+
MLIR_CAPI_EXPORTED MlirIdentifier
342+
mlirLocationNameGetName(MlirLocation location);
343+
344+
/// Getter for childLoc of Name.
345+
MLIR_CAPI_EXPORTED MlirLocation
346+
mlirLocationNameGetChildLoc(MlirLocation location);
347+
348+
/// TypeID Getter for Name.
349+
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationNameGetTypeID(void);
350+
351+
/// Checks whether the given location is an Name.
352+
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location);
353+
280354
/// Creates a location with unknown position owned by the given context.
281355
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context);
282356

@@ -978,6 +1052,12 @@ mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with,
9781052
intptr_t numExceptions,
9791053
MlirOperation *exceptions);
9801054

1055+
/// Gets the location of the value.
1056+
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v);
1057+
1058+
/// Gets the context that a value was created with.
1059+
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v);
1060+
9811061
//===----------------------------------------------------------------------===//
9821062
// OpOperand API.
9831063
//===----------------------------------------------------------------------===//

mlir/include/mlir/Bindings/Python/NanobindAdaptors.h

+10
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,16 @@ struct type_caster<MlirType> {
321321
}
322322
};
323323

324+
/// Casts MlirStringRef -> object.
325+
template <>
326+
struct type_caster<MlirStringRef> {
327+
NB_TYPE_CASTER(MlirStringRef, const_name("MlirStringRef"))
328+
static handle from_cpp(MlirStringRef s, rv_policy,
329+
cleanup_list *cleanup) noexcept {
330+
return nanobind::str(s.data, s.length).release();
331+
}
332+
};
333+
324334
} // namespace detail
325335
} // namespace nanobind
326336

mlir/lib/Bindings/Python/IRCore.cpp

+41-6
Original file line numberDiff line numberDiff line change
@@ -2943,6 +2943,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
29432943
nb::arg("callee"), nb::arg("frames"),
29442944
nb::arg("context").none() = nb::none(),
29452945
kContextGetCallSiteLocationDocstring)
2946+
.def("is_a_callsite", mlirLocationIsACallSite)
2947+
.def_prop_ro("callee", mlirLocationCallSiteGetCallee)
2948+
.def_prop_ro("caller", mlirLocationCallSiteGetCaller)
29462949
.def_static(
29472950
"file",
29482951
[](std::string filename, int line, int col,
@@ -2967,6 +2970,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
29672970
nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
29682971
nb::arg("end_line"), nb::arg("end_col"),
29692972
nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring)
2973+
.def("is_a_file", mlirLocationIsAFileLineColRange)
2974+
.def_prop_ro("filename",
2975+
[](MlirLocation loc) {
2976+
return mlirIdentifierStr(
2977+
mlirLocationFileLineColRangeGetFilename(loc));
2978+
})
2979+
.def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine)
2980+
.def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn)
2981+
.def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine)
2982+
.def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn)
29702983
.def_static(
29712984
"fused",
29722985
[](const std::vector<PyLocation> &pyLocations,
@@ -2984,6 +2997,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
29842997
nb::arg("locations"), nb::arg("metadata").none() = nb::none(),
29852998
nb::arg("context").none() = nb::none(),
29862999
kContextGetFusedLocationDocstring)
3000+
.def("is_a_fused", mlirLocationIsAFused)
3001+
.def_prop_ro("locations",
3002+
[](MlirLocation loc) {
3003+
unsigned numLocations =
3004+
mlirLocationFusedGetNumLocations(loc);
3005+
std::vector<MlirLocation> locations(numLocations);
3006+
if (numLocations)
3007+
mlirLocationFusedGetLocations(loc, locations.data());
3008+
return locations;
3009+
})
29873010
.def_static(
29883011
"name",
29893012
[](std::string name, std::optional<PyLocation> childLoc,
@@ -2998,6 +3021,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
29983021
nb::arg("name"), nb::arg("childLoc").none() = nb::none(),
29993022
nb::arg("context").none() = nb::none(),
30003023
kContextGetNameLocationDocString)
3024+
.def("is_a_name", mlirLocationIsAName)
3025+
.def_prop_ro("name_str",
3026+
[](MlirLocation loc) {
3027+
return mlirIdentifierStr(mlirLocationNameGetName(loc));
3028+
})
3029+
.def_prop_ro("child_loc", mlirLocationNameGetChildLoc)
30013030
.def_static(
30023031
"from_attr",
30033032
[](PyAttribute &attribute, DefaultingPyMlirContext context) {
@@ -3148,9 +3177,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
31483177
auto &concreteOperation = self.getOperation();
31493178
concreteOperation.checkValid();
31503179
MlirOperation operation = concreteOperation.get();
3151-
MlirStringRef name =
3152-
mlirIdentifierStr(mlirOperationGetName(operation));
3153-
return nb::str(name.data, name.length);
3180+
return mlirIdentifierStr(mlirOperationGetName(operation));
31543181
})
31553182
.def_prop_ro("operands",
31563183
[](PyOperationBase &self) {
@@ -3738,8 +3765,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
37383765
.def_prop_ro(
37393766
"name",
37403767
[](PyNamedAttribute &self) {
3741-
return nb::str(mlirIdentifierStr(self.namedAttr.name).data,
3742-
mlirIdentifierStr(self.namedAttr.name).length);
3768+
return mlirIdentifierStr(self.namedAttr.name);
37433769
},
37443770
"The name of the NamedAttribute binding")
37453771
.def_prop_ro(
@@ -3972,7 +3998,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
39723998
nb::arg("with"), nb::arg("exceptions"),
39733999
kValueReplaceAllUsesExceptDocstring)
39744000
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3975-
[](PyValue &self) { return self.maybeDownCast(); });
4001+
[](PyValue &self) { return self.maybeDownCast(); })
4002+
.def_prop_ro(
4003+
"location",
4004+
[](MlirValue self) {
4005+
return PyLocation(
4006+
PyMlirContext::forContext(mlirValueGetContext(self)),
4007+
mlirValueGetLocation(self));
4008+
},
4009+
"Returns the source location the value");
4010+
39764011
PyBlockArgument::bind(m);
39774012
PyOpResult::bind(m);
39784013
PyOpOperand::bind(m);

mlir/lib/CAPI/IR/IR.cpp

+108-6
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ MlirAttribute mlirLocationGetAttribute(MlirLocation location) {
259259
}
260260

261261
MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) {
262-
return wrap(Location(llvm::cast<LocationAttr>(unwrap(attribute))));
262+
return wrap(Location(llvm::dyn_cast<LocationAttr>(unwrap(attribute))));
263263
}
264264

265265
MlirLocation mlirLocationFileLineColGet(MlirContext context,
@@ -278,10 +278,64 @@ mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename,
278278
startLine, startCol, endLine, endCol)));
279279
}
280280

281+
MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location) {
282+
return wrap(llvm::dyn_cast<FileLineColRange>(unwrap(location)).getFilename());
283+
}
284+
285+
int mlirLocationFileLineColRangeGetStartLine(MlirLocation location) {
286+
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
287+
return loc.getStartLine();
288+
return -1;
289+
}
290+
291+
int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location) {
292+
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
293+
return loc.getStartColumn();
294+
return -1;
295+
}
296+
297+
int mlirLocationFileLineColRangeGetEndLine(MlirLocation location) {
298+
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
299+
return loc.getEndLine();
300+
return -1;
301+
}
302+
303+
int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location) {
304+
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
305+
return loc.getEndColumn();
306+
return -1;
307+
}
308+
309+
MlirTypeID mlirLocationFileLineColRangeGetTypeID() {
310+
return wrap(FileLineColRange::getTypeID());
311+
}
312+
313+
bool mlirLocationIsAFileLineColRange(MlirLocation location) {
314+
return isa<FileLineColRange>(unwrap(location));
315+
}
316+
281317
MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
282318
return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
283319
}
284320

321+
MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location) {
322+
return wrap(
323+
Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCallee()));
324+
}
325+
326+
MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location) {
327+
return wrap(
328+
Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCaller()));
329+
}
330+
331+
MlirTypeID mlirLocationCallSiteGetTypeID() {
332+
return wrap(CallSiteLoc::getTypeID());
333+
}
334+
335+
bool mlirLocationIsACallSite(MlirLocation location) {
336+
return isa<CallSiteLoc>(unwrap(location));
337+
}
338+
285339
MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
286340
MlirLocation const *locations,
287341
MlirAttribute metadata) {
@@ -290,6 +344,30 @@ MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
290344
return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
291345
}
292346

347+
unsigned mlirLocationFusedGetNumLocations(MlirLocation location) {
348+
if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location)))
349+
return locationsArrRef.getLocations().size();
350+
return 0;
351+
}
352+
353+
void mlirLocationFusedGetLocations(MlirLocation location,
354+
MlirLocation *locationsCPtr) {
355+
if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location))) {
356+
for (auto [i, location] : llvm::enumerate(locationsArrRef.getLocations()))
357+
locationsCPtr[i] = wrap(location);
358+
}
359+
}
360+
361+
MlirAttribute mlirLocationFusedGetMetadata(MlirLocation location) {
362+
return wrap(llvm::dyn_cast<FusedLoc>(unwrap(location)).getMetadata());
363+
}
364+
365+
MlirTypeID mlirLocationFusedGetTypeID() { return wrap(FusedLoc::getTypeID()); }
366+
367+
bool mlirLocationIsAFused(MlirLocation location) {
368+
return isa<FusedLoc>(unwrap(location));
369+
}
370+
293371
MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
294372
MlirLocation childLoc) {
295373
if (mlirLocationIsNull(childLoc))
@@ -299,6 +377,21 @@ MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
299377
StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
300378
}
301379

380+
MlirIdentifier mlirLocationNameGetName(MlirLocation location) {
381+
return wrap((llvm::dyn_cast<NameLoc>(unwrap(location)).getName()));
382+
}
383+
384+
MlirLocation mlirLocationNameGetChildLoc(MlirLocation location) {
385+
return wrap(
386+
Location(llvm::dyn_cast<NameLoc>(unwrap(location)).getChildLoc()));
387+
}
388+
389+
MlirTypeID mlirLocationNameGetTypeID() { return wrap(NameLoc::getTypeID()); }
390+
391+
bool mlirLocationIsAName(MlirLocation location) {
392+
return isa<NameLoc>(unwrap(location));
393+
}
394+
302395
MlirLocation mlirLocationUnknownGet(MlirContext context) {
303396
return wrap(Location(UnknownLoc::get(unwrap(context))));
304397
}
@@ -975,25 +1068,26 @@ bool mlirValueIsAOpResult(MlirValue value) {
9751068
}
9761069

9771070
MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
978-
return wrap(llvm::cast<BlockArgument>(unwrap(value)).getOwner());
1071+
return wrap(llvm::dyn_cast<BlockArgument>(unwrap(value)).getOwner());
9791072
}
9801073

9811074
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
9821075
return static_cast<intptr_t>(
983-
llvm::cast<BlockArgument>(unwrap(value)).getArgNumber());
1076+
llvm::dyn_cast<BlockArgument>(unwrap(value)).getArgNumber());
9841077
}
9851078

9861079
void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
987-
llvm::cast<BlockArgument>(unwrap(value)).setType(unwrap(type));
1080+
if (auto blockArg = llvm::dyn_cast<BlockArgument>(unwrap(value)))
1081+
blockArg.setType(unwrap(type));
9881082
}
9891083

9901084
MlirOperation mlirOpResultGetOwner(MlirValue value) {
991-
return wrap(llvm::cast<OpResult>(unwrap(value)).getOwner());
1085+
return wrap(llvm::dyn_cast<OpResult>(unwrap(value)).getOwner());
9921086
}
9931087

9941088
intptr_t mlirOpResultGetResultNumber(MlirValue value) {
9951089
return static_cast<intptr_t>(
996-
llvm::cast<OpResult>(unwrap(value)).getResultNumber());
1090+
llvm::dyn_cast<OpResult>(unwrap(value)).getResultNumber());
9971091
}
9981092

9991093
MlirType mlirValueGetType(MlirValue value) {
@@ -1047,6 +1141,14 @@ void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue,
10471141
oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet);
10481142
}
10491143

1144+
MlirLocation mlirValueGetLocation(MlirValue v) {
1145+
return wrap(unwrap(v).getLoc());
1146+
}
1147+
1148+
MlirContext mlirValueGetContext(MlirValue v) {
1149+
return wrap(unwrap(v).getContext());
1150+
}
1151+
10501152
//===----------------------------------------------------------------------===//
10511153
// OpOperand API.
10521154
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)