Skip to content

enhanced segment tree implementation and more pythonic #1715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 2, 2020
237 changes: 237 additions & 0 deletions data_structures/binary_tree/segment_tree_other.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""
Segment_tree creates a segment tree with a given array and function,
allowing queries to be done later in log(N) time
function takes 2 values and returns a same type value
"""

from queue import Queue
from collections.abc import Sequence


class SegmentTreeNode(object):
def __init__(self, start, end, val, left=None, right=None):
self.start = start
self.end = end
self.val = val
self.mid = (start + end) // 2
self.left = left
self.right = right

def __str__(self):
return 'val: %s, start: %s, end: %s' % (self.val, self.start, self.end)


class SegmentTree(object):
"""
>>> import operator
>>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add)
>>> for node in num_arr.traverse():
... print(node)
...
val: 15, start: 0, end: 4
val: 8, start: 0, end: 2
val: 7, start: 3, end: 4
val: 3, start: 0, end: 1
val: 5, start: 2, end: 2
val: 3, start: 3, end: 3
val: 4, start: 4, end: 4
val: 2, start: 0, end: 0
val: 1, start: 1, end: 1
>>>
>>> num_arr.update(1, 5)
>>> for node in num_arr.traverse():
... print(node)
...
val: 19, start: 0, end: 4
val: 12, start: 0, end: 2
val: 7, start: 3, end: 4
val: 7, start: 0, end: 1
val: 5, start: 2, end: 2
val: 3, start: 3, end: 3
val: 4, start: 4, end: 4
val: 2, start: 0, end: 0
val: 5, start: 1, end: 1
>>>
>>> num_arr.query_range(3, 4)
7
>>> num_arr.query_range(2, 2)
5
>>> num_arr.query_range(1, 3)
13
>>>
>>> max_arr = SegmentTree([2, 1, 5, 3, 4], max)
>>> for node in max_arr.traverse():
... print(node)
...
val: 5, start: 0, end: 4
val: 5, start: 0, end: 2
val: 4, start: 3, end: 4
val: 2, start: 0, end: 1
val: 5, start: 2, end: 2
val: 3, start: 3, end: 3
val: 4, start: 4, end: 4
val: 2, start: 0, end: 0
val: 1, start: 1, end: 1
>>>
>>> max_arr.update(1, 5)
>>> for node in max_arr.traverse():
... print(node)
...
val: 5, start: 0, end: 4
val: 5, start: 0, end: 2
val: 4, start: 3, end: 4
val: 5, start: 0, end: 1
val: 5, start: 2, end: 2
val: 3, start: 3, end: 3
val: 4, start: 4, end: 4
val: 2, start: 0, end: 0
val: 5, start: 1, end: 1
>>>
>>> max_arr.query_range(3, 4)
4
>>> max_arr.query_range(2, 2)
5
>>> max_arr.query_range(1, 3)
5
>>>
>>> min_arr = SegmentTree([2, 1, 5, 3, 4], min)
>>> for node in min_arr.traverse():
... print(node)
...
val: 1, start: 0, end: 4
val: 1, start: 0, end: 2
val: 3, start: 3, end: 4
val: 1, start: 0, end: 1
val: 5, start: 2, end: 2
val: 3, start: 3, end: 3
val: 4, start: 4, end: 4
val: 2, start: 0, end: 0
val: 1, start: 1, end: 1
>>>
>>> min_arr.update(1, 5)
>>> for node in min_arr.traverse():
... print(node)
...
val: 2, start: 0, end: 4
val: 2, start: 0, end: 2
val: 3, start: 3, end: 4
val: 2, start: 0, end: 1
val: 5, start: 2, end: 2
val: 3, start: 3, end: 3
val: 4, start: 4, end: 4
val: 2, start: 0, end: 0
val: 5, start: 1, end: 1
>>>
>>> min_arr.query_range(3, 4)
3
>>> min_arr.query_range(2, 2)
5
>>> min_arr.query_range(1, 3)
3
>>>

"""
def __init__(self, collection: Sequence, function):
self.collection = collection
self.fn = function
if self.collection:
self.root = self._build_tree(0, len(collection) - 1)

def update(self, i, val):
"""
Update an element in log(N) time
:param i: position to be update
:param val: new value
>>> import operator
>>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add)
>>> num_arr.update(1, 5)
>>> num_arr.query_range(1, 3)
13
"""
self._update_tree(self.root, i, val)

def query_range(self, i, j):
"""
Get range query value in log(N) time
:param i: left element index
:param j: right element index
:return: element combined in the range [i, j]
>>> import operator
>>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add)
>>> num_arr.update(1, 5)
>>> num_arr.query_range(3, 4)
7
>>> num_arr.query_range(2, 2)
5
>>> num_arr.query_range(1, 3)
13
>>>
"""
return self._query_range(self.root, i, j)

def _build_tree(self, start, end):
if start == end:
return SegmentTreeNode(start, end, self.collection[start])
mid = (start + end) // 2
left = self._build_tree(start, mid)
right = self._build_tree(mid + 1, end)
return SegmentTreeNode(start, end, self.fn(left.val, right.val), left, right)

def _update_tree(self, node, i, val):
if node.start == i and node.end == i:
node.val = val
return
if i <= node.mid:
self._update_tree(node.left, i, val)
else:
self._update_tree(node.right, i, val)
node.val = self.fn(node.left.val, node.right.val)

def _query_range(self, node, i, j):
if node.start == i and node.end == j:
return node.val

if i <= node.mid:
if j <= node.mid:
# range in left child tree
return self._query_range(node.left, i, j)
else:
# range in left child tree and right child tree
return self.fn(self._query_range(node.left, i, node.mid), self._query_range(node.right, node.mid + 1, j))
else:
# range in right child tree
return self._query_range(node.right, i, j)

def traverse(self):
if self.root is not None:
queue = Queue()
queue.put(self.root)
while not queue.empty():
node = queue.get()
yield node

if node.left is not None:
queue.put(node.left)

if node.right is not None:
queue.put(node.right)


if __name__ == '__main__':
import operator
for fn in [operator.add, max, min]:
print('*' * 50)
arr = SegmentTree([2, 1, 5, 3, 4], fn)
for node in arr.traverse():
print(node)
print()

arr.update(1, 5)
for node in arr.traverse():
print(node)
print()

print(arr.query_range(3, 4)) # 7
print(arr.query_range(2, 2)) # 5
print(arr.query_range(1, 3)) # 13
print()