线段树相关

前言

arr数组

query(L, R)表示arr[L] 到arr[R]之间的数字之和,复杂度O(n)

update(arr, i, val)表示要把arr[i]中的数字改成val,复杂度O(1)

若想降低query的复杂度

方案1:

前缀和

启用一个sum[]数组,

query复杂度减少到O(1)

但是 update的复杂度上升到了O(n)

引出 线段树

使得两个操作的复杂度都变为O(㏒n)

线段树

如下的数组arr[0…5]

012345
1357911

query操作

update操作

如何存储这个二叉树?

左孩子节点序号 left_node = 2 x node + 1

右孩子节点序号 right_node = 2 x node + 2

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <iostream>

using namespace std;

const int N = 1000;

void buildTree(int arr[], int tree[], int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
}
else {
int mid = start + end >> 1;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;

buildTree(arr, tree, leftNode, start, mid);
buildTree(arr, tree, rightNode, mid + 1, end);
tree[node] = tree[leftNode] + tree[rightNode];
}
}

void updateTree(int arr[], int tree[], int node, int start, int end, int index, int val) {
if (start == end) {
arr[index] = val;
tree[node] = val;
}
else{
int mid = start + end >> 1;

int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;

if (index >= start && index <= mid) {
updateTree(arr, tree, leftNode, start, mid, index, val);

}
else {
updateTree(arr, tree, rightNode, mid + 1, end, index, val);
}
tree[node] = tree[leftNode] + tree[rightNode];

}
}
int queryTree(int arr[], int tree[], int node, int start, int end, int L, int R) {
//printf("start = %d\n", start);
//printf("end = %d\n", end);
//printf("\n");

if (R < start || L > end) {
return 0;
}
else if (start >= L && end <= R) {
return tree[node];
}
else if (start == end) {
return tree[node];
}
else {
int mid = start + end >> 1;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;

int sumLeft = queryTree(arr, tree, leftNode, start, mid, L, R);
int sumRight = queryTree(arr, tree, rightNode, mid + 1, end, L, R);
return sumLeft + sumRight;
}
}

int main() {

int arr[] = {1, 3, 5, 7, 9, 11};
int size = 6;
int tree[N] = {0};

buildTree(arr, tree, 0, 0, size - 1);

for (int i = 0; i < 15; ++ i) {
printf("tree[%d] = %d \n", i, tree[i]);
}
printf("\n");

updateTree(arr, tree, 0, 0, size - 1, 4, 6);

for (int i = 0; i < 15; ++ i) {
printf("tree[%d] = %d \n", i, tree[i]);
}
printf("\n");

int s = queryTree(arr, tree, 0, 0, size - 1, 2, 5);
printf("%d\n", s);
return 0;
}