본문 바로가기
알고리즘 문제 풀이

[백준 11437] LCA

by 다빈치코딩 2023. 10. 10.

목차

    반응형

    문제 출처 : https://www.acmicpc.net/problem/11437

     

    11437번: LCA

    첫째 줄에 노드의 개수 N이 주어지고, 다음 N-1개 줄에는 트리 상에서 연결된 두 정점이 주어진다. 그 다음 줄에는 가장 가까운 공통 조상을 알고싶은 쌍의 개수 M이 주어지고, 다음 M개 줄에는 정

    www.acmicpc.net

    가장 기본적인 형태의 LCA 문제 입니다. LCA 알고리즘의 설명은 이전에 게시한 알고리즘 설명으로 대신하겠습니다.

    2023.10.06 - [알고리즘 설명] - 최소공통조상(LCA)

     

    최소공통조상(LCA)

    LCA란? 최소 공통 조상(Lowest Common Ancestor) 줄여서 LCA로 불리는 이 알고리즘은 트리에서 두 정점이 가지고 있는 가장 가까운 공통 조상을 찾는 알고리즘 입니다. 위와 같은 그래프가 있을 때 6번 정

    davincicoding.tistory.com

     

    입력 받기

    mii = lambda : map(int, input().split())
    N = int(input())
    
    tree = [[] for _ in range(N+1)]
    for _ in range(N-1):
        u, v = mii()
        tree[u].append(v)
        tree[v].append(u)
    
    M = int(input())
    for _ in range(M):
        u, v = mii()
        print(lca(u, v))
    

    노드의 개수 N을 입력 받고 N - 1개의 트리의 연결 정보를 받습니다. 트리의 연결 정보를 통해 tree를 구성해 줍니다. 다음으로 공통 조상을 알고 싶은 M개의 정점 쌍을 입력 받습니다. 두 정점 u, v를 입력 받아 lca 함수로 공통 조상을 찾아 출력해 줍니다. lca 함수는 뒤에 작성하겠습니다.

    트리 구성하기

    import sys
    sys.setrecursionlimit(10 ** 5)
    
    parent = [-1] * (N+1)
    level = [-1] * (N+1)
    
    def set_tree(node, lv):
        level[node] = lv
    
        for child in tree[node]:
            if level[child] == -1: 
                parent[child] = node
                set_tree(child, lv + 1)
    
    set_tree(1, 0)
    

    노드의 부모 정보를 저장하는 parent 리스트와 트리의 깊이를 저장할 level 리스트를 만들어 줍니다. 재귀 함수를 사용해야 하기 때문에 트리의 깊이를 지정하는 setrecursionlimit함수를 사용하여 재귀의 한도를 늘려줍니다.

    set_tree(1, 0)
    

    set_tree 함수는 트리의 정보를 구성하여 줍니다. 1번 노드를 0번 깊이를 시작으로 재귀함수로 트리 정보를 구성해 줍니다.

    레벨 지정하기

    level[node] = lv
    

    입력받은 레벨을 현재 노드의 레벨로 지정해 줍니다. 즉 1번 노드는 0의 깊이를 가지게 됩니다.

    부모 지정하기

        for child in tree[node]:
            if level[child] == -1: 
                parent[child] = node
                set_tree(child, lv + 1)
    

    DFS의 형태로 각 노드들의 부모 노드를 저장합니다. 이 함수를 통해 노드들의 부모가 누구인지 바로바로 알 수 있습니다. 노드의 자식 정점은 깊이가 부모 노드 보다 깊이가 1씩 늘어 납니다. 모든 정점들의 깊이가 모두 저장될 때까지 반복을 계속해 줍니다.

    lca 함수 만들기

    def lca(a, b):
        if level[a] < level[b]:
            a, b = b, a
        while True:
            if level[a] == level[b]:
                break 
            a = parent[a]
        
        for _ in range(level[a]):
            if a == b:
                return a
    
            a = parent[a]
            b = parent[b]
        
        return a
    

    두 정점의 정보를 가지고 공통 조상을 찾는 lca 함수 입니다. 위 함수를 하나하나 뜯어 보겠습니다.

    a, b 깊이 확인하기

    먼저 노드들의 정보를 가지고 깊이를 같게 만들어 주어야 합니다.

        if level[a] < level[b]:
            a, b = b, a
    

    leval[a]와 leval[b]를 비교하여 깊이가 더 깊은 노드를 a로 만들어 줍니다. a가 더 깊기 때문에 a를 먼저 b와 같은 높이로 만들어 주어야 합니다.

    a, b 깊이 같게 만들기

        while True:
            if level[a] == level[b]:
                break 
            a = parent[a]
    

    a가 더 깊은 레벨에 있기 때문에 a의 부모를 찾아 a로 바꿔 줍니다. 이것은 a와 b의 깊이가 같아질 때까지 반복합니다.

    a, b 공통 조상 찾기

        for _ in range(level[a]):
            if a == b:
                return a
    
            a = parent[a]
            b = parent[b]
        
        return a
    

    level[a] 의 깊이만큼 반복하며 a, b 모두 부모를 찾아 올라갑니다. 그러다 a, b의 값이 같아지면 그것이 바로 최소 공통 조상이 되고, 그것을 리턴해 줍니다. 만약 끝까지 반복했다면 a값은 루트인 1번 노드가 되고 루트는 최초 노드이기 때문에 그냥 이것을 리턴해 주면 됩니다.

    전체 코드

    전체 코드를 확인해 보겠습니다.

    import sys
    sys.setrecursionlimit(10 ** 5)
    
    mii = lambda : map(int, input().split())
    N = int(input())
    
    tree = [[] for _ in range(N+1)]
    parent = [-1] * (N+1)
    level = [-1] * (N+1)
    
    for _ in range(N-1):
        u, v = mii()
        tree[u].append(v)
        tree[v].append(u)
    
    def set_tree(node, lv):
        level[node] = lv
    
        for child in tree[node]:
            if level[child] == -1: 
                parent[child] = node
                set_tree(child, lv + 1)
    
    set_tree(1, 0)
    
    def lca(a, b):
        if level[a] < level[b]:
            a, b = b, a
        while True:
            if level[a] == level[b]:
                break 
            a = parent[a]
        
        for _ in range(level[a]):
            if a == b:
                return a
    
            a = parent[a]
            b = parent[b]
        
        return a
    
    M = int(input())
    for _ in range(M):
        u, v = mii()
        print(lca(u, v))
    
    반응형