모든 정점에서 출발해서 도달할 수 있는 정점들의 개수의 합을 구하는 문제이다. 반대로 생각해서 한 정점에 도달할 수 있는 정점들의 개수의 합을 구한다고 생각해 보자.
out-degree가 모두 1이므로, 하나의 연결된 그래프에서 가능한 SCC의 개수는 최대 1개이다. SCC에 속하는 정점들의 경우 SCC의 크기의 제곱을 하면 한 번에 구할 수 있다. SCC에 연결된 가지들의 경우, 맨 끝에서 탐색하며 값을 더해 나가면 답을 구할 수 있다. 가지들이 합쳐지는 부분에서 주의해야 하는데, 여차하면 오답이나 시간초과가 발생할 수 있다. 나는 in-degree를 세는 방법으로 해결했다. in-degree가 1 초과이면 해당 정점에 현재까지의 가중치를 더해 놓고 in-degree를 1 감소시킨다. 이렇게 되면 마지막 가지를 탐색하는 과정에서 모든 값들이 더해져 O(N)에 구현할 수 있다.
가지들에서 SCC에 연결되는 부분을 구현하는 부분이 개인적으로 까다로웠다. 오랜만에 구현하는 타잔 알고리즘에서 많은 시간이 소요되어 문제 난이도에 비해 고전했다.
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
constexpr ll INF = 1e18;
inline void init() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
}
int main() {
init();
ll n, ans = 0;
cin >> n;
vector<ll> a(n + 1), cnt(n + 1);
vector<vector<ll>> gr(n + 1);
for (ll i = 1; i <= n; i++) {
cin >> a[i];
cnt[a[i]]++;
}
vector<bool> vst(n + 1);
vector<ll> ord(n + 1), scc(n + 1), scc_sz(n + 1), sum(n + 1);
function<ll(ll)> dfs = [&](ll cur) -> ll {
static stack<ll> s;
static ll id = 0, scc_idx = 0;
ord[cur] = ++id;
s.emplace(cur);
ll p = INF;
if (ord[a[cur]] == 0) p = min(p, dfs(a[cur]));
else if (!scc[a[cur]]) p = min(p, ord[a[cur]]);
if (p == ord[cur]) {
ll sz = 0, tmp = -1;
scc_idx++;
while (tmp != cur) {
tmp = s.top();
s.pop();
scc[tmp] = scc_idx;
sz++;
}
scc_sz[scc_idx] = sz;
ans += sz * sz;
}
return p;
};
for (ll i = 1; i <= n; i++) {
if (ord[i]) continue;
dfs(i);
}
queue<ll> q;
for (ll i = 1; i <= n; i++) {
if (cnt[i] == 0) q.emplace(i);
}
while (!q.empty()) {
ll cur = q.front();
q.pop();
ll i = 0;
while (!scc[cur] and cnt[cur] <= 1) {
i += sum[cur];
ans += ++i;
cur = a[cur];
}
if (scc[cur]) {
ans += scc_sz[scc[cur]] * i;
} else {
sum[cur] = i;
cnt[cur]--;
}
}
cout << ans;
return 0;
}