Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/how_do_i.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,37 @@ stream(..., client=client)

## 🪣 Read & write files on a remote filesystem with `fsspec_fs`

To read and write files on a remote filesystem, use `fsspec_fs` to specify the filesystem.

A scratch directory must be provided; be sure to prefix the bucket name.

```python
fs = fsspec.filesystem('s3', anon=False)
stream(..., fsspec_fs=fs, scratch_dir="bucket-name/streamjoy_scratch")
```

## 🚗 Use a custom webdriver to render HoloViews

By default, StreamJoy uses Firefox as the default headless webdriver to render HoloViews objects into images.

If you want to use Chrome instead, you can pass `webdriver="chrome"`.

If you want to use a different webdriver, you can pass a custom function to `webdriver`.

```python
def get_webdriver():
from selenium.webdriver.firefox.options import Options
from selenium.webdriver.firefox.webdriver import Service, WebDriver
from webdriver_manager.firefox import GeckoDriverManager

options = Options()
options.add_argument("--headless")
options.add_argument("--disable-extensions")
executable_path = GeckoDriverManager().install()
driver = WebDriver(
service=Service(executable_path), options=options
)
return driver

stream(..., webdriver=get_webdriver)
```
51 changes: 51 additions & 0 deletions streamjoy/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
except ImportError:
xr = None

try:
from selenium.webdriver.remote.webdriver import BaseWebDriver
except ImportError:
BaseWebDriver = None


def update_logger(
level: str | None = None,
Expand Down Expand Up @@ -332,3 +337,49 @@ def subset_resources_renderer_iterables(
iterable[: len(resources)] for iterable in renderer_iterables or []
]
return resources, renderer_iterables


def get_webdriver_path(webdriver: str):
if webdriver.lower() == "chrome":
from webdriver_manager.chrome import ChromeDriverManager

webdriver_path = ChromeDriverManager().install()
elif webdriver.lower() == "firefox":
from webdriver_manager.firefox import GeckoDriverManager

webdriver_path = GeckoDriverManager().install()
return webdriver_path


def get_webdriver(webdriver: tuple[str, str] | Callable) -> BaseWebDriver:
if isinstance(webdriver, Callable):
return webdriver()

webdriver_key, webdriver_path = webdriver
if webdriver_key.lower() == "chrome":
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.webdriver import Service, WebDriver

options = Options()
options.add_argument("--headless")
options.add_argument("--disable-extensions")
webdriver_path = webdriver_path or get_webdriver_path("chrome")
driver = WebDriver(service=Service(webdriver_path), options=options)

elif webdriver_key.lower() == "firefox":
from selenium.webdriver.firefox.options import Options
from selenium.webdriver.firefox.webdriver import Service, WebDriver

options = Options()
options.add_argument("--headless")
options.add_argument("--disable-extensions")
webdriver_path = webdriver_path or get_webdriver_path("firefox")
driver = WebDriver(service=Service(webdriver_path), options=options)

else:
raise NotImplementedError(
f"Webdriver {webdriver_key} not supported; "
f"use 'chrome' or 'firefox', or pass a custom callable."
)

return driver
2 changes: 2 additions & 0 deletions streamjoy/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def serialize_polars(
in_memory=kwargs.get("in_memory"),
scratch_dir=kwargs.get("scratch_dir"),
fsspec_fs=kwargs.get("fsspec_fs"),
webdriver=renderer_kwargs.pop("webdriver", None),
)(default_polars_renderer)
numeric_cols = [
col
Expand Down Expand Up @@ -379,6 +380,7 @@ def _select_element(hv_obj, key):
in_memory=kwargs.get("in_memory"),
scratch_dir=kwargs.get("scratch_dir"),
fsspec_fs=kwargs.get("fsspec_fs"),
webdriver=renderer_kwargs.pop("webdriver", None),
)(default_holoviews_renderer)
clims = {}
for hv_el in hv_obj.traverse(full_breadth=False):
Expand Down
2 changes: 2 additions & 0 deletions streamjoy/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"max_files": 2,
# matplotlib
"max_open_warning": 100,
# holoviews
"webdriver": "firefox",
# output
"in_memory": False,
"scratch_dir": "streamjoy_scratch",
Expand Down
21 changes: 10 additions & 11 deletions streamjoy/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def wrap_holoviews(
in_memory: bool = False,
scratch_dir: str | Path | None = None,
fsspec_fs: Any | None = None,
webdriver: str | Callable | None = None,
) -> Callable:
"""
Wraps a function used to render a holoviews object so that
Expand All @@ -87,10 +88,16 @@ def wrap_holoviews(
Args:
in_memory: Whether to render the object in-memory.
scratch_dir: The scratch directory to use.
fsspec_fs: The fsspec filesystem to use.

Returns:
The wrapped function.
"""

webdriver = _utils.get_config_default("webdriver", webdriver, warn=False)
if isinstance(webdriver, str):
webdriver = (webdriver, _utils.get_webdriver_path(webdriver))

if in_memory:
raise ValueError("Holoviews renderer does not support in-memory rendering.")

Expand All @@ -116,18 +123,10 @@ def wrapped(*args, **kwargs) -> Path | BytesIO:
)
if backend == "bokeh":
from bokeh.io.export import get_screenshot_as_png
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.webdriver import Service, WebDriver
from webdriver_manager.chrome import ChromeDriverManager

options = Options()
options.add_argument("--headless")
options.add_argument("--disable-extensions")
with WebDriver(
service=Service(ChromeDriverManager().install()), options=options
) as webdriver:

with _utils.get_webdriver(webdriver) as driver:
image = get_screenshot_as_png(
hv.render(hv_obj, backend=backend), driver=webdriver
hv.render(hv_obj, backend=backend), driver=driver
)
if fsspec_fs:
with fsspec_fs.open(uri, "wb") as f:
Expand Down