04-递归与回溯

💡 核心结论

递归本质

  • 定义:函数调用自己,将大问题分解为小问题

  • 三要素:递归边界、递归规则、返回值

  • 关键:明确函数定义,相信递归,不要跳进递归

  • 代价:栈空间O(递归深度),可能栈溢出

  • 优化:记忆化、尾递归、改迭代

回溯本质

  • 定义:暴力搜索 + 剪枝,试探所有可能

  • 模板:选择→递归→撤销选择

  • 关键:路径、选择列表、结束条件

  • 剪枝:提前排除不可能的分支

  • 应用:全排列、组合、子集、N皇后

递归 vs 迭代

特性

递归

迭代

代码

简洁优雅

相对复杂

空间

O(递归深度)

O(1)

性能

函数调用开销

更快

适用

树、分治、回溯

简单循环

回溯模板(背下来)

result = []

def backtrack(路径, 选择列表):
    if 满足结束条件:
        result.add(路径)
        return
    
    for 选择 in 选择列表:
        做选择
        backtrack(路径, 新选择列表)
        撤销选择

🎯 经典递归问题

1. 阶乘

def factorial(n):
    if n <= 1:  # 递归边界
        return 1
    return n * factorial(n - 1)  # 递归规则

2. 斐波那契

def fib(n):
    if n <= 1:
        return n
    return fib(n - 1) + fib(n - 2)

3. 二叉树遍历

def preorder(root):
    if not root:
        return []
    return [root.val] + preorder(root.left) + preorder(root.right)

4. 汉诺塔

def hanoi(n, source, target, auxiliary):
    if n == 1:
        print(f"Move disk from {source} to {target}")
        return
    
    hanoi(n - 1, source, auxiliary, target)
    print(f"Move disk from {source} to {target}")
    hanoi(n - 1, auxiliary, target, source)

🔙 回溯算法

1. 全排列

def permute(nums):
    result = []
    
    def backtrack(path, choices):
        if len(path) == len(nums):
            result.append(path[:])
            return
        
        for i in range(len(choices)):
            # 做选择
            path.append(choices[i])
            # 递归
            backtrack(path, choices[:i] + choices[i+1:])
            # 撤销选择
            path.pop()
    
    backtrack([], nums)
    return result

2. 组合

def combine(n, k):
    result = []
    
    def backtrack(start, path):
        if len(path) == k:
            result.append(path[:])
            return
        
        for i in range(start, n + 1):
            path.append(i)
            backtrack(i + 1, path)
            path.pop()
    
    backtrack(1, [])
    return result

3. 子集

def subsets(nums):
    result = []
    
    def backtrack(start, path):
        result.append(path[:])  # 每个状态都是一个子集
        
        for i in range(start, len(nums)):
            path.append(nums[i])
            backtrack(i + 1, path)
            path.pop()
    
    backtrack(0, [])
    return result

4. N皇后

def solve_n_queens(n):
    result = []
    board = [['.'] * n for _ in range(n)]
    
    def is_valid(row, col):
        # 检查列
        for i in range(row):
            if board[i][col] == 'Q':
                return False
        
        # 检查左上对角线
        i, j = row - 1, col - 1
        while i >= 0 and j >= 0:
            if board[i][j] == 'Q':
                return False
            i -= 1
            j -= 1
        
        # 检查右上对角线
        i, j = row - 1, col + 1
        while i >= 0 and j < n:
            if board[i][j] == 'Q':
                return False
            i -= 1
            j += 1
        
        return True
    
    def backtrack(row):
        if row == n:
            result.append([''.join(row) for row in board])
            return
        
        for col in range(n):
            if is_valid(row, col):
                board[row][col] = 'Q'
                backtrack(row + 1)
                board[row][col] = '.'
    
    backtrack(0)
    return result

5. 括号生成

def generate_parenthesis(n):
    result = []
    
    def backtrack(path, left, right):
        if len(path) == 2 * n:
            result.append(path)
            return
        
        if left < n:
            backtrack(path + '(', left + 1, right)
        if right < left:
            backtrack(path + ')', left, right + 1)
    
    backtrack('', 0, 0)
    return result

🎯 剪枝优化

1. 提前返回

def backtrack(path):
    if 当前路径不可能产生解:
        return  # 剪枝
    
    if 找到解:
        result.append(path)
        return
    
    for choice in choices:
        backtrack(...)

2. 去重

def permute_unique(nums):
    nums.sort()  # 先排序
    result = []
    used = [False] * len(nums)
    
    def backtrack(path):
        if len(path) == len(nums):
            result.append(path[:])
            return
        
        for i in range(len(nums)):
            if used[i]:
                continue
            # 剪枝:跳过重复元素
            if i > 0 and nums[i] == nums[i-1] and not used[i-1]:
                continue
            
            used[i] = True
            path.append(nums[i])
            backtrack(path)
            path.pop()
            used[i] = False
    
    backtrack([])
    return result

📚 LeetCode练习

递归

回溯

💡 解题技巧

递归三问

  1. 递归函数的定义是什么?

  2. 递归的终止条件是什么?

  3. 递归如何缩小问题规模?

回溯三步

  1. 路径:已做的选择

  2. 选择列表:当前可以做的选择

  3. 结束条件:到达决策树底层

优化方向

  1. 剪枝:提前排除无效分支

  2. 去重:避免重复计算

  3. 记忆化:存储子问题结果

  4. 改DP:自底向上

💻 完整代码实现

Python 实现

  1"""
  2递归与回溯算法实现
  3"""
  4
  5# ========== 1. 全排列 ==========
  6
  7def permute(nums):
  8    """全排列"""
  9    result = []
 10    
 11    def backtrack(path, choices):
 12        if len(path) == len(nums):
 13            result.append(path[:])
 14            return
 15        
 16        for i in range(len(choices)):
 17            path.append(choices[i])
 18            backtrack(path, choices[:i] + choices[i+1:])
 19            path.pop()
 20    
 21    backtrack([], nums)
 22    return result
 23
 24
 25# ========== 2. 组合 ==========
 26
 27def combine(n, k):
 28    """从1到n选k个数的所有组合"""
 29    result = []
 30    
 31    def backtrack(start, path):
 32        if len(path) == k:
 33            result.append(path[:])
 34            return
 35        
 36        for i in range(start, n + 1):
 37            path.append(i)
 38            backtrack(i + 1, path)
 39            path.pop()
 40    
 41    backtrack(1, [])
 42    return result
 43
 44
 45# ========== 3. 子集 ==========
 46
 47def subsets(nums):
 48    """所有子集"""
 49    result = []
 50    
 51    def backtrack(start, path):
 52        result.append(path[:])
 53        
 54        for i in range(start, len(nums)):
 55            path.append(nums[i])
 56            backtrack(i + 1, path)
 57            path.pop()
 58    
 59    backtrack(0, [])
 60    return result
 61
 62
 63# ========== 4. N皇后 ==========
 64
 65def solve_n_queens(n):
 66    """N皇后问题"""
 67    result = []
 68    board = [['.'] * n for _ in range(n)]
 69    
 70    def is_valid(row, col):
 71        # 检查列
 72        for i in range(row):
 73            if board[i][col] == 'Q':
 74                return False
 75        
 76        # 检查左上对角线
 77        i, j = row - 1, col - 1
 78        while i >= 0 and j >= 0:
 79            if board[i][j] == 'Q':
 80                return False
 81            i -= 1
 82            j -= 1
 83        
 84        # 检查右上对角线
 85        i, j = row - 1, col + 1
 86        while i >= 0 and j < n:
 87            if board[i][j] == 'Q':
 88                return False
 89            i -= 1
 90            j += 1
 91        
 92        return True
 93    
 94    def backtrack(row):
 95        if row == n:
 96            result.append([''.join(row) for row in board])
 97            return
 98        
 99        for col in range(n):
100            if is_valid(row, col):
101                board[row][col] = 'Q'
102                backtrack(row + 1)
103                board[row][col] = '.'
104    
105    backtrack(0)
106    return result
107
108
109# ========== 5. 括号生成 ==========
110
111def generate_parenthesis(n):
112    """生成n对有效括号"""
113    result = []
114    
115    def backtrack(path, left, right):
116        if len(path) == 2 * n:
117            result.append(path)
118            return
119        
120        if left < n:
121            backtrack(path + '(', left + 1, right)
122        if right < left:
123            backtrack(path + ')', left, right + 1)
124    
125    backtrack('', 0, 0)
126    return result
127
128
129# ========== 6. 组合总和 ==========
130
131def combination_sum(candidates, target):
132    """数字可重复使用"""
133    result = []
134    
135    def backtrack(start, path, total):
136        if total == target:
137            result.append(path[:])
138            return
139        if total > target:
140            return
141        
142        for i in range(start, len(candidates)):
143            path.append(candidates[i])
144            backtrack(i, path, total + candidates[i])  # i不是i+1,可重复
145            path.pop()
146    
147    backtrack(0, [], 0)
148    return result
149
150
151# ========== 7. 单词搜索 ==========
152
153def exist(board, word):
154    """在网格中查找单词"""
155    m, n = len(board), len(board[0])
156    
157    def backtrack(i, j, k):
158        if k == len(word):
159            return True
160        
161        if i < 0 or i >= m or j < 0 or j >= n or board[i][j] != word[k]:
162            return False
163        
164        temp = board[i][j]
165        board[i][j] = '#'  # 标记已访问
166        
167        found = (backtrack(i+1, j, k+1) or
168                backtrack(i-1, j, k+1) or
169                backtrack(i, j+1, k+1) or
170                backtrack(i, j-1, k+1))
171        
172        board[i][j] = temp  # 恢复
173        return found
174    
175    for i in range(m):
176        for j in range(n):
177            if backtrack(i, j, 0):
178                return True
179    return False
180
181
182def demo():
183    """演示回溯算法"""
184    print("=== 回溯算法演示 ===\n")
185    
186    # 全排列
187    print("全排列 [1,2,3]:")
188    print(permute([1, 2, 3]))
189    print()
190    
191    # 组合
192    print("C(4,2) 组合:")
193    print(combine(4, 2))
194    print()
195    
196    # 子集
197    print("子集 [1,2,3]:")
198    print(subsets([1, 2, 3]))
199    print()
200    
201    # N皇后
202    print("4皇后问题:")
203    solutions = solve_n_queens(4)
204    print(f"共{len(solutions)}种解法")
205    for i, solution in enumerate(solutions):
206        print(f"\n解法{i+1}:")
207        for row in solution:
208            print(row)
209    print()
210    
211    # 括号生成
212    print("3对括号:")
213    print(generate_parenthesis(3))
214    print()
215    
216    # 组合总和
217    print("组合总和 candidates=[2,3,6,7], target=7:")
218    print(combination_sum([2, 3, 6, 7], 7))
219
220
221if __name__ == '__main__':
222    demo()
223

C++ 实现

  1/**
  2 * 递归与回溯算法实现
  3 */
  4
  5#include <iostream>
  6#include <vector>
  7#include <string>
  8
  9using namespace std;
 10
 11// ========== 全排列 ==========
 12
 13void permuteHelper(vector<int>& nums, vector<bool>& used, 
 14                   vector<int>& path, vector<vector<int>>& result) {
 15    if (path.size() == nums.size()) {
 16        result.push_back(path);
 17        return;
 18    }
 19    
 20    for (int i = 0; i < nums.size(); i++) {
 21        if (used[i]) continue;
 22        
 23        used[i] = true;
 24        path.push_back(nums[i]);
 25        permuteHelper(nums, used, path, result);
 26        path.pop_back();
 27        used[i] = false;
 28    }
 29}
 30
 31vector<vector<int>> permute(vector<int>& nums) {
 32    vector<vector<int>> result;
 33    vector<int> path;
 34    vector<bool> used(nums.size(), false);
 35    permuteHelper(nums, used, path, result);
 36    return result;
 37}
 38
 39// ========== 组合 ==========
 40
 41void combineHelper(int n, int k, int start, vector<int>& path, 
 42                   vector<vector<int>>& result) {
 43    if (path.size() == k) {
 44        result.push_back(path);
 45        return;
 46    }
 47    
 48    for (int i = start; i <= n; i++) {
 49        path.push_back(i);
 50        combineHelper(n, k, i + 1, path, result);
 51        path.pop_back();
 52    }
 53}
 54
 55vector<vector<int>> combine(int n, int k) {
 56    vector<vector<int>> result;
 57    vector<int> path;
 58    combineHelper(n, k, 1, path, result);
 59    return result;
 60}
 61
 62// ========== 子集 ==========
 63
 64void subsetsHelper(vector<int>& nums, int start, vector<int>& path,
 65                   vector<vector<int>>& result) {
 66    result.push_back(path);
 67    
 68    for (int i = start; i < nums.size(); i++) {
 69        path.push_back(nums[i]);
 70        subsetsHelper(nums, i + 1, path, result);
 71        path.pop_back();
 72    }
 73}
 74
 75vector<vector<int>> subsets(vector<int>& nums) {
 76    vector<vector<int>> result;
 77    vector<int> path;
 78    subsetsHelper(nums, 0, path, result);
 79    return result;
 80}
 81
 82// ========== N皇后 ==========
 83
 84bool isValid(const vector<string>& board, int row, int col, int n) {
 85    // 检查列
 86    for (int i = 0; i < row; i++) {
 87        if (board[i][col] == 'Q') return false;
 88    }
 89    
 90    // 检查左上对角线
 91    for (int i = row - 1, j = col - 1; i >= 0 && j >= 0; i--, j--) {
 92        if (board[i][j] == 'Q') return false;
 93    }
 94    
 95    // 检查右上对角线
 96    for (int i = row - 1, j = col + 1; i >= 0 && j < n; i--, j++) {
 97        if (board[i][j] == 'Q') return false;
 98    }
 99    
100    return true;
101}
102
103void solveNQueensHelper(int n, int row, vector<string>& board,
104                        vector<vector<string>>& result) {
105    if (row == n) {
106        result.push_back(board);
107        return;
108    }
109    
110    for (int col = 0; col < n; col++) {
111        if (isValid(board, row, col, n)) {
112            board[row][col] = 'Q';
113            solveNQueensHelper(n, row + 1, board, result);
114            board[row][col] = '.';
115        }
116    }
117}
118
119vector<vector<string>> solveNQueens(int n) {
120    vector<vector<string>> result;
121    vector<string> board(n, string(n, '.'));
122    solveNQueensHelper(n, 0, board, result);
123    return result;
124}
125
126// ========== 括号生成 ==========
127
128void generateParenthesisHelper(int n, int left, int right,
129                               string& path, vector<string>& result) {
130    if (path.length() == 2 * n) {
131        result.push_back(path);
132        return;
133    }
134    
135    if (left < n) {
136        path.push_back('(');
137        generateParenthesisHelper(n, left + 1, right, path, result);
138        path.pop_back();
139    }
140    
141    if (right < left) {
142        path.push_back(')');
143        generateParenthesisHelper(n, left, right + 1, path, result);
144        path.pop_back();
145    }
146}
147
148vector<string> generateParenthesis(int n) {
149    vector<string> result;
150    string path;
151    generateParenthesisHelper(n, 0, 0, path, result);
152    return result;
153}
154
155// 打印结果
156template<typename T>
157void printVector(const vector<T>& vec) {
158    cout << "[";
159    for (size_t i = 0; i < vec.size(); i++) {
160        cout << vec[i];
161        if (i < vec.size() - 1) cout << ", ";
162    }
163    cout << "]";
164}
165
166int main() {
167    cout << "=== 回溯算法演示 ===" << endl << endl;
168    
169    // 全排列
170    vector<int> nums = {1, 2, 3};
171    cout << "全排列 [1,2,3]:" << endl;
172    vector<vector<int>> perms = permute(nums);
173    for (const auto& perm : perms) {
174        cout << "  ";
175        printVector(perm);
176        cout << endl;
177    }
178    cout << endl;
179    
180    // 组合
181    cout << "C(4,2) 组合:" << endl;
182    vector<vector<int>> combs = combine(4, 2);
183    for (const auto& comb : combs) {
184        cout << "  ";
185        printVector(comb);
186        cout << endl;
187    }
188    cout << endl;
189    
190    // 子集
191    vector<int> nums2 = {1, 2, 3};
192    cout << "子集 [1,2,3]:" << endl;
193    vector<vector<int>> subs = subsets(nums2);
194    cout << "  共" << subs.size() << "个子集" << endl << endl;
195    
196    // N皇后
197    cout << "4皇后问题:" << endl;
198    vector<vector<string>> queens = solveNQueens(4);
199    cout << "  共" << queens.size() << "种解法" << endl << endl;
200    
201    // 括号生成
202    cout << "3对括号:" << endl;
203    vector<string> parens = generateParenthesis(3);
204    for (const auto& p : parens) {
205        cout << "  " << p << endl;
206    }
207    
208    return 0;
209}