알고리즘/백준

[Java | 트리 DP | 24232번 | 골1] 망가진 나무

이진지니지니진 2025. 1. 27. 18:25

☑️ 문제

https://www.acmicpc.net/problem/24232

욕심쟁이 판다가 나무를 갉아 먹어서 나무가 망가졌다!

입력으로 방향 그래프가 주어진다. 이 그래프의 모든 간선을 양방향으로 바꾸면 트리가 된다.

당신은 임의로 간선의 방향을 뒤집을 수 있다. 당신의 목적은 간선을 뒤집는 횟수를 최소로 하여 다음을 만족하는 정점이 존재하도록 하는 것이다.

이 정점에서 모든 정점에 도달할 수 있다.

 

입력
첫째 줄에 정점의 개수 N이 주어진다. (2≤N≤100,000)
다음 줄부터 N-1개의 줄에 두 개의 정수 u, v가 주어진다. 이는 정점 u에서 정점 v로 향하는 간선을 의미한다. (1≤u,v≤N, u≠v)
정점의 번호는 1부터 N까지이다. 그래프의 모든 간선을 양방향으로 바꾸면 트리가 됨이 보장된다.

 

출력
첫째 줄에 뒤집어야하는 간선을 N-1자리 이진수로 출력한다. 왼쪽에서 i번째 비트는 i번째 간선을 뒤집어야 하면 1, 아니면 0이다. 이진수에 등장하는 1의 개수가 최소가 되도록 해야 한다.

가능한 답이 여러 가지일 경우, 아무거나 출력하면 된다.

 

예제 입력1

5
2 4
2 3
3 1
5 1

예제 출력1

0001

 

예제 입력2

6
5 1
3 1
3 4
2 4
2 6

예제 출력2

10100

 

☄️ 정답코드

import java.io.*;
import java.util.*;

public class Main {
    static int N;
    static List<List<Edge>> graph;
    static boolean[] flip;
    static boolean[] result;
    static int minFlip;

    static class Edge {
        int to, index;
        boolean dir;

        public Edge(int to, boolean dir, int index) {
            this.to = to;
            this.dir = dir;
            this.index = index;
        }
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));

        N = Integer.parseInt(br.readLine());
        graph = new ArrayList<>();
        for (int i = 0; i <= N; i++) {
            graph.add(new ArrayList<>());
        }

        for (int i = 0; i < N - 1; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int u = Integer.parseInt(st.nextToken());
            int v = Integer.parseInt(st.nextToken());

            graph.get(u).add(new Edge(v, false, i));
            graph.get(v).add(new Edge(u, true, i));
        }

        flip = new boolean[N - 1];
        result = new boolean[N - 1];
        minFlip = Integer.MAX_VALUE;

        int flipCnt = init(1, 0);

        dfs(1, 0, flipCnt);

        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < N - 1; i++) {
            sb.append(result[i] ? 1 : 0);
        }
        System.out.println(sb);
    }

    static int init(int node, int parent) {
        int sum = 0;
        for (Edge edge : graph.get(node)) {
            int next = edge.to;
            boolean dir = edge.dir;
            int idx = edge.index;
            if (next == parent) continue;

            flip[idx] = dir;
            sum += init(next, node) + (dir ? 1 : 0);
        }
        return sum;
    }

    static void dfs(int node, int parent, int flipCount) {
        if (flipCount < minFlip) {
            minFlip = flipCount;
            for (int i = 0; i < N - 1; i++) {
                result[i] = flip[i];
            }
        }

        for (Edge edge : graph.get(node)) {
            int next = edge.to;
            boolean dir = edge.dir;
            int idx = edge.index;
            if (next == parent) continue;

            if (dir) {
                flip[idx] = false;
                dfs(next, node, flipCount - 1);
                flip[idx] = true;
            } else {
                flip[idx] = true;
                dfs(next, node, flipCount + 1);
                flip[idx] = false;
            }
        }
    }
}

 

🪐 풀이

검색해보지 않았다면, ,, 절대 생각하지 못했을 것 같다.

 

🚨 첫번째 시도 - 메모리 초과(너무나도 .. 당연한 결과)

더보기

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

public class Main {
    static int N;
    static List<int[]>[] tree;
    static int ans;
    static List<int[]> inputs;
    static int[] cnts;
    static int[][] parents;

    static void makeTree(int start) {
        boolean[] visited = new boolean[N + 1];
        Queue<int[]> q = new LinkedList<>();
        q.add(new int[] {start, 0});

        while (!q.isEmpty()) {
            int[] now = q.poll();

            visited[now[0]] = true;
            parents[start][now[0]] = now[1];

            for(int[] nxt: tree[now[0]]) {
                if(visited[nxt[0]]) continue;

                if(nxt[1] == -1) cnts[start]++;
                q.add(new int[] {nxt[0], now[1]+1});
            }
        }
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;

        N = Integer.parseInt(br.readLine());
        ans = Integer.MAX_VALUE;
        inputs = new ArrayList<>();
        cnts = new int[N+1];
        tree = new ArrayList[N + 1];
        parents = new int[N + 1][N + 1];

        for (int i = 1; i < N+1; i++) {
            tree[i] = new ArrayList<>();
        }

        for (int i = 0; i < N-1; i++) {
            st = new StringTokenizer(br.readLine());
            int u = Integer.parseInt(st.nextToken());
            int v = Integer.parseInt(st.nextToken());
            tree[u].add(new int[]{v, 0});
            tree[v].add(new int[]{u, -1});
            inputs.add(new int[]{u, v});
        }

        for (int i = 1; i < N+1; i++) {
            makeTree(i);
        }
        
        int n = 0;
        int cnt = Integer.MAX_VALUE;
        for (int i = 1; i < N+1; i++) {
            if(cnts[i] < cnt) {
                n = i;
                cnt = cnts[i];
            }
        }

        StringBuilder sb = new StringBuilder();
        for(int[] input: inputs) {
            if(parents[n][input[0]] < parents[n][input[1]]) sb.append(0);
            else sb.append(1);
        }

        System.out.println(sb);
    }
}

parents만 봐도 크기가 (N+1)*(N+1) 인데 N이 100,000이니까 ... ~~ 당연히 메모리 터지겠져? ☄️

대충 로직은 모든 노드를 루트로 설정하여 각각의 트리에 대해 BFS 돌면서 부모-자식 관계를 `parent`에 저장하는 방식이다.

쉽게 말하자면 완전 탐색 방식으로 푼거다

 

이미 트리 dp라는 유형을 알고 푼 문제이지만, 여기서 어떻게 더 나아가야 할 지 감이 잡히지 않아 검색을 통해 풀이를 찾아봤다 ..!

 

✅ 최종 코드

초기 상태(1번 루트가 노드인 경우)를 기준으로, 모든 간선의 뒤집힘 여부를 체크하며, 메모이제이션 방식으로 최소 뒤집기 횟수를 계산하였다.

 

이 말만 보면 어렵게 느껴지는데, 코드를 천천히 뜯어보면 이해를 할 수 있다!

 

1️⃣ Edge

static class Edge {
    int to;
    boolean dir;    	// 뒤집히면 true, 뒤집히지 않으면 false
    int index;		// 간선이 입력된 순서

    public Edge(int to, boolean dir, int index) {
    	this.to = to;
        this.dir = dir;
        this.index = index;
    }
}

 

2️⃣ graph

// 입력 예제1
graph = [
    [],
    [Edge(to=3, dir=true, index=2), Edge(to=5, dir=true, index=3)], // 1번 노드
    [Edge(to=4, dir=false, index=0), Edge(to=3, dir=false, index=1)], // 2번 노드
    [Edge(to=1, dir=false, index=2), Edge(to=2, dir=true, index=1)], // 3번 노드
    [Edge(to=2, dir=true, index=0)], // 4번 노드
    [Edge(to=1, dir=false, index=3)] // 5번 노드
]

 

3️⃣ init

1번 노드를 루트로 설정했을 때, 뒤집혀야 할 간선의 개수를 계산

 

왜 필요한지 아래 그림으로 알아보자! (입력 예제1 기준)

init 함수로 1번 노드가 루트면 flipCnt를 계산할 수 있다. 현재는 3이다.

이 값을 기준으로 간선이 뒤집힌 경우 flipCount + 1, 간선이 뒤집히지 않은 경우 flipCount - 1 하면서 최소 뒤집기 횟수를 계산할 수 있다.

(이 값을 본격적으로 계산하는 함수는 dfs)

 

 

static int init(int node, int parent) {
    int sum = 0;
    for (Edge edge : graph.get(node)) {
        int next = edge.to;
        boolean dir = edge.dir;
        int idx = edge.index;
        if (next == parent) continue;

        flip[idx] = dir;
        sum += init(next, node) + (dir ? 1 : 0);
    }
    return sum;
}

 

4️⃣ dfs

본격적으로 최소 뒤집기 횟수를 구하는 함수 (백트래킹 한다고 생각하면 된다)

 

탐색 과정에서 위에서 설명했듯이

- 간선을 뒤집지 않는 경우: flipCount - 1

- 간선을 뒤집는 경우: flipCount + 1

static void dfs(int node, int parent, int flipCount) {
    if (flipCount < minFlip) {
        minFlip = flipCount;
        for (int i = 0; i < N - 1; i++) {
            result[i] = flip[i];
        }
    }

    for (Edge edge : graph.get(node)) {
        int next = edge.to;
        boolean dir = edge.dir;
        int idx = edge.index;
        if (next == parent) continue;

        if (dir) {
            flip[idx] = false;
            dfs(next, node, flipCount - 1);
            flip[idx] = true;
        } else {
            flip[idx] = true;
            dfs(next, node, flipCount + 1);
            flip[idx] = false;
        }
    }
}

 

 

참고

https://burningfalls.github.io/algorithm/boj-24232/