Skip to content

Commit 1729c81

Browse files
authored
fix: writer support for numpy scalars (#72)
1 parent b5558b3 commit 1729c81

File tree

4 files changed

+88
-24
lines changed

4 files changed

+88
-24
lines changed

src/reader.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use crate::{
66
use delegate::delegate;
77
use num_traits::Zero;
88
use numpy::{
9-
ndarray::{self, Array0},
10-
Element, PyArray0,
9+
ndarray::{self},
10+
Element,
1111
};
1212
use omfiles_rs::{
1313
reader::{OmFileArray as OmFileArrayRs, OmFileReader as OmFileReaderRs},
@@ -124,10 +124,23 @@ impl OmFileReader {
124124
.map_err(|_| Self::only_scalars_error())?;
125125

126126
let value = scalar_reader.read_scalar::<T>();
127-
let array_base = Array0::from_elem([], value.unwrap());
128-
let py_scalar = PyArray0::from_owned_array(py, array_base);
129127

130-
return Ok(py_scalar.into_any().unbind());
128+
let numpy = py.import("numpy")?;
129+
let np_type = match std::any::type_name::<T>() {
130+
"f32" => numpy.getattr("float32")?,
131+
"f64" => numpy.getattr("float64")?,
132+
"i8" => numpy.getattr("int8")?,
133+
"u8" => numpy.getattr("uint8")?,
134+
"i16" => numpy.getattr("int16")?,
135+
"u16" => numpy.getattr("uint16")?,
136+
"i32" => numpy.getattr("int32")?,
137+
"u32" => numpy.getattr("uint32")?,
138+
"i64" => numpy.getattr("int64")?,
139+
"u64" => numpy.getattr("uint64")?,
140+
_ => return Err(PyErr::new::<PyValueError, _>("Unsupported type")),
141+
};
142+
let py_scalar = np_type.call1((value,))?;
143+
Ok(py_scalar.into())
131144
})
132145
}
133146
}
@@ -536,7 +549,7 @@ impl OmFileReader {
536549
/// Read the scalar value of the variable.
537550
///
538551
/// Returns:
539-
/// object: The scalar value as a Python object (str, int, or float).
552+
/// object: The scalar value as a numpy scalar or a Python string.
540553
///
541554
/// Raises:
542555
/// ValueError: If the variable is not a scalar.

src/reader_async.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ use async_lock::RwLock;
77
use delegate::delegate;
88
use num_traits::Zero;
99
use numpy::{
10-
ndarray::{self, Array0},
11-
Element, PyArray0,
10+
ndarray::{self},
11+
Element,
1212
};
1313
use omfiles_rs::{
1414
reader_async::OmFileReaderAsync as OmFileReaderAsyncRs,
@@ -135,10 +135,23 @@ impl OmFileReaderAsync {
135135
.map_err(|_| Self::only_scalars_error())?;
136136

137137
let value = scalar_reader.read_scalar::<T>();
138-
let array_base = Array0::from_elem([], value.unwrap());
139-
let py_scalar = PyArray0::from_owned_array(py, array_base);
140138

141-
return Ok(py_scalar.into_any().unbind());
139+
let numpy = py.import("numpy")?;
140+
let np_type = match std::any::type_name::<T>() {
141+
"f32" => numpy.getattr("float32")?,
142+
"f64" => numpy.getattr("float64")?,
143+
"i8" => numpy.getattr("int8")?,
144+
"u8" => numpy.getattr("uint8")?,
145+
"i16" => numpy.getattr("int16")?,
146+
"u16" => numpy.getattr("uint16")?,
147+
"i32" => numpy.getattr("int32")?,
148+
"u32" => numpy.getattr("uint32")?,
149+
"i64" => numpy.getattr("int64")?,
150+
"u64" => numpy.getattr("uint64")?,
151+
_ => return Err(PyErr::new::<PyValueError, _>("Unsupported type")),
152+
};
153+
let py_scalar = np_type.call1((value,))?;
154+
Ok(py_scalar.into())
142155
})
143156
}
144157
}
@@ -486,7 +499,7 @@ impl OmFileReaderAsync {
486499
/// Read the scalar value of the variable.
487500
///
488501
/// Returns:
489-
/// object: The scalar value as a Python object (str, int, or float).
502+
/// object: The scalar value as a numpy scalar or a Python string.
490503
///
491504
/// Raises:
492505
/// ValueError: If the variable is not a scalar.

src/writer.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,35 @@ impl OmFileWriter {
332332
.map(Into::into)
333333
.collect();
334334

335+
let py = value.py();
336+
337+
// make an instance check against numpy scalar types
338+
macro_rules! check_numpy_type {
339+
($numpy:expr, $type_name:literal, $rust_type:ty) => {
340+
if let Ok(numpy_type) = $numpy.getattr($type_name) {
341+
if value.is_instance(&numpy_type)? {
342+
let scalar_value: $rust_type = value.call_method0("item")?.extract()?;
343+
return self.store_scalar(scalar_value, name, &children);
344+
}
345+
}
346+
};
347+
}
348+
349+
// Try to import numpy and check for numpy scalar types
350+
if let Ok(numpy) = py.import("numpy") {
351+
check_numpy_type!(numpy, "int8", i8);
352+
check_numpy_type!(numpy, "uint8", u8);
353+
check_numpy_type!(numpy, "int16", i16);
354+
check_numpy_type!(numpy, "uint16", u16);
355+
check_numpy_type!(numpy, "int32", i32);
356+
check_numpy_type!(numpy, "uint32", u32);
357+
check_numpy_type!(numpy, "int64", i64);
358+
check_numpy_type!(numpy, "uint64", u64);
359+
check_numpy_type!(numpy, "float32", f32);
360+
check_numpy_type!(numpy, "float64", f64);
361+
}
362+
363+
// Fall back to Python built-in types
335364
let result = if let Ok(_value) = value.extract::<String>() {
336365
self.store_scalar(value.to_string(), name, &children)?
337366
} else if let Ok(value) = value.extract::<f64>() {

tests/test_read_write.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def test_write_hierarchical_file(empty_temp_om_file):
7171
child2_var = writer.write_array(child2_data, chunks=[1, 1], name="child2", scale_factor=100000.0)
7272

7373
# Write attributes and get their variables
74-
meta1_var = writer.write_scalar(42.0, name="metadata1")
75-
meta2_var = writer.write_scalar(123, name="metadata2")
74+
meta1_var = writer.write_scalar(np.float32(42.0), name="metadata1")
75+
meta2_var = writer.write_scalar(np.int32(123), name="metadata2")
7676
meta3_var = writer.write_scalar("blub", name="metadata3")
7777

7878
# Write child1 array with attribute children
@@ -115,21 +115,30 @@ def test_write_hierarchical_file(empty_temp_om_file):
115115
assert read_child2.dtype == np.float32
116116

117117
# Verify metadata attributes
118-
metadata_reader0 = child1_reader.get_child_by_index(0)
119-
metadata = metadata_reader0.read_scalar()
120-
assert metadata == 42.0
121-
assert metadata_reader0.dtype == np.float64
122-
123-
metadata_reader2 = child1_reader.get_child_by_index(2)
124-
metadata = metadata_reader2.read_scalar()
125-
assert metadata == "blub"
126-
assert metadata_reader2.dtype == str
118+
metadata_reader1 = child1_reader.get_child_by_index(0)
119+
metadata1 = metadata_reader1.read_scalar()
120+
assert metadata1 == 42.0
121+
assert type(metadata1) == np.float32
122+
assert metadata_reader1.dtype == np.float32
123+
124+
metadata_reader2 = child1_reader.get_child_by_index(1)
125+
metadata2 = metadata_reader2.read_scalar()
126+
assert metadata2 == 123
127+
assert type(metadata2) == np.int32
128+
assert metadata_reader2.dtype == np.int32
129+
130+
metadata_reader3 = child1_reader.get_child_by_index(2)
131+
metadata3 = metadata_reader3.read_scalar()
132+
assert metadata3 == "blub"
133+
assert type(metadata3) == str
134+
assert metadata_reader3.dtype == str
127135

128136
reader.close()
129137
child1_reader.close()
130138
child2_reader.close()
131-
metadata_reader0.close()
139+
metadata_reader1.close()
132140
metadata_reader2.close()
141+
metadata_reader3.close()
133142

134143

135144
@pytest.mark.asyncio

0 commit comments

Comments
 (0)