hiho一下第112周《Total Highway Distance》题目分析

1
1

题目大意

给定一颗有N个节点的带权树,之后进行M次操作:

Q操作:询问树上所有点对之间的距离之和

E操作:修改树上某一条边的权值

解题思路

本题的考察点是对树结构的理解。

要求解所有点对的距离之和,一般的思路是使用Floyd算法计算出距离矩阵,再在矩阵内求和。

但是对于题目给出的节点数目N=100,000,显然Floyd算法O(N^3)的时间复杂度是无法实现的。

因此我们需要从其它角度来思考解决办法。

对于树中任意一条边e,一定满足这样一个情况:

假设A集合有x个节点,则B集合有N-x个节点。

而任何两个节点i,j,若i属于A,j属于B,则其连接路径中一定包含有边e。

i,j的组合有x*(N-x)种,因此可以确定一共有x*(N-x)条路径中包含有边e。换句话说,在所有点对的距离和之中,边e一共计算了x*(N-x)次。

也即是

所有点对之间的距离之和 = sigma(e的权值 * e一边节点数 * e另一边的节点数)

利用这种算法来计算的结果,也能很方便的对修改进行维护。

当一条边e的权值增加delta(可以为负)之后,我们也只需要将总的权值增加delta乘上两边节点数量的乘积。


有了算法之后,我们考虑如何来实现这个算法。

首先我们任意选择一个节点作为根,对整棵树进行遍历。遍历的过程中我们需要做3件事:

  1. 标记每个节点的层次:方便之后修改操作快速查找边

  2. 记录子树的节点数量:为了计算一条边左右的节点总数。

  3. 计算出所有点对距离之和total的初值

其实现的伪代码为:

total = 0 // 距离之和
root = 1 // 我们选择1作为根节点
level[] = 0 // 初始化都为0
node[] = 0 // 记录子树节点数量,初始化为0
path[] = 0 // path[i]表示节点i到它父亲节点的边的长度

level[root] = 1 // 标记根节点的层数

DFS(rt):
    node[rt] = 1 // 至少包含的当前节点
    For each e(v,w) is connect to rt
        // 枚举所有连接到rt的边
        // e(v,w) 表示这条rt连接出去的边的另一端节点编号和权值
        If (level[v] == 0) Then // v未被访问过,则v属于子节点
            level[v] = level[rt] + 1 // 层数加1
            DFS(v) // 迭代处理子树
            node[rt] = node[rt] + node[v] // 加上该子树的节点
            total = total + w*node[v]*(N-node[v]) // 将这条边的的值加入总的距离
        End If
    End For

经过一次遍历之后,我们就得到最初的所有点对距离之和total

接下来考虑如何维护,假设我们读入了一条修改信息x y key,我们按照如下方式来处理:

edit(x, y, key):
    // 首先我们确保x是y的儿子节点
    If (level[x] < level[y]) Then
        Swap(x, y)
    End If
    // 此时path[x]也就是我们要修改的边的原长度
    delta = key - path[x] // 获取增量
    total = total + delta * node[x] * (N - node[x]) // 更新总长度
    path[x] = key // 更新边

该算法下第一次遍历花费的时间复杂度为O(N),每一次询问维护的时间复杂度为_O(1)。

因此总的时间复杂度为O(N+M),能够顺利的通过所有的数据。

  • 题目中没有标明形成的图是树呀,high way完全可以形成环路,在形成环路过后貌似就不能这样做了

  • 第一段说了N个点N-1条边,并且是联通的。所以一定是树。

  • 添加评论
  • reply

4 answer(s)

0

实现的伪代码里是不是缺了对path这个数组的赋值?

0

求教为什么会T啊

#include <cstdio>
#include <cmath>
#include <algorithm>
#define MAXN 100010
#define MAXM 50010
#define ll long long
using namespace std;
int m,e[MAXM << 1],nx[MAXM << 1],ed[MAXN],cnt,f[MAXN],path[MAXN],v[MAXM << 1],siz[MAXN],n;
ll ans;
void add(int x,int y,int z)
{
    e[cnt] = y; v[cnt] = z; nx[cnt] = ed[x]; ed[x] = cnt++; 
}
ll calc(int x,int y)// fa = x   , son = y
{
    return (ll)siz[y]*((ll)n-(ll)siz[y]);
}
void dfs(int p,int fa)
{
    siz[p] = 1;
    f[p] = fa;
    int i = ed[p];
    while (i > -1)
    {
        if (e[i] != fa)
        {
            path[e[i]] = i;
            dfs(e[i],p);
            siz[p] += siz[e[i]];
            ans += calc(p,e[i])*(ll)v[i];
        }
        i = nx[i];
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i = 1;i <= n;i++) ed[i] = -1;
    for (int i = 1;i <= n-1;i++)
    {
        int x,y,z;
        scanf("%d%d%d",&x,&y,&z);
        add(x,y,z);
        add(y,x,z);
    }
    dfs(1,0);
    while (m--)
    {
        char s[10];
        scanf("%s",&s);
        switch (s[0])
        {
            case 'Q' : printf("%lld\n",ans); break;
            case 'E' :
                int x,y,z;
                scanf("%d%d%d",&x,&y,&z);
                if (f[x] == y) swap(x,y);
                int i = path[y];
                ans += calc(x,y)*((ll)z-(ll)v[i]);
                v[i^1] = v[i] = z;
                break;
        }
    }
    return 0;
}
0
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Scanner;

public class Main {
    public static int N,M,i;
    public static long total=0;
    public static HashMap<Integer,Integer>[] map;
    public static int[] node,level;
    public static void main(String[] args) {
        Scanner scan = new Scanner(System.in);
        N = scan.nextInt();
        M = scan.nextInt();
        map = new HashMap[N+1];
        for(int i=1;i<=N;i++)   map[i]=new HashMap();
        node = new int[N+1];
        level = new int[N+1];
        int u,v,k;
        for(i=2;i<=N;i++){
            u=scan.nextInt();
            v=scan.nextInt();
            k=scan.nextInt();
            map[u].put(v,k);
            map[v].put(u,k);
        }
        level[1]=1;
        dfs(1); 
        for(i=1;i<=M;i++){
            String input = scan.next();
            if(input.equals("QUERY"))
                System.out.println(total);
            else{
                edit(scan.nextInt(),scan.nextInt(),scan.nextInt());
            }
        }
    }
    public static void dfs(int u){
        node[u]=1;
        Iterator entries = map[u].entrySet().iterator();
        while(entries.hasNext()){
             Map.Entry entry = (Map.Entry) entries.next();               
             int v = (int)entry.getKey();  
             if(level[v]==0){
                int k = (int)entry.getValue(); 
                level[v]=level[u]+1;
                dfs(v);
                node[u]+=node[v];
                total+=k*node[v]*(N-node[v]);
             }    
        }
    }
    public static void edit(int u, int v, int k){
        int child = level[u]>level[v]?u:v;  
        total+=(k-map[u].get(v))*node[child]*(N-node[child]);
        map[u].remove(v);map[v].remove(u);
        map[u].put(v,k);map[v].put(u,k);
    }
}
  • 虽然你total定义成long了,但是计算过程也可能超过int范围。比如(k-map[u].get(v))*node[child]*(N-node[child]) 和 k*node[v]*(N-node[v])

  • 多谢 对这两处加了转型之后就ok了

  • 多谢,也是类型卡死了。。。。

  • 添加评论
  • reply
0

想问一下为何我的代码只能通过前60%,之后不是超时而是Wrong Anwser呢?下面是我的代码

write answer 切换为英文 切换为中文


转发分享