Skip to content

Commit ffb7ca4

Browse files
authored
Add apmath algorithms and list expression kind (#112)
1 parent 0d9b88b commit ffb7ca4

15 files changed

+550
-63
lines changed

functional_algorithms/algorithms.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414

1515
import functools
16+
import warnings
1617
from . import floating_point_algorithms as fpa
1718

1819

@@ -45,6 +46,16 @@ def foo(ctx, ...):
4546
# domain=None)` implements dispatch to real_foo and
4647
# complex_foo based on the arguments domain.
4748
assert 0 # unreachable
49+
50+
Warning: definition registry is global. When reusing definition
51+
class in other modules, make sure to use module specific registry
52+
via::
53+
54+
import functional_algorithms as fa
55+
56+
class definition(fa.algorithms.definition):
57+
_registry = {}
58+
4859
"""
4960

5061
# dict(<domain>=<dict of <native function name>:<definition for domain>>)
@@ -82,8 +93,13 @@ def wrapper(ctx, *args, **kwargs):
8293
if result is NotImplemented:
8394
raise NotImplementedError(f"{self.native_func_name} not implemented for {self.domain} domain: {func.__name__}")
8495

96+
if isinstance(result, list):
97+
result = ctx.list(result)
98+
8599
return result
86100

101+
if self.native_func_name in self.registry:
102+
warnings.warn(f"{self.native_func_name} wrapper is overwritten by {func.__name__} wrapper")
87103
self.registry[self.native_func_name] = wrapper
88104

89105
return wrapper

functional_algorithms/apmath.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def renormalize(ctx, seq, functional=False, fast=False, size=None, dtype=None):
187187
e_lst = vecsum(ctx, seq, fast=fast)
188188
# VecSumErrBranch:
189189
if functional:
190-
zero = ctx.constant(0)
190+
zero = ctx.constant(0, seq[0])
191191

192192
f_lst = []
193193
eps_i = e_lst[0]
@@ -241,20 +241,21 @@ def nztopk(ctx, seq, k):
241241
elif len(seq) == 1:
242242
return seq
243243
elif len(seq) == 2:
244-
flag = seq[0] == ctx.constant(0)
244+
flag = seq[0] == ctx.constant(0, seq[0])
245245
if k == 1:
246246
return [ctx.select(flag, seq[1], seq[0])]
247247
return [ctx.select(flag, seq[1], seq[0]), ctx.select(flag, seq[0], seq[1])]
248248

249-
result = []
250-
zero = ctx.constant(0)
251-
one = ctx.constant(1)
249+
zero = ctx.constant(0, seq[0])
250+
izero = ctx.constant(0)
251+
ione = ctx.constant(1, izero)
252252

253253
isnzero = [a != zero for a in seq]
254-
nzcount = [zero] # ideally, nzcount ought to be a integer sequence
254+
nzcount = [izero]
255255
for b in isnzero[:-1]:
256-
nzcount.append(nzcount[-1] + ctx.select(b, one, zero))
256+
nzcount.append(nzcount[-1] + ctx.select(b, ione, izero))
257257

258+
result = []
258259
for i in range(min(k, len(seq))):
259260
lst = [ctx.select(ctx.logical_and(isnzero[j], nzcount[j] == i), seq[j], zero) for j in range(i, len(seq))]
260261
result.append(sum(lst[:-1], lst[-1]))
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from . import algorithms as faa
2+
from . import apmath
3+
4+
5+
class definition(faa.definition):
6+
_registry = {}
7+
8+
9+
@definition("square", domain="real")
10+
def real_square(ctx, x: list[float, ...], functional: bool = True, size: int = None):
11+
"""Square on real input: x * x"""
12+
return apmath.square(ctx, x, functional=functional, size=2)
13+
14+
15+
@definition("square")
16+
def square(ctx, z: list[float | complex, ...]):
17+
"""Square on floating-point expansion"""
18+
assert 0

functional_algorithms/context.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import typing
55
import warnings
66
from collections import defaultdict
7-
from .utils import UNSPECIFIED, boolean_types, float_types, complex_types
8-
from .expr import Expr, make_constant, make_symbol, make_apply, known_expression_kinds
7+
from .utils import UNSPECIFIED, boolean_types, float_types, complex_types, integer_types
8+
from .expr import Expr, make_constant, make_symbol, make_apply, known_expression_kinds, make_list, make_item, make_len
99
from .typesystem import Type
1010

1111

@@ -194,11 +194,14 @@ def trace(self, func, *args):
194194
if ":" in a:
195195
name, annot = a.split(":", 1)
196196
name = name.strip()
197+
# TODO: parse annot as it may contain typing alias
197198
param = param.replace(name=name if name else param.name, annotation=annot.strip())
198199
else:
199200
param = param.replace(name=a.strip())
200201
elif isinstance(a, type):
201202
param = param.replace(annotation=a.__name__)
203+
elif isinstance(a, types.GenericAlias):
204+
param = param.replace(annotation=a)
202205
else:
203206
raise NotImplementedError((a, type(a)))
204207
if param.annotation is inspect.Parameter.empty:
@@ -207,12 +210,28 @@ def trace(self, func, *args):
207210
typ = param.annotation
208211
if isinstance(typ, types.UnionType):
209212
typ = typing.get_args(typ)[0]
210-
assert isinstance(typ, (type, str)), typ
213+
assert isinstance(typ, (type, str, types.GenericAlias)), (type(typ), typ)
211214
if default_typ is UNSPECIFIED:
212215
default_typ = typ
213-
a = self.symbol(param.name, typ).reference(ref_name=param.name)
216+
if isinstance(typ, types.GenericAlias):
217+
if typ.__name__ == "list":
218+
a = self.list(
219+
[
220+
self.symbol(f"{param.name}_{k}_", t).reference(
221+
ref_name=f"{param.name}_{k}_",
222+
force=False, # reference to item will be defined only when it is used
223+
)
224+
for k, t in enumerate(typ.__args__)
225+
]
226+
)
227+
else:
228+
raise TypeError(f"annotation type must be type of a scalar or list, got {typ}")
229+
else:
230+
a = self.symbol(param.name, typ)
231+
a = a.reference(ref_name=param.name)
214232
new_args.append(a)
215233
args = tuple(new_args)
234+
216235
name = self.symbol(func.__name__).reference(ref_name=func.__name__)
217236
return make_apply(self, name, args, func(self, *args))
218237

@@ -233,6 +252,7 @@ def symbol(self, name, typ=UNSPECIFIED):
233252
if typ is UNSPECIFIED:
234253
like = self.default_like
235254
if like is not None:
255+
assert like.kind == "symbol", like.kind
236256
typ = like.operands[1]
237257
else:
238258
typ = "float"
@@ -243,15 +263,31 @@ def constant(self, value, like_expr=UNSPECIFIED):
243263
if isinstance(value, boolean_types):
244264
like_expr = self.symbol("_boolean_value", "boolean")
245265
elif self._default_constant_type is not None:
266+
# Warning: when specified and default_constant_type is
267+
# float, integer values will be interpreted as floats
246268
like_expr = self.symbol("_value", self._default_constant_type)
269+
elif isinstance(value, integer_types):
270+
like_expr = self.symbol("_integer_value", type(value))
247271
elif isinstance(value, float_types):
248272
like_expr = self.symbol("_float_value", type(value))
249273
elif isinstance(value, complex_types):
250274
like_expr = self.symbol("_complex_value", type(value))
251275
else:
252276
like_expr = self.default_like
277+
elif isinstance(like_expr, (str, type, Type)):
278+
typ = Type.fromobject(self, like_expr)
279+
like_expr = self.symbol(f"_{typ.kind}_value", typ)
253280
return make_constant(self, value, like_expr)
254281

282+
def list(self, items):
283+
return make_list(self, items)
284+
285+
def item(self, container, index):
286+
return make_item(self, container, index)
287+
288+
def len(self, container):
289+
return make_len(self, container)
290+
255291
def call(self, func, args):
256292
"""Apply callable to arguments and return its result.
257293

0 commit comments

Comments
 (0)