러닝머신 하는 K-공대생
Segment Tree Implementation (Python) 본문
정말 심심해져서 최근 교내 대회 문제를 업솔빙하고자 '바벨탑' 문제 를 풀다 세그먼트 트리가 기억이 안나서 다시 복습하고 재귀적 방식과 비재귀적 방식으로 구현해 보았다. 개인적으로 탑다운으로 짠게 직관적이고 레이지나 그 외 여러 확장적 측면에서 편한 것 같은데 Python으로 세그트리 문제들 풀어보면 성능적인 면은 확실히 Bottom-Up 방식이 유리한 것 같다.
1. 재귀적 방식 구현
class SegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n)
self.build(arr, 1, 0, self.n - 1)
def f(self, a, b):
return min(a, b)
def build(self, arr, node, l, r):
if l == r:
self.tree[node] = arr[l]
return
mid = (l + r) // 2
self.build(arr, 2 * node, l, mid)
self.build(arr, 2 * node + 1, mid + 1, r)
self.tree[node] = self.f(self.tree[2 * node], self.tree[2 * node + 1])
def query(self, l, r):
return self._query(l, r, 1, 0, self.n - 1)
def _query(self, l, r, node, nl, nr):
if nl >= l and nr <= r:
return self.tree[node]
mid = (nl + nr) // 2
if nr < l or nl > r:
return int(1e9)
return self.f(self._query(l, r, 2 * node, nl, mid),
self._query(l, r, 2 * node + 1, mid + 1, nr))
def update(self, idx, val):
self._update(idx, val, 1, 0, self.n - 1)
def _update(self, idx, val, node, l, r):
if l == r:
self.tree[node] = val
return
mid = (l + r) // 2
if idx <= mid:
self._update(idx, val, 2 * node, l, mid)
else:
self._update(idx, val, 2 * node + 1, mid + 1, r)
self.tree[node] = self.f(self.tree[2 * node], self.tree[2 * node + 1])
arr = [1, 3, 2, 4, 5, 7, 6, 8]
seg = SegmentTree(arr)
print(seg.query(0, 7))
seg.update(0,10)
print(seg.query(0, 7))
2. Botton-Up 방식 구현 (https://heejayaa.tistory.com/45 에서 가져옴)
class SegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (2 * self.n)
self.tree[n:] = arr[:]
self.build()
def f(self, a, b):
return a + b
def build(self):
for i in range(n - 1, 0, -1):
self.tree[i] = self.f(self.tree[i << 1], self.tree[i << 1 | 1])
def query(self, l, r):
l += self.n
r += self.n
res = 0
while l <= r:
if l & 1:
res = self.f(res, self.tree[l])
l += 1
if not r & 1:
res = self.f(res, self.tree[r])
r -= 1
l >>= 1
r >>= 1
return res
def update(self, idx, val):
i = self.n + idx
self.tree[i] = val
while i >= 1:
i >>= 1
self.tree[i] = self.f(self.tree[i << 1], self.tree[i << 1 | 1])
3. Lazy Segment tree
class LazySegTree:
def __init__(self, arr, identity=0):
self.arr = arr
self.n = len(arr)
self.tree = [0] * 4 * self.n
self.lazy = [0] * 4 * self.n
self.identity = 0
self.build(1, self.n, 1)
def f(self, a, b):
return a + b
def build(self, l, r, node):
if l == r:
self.tree[node] = self.arr[l - 1]
return
m = (l + r) // 2
self.build(l, m, node << 1)
self.build(m + 1, r, node << 1 | 1)
self.tree[node] = self.f(self.tree[node << 1], self.tree[node << 1 | 1])
def update(self, nl, nr, val):
self._update(1, self.n, 1, nl, nr, val)
def query(self, l, r):
return self._query(l, r, 1, self.n, 1)
def prop(self, l, r, node): # 덧셈 기준으로 작성
m = (l + r) // 2
self.tree[node << 1] = self.f(self.tree[node << 1], self.lazy[node] * (l - m + 1))
self.tree[node << 1 | 1] = self.f(self.tree[node << 1 | 1], self.lazy[node] * (r - m))
self.lazy[node] = 0
def _update(self, l, r, node, nl, nr, val):
if l > nr or r < nl:
return
if l <= nl and nr <= r:
self.tree[node] = self.f(self.tree[node], val * (nr - nl + 1))
self.lazy[node] = val
return
self.prop(nl, nr, node)
mid = (nl + nr) // 2
self._update(l, r, node << 1, nl, mid, val)
self._update(l, r, node << 1 | 1, mid + 1, nr, val)
self.tree[node] = self.f(self.tree[node << 1], self.tree[node << 1 | 1])
def _query(self, l, r, nl, nr, node):
if l > nr or r < nl:
return self.identity
if l <= nl and nr <= r:
return self.tree[node]
self.prop(nl, nr, node)
mid = (nl + nr) // 2
return self.f(self._query(l, r, nl, mid, node<<1), self._query(l, r, mid + 1, nr, node<<1|1))
4. Node Swap
import sys
input = sys.stdin.readline
class Node:
def __init__(self, v=0, parent=None):
self.parent = parent
self.left = None
self.right = None
self.val = v
class SegTree:
def __init__(self, arr):
self.n = len(arr)
self.arr = arr
self.root = Node(0)
self.build(1, self.n, self.root)
def f(self, a, b):
return a + b
def build(self, l, r, node):
if l == r:
node.val = self.arr[l - 1]
return
m = (l + r) // 2
if not node.left:
node.left = Node(0, parent=node)
if not node.right:
node.right = Node(0, parent=node)
self.build(l, m, node.left)
self.build(m + 1, r, node.right)
node.val = self.f(node.left.val, node.right.val)
def query(self, l, r):
return self._query(l, r, self.root, 1, self.n)
def _query(self, l, r, node, nl, nr):
if nl >= l and nr <= r:
return node.val
mid = (nl + nr) // 2
if nr < l or nl > r:
return 0
return self.f(self._query(l, r, node.left, nl, mid),
self._query(l, r, node.right, mid + 1, nr))
def swap(self, l, r, other):
self.node_list = []
other.node_list = []
self.find_node(l, r, self.root, 1, self.n)
other.find_node(l, r, other.root, 1, other.n)
my_list = self.node_list
other_list = other.node_list
# node always has parent
for i in range(len(my_list)):
my_node = my_list[i]
other_node = other_list[i]
my_parent = my_node.parent
other_parent = other_node.parent
if my_parent:
if my_parent.left == my_node:
my_parent.left = other_node
other_node.parent = my_parent
other_parent.left = my_node
my_node.parent = other_parent
else:
my_parent.right = other_node
other_node.parent = my_parent
other_parent.right = my_node
my_node.parent = other_parent
else:
self.root, self.other.root = self.other.root, self.root
break
# update
node = other_node
while node.parent:
node = node.parent
node.val = self.f(node.left.val, node.right.val)
node = my_node
while node.parent:
node = node.parent
node.val = self.f(node.left.val, node.right.val)
def find_node(self, l, r, node, nl, nr):
if nl >= l and nr <= r:
self.node_list.append(node)
return
mid = (nl + nr) // 2
if nr < l or nl > r:
return
self.find_node(l, r, node.left, nl, mid)
self.find_node(l, r, node.right, mid + 1, nr)
arr1 = [1, 1, 1, 1, 1]
arr2 = [3, 3, 3, 3, 3]
tree1 = SegTree(arr1)
tree2 = SegTree(arr2)
print(tree1.query(1, 5))
print(tree2.query(1, 5))
tree1.swap(1, 3, tree2)
print(tree1.query(1, 3))
print(tree2.query(1, 3))
5. 세그먼트 트리의 활용 예시
26655번: 바벨탑
바벨은 헬스장에서 흔히 볼 수 있는 운동기구이다. 바벨은 바벨 봉과 바벨 원판으로 이루어져 있으며, 봉의 양쪽에 다양한 무게의 원판을 여러 개 끼워서 무게를 조절할 수 있다. 경기북과학고
www.acmicpc.net
- 풀이:
한쪽 바벨 부분만 생각하자. 무게 최소 되야하는 조건을 생각하면 5kg 짜리만 사용할 때도 가능하므로 하나의 운동 중에 기존 바벨을 버리고 새 바벨을 끼우는 행위는 발생하면 안 됨. 즉 잘 분할해야 이동 횟수 최소되는데 그림 그려서 생각해보면 최소값을 기준으로 잘라주고 잘라진 좌우에서 각각 최소값을 기준으로 또 이를 계속 반복하면 됨. 이때 구간 별 최소값을 O(logN)에 구하기 위해 세그먼트 트리를 이용한다. 잘라진 덩어리를 가지고 무게랑 이동할 횟수를 잘 세주는게 중요하고 자르는 규칙은 동일하므로 분할 정복 잘 짜주면 된다.
왜 그런지는 잘 모르겠는데 재귀 구현 방식은 TLE 떴고 비재귀로 구현했을 때 통과된 걸 보면 python으로 제출하는 사람들은 가능하면 Bottom-Up 방식을 이용하는게 좋을 듯 하다.

- 소스 코드 (Python3)
import sys
input = sys.stdin.readline
n = int(input())
A = list(map(int, input().split()))
arr = [(i - 20) // 2 for i in A]
cnt = 0
w = 0
tree = [0] * (2 * n)
for i in range(n):
tree[n + i] = (arr[i], i)
def build():
for i in range(n-1, 0, -1):
if tree[i<<1][0] < tree[i<<1|1][0]:
tree[i] = tree[i<<1]
else:
tree[i] = tree[i<<1|1]
def query(l, r):
l += n
r += n
res = (int(1e9),-1)
while l <= r:
if l&1:
res = res if res[0] < tree[l][0] else tree[l]
l += 1
if not r&1:
res = res if res[0] < tree[r][0] else tree[r]
r -= 1
l >>= 1
r >>= 1
return res
def f(x):
cnt = 0
cnt += x // 20
x %= 20
cnt += x // 15
x %= 15
cnt += x // 10
x %= 10
cnt += x // 5
return cnt
def solve(l, r, prev):
global cnt, w
if l == r:
cnt += f(arr[l] - prev)
w += arr[l] - prev
return
if l > r : return
val, idx = query(l, r)
cnt += f(val - prev)
w += val - prev
solve(l, idx - 1, val)
solve(idx + 1, r, val)
build()
solve(0, n-1, 0)
print(4 * w, 4 * cnt)
6. Segment Tree 공부 자료
IOI Korea 강의 모음
3. 세그먼트 트리의 응용
강의 ppt 자료: https://drive.google.com/file/d/1QWpOb_L0rxS3We2mCy39du-B0oBhUDXU/view