ホームに戻る
セグメント木に関する知見
0、はじめに
セグメント木はある区間の値をO(logN)で変更することができます。
使用例:
・N=100000個の数のうちl番目からr番目までの数の最大値をO(logN)で求める。
・N=100000個の数のうちl番目からr番目までに数xをO(logN)で加算する。
・N=100000個の数のうちl番目からr番目までの和、積をO(logN)で求める。
・N=100000個の数のうちl番目からr番目までのデータをO(logN)で結合する。
以上のようなことが出来ます。
N=100000個のデータに対してO(logN)で操作ができるのは不思議ですが、
セグメント木を使えばこのような操作が実際に可能になります。
競技プログラミングでセグメント木を使う場合はN=100000前後であることが多い。
なぜならN=100程度だとO(logN)してもそれほど高速化のメリットは得られず、
またN=1000000000だと使用するメモリが足らなくなるという理由がある。
よって、N=100000で”ある区間”の値を変更するということをしたかった場合に、
セグメント木が想定であることが多い。
また区間の値を変換する回数も100000回前後であることが多い。
これはN=100000の区間を1回更新するだけでは0(N)にしかならないため。
書き換え回数を100000にすうrことでO(N^2)にさせる。
セグメント木を使うことでO(N^2)をO(NlogN)にする問題はよく見られる。
セグメント木と似たアルゴリズムに平方分割がある。
平方分割はO(N^(1/2))なのでセグメント木よりも遅い。
しかし、制限時間が2秒ならO(NlogN)もO(NN^(1/2))も間に合う。
セグメント木を考えるときは平方分割することも考えると良い。
セグメント木はテンプレートを貼るだけでは対応出来ない問題が多く、
場合に応じてテンプレートを書き換えてやる必要がある。
よって、基本的な考え方と実装について細かく書いていきたい。
1、セグメント木の構造
まずデータ数が2の累乗数でないといけない制限があります。
しかしこれはデータ数を大きめにとれば良いだけなので大した制限ではありません。
たとえばN=100000なら2^17=131072とし100001以降は使わなければいい。
セグメント木はデータの数がNであればちょうどN*2-1のメモリを使用する。
実際には2*Nのメモリを用意してメモリの0位置から使用する場合と、
メモリの1位置から使用する場合に分かれるようである。
どちらでも良いが以降は0位置から使用することで統一する。
統一しないと間違いの元になります。
セグメント木の初期化は最初に極値で埋めるか、
数値を入れる場合は1つずつ入れるのでO(NlogN)で可能となる。
しかし、以下のようにすればO(N)で埋まるので覚えておきたい。
//N=100000で0から99999まで0で埋める。
//最小値を求めるセグメント木の場合。
#define N 131072
long long dat[2*N-1];
for(long long i = N-1; i < 2*N-1; i++){
dat[i] = 0;
}
for(long long i = N-2; i >= 0; i--){
dat[i] = min(dat[i*2+1],dat[i*2+2]);
}
この初期化がなぜこれで良いのかを確認することは、
セグメント木の理解につながるのでぜひやって欲しい。
2、最も簡単なセグメント木
最もシンプルな最大値、最小値を求めるセグメント木を考える。
このセグメント木の特徴は、
・1点を更新する。
・区間の最大値、最小値を求める。
以上である。
例として区間の最大値を求めるセグメント木を考える。
初期化として最初に全ての値を-INFにしておく必要がある。
1点を更新する方法は以下のようになる。
セグメント木でなければO(1)だがセグメント木ではO(logN)となる。
//位置kに値aを入れる。
void update(int k, long long a){
k += N-1;
dat[k] = a;
while(k > 0){
k = (k-1) / 2;
dat[k] = max(dat[k*2+1],dat[k*2+2]);
}
}
では区間の最大値を求めてみる。O(logN)である。
区間は a<=x<bでa,bを指定する。k=0,l=0,r=Nとする。
long long query_mx(int a, int b, int k, int l, int r){
// 区間に入っていなければ-INFを返す。
if(r <= a || b <= l)return -INF;
// 区間にピッタリ収まっていればそのときの値を返す。
if(a <= l && r <= b){
return dat[k];
}
// 区間をはみ出す場合はさらに細かい区間に進む。
else{
long long vl = query_mx(a, b, k*2+1, l, (l+r)/2);
long long vr = query_mx(a, b, k*2+2, (l+r)/2, r);
return max(vl, vr);
}
}
大きな区間から見ていき区間からはみ出す場合にはさらに細かい区間を調べる。
区間がピッタリ収まる区間にきたらそのときの最大値を返す。
この更新と区間を調べるやり方は基本なのでメモリのどの位置にいて、
どの位置にどのような情報があるかの詳細な理解が必要です。
3、データを行列にしたセグメント木
セグメント木の1つ1つの情報を行列にする場合は良くあります。
例えば次のような問題を考えてみます。
問題
最初、B(i)は全て0である。
A(i)については以下の漸化式で求められる。
A(0)=1,A(i+1)=B(i)*A(i)+1
次のような指示が与えられる。
・B(i)の値を更新する。
・その時点でA(i)の値を求める。
0<=i<=100000で指示は最大100000まで順に与えられる。
さて、この問題をどう考えるか?
難点はB(i)の値を1つ更新するとA(i+1)以降の答えが全て変わる点です。
これをセグメント木を使って解決する方法を考えます。
まず漸化式なので行列で更新することを考えます。
更新の行列は次のようなものだと思います。
|B(i) 1||A(i)|
|0 1||1 |
よって、2行2列の情報をセグメント木で扱えるようにしたいと思います。
先に、2*2と2*2の行列の掛け算をする関数を書いておきます。
void mul(long long d1[4], long long d2[4], long long d3[4]){
d3[0] = d1[0]*d2[0]+d1[1]*d2[2];
d3[1] = d1[0]*d2[1]+d1[1]*d2[3];
d3[2] = d1[2]*d2[0]+d1[3]*d2[2];
d3[3] = d1[2]*d2[1]+d1[3]*d2[3];
}
メモリは以下のようでいいですね。
#define N 8
long long dat[2*N-1][4];
最初B(i)は全て0ですので0で初期化します。
for(long long i = N-1; i < 2*N-1; i++){
dat[i][0] = 0;
dat[i][1] = 1;
dat[i][2] = 0;
dat[i][3] = 1;
}
for(long long i = N-2; i >= 0; i--){
mul(dat[i*2+2],dat[i*2+1],dat[i]);
}
行列なので掛け算の順番は大切です、順番に気をつけましょう。
行列G[0]から順にG[3]まで掛け算していきたいとき,
G[0]からG[3]の行列をまとめるとG[3]G[2]G[1]G[0]と逆になります。
ではB(i)を更新してみましょう。
void update(int k, long long a){
k += N-1;
dat[k][0] = a;
while(k > 0){
k = (k-1) / 2;
mul(dat[k*2+2],dat[k*2+1],dat[k]);
}
}
区間に関しては区間での行列積を得ます。
void query_ith(int a, int b, int k, int l, int r, long long d3[4]){
// 区間に入っていなければ単位行列を返す。
if(r <= a || b <= l){
d3[0]=1;d3[1]=0;d3[2]=0;d3[3]=1;
return;
}
// 区間にピッタリ収まっていればそのときの値を返す。
if(a <= l && r <= b){
for(int i = 0; i < 4; i++){
d3[i] = dat[k][i];
}
}
// 区間をはみ出す場合はさらに細かい区間に進む。
else{
ll d1[4],d2[4];
query_ith(a, b, k*2+1, l, (l+r)/2, d1);
query_ith(a, b, k*2+2, (l+r)/2, r, d2);
mul(d2, d1, d3);
}
}
次のようにしてB(i)の値をxに更新します。
update(i,x);
以下のようにしてA[i+1]が求まります。
A[i+1]=query_ith(0,i+1,0,0,N)*A[0]
次のように使って見ました。
update(0,2); // B(0)を2に更新
update(1,1); // B(1)を1に更新
// A(1)からA(N)を表示
for(long long i = 1; i <= N; i++){
long long d1[4];
query_ith(0,i,0,0,N,d1);
cout << d1[0]+d1[1] << endl;
}
4、遅延評価セグメント木を書いてみる。
遅延評価セグメント木というと難しそうです。
できることは例えば、
・O(logN)である区間に値を加算できる。
・O(logN)である区間の総和を計算できる。
N=100000のときにO(logN)で全ての領域に1を加算などができます。
もちろん実際に加算することはできませんので、
すべての区間に加算しているように見せるという工夫です。
ポイントは記憶するセグメント木を2本使うところです。
まず1本は実際の総和を入れておくセグメント木です。
もう1本はその奥への更新を保留しておくセグメント木です。
例えば、ある区間に加算をするとき、
すべての区間に1つずつ加算することはできません。
大きな区間で総和を記録したらその先の値の更新は保留にします。
実際にその先の値が欲しかったら保留ぶんを更新しつつ取りに行きます。
値を取りに行くまで評価が遅延するのです。
簡単のため区間の和を記録する遅延評価セグメント木を書きます。
実際の総和を入れる配列をd1、保留を入れる配列をd2とします。
d1には足し算後の和が入っています。
d2には和をする前の足す数が入っています。
初期化は以下の通り。
#define N 8
long long d1[2*N-1];
long long d2[2*N-1];
for(long long i = N-1; i < 2*N-1; i++){
d1[i] = 0;
d2[i] = 0;
}
for(long long i = N-2; i >= 0; i--){
d1[i] = d1[i*2+1]+d1[i*2+2];
d2[i] = 0;
}
区間への加算をします。
void add(int a, int b, long long x, int k, int l, int r) {
if (a <= l && r <= b) {
d1[k] += (r-l)*x;
// さらに奥があれば遅延させる値を1つ奥に押し込む
if(k<N-1){
d2[k*2+1] += x;
d2[k*2+2] += x;
}
}
else if (l < b && a < r) {
d1[k] += (min(b, r) - max(a, l)) * x;
add(a, b, x, k * 2 + 1, l, (l + r) / 2);
add(a, b, x, k * 2 + 2, (l + r) / 2, r);
}
}
区間の総和を求めます。
long long sum(int a, int b, int k, int l, int r) {
if (b <= l || r <= a)return 0LL;
else if (a <= l && r <= b) {
return d2[k] * (r-l) + d1[k];
}
else {
// 保留していたデータを確定し1つ先に送る。
d1[k] += (r-l)*d2[k];
if(k<N-1){
d2[k*2+1] += d2[k];
d2[k*2+2] += d2[k];
}
d2[k] = 0;
long long ret = 0LL;
ret += sum(a, b, k * 2 + 1, l, (l + r) / 2);
ret += sum(a, b, k * 2 + 2, (l + r) / 2, r);
return ret;
}
}
5、Starry Sky Treeを書いてみる。
できることは例えば、
・O(logN)である区間に値を加算できる。
・O(logN)である区間の値を取得できる。
遅延評価セグメント木と違うのは、
取得する範囲のデータが一律でないといけないところです。
区間内の数の総和など区間での値が不揃いな場合は使えません。
Starry Sky Treeは値を更新した際に手前に値を持って帰ります。
評価するときには値は動きません。
Starry Sky Treeもセグメント木を2本使います。
まず1本は区間に対して変更した値を入れておくセグメント木です。
もう1本はその手前に戻るときに更新されるセグメント木です。
まずある区間で更新をするとします。
その区間でピッタリ収まる領域に値を加算します。
そして、再起で戻ってくる際に奥の両側から値を持ってきて、
手前のデータを更新していきます。
実際に区間を参照する際は手前の情報しか使いません。
区間に加算できて区間の最小値を求めるStarrt Sky Treeを書きます。
初期化は次のようです。
初期値は全て1にしてみます。(初期値はd1に入れること。)
#define N 8
long long d1[2*N-1];
long long d2[2*N-1];
for(long long i = N-1; i < 2*N-1; i++){
d1[i] = 1; // 例として初期値は1とする。
d2[i] = 0;
}
for(long long i = N-2; i >= 0; i--){
d1[i] = min(d1[i*2+1] + d2[i*2+1], d1[i*2+2] + d2[i*2+2]);
d2[i] = 0;
}
区間に対する値の加算。
void add(int a, int b, long long x, int k, int l, int r) {
if (r <= a || b <= l)return;
// 区間にピッタリ収まっていればd2を更新する。
if (a <= l && r <= b) {
d2[k] += x;
return;
}
add(a, b, x, k * 2 + 1, l, (l + r) / 2);
add(a, b, x, k * 2 + 2, (l + r) / 2, r);
// 戻ってくるときにd1を更新する。
d1[k] = min(d1[k*2+1] + d2[k*2+1], d1[k*2+2] + d2[k*2+2]);
}
区間の最小値を求める。
long long sst(int a, int b, int k, int l, int r) {
// 範囲外で評価させたくないので INF を返す。
if (b <= l || r <= a)return INF;
if (a <= l && r <= b) {
return d1[k] + d2[k];
}
else {
long long vl = sst(a, b, k * 2 + 1, l, (l + r) / 2);
long long vr = sst(a, b, k * 2 + 2, (l + r) / 2, r);
return min(vl,vr) + d2[k];
}
}
6、例題
配列A1,A2,...,Anがある。
配列Biはフィボナッチ数列のAi番目の項です。
フィボナッチ数はF(0)=0,F(1)=1,F(i+2)=F(i+1)+F(i)です。
例えばAの配列が{1,3,4,6}ならBの配列は{1,2,3,8}になります。
次の指示があります。
・AaからAbの区間にxを加算する。
・その時点でのBaからBbの区間の値の総和を求める。
指示に対して対応せよ。
制約は次のようです。
1<=n<=100000、1<=指示の数<=100000。
1<=Ai<=1000000000。1<=x<=1000000000。
Starry Sky Treeを使うと正解できる。
以下は、Starry Sky Treeの部分のみ抜粋してある。
// Matは行列クラス
// addMat、mulMat、powMatはそれぞれ加算、掛算、乗算。
// Mat()、unitMat()はそれぞれゼロ行列、単位行列。
#define N (1<<17)
Mat d1[2*N-1];
Mat d2[2*N-1];
void add(int a, int b, Mat x, int k, int l, int r) {
if (r <= a || b <= l)return;
if (a <= l && r <= b) {
d2[k] = mulMat(d2[k],x);
return;
}
add(a, b, x, k * 2 + 1, l, (l + r) / 2);
add(a, b, x, k * 2 + 2, (l + r) / 2, r);
d1[k] = addMat(mulMat(d1[k*2+1],d2[k*2+1]),mulMat(d1[k*2+2],d2[k*2+2]));
}
Mat sst(int a, int b, int k, int l, int r) {
if (b <= l || r <= a)return Mat();
if (a <= l && r <= b) {
return mulMat(d1[k],d2[k]);
}
else {
Mat vl = sst(a, b, k * 2 + 1, l, (l + r) / 2);
Mat vr = sst(a, b, k * 2 + 2, (l + r) / 2, r);
return mulMat(d2[k],addMat(vl,vr));
}
}
// 初期化
for(long long i = 0; i < 2*N-1; i++){
d2[i] = unitMat();
}
for(long long i = N-1; i < 2*N-1; i++){
d1[i] = unitMat(); // 1項目で初期化
}
for(long long i = N - 2; i >= 0; i--) {
d1[i] = addMat(mulMat(d1[i*2+1],d2[i*2+1]),mulMat(d1[i*2+2],d2[i*2+2]));
}