Dynamic Programming과 DFS spanning tree에 대한 지식 모두를 요하는 문제라 어려울 수 있다.
우선 주어진 입력으로 주어지는 방향그래프가 어떤 모양일지 생각해보자.
주어진 그래프는 n개의 vertex와 n개의 edge로 이루어져 있다.
따라서 connected component 끼리 나눠도 각각의 component에 존재하는 vertex와 edge의 개수는 같다.
따라서 이 1개의 component에 대해서만 확인해 보자.
한 component는 k개의 vertex와 edge로 이루어져있는데, 이는 즉 DFS spanning tree를 만들면 남는 edge가 1개가 생기고 이 edge를 제거하면 전체 그래프는 tree가 된다는 것이다.
즉, 한개의 component는 중심이 되는 cycle이 하나가 존재하고, 이 cycle에 여러개의 tree가 달라붙어 있는 모양을 가지게 된다. 다음은 component의 예시이다.
전체 component를 색칠하는 경우의 수는 (사이클을 색칠하는 경우의 수) * (트리 부분을 색칠하는 경우의 수)이다.
따라서, 이 각각을 구한다면 원하는 경우의 수를 구할 수 있다.
우선 사이클 부분의 경우의 수를 구해보자.
dp(n,k)=(n개의 칸에 k가지 색으로 칠할 때, 이웃한 색깔을 다르게 칠하면서, 첫번째 칸과 n번째 칸을 다르게 칠하는 경우의 수)
이 dp에 대한 점화식을 구하기 위해서 다음과 같이 경우를 나눠보자.
case 1) 1번 노드의 색깔과 n−1번 노드의 색깔이 다를 경우
이 경우는 1번부터 n−1번 까지 칠하는 경우의 수가 dp(n−1,k)와 일대일 대응이 된다.
n번 칸에는 1번과 n−1번과 다른 색깔로 칠해야 하므로 k−2개의 색깔중 하나를 골라서 칠할 수 있다.
따라서 이 경우의 수는 (k−2)×dp(n−1,k)이다.
case 2) 1번 노드의 색깔과 n−1번 노드의 색깔이 같을 경우
이 경우는 1번과 n−1번을 겹쳐서 cycle을 만든다고 생각하면, 1번부터 n−1번 까지 칠하는 경우의 수가 dp(n−2,k)와 일대일 대응이 된다.
n번 칸에는 1번과 n−1번과 다른 색깔로 칠해야 하는데, 두 색깔이 같으므로 k−1개의 색깔중 하나를 골라서 칠할 수 있다.
따라서 이 경우의 수는 (k−1)×dp(n−2,k)이다.
종합해보면 dp(n,k)=(k−2)×dp(n−1,k)+(k−1)×dp(n−2,k)이다.
여기서 초기값을 잘 생각해보아야 하는데, n≥2일 때에는 dp값과 구하고자 하는 사이클의 색칠 값이 같지만, n=1일 때에는 사이클의 색칠하는 경우의 수는 k이지만 dp값은 0이어야 한다.
따라서 dp(1,k)=0,dp(2,k)=k2−k로 dp테이블을 먼저 채워준 후, dp(1,k)=k로 정의해야 한다.
다음은 트리 부분의 경우의 수를 구해보자.
이 부분은 어렵지 않은데, 각 노드들은 유일한 부모노드가 존재하며, 위에서부터 부모의 색깔과 다르게 칠하기만 해면 재귀적으로 칠할 수 있으므로, 각 칸마다 k−1가지의 색깔로 칠할 수 있다.
요약하자면 주어진 input에 t개의 component가 있고 각각의 component를 구성하는 cycle의 크기를 ci라고 하면, 총 경우의 수는 다음과 같은 식으로 표현이 가능하다.
t∏i=1dp(ci,k)×(k−1)(n−∑ti=1ci)
cycle의 탐색은 DFS탐색을 하면서 visitTime을 관리하면 쉽게 할 수 있다.
출발한 노드의 visitTime과 도착한 노드의 visitTime을 비교했을 때, 출발한 노드의 visitTime이 더 크다면, 지금 도착한 노드는 이미 계산이 끝난 노드이므로 무시하고, 도착한 노드의 visitTime이 더 크다면 현재시간과 도착한 노드의 시간의 차이가 cycle의 크기가 된다.
따라서 시간복잡도는 dp 테이블을 채우는데 O(n), DFS를 하면서 cycle 부분의 색칠 수를 곱하는데에 O(n) 나머지 트리 부분의 색칠 수를 곱하는데에 O(n)(이 부분은 O(logn)에도 가능하지만 bottleneck이 아니므로 무시)으로, 총 시간복잡도는 O(n)이다.
Time Complexity : O(n)
코드
#include<iostream>#include<algorithm>#define N 1000004#define MOD 1000000007usingnamespacestd;
int n,k;
int a[N],t[N];
typedeflonglong ll;
ll dp[N];
intmain(){
cin.tie(NULL);
cin.sync_with_stdio(false);
cin>>n>>k;
for(int i=1;i<=n;i++) cin>>a[i];
dp[0]=k;
for(int i=2;i<=n;i++) dp[i] = (dp[i-1]*(k-2)+dp[i-2]*(k-1))%MOD;
dp[1]=k;
int c=1,cnt=0;
ll ans=1;
for(int i=1;i<=n;i++)
{
if(t[i]!=0)continue;
int j=i;
for(;t[j]==0; j=a[j]) t[j]=c++;
if(t[j]>=t[i])
{
int x = c-t[j];
ans = (ans*dp[x])%MOD;
cnt += x;
}
}
for(int i=n-cnt;i>0;i--) ans = (ans * (k-1))%MOD;
cout<<ans;
}