《算法竞赛·快冲300题》将于2024年出版,是《算法竞赛》的辅助练习册。
所有题目放在自建的OJ New Online Judge。
用C/C++、Java、Python三种语言给出代码,以中低档题为主,适合入门、进阶。
文章目录
- 题目描述
- 题解
- C++代码
- Java代码
- Python代码
“ 直径点对” ,链接: http://oj.ecustacm.cn/problem.php?id=1736
题目描述
【题目描述】 给你一个n个节点的树,编号为1到n。求存在多少对节点<u,v>,使得u到v的距离等于这棵树的直径。
树的直径:树上最远的两个点的距离
树上两点的距离:两点之间边的数量
<1,2>和<2,1>属于两对节点
【输入格式】 第一行为正整数n(n≤300000)。接下来n-1行,每行两个数字u和v,表示点u和点v之间存在边
【输出格式】 输出一个数字表示答案。
【输入样例】
4
1 2
1 3
1 4
【输出样例】
6
题解
求树的直径有两种方法[ 《算法竞赛》,清华大学出版社,罗勇军、郭卫斌著,231页,“4.7.2 树的直径”。]:
(1)做两次DFS,第一次求一个任意点的最远点s,第二次求s的最远点t,s和t之间的距离就是树的直径。
(2)树形DP。算法不难,但是解释有点长,见《算法竞赛》233页的说明。请仔细理解如何用树形DP求树的直径。
以上两种方法,算法复杂度都为O(n),即只对每个节点处理O(1)次。
本题用这两种方法都能求解。下面用树形DP求树的直径,求直径的同时统计距离为直径的节点对数量。
【重点】 树形DP。
C++代码
定义状态dp[],dp[u]表示从u出发的最长路径的长度,这条路径的终点是u的一个叶子节点。
定义num[],num[u]表示从u出发的最长路径的数量。
定义maxlen,表示直径的长度,k是经过直径的节点对数量。
详细解释见代码的注释。如果仍然不能理解,可以把代码中与num[]和k有关的第8、20、22、23、26、28、29行删除,剩下的代码就是《算法竞赛》233页的树形DP求直径的模板。然后再加上num[]和k的代码并理解。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 300010;
vector<int>e[N];
int dp[N]; //dp[u]:从u出发的最长路径
int num[N]; //num[u]:从u出发的最长路径数量
ll maxlen = 0, k = 0; //直径的长度maxlen,经过直径的节点对数量k
void dfs(int u, int fa){
dp[u] = 0;
num[u] = 1;
for(auto v : e[u]){
if(v == fa) continue;
dfs(v, u); //继续深入,回溯时带回算好的dp[v]
int now = dp[v] + 1; //从u出发,经过子节点v的最长路径
if(now + dp[u] > maxlen){ //此时dp[u]是不经过v,而经过其他子节点的最长路径
//now+dp[u]是经过u的最长路径
maxlen = now + dp[u]; //更新maxlen为经过u的最长路径
//此时u、v可能在树的直径上。比较所有的maxlen,最大的就是树的直径
k = num[u] * num[v]; //计算k,这个k可能重新赋值
}
else if(now + dp[u] == maxlen) //把此时的len看成树的直径,如果20行的k更新,这里也会重算
k += num[u] * num[v];
if(now > dp[u]){ //u经过v的路径更长
dp[u] = now; //更新dp[u]为经过v的路径
num[u] = num[v]; //v更可能在直径上,把经过u的最长路径数量更新为经过v的数量
}
else if(now == dp[u]) //相等,这也是最长路径
num[u] += num[v];
}
}
int main(){
int n; scanf("%d", &n);
for(int i = 1; i < n; i++){
int u, v; scanf("%d%d", &u, &v);
e[u].push_back(v); //加边
e[v].push_back(u);
}
dfs(1, 0);
cout << k * 2 << endl; //按题意u-v和v-u不同,所以乘以2
return 0;
}
Java代码
import java.util.*;
import java.io.*;
public class Main {
static FastReader scanner = new FastReader();
static int N = 300010;
static ArrayList<Integer>[] e = new ArrayList[N];
static int[] dp = new int[N]; // dp[u]:从u出发的最长路径
static int[] num = new int[N]; // num[u]:从u出发的最长路径数量
static long maxlen = 0, k = 0; // 直径的长度maxlen,经过直径的节点对数量k
public static void main(String[] args) throws IOException {
int n = scanner.nextInt();
for (int i = 1; i <= n; i++) e[i] = new ArrayList<>();
for (int i = 1; i < n; i++) {
int u = scanner.nextInt();
int v = scanner.nextInt();
e[u].add(v); // 加边
e[v].add(u);
}
dfs(1, 0);
System.out.println(k * 2); // 按题意u-v和v-u不同,所以乘以2
}
public static void dfs(int u, int fa) {
dp[u] = 0;
num[u] = 1;
for (int v : e[u]) {
if (v == fa) continue;
dfs(v, u); // 继续深入,回溯时带回算好的dp[v]
int now = dp[v] + 1; // 从u出发,经过子节点v的最长路径
if (now + dp[u] > maxlen) { // 此时dp[u]是不经过v,而经过其他子节点的最长路径
// now+dp[u]是经过u的最长路径
maxlen = now + dp[u]; // 更新maxlen为经过u的最长路径
// 此时u、v可能在树的直径上。比较所有的maxlen,最大的就是树的直径
k = num[u] * num[v]; // 计算k,这个k可能重新赋值
} else if (now + dp[u] == maxlen)
// 把此时的len看成树的直径,如果34行的k更新,这里也会重算
k += num[u] * num[v];
if (now > dp[u]) { // u经过v的路径更长
dp[u] = now; // 更新dp[u]为经过v的路径
// v更可能在直径上,把经过u的最长路径数量更新为经过v的数量
num[u] = num[v];
} else if (now == dp[u]) // 相等,这也是最长路径
num[u] += num[v];
}
}
static class FastReader {
BufferedReader br;
StringTokenizer st;
public FastReader() { br = new BufferedReader(new InputStreamReader(System.in)); }
String next() {
while (st == null || !st.hasMoreElements()) {
try {st = new StringTokenizer(br.readLine());}
catch (IOException e) { e.printStackTrace();}
}
return st.nextToken();
}
int nextInt() { return Integer.parseInt(next()); }
String nextLine() {
String str = "";
try { str = br.readLine();}
catch (IOException e) { e.printStackTrace(); }
return str;
}
}
}
Python代码
#pypy
from collections import defaultdict
import sys
sys.setrecursionlimit(300000)
e = defaultdict(list)
dp = [0] * 300010
num = [0] * 300010
maxlen = 0
k = 0
def dfs(u, fa):
global maxlen, k
dp[u] = 0
num[u] = 1
for v in e[u]:
if v == fa: continue
dfs(v, u)
now = dp[v] + 1
if now + dp[u] > maxlen:
maxlen = now + dp[u]
k = num[u] * num[v]
elif now + dp[u] == maxlen: k += num[u] * num[v]
if now > dp[u]:
dp[u] = now
num[u] = num[v]
elif now == dp[u]: num[u] += num[v]
n = int(input())
for i in range(1, n):
u, v = map(int, input().split())
e[u].append(v)
e[v].append(u)
dfs(1, 0)
print(k * 2)