Notice
Recent Posts
Recent Comments
Link
«   2025/01   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
Tags more
Archives
Today
Total
관리 메뉴

러닝머신 하는 K-공대생

Segment Tree Implementation (Python) 본문

Problem Solving/Algorithms

Segment Tree Implementation (Python)

prgmti1 2023. 2. 12. 06:26

정말 심심해져서 최근 교내 대회 문제를 업솔빙하고자 '바벨탑' 문제 를 풀다 세그먼트 트리가 기억이 안나서 다시 복습하고 재귀적 방식과 비재귀적 방식으로 구현해 보았다. 개인적으로 탑다운으로 짠게 직관적이고 레이지나 그 외 여러 확장적 측면에서 편한 것 같은데 Python으로 세그트리 문제들 풀어보면 성능적인 면은 확실히 Bottom-Up 방식이 유리한 것 같다. 

 

1. 재귀적 방식 구현

class SegmentTree:
    def __init__(self, arr):
        self.n = len(arr)
        self.tree = [0] * (4 * self.n)
        self.build(arr, 1, 0, self.n - 1)

    def f(self, a, b):
        return min(a, b)

    def build(self, arr, node, l, r):
        if l == r:
            self.tree[node] = arr[l]
            return
        mid = (l + r) // 2
        self.build(arr, 2 * node, l, mid)
        self.build(arr, 2 * node + 1, mid + 1, r)
        self.tree[node] = self.f(self.tree[2 * node], self.tree[2 * node + 1])

    def query(self, l, r):
        return self._query(l, r, 1, 0, self.n - 1)

    def _query(self, l, r, node, nl, nr):
        if nl >= l and nr <= r:
            return self.tree[node]
        mid = (nl + nr) // 2
        if nr < l or nl > r:
            return int(1e9)
        return self.f(self._query(l, r, 2 * node, nl, mid),
                      self._query(l, r, 2 * node + 1, mid + 1, nr))

    def update(self, idx, val):
        self._update(idx, val, 1, 0, self.n - 1)

    def _update(self, idx, val, node, l, r):
        if l == r:
            self.tree[node] = val
            return
        mid = (l + r) // 2
        if idx <= mid:
            self._update(idx, val, 2 * node, l, mid)
        else:
            self._update(idx, val, 2 * node + 1, mid + 1, r)
        self.tree[node] = self.f(self.tree[2 * node], self.tree[2 * node + 1])
        

arr = [1, 3, 2, 4, 5, 7, 6, 8]
seg = SegmentTree(arr)
print(seg.query(0, 7))
seg.update(0,10)
print(seg.query(0, 7))

 

2. Botton-Up 방식 구현 (https://heejayaa.tistory.com/45 에서 가져옴)

class SegmentTree:
    def __init__(self, arr):
        self.n = len(arr)
        self.tree = [0] * (2 * self.n)
        self.tree[n:] = arr[:]
        self.build()

    def f(self, a, b):
        return a + b

    def build(self):
        for i in range(n - 1, 0, -1):
            self.tree[i] = self.f(self.tree[i << 1], self.tree[i << 1 | 1])

    def query(self, l, r):
        l += self.n
        r += self.n
        res = 0
        while l <= r:
            if l & 1:
                res = self.f(res, self.tree[l])
                l += 1
            if not r & 1:
                res = self.f(res, self.tree[r])
                r -= 1
            l >>= 1
            r >>= 1
        return res

    def update(self, idx, val):
        i = self.n + idx
        self.tree[i] = val
        while i >= 1:
            i >>= 1
            self.tree[i] = self.f(self.tree[i << 1], self.tree[i << 1 | 1])

 

 

3. 세그먼트 트리의 활용 예시

 

 

26655번: 바벨탑

바벨은 헬스장에서 흔히 볼 수 있는 운동기구이다. 바벨은 바벨 봉과 바벨 원판으로 이루어져 있으며, 봉의 양쪽에 다양한 무게의 원판을 여러 개 끼워서 무게를 조절할 수 있다. 경기북과학고

www.acmicpc.net

- 풀이:

더보기

한쪽 바벨 부분만 생각하자. 무게 최소 되야하는 조건을 생각하면 5kg 짜리만 사용할 때도 가능하므로 하나의 운동 중에 기존 바벨을 버리고 새 바벨을 끼우는 행위는 발생하면 안 됨. 즉 잘 분할해야 이동 횟수 최소되는데 그림 그려서 생각해보면 최소값을 기준으로 잘라주고 잘라진 좌우에서 각각 최소값을 기준으로 또 이를 계속 반복하면 됨. 이때 구간 별 최소값을 O(logN)에 구하기 위해 세그먼트 트리를 이용한다. 잘라진 덩어리를 가지고 무게랑 이동할 횟수를 잘 세주는게 중요하고 자르는 규칙은 동일하므로 분할 정복 잘 짜주면 된다.

 

왜 그런지는 잘 모르겠는데 재귀 구현 방식은 TLE 떴고 비재귀로 구현했을 때 통과된 걸 보면 python으로 제출하는 사람들은 가능하면 Bottom-Up 방식을 이용하는게 좋을 듯 하다.

 

 

 

- 소스 코드 (Python3)

import sys

input = sys.stdin.readline
n = int(input())
A = list(map(int, input().split()))
arr = [(i - 20) // 2 for i in A]
cnt = 0
w = 0
tree = [0] * (2 * n)
for i in range(n):
    tree[n + i] = (arr[i], i)

def build():
    for i in range(n-1, 0, -1):
        if tree[i<<1][0] < tree[i<<1|1][0]:
            tree[i] = tree[i<<1]
        else:
            tree[i] = tree[i<<1|1]

def query(l, r):
    l += n
    r += n
    res = (int(1e9),-1)
    while l <= r:
        if l&1:
            res = res if res[0] < tree[l][0] else tree[l]
            l += 1
        if not r&1:
            res = res if res[0] < tree[r][0] else tree[r]
            r -= 1
        l >>= 1
        r >>= 1
    return res

def f(x):
    cnt = 0
    cnt += x // 20
    x %= 20
    cnt += x // 15
    x %= 15
    cnt += x // 10
    x %= 10
    cnt += x // 5
    return cnt

def solve(l, r, prev):
    global cnt, w
    if l == r:
        cnt += f(arr[l] - prev)
        w += arr[l] - prev
        return
    if l > r : return
    val, idx = query(l, r)
    cnt += f(val - prev)
    w += val - prev
    solve(l, idx - 1, val)
    solve(idx + 1, r, val)

build()
solve(0, n-1, 0)
print(4 * w, 4 * cnt)

  

 

4. Segment Tree 공부 자료

 

IOI Korea 강의 모음

1. 세그먼트 트리의 개념 및 필요성

2. 세그먼트 트리의 재귀 및 비재귀 구현

3. 세그먼트 트리의 응용

4. Lazy Propagation

 

강의 ppt 자료: https://drive.google.com/file/d/1QWpOb_L0rxS3We2mCy39du-B0oBhUDXU/view

 

Comments