러닝머신 하는 K-공대생
Segment Tree Implementation (Python) 본문
정말 심심해져서 최근 교내 대회 문제를 업솔빙하고자 '바벨탑' 문제 를 풀다 세그먼트 트리가 기억이 안나서 다시 복습하고 재귀적 방식과 비재귀적 방식으로 구현해 보았다. 개인적으로 탑다운으로 짠게 직관적이고 레이지나 그 외 여러 확장적 측면에서 편한 것 같은데 Python으로 세그트리 문제들 풀어보면 성능적인 면은 확실히 Bottom-Up 방식이 유리한 것 같다.
1. 재귀적 방식 구현
class SegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n)
self.build(arr, 1, 0, self.n - 1)
def f(self, a, b):
return min(a, b)
def build(self, arr, node, l, r):
if l == r:
self.tree[node] = arr[l]
return
mid = (l + r) // 2
self.build(arr, 2 * node, l, mid)
self.build(arr, 2 * node + 1, mid + 1, r)
self.tree[node] = self.f(self.tree[2 * node], self.tree[2 * node + 1])
def query(self, l, r):
return self._query(l, r, 1, 0, self.n - 1)
def _query(self, l, r, node, nl, nr):
if nl >= l and nr <= r:
return self.tree[node]
mid = (nl + nr) // 2
if nr < l or nl > r:
return int(1e9)
return self.f(self._query(l, r, 2 * node, nl, mid),
self._query(l, r, 2 * node + 1, mid + 1, nr))
def update(self, idx, val):
self._update(idx, val, 1, 0, self.n - 1)
def _update(self, idx, val, node, l, r):
if l == r:
self.tree[node] = val
return
mid = (l + r) // 2
if idx <= mid:
self._update(idx, val, 2 * node, l, mid)
else:
self._update(idx, val, 2 * node + 1, mid + 1, r)
self.tree[node] = self.f(self.tree[2 * node], self.tree[2 * node + 1])
arr = [1, 3, 2, 4, 5, 7, 6, 8]
seg = SegmentTree(arr)
print(seg.query(0, 7))
seg.update(0,10)
print(seg.query(0, 7))
2. Botton-Up 방식 구현 (https://heejayaa.tistory.com/45 에서 가져옴)
class SegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (2 * self.n)
self.tree[n:] = arr[:]
self.build()
def f(self, a, b):
return a + b
def build(self):
for i in range(n - 1, 0, -1):
self.tree[i] = self.f(self.tree[i << 1], self.tree[i << 1 | 1])
def query(self, l, r):
l += self.n
r += self.n
res = 0
while l <= r:
if l & 1:
res = self.f(res, self.tree[l])
l += 1
if not r & 1:
res = self.f(res, self.tree[r])
r -= 1
l >>= 1
r >>= 1
return res
def update(self, idx, val):
i = self.n + idx
self.tree[i] = val
while i >= 1:
i >>= 1
self.tree[i] = self.f(self.tree[i << 1], self.tree[i << 1 | 1])
3. 세그먼트 트리의 활용 예시
- 풀이:
한쪽 바벨 부분만 생각하자. 무게 최소 되야하는 조건을 생각하면 5kg 짜리만 사용할 때도 가능하므로 하나의 운동 중에 기존 바벨을 버리고 새 바벨을 끼우는 행위는 발생하면 안 됨. 즉 잘 분할해야 이동 횟수 최소되는데 그림 그려서 생각해보면 최소값을 기준으로 잘라주고 잘라진 좌우에서 각각 최소값을 기준으로 또 이를 계속 반복하면 됨. 이때 구간 별 최소값을 O(logN)에 구하기 위해 세그먼트 트리를 이용한다. 잘라진 덩어리를 가지고 무게랑 이동할 횟수를 잘 세주는게 중요하고 자르는 규칙은 동일하므로 분할 정복 잘 짜주면 된다.
왜 그런지는 잘 모르겠는데 재귀 구현 방식은 TLE 떴고 비재귀로 구현했을 때 통과된 걸 보면 python으로 제출하는 사람들은 가능하면 Bottom-Up 방식을 이용하는게 좋을 듯 하다.
- 소스 코드 (Python3)
import sys
input = sys.stdin.readline
n = int(input())
A = list(map(int, input().split()))
arr = [(i - 20) // 2 for i in A]
cnt = 0
w = 0
tree = [0] * (2 * n)
for i in range(n):
tree[n + i] = (arr[i], i)
def build():
for i in range(n-1, 0, -1):
if tree[i<<1][0] < tree[i<<1|1][0]:
tree[i] = tree[i<<1]
else:
tree[i] = tree[i<<1|1]
def query(l, r):
l += n
r += n
res = (int(1e9),-1)
while l <= r:
if l&1:
res = res if res[0] < tree[l][0] else tree[l]
l += 1
if not r&1:
res = res if res[0] < tree[r][0] else tree[r]
r -= 1
l >>= 1
r >>= 1
return res
def f(x):
cnt = 0
cnt += x // 20
x %= 20
cnt += x // 15
x %= 15
cnt += x // 10
x %= 10
cnt += x // 5
return cnt
def solve(l, r, prev):
global cnt, w
if l == r:
cnt += f(arr[l] - prev)
w += arr[l] - prev
return
if l > r : return
val, idx = query(l, r)
cnt += f(val - prev)
w += val - prev
solve(l, idx - 1, val)
solve(idx + 1, r, val)
build()
solve(0, n-1, 0)
print(4 * w, 4 * cnt)
4. Segment Tree 공부 자료
IOI Korea 강의 모음
3. 세그먼트 트리의 응용
강의 ppt 자료: https://drive.google.com/file/d/1QWpOb_L0rxS3We2mCy39du-B0oBhUDXU/view