题解:[CF609E/代码源OJ L4上 倍增 - 981] Minimum spanning tree for each edge
感谢@Cute_SilverWolf为这篇题解作出的贡献。
有关最小生成树的问题有很多,像是今年S组T2。每次用kruskal写出一道题,就总是觉得kruskal的题其实都很简单,但是场上我就是写不出来它的简单结论。
——Hsl_Beat在写这篇文章时对SW说的感叹
题目描述
现有一张 个点, 条边的无向连通图,边带权。对于每个 ,求出如果一个最小生成树中要求必须包括第 条边,最小生成树的边权总和最小值。
,,,,。
思路
就这题而言,我们不知道为什么要写树链剖分(其实是因为截止到目前我们都不会树链剖分)
首先我们可以对着原图用kruskal跑一次MST,这样对于MST里面的边,答案就是MST的值了,我们假设这个值为。
现在问题来到了不在MST里的边怎么做。先来看一张图:

观察这张图,其中红色的边是原图中的最小生成树,我们现在要计算虚线那条非MST里的边的答案。
由于我们想让添加虚线这条边后原图也是一个生成树,那我们就需要从原本生成树里面删掉一条边,假设删掉的边边权为,虚线边权为,连接的两个点分别为,那么答案就是,其中和确定,所以我们想让权值最小化,就需要求出最大的。
比如对于这张图片,最好情况就是把和之间边权为的边删掉。
现在我们来想一想删掉的这条边要满足什么条件。首先删掉它之后,是不能连通的,因为加上虚线边之后就会出现环,不符合生成树的特征。不难想到,删除的这条边肯定在到它们的LCA之间这几条边。
所以我们可以倍增,先求出MST并把MST里面的边放到求LCA的边里去,剩下的和LCA的经典手法差不多,只是的含义是结点往上跳次之后边权的最大值。
查询就是我们在求当前虚线边的的LCA时,每有一个结点往上跳,就代表跳过的这一段都是可以作为我们求最大值答案参考的边,这个时候就要求一下最大值。具体可以见代码。
最后别忘了你的写法要是和我的一样,与他们LCA之间的边也要考虑。
(好吧,我们都很讨厌写倍增,这个代码历经千辛万苦才搞定)
代码
using namespace std;
int f[200005];
int n, m;
void init()
{
for (int i = 0; i <= n; i++) {
f[i] = i;
}
}
int find(int x)
{
if (x == f[x]) {
return x;
}
return f[x] = find(f[x]);
}
void merge(int x, int y)
{
x = find(x);
y = find(y);
f[x] = y;
}
bool same(int x, int y)
{
return find(x) == find(y);
}
bool vis[200005];
struct node
{
int u, v, w, idx;
};
int fa[200005][30], w[200005][30], dep[200005];
vector<pair<int, int>> edges[200005];
vector<node> edge;
void dfs(int x, int f)
{
for (int i = 1; i < 30; i++) {
fa[x][i] = fa[fa[x][i - 1]][i - 1];
w[x][i] = max(w[x][i - 1], w[fa[x][i - 1]][i - 1]);
}
for (auto nex : edges[x]) {
int v = nex.first;
if (v != f) {
fa[v][0] = x;
w[v][0] = nex.second;
dep[v] = dep[x] + 1;
dfs(v, x);
}
}
}
int lca(int x, int y)
{
int ans = 0;
if (dep[x] < dep[y]) {
swap(x, y);
}
for (int i = 29; i >= 0; i--) {
if (dep[fa[x][i]] >= dep[y]) {
ans = max(ans, w[x][i]);
x = fa[x][i];
}
}
if (x == y) {
return ans;
}
for (int i = 29; i >= 0; i--) {
if (fa[x][i] != fa[y][i]) {
ans = max(ans, max(w[x][i], w[y][i]));
x = fa[x][i];
y = fa[y][i];
}
}
ans = max(ans, max(w[x][0], w[y][0]));
return ans;
}
signed main()
{
cin >> n >> m;
edge.resize(m);
for (int i = 0; i < m; i++) {
cin >> edge[i].u >> edge[i].v >> edge[i].w;
edge[i].idx = i;
}
sort(edge.begin(), edge.end(), [&](node a, node b) {
return a.w < b.w;
});
init();
int cnt1 = 0, cnt2 = 0;
for (int i = 0; i < m; i++) {
if (!same(edge[i].u, edge[i].v)) {
merge(edge[i].u, edge[i].v);
edges[edge[i].u].push_back({edge[i].v, edge[i].w});
edges[edge[i].v].push_back({edge[i].u, edge[i].w});
cnt1 += edge[i].w;
cnt2++;
vis[edge[i].idx] = 1;
}
if (cnt2 == n - 1) {
break;
}
}
dfs(1, 0);
vector<int> anss(m);
for (int i = 0; i < m; i++) {
if (vis[edge[i].idx]) {
anss[edge[i].idx] = cnt1;
continue;
}
int maxx = lca(edge[i].u, edge[i].v);
anss[edge[i].idx] = cnt1 - maxx + edge[i].w;
}
for (int i = 0; i < m; i++) {
cout << anss[i] << '\n';
}
return 0;
}