File tree Expand file tree Collapse file tree 5 files changed +15
-10
lines changed
examples/python-extension Expand file tree Collapse file tree 5 files changed +15
-10
lines changed Original file line number Diff line number Diff line change @@ -17,7 +17,7 @@ name = "tch_ext"
17
17
crate-type = [" cdylib" ]
18
18
19
19
[dependencies ]
20
- pyo3 = { version = " 0.21 " , features = [" extension-module" ] }
20
+ pyo3 = { version = " 0.24 " , features = [" extension-module" ] }
21
21
pyo3-tch = { path = " ../../pyo3-tch" , version = " 0.20.0" }
22
22
tch = { path = " ../.." , features = [" python-extension" ], version = " 0.20.0" }
23
- torch-sys = { path = " ../../torch-sys" , features = [" python-extension" ], version = " 0.20.0" }
23
+ torch-sys = { path = " ../../torch-sys" , features = [" python-extension" ], version = " 0.20.0" }
Original file line number Diff line number Diff line change @@ -11,7 +11,7 @@ Python environment that has torch installed from the root of the github repo.
11
11
12
12
``` bash
13
13
LIBTORCH_USE_PYTORCH=1 cargo build -p tch-ext && cp -f target/debug/libtch_ext.so tch_ext.so
14
- python examples/python-extension/main .py
14
+ PYTHONPATH=. python examples/python-extension/test .py
15
15
```
16
16
17
17
It is recommended to run the build with ` LIBTORCH_USE_PYTORCH ` set, this will
Original file line number Diff line number Diff line change @@ -11,7 +11,7 @@ fn add_one(tensor: PyTensor) -> PyResult<PyTensor> {
11
11
/// objects.
12
12
#[ pymodule]
13
13
fn tch_ext ( py : Python < ' _ > , m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
14
- py. import_bound ( "torch" ) ?;
14
+ py. import ( "torch" ) ?;
15
15
m. add_function ( wrap_pyfunction ! ( add_one, m) ?) ?;
16
16
Ok ( ( ) )
17
17
}
Original file line number Diff line number Diff line change @@ -14,4 +14,4 @@ license = "MIT/Apache-2.0"
14
14
[dependencies ]
15
15
tch = { path = " .." , features = [" python-extension" ], version = " 0.20.0" }
16
16
torch-sys = { path = " ../torch-sys" , features = [" python-extension" ], version = " 0.20.0" }
17
- pyo3 = { version = " 0.21 " , features = [" extension-module" ] }
17
+ pyo3 = { version = " 0.24 " , features = [" extension-module" ] }
Original file line number Diff line number Diff line change @@ -18,7 +18,7 @@ pub fn wrap_tch_err(err: tch::TchError) -> PyErr {
18
18
}
19
19
20
20
impl < ' source > FromPyObject < ' source > for PyTensor {
21
- fn extract ( ob : & ' source PyAny ) -> PyResult < Self > {
21
+ fn extract_bound ( ob : & Bound < ' source , PyAny > ) -> PyResult < Self > {
22
22
let ptr = ob. as_ptr ( ) as * mut tch:: python:: CPyObject ;
23
23
let tensor = unsafe { tch:: Tensor :: pyobject_unpack ( ptr) } ;
24
24
tensor
@@ -31,13 +31,18 @@ impl<'source> FromPyObject<'source> for PyTensor {
31
31
}
32
32
}
33
33
34
- impl IntoPy < PyObject > for PyTensor {
35
- fn into_py ( self , py : Python < ' _ > ) -> PyObject {
34
+ impl < ' py > IntoPyObject < ' py > for PyTensor {
35
+ type Output = Bound < ' py , Self :: Target > ;
36
+ type Target = PyAny ;
37
+ type Error = PyErr ;
38
+
39
+ fn into_pyobject ( self , py : Python < ' py > ) -> Result < Self :: Output , Self :: Error > {
36
40
// There is no fallible alternative to ToPyObject/IntoPy at the moment so we return
37
41
// None on errors. https://github.com/PyO3/pyo3/issues/1813
38
- self . 0 . pyobject_wrap ( ) . map_or_else (
42
+ let v = self . 0 . pyobject_wrap ( ) . map_or_else (
39
43
|_| py. None ( ) ,
40
44
|ptr| unsafe { PyObject :: from_owned_ptr ( py, ptr as * mut pyo3:: ffi:: PyObject ) } ,
41
- )
45
+ ) ;
46
+ Ok ( v. into_pyobject ( py) ?)
42
47
}
43
48
}
You can’t perform that action at this time.
0 commit comments