线段树是一种树形结构,主要用于维护区间信息,可以在O(logN)的时间复杂度内实现单点修改、区间修改、区间查询、区间求和等操作。
看一个简单的例子,给定一个数组,要求数组的区间和,比如数组[1, 3, -1, 2, 5],求数组[1,3]区间内元素之和,一般来说可以通过下面两种方式来实现。
目前来看第二种方式更高效一点,但是如果数组是可变的呢?对于第一种方式,修改数组元素只需要O(1)的时间复杂度,而对于第二种方式,修改完数组元素后需要更新前缀和数组,这个更新的操作时间复杂度是O(N)。而线段树就是能够让修改元素和区间求和这个两个操作的时间复杂度都在O(logN)级别,这样就弥补了上述两种办法的缺点,下面看看线段树的逻辑结构。
下面看看线段树的实现,这里我们使用数组来保存线段树节点。
1/**
2 * 线段树,使用数组存储
3 */
4class SegmentTree {
5
6 /**
7 * 原数组
8 */
9 private int[] arr;
10 /**
11 * 保存线段树的数组
12 */
13 private int[] tree;
14
15 /**
16 * 构建线段树
17 *
18 * @param arr 原数组
19 */
20 public void build(int[] arr) {
21 this.arr = arr;
22 /**
23 * 通过数组的长度计算线段树层级高度
24 */
25 int level = (int) Math.ceil(Math.log(arr.length) / Math.log(2)) + 1;
26 int size = (int) Math.pow(2, level);
27 this.tree = new int[size];
28 build(0, 0, arr.length - 1);
29 }
30
31 /**
32 * 构建线段树
33 *
34 * @param node 线段树的节点
35 * @param start 当前节点区间的开始
36 * @param end 当前节点区间的结束
37 */
38 private void build(int node, int start, int end) {
39 /**
40 * 如果到了叶子节点,也就是当前节点区间中只有一个元素
41 */
42 if (start == end) {
43 tree[node] = arr[start];
44 } else {
45 /**
46 * 如果当前节点还有子节点,先计算左右子节点,然后将左右子节点的值相加赋值给当前节点
47 */
48 int mid = (start + end) / 2;
49 int left = 2 * node + 1;
50 int right = 2 * node + 2;
51 build(left, start, mid);
52 build(right, mid + 1, end);
53 tree[node] = tree[left] + tree[right];
54 }
55 }
56
57 /**
58 * 更新元素
59 *
60 * @param index 要更新的元素下标
61 * @param val 更新后的值
62 */
63 public void update(int index, int val) {
64 arr[index] = val;
65 update(0, index, 0, arr.length - 1);
66 }
67
68 /**
69 * 更新元素、线段树
70 *
71 * @param node 线段树节点
72 * @param index 要更新的元素下标
73 * @param start 当前节点区间的开始
74 * @param end 当前节点区间的结束
75 */
76 private void update(int node, int index, int start, int end) {
77 /**
78 * 如果当前节点区间只有一个元素,直接更新
79 */
80 if (start == end) {
81 tree[node] = arr[index];
82 } else {
83 /**
84 * 如果当前节点还有子节点,先更新子节点,然后再更新当前节点
85 */
86 int mid = (start + end) / 2;
87 int left = 2 * node + 1;
88 int right = 2 * node + 2;
89 if (start <= index && index <= mid) {
90 update(left, index, start, mid);
91 } else {
92 update(right, index, mid + 1, end);
93 }
94 tree[node] = tree[left] + tree[right];
95 }
96 }
97
98 /**
99 * 区间查询
100 *
101 * @param L 区间开始位置
102 * @param R 区间结束位置
103 * @return 区间元素之和
104 */
105 public int query(int L, int R) {
106 return query(0, 0, arr.length - 1, L, R);
107 }
108
109 private int query(int node, int start, int end, int L, int R) {
110 /**
111 * 如果当前节点区间和要查询的区间没有交集
112 */
113 if (R < start || L > end) {
114 return 0;
115 }
116 /**
117 * 如果查询的区间完全包含当前节点区间
118 */
119 if (L <= start && end <= R) {
120 return tree[node];
121 }
122 /**
123 * 如果当前节点区间和要查询的区间有交集,但是并不是完全包含
124 */
125 int mid = (start + end) / 2;
126 int left = 2 * node + 1;
127 int right = 2 * node + 2;
128 int sumLeft = query(left, start, mid, L, R);
129 int sumRight = query(right, mid + 1, end, L, R);
130 return sumLeft + sumRight;
131 }
132}
133
134public class Main {
135
136 public static void main(String[] args) {
137 int[] arr = {1, 3, -1, 2, 5};
138 SegmentTree tree = new SegmentTree();
139 tree.build(arr);
140 System.out.println(tree.query(0, 3));
141 tree.update(2, 1);
142 System.out.println(tree.query(0, 3));
143 }
144}
上面的实现是用来计算区间求和的代码,稍微修改一下还可以实现更多功能,比如区间求最值。