반응형

 

세그먼트 트리(Segment Tree)란?

알고리즘 문제를 풀다 보면 정렬되어 있지 않은 구간 내의 합이나 최솟값들을 빠르게 찾아야 하는 경우가 많습니다. 질문이 1개일 경우 간단하게 for문을 통해서 O(N)의 시간 복잡도로 최솟값이나 합을 구하는 것은 쉽습니다. 그러나 만약 질문의 수가 수 만개 혹은 수 십만 개라면? 실행시간 초과가 발생할 경우가 큽니다.

그러한 문제점을 해결하기 위해 세그먼트 트리가 사용되며 "배열의 구간정보를 담고 있는 트리"입니다. 세그먼트 트리를 이용할 경우 기존의 for문이 O(N)의 시간 복잡도보다 빠른 O(logN)의 시간 복잡도로 답을 빠르게 찾을 수 있게 됩니다.

 

예를 들어보겠습니다.

먼저 입력으로 받을 배열을 input[],
세그먼트 트리로 저장할 배열을 seg[]라고 할 때,

input = {7, 3, 2, 6, 5, 8, 1, 4}라고 합시다.

INDEX 0 1 2 3 4 5 6 7
input[ ] 7 3 2 6 5 8 1 4

입력된 배열로 segment tree를 만들어보면 다음과 같습니다.

 

먼저 트리의 맨 하단에는 input배열의 원소들이 있습니다.

그리고 최종적으로 만들어지는 세그먼트 트리는 다음과 같습니다.

각각의 세그먼트 트리들은 자식 노드들의 최솟값을 담고 있으며, 세그먼트 트리의 최상단인 seg[1]은 전 구간에서 최솟값인 "1"을 담고 있습니다.

 

이런 식으로 세그먼트 트리를 만드는 함수를 이제부터 init() 함수라고 하겠습니다.

int init(int node, int s, int e) 
{
    if (s == e) return seg[node] = input[s]; // start 와 end 의 위치가 일치하면 input[start] 값을 넣어준다.
    int mid = (s + e) / 2;
    return seg[node] = min(init(2 * node, s, mid), init(2 * node + 1, mid + 1, e)); // 다른 노드들의 정보를 
}

init함수에서 파라미터로 받는 node는 부모의 노드, s는 구간의 시작, e는 구간의 끝을 의미합니다. 그리고 부모 노드를 node라고 할 때 왼쪽의 자식 노드의 번호는 2 * node, 그리고 오른쪽 자식의 노드 번호는 2 * node + 1이 됩니다.

return문을 살펴보면 seg[node] 즉, 현재의 노드 값은 왼쪽의 자식 노드의 값과 오른쪽의 자식 노드의 값 중에서 min값을 가져오게 됩니다. 그리고 init함수가 재귀적으로 내려갈 때, 맨 아래에서 s==e가 되는 순간 seg[node]의 값은 input[s]의 값으로 저장됩니다.

 

실제로 위의 함수를 이용해서 세그먼트 트리가 잘 구해졌는지 확인해보겠습니다.

<실행 코드>

#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;

#define INF 1000000000 // 10억

int input[8];
int seg[16];

int init(int node, int s, int e) 
{
    if (s == e) return seg[node] = input[s]; // start 와 end 의 위치가 일치하면 input[start] 값을 넣어준다.
    int mid = (s + e) / 2;
    return seg[node] = min(init(2 * node, s, mid), init(2 * node + 1, mid + 1, e)); // 다른 노드들의 정보를 
}

int main()
{
    for (int i = 0; i < 8; i++)
    {
        cin >> input[i];
    }

    init(1, 0, 7); // 세그먼트 트리 만들기

    
    for (int i = 1; i < 16; i++)
    {
        cout << "seg[" << i << "] : " << seg[i] << endl;
    }
    
}

<실행 결과>

네 원하는 값들이 잘 구해졌습니다!

 

이런 식으로 세그먼트 트리를 생성하였다면 이제 구간정보가 주어졌을 때 해당 구간 내에 최솟값을 찾는 함수가 필요합니다.

 

이제부터 그 함수를 query() 함수라고 하겠습니다.

int query(int node, int s, int e, int l, int r) 
{
    if (e < l || r < s) return INF; // 찾아야하는 구간과 노드구간이 겹치지 않을 때
    if (l <= s && e <= r) return seg[node]; // 찾아야하는 구간내에 노드구간이 포함될 때
    int mid = (s + e) / 2;
    // 찾아야하는 구간이 노드구간에 포함되거나, 부분적으로 겹치는 경우
    return min(query(2 * node, s, mid, l, r), query(2 * node + 1, mid + 1, e, l, r)); 
}

query함수에서 파라미터로 받는 node, s, e는 init함수와 동일하고 l과 r은 "left", "right"의 앞글자로 찾을 구간을 의미합니다. 

 

1. 찾아야 하는 구간과 노드 구간이 겹치지 않을 때 

if (e < l || r < s) return INF;

찾지 않아도 되는 구간입니다. min함수에서 자동적으로 걸러지도록 INF를 리턴해줍니다.

 

2. 찾아야 하는 구간 내에 노드 구간이 포함될 때 

if (l <= s && e <= r) return seg[node];

더 이상 재귀적으로 들어가지 않아도 됩니다. 해당하는 노드의 segment 값을 리턴해줍니다.

 

3. 찾아야 하는 구간이 노드 구간에 포함되거나, 부분적으로 겹치는 경우 

return min(query(2 * node, s, mid, l, r), query(2 * node + 1, mid + 1, e, l, r)); 

위의 1번 경우나 2번 경우가 나올 때까지 재귀적으로 계속해서 들어가 줍니다.

 

INDEX 0 1 2 3 4 5 6 7
input[ ] 7 3 2 6 5 8 1 4

이제 찾을 구간정보를 받으면 해당 구간 내의 최솟값을 정말 반환하는지 확인해보겠습니다. 

구간정보를 2~6으로 주어진다고 생각해봅시다. 일단 위의 input배열에서 INDEX가 2~6일 때 input값의 최소는 "1"임을 쉽게 알 수 있습니다. 여기서 세그먼트 트리를 활용한다면 다음과 같은 그림으로 구하게 될 것입니다.

 

쿼리 함수를 통해 재귀적으로 들어가 보겠습니다. 먼저 1번 세그먼트(seg[1])에서 출발을 하게 됩니다. 1번 세그먼트는 (0~7)의 구간정보를 담고 있고 찾으려는 구간정보가 여기에 포함되므로 양쪽의 자식 노드로 recursion 합니다.

이제 세그먼트 2번과 3번의 경우 둘 다 찾으려는 구간정보와 포함하고 있는 구간정보가 겹치므로 다시 recursion 합니다.

이제 세그먼트 4,5,6,7번을 살펴봅시다. 

세그먼트 4번의 경우 찾아야 하는 구간과 노드 구간이 겹치지 않으므로 자동적으로 걸러지도록 INF를 리턴.
(return INF)

세그먼트 5, 6번의 경우 찾아야 하는 구간 내에 노드 구간이 포함되므로 해당 seg값 리턴. (return seg[5], seg[6])

세그먼트 7번의 경우 찾아야 하는 구간과 노드 구간이 겹치므로 다시 recursion 해서 찾아야 하는 구간 내에 노드 구간이 포함되는 seg값을 리턴합니다.
(return seg[14])

 

실제로 코드가 정확한지 확인해보겠습니다.

<완성된 최종 코드>

#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;

#define INF 1000000000 // 10억

int a, b;
int input[8];
int seg[16];

int init(int node, int s, int e) 
{
    if (s == e) return seg[node] = input[s]; // start 와 end 의 위치가 일치하면 input[start] 값을 넣어준다.
    int mid = (s + e) / 2;
    return seg[node] = min(init(2 * node, s, mid), init(2 * node + 1, mid + 1, e)); // 다른 노드들의 정보를 
}

int query(int node, int s, int e, int l, int r) {
    if (e < l || r < s) return INF; // 찾아야하는 구간과 노드구간이 겹치지 않을 때
    if (l <= s && e <= r) return seg[node]; // 찾아야하는 구간내에 노드구간이 포함될 때
    int mid = (s + e) / 2;
    // 찾아야하는 구간이 노드구간에 포함되거나, 부분적으로 겹치는 경우
    return min(query(2 * node, s, mid, l, r), query(2 * node + 1, mid + 1, e, l, r)); 
}


int main()
{
    for (int i = 0; i < 8; i++)
    {
        cin >> input[i];
    }

    init(1, 0, 7); // 세그먼트 트리 만들기

    
    cin >> a >> b;

    cout << query(1, 0, 7, a, b);
    
}

<실행 결과>

INDEX 0 1 2 3 4 5 6 7
input[ ] 7 3 2 6 5 8 1 4

 

 

이상입니다. 틀린 부분이 있거나 궁금하신 점을 댓글 부탁드리겠습니다. 읽어주셔서 감사드립니다!

반응형

+ Recent posts