树型DP(下)
树型DP在树结构上求解最优子结构,常用于选点、路径等问题,状态在子树间转移,自底向上递推。
树型DP(下)
树型DP(下)
2477. 到达首都的最少油耗
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
class Solution {
public:
int MAXN = 100001;
// 无向图
int MAXM = 200001;
// 链式前向星
// 下标:顶点编号,值:该顶点第一条边的边号
vector<int> head;
// 下标:边号,值:下一条边的边号
vector<int> nxt;
// 下标:边号,值:去往的顶点编号
vector<int> to;
// 边号计数,从 1 开始,0 表示没有边
int cnt;
Solution() {
head.resize(MAXN);
nxt.resize(MAXM);
to.resize(MAXM);
}
void build() {
cnt = 1;
fill(head.begin(), head.end(), 0);
}
void addEdge(int u, int v) {
nxt[cnt] = head[u];
to[cnt] = v;
head[u] = cnt;
cnt++;
}
long long minimumFuelCost(vector<vector<int>> &roads, int seats) {
int n = roads.size() + 1;
// 建立无向图
build();
for (const auto &item: roads) {
addEdge(item[0], item[1]);
addEdge(item[1], item[0]);
}
// 节点总数
vector<int> count(n);
// 油耗总数
vector<long> cost(n);
fc(n, seats, 0, -1, count, cost);
return cost[0];
}
// u 为当前节点,p 为父节点
void fc(int n, int seats, int u, int p, vector<int> &count, vector<long> &cost) {
count[u] = 1;
// 遍历从 u 节点出发的每一条边
for (int ei = head[u]; ei > 0; ei = nxt[ei]) {
// 邻接点编号
int v = to[ei];
// 父节点就跳过,确保是从上往下遍历
if (v == p) continue;
fc(n, seats, v, u, count, cost);
// 加上下游的节点总数与油耗
count[u] += count[v];
cost[u] += cost[v];
// 再加上邻接点到当前节点产生的油耗
cost[u] += (count[v] + seats - 1) / seats;
}
}
};
2246. 相邻字符不同的最长路径
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
struct Info {
// 必须以当前节点往下所形成的路径长度的最大值,包含自身
int maxLenFromRoot;
// 整棵树上的最长路径
// 可能是经过当前节点 v,由以 v 为端的最长路径和次长路径拼接得到
// 也可能不经过 v,由 v 的下游提供
int maxLen;
Info(int m, int maxLen) : maxLenFromRoot(m), maxLen(maxLen) {}
};
class Solution {
public:
int MAXN = 100001;
int MAXM = 200001;
vector<int> head;
vector<int> nxt;
vector<int> to;
int cnt;
Solution() {
head.resize(MAXN);
nxt.resize(MAXM);
to.resize(MAXM);
}
void build() {
cnt = 1;
fill(head.begin(), head.end(), 0);
}
void addEdge(int u, int v) {
nxt[cnt] = head[u];
to[cnt] = v;
head[u] = cnt;
cnt++;
}
int longestPath(vector<int> &parent, string str) {
int n = parent.size();
// 建立无向图
build();
for (int i = 1; i < n; ++i) {
addEdge(i, parent[i]);
addEdge(parent[i], i);
}
vector<char> s(begin(str), end(str));
s.push_back('\0');
return fc(0, parent[0], s).maxLen;
}
// 返回任意一对相邻节点都没有分配到相同字符的最长路径
Info fc(int u, int p, vector<char> &s) {
// 没有邻边时的返回值
if (head[u] <= 0) return Info(1, 1);
// 以当前节点往下所形成的路径长度的最大值和次大值,包含自身
int m1 = 1;
int m2 = 0;
int maxLen = 1;
// 遍历邻边 u->v
for (int ei = head[u]; ei > 0; ei = nxt[ei]) {
int v = to[ei];
// 是父节点就跳过
if (v == p) continue;
Info infoV = fc(v, u, s);
maxLen = max(maxLen, infoV.maxLen);
if (s[u] != s[v]) {
// 字符不一样
if (infoV.maxLenFromRoot + 1 >= m1) {
m2 = m1;
m1 = infoV.maxLenFromRoot + 1;
} else if (infoV.maxLenFromRoot + 1 > m2) {
m2 = infoV.maxLenFromRoot + 1;
}
}
}
maxLen = max(maxLen, m1 + m2 - 1);
return Info(m1, maxLen);
}
};
2458. 移除子树后的二叉树高度
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
struct TreeNode {
int val;
TreeNode *left;
TreeNode *right;
TreeNode() : val(0), left(nullptr), right(nullptr) {}
TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
};
int MAXN = 100002;
// 下标为节点的值
vector<int> dfn(MAXN);
// 下标为 dfn 序号,记录根节点到当前节点经过的边数
vector<int> deep(MAXN);
// 下标为 dfn 序号,记录这棵树的节点总数
vector<int> cnt(MAXN);
// 辅助数组,用于快速在 size 数组去除一段后找最值
vector<int> maxL(MAXN, 0);
vector<int> maxR(MAXN, 0);
int dfnCnt;
class Solution {
public:
vector<int> treeQueries(TreeNode *root, vector<int> &queries) {
dfnCnt = 0;
fc(root, 0);
for (int i = 1; i <= dfnCnt; ++i)
maxL[i] = max(maxL[i - 1], deep[i]);
maxR[dfnCnt + 1] = 0;
for (int i = dfnCnt; i >= 1; i--)
maxR[i] = max(maxR[i + 1], deep[i]);
int m = queries.size();
vector<int> res(m);
for (int i = 0; i < m; ++i) {
int index = dfn[queries[i]];
int leftMax = maxL[index - 1];
int rightMax = maxR[index + cnt[index]];
res[i] = max(leftMax, rightMax);
}
return res;
}
// 根节点到当前节点经过了 k 条边
void fc(TreeNode *node, int k) {
// dfn 序号从 1 开始
int i = ++dfnCnt;
// 记录 dfn 序号
dfn[node->val] = i;
deep[i] = k;
cnt[i] = 1;
// 累加上左右子树的节点数
if (node->left != nullptr) {
fc(node->left, k + 1);
cnt[i] += cnt[dfn[node->left->val]];
}
if (node->right != nullptr) {
fc(node->right, k + 1);
cnt[i] += cnt[dfn[node->right->val]];
}
}
};
2322. 从树中删除边的最小分数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <iostream>
#include <vector>
#include <algorithm>
#include <forward_list>
using namespace std;
int MAXN = 1001;
// 下标为原始节点编号,dfn 序号从 1 开始
vector<int> dfn(MAXN);
// 下标为 dfn 序号
vector<int> xorSum(MAXN);
// 下标为 dfn 序号
vector<int> cnt(MAXN);
int dfnCnt;
// 邻接表
vector<forward_list<int>> graph(MAXN);
class Solution {
public:
// 无向图建图
void undirectedGraph(vector<vector<int>> &edges) {
for (const auto &edge: edges) {
graph[edge[0]].emplace_front(edge[1]);
graph[edge[1]].emplace_front(edge[0]);
}
}
int minimumScore(vector<int> &nums, vector<vector<int>> &edges) {
int n = nums.size();
int m = edges.size();
int res = INT_MAX;
// 清理全局变量
dfn.assign(MAXN, 0);
xorSum.assign(MAXN, 0);
cnt.assign(MAXN, 0);
dfnCnt = 0;
for (auto &g: graph) g.clear();
// 生成图
undirectedGraph(edges);
// 遍历并记录
fc(nums, 0);
// 枚举被删除的两条边的所有可能
for (int i = 0; i < m; ++i) {
// 第一条被删除的边连着的两个节点的 dfn 序号较大者
int a = max(dfn[edges[i][0]], dfn[edges[i][1]]);
for (int j = i + 1; j < m; ++j) {
// 第二条被删除的边连着的两个节点的 dfn 序号较大者
int b = max(dfn[edges[j][0]], dfn[edges[j][1]]);
// 判断 a,b 先后
int pre = a < b ? a : b;
int post = a + b - pre;
// 连通子图的异或和
int sum1 = xorSum[post];
int sum2, sum3;
if (post < pre + cnt[pre]) {
// 存在祖先关系
sum2 = xorSum[pre] ^ xorSum[post];
sum3 = xorSum[1] ^ xorSum[pre];
} else {
sum2 = xorSum[pre];
sum3 = xorSum[1] ^ sum1 ^ sum2;
}
int maxVal = max(max(sum1, sum2), sum3);
int minVal = min(min(sum1, sum2), sum3);
res = min(res, maxVal - minVal);
}
}
return res;
}
// u 是原始编号,遍历整棵树,记录 dfn 序号、异或和、节点数
void fc(vector<int> &nums, int u) {
int i = ++dfnCnt;
dfn[u] = i;
xorSum[i] = nums[u];
cnt[i] = 1;
for (const auto &v: graph[u]) {
// v 节点已经处理过
if (dfn[v] != 0) continue;
fc(nums, v);
xorSum[i] ^= xorSum[dfn[v]];
cnt[i] += cnt[dfn[v]];
}
}
};
P2014 [CTSC1997] 选课
- 时间复杂度:O(n * 每个节点的平均孩子数 * (m 的平方))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <iostream>
#include <vector>
#include <forward_list>
using namespace std;
const int MAXN = 301;
// 记录课程的学分
vector<int> nums(MAXN, 0);
// 邻接表,实际课程编号 1~n,增加了一个编号 0 的虚节点,把多个子树都接在 0 号节点下面
vector<vector<int>> graph(MAXN);
// dp[i][j][k]: 当前根节点为 i,在 i 号节点、及其 i 号节点下方的前 j 棵子树上挑选节点
// 共 k 个,且挑选的节点连成一片,返回最大的累加和
vector<vector<vector<int>>> dp(MAXN, vector<vector<int>>(MAXN, vector<int>(MAXN, -1)));
// 返回最大的累加和
int fc(int i, int j, int k) {
// 一个都不选
if (k == 0) return 0;
// 只能选 i 号节点
if (j == 0 || k == 1) return nums[i];
if (dp[i][j][k] != -1) return dp[i][j][k];
// p1: 第 j 个子树上一个都不选
int res = fc(i, j - 1, k);
// 第 j 棵子树根节点 v
int v = graph[i][j - 1];
// p2: 尝试第 j 个子树上选 s 个,在 i 的前 j - 1 棵子树上选 k - s 个
for (int s = 1; s < k; s++)
res = max(res, fc(i, j - 1, k - s) + fc(v, graph[v].size(), s));
dp[i][j][k] = res;
return res;
}
int main() {
int n, m;
cin >> n >> m;
// 多出的一个名额给必须选择的 0 号虚节点
m++;
for (int i = 1, pre; i <= n; i++) {
cin >> pre;
// 录入先导课程编号 pre 与当前 i 号课程所形成的边
graph[pre].push_back(i);
// 以及 i 号课程的学分
cin >> nums[i];
}
cout << fc(0, graph[0].size(), m) << endl;
return 0;
}
- 时间复杂度:O(n * m)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <iostream>
#include <vector>
#include <forward_list>
using namespace std;
int MAXN = 301;
int MAXM = 301;
// 链式前向星
vector<int> head(MAXN);
vector<int> nxt(MAXM);
vector<int> to(MAXM);
// 边序号从 1 开始
int edgeCnt;
void addEdge(int u, int v) {
nxt[edgeCnt] = head[u];
to[edgeCnt] = v;
head[u] = edgeCnt;
edgeCnt++;
}
// 下标为原始节点编号,值为 dfn 序号
vector<int> dfn(MAXN + 1);
// 下标为 dfn 序号,记录这个节点能获得的学分
vector<int> val(MAXN + 1);
// 下标为 dfn 序号,记录节点总数
vector<int> cnt(MAXN + 1);
// dfn 序号,从 1 开始
int dfnCnt;
// dp[i][j]: dfn 序号 i ~ n+1 范围的节点,选择 j 个节点一定要形成有效结构的情况下,最大的累加和
vector<vector<int>> dp(MAXN + 2, vector<int>(MAXN));
// 记录课程的学分
vector<int> nums(MAXN, 0);
void build() {
edgeCnt = 1;
dfnCnt = 0;
fill(begin(head), end(head), 0);
fill(begin(dp), end(dp), vector<int>(MAXN, 0));
}
// 遍历树,记录 dfn 序号、树的节点总数
void fc(int u) {
// 原始节点编号 u 的 dfn 序号
int i = ++dfnCnt;
dfn[u] = i;
val[i] = nums[u];
cnt[i] = 1;
// 遍历邻边 u->v
for (int ei = head[u]; ei > 0; ei = nxt[ei]) {
int v = to[ei];
fc(v);
cnt[i] += cnt[dfn[v]];
}
}
int compute(int n, int m) {
// 原始节点编号 0~n,dfn 序号 1~n+1
// 根据 dfn 序号逆序遍历,对当前的根节点进行如下操作
// 选择的总节点数从 1 尝试到 m
// p1: 不要当前节点
// 那么当前节点及其子树,也就是 dfn 序号 i~i + cnt[i] -1 范围的都不能选,因为不能连成一片
// 只能从 dfn 序号 i + cnt[i] 开始往后选 j 个节点,即 dp[i + cnt[i]][j]
// p2: 要当前节点
// 剩下 j-1 个节点从 dfn 序号 i+1~n+1 范围内选,即 dp[i + 1][j - 1],再加上当前值 val[i]
// 这些节点在加了一个虚拟头节点的情况下,就能连成一片
for (int i = n + 1; i >= 2; i--) {
for (int j = 1; j <= m; j++) {
// p1,p2 两种情况选最大值
dp[i][j] = max(dp[i + cnt[i]][j], val[i] + dp[i + 1][j - 1]);
}
}
return dp[2][m];
}
int main() {
int n, m;
cin >> n >> m;
build();
// 原始节点编号 0~n
for (int i = 1, pre; i <= n; i++) {
cin >> pre;
// 录入先导课程编号 pre 与当前 i 号课程所形成的边
addEdge(pre, i);
// 以及 i 号课程的学分
cin >> nums[i];
}
fc(0);
cout << compute(n, m) << endl;
return 0;
}
本文由作者按照 CC BY 4.0 进行授权