09-DP进阶

💡 核心结论

DP进阶问题特点

  • 双串问题:两个字符串/数组的DP

  • 状态复杂:需要多维状态表示

  • 优化技巧:降维、滚动数组、状态压缩

  • 经典问题:LCS、编辑距离、股票、打家劫舍变种

经典DP问题分类

类型

代表问题

状态

时间

线性DP

LIS

dp[i]

O(n²)

双串DP

LCS、编辑距离

dp[i][j]

O(mn)

区间DP

最长回文子串

dp[i][j]

O(n²)

树形DP

打家劫舍III

dp[node]

O(n)

状态机DP

股票问题

dp[i][k][s]

O(nk)

🎯 最长公共子序列(LCS)

问题

找两个字符串的最长公共子序列长度

状态定义

dp[i][j] = text1[0:i]和text2[0:j]的LCS长度

状态转移

if text1[i-1] == text2[j-1]:
    dp[i][j] = dp[i-1][j-1] + 1
else:
    dp[i][j] = max(dp[i-1][j], dp[i][j-1])

实现

def longest_common_subsequence(text1, text2):
    m, n = len(text1), len(text2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if text1[i-1] == text2[j-1]:
                dp[i][j] = dp[i-1][j-1] + 1
            else:
                dp[i][j] = max(dp[i-1][j], dp[i][j-1])
    
    return dp[m][n]

✏️ 编辑距离

问题

将word1转换为word2的最少操作数(插入、删除、替换)

状态定义

dp[i][j] = word1[0:i]转换为word2[0:j]的最少操作数

状态转移

if word1[i-1] == word2[j-1]:
    dp[i][j] = dp[i-1][j-1]  # 不需要操作
else:
    dp[i][j] = 1 + min(
        dp[i-1][j],    # 删除
        dp[i][j-1],    # 插入
        dp[i-1][j-1]   # 替换
    )

实现

def min_distance(word1, word2):
    m, n = len(word1), len(word2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    
    # 初始化
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    
    # 状态转移
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if word1[i-1] == word2[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                dp[i][j] = 1 + min(
                    dp[i-1][j],    # 删除
                    dp[i][j-1],    # 插入
                    dp[i-1][j-1]   # 替换
                )
    
    return dp[m][n]

💰 股票问题系列

买卖一次

def max_profit(prices):
    min_price = float('inf')
    max_profit = 0
    
    for price in prices:
        min_price = min(min_price, price)
        max_profit = max(max_profit, price - min_price)
    
    return max_profit

买卖多次

def max_profit_unlimited(prices):
    profit = 0
    for i in range(1, len(prices)):
        if prices[i] > prices[i-1]:
            profit += prices[i] - prices[i-1]
    return profit

买卖k次(通用解法)

def max_profit_k(k, prices):
    if not prices:
        return 0
    
    n = len(prices)
    if k >= n // 2:
        return max_profit_unlimited(prices)
    
    # dp[i][k][0/1] = 第i天,最多k次交易,持有/不持有
    dp = [[[0, 0] for _ in range(k + 1)] for _ in range(n)]
    
    for i in range(n):
        for j in range(k, 0, -1):
            if i == 0:
                dp[i][j][0] = 0
                dp[i][j][1] = -prices[i]
            else:
                dp[i][j][0] = max(dp[i-1][j][0], dp[i-1][j][1] + prices[i])
                dp[i][j][1] = max(dp[i-1][j][1], dp[i-1][j-1][0] - prices[i])
    
    return dp[n-1][k][0]

🏠 打家劫舍变种

打家劫舍II(环形)

def rob_circular(nums):
    if len(nums) == 1:
        return nums[0]
    
    def rob_range(nums, start, end):
        prev, curr = 0, 0
        for i in range(start, end):
            prev, curr = curr, max(curr, prev + nums[i])
        return curr
    
    # 分两种情况:偷第一个或不偷
    return max(
        rob_range(nums, 0, len(nums) - 1),  # 不偷最后一个
        rob_range(nums, 1, len(nums))       # 不偷第一个
    )

打家劫舍III(二叉树)

def rob_tree(root):
    def dfs(node):
        if not node:
            return 0, 0
        
        left = dfs(node.left)
        right = dfs(node.right)
        
        # 偷当前节点
        rob_curr = node.val + left[1] + right[1]
        # 不偷当前节点
        not_rob = max(left) + max(right)
        
        return rob_curr, not_rob
    
    return max(dfs(root))

📚 LeetCode练习

LCS变种

编辑距离

股票问题

打家劫舍

💻 完整代码实现

Python 实现

  1"""
  2DP进阶问题实现
  3"""
  4
  5# ========== 最长公共子序列 ==========
  6
  7def longest_common_subsequence(text1, text2):
  8    """LCS"""
  9    m, n = len(text1), len(text2)
 10    dp = [[0] * (n + 1) for _ in range(m + 1)]
 11    
 12    for i in range(1, m + 1):
 13        for j in range(1, n + 1):
 14            if text1[i-1] == text2[j-1]:
 15                dp[i][j] = dp[i-1][j-1] + 1
 16            else:
 17                dp[i][j] = max(dp[i-1][j], dp[i][j-1])
 18    
 19    return dp[m][n]
 20
 21
 22# ========== 编辑距离 ==========
 23
 24def min_distance(word1, word2):
 25    """编辑距离"""
 26    m, n = len(word1), len(word2)
 27    dp = [[0] * (n + 1) for _ in range(m + 1)]
 28    
 29    for i in range(m + 1):
 30        dp[i][0] = i
 31    for j in range(n + 1):
 32        dp[0][j] = j
 33    
 34    for i in range(1, m + 1):
 35        for j in range(1, n + 1):
 36            if word1[i-1] == word2[j-1]:
 37                dp[i][j] = dp[i-1][j-1]
 38            else:
 39                dp[i][j] = 1 + min(
 40                    dp[i-1][j],    # 删除
 41                    dp[i][j-1],    # 插入
 42                    dp[i-1][j-1]   # 替换
 43                )
 44    
 45    return dp[m][n]
 46
 47
 48# ========== 最长回文子串 ==========
 49
 50def longest_palindrome(s):
 51    """最长回文子串"""
 52    n = len(s)
 53    dp = [[False] * n for _ in range(n)]
 54    start, max_len = 0, 1
 55    
 56    # 初始化
 57    for i in range(n):
 58        dp[i][i] = True
 59    
 60    # 长度为2
 61    for i in range(n - 1):
 62        if s[i] == s[i+1]:
 63            dp[i][i+1] = True
 64            start = i
 65            max_len = 2
 66    
 67    # 长度>=3
 68    for length in range(3, n + 1):
 69        for i in range(n - length + 1):
 70            j = i + length - 1
 71            if s[i] == s[j] and dp[i+1][j-1]:
 72                dp[i][j] = True
 73                start = i
 74                max_len = length
 75    
 76    return s[start:start + max_len]
 77
 78
 79# ========== 股票问题 ==========
 80
 81def max_profit_one(prices):
 82    """买卖一次"""
 83    min_price = float('inf')
 84    max_profit = 0
 85    
 86    for price in prices:
 87        min_price = min(min_price, price)
 88        max_profit = max(max_profit, price - min_price)
 89    
 90    return max_profit
 91
 92
 93def max_profit_unlimited(prices):
 94    """买卖多次"""
 95    profit = 0
 96    for i in range(1, len(prices)):
 97        if prices[i] > prices[i-1]:
 98            profit += prices[i] - prices[i-1]
 99    return profit
100
101
102def max_profit_k(k, prices):
103    """最多k次交易"""
104    if not prices:
105        return 0
106    
107    n = len(prices)
108    if k >= n // 2:
109        return max_profit_unlimited(prices)
110    
111    # dp[i][j][0/1] = 第i天,最多j次交易,不持有/持有
112    dp = [[[0, 0] for _ in range(k + 1)] for _ in range(n)]
113    
114    for i in range(n):
115        for j in range(k, 0, -1):
116            if i == 0:
117                dp[i][j][0] = 0
118                dp[i][j][1] = -prices[i]
119            else:
120                dp[i][j][0] = max(dp[i-1][j][0], dp[i-1][j][1] + prices[i])
121                dp[i][j][1] = max(dp[i-1][j][1], dp[i-1][j-1][0] - prices[i])
122    
123    return dp[n-1][k][0]
124
125
126# ========== 打家劫舍II(环形) ==========
127
128def rob_circular(nums):
129    """环形房屋"""
130    if len(nums) == 1:
131        return nums[0]
132    
133    def rob_range(start, end):
134        prev, curr = 0, 0
135        for i in range(start, end):
136            prev, curr = curr, max(curr, prev + nums[i])
137        return curr
138    
139    return max(
140        rob_range(0, len(nums) - 1),
141        rob_range(1, len(nums))
142    )
143
144
145# ========== 最长递增子序列(二分优化) ==========
146
147def length_of_lis_binary(nums):
148    """LIS O(n log n)"""
149    if not nums:
150        return 0
151    
152    tails = []
153    
154    for num in nums:
155        left, right = 0, len(tails)
156        while left < right:
157            mid = (left + right) // 2
158            if tails[mid] < num:
159                left = mid + 1
160            else:
161                right = mid
162        
163        if left == len(tails):
164            tails.append(num)
165        else:
166            tails[left] = num
167    
168    return len(tails)
169
170
171def demo():
172    """演示DP进阶"""
173    print("=== DP进阶演示 ===\n")
174    
175    # LCS
176    text1, text2 = "abcde", "ace"
177    print(f"最长公共子序列 '{text1}' 和 '{text2}':")
178    print(f"  长度: {longest_common_subsequence(text1, text2)}\n")
179    
180    # 编辑距离
181    word1, word2 = "horse", "ros"
182    print(f"编辑距离 '{word1}' -> '{word2}':")
183    print(f"  最少操作: {min_distance(word1, word2)}\n")
184    
185    # 最长回文
186    s = "babad"
187    print(f"最长回文子串 '{s}':")
188    print(f"  结果: {longest_palindrome(s)}\n")
189    
190    # 股票
191    prices = [7,1,5,3,6,4]
192    print(f"股票问题 {prices}:")
193    print(f"  买卖一次: {max_profit_one(prices)}")
194    print(f"  买卖多次: {max_profit_unlimited(prices)}")
195    print(f"  最多2次: {max_profit_k(2, prices)}\n")
196    
197    # 打家劫舍II
198    nums = [2,3,2]
199    print(f"打家劫舍II(环形){nums}:")
200    print(f"  最大金额: {rob_circular(nums)}\n")
201    
202    # LIS
203    nums = [10, 9, 2, 5, 3, 7, 101, 18]
204    print(f"最长递增子序列 {nums}:")
205    print(f"  长度: {length_of_lis_binary(nums)}")
206
207
208if __name__ == '__main__':
209    demo()
210

C++ 实现

  1/**
  2 * DP进阶问题实现
  3 */
  4
  5#include <iostream>
  6#include <vector>
  7#include <string>
  8#include <algorithm>
  9#include <climits>
 10
 11using namespace std;
 12
 13// 最长公共子序列
 14int longestCommonSubsequence(const string& text1, const string& text2) {
 15    int m = text1.length(), n = text2.length();
 16    vector<vector<int>> dp(m + 1, vector<int>(n + 1, 0));
 17    
 18    for (int i = 1; i <= m; i++) {
 19        for (int j = 1; j <= n; j++) {
 20            if (text1[i-1] == text2[j-1]) {
 21                dp[i][j] = dp[i-1][j-1] + 1;
 22            } else {
 23                dp[i][j] = max(dp[i-1][j], dp[i][j-1]);
 24            }
 25        }
 26    }
 27    
 28    return dp[m][n];
 29}
 30
 31// 编辑距离
 32int minDistance(const string& word1, const string& word2) {
 33    int m = word1.length(), n = word2.length();
 34    vector<vector<int>> dp(m + 1, vector<int>(n + 1, 0));
 35    
 36    for (int i = 0; i <= m; i++) dp[i][0] = i;
 37    for (int j = 0; j <= n; j++) dp[0][j] = j;
 38    
 39    for (int i = 1; i <= m; i++) {
 40        for (int j = 1; j <= n; j++) {
 41            if (word1[i-1] == word2[j-1]) {
 42                dp[i][j] = dp[i-1][j-1];
 43            } else {
 44                dp[i][j] = 1 + min({
 45                    dp[i-1][j],    // 删除
 46                    dp[i][j-1],    // 插入
 47                    dp[i-1][j-1]   // 替换
 48                });
 49            }
 50        }
 51    }
 52    
 53    return dp[m][n];
 54}
 55
 56// 最长回文子串
 57string longestPalindrome(const string& s) {
 58    int n = s.length();
 59    if (n == 0) return "";
 60    
 61    vector<vector<bool>> dp(n, vector<bool>(n, false));
 62    int start = 0, maxLen = 1;
 63    
 64    for (int i = 0; i < n; i++) {
 65        dp[i][i] = true;
 66    }
 67    
 68    for (int i = 0; i < n - 1; i++) {
 69        if (s[i] == s[i+1]) {
 70            dp[i][i+1] = true;
 71            start = i;
 72            maxLen = 2;
 73        }
 74    }
 75    
 76    for (int len = 3; len <= n; len++) {
 77        for (int i = 0; i <= n - len; i++) {
 78            int j = i + len - 1;
 79            if (s[i] == s[j] && dp[i+1][j-1]) {
 80                dp[i][j] = true;
 81                start = i;
 82                maxLen = len;
 83            }
 84        }
 85    }
 86    
 87    return s.substr(start, maxLen);
 88}
 89
 90// 买卖股票(一次)
 91int maxProfitOne(const vector<int>& prices) {
 92    int minPrice = INT_MAX;
 93    int maxProfit = 0;
 94    
 95    for (int price : prices) {
 96        minPrice = min(minPrice, price);
 97        maxProfit = max(maxProfit, price - minPrice);
 98    }
 99    
100    return maxProfit;
101}
102
103// 买卖股票(无限次)
104int maxProfitUnlimited(const vector<int>& prices) {
105    int profit = 0;
106    
107    for (int i = 1; i < prices.size(); i++) {
108        if (prices[i] > prices[i-1]) {
109            profit += prices[i] - prices[i-1];
110        }
111    }
112    
113    return profit;
114}
115
116// 买卖股票(k次)
117int maxProfitK(int k, const vector<int>& prices) {
118    if (prices.empty()) return 0;
119    
120    int n = prices.size();
121    if (k >= n / 2) {
122        return maxProfitUnlimited(prices);
123    }
124    
125    // dp[i][j][0/1] = 第i天,最多j次交易,不持有/持有
126    vector<vector<vector<int>>> dp(n, 
127        vector<vector<int>>(k + 1, vector<int>(2, 0)));
128    
129    for (int i = 0; i < n; i++) {
130        for (int j = k; j >= 1; j--) {
131            if (i == 0) {
132                dp[i][j][0] = 0;
133                dp[i][j][1] = -prices[i];
134            } else {
135                dp[i][j][0] = max(dp[i-1][j][0], dp[i-1][j][1] + prices[i]);
136                dp[i][j][1] = max(dp[i-1][j][1], dp[i-1][j-1][0] - prices[i]);
137            }
138        }
139    }
140    
141    return dp[n-1][k][0];
142}
143
144// 最长递增子序列(二分优化)
145int lengthOfLIS(const vector<int>& nums) {
146    if (nums.empty()) return 0;
147    
148    vector<int> tails;
149    
150    for (int num : nums) {
151        auto it = lower_bound(tails.begin(), tails.end(), num);
152        if (it == tails.end()) {
153            tails.push_back(num);
154        } else {
155            *it = num;
156        }
157    }
158    
159    return tails.size();
160}
161
162int main() {
163    cout << "=== DP进阶演示 ===" << endl << endl;
164    
165    // LCS
166    string text1 = "abcde", text2 = "ace";
167    cout << "最长公共子序列 '" << text1 << "' 和 '" << text2 << "':" << endl;
168    cout << "  长度: " << longestCommonSubsequence(text1, text2) << endl << endl;
169    
170    // 编辑距离
171    string word1 = "horse", word2 = "ros";
172    cout << "编辑距离 '" << word1 << "' -> '" << word2 << "':" << endl;
173    cout << "  最少操作: " << minDistance(word1, word2) << endl << endl;
174    
175    // 最长回文
176    string s = "babad";
177    cout << "最长回文子串 '" << s << "':" << endl;
178    cout << "  结果: " << longestPalindrome(s) << endl << endl;
179    
180    // 股票
181    vector<int> prices = {7,1,5,3,6,4};
182    cout << "股票问题 [7,1,5,3,6,4]:" << endl;
183    cout << "  买卖一次: " << maxProfitOne(prices) << endl;
184    cout << "  买卖多次: " << maxProfitUnlimited(prices) << endl;
185    cout << "  最多2次: " << maxProfitK(2, prices) << endl << endl;
186    
187    // LIS
188    vector<int> nums = {10, 9, 2, 5, 3, 7, 101, 18};
189    cout << "最长递增子序列 [10,9,2,5,3,7,101,18]:" << endl;
190    cout << "  长度: " << lengthOfLIS(nums) << endl;
191    
192    return 0;
193}