线段树

前言

记得大一的时候耶学习过线段树, 当时这是我唯一认真学习过的数据结构, 至于什么 AVL、红黑树. 当时基本上要么是没看, 要么就是知难而退了, 唯独线段树自己认真啃完了大牛的博客, 也自己用 C++实现过了一遍.

但是应该怎么也没料到, 自己会在大三的一次学校组织的 ACM 竞赛中被线段树绊住了脚. 这次就来认真学习一遍线段树, 争取下次需要用到的时候能够胸有成竹.

题目

查询数组某个区间的最大值、并在每次查询后更新单个值.

线段树详解

  1. 什么是线段树 线段树所记录的是数组(一般是数组)区间内的信息,比如区间内的元素求和、连乘之类的。比如树中的根节点就记录了数组 A 中从索引 0 到索引 9 的元素信息,比如求和、求区间最大值。线段树的叶子节点只有一个元素。但不是完全二叉树,是一棵平衡二叉树(树的最大深度至多比最小深度多 1)。但是一般的,我们认为线段树是一棵满二叉树(为了能像堆那样用索引来表示左、右孩子,只不过堆不一定是满二叉树),只需在不存在元素的地方用 None 填充就好了。线段树一般是以中点索引来分割左、右子树的。
  2. 为什么要使用线段树 对于需要频繁查询数组区间的区间值、并且需要频繁修改数组的值的情况下. 就需要使用到线段树, 因为线段树采用了空间换取时间的思想, 每次查询和每次修改后维护的时间复杂度都只是 log(n) 级别, 更为高效.
  3. 使用线段树的限制 线段树无法处理数组元素的添加或者删除操作, 只支持空间内的元素值的修改.

线段树的思想

线段树的基础思想就是通过额外的多个空间来记录数组的小段区间值, 并在数组修改后能够自底向上进行修改所有需要修改的区间、而不需要修改的区间就不会被访问, 因此修改的复杂度减小了许多.

Python 实现

from typing import List


class SegmentTree:
    def __init__(self, alist: List[int], merger_):
        """初始化线段树
        Args:
            alist (List[int]): 初始化时传入的一个list,为需要进行区间索引的数组
            merger_ (function): merge函数,用于对实现两个数合成一个数对功能
        """
        self.data = alist
        self.size = len(self.data)
        self.tree = [-float('inf')] * 4 * len(self.data)
        self.merger = merger_
        self._buildSegmentTree(0, 0, len(self.data) - 1)

    def _buildSegmentTree(self, treeIndex: int, left: int, right: int):
        """构建线段树,通过递归构建线段树
        Args:
            treeIndex (int): 当前根节点的索引
            left (int): 当前根节点的左边界
            right (int): 当前根节点的右边界
        """
        if left == right:
            self.tree[treeIndex] = self.data[left]
            return

        leftChild_index = 2 * treeIndex + 1
        rightChild_index = 2 * treeIndex + 2

        mid = (left + right) // 2
        self._buildSegmentTree(leftChild_index, left, mid)
        self._buildSegmentTree(rightChild_index, mid + 1, right)
        self.tree[treeIndex] = self.merger(
            self.tree[leftChild_index], self.tree[rightChild_index])

    def quary(self, left: int, right: int) -> int:
        """接口函数,用于查询区间值
        Args:
            left (int): 区间左边界
            right (int): 区间右边界
        Raises:
            Exception: 数组索引越界
        Returns:
            int: 返回区间值
        """
        if 0 <= left < self.size and 0 <= right < self.size:
            return self._find(0, 0, self.size, left, right)
        else:
            raise Exception('The indexes is illegal!')

    def _find(self, treeIndex: int, left: int, right: int,
              quaryL: int, quaryR: int):
        """非接口函数,递归查询区间值
        Args:
            treeIndex (int): 当前根节点索引
            left (int): 当前根节点左边界
            right (int): 当前根节点右边界
            quaryL (int): 查询区间左边界
            quaryR (int): 查询区间右边界
        Returns:
            [type]: 返回区间值
        """
        if left == quaryL and right == quaryR:
            return self.tree[treeIndex]
        mid = (left + right) // 2
        leftChild_index = 2 * treeIndex + 1
        rightChild_index = 2 * treeIndex + 2

        if quaryL > mid:
            return self._find(rightChild_index, mid + 1, right, quaryL, quaryR)
        elif quaryR <= mid:
            return self._find(leftChild_index, left, mid, quaryL, quaryR)

        leftResult = self._find(leftChild_index, left, mid, quaryL, mid)
        rightResult = self._find(
            rightChild_index, mid + 1, right, mid + 1, quaryR)

        return self.merger(leftResult, rightResult)

    def update(self, index: int, value):
        """接口函数,单点更新
        Args:
            index (int): 更新点的位置
            value (any): 更新的值
        Raises:
            Exception: 数组索引越界
        """
        if 0 <= index < self.size:
            self.data[index] = value
            self._maintain(0, 0, self.size - 1, index, value)
        else:
            raise Exception('The index is illegal!')

    def _maintain(self, treeIndex: int, left: int, right: int,
                  index: int, value):
        """非接口函数,递归维护所有包含区间的值
        Args:
            treeIndex (int): 当前根节点的索引
            left (int): 当前根节点左边界
            right (int): 当前根节点右边界
            index (int): 更新点的位置
            value (any): 更新的值
        """
        if left == right:
            self.tree[treeIndex] = value
            return

        mid = (left + right) // 2
        leftChild_index = 2 * treeIndex + 1
        rightChild_index = 2 * treeIndex + 2

        if index <= mid:
            self._maintain(leftChild_index, left, mid, index, value)
        else:
            self._maintain(rightChild_index, mid + 1, right, index, value)
        self.tree[treeIndex] = self.merger(
            self.tree[leftChild_index], self.tree[rightChild_index])

正是你花费在玫瑰上的时间才使得你的玫瑰花珍贵无比