diff --git a/toolz/curried/__init__.py b/toolz/curried/__init__.py index 356eddbd..e8634b24 100644 --- a/toolz/curried/__init__.py +++ b/toolz/curried/__init__.py @@ -65,6 +65,7 @@ drop = toolz.curry(toolz.drop) excepts = toolz.curry(toolz.excepts) filter = toolz.curry(toolz.filter) +flatten = toolz.curry(toolz.flatten) get = toolz.curry(toolz.get) get_in = toolz.curry(toolz.get_in) groupby = toolz.curry(toolz.groupby) diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index 89931ae5..7bcec912 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -4,7 +4,7 @@ import operator from functools import partial from itertools import filterfalse, zip_longest -from collections.abc import Sequence +from collections.abc import Sequence, Mapping from toolz.utils import no_default @@ -13,7 +13,8 @@ 'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv', 'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate', 'sliding_window', 'partition', 'partition_all', 'count', 'pluck', - 'join', 'tail', 'diff', 'topk', 'peek', 'peekn', 'random_sample') + 'join', 'tail', 'diff', 'topk', 'peek', 'peekn', 'random_sample', + 'flatten') def remove(predicate, seq): @@ -484,6 +485,7 @@ def concat(seqs): See also: itertools.chain.from_iterable equivalent + flatten """ return itertools.chain.from_iterable(seqs) @@ -1051,3 +1053,56 @@ def random_sample(prob, seq, random_state=None): random_state = Random(random_state) return filter(lambda _: random_state.random() < prob, seq) + + +def _default_descend(x): + return not isinstance(x, (str, bytes, Mapping)) + + +def flatten(level, seq, descend=_default_descend): + """ Flatten a possibly nested sequence + + Inspired by Javascript's Array.flat(), this is a recursive, + depth limited flattening generator. A level 0 flattening will + yield the input sequence unchanged. A level -1 flattening + will flatten all possible levels of nesting. + + >>> list(flatten(0, [1, [2], [[3]]])) # flatten 0 levels + [1, [2], [[3]]] + >>> list(flatten(1, [1, [2], [[3]]])) # flatten 1 level + [1, 2, [3]] + >>> list(flatten(2, [1, [2], [[3]]])) + [1, 2, 3] + >>> list(flatten(-1, [1, [[[[2]]]]])) # flatten all levels + [1, 2] + + An optional ``descend`` function can be provided by the user + to determine which iterable objects to recurse into. This function + should return a boolean with True meaning it is permissible to descend + another level of recursion. The recursion limit of the Python interpreter + is the ultimate bounding factor on depth. By default, stings, bytes, + and mappings are exempted. + + >>> list(flatten(-1, ['abc', [{'a': 2}, [b'123']]])) + ['abc', {'a': 2}, b'123'] + + See also: + concat + """ + if level < -1: + raise ValueError("Level must be >= -1") + if not callable(descend): + raise ValueError("descend must be a callable boolean function") + + def flat(level, seq): + if level == 0: + yield from seq + return + + for item in seq: + if isiterable(item) and descend(item): + yield from flat(level - 1, item) + else: + yield item + + yield from flat(level, seq) diff --git a/toolz/tests/test_itertoolz.py b/toolz/tests/test_itertoolz.py index 25e7d39a..7786a893 100644 --- a/toolz/tests/test_itertoolz.py +++ b/toolz/tests/test_itertoolz.py @@ -1,3 +1,4 @@ +import pytest import itertools from itertools import starmap from toolz.utils import raises @@ -13,7 +14,7 @@ reduceby, iterate, accumulate, sliding_window, count, partition, partition_all, take_nth, pluck, join, - diff, topk, peek, peekn, random_sample) + diff, topk, peek, peekn, random_sample, flatten) from operator import add, mul @@ -547,3 +548,37 @@ def test_random_sample(): assert mk_rsample(b"a") == mk_rsample(u"a") assert raises(TypeError, lambda: mk_rsample([])) + + +def test_flat(): + seq = [1, 2, 3, 4] + assert list(flatten(0, seq)) == seq + assert list(flatten(1, seq)) == seq + + seq = [1, [2, [3]]] + assert list(flatten(0, seq)) == seq + assert list(flatten(1, seq)) == [1, 2, [3]] + assert list(flatten(2, seq)) == [1, 2, 3] + + # Test mappings + seq = [{'a': 1}, [1, 2, 3]] + assert list(flatten(0, seq)) == seq + assert list(flatten(1, seq)) == [{'a': 1}, 1, 2, 3] + + # Test stringsj + seq = ["asgf", b"abcd"] + assert list(flatten(-1, seq)) == seq + + # Test custom descend function + def descend(x): + if isinstance(x, str): + return len(x) != 1 + return False + seq = ["asdf", [1, 2, 3]] + assert list(flatten(1, seq, descend=descend)) == ["a", "s", "d", "f", [1, 2, 3]] + + with pytest.raises(ValueError): + list(flatten(0, [1, 2], descend=True)) + + with pytest.raises(ValueError): + list(flatten(-2, [1, 2]))