Skip to content

Commit d250952

Browse files
Fix streamlit-folium incompatibility (add layer to map with new class) (#1834)
* Add Layer to map with MacroElement class * run ruff * add ElementAddToElement only once * Fix marker cluster test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6040f42 commit d250952

File tree

3 files changed

+47
-38
lines changed

3 files changed

+47
-38
lines changed

folium/elements.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Tuple
22

3-
from branca.element import CssLink, Element, Figure, JavascriptLink
3+
from branca.element import CssLink, Element, Figure, JavascriptLink, MacroElement
4+
from jinja2 import Template
45

56

67
class JSCSSMixin(Element):
@@ -22,3 +23,20 @@ def render(self, **kwargs) -> None:
2223
figure.header.add_child(CssLink(url), name=name)
2324

2425
super().render(**kwargs)
26+
27+
28+
class ElementAddToElement(MacroElement):
29+
"""Abstract class to add an element to another element."""
30+
31+
_template = Template(
32+
"""
33+
{% macro script(this, kwargs) %}
34+
{{ this.element_name }}.addTo({{ this.element_parent_name }});
35+
{% endmacro %}
36+
"""
37+
)
38+
39+
def __init__(self, element_name: str, element_parent_name: str):
40+
super().__init__()
41+
self.element_name = element_name
42+
self.element_parent_name = element_parent_name

folium/map.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from branca.element import Element, Figure, Html, MacroElement
1010
from jinja2 import Template
1111

12+
from folium.elements import ElementAddToElement
1213
from folium.utilities import (
1314
TypeBounds,
1415
TypeJsonValue,
1516
camelize,
1617
escape_backticks,
17-
get_and_assert_figure_root,
1818
parse_options,
1919
validate_location,
2020
)
@@ -51,24 +51,15 @@ def __init__(
5151
self.show = show
5252

5353
def render(self, **kwargs):
54-
super().render(**kwargs)
5554
if self.show:
56-
self._add_layer_to_map()
57-
58-
def _add_layer_to_map(self, **kwargs):
59-
"""Show the layer on the map by adding it to its parent in JS."""
60-
template = Template(
61-
"""
62-
{%- macro script(this, kwargs) %}
63-
{{ this.get_name() }}.addTo({{ this._parent.get_name() }});
64-
{%- endmacro %}
65-
"""
66-
)
67-
script = template.module.__dict__["script"]
68-
figure = get_and_assert_figure_root(self)
69-
figure.script.add_child(
70-
Element(script(self, kwargs)), name=self.get_name() + "_add"
71-
)
55+
self.add_child(
56+
ElementAddToElement(
57+
element_name=self.get_name(),
58+
element_parent_name=self._parent.get_name(),
59+
),
60+
name=self.get_name() + "_add",
61+
)
62+
super().render(**kwargs)
7263

7364

7465
class FeatureGroup(Layer):

tests/plugins/test_marker_cluster.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,7 @@ def test_marker_cluster():
2323
m = folium.Map([45.0, 3.0], zoom_start=4)
2424
mc = plugins.MarkerCluster(data).add_to(m)
2525

26-
out = normalize(m._parent.render())
27-
28-
# We verify that imports
29-
assert (
30-
'<script src="https://cdnjs.cloudflare.com/ajax/libs/leaflet.markercluster/1.1.0/leaflet.markercluster.js"></script>' # noqa
31-
in out
32-
) # noqa
33-
assert (
34-
'<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.markercluster/1.1.0/MarkerCluster.css"/>' # noqa
35-
in out
36-
) # noqa
37-
assert (
38-
'<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.markercluster/1.1.0/MarkerCluster.Default.css"/>' # noqa
39-
in out
40-
) # noqa
41-
42-
# Verify the script part is okay.
43-
tmpl = Template(
26+
tmpl_for_expected = Template(
4427
"""
4528
var {{this.get_name()}} = L.markerClusterGroup(
4629
{{ this.options|tojson }}
@@ -60,7 +43,24 @@ def test_marker_cluster():
6043
{{ this.get_name() }}.addTo({{ this._parent.get_name() }});
6144
"""
6245
)
63-
expected = normalize(tmpl.render(this=mc))
46+
expected = normalize(tmpl_for_expected.render(this=mc))
47+
48+
out = normalize(m._parent.render())
49+
50+
# We verify that imports
51+
assert (
52+
'<script src="https://cdnjs.cloudflare.com/ajax/libs/leaflet.markercluster/1.1.0/leaflet.markercluster.js"></script>' # noqa
53+
in out
54+
) # noqa
55+
assert (
56+
'<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.markercluster/1.1.0/MarkerCluster.css"/>' # noqa
57+
in out
58+
) # noqa
59+
assert (
60+
'<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.markercluster/1.1.0/MarkerCluster.Default.css"/>' # noqa
61+
in out
62+
) # noqa
63+
6464
assert expected in out
6565

6666
bounds = m.get_bounds()

0 commit comments

Comments
 (0)