Skip to content

Commit bc77907

Browse files
authored
Merge pull request #26 from ahuang11/fix_max_frames
Fix max frames
2 parents 2c14def + a0dd967 commit bc77907

File tree

10 files changed

+184
-17
lines changed

10 files changed

+184
-17
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: Build
22

3-
on: [push, pull_request]
3+
on: [pull_request]
44

55
jobs:
66
test:

docs/supported_formats.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
StreamJoy supports a variety of input types!
44

5-
## 📋 List of Images or URLs
5+
## 📋 List of Images, GIFs, Videos, or URLs
66

77
```python
88
from streamjoy import stream
@@ -15,7 +15,7 @@ stream([URL_FMT.format(day=day) for day in range(1, 31)], uri="2024_jan_sea_ice.
1515
<source src="https://github.com/ahuang11/streamjoy/assets/15331990/7c933cd4-aa15-461a-af79-f508d9d76aa5" type="video/mp4">
1616
</video>
1717

18-
## 📁 Directory of Images or URLs
18+
## 📁 Directory of Images, GIFs, Videos, or URLs
1919

2020
```python
2121
from streamjoy import stream
@@ -29,6 +29,20 @@ stream(URL_DIR, uri="air_temperature.mp4", pattern="air.sig995.194*.nc")
2929
<source src="https://github.com/ahuang11/streamjoy/assets/15331990/93cb0c1b-46d3-48e6-be2c-e3b1487f9117" type="video/mp4">
3030
</video>
3131

32+
## 🧮 Numpy Array
33+
34+
```python
35+
from streamjoy import stream
36+
import imageio.v3 as iio
37+
38+
array = iio.imread("imageio:newtonscradle.gif") # is a 4D numpy array
39+
stream(array, max_frames=-1).write("newtonscradle.mp4")
40+
```
41+
42+
<video controls="true" allowfullscreen="true">
43+
<source src="https://github.com/ahuang11/streamjoy/assets/15331990/7687e951-654c-4719-b50a-4aabc0ddf2e4" type="video/mp4">
44+
</video>
45+
3246
## 🐼 Pandas DataFrame or Series
3347

3448
```python

streamjoy/_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,6 @@ def imread_with_pause(
327327
def subset_resources_renderer_iterables(
328328
resources: Any, renderer_iterables: list[Any], max_frames: int
329329
):
330-
if len(resources) > max_frames and max_frames != -1:
331-
color = config["logging_warning_color"]
332-
reset = config["logging_reset_color"]
333-
logging.warning(
334-
f"There are a total of {len(resources)} frames, "
335-
f"but streaming only {color}{max_frames}{reset}. "
336-
f"Set max_frames to -1 to stream all frames."
337-
)
338330
resources = resources[: max_frames or max_frames]
339331
renderer_iterables = [
340332
iterable[: len(resources)] for iterable in renderer_iterables or []

streamjoy/serializers.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,50 @@ def _select_obj_handler(resources: Any) -> MediaStream:
4848
if isinstance(resources, (Path, str)):
4949
return serialize_paths
5050

51+
resources_type = type(resources)
52+
module = getattr(resources_type, "__module__").split(".", maxsplit=1)[0]
53+
type_ = resources_type.__name__
5154
for class_or_package_name, function_name in obj_handlers.items():
52-
module = getattr(resources, "__module__", "").split(".", maxsplit=1)[0]
53-
type_ = type(resources).__name__
5455
if (
5556
f"{module}.{type_}" == class_or_package_name
5657
or module == class_or_package_name
5758
):
5859
return globals()[function_name]
5960

6061
raise ValueError(
61-
f"Could not find a method to handle {type(resources)}; "
62+
f"Could not find a method to handle {resources_type}; "
6263
f"supported classes/packages are {list(obj_handlers.keys())}."
6364
)
6465

6566

67+
def serialize_numpy(
68+
stream_cls,
69+
resources: np.ndarray,
70+
renderer: Callable | None = None,
71+
renderer_iterables: list[Any] | None = None,
72+
renderer_kwargs: dict | None = None,
73+
**kwargs,
74+
) -> Serialized:
75+
"""
76+
Serialize numpy arrays for streaming or rendering.
77+
78+
Args:
79+
stream_cls: The class reference used for logging and utility functions.
80+
resources: The numpy array to be serialized.
81+
renderer: The rendering function to use on the array.
82+
renderer_iterables: Additional iterable arguments to pass to the renderer.
83+
renderer_kwargs: Additional keyword arguments to pass to the renderer.
84+
**kwargs: Additional keyword arguments, including 'dim' and 'var' for xarray selection.
85+
86+
Returns:
87+
A tuple containing the serialized resources, renderer, renderer_iterables, renderer_kwargs, and any additional keyword arguments.
88+
"""
89+
resources = [resource for resource in resources]
90+
renderer_kwargs = renderer_kwargs or {}
91+
renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs))
92+
return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs)
93+
94+
6695
def serialize_xarray(
6796
stream_cls,
6897
resources: xr.Dataset | xr.DataArray,
@@ -230,8 +259,8 @@ def serialize_polars(
230259
groupby = kwargs.get("groupby")
231260

232261
if groupby:
233-
group_sizes = resources.groupby(groupby).agg(pl.count()).sort(by="count")
234-
total_frames = group_sizes.select(pl.col("count").max()).to_numpy()[0, 0]
262+
group_sizes = resources.groupby(groupby).agg(pl.len())
263+
total_frames = group_sizes.select(pl.col("len").max()).to_numpy()[0, 0]
235264
else:
236265
total_frames = len(resources)
237266

@@ -276,6 +305,12 @@ def serialize_polars(
276305
if "ylabel" not in renderer_kwargs:
277306
renderer_kwargs["ylabel"] = renderer_kwargs["y"].title().replace("_", " ")
278307

308+
if kwargs.get("processes"):
309+
logging.warning(
310+
"Polars (HoloViews) rendering does not support processes; "
311+
"setting processes=False."
312+
)
313+
kwargs["processes"] = False
279314
return Serialized(
280315
resources_expanded, renderer, renderer_iterables, renderer_kwargs, kwargs
281316
)

streamjoy/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"pandas.Series": "serialize_pandas",
4949
"holoviews": "serialize_holoviews",
5050
"polars.DataFrame": "serialize_polars",
51-
"polars.Series": "serialize_polars",
51+
"numpy.ndarray": "serialize_numpy",
5252
}
5353

5454
file_handlers = {

streamjoy/streams.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,25 @@ def __init__(self, resources: list[Any] | None = None, **params) -> None:
196196

197197
super().__init__(**params)
198198

199+
@classmethod
200+
def from_numpy(
201+
cls,
202+
array: np.ndarray,
203+
renderer: Callable | None = None,
204+
renderer_iterables: list[Any] | None = None,
205+
renderer_kwargs: dict | None = None,
206+
**kwargs,
207+
) -> MediaStream:
208+
serialized = serialize_appropriately(
209+
cls,
210+
resources=array,
211+
renderer=renderer,
212+
renderer_iterables=renderer_iterables,
213+
renderer_kwargs=renderer_kwargs,
214+
**kwargs,
215+
)
216+
return cls(**serialized.param.values(), **serialized.kwargs)
217+
199218
@classmethod
200219
def from_xarray(
201220
cls,
@@ -538,6 +557,13 @@ def write(
538557
if resources is None:
539558
resources = self.resources
540559
renderer_iterables = self.renderer_iterables
560+
max_frames = _utils.get_max_frames(
561+
len(resources), kwargs.get("max_frames", self.max_frames)
562+
)
563+
kwargs["max_frames"] = max_frames
564+
resources, renderer_iterables = _utils.subset_resources_renderer_iterables(
565+
resources, renderer_iterables, max_frames
566+
)
541567
else:
542568
serialized = serialize_appropriately(
543569
self, resources, renderer, renderer_kwargs, **kwargs

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22

33
import hvplot.xarray # noqa: F401
4+
import imageio.v3 as iio
45
import pandas as pd
56
import polars as pl
67
import pytest
@@ -16,6 +17,11 @@
1617
PARQUET_PATH = DATA_DIR / "gapminder.parquet"
1718

1819

20+
@pytest.fixture
21+
def array():
22+
return iio.imread("imageio:newtonscradle.gif")
23+
24+
1925
@pytest.fixture
2026
def ds():
2127
return xr.open_zarr(ZARR_PATH)

tests/test_polars.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from io import BytesIO
2+
3+
import streamjoy.polars # noqa: F401
4+
5+
6+
class TestPolars:
7+
def test_dataframe(self, pl_df):
8+
stream = pl_df.streamjoy(groupby="Country")
9+
assert stream.renderer_kwargs == {
10+
"groupby": "Country",
11+
"x": "Year",
12+
"y": "fertility",
13+
"xlabel": "Year",
14+
"ylabel": "Fertility",
15+
}
16+
assert isinstance(stream.write(), BytesIO)

tests/test_serializers.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from streamjoy.models import Serialized
2+
from streamjoy.serializers import (
3+
serialize_holoviews,
4+
serialize_numpy,
5+
serialize_pandas,
6+
serialize_polars,
7+
serialize_xarray,
8+
)
9+
10+
11+
class TestSerializeNumpy:
12+
def test_serialize_numpy(self, array):
13+
serialized = serialize_numpy(None, array)
14+
assert isinstance(serialized, Serialized)
15+
assert len(serialized.resources) == 36
16+
assert isinstance(serialized.resources, list)
17+
assert not serialized.renderer
18+
assert serialized.renderer_iterables is None
19+
assert isinstance(serialized.renderer_kwargs, dict)
20+
assert isinstance(serialized.kwargs, dict)
21+
22+
23+
class TestSerializeXarray:
24+
def test_serialize_xarray(self, ds):
25+
serialized = serialize_xarray(None, ds)
26+
assert isinstance(serialized, Serialized)
27+
assert len(serialized.resources) == 3
28+
assert isinstance(serialized.resources, list)
29+
assert callable(serialized.renderer)
30+
assert serialized.renderer_iterables is None
31+
assert isinstance(serialized.renderer_kwargs, dict)
32+
assert isinstance(serialized.kwargs, dict)
33+
34+
35+
class TestSerializePandas:
36+
def test_serialize_pandas(self, df):
37+
serialized = serialize_pandas(None, df)
38+
assert isinstance(serialized, Serialized)
39+
assert len(serialized.resources) == 3
40+
assert isinstance(serialized.resources, list)
41+
assert callable(serialized.renderer)
42+
assert serialized.renderer_iterables is None
43+
assert isinstance(serialized.renderer_kwargs, dict)
44+
assert isinstance(serialized.kwargs, dict)
45+
46+
47+
class TestSerializePolars:
48+
def test_serialize_polars(self, pl_df):
49+
serialized = serialize_polars(None, pl_df)
50+
assert isinstance(serialized, Serialized)
51+
assert len(serialized.resources) == 3
52+
assert isinstance(serialized.resources, list)
53+
assert callable(serialized.renderer)
54+
assert serialized.renderer_iterables is None
55+
assert isinstance(serialized.renderer_kwargs, dict)
56+
assert isinstance(serialized.kwargs, dict)
57+
58+
59+
class TestSerializeHoloviews:
60+
def test_serialize_holoviews(self, hmap):
61+
serialized = serialize_holoviews(None, hmap)
62+
assert isinstance(serialized, Serialized)
63+
assert len(serialized.resources) == 20
64+
assert isinstance(serialized.resources, list)
65+
assert callable(serialized.renderer)
66+
assert serialized.renderer_iterables is None
67+
assert isinstance(serialized.renderer_kwargs, dict)
68+
assert isinstance(serialized.kwargs, dict)

tests/test_streams.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ def _assert_stream_and_props(self, sj, stream_cls):
1414
props.n_images == 3
1515
return props
1616

17+
def test_from_numpy(self, stream_cls, array):
18+
sj = stream_cls.from_numpy(array)
19+
self._assert_stream_and_props(sj, stream_cls)
20+
1721
def test_from_pandas(self, stream_cls, df):
1822
sj = stream_cls.from_pandas(df)
1923
self._assert_stream_and_props(sj, stream_cls)
@@ -69,6 +73,12 @@ def test_holoviews_bokeh_backend(self, stream_cls, ds):
6973
props = self._assert_stream_and_props(sj, stream_cls)
7074
assert props.shape[1] == 300
7175

76+
def test_write_max_frames(self, stream_cls, df):
77+
sj = stream_cls.from_pandas(df, max_frames=3)
78+
buf = sj.write(max_frames=2)
79+
props = improps(buf)
80+
assert props.n_images == 2
81+
7282

7383
class TestGifStream(AbstractTestMediaStream):
7484
@pytest.fixture(scope="class")

0 commit comments

Comments
 (0)