Computer Science/Problem Solving

[BOJ 10167, KOI 2014 중등부 4번] 금광 - 고인물들이 웰노운이라고 하는 금광 세그트리가 뭘까?

리유나 2022. 12. 4. 20:54

https://www.acmicpc.net/problem/10167

 

10167번: 금광

첫 줄에는 금광들의 개수 N (1 ≤ N ≤ 3,000)이 주어진다. 이어지는 N개의 줄 각각에는 금광의 좌표 (x, y)를 나타내는 음이 아닌 두 정수 x와 y(0 ≤ x, y ≤ 109), 그리고 금광을 개발하면 얻게 되는 이

www.acmicpc.net

정말로 유명한 문제입니다. 당시 중등부에 이정도 난이도 문제가 나와서 말이 많았다고도 하고, 풀이 자체는 여러 문제에서 유용하게 쓰일 수 있기 때문에 세그트리 응용을 공부할 때 꼭 한번씩은 짚고 넘어가는 문제입니다.

 

이 문제의 프리퀄 격이라 할 수 있는 연속합과 쿼리(예전에 쓴 블로그 글)를 풀어보신적이 있으시다면 더욱 이해가 쉬울 것입니다.

 

문제를 요약하면, "어떤 좌표평면에 점 N개가 주어지고, 그 점들의 가치가 정수로 주어질 때, 포함된 점들의 가치가 가장 큰 각 변이 x축 혹은 y축과 평행한 직사각형을 찾으시오." 라고 할 수 있겠네요. 그냥 간단해보이는 문제지만 막상 풀어보면 생각보다 알고리즘을 내는 데에도, 구현하는 데에도 큰 어려움이 따른다는 것을 알 수 있습니다.

작년 이맘때의 실행 시간을 줄이기 위한 고군분투...

자명한 풀이에서부터 시작해서, 시간 복잡도를 점차 줄여나가는 식으로 풀이를 설명하겠습니다.

 

O(N^5)

그림판으로 대충 그린 것이니 너른 양해 바랍니다.

그림과 같이 금광들이 존재하고 이를 감싸는 가장 이득을 많이 볼 수 있는 직사각형을 그린다면(검은색), 같은 범위를 포함하는 작은 직사각형(빨간색)을 언제나 만들 수 있습니다. 따라서, 직사각형의 모든 가로/세로에 적어도 한개씩은 금광이 올라와 있는 사각형만 생각해도 됩니다. 이렇게 엄청나게 넓은 범위의 좌표평면/공간에 한정된 점이나 도형이 있을 때, 그것들만을 보는 테크닉을 좌표 압축이라고 합니다. 아무튼 이렇게 하면 만들 수 있는 사각형의 개수는 O(N^4)개가 됩니다.

 

이후, 각각의 사각형 안에 포함된 금광들의 가치가 얼마인지를 확인하면 됩니다. 안에 들어갈 수 있는 금광은 최대 N개이므로, 직접 확인하는 데에는 O(N)의 시간이 걸립니다.(모든 좌표를 확인한다면 O(N^2)가 걸리겠지만, 이는 아무리 생각해도 조금 아닌 듯합니다. ㅎㅎ;) 따라서, 총 시간복잡도는 O(N^5)이 됩니다.

 

O(N^4)

당연히 O(N^5)같은 복잡도로는 아무 것도 할 수 없습니다. 조금 더 관찰을 해봅시다.

널리 알려진 테크닉 중에서 "구간 합"이라는 것이 있습니다. 원래는 1차원 배열에서 누적 합을 가지고 연속한 구간의 합을 쉽게 구할 수 있는데, 이를 조금만 응용하면 2차원에서도 할 수 있습니다!

 

위 상황과 거의 같은 상황을 5*5 배열로 나타내 보았는데요, 5와 2를 딱 끝에 걸치게 하는 3*2짜리 작은 직사각형 내부의 합을 모두 알고 싶고, 우리가 이미 부분합들은 모두 구해놓았다고 해봅시다.

 

우선은 왼쪽 아래부터 빨간색으로 색칠된 경계까지의 모든 합을 더하고, 회색으로 덧칠해진 사각형 두개를 빼 주면 되겠죠.

여기서 중요할 점은, 회색으로 두번 덧칠해진 사각형입니다. 한 번 잘못 들어갔다는(?)이유만으로 두 번이나 빠지는 불운을 겪게 되었는데, 실제로 계산하는데도 저러면 안되니까 왼쪽 아래 겹쳐지는 부분은 다시 더해줍니다.

 

이 때 원점부터 (x, y)번째까지의 누적합을 Psum[x][y]라고 가정하고, 위 직사각형은 x좌표는 l에서 r까지, y좌표는 u에서 d까지의 범위를 커버한다고 하면 직사각형 내부 숫자들의 합은 다음과 같습니다.

 

(직사각형 내부 숫자들의 합) = Psum[r][u] - Psum[l-1][u] - Psum[r][d-1] + Psum[l-1][d-1]

 

2차원 prefix sum들을 미리 구해 놓았다면 이 과정을 어떤 직사각형에 대해서 하든 고작 O(1)의 시간이 걸립니다. 이렇게 구현하면 O(N^4)로 구현을 할 수 있고, N=100인 경우에는 어떻게 잘 넘어가겠네요.

 

O(N^3)

그런데 과연, 우리는 무조건 모든 경우를 다 따져봐야 할까요? 이런 문제가 1차원으로 주어졌을 때를 생각해 봅시다.

 

물론 위 링크 문제는 여러 가지 방법으로 풀립니다. N이 1000이니까 시작점, 끝점으로 브루트 포스를 해도 O(N^2)이니까 충분히 시간 안에 돌고, 앞서 말한 누적합을 사용해도 됩니다. 하지만 DP를 사용하는 풀이를 생각해 봅시다.

 

이렇게 1차원으로 주어진 경우에는, dp[i]를 i번까지 더할 때의 최댓값이라고 가정하면 dp[i] = max(dp[i-1]+L[i], L[i]) 정도로 표현할 수 있겠네요. 그렇게 dp 배열을 채우는 데에 O(N), 채우는 과정에서 최댓값을 저장해도 괜찮고, 채운 뒤에 탐색한다고 해도 O(N)의 시간이 걸리니 아주 간단하게 선형 시간으로 해결 가능합니다.

 

비슷한 방식으로 이 문제도 O(N^3)에 해결 가능합니다! 2차원 배열에서 세로로 누적 합을 먼저 미리 만들어둡시다.

 

 

직사각형의 가로선만 두 개 고른다고 해봅시다. 그 가로선 사이에 있는 금광들을 가로선 두개로 짜부러뜨린다고(!)생각합시다. 쨘! 이렇게 1차원으로 만들 수 있습니다.

 

짜부러뜨렸을 때 생기는 새로운 1차원 배열은 위에서 만든 세로 누적합을 통해서 간단히 O(N)만에 구할 수 있네요. 이제 이 배열에서 위에 링크한 문제를 풀면 푸는 데 걸리는 시간은 O(N), 이 작업을 O(N^2)개의 모든 가로선 쌍에 대해서 해야 하니까 전체 시간은 O(N^3). 이렇게 풀이를 만들면 N이 500 이하인 테스트 케이스도 단숨에 통과 가능하고, 모든 좌표에 대해서 y=0을 만족하는(=1차원 배열인)테스트 케이스도 통과해서 49점을 받을 수 있습니다.

 

하지만, 우리의 목표는 100점이잖아요?

 

O(N^2 logN)

 

이 부분 설명을 들어가기 전에, 우선 세그먼트 트리가 무엇인지에 대해서는 알고 있다고 생각하겠습니다. 또한 이 문제를 풀어 보셨거나, 이 글을 읽고 오셨다면 더욱 좋습니다.

 

링크된 문제와 풀이는 거의 비슷합니다! 위 풀이를 응용하면 됩니다.

 

가로선 두 개를 고른다고 했는데, 그 중 하나를 고정시키고 나머지 하나만을 처음부터 끝까지 옮긴다고 해봅시다.(for문을 짜면 웬만해서는 이런 식이 되겠네요.) 그 과정에서, 새로 추가되는 줄들이 생길 때마다 prefix sum들을 싸그리 처음 구하는 게 아니라, 추가되는 점들만 고려하면 됩니다.

 

이 과정에서 기존 리스트의 최대 증가 부분 수열은 쓸모가 없어질까요? 그렇다면 조금 아쉽겠지만, 이런 때를 위해서 우리에겐(골드 4부터 루비 보스 대장들까지 전부 커버하는 파도 파도 끝이없는)Segment Tree라는 자료구조가 있습니다!

 

Segment Tree 각각에 원소를 4개씩 집어 넣습니다. "왼쪽에서 시작한 부분합 중 최댓값", "오른쪽에서 끝나는 부분합 중 최댓값", "그냥 부분합 최댓값", "전체 합"을 각각 저장해두고, 구간을 합칠 때마다 나오는 새 구간들을 만들면서 위 정보를 활용하면 됩니다! 이것을 총 O(N)개의 고정된 줄에 대해서 하고, 각각 옮길 때마다 세그먼트 트리 업데이트는 O(logN)이니까, 모든 구간에 대해서 확인하는 데에 고작 O(N^2 logN)의 시간밖에 사용하지 않습니다! 이를 그대로 구현하면, N=3000이니 빡빡하게나마 시간 제한 안에 들어갈 수 있습니다.

 

이 문제를 푸셨다면, 거의 비슷한 문제들인 이거나 하위호환인 이것 등등을 쉽게 풀 수 있습니다.

 

Python 코드

 

아름다운 양심으로 이 코드를 그대로 복사/붙여넣기 하신다면 어차피 49점을 받게 됩니다 ㅎㅎ... 일단 O(N^2logN)은 맞지만 상수 커팅이 잘 되지 않아서 시간 초과가 나는 코드입니다. 로직 참고하실 때 보시면 좋습니다!

 

import sys
from functools import cmp_to_key
input=sys.stdin.readline

class sgTree:
    def __init__(self,L):
        self.len=len(L)
        newL=[]
        for i in L:newL.append([i,i,i,i])
        self.tree=[[0,0,0,0]for i in range(self.len)]+newL
        for i in range(self.len-1, 0, -1):
            L1=self.tree[i]
            left=self.tree[2*i]
            right=self.tree[2*i+1]
            #0th:lsum 1st:rsum 2nd:tot 3rd:max
            L1[0]=max([left[0],left[2]+right[0]])
            L1[1]=max([right[1],right[2]+left[1]])
            L1[2]=left[2]+right[2]
            L1[3]=max([left[1]+right[0],left[3],right[3],L1[0],L1[1]])

    def res(self, now, start, end, l, r):
        if end<l or r<start:return[-1000000]*4
        if l<=start and end<=r:return self.tree[now]
        mid=(start+end)//2
        L1=self.res(now*2,start,mid,l,r)
        L2=self.res(now*2+1,mid+1,end,l,r)
        L=[]
        L.append(max(L1[0],L1[2]+L2[0]))
        L.append(max(L2[1],L2[2]+L1[1]))
        L.append(L1[2]+L2[2])
        L.append(max([L1[1]+L2[0],L1[3],L2[3],L[0],L[1]]))
        return L
    #add d in point pt
    def update(self, pt, d):
        i=self.len+pt
        self.tree[i][0]+=d
        self.tree[i][1]+=d
        self.tree[i][2]+=d
        self.tree[i][3]+=d
        while i>1:
            i//=2
            L1=self.tree[i]
            left=self.tree[2*i]
            right=self.tree[2*i+1]
            #0th:lsum 1st:rsum 2nd:tot 3rd:max
            L1[0]=max([left[0],left[2]+right[0]])
            L1[1]=max([right[1],right[2]+left[1]])
            L1[2]=left[2]+right[2]
            L1[3]=max([left[1]+right[0],left[3],right[3],L1[0],L1[1]])
        


n=int(input())
L=[]
dx=set()
dy=set()
for i in ' '*n:
    x,y,w=map(int,input().split())
    dx.add(x)
    dy.add(y)
    L.append([x,y,w])

dx=sorted(list(dx))
dy=sorted(list(dy))
ddx=dict()
ddy=dict()
for i in range(len(dx)):
    ddx[dx[i]]=i
for i in range(len(dy)):
    ddy[dy[i]]=i

for i in range(n):
    L[i][0],L[i][1]=ddx[L[i][0]], ddy[L[i][1]]

def ptcmp(L1, L2):
    if L1[1]>L2[1]:return 1
    elif L1[1]==L2[1]:
        if L1[0]>L2[0]:return 1
        elif L1[0]==L2[0]:return 0
        else:return -1
    else:
        return -1
L=sorted(L, key=cmp_to_key(ptcmp))
numx=len(dx)
c=1
while c<numx:c*=2
segL=[0]*c
s=sgTree(segL)
ans=0
for i in range(n):
    if i and L[i][1]==L[i-1][1]:continue
    segL=[0]*c
    s=sgTree(segL)
    for j in range(i, n):
        s.update(L[j][0], L[j][2])
        if j==n-1 or L[j][1]!=L[j+1][1]:
            ans=max(ans, s.tree[1][-1])

print(ans)