[CuTeDSL] Add docs of how to call cute.jit functions in C++ via TVM-FFI#3289
[CuTeDSL] Add docs of how to call cute.jit functions in C++ via TVM-FFI#3289kainzhong wants to merge 6 commits into
Conversation
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
|
Thanks a lot! @tqchen |
| When you compile a ``@cute.jit`` function with the ``--enable-tvm-ffi`` option, ``cute.compile`` will return a TVM FFI function object. | ||
| Then you can register it as a `Global Function <https://tvm.apache.org/ffi/guides/export_func_cls.html#global-functions>`_ to make it accessible from other languages. | ||
|
|
||
| In C++, you can load the compiled function from the global registry with ``tvm::ffi::Function::GetGlobal`` (or ``tvm::ffi::Function::GetGlobalRequired``, which will throw if the function is not found). |
There was a problem hiding this comment.
How can we dump the tvm ffi function from cutedsl, can we leverage AOT? Can you explain this part more clearly. Other parts look good to me, thanks
There was a problem hiding this comment.
I add some description on this but really it's just doing AOT compilation and then calling this TVM-FFI API as in my example.
There was a problem hiding this comment.
One thing that may be relevant surfacing is https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_cpp_bundle.cpp
see how it is not using the global function registry and relies on the symbol being available
There was a problem hiding this comment.
Ah yes. I also added some explanation of how to load symbols from shared libraries. Although I think using the global registry is more straightforward if I'm working with frameworks running in both C++ and python since in this way I don't need to handle the shared library manually.
Also the script path in aot_use_in_cpp_bundle.cpp is a bit outdated. I fixed that a bit as well
| void apply_tvm_function(const std::string& name, at::Tensor &x, at::Tensor &y, at::Tensor &z) { | ||
| tvm::ffi::Function fn = tvm::ffi::Function::GetGlobalRequired(name); | ||
| DLTensor dl_x = {}; | ||
| DLTensor dl_y = {}; |
There was a problem hiding this comment.
one possible option is to actually make use of tvm ffi's export mechanism just like flashinfer
This way exported function also can be used in various tvm-ffi compatible scenarios and callable from pytorch
In this case, you can also do
using ffi = tvm::ffi;
void ApplyTVMFFIFunc(ffi::Function f, ffi::TensorView x, ffi::TensorView y, ffi::TensorView z) {
f(f, x, y, z);
}
TVM_FFI_DLL_EXPORT_TYPED_FUNC(ApplyTVMFFIFunc);There was a problem hiding this comment.
There was a problem hiding this comment.
Hi @tqchen
I think the use case here is to create and export a TVM-FFI function in python, then call it in C++, which would be helpful if you are working with a C++ heavy framework and you just want to leverage CuTeDSL's flexibility and run a kernel without going to python region.
I feel like TVM_FFI_DLL_EXPORT_TYPED_FUNC is more like exporting a C++ function to python, which seems to be the case in flashinfer and it's rather the opposite direction (we want to export in python and call it in C++).
I include the pybind code in my example only because I need to run the C++ function to demonstrate the code is working (as I mentioned in the doc, in practice you probably want to call the C++ function in C++).
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Add some documentation to explain how to call cute.jit function in C++.
In our use case, we would like to avoid python overhead and stay in C++ environment, but also want to use CuTeDSL to implement kernels. Therefore we plan to use this approach so we can completely bypass python. I feel like some brief description here will make it easier for people to just get it working since people who work with kernels might not be familiar with TVM-FFI. This would act as a quick walkthrough so people can quickly get CuTeDSL working in C++.
Also, some explanation here also applies to using TVM-FFI exported function in C++ if you choose to export using an object file, so I figure it would be helpful./