Post

[BOJ]1167. 트리의 지름

1167. 트리의 지름


❌code1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import sys


def dfs(cur_node, total_cost):
    global ans

    if ans < total_cost:
        ans = total_cost  # 최댓값을 정답으로 갱신

    for node, cost in tree[cur_node]:
        if visited[node]:
            continue  # 이미 방문한 노드 무시
        visited[node] = 1  # 방문 표시
        dfs(node, total_cost + cost)
        visited[node] = 0  # 방문 표시 취소


V = int(sys.stdin.readline())
tree = [[] for _ in range(V + 1)]
for _ in range(V):
    info = list(map(int, sys.stdin.readline().split()))
    for i in range(1, len(info), 2):
        if info[i] == -1:
            break
        # 연결된 노드와 그 비용을 추가
        tree[info[0]].append((info[i], info[i + 1]))

visited = [0] * (V + 1)
ans = 0  # 최종 답
for start_node in range(1, V + 1):  # 모든 노드에 대하여 검사
    visited[start_node] = 1  # 시작 노드 방문 표시
    dfs(start_node, 0)
    visited[start_node] = 0  # 시작 노드 방문 표시 취소
print(ans)

시도

모든 노드에 대하여 깊이 우선 탐색을 진행하여 두 노드 사이의 거리 중 최댓값을 구한다.

문제

시간 초과


❌code2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import sys
import heapq


def dfs(cur_node, total_cost):
    ret = 0
    for node, cost in tree[cur_node]:
        if visited[node]:
            continue  # 이미 방문한 노드 무시

        visited[node] = 1

        # 1번 노드의 경우 도출된 가지의 길이를 힙에 저장한다
        if cur_node == 1:
            heapq.heappush(possible_path, -1 * (dfs(node, total_cost + cost) + cost))
        else:
            ret = max(ret, dfs(node, total_cost + cost) + cost)
        visited[node] = 0

    costs[cur_node] = ret
    return ret


V = int(sys.stdin.readline())
tree = [[] for _ in range(V + 1)]
for _ in range(V):
    info = list(map(int, sys.stdin.readline().split()))
    for i in range(1, len(info), 2):
        if info[i] == -1:
            break
        # 연결된 노드와 그 비용을 추가
        tree[info[0]].append((info[i], info[i + 1]))

visited = [0] * (V + 1)
costs = [-1] * (V + 1)  # [index] 노드가 시작 노드일 때 얻을 수 있는 최대 길이 저장
visited[1] = 1  # 1번 노드 방문 표시
possible_path = []  # 1번 노드에서 갈 수 있는 가지의 최대 길이를 모두 저장
dfs(1, 0)
if len(possible_path) == 1:
    print(-1 * possible_path[0])
else:
    print(-1 * (heapq.heappop(possible_path) + heapq.heappop(possible_path)))

시도

1번 노드에서 깊이 우선 탐색을 진행한다. 1번 노드와 연결된 각 노드를 통해 해당 노드가 만드는 가지의 최대 길이를 구하고 이를 possible_path에 저장한 뒤 그중 가장 긴 길이와 두 번째로 긴 길이를 구하고자 했다. 편의를 위해 최대 힙에 길이를 저장해 그중 가장 긴 길이와 두 번째로 긴 길이를 바로 구할 수 있도록 했다. 디버깅을 위하여 costs 리스트를 만들어 각 가지의 최대 길이를 저장하였다.

문제

해당 코드는 트리의 지름의 양 끝을 맡는 두 노드 중 한 노드가 1번 노드라는 전제를 가지고 있다. 문제에서 제시된 트리는 2번 노드를 제외하고 1-3-4-5로 이어지는 지름을 갖는다. 만약 2번 노드와 1번 노드의 위치가 바뀐다면 1번 노드가 포함된 지름의 최댓값은 10이 되지만 실제로 구할 수 있는 최댓값은 11이다.


⭕code3

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# 1167. 트리의 지름


import sys


def dfs(cur_node, total_cost):
   global ans

   max_ret = 0  # cur_node의 가지의 거리를 모두 합했을 때 그중 가장 큰 값
   for node, cost in tree[cur_node]:
       if visited[node]:
           continue  # 이미 방문한 노드 무시

       visited[node] = True  # 방문 표시

       # cur_node에서 node로 시작되는 가지의 거리 총 합
       ret = dfs(node, total_cost + cost) + cost

       # total_ret 값을 계속 갱신하면 어느 시점에 최대 지름을 구할 수 있다
       total_ret = ret + max_ret
       ans = max(ans, total_ret)  # 최종 답에 최댓값 갱신
       max_ret = max(max_ret, ret)  # 최댓값 갱신
       visited[node] = False  # 방문 표시 취소

   return max_ret


V = int(sys.stdin.readline())
tree = [[] for _ in range(V + 1)]
for _ in range(V):
   info = list(map(int, sys.stdin.readline().split()))
   for i in range(1, len(info), 2):
       if info[i] == -1:
           break
       # 연결된 노드와 그 비용을 추가
       tree[info[0]].append((info[i], info[i + 1]))

visited = [False] * (V + 1)  # 방문 표시 리스트
visited[1] = True  # 항상 존재하는 1번 노드를 기준으로 잡고 먼저 방문 표시
ans = 0  # 최종 답
dfs(1, 0)
print(ans)

시도

트리의 지름은 한 개의 노드와 그 노드에 연결된 두 개의 노드가 만드는 가지로 구성되어있다. 따라서 각 노드가 갖는 각각의 가지의 길이를 알면 그중 가장 긴 두 길이를 더해 해당 노드가 포함된 트리의 지름을 구할 수 있고 만들어진 지름 중 최댓값을 구하면 전체 트리의 지름이 될 것이다.

트리에서 임의의 한 노드는 다른 모든 노드와도 연결될 수 있기 때문에 무조건 존재하는 임의의 한 노드(1번 노드)에서 시작하여 연결된 다른 노드를 탐색한다.

현재 노드(cur_node)에 연결된 노드를 찾아 해당 노드가 만드는 가지의 최장 길이를 구해 ret 변수에 저장한다. 현재 노드가 포함된 트리의 최장 지름은 최대 2개의 가지의 길이를 합한 값이다. 따라서 지름의 최댓값을 구하기 위해선 가장 큰 ret 값과 두 번째로 큰 ret 값을 더해야 한다.

연결된 가지가 2개 이상일 때 구한 ret 값과 max_ret 값을 더한다면 어느 시점에 아래 두 경우가 생길 수 있다.

  1. total_ret = (ret: 가장 긴 거리) + (max_ret: 두 번째로 긴 거리)
    1. 이후 가장 긴 거리가 max_ret 값으로 갱신됨
  2. total_ret = (ret: 두 번째로 긴 거리) + (max_ret: 가장 긴 거리)

따라서 total_ret 값을 계속 갱신하면 어느 시점에 최대 지름 total_ret을 구할 수 있다. 이렇게 모든 노드에 대하여 해당 노드가 포함된 트리의 지름 중 최댓값을 구할 수 있게 되었고 여기서 다시 최댓값을 찾으면 전체 트리의 지름이 된다.


요약

트리의 지름은 각 노드에 대하여 해당 노드가 포함된 트리의 지름을 구한 뒤 그중 최댓값을 택하면 된다.

This post is licensed under CC BY 4.0 by the author.