09-DP进阶
💡 核心结论
DP进阶问题特点
双串问题:两个字符串/数组的DP
状态复杂:需要多维状态表示
优化技巧:降维、滚动数组、状态压缩
经典问题:LCS、编辑距离、股票、打家劫舍变种
经典DP问题分类
类型 |
代表问题 |
状态 |
时间 |
|---|---|---|---|
线性DP |
LIS |
dp[i] |
O(n²) |
双串DP |
LCS、编辑距离 |
dp[i][j] |
O(mn) |
区间DP |
最长回文子串 |
dp[i][j] |
O(n²) |
树形DP |
打家劫舍III |
dp[node] |
O(n) |
状态机DP |
股票问题 |
dp[i][k][s] |
O(nk) |
🎯 最长公共子序列(LCS)
问题
找两个字符串的最长公共子序列长度
状态定义
dp[i][j] = text1[0:i]和text2[0:j]的LCS长度
状态转移
if text1[i-1] == text2[j-1]:
dp[i][j] = dp[i-1][j-1] + 1
else:
dp[i][j] = max(dp[i-1][j], dp[i][j-1])
实现
def longest_common_subsequence(text1, text2):
m, n = len(text1), len(text2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(1, m + 1):
for j in range(1, n + 1):
if text1[i-1] == text2[j-1]:
dp[i][j] = dp[i-1][j-1] + 1
else:
dp[i][j] = max(dp[i-1][j], dp[i][j-1])
return dp[m][n]
✏️ 编辑距离
问题
将word1转换为word2的最少操作数(插入、删除、替换)
状态定义
dp[i][j] = word1[0:i]转换为word2[0:j]的最少操作数
状态转移
if word1[i-1] == word2[j-1]:
dp[i][j] = dp[i-1][j-1] # 不需要操作
else:
dp[i][j] = 1 + min(
dp[i-1][j], # 删除
dp[i][j-1], # 插入
dp[i-1][j-1] # 替换
)
实现
def min_distance(word1, word2):
m, n = len(word1), len(word2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
# 初始化
for i in range(m + 1):
dp[i][0] = i
for j in range(n + 1):
dp[0][j] = j
# 状态转移
for i in range(1, m + 1):
for j in range(1, n + 1):
if word1[i-1] == word2[j-1]:
dp[i][j] = dp[i-1][j-1]
else:
dp[i][j] = 1 + min(
dp[i-1][j], # 删除
dp[i][j-1], # 插入
dp[i-1][j-1] # 替换
)
return dp[m][n]
💰 股票问题系列
买卖一次
def max_profit(prices):
min_price = float('inf')
max_profit = 0
for price in prices:
min_price = min(min_price, price)
max_profit = max(max_profit, price - min_price)
return max_profit
买卖多次
def max_profit_unlimited(prices):
profit = 0
for i in range(1, len(prices)):
if prices[i] > prices[i-1]:
profit += prices[i] - prices[i-1]
return profit
买卖k次(通用解法)
def max_profit_k(k, prices):
if not prices:
return 0
n = len(prices)
if k >= n // 2:
return max_profit_unlimited(prices)
# dp[i][k][0/1] = 第i天,最多k次交易,持有/不持有
dp = [[[0, 0] for _ in range(k + 1)] for _ in range(n)]
for i in range(n):
for j in range(k, 0, -1):
if i == 0:
dp[i][j][0] = 0
dp[i][j][1] = -prices[i]
else:
dp[i][j][0] = max(dp[i-1][j][0], dp[i-1][j][1] + prices[i])
dp[i][j][1] = max(dp[i-1][j][1], dp[i-1][j-1][0] - prices[i])
return dp[n-1][k][0]
🏠 打家劫舍变种
打家劫舍II(环形)
def rob_circular(nums):
if len(nums) == 1:
return nums[0]
def rob_range(nums, start, end):
prev, curr = 0, 0
for i in range(start, end):
prev, curr = curr, max(curr, prev + nums[i])
return curr
# 分两种情况:偷第一个或不偷
return max(
rob_range(nums, 0, len(nums) - 1), # 不偷最后一个
rob_range(nums, 1, len(nums)) # 不偷第一个
)
打家劫舍III(二叉树)
def rob_tree(root):
def dfs(node):
if not node:
return 0, 0
left = dfs(node.left)
right = dfs(node.right)
# 偷当前节点
rob_curr = node.val + left[1] + right[1]
# 不偷当前节点
not_rob = max(left) + max(right)
return rob_curr, not_rob
return max(dfs(root))
📚 LeetCode练习
LCS变种
编辑距离
股票问题
打家劫舍
💻 完整代码实现
Python 实现
1"""
2DP进阶问题实现
3"""
4
5# ========== 最长公共子序列 ==========
6
7def longest_common_subsequence(text1, text2):
8 """LCS"""
9 m, n = len(text1), len(text2)
10 dp = [[0] * (n + 1) for _ in range(m + 1)]
11
12 for i in range(1, m + 1):
13 for j in range(1, n + 1):
14 if text1[i-1] == text2[j-1]:
15 dp[i][j] = dp[i-1][j-1] + 1
16 else:
17 dp[i][j] = max(dp[i-1][j], dp[i][j-1])
18
19 return dp[m][n]
20
21
22# ========== 编辑距离 ==========
23
24def min_distance(word1, word2):
25 """编辑距离"""
26 m, n = len(word1), len(word2)
27 dp = [[0] * (n + 1) for _ in range(m + 1)]
28
29 for i in range(m + 1):
30 dp[i][0] = i
31 for j in range(n + 1):
32 dp[0][j] = j
33
34 for i in range(1, m + 1):
35 for j in range(1, n + 1):
36 if word1[i-1] == word2[j-1]:
37 dp[i][j] = dp[i-1][j-1]
38 else:
39 dp[i][j] = 1 + min(
40 dp[i-1][j], # 删除
41 dp[i][j-1], # 插入
42 dp[i-1][j-1] # 替换
43 )
44
45 return dp[m][n]
46
47
48# ========== 最长回文子串 ==========
49
50def longest_palindrome(s):
51 """最长回文子串"""
52 n = len(s)
53 dp = [[False] * n for _ in range(n)]
54 start, max_len = 0, 1
55
56 # 初始化
57 for i in range(n):
58 dp[i][i] = True
59
60 # 长度为2
61 for i in range(n - 1):
62 if s[i] == s[i+1]:
63 dp[i][i+1] = True
64 start = i
65 max_len = 2
66
67 # 长度>=3
68 for length in range(3, n + 1):
69 for i in range(n - length + 1):
70 j = i + length - 1
71 if s[i] == s[j] and dp[i+1][j-1]:
72 dp[i][j] = True
73 start = i
74 max_len = length
75
76 return s[start:start + max_len]
77
78
79# ========== 股票问题 ==========
80
81def max_profit_one(prices):
82 """买卖一次"""
83 min_price = float('inf')
84 max_profit = 0
85
86 for price in prices:
87 min_price = min(min_price, price)
88 max_profit = max(max_profit, price - min_price)
89
90 return max_profit
91
92
93def max_profit_unlimited(prices):
94 """买卖多次"""
95 profit = 0
96 for i in range(1, len(prices)):
97 if prices[i] > prices[i-1]:
98 profit += prices[i] - prices[i-1]
99 return profit
100
101
102def max_profit_k(k, prices):
103 """最多k次交易"""
104 if not prices:
105 return 0
106
107 n = len(prices)
108 if k >= n // 2:
109 return max_profit_unlimited(prices)
110
111 # dp[i][j][0/1] = 第i天,最多j次交易,不持有/持有
112 dp = [[[0, 0] for _ in range(k + 1)] for _ in range(n)]
113
114 for i in range(n):
115 for j in range(k, 0, -1):
116 if i == 0:
117 dp[i][j][0] = 0
118 dp[i][j][1] = -prices[i]
119 else:
120 dp[i][j][0] = max(dp[i-1][j][0], dp[i-1][j][1] + prices[i])
121 dp[i][j][1] = max(dp[i-1][j][1], dp[i-1][j-1][0] - prices[i])
122
123 return dp[n-1][k][0]
124
125
126# ========== 打家劫舍II(环形) ==========
127
128def rob_circular(nums):
129 """环形房屋"""
130 if len(nums) == 1:
131 return nums[0]
132
133 def rob_range(start, end):
134 prev, curr = 0, 0
135 for i in range(start, end):
136 prev, curr = curr, max(curr, prev + nums[i])
137 return curr
138
139 return max(
140 rob_range(0, len(nums) - 1),
141 rob_range(1, len(nums))
142 )
143
144
145# ========== 最长递增子序列(二分优化) ==========
146
147def length_of_lis_binary(nums):
148 """LIS O(n log n)"""
149 if not nums:
150 return 0
151
152 tails = []
153
154 for num in nums:
155 left, right = 0, len(tails)
156 while left < right:
157 mid = (left + right) // 2
158 if tails[mid] < num:
159 left = mid + 1
160 else:
161 right = mid
162
163 if left == len(tails):
164 tails.append(num)
165 else:
166 tails[left] = num
167
168 return len(tails)
169
170
171def demo():
172 """演示DP进阶"""
173 print("=== DP进阶演示 ===\n")
174
175 # LCS
176 text1, text2 = "abcde", "ace"
177 print(f"最长公共子序列 '{text1}' 和 '{text2}':")
178 print(f" 长度: {longest_common_subsequence(text1, text2)}\n")
179
180 # 编辑距离
181 word1, word2 = "horse", "ros"
182 print(f"编辑距离 '{word1}' -> '{word2}':")
183 print(f" 最少操作: {min_distance(word1, word2)}\n")
184
185 # 最长回文
186 s = "babad"
187 print(f"最长回文子串 '{s}':")
188 print(f" 结果: {longest_palindrome(s)}\n")
189
190 # 股票
191 prices = [7,1,5,3,6,4]
192 print(f"股票问题 {prices}:")
193 print(f" 买卖一次: {max_profit_one(prices)}")
194 print(f" 买卖多次: {max_profit_unlimited(prices)}")
195 print(f" 最多2次: {max_profit_k(2, prices)}\n")
196
197 # 打家劫舍II
198 nums = [2,3,2]
199 print(f"打家劫舍II(环形){nums}:")
200 print(f" 最大金额: {rob_circular(nums)}\n")
201
202 # LIS
203 nums = [10, 9, 2, 5, 3, 7, 101, 18]
204 print(f"最长递增子序列 {nums}:")
205 print(f" 长度: {length_of_lis_binary(nums)}")
206
207
208if __name__ == '__main__':
209 demo()
210
C++ 实现
1/**
2 * DP进阶问题实现
3 */
4
5#include <iostream>
6#include <vector>
7#include <string>
8#include <algorithm>
9#include <climits>
10
11using namespace std;
12
13// 最长公共子序列
14int longestCommonSubsequence(const string& text1, const string& text2) {
15 int m = text1.length(), n = text2.length();
16 vector<vector<int>> dp(m + 1, vector<int>(n + 1, 0));
17
18 for (int i = 1; i <= m; i++) {
19 for (int j = 1; j <= n; j++) {
20 if (text1[i-1] == text2[j-1]) {
21 dp[i][j] = dp[i-1][j-1] + 1;
22 } else {
23 dp[i][j] = max(dp[i-1][j], dp[i][j-1]);
24 }
25 }
26 }
27
28 return dp[m][n];
29}
30
31// 编辑距离
32int minDistance(const string& word1, const string& word2) {
33 int m = word1.length(), n = word2.length();
34 vector<vector<int>> dp(m + 1, vector<int>(n + 1, 0));
35
36 for (int i = 0; i <= m; i++) dp[i][0] = i;
37 for (int j = 0; j <= n; j++) dp[0][j] = j;
38
39 for (int i = 1; i <= m; i++) {
40 for (int j = 1; j <= n; j++) {
41 if (word1[i-1] == word2[j-1]) {
42 dp[i][j] = dp[i-1][j-1];
43 } else {
44 dp[i][j] = 1 + min({
45 dp[i-1][j], // 删除
46 dp[i][j-1], // 插入
47 dp[i-1][j-1] // 替换
48 });
49 }
50 }
51 }
52
53 return dp[m][n];
54}
55
56// 最长回文子串
57string longestPalindrome(const string& s) {
58 int n = s.length();
59 if (n == 0) return "";
60
61 vector<vector<bool>> dp(n, vector<bool>(n, false));
62 int start = 0, maxLen = 1;
63
64 for (int i = 0; i < n; i++) {
65 dp[i][i] = true;
66 }
67
68 for (int i = 0; i < n - 1; i++) {
69 if (s[i] == s[i+1]) {
70 dp[i][i+1] = true;
71 start = i;
72 maxLen = 2;
73 }
74 }
75
76 for (int len = 3; len <= n; len++) {
77 for (int i = 0; i <= n - len; i++) {
78 int j = i + len - 1;
79 if (s[i] == s[j] && dp[i+1][j-1]) {
80 dp[i][j] = true;
81 start = i;
82 maxLen = len;
83 }
84 }
85 }
86
87 return s.substr(start, maxLen);
88}
89
90// 买卖股票(一次)
91int maxProfitOne(const vector<int>& prices) {
92 int minPrice = INT_MAX;
93 int maxProfit = 0;
94
95 for (int price : prices) {
96 minPrice = min(minPrice, price);
97 maxProfit = max(maxProfit, price - minPrice);
98 }
99
100 return maxProfit;
101}
102
103// 买卖股票(无限次)
104int maxProfitUnlimited(const vector<int>& prices) {
105 int profit = 0;
106
107 for (int i = 1; i < prices.size(); i++) {
108 if (prices[i] > prices[i-1]) {
109 profit += prices[i] - prices[i-1];
110 }
111 }
112
113 return profit;
114}
115
116// 买卖股票(k次)
117int maxProfitK(int k, const vector<int>& prices) {
118 if (prices.empty()) return 0;
119
120 int n = prices.size();
121 if (k >= n / 2) {
122 return maxProfitUnlimited(prices);
123 }
124
125 // dp[i][j][0/1] = 第i天,最多j次交易,不持有/持有
126 vector<vector<vector<int>>> dp(n,
127 vector<vector<int>>(k + 1, vector<int>(2, 0)));
128
129 for (int i = 0; i < n; i++) {
130 for (int j = k; j >= 1; j--) {
131 if (i == 0) {
132 dp[i][j][0] = 0;
133 dp[i][j][1] = -prices[i];
134 } else {
135 dp[i][j][0] = max(dp[i-1][j][0], dp[i-1][j][1] + prices[i]);
136 dp[i][j][1] = max(dp[i-1][j][1], dp[i-1][j-1][0] - prices[i]);
137 }
138 }
139 }
140
141 return dp[n-1][k][0];
142}
143
144// 最长递增子序列(二分优化)
145int lengthOfLIS(const vector<int>& nums) {
146 if (nums.empty()) return 0;
147
148 vector<int> tails;
149
150 for (int num : nums) {
151 auto it = lower_bound(tails.begin(), tails.end(), num);
152 if (it == tails.end()) {
153 tails.push_back(num);
154 } else {
155 *it = num;
156 }
157 }
158
159 return tails.size();
160}
161
162int main() {
163 cout << "=== DP进阶演示 ===" << endl << endl;
164
165 // LCS
166 string text1 = "abcde", text2 = "ace";
167 cout << "最长公共子序列 '" << text1 << "' 和 '" << text2 << "':" << endl;
168 cout << " 长度: " << longestCommonSubsequence(text1, text2) << endl << endl;
169
170 // 编辑距离
171 string word1 = "horse", word2 = "ros";
172 cout << "编辑距离 '" << word1 << "' -> '" << word2 << "':" << endl;
173 cout << " 最少操作: " << minDistance(word1, word2) << endl << endl;
174
175 // 最长回文
176 string s = "babad";
177 cout << "最长回文子串 '" << s << "':" << endl;
178 cout << " 结果: " << longestPalindrome(s) << endl << endl;
179
180 // 股票
181 vector<int> prices = {7,1,5,3,6,4};
182 cout << "股票问题 [7,1,5,3,6,4]:" << endl;
183 cout << " 买卖一次: " << maxProfitOne(prices) << endl;
184 cout << " 买卖多次: " << maxProfitUnlimited(prices) << endl;
185 cout << " 最多2次: " << maxProfitK(2, prices) << endl << endl;
186
187 // LIS
188 vector<int> nums = {10, 9, 2, 5, 3, 7, 101, 18};
189 cout << "最长递增子序列 [10,9,2,5,3,7,101,18]:" << endl;
190 cout << " 长度: " << lengthOfLIS(nums) << endl;
191
192 return 0;
193}