Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ crate-type = ["cdylib"]

[dependencies]
rayon = "1.10"
serde = { version = "1.0", features = [ "rc", "derive" ]}
serde = { version = "1.0", features = ["rc", "derive"] }
serde_json = "1.0"
libc = "0.2"
env_logger = "0.11"
pyo3 = { version = "0.21" }
numpy = "0.21"
pyo3 = { version = "0.22", features = ["py-clone"] }
numpy = "0.22"
ndarray = "0.15"
itertools = "0.12"

Expand All @@ -24,7 +24,7 @@ path = "../../tokenizers"

[dev-dependencies]
tempfile = "3.10"
pyo3 = { version = "0.21", features = ["auto-initialize"] }
pyo3 = { version = "0.22", features = ["auto-initialize", "py-clone"] }

[features]
defaut = ["pyo3/extension-module"]
10 changes: 5 additions & 5 deletions bindings/python/py_src/tokenizers/models/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Model:
"""
pass

def save(self, folder, prefix):
def save(self, folder, prefix, name):
"""
Save the current model

Expand Down Expand Up @@ -204,7 +204,7 @@ class BPE(Model):
"""
pass

def save(self, folder, prefix):
def save(self, folder, prefix, name):
"""
Save the current model

Expand Down Expand Up @@ -286,7 +286,7 @@ class Unigram(Model):
"""
pass

def save(self, folder, prefix):
def save(self, folder, prefix, name):
"""
Save the current model

Expand Down Expand Up @@ -414,7 +414,7 @@ class WordLevel(Model):
"""
pass

def save(self, folder, prefix):
def save(self, folder, prefix, name):
"""
Save the current model

Expand Down Expand Up @@ -544,7 +544,7 @@ class WordPiece(Model):
"""
pass

def save(self, folder, prefix):
def save(self, folder, prefix, name):
"""
Save the current model

Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ impl PyDecoder {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.decoder = serde_json::from_slice(s.as_bytes()).map_err(|e| {
self.decoder = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Decoder: {}",
e
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ impl PyEncoding {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.encoding = serde_json::from_slice(s.as_bytes()).map_err(|e| {
self.encoding = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Encoding: {}",
e
Expand Down
8 changes: 4 additions & 4 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ impl PyModel {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.model = serde_json::from_slice(s.as_bytes()).map_err(|e| {
self.model = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Model: {}",
e
Expand Down Expand Up @@ -181,7 +181,7 @@ impl PyModel {
///
/// Returns:
/// :obj:`List[str]`: The list of saved files
#[pyo3(text_signature = "(self, folder, prefix)")]
#[pyo3(signature = (folder, prefix=None, name=None), text_signature = "(self, folder, prefix, name)")]
fn save<'a>(
&self,
py: Python<'_>,
Expand Down Expand Up @@ -835,7 +835,7 @@ pub struct PyUnigram {}
#[pymethods]
impl PyUnigram {
#[new]
#[pyo3(text_signature = "(self, vocab, unk_id, byte_fallback)")]
#[pyo3(signature = (vocab=None, unk_id=None, byte_fallback=None), text_signature = "(self, vocab, unk_id, byte_fallback)")]
fn new(
vocab: Option<Vec<(String, f64)>>,
unk_id: Option<usize>,
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ impl PyNormalizer {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.normalizer = serde_json::from_slice(s.as_bytes()).map_err(|e| {
self.normalizer = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Normalizer: {}",
e
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/pre_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ impl PyPreTokenizer {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| {
let unpickled = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle PreTokenizer: {}",
e
Expand Down
8 changes: 4 additions & 4 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ impl PyPostProcessor {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.processor = serde_json::from_slice(s.as_bytes()).map_err(|e| {
self.processor = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle PostProcessor: {}",
e
Expand Down Expand Up @@ -272,7 +272,7 @@ impl From<PySpecialToken> for SpecialToken {
}

impl FromPyObject<'_> for PySpecialToken {
fn extract(ob: &PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(v) = ob.extract::<(String, u32)>() {
Ok(Self(v.into()))
} else if let Ok(v) = ob.extract::<(u32, String)>() {
Expand Down Expand Up @@ -312,7 +312,7 @@ impl From<PyTemplate> for Template {
}

impl FromPyObject<'_> for PyTemplate {
fn extract(ob: &PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(s) = ob.extract::<&str>() {
Ok(Self(
s.try_into().map_err(exceptions::PyValueError::new_err)?,
Expand Down
46 changes: 23 additions & 23 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use serde::Serialize;
use std::collections::{hash_map::DefaultHasher, HashMap};
use std::hash::{Hash, Hasher};

use numpy::{npyffi, PyArray1};
use numpy::{npyffi, PyArray1, PyArrayMethods};
use pyo3::class::basic::CompareOp;
use pyo3::exceptions;
use pyo3::intern;
Expand Down Expand Up @@ -156,7 +156,7 @@ impl PyAddedToken {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyDict>(py) {
match state.downcast_bound::<PyDict>(py) {
Ok(state) => {
for (key, value) in state {
let key: &str = key.extract()?;
Expand All @@ -172,7 +172,7 @@ impl PyAddedToken {
}
Ok(())
}
Err(e) => Err(e),
Err(e) => Err(e.into()),
}
}

Expand Down Expand Up @@ -263,10 +263,10 @@ impl PyAddedToken {

struct TextInputSequence<'s>(tk::InputSequence<'s>);
impl<'s> FromPyObject<'s> for TextInputSequence<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'s, PyAny>) -> PyResult<Self> {
let err = exceptions::PyTypeError::new_err("TextInputSequence must be str");
if let Ok(s) = ob.downcast::<PyString>() {
Ok(Self(s.to_string_lossy().into()))
if let Ok(s) = ob.extract::<String>() {
Ok(Self(s.into()))
} else {
Err(err)
}
Expand All @@ -280,7 +280,7 @@ impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {

struct PyArrayUnicode(Vec<String>);
impl FromPyObject<'_> for PyArrayUnicode {
fn extract(ob: &PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
// SAFETY Making sure the pointer is a valid numpy array requires calling numpy C code
if unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) } == 0 {
return Err(exceptions::PyTypeError::new_err("Expected an np.array"));
Expand All @@ -291,8 +291,8 @@ impl FromPyObject<'_> for PyArrayUnicode {
let desc = (*arr).descr;
(
(*desc).type_num,
(*desc).elsize as usize,
(*desc).alignment as usize,
npyffi::PyDataType_ELSIZE(ob.py(), desc) as usize,
npyffi::PyDataType_ALIGNMENT(ob.py(), desc) as usize,
(*arr).data,
(*arr).nd,
(*arr).flags,
Expand Down Expand Up @@ -347,7 +347,7 @@ impl From<PyArrayUnicode> for tk::InputSequence<'_> {

struct PyArrayStr(Vec<String>);
impl FromPyObject<'_> for PyArrayStr {
fn extract(ob: &PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
let array = ob.downcast::<PyArray1<PyObject>>()?;
let seq = array
.readonly()
Expand All @@ -370,7 +370,7 @@ impl From<PyArrayStr> for tk::InputSequence<'_> {

struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'s, PyAny>) -> PyResult<Self> {
if let Ok(seq) = ob.extract::<PyArrayUnicode>() {
return Ok(Self(seq.into()));
}
Expand Down Expand Up @@ -400,17 +400,17 @@ impl<'s> From<PreTokenizedInputSequence<'s>> for tk::InputSequence<'s> {

struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
impl<'s> FromPyObject<'s> for TextEncodeInput<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'s, PyAny>) -> PyResult<Self> {
if let Ok(i) = ob.extract::<TextInputSequence>() {
return Ok(Self(i.into()));
}
if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() {
return Ok(Self((i1, i2).into()));
}
if let Ok(arr) = ob.extract::<Vec<&PyAny>>() {
if let Ok(arr) = ob.downcast::<PyList>() {
if arr.len() == 2 {
let first = arr[0].extract::<TextInputSequence>()?;
let second = arr[1].extract::<TextInputSequence>()?;
let first = arr.get_item(0)?.extract::<TextInputSequence>()?;
let second = arr.get_item(1)?.extract::<TextInputSequence>()?;
return Ok(Self((first, second).into()));
}
}
Expand All @@ -426,18 +426,18 @@ impl<'s> From<TextEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
}
struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'s, PyAny>) -> PyResult<Self> {
if let Ok(i) = ob.extract::<PreTokenizedInputSequence>() {
return Ok(Self(i.into()));
}
if let Ok((i1, i2)) = ob.extract::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>()
{
return Ok(Self((i1, i2).into()));
}
if let Ok(arr) = ob.extract::<Vec<&PyAny>>() {
if let Ok(arr) = ob.downcast::<PyList>() {
if arr.len() == 2 {
let first = arr[0].extract::<PreTokenizedInputSequence>()?;
let second = arr[1].extract::<PreTokenizedInputSequence>()?;
let first = arr.get_item(0)?.extract::<PreTokenizedInputSequence>()?;
let second = arr.get_item(1)?.extract::<PreTokenizedInputSequence>()?;
return Ok(Self((first, second).into()));
}
}
Expand Down Expand Up @@ -498,9 +498,9 @@ impl PyTokenizer {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
self.tokenizer = serde_json::from_slice(s.as_bytes()).map_err(|e| {
self.tokenizer = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Tokenizer: {}",
e
Expand Down Expand Up @@ -1030,7 +1030,7 @@ impl PyTokenizer {
fn encode_batch(
&self,
py: Python<'_>,
input: Vec<&PyAny>,
input: Bound<'_, PyList>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually this broke tokenizers because it only supports PyList now 😓 looking into a fix!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@diliop we need this to be probably PySequence, but I am not sure about the fix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I just saw this :( Yeah I "assumed" that Vec was only accepting list from Python hence PyList but if you need tuples then PySequence is indeed the way to go. Py* will give you the benefit of the Python type check at almost zero cost as you mentioned in your PR so they should be preferred where possible. That said, there should be tests covering this so I can pick this up over the weekend and make sure anything else I changed is also covered.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool yeah was in a rush to fix this, forgot about the tests, super nice if you want to add them 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a quick stab at adding tests but from the looks of it I will need to spend a bit more time here to do this right since PySequence is not the right solution after all. The TL;DR here is that the reason why the change I made was not caught by tests is because the test covering this line was turned off (here). Turning the test back on will now fail on parsing ndarray as an additional input type. So encode_batch and encode_batch_fast need to support list, tuple and ndarray. I think I can support all 3 by changing the input arg type to something like Vec<Bound<'_, PyAny>> with some changes in PreTokenizedEncodeInput and TextEncodeInput. I will have some time to work on this over the weekend so hopefully I have fix for this soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1679 should be the fix. Still planning to add more tests but this my hope is would be enough to restore the previous functionality.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for diving into it!

is_pretokenized: bool,
add_special_tokens: bool,
) -> PyResult<Vec<PyEncoding>> {
Expand Down Expand Up @@ -1091,7 +1091,7 @@ impl PyTokenizer {
fn encode_batch_fast(
&self,
py: Python<'_>,
input: Vec<&PyAny>,
input: Bound<'_, PyList>,
is_pretokenized: bool,
add_special_tokens: bool,
) -> PyResult<Vec<PyEncoding>> {
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/trainers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ impl PyTrainer {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
match state.extract::<&[u8]>(py) {
Ok(s) => {
let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| {
let unpickled = serde_json::from_slice(s).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle PyTrainer: {}",
e
Expand Down
6 changes: 3 additions & 3 deletions bindings/python/src/utils/normalization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub enum PyRange<'s> {
#[pyo3(annotation = "Tuple[uint, uint]")]
Range(usize, usize),
#[pyo3(annotation = "slice")]
Slice(&'s PySlice),
Slice(Bound<'s, PySlice>),
}
impl PyRange<'_> {
pub fn to_range(&self, max_len: usize) -> PyResult<std::ops::Range<usize>> {
Expand All @@ -83,7 +83,7 @@ impl PyRange<'_> {
}
PyRange::Range(s, e) => Ok(*s..*e),
PyRange::Slice(s) => {
let r = s.indices(max_len as std::os::raw::c_long)?;
let r = s.indices(max_len.try_into()?)?;
Ok(r.start as usize..r.stop as usize)
}
}
Expand All @@ -94,7 +94,7 @@ impl PyRange<'_> {
pub struct PySplitDelimiterBehavior(pub SplitDelimiterBehavior);

impl FromPyObject<'_> for PySplitDelimiterBehavior {
fn extract(obj: &PyAny) -> PyResult<Self> {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
let s = obj.extract::<&str>()?;

Ok(Self(match s {
Expand Down
6 changes: 3 additions & 3 deletions bindings/python/src/utils/pretokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn tokenize(pretok: &mut PreTokenizedString, func: &Bound<'_, PyAny>) -> PyResul
ToPyResult(pretok.tokenize(|normalized| {
let output = func.call((normalized.get(),), None)?;
Ok(output
.extract::<&PyList>()?
.extract::<Bound<PyList>>()?
.into_iter()
.map(|obj| Ok(Token::from(obj.extract::<PyToken>()?)))
.collect::<PyResult<Vec<_>>>()?)
Expand All @@ -69,7 +69,7 @@ fn tokenize(pretok: &mut PreTokenizedString, func: &Bound<'_, PyAny>) -> PyResul
#[derive(Clone)]
pub struct PyOffsetReferential(OffsetReferential);
impl FromPyObject<'_> for PyOffsetReferential {
fn extract(obj: &PyAny) -> PyResult<Self> {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
let s = obj.extract::<&str>()?;

Ok(Self(match s {
Expand All @@ -85,7 +85,7 @@ impl FromPyObject<'_> for PyOffsetReferential {
#[derive(Clone)]
pub struct PyOffsetType(OffsetType);
impl FromPyObject<'_> for PyOffsetType {
fn extract(obj: &PyAny) -> PyResult<Self> {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
let s = obj.extract::<&str>()?;

Ok(Self(match s {
Expand Down