Skip to content

Fix type hints after adding Branca type checking #2060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Dec 29, 2024
2 changes: 1 addition & 1 deletion folium/elements.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to rename these methods, since they really do something different than branca.Element.render. See the discussion I started on the signature redefinition.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand all the typing issues fully, but in general it looks good. I don't mind going ahead with this as is to fix the immediate issues. We can discuss my proposal to rename render from MacroElement and its children at a later moment. Whatever you think is best for now.

Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class JSCSSMixin(Element):
default_js: List[Tuple[str, str]] = []
default_css: List[Tuple[str, str]] = []

def render(self, **kwargs) -> None:
def render(self, **kwargs):
figure = self.get_root()
assert isinstance(
figure, Figure
Expand Down
78 changes: 50 additions & 28 deletions folium/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,24 @@
import numpy as np
import requests
from branca.colormap import ColorMap, LinearColormap, StepColormap
from branca.element import Element, Figure, Html, IFrame, JavascriptLink, MacroElement
from branca.element import (
Div,
Element,
Figure,
Html,
IFrame,
JavascriptLink,
MacroElement,
)
from branca.utilities import color_brewer

from folium.elements import JSCSSMixin
from folium.folium import Map
from folium.map import FeatureGroup, Icon, Layer, Marker, Popup, Tooltip
from folium.template import Template
from folium.utilities import (
TypeBoundsReturn,
TypeContainer,
TypeJsonValue,
TypeLine,
TypePathOptions,
Expand Down Expand Up @@ -165,7 +175,7 @@ def __init__(
self.top = _parse_size(top)
self.position = position

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
super().render(**kwargs)

Expand Down Expand Up @@ -284,9 +294,15 @@ def __init__(
self.top = _parse_size(top)
self.position = position

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
self._parent.html.add_child(
parent = self._parent
if not isinstance(parent, (Figure, Div, Popup)):
raise TypeError(
"VegaLite elements can only be added to a Figure, Div, or Popup"
)

parent.html.add_child(
Element(
Template(
"""
Expand Down Expand Up @@ -331,7 +347,7 @@ def render(self, **kwargs) -> None:
embed_vegalite = embed_mapping.get(
self.vegalite_major_version, self._embed_vegalite_v2
)
embed_vegalite(figure)
embed_vegalite(figure=figure, parent=parent)

@property
def vegalite_major_version(self) -> Optional[int]:
Expand All @@ -342,8 +358,8 @@ def vegalite_major_version(self) -> Optional[int]:

return int(schema.split("/")[-1].split(".")[0].lstrip("v"))

def _embed_vegalite_v5(self, figure: Figure) -> None:
self._vega_embed()
def _embed_vegalite_v5(self, figure: Figure, parent: TypeContainer) -> None:
self._vega_embed(parent=parent)

figure.header.add_child(
JavascriptLink("https://cdn.jsdelivr.net/npm//vega@5"), name="vega"
Expand All @@ -356,8 +372,8 @@ def _embed_vegalite_v5(self, figure: Figure) -> None:
name="vega-embed",
)

def _embed_vegalite_v4(self, figure: Figure) -> None:
self._vega_embed()
def _embed_vegalite_v4(self, figure: Figure, parent: TypeContainer) -> None:
self._vega_embed(parent=parent)

figure.header.add_child(
JavascriptLink("https://cdn.jsdelivr.net/npm//vega@5"), name="vega"
Expand All @@ -370,8 +386,8 @@ def _embed_vegalite_v4(self, figure: Figure) -> None:
name="vega-embed",
)

def _embed_vegalite_v3(self, figure: Figure) -> None:
self._vega_embed()
def _embed_vegalite_v3(self, figure: Figure, parent: TypeContainer) -> None:
self._vega_embed(parent=parent)

figure.header.add_child(
JavascriptLink("https://cdn.jsdelivr.net/npm/vega@4"), name="vega"
Expand All @@ -384,8 +400,8 @@ def _embed_vegalite_v3(self, figure: Figure) -> None:
name="vega-embed",
)

def _embed_vegalite_v2(self, figure: Figure) -> None:
self._vega_embed()
def _embed_vegalite_v2(self, figure: Figure, parent: TypeContainer) -> None:
self._vega_embed(parent=parent)

figure.header.add_child(
JavascriptLink("https://cdn.jsdelivr.net/npm/vega@3"), name="vega"
Expand All @@ -398,8 +414,8 @@ def _embed_vegalite_v2(self, figure: Figure) -> None:
name="vega-embed",
)

def _vega_embed(self) -> None:
self._parent.script.add_child(
def _vega_embed(self, parent: TypeContainer) -> None:
parent.script.add_child(
Element(
Template(
"""
Expand All @@ -412,8 +428,8 @@ def _vega_embed(self) -> None:
name=self.get_name(),
)

def _embed_vegalite_v1(self, figure: Figure) -> None:
self._parent.script.add_child(
def _embed_vegalite_v1(self, figure: Figure, parent: TypeContainer) -> None:
parent.script.add_child(
Element(
Template(
"""
Expand All @@ -436,19 +452,19 @@ def _embed_vegalite_v1(self, figure: Figure) -> None:
figure.header.add_child(
JavascriptLink("https://cdnjs.cloudflare.com/ajax/libs/vega/2.6.5/vega.js"),
name="vega",
) # noqa
)
figure.header.add_child(
JavascriptLink(
"https://cdnjs.cloudflare.com/ajax/libs/vega-lite/1.3.1/vega-lite.js"
),
name="vega-lite",
) # noqa
)
figure.header.add_child(
JavascriptLink(
"https://cdnjs.cloudflare.com/ajax/libs/vega-embed/2.2.0/vega-embed.js"
),
name="vega-embed",
) # noqa
)


class GeoJson(Layer):
Expand Down Expand Up @@ -820,7 +836,7 @@ def _get_self_bounds(self) -> List[List[Optional[float]]]:
"""
return get_bounds(self.data, lonlat=True)

def render(self, **kwargs) -> None:
def render(self, **kwargs):
self.parent_map = get_obj_in_upper_tree(self, Map)
# Need at least one feature, otherwise style mapping fails
if (self.style or self.highlight) and self.data["features"]:
Expand Down Expand Up @@ -1041,12 +1057,12 @@ def recursive_get(data, keys):
self.style_function(feature)
) # noqa

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
self.style_data()
super().render(**kwargs)

def get_bounds(self) -> List[List[float]]:
def get_bounds(self) -> TypeBoundsReturn:
"""
Computes the bounds of the object itself (not including it's children)
in the form [[lat_min, lon_min], [lat_max, lon_max]]
Expand Down Expand Up @@ -1146,6 +1162,7 @@ def __init__(

def warn_for_geometry_collections(self) -> None:
"""Checks for GeoJson GeometryCollection features to warn user about incompatibility."""
assert isinstance(self._parent, GeoJson)
geom_collections = [
feature.get("properties") if feature.get("properties") is not None else key
for key, feature in enumerate(self._parent.data["features"])
Expand All @@ -1160,7 +1177,7 @@ def warn_for_geometry_collections(self) -> None:
UserWarning,
)

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
figure = self.get_root()
if isinstance(self._parent, GeoJson):
Expand Down Expand Up @@ -1565,7 +1582,7 @@ def __init__(
color_range = color_brewer(fill_color, n=nb_bins)
self.color_scale = StepColormap(
color_range,
index=bin_edges,
index=list(bin_edges),
vmin=bins_min,
vmax=bins_max,
caption=legend_name,
Expand Down Expand Up @@ -1625,7 +1642,7 @@ def highlight_function(x):
return {"weight": line_weight + 2, "fillOpacity": fill_opacity + 0.2}

if topojson:
self.geojson = TopoJson(
self.geojson: Union[TopoJson, GeoJson] = TopoJson(
geo_data,
topojson,
style_function=style_function,
Expand Down Expand Up @@ -1657,7 +1674,7 @@ def _get_by_key(cls, obj: Union[dict, list], key: str) -> Union[float, str, None
else:
return value

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Render the GeoJson/TopoJson and color scale objects."""
if self.color_scale:
# ColorMap needs Map as its parent
Expand Down Expand Up @@ -1963,8 +1980,13 @@ def __init__(
vmin=min(colors),
vmax=max(colors),
).to_step(nb_steps)
else:
elif isinstance(colormap, StepColormap):
cm = colormap
else:
raise TypeError(
f"Unexpected type for argument `colormap`: {type(colormap)}"
)

out: Dict[str, List[List[List[float]]]] = {}
for (lat1, lng1), (lat2, lng2), color in zip(coords[:-1], coords[1:], colors):
out.setdefault(cm(color), []).append([[lat1, lng1], [lat2, lng2]])
Expand Down
2 changes: 1 addition & 1 deletion folium/folium.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def _repr_png_(self) -> Optional[bytes]:
return None
return self._to_png()

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
figure = self.get_root()
assert isinstance(
Expand Down
13 changes: 7 additions & 6 deletions folium/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import warnings
from collections import OrderedDict
from typing import TYPE_CHECKING, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Optional, Sequence, Union, cast

from branca.element import Element, Figure, Html, MacroElement

Expand All @@ -14,6 +14,7 @@
from folium.utilities import (
JsCode,
TypeBounds,
TypeBoundsReturn,
TypeJsonValue,
escape_backticks,
parse_options,
Expand Down Expand Up @@ -221,7 +222,7 @@ def reset(self) -> None:
self.base_layers = OrderedDict()
self.overlays = OrderedDict()

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
self.reset()
for item in self._parent._children.values():
Expand Down Expand Up @@ -396,15 +397,15 @@ def __init__(
tooltip if isinstance(tooltip, Tooltip) else Tooltip(str(tooltip))
)

def _get_self_bounds(self) -> List[List[float]]:
def _get_self_bounds(self) -> TypeBoundsReturn:
"""Computes the bounds of the object itself.

Because a marker has only single coordinates, we repeat them.
"""
assert self.location is not None
return [self.location, self.location]
return cast(TypeBoundsReturn, [self.location, self.location])

def render(self) -> None:
def render(self):
if self.location is None:
raise ValueError(
f"{self._name} location must be assigned when added directly to map."
Expand Down Expand Up @@ -492,7 +493,7 @@ def __init__(
**kwargs,
)

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
for name, child in self._children.items():
child.render(**kwargs)
Expand Down
2 changes: 1 addition & 1 deletion folium/plugins/overlapping_marker_spiderfier.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def add_to(
) -> Element:
self._parent = parent
self.markers = self._get_all_markers(parent)
super().add_to(parent, name=name, index=index)
return super().add_to(parent, name=name, index=index)

def _get_all_markers(self, element: Element) -> list:
markers = []
Expand Down
16 changes: 9 additions & 7 deletions folium/raster_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from folium.template import Template
from folium.utilities import (
TypeBounds,
TypeBoundsReturn,
TypeJsonValue,
image_to_url,
mercator_transform,
normalize_bounds_type,
parse_options,
remove_empty,
)
Expand Down Expand Up @@ -246,7 +248,7 @@ class ImageOverlay(Layer):
* If string, it will be written directly in the output file.
* If file, it's content will be converted as embedded in the output file.
* If array-like, it will be converted to PNG base64 string and embedded in the output.
bounds: list
bounds: list/tuple of list/tuple of float
Image bounds on the map in the form
[[lat_min, lon_min], [lat_max, lon_max]]
opacity: float, default Leaflet's default (1.0)
Expand Down Expand Up @@ -319,7 +321,7 @@ def __init__(

self.url = image_to_url(image, origin=origin, colormap=colormap)

def render(self, **kwargs) -> None:
def render(self, **kwargs):
super().render()

figure = self.get_root()
Expand All @@ -344,13 +346,13 @@ def render(self, **kwargs) -> None:
Element(pixelated), name="leaflet-image-layer"
) # noqa

def _get_self_bounds(self) -> TypeBounds:
def _get_self_bounds(self) -> TypeBoundsReturn:
"""
Computes the bounds of the object itself (not including it's children)
in the form [[lat_min, lon_min], [lat_max, lon_max]].

"""
return self.bounds
return normalize_bounds_type(self.bounds)


class VideoOverlay(Layer):
Expand All @@ -361,7 +363,7 @@ class VideoOverlay(Layer):
----------
video_url: str
URL of the video
bounds: list
bounds: list/tuple of list/tuple of float
Video bounds on the map in the form
[[lat_min, lon_min], [lat_max, lon_max]]
autoplay: bool, default True
Expand Down Expand Up @@ -411,10 +413,10 @@ def __init__(
self.bounds = bounds
self.options = remove_empty(autoplay=autoplay, loop=loop, **kwargs)

def _get_self_bounds(self) -> TypeBounds:
def _get_self_bounds(self) -> TypeBoundsReturn:
"""
Computes the bounds of the object itself (not including it's children)
in the form [[lat_min, lon_min], [lat_max, lon_max]]

"""
return self.bounds
return normalize_bounds_type(self.bounds)
Loading
Loading