文章

动态规划中优化枚举(下)

动态规划中优化枚举减少状态转移复杂度,通过单调性、二分、单调队列等技巧加速计算,提高效率。

动态规划中优化枚举(下)

动态规划中优化枚举

1235. 规划兼职工作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

class Solution {
public:
    int jobScheduling(vector<int> &startTime, vector<int> &endTime, vector<int> &profit) {
        int n = startTime.size();
        vector<vector<int>> jobs(n, vector<int>(3));
        for (int i = 0; i < n; ++i) {
            jobs[i][0] = startTime[i];
            jobs[i][1] = endTime[i];
            jobs[i][2] = profit[i];
        }
        // 根据结束时间排序
        sort(begin(jobs), end(jobs), [](vector<int> &a, vector<int> &b) { return a[1] < b[1]; });
        // 表示在排完序的工作中,0~i 号中进行选择,能获得的最大报酬
        vector<int> dp(n);
        dp[0] = jobs[0][2];
        for (int i = 1; i < n; ++i) {
            int start = jobs[i][0];
            // p1: 选 i 号工作
            dp[i] = jobs[i][2];
            // 再加上在比 i 号工作的开始时间跟早的工作中选择,能获得的最大报酬
            // 如果 jobs[0][1] > start 说明之前每个工作的结束时间都在 start 之后,都不能选
            // 用二分优化枚举
            if (jobs[0][1] <= start)
                dp[i] += dp[binarySearch(jobs, i - 1, start)];
            // p2: 不选 i 号工作,即dp[i - 1]。取 p1、p2 的最大值
            dp[i] = max(dp[i], dp[i - 1]);
        }
        return dp[n - 1];
    }

    // 在 0~right 号工作中找结束时间小于等于 target 的,且是最右边的
    int binarySearch(vector<vector<int>> &jobs, int right, int target) {
        int l = 0;
        int r = right;
        int m;
        while (l <= r) {
            m = l + (r - l) / 2;
            if (jobs[m][1] <= target) {
                l = m + 1;
            } else {
                r = m - 1;
            }
        }
        return r;
    }
};

629. K 个逆序对数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

class Solution {
public:
    const int MOD = 1e9 + 7;

    int kInversePairs(int n, int k) {
        // dp[i][j]: 1~i 形成的排列中逆序对为 j 个的排列有多少种
        vector<vector<int>> dp(n + 1, vector<int>(k + 1, 0));
        dp[0][0] = 1;
        for (int i = 1; i <= n; ++i) {
            dp[i][0] = 1;
            for (int j = 1; j <= k; ++j) {
                if (i > j) {
                    for (int p = 0; p <= j; ++p)
                        dp[i][j] = (dp[i][j] + dp[i - 1][p]) % MOD;
                } else {
                    for (int p = j - i + 1; p <= j; ++p)
                        dp[i][j] = (dp[i][j] + dp[i - 1][p]) % MOD;
                }
            }
        }
        return dp[n][k];
    }
};
  • 优化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

class Solution {
public:
    const int MOD = 1e9 + 7;

    int kInversePairs(int n, int k) {
        // dp[i][j]: 1~i 形成的排列中逆序对为 j 个的排列有多少种
        vector<vector<int>> dp(n + 1, vector<int>(k + 1, 0));
        dp[0][0] = 1;
        for (int i = 1; i <= n; ++i) {
            dp[i][0] = 1;
            // 窗口的累加和
            int window = 1;
            for (int j = 1; j <= k; ++j) {
                if (i > j) {
                    window = (window + dp[i - 1][j]) % MOD;
                } else {
                    window = ((window + dp[i - 1][j]) % MOD - dp[i - 1][j - i] + MOD) % MOD;
                }
                dp[i][j] = window;
            }
        }
        return dp[n][k];
    }
};

514. 自由之路

  • 枚举所有位置
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <iostream>
#include <vector>
#include <string>
#include <climits>
#include <algorithm>

using namespace std;

class Solution {
public:
    // 用于存储 ring 转换为字符编号(0-25)
    vector<int> ring;
    // 用于存储 key 转换为字符编号(0-25)
    vector<int> key;
    // where[c] 存放字符 c(0-25)在 ring 中出现的位置(有序)
    vector<vector<int>> where;
    // 记忆化数组 dp[i][j] 表示从 ring[i] 开始,搞定 key[j..] 的最小代价
    vector<vector<int>> dp;
    // ring 和 key 的长度
    int n, m;

    // 预处理函数:构建 ring、key、where 和初始化 dp 表
    void build(const string &r, const string &k) {
        n = r.size();
        m = k.size();

        ring.resize(n);
        key.resize(m);
        where.assign(26, vector<int>());  // 每个字符出现的位置列表

        for (int i = 0; i < n; ++i) {
            int c = r[i] - 'a';
            ring[i] = c;
            where[c].push_back(i);
        }

        for (int i = 0; i < m; ++i) {
            key[i] = k[i] - 'a';
        }

        // 初始化 dp 表为 -1(表示尚未计算)
        dp.assign(n, vector<int>(m, -1));
    }

    // 记忆化搜索函数
    // 当前指针在 ring 的位置 i,key 从 j 开始要处理的所有字符
    int dfs(int i, int j) {
        if (j == m) return 0;  // 所有 key 字符都搞定了
        if (dp[i][j] != -1) return dp[i][j];  // 返回已计算的值

        int res = INT_MAX;
        // 枚举 key[j] 的所有可能位置
        for (int next: where[key[j]]) {
            // 计算从 i 到 next 的最短旋转距离(顺或逆)
            int delta = abs(i - next);
            int step = min(delta, n - delta);
            // 递归处理下一个 key 字符,累加旋转步数
            res = min(res, step + dfs(next, j + 1));
        }

        // +1 是按下按钮的代价
        dp[i][j] = res + 1;
        return dp[i][j];
    }

    int findRotateSteps(string r, string k) {
        build(r, k);
        return dfs(0, 0);
    }
};
  • 只枚举顺时针和逆时针的最近位置
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include <iostream>
#include <vector>
#include <cstring>
#include <algorithm>

using namespace std;

class Solution {
public:
    static const int MAXC = 26;

    vector<int> ring;            // 保存ring中字符的位置(索引)
    vector<int> key;             // 保存key中字符的对应值(字符映射为数字)
    vector<int> cnts;            // 每个字符(a-z)在ring中出现的次数
    vector<vector<int>> where;   // 存储每个字符在ring中的所有位置
    vector<vector<int>> dp;      // 动态规划表,存储从某个位置处理某个key时的最少旋转步数
    int n, m;                    // n为ring的长度,m为key的长度

    void build(string r, string k) {
        n = r.length();
        m = k.length();
        ring.resize(n);
        key.resize(m);
        cnts.assign(MAXC, 0);
        where.resize(MAXC, vector<int>(n));
        dp.resize(n, vector<int>(m, -1));

        for (int i = 0, v; i < n; i++) {
            v = r[i] - 'a';
            where[v][cnts[v]++] = i;
            ring[i] = v;
        }
        for (int i = 0; i < m; i++) {
            key[i] = k[i] - 'a';
        }
    }

    // 主函数,计算最少的旋转步数
    int findRotateSteps(string r, string k) {
        build(r, k);
        return fc(0, 0);
    }

    // 当前指针在 ring 的位置 i,key 从 j 开始要处理的所有字符
    int fc(int i, int j) {
        if (j == m) return 0;  // 如果所有key字符已经处理完,返回0步
        if (dp[i][j] != -1) return dp[i][j];

        int res;
        if (ring[i] == key[j]) {
            // 如果ring[i]已经是key[j],直接按下按钮
            res = 1 + fc(i, j + 1);
        } else {
            // 顺时针跳到目标字符的最小步数
            int jump1 = clock(i, key[j]);
            int distance1 = (jump1 > i ? (jump1 - i) : (n - i + jump1));

            // 逆时针跳到目标字符的最小步数
            int jump2 = counterClock(i, key[j]);
            int distance2 = (i > jump2 ? (i - jump2) : (i + n - jump2));

            // 选择最小的步数,并递归处理
            res = min(distance1 + fc(jump1, j), distance2 + fc(jump2, j));
        }

        dp[i][j] = res;
        return res;
    }

    // 顺时针找到从位置i开始,字符v在ring中的最近位置
    int clock(int i, int v) {
        int l = 0;
        int r = cnts[v] - 1, mid;
        int find = -1;
        vector<int> &sorted = where[v];  // 获取字符v在ring中的所有位置

        // 二分查找,找到第一个大于i的索引
        while (l <= r) {
            mid = (l + r) / 2;
            if (sorted[mid] > i) {
                find = mid;
                r = mid - 1;
            } else {
                l = mid + 1;
            }
        }

        // 如果找到合适位置,返回该位置,否则返回最小位置
        return find != -1 ? sorted[find] : sorted[0];
    }


    // 逆时针找到从位置i开始,字符v在ring中的最近位置
    int counterClock(int i, int v) {
        int l = 0;
        int r = cnts[v] - 1, mid;
        int find = -1;
        vector<int> &sorted = where[v];  // 获取字符v在ring中的所有位置

        // 二分查找,找到第一个小于i的索引
        while (l <= r) {
            mid = (l + r) / 2;
            if (sorted[mid] < i) {
                find = mid;
                l = mid + 1;
            } else {
                r = mid - 1;
            }
        }

        // 如果找到合适位置,返回该位置,否则返回最大位置
        return find != -1 ? sorted[find] : sorted[cnts[v] - 1];
    }
};

未排序数组中累加和小于或等于给定值的最长子数组长度

  • 带二分的 O(nlogn) 实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

int n, k;
vector<int> nums;

// 在 sums 中查找第一个大于等于 target 的下标
int binarySearch(vector<int> &sums, int target) {
    int l = 0;
    int r = sums.size() - 1;
    int m;
    while (l <= r) {
        m = l + (r - l) / 2;
        if (sums[m] >= target) {
            r = m - 1;
        } else {
            l = m + 1;
        }
    }
    return l;
}

// 返回累加和不超过 k 的最长子数组长度
int fc() {
    // 前缀最大和数组 sums,sums[i] 表示从下标 0 到 i-1 的最大前缀和
    vector<int> sums(n + 1, 0);
    for (int i = 0, sum = 0; i < n; i++) {
        sum += nums[i]; // 计算当前前缀和
        sums[i + 1] = max(sums[i], sum); //前 i+1 个数的最大前缀和
    }

    int res = 0;
    for (int i = 0, sum = 0; i < n; i++) {
        sum += nums[i]; // sum 表示从 nums[0] 到 nums[i] 的前缀和
        // 找最早的一个前缀和 pre,使得 sum - pre <= k,即 pre >= sum - k
        int pre = binarySearch(sums, sum - k);
        int len = pre == -1 ? 0 : i - pre + 1; // 若找到,更新长度(i - pre + 1)
        res = max(res, len);
    }
    return res;
}

int main() {
    while (cin >> n >> k) {
        nums.resize(n);
        for (int i = 0; i < n; ++i) {
            cin >> nums[i];
        }
        cout << fc() << endl;
    }
    return 0;
}
  • 最优 O(n) 解法,贪心 + 滑动窗口
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

int n, k;
vector<int> nums;

// O(n) 解法,
int fc() {
    vector<int> minSums(n), minSumEnds(n);
    // minSums[i] 表示从 i 开始往右延伸的最小累加和(尽可能延长)
    // minSumEnds[i] 表示这个最小累加和右边界的位置(闭区间)
    minSums[n - 1] = nums[n - 1];
    minSumEnds[n - 1] = n - 1;
    for (int i = n - 2; i >= 0; i--) {
        // 如果后缀最小和是负数,继续延伸
        if (minSums[i + 1] < 0) {
            minSums[i] = nums[i] + minSums[i + 1];
            minSumEnds[i] = minSumEnds[i + 1];
        } else {
            // 否则从当前位置单独开始
            minSums[i] = nums[i];
            minSumEnds[i] = i;
        }
    }

    int res = 0;       // 记录最长子数组长度
    int sum = 0;       // 当前窗口内累加和
    int end = 0;       // 窗口右边界

    // 从左到右滑动窗口的起始位置 i
    for (int i = 0; i < n; i++) {
        // 不断尝试扩展右边界,只要扩展后的 sum 不超过 k
        while (end < n && sum + minSums[end] <= k) {
            sum += minSums[end];
            end = minSumEnds[end] + 1; // 直接跳到下一个可扩展位置
        }

        if (end > i) {
            // 窗口 [i, end - 1] 是合法的,更新答案
            res = max(res, end - i);
            sum -= nums[i]; // 移动窗口左边界,减去 nums[i]
        } else {
            // 窗口无法扩展,end 右移,i 也要右移(跳过此位置)
            end = i + 1;
        }
    }
    return res;
}

int main() {
    while (cin >> n >> k) {
        nums.resize(n);
        for (int i = 0; i < n; ++i)
            cin >> nums[i];
        cout << fc() << endl;
    }
    return 0;
}
本文由作者按照 CC BY 4.0 进行授权