Problem Solving

[백준][C++] 1197번 최소 스패닝 트리

wisdom11 2022. 4. 13. 21:46

문제 링크: https://www.acmicpc.net/problem/1197

 

1197번: 최소 스패닝 트리

첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이

www.acmicpc.net

 

 

문제

 

그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하시오.

최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 말한다.

 

 

난이도: 골드4

 

문제가 아주 심플하다.

문제 제목처럼 최소 스패닝 트리(Minimum Spanning Tree, MST)를 구현하고 가중치를 출력해주면 된다.

 

 


문제 해결

 

최소 스패닝 트리 (Minimum Spanning Tree)를 구현하는 방법은 2가지가 있다.

첫 번째, Prim의 알고리즘
두 번째, Kruskal의 알고리즘

 

Prim의 알고리즘: 우선순위 큐 이용

Kruskal의 알고리즘: 분리 집합(disjoint-set) 이용

 

Kruskal의 알고리즘이 구현이 더 간단하기 때문에, 이 알고리즘을 구현하는 방식으로 문제를 해결하였다.

 

 

분리 집합 (Disjoint-set)

Union-Find 를 통해 구현할 수 있다.

노드 x의 부모 노드가 parent[x]에 저장되어 있다고 가정하면 다음과 같은 방식으로 구현할 수 있다.

 

1. Find

특정 원소의 집합의 루트 노드를 찾는다.

int Find(int x) {
    if(parent[x] == x) return x;
    return parent[x] = Find(parent[x]);
}

 

2. Union

두 집합을 합친다.

void Union(int x, int y) {
    x = Find(x);
    y = Find(y);
    parent[x] = y;
}

 

Kruskal의 알고리즘

1. 가중치가 가장 작은 간선을 선택한다.

2. 선택한 간선에 연결된 2개의 노드의 루트 노드를 찾는다.

3. 두 노드의 루트 노드가 서로 다르다면 (= 서로 다른 집합이라면) union 연산을 한다.

이 과정을 반복한다.

 


소스코드

#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
struct edge {
    int start, end, cost;
};
int parent[10010];
vector<edge> edges;

int Find(int x) {
    if(parent[x] == x) return x;
    return parent[x] = Find(parent[x]);
}
void Union(int x, int y) {
    x = Find(x);
    y = Find(y);
    parent[x] = y;
}

bool cmp(edge e1, edge e2) {
    return e1.cost < e2.cost;
}

int main() {
    int v, e;
    scanf("%d %d", &v, &e);

    // 부모 노드 초기화
    for(int i=1; i<=v; i++) parent[i] = i;

    for(int i=0; i<e; i++) {
        edge e;
        scanf("%d %d %d", &e.start, &e.end, &e.cost);
        edges.push_back(e);
    }

    // kruskal의 알고리즘
    sort(edges.begin(), edges.end(), cmp);
    int ans = 0;
    for(int i=0; i<e; i++) {
        int x = Find(edges[i].start);
        int y = Find(edges[i].end);
        if(x != y) {
            Union(x, y);
            ans += edges[i].cost;
        }
    }

    printf("%d\n", ans);
    return 0;
}
728x90