题目描述

有一颗 nn 个节点树,其中有 kk 个点是关键点。现在有 mm 次询问,每次询问是否存在一条 uvu \to v 的路径(不一定是简单路径),使得路径上任意两个关键点并且 u,vu,v 和路径上与其最近的关键点的距离小于等于 kk

n,m,k2×105n,m,k \le 2\times10^5

思路

考虑我们会怎么走,显然是:uu \to 关键点 \to 关键点 \to \cdots \to 关键点 v\to v

那我们可以建立并查集,把 kk 步之内可以互达的关键点以及可以到达他们的普通点并在一块。可以使用 BFS 计算。从每个关键点向外 BFS k2\frac{k}{2} 步即可。当 kk 是奇数的时候,我们只需要对每个点建边并且使 kk22 即可。

接下来我们套路的进行树上倍增,对于一次询问 uvu \to v,我们只需要将 uuvvk2\frac{k}{2} 步,再将 vvuuk2\frac{k}{2} 步。然后再判断跳完后的答案是否在并查集同一集合内即可。

时间复杂度 O(nlogn)\mathcal{O}(n \log n)

代码

#include <bits/stdc++.h>
using namespace std;
const int N= 2e5 + 5;

vector<int> G[N << 1];
int fa[N << 1][23], p[N << 1], depth[N << 1];
int n, k, r, Q, vis[N << 1];

int LCA(int u, int v) {
    if (depth[u] < depth[v]) swap(u, v);
    for (int i = 19; i >= 0; i--) 
        if (depth[fa[u][i]] >= depth[v])
            u = fa[u][i];
    if (u == v) return u;
    for (int i = 19; i >= 0; i--)
        if (fa[u][i] != fa[v][i])
            u = fa[u][i], v = fa[v][i];
    return fa[u][0];
}

void dfs(int u, int f) {
    fa[u][0] = f;
    depth[u] = depth[f] + 1;
    for (int v : G[u])
        if (v != f)
            dfs(v, u);
};

int find(int x) {
    return p[x] == x ? x : p[x] = find(p[x]);
}

int move(int u, int k) {
    for (int i = 19; i >= 0; i--) {
        if (1 << i <= k) {
            u = fa[u][i];
            k ^= 1 << i;
        }
    }
    return u;
}

int main() {
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    cin >> n >> k >> r;
    for (int i = 1; i <= n << 1; i++) p[i] = i;
    for (int i = 1; i < n; i++) {
        static int u, v;
        cin >> u >> v;
        G[u].push_back(n + i); G[n + i].push_back(u);
        G[v].push_back(n + i); G[n + i].push_back(v);
    }
    dfs(1, 0);
    for (int i = 1; i <= 19; i++)
        for (int j = 1; j <= n << 1; j++)
            fa[j][i] = fa[fa[j][i - 1]][i - 1];
    memset(vis, -1, sizeof vis); queue<int> q;
    for (int i = 1; i <= r; i++) {
        static int u;
        cin >> u;
        vis[u] = 0;
        q.push(u);
    }
    while (q.size()) {
        int u = q.front(); q.pop();
        if (vis[u] == k) break;
        for (int v : G[u]) {
            int x = find(u), y = find(v);
            p[x] = y;
            if (vis[v] == -1) {
                q.push(v);
                vis[v] = vis[u] + 1;
            }
        }
    }
    cin >> Q;
    while (Q--) {
        static int u, v, z;
        cin >> u >> v;
        z = LCA(u, v);
        if (depth[u] + depth[v] - 2 * depth[z] <= 2 * k) {
            cout << "YES" << endl;
        } else {
            int t1 = k <= depth[u] - depth[z] ? find(move(u, k)) : find(move(v, depth[u] + depth[v] - 2 * depth[z] - k));
            int t2 = k <= depth[v] - depth[z] ? find(move(v, k)) : find(move(u, depth[u] + depth[v] - 2 * depth[z] - k));
            cout << (t1 == t2 ? "YES" : "NO") << endl;
        }
    }
    return 0;
}