Skip to content
Go back

Segment Tree

Edit Page

https://leetcode.com/problems/range-sum-query-mutable/

1-indexed array. The ith element’s left child is 2*i and right child is 2*i+1. The parent of a node at i is int(i/2) The first value tree[1] stores the range of the whole array.

Time complexity:

Space: O(n) n*4 for the usual recursive approach, n*2 for iterative.

Recursive

class SegmentTreeRecursive:
    def __init__(self, nums: list[int]):
        self.nums = nums
        self.n = len(nums)
        height = math.ceil(math.log2(len(nums)))
        size = 2 ** (height + 1) - 1
        self.tree = [0] * (size)
        self.build()

    def build(self, x=1, left=0, right=None):
        if right is None:
            right = self.n - 1
        if left == right:
            self.tree[x] = self.nums[left]
            return
        mid = (left + right) // 2
        self.build(x * 2, left, mid)
        self.build(x * 2 + 1, mid + 1, right)
        self.tree[x] = self.tree[x * 2] + self.tree[x * 2 + 1]

    def update(self, target, val, x=1, left=0, right=None):
        if right is None:
            right = self.n - 1
        if left == right:
            self.tree[x] = val  # leaf node
            return
        mid = (left + right) // 2
        if target <= mid:
            self.update(target, val, x * 2, left, mid)
        else:
            self.update(target, val, x * 2 + 1, mid + 1, right)
        self.tree[x] = self.tree[x * 2] + self.tree[x * 2 + 1]

    def sumRange(self, i, j, x=1, left=0, right=None):
        if right is None:
            right = self.n - 1
        if left > j or right < i:
            return 0
        if i <= left and j >= right:
            return self.tree[x]
        mid = (left + right) // 2
        return self.sumRange(i, j, x * 2, left, mid) + self.sumRange(
            i, j, x * 2 + 1, mid + 1, right
        )

Taking an array from 0 to 12 as example. The segment tree:

Recursive segment tree

Iterative

class SegmentTree:
    def __init__(self, nums: List[int]):
        self.n = len(nums)
        self.tree = [0] * (2 * self.n)
        self.nums = nums
        self.build()

    def build(self):
        # init leaf nodes
        for i in range(self.n):
            self.tree[self.n + i] = self.nums[i]

        for i in range(self.n-1, 0, -1):
            self.tree[i] = self.tree[i*2] + self.tree[i*2+1]


    def update(self, idx: int, val: int) -> None:
        idx += self.n
        self.tree[idx] = val # leaf node
        while idx > 1:
            idx = idx // 2
            self.tree[idx] = self.tree[idx*2] + self.tree[idx*2+1]


    def sumRange(self, left: int, right: int) -> int:
        left += self.n
        right += self.n
        s = 0
        while left <= right:
            if left % 2 == 1: # is a right child
                s += self.tree[left]
                left += 1
            if right % 2 == 0: # is a left child
                s += self.tree[right]
                right -= 1
            left = left // 2
            right = right // 2
        return s

N = 13. Iterative segment tree:

Iterative segment tree

The graph might look counter-intuitive. But the segment tree works the same.

Application

Aside from calculating range sum. The query function can be calculating minimum/maximum of a range.

https://leetcode.com/problems/count-number-of-teams/

For the problem 1395. Count Number of Teams. We convert counting the number of smaller/greater numbers into range sum, by counting frequency.


Edit Page
Share this post on:

Previous Post
Setting Up A New Mac
Next Post
Optimizing Hive Query