线段树

学习数据结构与算法
2021-05-17 14:29 · 阅读时长7分钟
小课

线段树是一种树形结构,主要用于维护区间信息,可以在O(logN)的时间复杂度内实现单点修改、区间修改、区间查询、区间求和等操作。

看一个简单的例子,给定一个数组,要求数组的区间和,比如数组[1, 3, -1, 2, 5],求数组[1,3]区间内元素之和,一般来说可以通过下面两种方式来实现。

  1. 直接遍历数组区间,然后将元素相加,例如,sum[1,3]=3+(-1)+2=4,时间复杂度是O(N)。
  2. 先求出数组的前缀和,然后再通过前缀和求区间和,例如,prefix=[1, 4, 3, 5, 10],sum[1,3]=prefix[3]-prefix[0]=4,求区间和的这个操作时间复杂度是O(1)。

目前来看第二种方式更高效一点,但是如果数组是可变的呢?对于第一种方式,修改数组元素只需要O(1)的时间复杂度,而对于第二种方式,修改完数组元素后需要更新前缀和数组,这个更新的操作时间复杂度是O(N)。而线段树就是能够让修改元素和区间求和这个两个操作的时间复杂度都在O(logN)级别,这样就弥补了上述两种办法的缺点,下面看看线段树的逻辑结构。

线段树

下面看看线段树的实现,这里我们使用数组来保存线段树节点。

1/**
2 * 线段树,使用数组存储
3 */
4class SegmentTree {
5
6    /**
7     * 原数组
8     */
9    private int[] arr;
10    /**
11     * 保存线段树的数组
12     */
13    private int[] tree;
14
15    /**
16     * 构建线段树
17     *
18     * @param arr 原数组
19     */
20    public void build(int[] arr) {
21        this.arr = arr;
22        /**
23         * 通过数组的长度计算线段树层级高度
24         */
25        int level = (int) Math.ceil(Math.log(arr.length) / Math.log(2)) + 1;
26        int size = (int) Math.pow(2, level);
27        this.tree = new int[size];
28        build(0, 0, arr.length - 1);
29    }
30
31    /**
32     * 构建线段树
33     *
34     * @param node  线段树的节点
35     * @param start 当前节点区间的开始
36     * @param end   当前节点区间的结束
37     */
38    private void build(int node, int start, int end) {
39        /**
40         * 如果到了叶子节点,也就是当前节点区间中只有一个元素
41         */
42        if (start == end) {
43            tree[node] = arr[start];
44        } else {
45            /**
46             * 如果当前节点还有子节点,先计算左右子节点,然后将左右子节点的值相加赋值给当前节点
47             */
48            int mid = (start + end) / 2;
49            int left = 2 * node + 1;
50            int right = 2 * node + 2;
51            build(left, start, mid);
52            build(right, mid + 1, end);
53            tree[node] = tree[left] + tree[right];
54        }
55    }
56
57    /**
58     * 更新元素
59     *
60     * @param index 要更新的元素下标
61     * @param val   更新后的值
62     */
63    public void update(int index, int val) {
64        arr[index] = val;
65        update(0, index, 0, arr.length - 1);
66    }
67
68    /**
69     * 更新元素、线段树
70     *
71     * @param node  线段树节点
72     * @param index 要更新的元素下标
73     * @param start 当前节点区间的开始
74     * @param end   当前节点区间的结束
75     */
76    private void update(int node, int index, int start, int end) {
77        /**
78         * 如果当前节点区间只有一个元素,直接更新
79         */
80        if (start == end) {
81            tree[node] = arr[index];
82        } else {
83            /**
84             * 如果当前节点还有子节点,先更新子节点,然后再更新当前节点
85             */
86            int mid = (start + end) / 2;
87            int left = 2 * node + 1;
88            int right = 2 * node + 2;
89            if (start <= index && index <= mid) {
90                update(left, index, start, mid);
91            } else {
92                update(right, index, mid + 1, end);
93            }
94            tree[node] = tree[left] + tree[right];
95        }
96    }
97
98    /**
99     * 区间查询
100     *
101     * @param L 区间开始位置
102     * @param R 区间结束位置
103     * @return 区间元素之和
104     */
105    public int query(int L, int R) {
106        return query(0, 0, arr.length - 1, L, R);
107    }
108
109    private int query(int node, int start, int end, int L, int R) {
110        /**
111         * 如果当前节点区间和要查询的区间没有交集
112         */
113        if (R < start || L > end) {
114            return 0;
115        }
116        /**
117         * 如果查询的区间完全包含当前节点区间
118         */
119        if (L <= start && end <= R) {
120            return tree[node];
121        }
122        /**
123         * 如果当前节点区间和要查询的区间有交集,但是并不是完全包含
124         */
125        int mid = (start + end) / 2;
126        int left = 2 * node + 1;
127        int right = 2 * node + 2;
128        int sumLeft = query(left, start, mid, L, R);
129        int sumRight = query(right, mid + 1, end, L, R);
130        return sumLeft + sumRight;
131    }
132}
133
134public class Main {
135
136    public static void main(String[] args) {
137        int[] arr = {1, 3, -1, 2, 5};
138        SegmentTree tree = new SegmentTree();
139        tree.build(arr);
140        System.out.println(tree.query(0, 3));
141        tree.update(2, 1);
142        System.out.println(tree.query(0, 3));
143    }
144}
注意: 这个Java运行环境不支持自定义包名,并且public class name必须是Main

上面的实现是用来计算区间求和的代码,稍微修改一下还可以实现更多功能,比如区间求最值。

线段树Segment tree