본문 바로가기
Atcoder

ABC 357 E

by gmroh 2025. 2. 5.

모든 정점에서 출발해서 도달할 수 있는 정점들의 개수의 합을 구하는 문제이다. 반대로 생각해서 한 정점에 도달할 수 있는 정점들의 개수의 합을 구한다고 생각해 보자. 

 

out-degree가 모두 1이므로, 하나의 연결된 그래프에서 가능한 SCC의 개수는 최대 1개이다. SCC에 속하는 정점들의 경우 SCC의 크기의 제곱을 하면 한 번에 구할 수 있다. SCC에 연결된 가지들의 경우, 맨 끝에서 탐색하며 값을 더해 나가면 답을 구할 수 있다. 가지들이 합쳐지는 부분에서 주의해야 하는데, 여차하면 오답이나 시간초과가 발생할 수 있다.  나는 in-degree를 세는 방법으로 해결했다. in-degree가 1 초과이면 해당 정점에 현재까지의 가중치를 더해 놓고 in-degree를 1 감소시킨다. 이렇게 되면 마지막 가지를 탐색하는 과정에서 모든 값들이 더해져 O(N)에 구현할 수 있다. 

 

가지들에서 SCC에 연결되는 부분을 구현하는 부분이 개인적으로 까다로웠다. 오랜만에 구현하는 타잔 알고리즘에서 많은 시간이 소요되어 문제 난이도에 비해 고전했다.

 

#include <bits/stdc++.h>

using namespace std;
using ll = long long;

constexpr ll INF = 1e18;

inline void init() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
}

int main() {
    init();

    ll n, ans = 0;

    cin >> n;

    vector<ll> a(n + 1), cnt(n + 1);
    vector<vector<ll>> gr(n + 1);

    for (ll i = 1; i <= n; i++) {
        cin >> a[i];
        cnt[a[i]]++;
    }

    vector<bool> vst(n + 1);
    vector<ll> ord(n + 1), scc(n + 1), scc_sz(n + 1), sum(n + 1);

    function<ll(ll)> dfs = [&](ll cur) -> ll {
        static stack<ll> s;
        static ll id = 0, scc_idx = 0;

        ord[cur] = ++id;
        s.emplace(cur);

        ll p = INF;

        if (ord[a[cur]] == 0) p = min(p, dfs(a[cur]));
        else if (!scc[a[cur]]) p = min(p, ord[a[cur]]);

        if (p == ord[cur]) {
            ll sz = 0, tmp = -1;
            scc_idx++;

            while (tmp != cur) {
                tmp = s.top();
                s.pop();
                scc[tmp] = scc_idx;
                sz++;
            }

            scc_sz[scc_idx] = sz;
            ans += sz * sz;
        }

        return p;
    };

    for (ll i = 1; i <= n; i++) {
        if (ord[i]) continue;
        dfs(i);
    }

    queue<ll> q;

    for (ll i = 1; i <= n; i++) {
        if (cnt[i] == 0) q.emplace(i);
    }

    while (!q.empty()) {
        ll cur = q.front();
        q.pop();

        ll i = 0;

        while (!scc[cur] and cnt[cur] <= 1) {
            i += sum[cur];
            ans += ++i;
            cur = a[cur];
        }

        if (scc[cur]) {
            ans += scc_sz[scc[cur]] * i;
        } else {
            sum[cur] = i;
            cnt[cur]--;
        }
    }

    cout << ans;

    return 0;
}

 

 

'Atcoder' 카테고리의 다른 글

AGC 019 C  (0) 2025.02.14
ABC 355 F  (0) 2025.02.05