回溯问题

  • 组合问题:N个数里面按一定规则找出k个数的集合

  • 切割问题:一个字符串按一定规则有几种切割方式

  • 子集问题:一个N个数的集合里有多少符合条件的子集

  • 排列问题:N个数按一定规则全排列,有几种排列方式

  • 棋盘问题:N皇后,解数独等等

关于回溯。

一个问题两个字符串各选一个组成一个 2 个字符的字符串。最容易想到的是两个循环嵌套。

但是如果说要生成一个长为3、长为 4,或者长度是个参数无法提前知道,那么没法用循环嵌套来表达了。

  • 原问题:构造长为 n 的字符串。

  • 枚举一个字母

  • 子问题:构造长为 n - 1 的字符串。

这个子问题和原问题是很相似的。

这种可以分解的问题,可以使用递归来实现。回溯有一个增量构造答案的过程。

把边界条件和非边界条件的逻辑写对了,那递归代码一定能得到正确的结果。

回溯三问:

  • 当前操作是什么?

  • 子问题是什么?

  • 下一个子问题?

对于电话号码组合,用 path[i] 来记录路径上的字母

  • 枚举 path[i] 处要填入的字母

  • 构造字符串 >= i 的部分

  • 构造字符串 >= i + 1 的部分

dfs(i) -> dfs(i+1)

子集型回溯

模板一

子集型回溯,看每个元素或者不选

比如计算 $ [1, 2] $ 所有的子集,每个元素都可以选或者不选,总共就有 4 种情况。

回溯三问,站在输入的角度思考

  • 当前操作?每个第 $ i $ 个数选或不选

  • 子问题?从下标 $ \ge i $ 的数字中构造子集

  • 下一个子问题?从下标 $ \ge i + 1 $ 的数字中构造子集

#define MAX_RESULT  100000

int n;              /* 数组元素个数,终止条件判断 */

int** result;       /* 保存结果的数组 */
int result_len;     /* 结果个数 */
int *column_size;   /* 每个结果的长度 */

int path[10] = {0}; /* 单个结果 */
int path_len;       /* 单个结果的长度 */

void save_result(void)
{
    result[result_len] = malloc(path_len * sizeof(int));
    column_size[result_len] = path_len;

    for (int i=0; i<path_len; i++)
        result[result_len][i] = path[i];

    result_len++;
}

void dfs(int *nums, int i)
{
    if ( i == n )
    {
        save_result();
        return;
    }

    dfs(nums, i+1);   /* 当前操作不选,直接进下一层递归 */

    path[path_len] = nums[i];   /* 当前操作选, */
    path_len++;
    dfs(nums, i+1);

    path_len--;
}

int** subsets(int* nums, int numsSize, int* returnSize, int** returnColumnSizes) 
{
    n           = numsSize;
    result      = malloc(MAX_RESULT * sizeof(int *));
    result_len  = 0;
    column_size = malloc(MAX_RESULT * sizeof(int));
    path_len    = 0;

    dfs(nums, 0);

    *returnSize = result_len;
    *returnColumnSizes = column_size;
    return result;
}

时间复杂度,每次递归,只有选或不选两种情况,$ O(2^n) $ 复制的时间是 O(n) 所以,总的复杂度是 $ O(n*2^n) $

模板二

另一个思路,站在答案的角度去考虑,枚举第一个选谁,第二个选谁,每个节点都是答案

  • 当前操作?枚举一个下标 $ \ge i $ 的数字,加入 path

  • 子问题?从下标 $ \ge i $ 的数字中构造子集

  • 下一个子问题?从下标 $ \ge i + 1 $ 的数字中构造子集

#define MAX_RESULT  100000

int n;              /* 数组元素个数,终止条件判断 */

int** result;       /* 保存结果的数组 */
int result_len;     /* 结果个数 */
int *column_size;   /* 每个结果的长度 */

int path[10] = {0}; /* 单个结果 */
int path_len;       /* 单个结果的长度 */

void save_result(void)
{
    result[result_len] = malloc(path_len * sizeof(int));
    column_size[result_len] = path_len;

    for (int i=0; i<path_len; i++)
        result[result_len][i] = path[i];

    result_len++;
}

void dfs(int *nums, int i)
{
    save_result();

    if ( i == n )   return;
        
    for (int j=i; j<n; j++)
    {
        path[path_len] = nums[j];
        path_len++;
        dfs(nums, j+1);
        path_len--;
    }
}

int** subsets(int* nums, int numsSize, int* returnSize, int** returnColumnSizes) 
{
    n           = numsSize;
    result      = malloc(MAX_RESULT * sizeof(int *));
    result_len  = 0;
    column_size = malloc(MAX_RESULT * sizeof(int));
    path_len    = 0;

    dfs(nums, 0);

    *returnSize = result_len;
    *returnColumnSizes = column_size;
    return result;
}

应用:分割回文串

组合型回溯

组合问题,在子集问题的基础上,增加一个判断逻辑,比如只要选两个数。

相比子集问题,组合问题是可以做一些额外的优化的。

设 path 长度为 m,总共要选 k 个数,那么还需要选 d = k - m 个数。剩下的可以选的数的个数小于 d,那么就不需要递归了,肯定选不出来。这是一个剪枝技巧。

int * single_out = NULL;
int single_out_len = 0;

int **result = NULL;
int result_len = 0;

void backtracking(int n, int k, int start)
{
    //保存结果
    if (single_out_len == k)
    {
        int *temp = malloc(k*sizeof(int));

        for (int i=0; i<k; i++)
        {
            temp[i] = single_out[i];
        }

        result[result_len] = temp;
        result_len++;

        return; //忘记加了
    }

    for (int i=start; i<=n; i++)
    {
        single_out[single_out_len] = i;
        single_out_len++;

        backtracking(n, k, i+1);

        single_out--;
    }
}

int** combine(int n, int k, int* returnSize, int** returnColumnSizes) {

    single_out = (int*)malloc(k*sizeof(int));
    result = malloc(5000*sizeof(int));

    result_len = 0; // 注释掉会出错,为什么?

    backtracking(n, k, 1);

    *returnSize = result_len;

    int **column = malloc(result_len*sizeof(int));

    for(int i=0; i<result_len; i++)
    {
        (*column)[i] = k;
    }

    return result;
}

排列型回溯

全排列,全排列的长度就是数组长度,全排列的个数就是元素个数的阶乘。

N 皇后

char ** single_out = NULL;
int single_Q_num = 0;

char ***result = NULL;
int result_len = 0;


void save_single_result(int n) 
{
    char **temp = (char**)malloc(sizeof(char) * (n + 1)*n);

    for(int i = 0; i < (n + 1)*n; i++) {
        **(temp + i) = ** (single_out + i);
    }

    result[result_len] = temp;
    result_len++;
}


int is_ok(int x, int y, int n) {

    char ** path =  single_out;

    int i, j;
    //检查同一行以及同一列是否有效
    for(i = 0; i < n; ++i) {
        if(path[y][i] == 'Q' || path[i][x] == 'Q')
            return 0;
    }
    //下面两个for循环检查斜角45度是否有效
    i = y - 1;
    j = x - 1;
    while(i >= 0 && j >= 0) {
        if(path[i][j] == 'Q')
            return 0;
        --i, --j;
    }

    i = y + 1;
    j = x + 1;
    while(i < n && j < n) {
        if(path[i][j] == 'Q')
            return 0;
        ++i, ++j;
    }

    //下面两个for循环检查135度是否有效
    i = y - 1;
    j = x + 1;
    while(i >= 0 && j < n) {
        if(path[i][j] == 'Q')
            return 0;
        --i, ++j;
    }

    i = y + 1;
    j = x -1;
    while(j >= 0 && i < n) {
        if(path[i][j] == 'Q')
            return 0;
        ++i, --j;
    }
    return 1;
}

void backtracking(int n, int j)
{
    if (single_Q_num == n)
    {
        save_single_result(n);

        return;
    }

    for (int i=0; i<n; i++)
    {
        if (is_ok(i, j, n) == 1)
        {
            single_out[j][i] = 'Q';
            single_Q_num++;

            backtracking(n, j + 1);

            single_out[j][i] = '.';
            single_Q_num--;

        }
    }
}


char*** solveNQueens(int n, int* returnSize, int** returnColumnSizes) {

    single_Q_num = 0;
    result_len = 0;

    single_out = (char **)malloc(sizeof(char)*(n+1)*n);

    result = (char ***)malloc(sizeof(char**)*100);

    memset(single_out, '.', (n+1)*n);

    for (int i=0;i<n;i++)
        single_out[i][n] = '\0';

    backtracking(n, 0);

    *returnSize = result_len;

    *returnColumnSizes = (int*)malloc(sizeof(int) * result_len);
    for(int i; i < result_len; i++) {
        (*returnColumnSizes)[i] = n;
    }

    return result;

}

最后更新于