Skip to content

Commit cbdc1dc

Browse files
authored
Refactor template engine: Extract math & statistical functions into MathExtension (home-assistant#152266)
1 parent b203a83 commit cbdc1dc

File tree

6 files changed

+736
-876
lines changed

6 files changed

+736
-876
lines changed

homeassistant/helpers/template/__init__.py

Lines changed: 2 additions & 271 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import pathlib
1919
import random
2020
import re
21-
import statistics
2221
from struct import error as StructError, pack, unpack_from
2322
import sys
2423
from types import CodeType, TracebackType
@@ -37,7 +36,7 @@
3736

3837
from awesomeversion import AwesomeVersion
3938
import jinja2
40-
from jinja2 import pass_context, pass_environment, pass_eval_context
39+
from jinja2 import pass_context, pass_eval_context
4140
from jinja2.runtime import AsyncLoopContext, LoopContext
4241
from jinja2.sandbox import ImmutableSandboxedEnvironment
4342
from jinja2.utils import Namespace
@@ -2047,121 +2046,11 @@ def returns(value):
20472046
return wrapper
20482047

20492048

2050-
def logarithm(value, base=math.e, default=_SENTINEL):
2051-
"""Filter and function to get logarithm of the value with a specific base."""
2052-
try:
2053-
base_float = float(base)
2054-
except (ValueError, TypeError):
2055-
if default is _SENTINEL:
2056-
raise_no_default("log", base)
2057-
return default
2058-
try:
2059-
value_float = float(value)
2060-
return math.log(value_float, base_float)
2061-
except (ValueError, TypeError):
2062-
if default is _SENTINEL:
2063-
raise_no_default("log", value)
2064-
return default
2065-
2066-
2067-
def sine(value, default=_SENTINEL):
2068-
"""Filter and function to get sine of the value."""
2069-
try:
2070-
return math.sin(float(value))
2071-
except (ValueError, TypeError):
2072-
if default is _SENTINEL:
2073-
raise_no_default("sin", value)
2074-
return default
2075-
2076-
2077-
def cosine(value, default=_SENTINEL):
2078-
"""Filter and function to get cosine of the value."""
2079-
try:
2080-
return math.cos(float(value))
2081-
except (ValueError, TypeError):
2082-
if default is _SENTINEL:
2083-
raise_no_default("cos", value)
2084-
return default
2085-
2086-
2087-
def tangent(value, default=_SENTINEL):
2088-
"""Filter and function to get tangent of the value."""
2089-
try:
2090-
return math.tan(float(value))
2091-
except (ValueError, TypeError):
2092-
if default is _SENTINEL:
2093-
raise_no_default("tan", value)
2094-
return default
2095-
2096-
2097-
def arc_sine(value, default=_SENTINEL):
2098-
"""Filter and function to get arc sine of the value."""
2099-
try:
2100-
return math.asin(float(value))
2101-
except (ValueError, TypeError):
2102-
if default is _SENTINEL:
2103-
raise_no_default("asin", value)
2104-
return default
2105-
2106-
2107-
def arc_cosine(value, default=_SENTINEL):
2108-
"""Filter and function to get arc cosine of the value."""
2109-
try:
2110-
return math.acos(float(value))
2111-
except (ValueError, TypeError):
2112-
if default is _SENTINEL:
2113-
raise_no_default("acos", value)
2114-
return default
2115-
2116-
2117-
def arc_tangent(value, default=_SENTINEL):
2118-
"""Filter and function to get arc tangent of the value."""
2119-
try:
2120-
return math.atan(float(value))
2121-
except (ValueError, TypeError):
2122-
if default is _SENTINEL:
2123-
raise_no_default("atan", value)
2124-
return default
2125-
2126-
2127-
def arc_tangent2(*args, default=_SENTINEL):
2128-
"""Filter and function to calculate four quadrant arc tangent of y / x.
2129-
2130-
The parameters to atan2 may be passed either in an iterable or as separate arguments
2131-
The default value may be passed either as a positional or in a keyword argument
2132-
"""
2133-
try:
2134-
if 1 <= len(args) <= 2 and isinstance(args[0], (list, tuple)):
2135-
if len(args) == 2 and default is _SENTINEL:
2136-
# Default value passed as a positional argument
2137-
default = args[1]
2138-
args = args[0]
2139-
elif len(args) == 3 and default is _SENTINEL:
2140-
# Default value passed as a positional argument
2141-
default = args[2]
2142-
2143-
return math.atan2(float(args[0]), float(args[1]))
2144-
except (ValueError, TypeError):
2145-
if default is _SENTINEL:
2146-
raise_no_default("atan2", args)
2147-
return default
2148-
2149-
21502049
def version(value):
21512050
"""Filter and function to get version object of the value."""
21522051
return AwesomeVersion(value)
21532052

21542053

2155-
def square_root(value, default=_SENTINEL):
2156-
"""Filter and function to get square root of the value."""
2157-
try:
2158-
return math.sqrt(float(value))
2159-
except (ValueError, TypeError):
2160-
if default is _SENTINEL:
2161-
raise_no_default("sqrt", value)
2162-
return default
2163-
2164-
21652054
def timestamp_custom(value, date_format=DATE_STR_FORMAT, local=True, default=_SENTINEL):
21662055
"""Filter to convert given timestamp to format."""
21672056
try:
@@ -2315,118 +2204,6 @@ def fail_when_undefined(value):
23152204
return value
23162205

23172206

2318-
def min_max_from_filter(builtin_filter: Any, name: str) -> Any:
2319-
"""Convert a built-in min/max Jinja filter to a global function.
2320-
2321-
The parameters may be passed as an iterable or as separate arguments.
2322-
"""
2323-
2324-
@pass_environment
2325-
@wraps(builtin_filter)
2326-
def wrapper(environment: jinja2.Environment, *args: Any, **kwargs: Any) -> Any:
2327-
if len(args) == 0:
2328-
raise TypeError(f"{name} expected at least 1 argument, got 0")
2329-
2330-
if len(args) == 1:
2331-
if isinstance(args[0], Iterable):
2332-
return builtin_filter(environment, args[0], **kwargs)
2333-
2334-
raise TypeError(f"'{type(args[0]).__name__}' object is not iterable")
2335-
2336-
return builtin_filter(environment, args, **kwargs)
2337-
2338-
return pass_environment(wrapper)
2339-
2340-
2341-
def average(*args: Any, default: Any = _SENTINEL) -> Any:
2342-
"""Filter and function to calculate the arithmetic mean.
2343-
2344-
Calculates of an iterable or of two or more arguments.
2345-
2346-
The parameters may be passed as an iterable or as separate arguments.
2347-
"""
2348-
if len(args) == 0:
2349-
raise TypeError("average expected at least 1 argument, got 0")
2350-
2351-
# If first argument is iterable and more than 1 argument provided but not a named
2352-
# default, then use 2nd argument as default.
2353-
if isinstance(args[0], Iterable):
2354-
average_list = args[0]
2355-
if len(args) > 1 and default is _SENTINEL:
2356-
default = args[1]
2357-
elif len(args) == 1:
2358-
raise TypeError(f"'{type(args[0]).__name__}' object is not iterable")
2359-
else:
2360-
average_list = args
2361-
2362-
try:
2363-
return statistics.fmean(average_list)
2364-
except (TypeError, statistics.StatisticsError):
2365-
if default is _SENTINEL:
2366-
raise_no_default("average", args)
2367-
return default
2368-
2369-
2370-
def median(*args: Any, default: Any = _SENTINEL) -> Any:
2371-
"""Filter and function to calculate the median.
2372-
2373-
Calculates median of an iterable of two or more arguments.
2374-
2375-
The parameters may be passed as an iterable or as separate arguments.
2376-
"""
2377-
if len(args) == 0:
2378-
raise TypeError("median expected at least 1 argument, got 0")
2379-
2380-
# If first argument is a list or tuple and more than 1 argument provided but not a named
2381-
# default, then use 2nd argument as default.
2382-
if isinstance(args[0], Iterable):
2383-
median_list = args[0]
2384-
if len(args) > 1 and default is _SENTINEL:
2385-
default = args[1]
2386-
elif len(args) == 1:
2387-
raise TypeError(f"'{type(args[0]).__name__}' object is not iterable")
2388-
else:
2389-
median_list = args
2390-
2391-
try:
2392-
return statistics.median(median_list)
2393-
except (TypeError, statistics.StatisticsError):
2394-
if default is _SENTINEL:
2395-
raise_no_default("median", args)
2396-
return default
2397-
2398-
2399-
def statistical_mode(*args: Any, default: Any = _SENTINEL) -> Any:
2400-
"""Filter and function to calculate the statistical mode.
2401-
2402-
Calculates mode of an iterable of two or more arguments.
2403-
2404-
The parameters may be passed as an iterable or as separate arguments.
2405-
"""
2406-
if not args:
2407-
raise TypeError("statistical_mode expected at least 1 argument, got 0")
2408-
2409-
# If first argument is a list or tuple and more than 1 argument provided but not a named
2410-
# default, then use 2nd argument as default.
2411-
if len(args) == 1 and isinstance(args[0], Iterable):
2412-
mode_list = args[0]
2413-
elif isinstance(args[0], list | tuple):
2414-
mode_list = args[0]
2415-
if len(args) > 1 and default is _SENTINEL:
2416-
default = args[1]
2417-
elif len(args) == 1:
2418-
raise TypeError(f"'{type(args[0]).__name__}' object is not iterable")
2419-
else:
2420-
mode_list = args
2421-
2422-
try:
2423-
return statistics.mode(mode_list)
2424-
except (TypeError, statistics.StatisticsError):
2425-
if default is _SENTINEL:
2426-
raise_no_default("statistical_mode", args)
2427-
return default
2428-
2429-
24302207
def forgiving_float(value, default=_SENTINEL):
24312208
"""Try to convert value to a float."""
24322209
try:
@@ -2549,21 +2326,6 @@ def regex_findall(value, find="", ignorecase=False):
25492326
return _regex_cache(find, flags).findall(value)
25502327

25512328

2552-
def bitwise_and(first_value, second_value):
2553-
"""Perform a bitwise and operation."""
2554-
return first_value & second_value
2555-
2556-
2557-
def bitwise_or(first_value, second_value):
2558-
"""Perform a bitwise or operation."""
2559-
return first_value | second_value
2560-
2561-
2562-
def bitwise_xor(first_value, second_value):
2563-
"""Perform a bitwise xor operation."""
2564-
return first_value ^ second_value
2565-
2566-
25672329
def struct_pack(value: Any | None, format_string: str) -> bytes | None:
25682330
"""Pack an object into a bytes object."""
25692331
try:
@@ -3065,45 +2827,29 @@ def __init__(
30652827
self.add_extension("jinja2.ext.do")
30662828
self.add_extension("homeassistant.helpers.template.extensions.Base64Extension")
30672829
self.add_extension("homeassistant.helpers.template.extensions.CryptoExtension")
2830+
self.add_extension("homeassistant.helpers.template.extensions.MathExtension")
30682831

3069-
self.globals["acos"] = arc_cosine
30702832
self.globals["as_datetime"] = as_datetime
30712833
self.globals["as_function"] = as_function
30722834
self.globals["as_local"] = dt_util.as_local
30732835
self.globals["as_timedelta"] = as_timedelta
30742836
self.globals["as_timestamp"] = forgiving_as_timestamp
3075-
self.globals["asin"] = arc_sine
3076-
self.globals["atan"] = arc_tangent
3077-
self.globals["atan2"] = arc_tangent2
3078-
self.globals["average"] = average
30792837
self.globals["bool"] = forgiving_boolean
30802838
self.globals["combine"] = combine
3081-
self.globals["cos"] = cosine
30822839
self.globals["difference"] = difference
3083-
self.globals["e"] = math.e
30842840
self.globals["flatten"] = flatten
30852841
self.globals["float"] = forgiving_float
30862842
self.globals["iif"] = iif
30872843
self.globals["int"] = forgiving_int
30882844
self.globals["intersect"] = intersect
30892845
self.globals["is_number"] = is_number
3090-
self.globals["log"] = logarithm
3091-
self.globals["max"] = min_max_from_filter(self.filters["max"], "max")
3092-
self.globals["median"] = median
30932846
self.globals["merge_response"] = merge_response
3094-
self.globals["min"] = min_max_from_filter(self.filters["min"], "min")
30952847
self.globals["pack"] = struct_pack
3096-
self.globals["pi"] = math.pi
30972848
self.globals["set"] = _to_set
30982849
self.globals["shuffle"] = shuffle
3099-
self.globals["sin"] = sine
31002850
self.globals["slugify"] = slugify
3101-
self.globals["sqrt"] = square_root
3102-
self.globals["statistical_mode"] = statistical_mode
31032851
self.globals["strptime"] = strptime
31042852
self.globals["symmetric_difference"] = symmetric_difference
3105-
self.globals["tan"] = tangent
3106-
self.globals["tau"] = math.pi * 2
31072853
self.globals["timedelta"] = timedelta
31082854
self.globals["tuple"] = _to_tuple
31092855
self.globals["typeof"] = typeof
@@ -3113,25 +2859,16 @@ def __init__(
31132859
self.globals["version"] = version
31142860
self.globals["zip"] = zip
31152861

3116-
self.filters["acos"] = arc_cosine
31172862
self.filters["add"] = add
31182863
self.filters["apply"] = apply
31192864
self.filters["as_datetime"] = as_datetime
31202865
self.filters["as_function"] = as_function
31212866
self.filters["as_local"] = dt_util.as_local
31222867
self.filters["as_timedelta"] = as_timedelta
31232868
self.filters["as_timestamp"] = forgiving_as_timestamp
3124-
self.filters["asin"] = arc_sine
3125-
self.filters["atan"] = arc_tangent
3126-
self.filters["atan2"] = arc_tangent2
3127-
self.filters["average"] = average
3128-
self.filters["bitwise_and"] = bitwise_and
3129-
self.filters["bitwise_or"] = bitwise_or
3130-
self.filters["bitwise_xor"] = bitwise_xor
31312869
self.filters["bool"] = forgiving_boolean
31322870
self.filters["combine"] = combine
31332871
self.filters["contains"] = contains
3134-
self.filters["cos"] = cosine
31352872
self.filters["difference"] = difference
31362873
self.filters["flatten"] = flatten
31372874
self.filters["float"] = forgiving_float_filter
@@ -3142,8 +2879,6 @@ def __init__(
31422879
self.filters["intersect"] = intersect
31432880
self.filters["is_defined"] = fail_when_undefined
31442881
self.filters["is_number"] = is_number
3145-
self.filters["log"] = logarithm
3146-
self.filters["median"] = median
31472882
self.filters["multiply"] = multiply
31482883
self.filters["ord"] = ord
31492884
self.filters["ordinal"] = ordinal
@@ -3156,12 +2891,8 @@ def __init__(
31562891
self.filters["regex_search"] = regex_search
31572892
self.filters["round"] = forgiving_round
31582893
self.filters["shuffle"] = shuffle
3159-
self.filters["sin"] = sine
31602894
self.filters["slugify"] = slugify
3161-
self.filters["sqrt"] = square_root
3162-
self.filters["statistical_mode"] = statistical_mode
31632895
self.filters["symmetric_difference"] = symmetric_difference
3164-
self.filters["tan"] = tangent
31652896
self.filters["timestamp_custom"] = timestamp_custom
31662897
self.filters["timestamp_local"] = timestamp_local
31672898
self.filters["timestamp_utc"] = timestamp_utc

homeassistant/helpers/template/extensions/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22

33
from .base64 import Base64Extension
44
from .crypto import CryptoExtension
5+
from .math import MathExtension
56

6-
__all__ = ["Base64Extension", "CryptoExtension"]
7+
__all__ = ["Base64Extension", "CryptoExtension", "MathExtension"]

homeassistant/helpers/template/extensions/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class TemplateFunction:
1919
"""Definition for a template function, filter, or test."""
2020

2121
name: str
22-
func: Callable[..., Any]
22+
func: Callable[..., Any] | Any
2323
as_global: bool = False
2424
as_filter: bool = False
2525
as_test: bool = False

0 commit comments

Comments
 (0)