关于回溯。
一个问题两个字符串各选一个组成一个 2 个字符的字符串。最容易想到的是两个循环嵌套。
但是如果说要生成一个长为3、长为 4,或者长度是个参数无法提前知道,那么没法用循环嵌套来表达了。
这个子问题和原问题是很相似的。
这种可以分解的问题,可以使用递归来实现。回溯有一个增量构造答案的过程。
把边界条件和非边界条件的逻辑写对了,那递归代码一定能得到正确的结果。
回溯三问:
对于电话号码组合,用 path[i] 来记录路径上的字母
dfs(i) -> dfs(i+1)
子集型回溯
模板一
子集型回溯,看每个元素选或者不选。
比如计算 $ [1, 2] $ 所有的子集,每个元素都可以选或者不选,总共就有 4 种情况。
回溯三问,站在输入的角度思考
子问题?从下标 $ \ge i $ 的数字中构造子集
下一个子问题?从下标 $ \ge i + 1 $ 的数字中构造子集
#define MAX_RESULT 100000
int n; /* 数组元素个数,终止条件判断 */
int** result; /* 保存结果的数组 */
int result_len; /* 结果个数 */
int *column_size; /* 每个结果的长度 */
int path[10] = {0}; /* 单个结果 */
int path_len; /* 单个结果的长度 */
void save_result(void)
{
result[result_len] = malloc(path_len * sizeof(int));
column_size[result_len] = path_len;
for (int i=0; i<path_len; i++)
result[result_len][i] = path[i];
result_len++;
}
void dfs(int *nums, int i)
{
if ( i == n )
{
save_result();
return;
}
dfs(nums, i+1); /* 当前操作不选,直接进下一层递归 */
path[path_len] = nums[i]; /* 当前操作选, */
path_len++;
dfs(nums, i+1);
path_len--;
}
int** subsets(int* nums, int numsSize, int* returnSize, int** returnColumnSizes)
{
n = numsSize;
result = malloc(MAX_RESULT * sizeof(int *));
result_len = 0;
column_size = malloc(MAX_RESULT * sizeof(int));
path_len = 0;
dfs(nums, 0);
*returnSize = result_len;
*returnColumnSizes = column_size;
return result;
}
时间复杂度,每次递归,只有选或不选两种情况,$ O(2^n) $ 复制的时间是 O(n) 所以,总的复杂度是 $ O(n*2^n) $
模板二
另一个思路,站在答案的角度去考虑,枚举第一个选谁,第二个选谁,每个节点都是答案
当前操作?枚举一个下标 $ \ge i $ 的数字,加入 path
子问题?从下标 $ \ge i $ 的数字中构造子集
下一个子问题?从下标 $ \ge i + 1 $ 的数字中构造子集
#define MAX_RESULT 100000
int n; /* 数组元素个数,终止条件判断 */
int** result; /* 保存结果的数组 */
int result_len; /* 结果个数 */
int *column_size; /* 每个结果的长度 */
int path[10] = {0}; /* 单个结果 */
int path_len; /* 单个结果的长度 */
void save_result(void)
{
result[result_len] = malloc(path_len * sizeof(int));
column_size[result_len] = path_len;
for (int i=0; i<path_len; i++)
result[result_len][i] = path[i];
result_len++;
}
void dfs(int *nums, int i)
{
save_result();
if ( i == n ) return;
for (int j=i; j<n; j++)
{
path[path_len] = nums[j];
path_len++;
dfs(nums, j+1);
path_len--;
}
}
int** subsets(int* nums, int numsSize, int* returnSize, int** returnColumnSizes)
{
n = numsSize;
result = malloc(MAX_RESULT * sizeof(int *));
result_len = 0;
column_size = malloc(MAX_RESULT * sizeof(int));
path_len = 0;
dfs(nums, 0);
*returnSize = result_len;
*returnColumnSizes = column_size;
return result;
}
应用:分割回文串
组合型回溯
组合问题,在子集问题的基础上,增加一个判断逻辑,比如只要选两个数。
相比子集问题,组合问题是可以做一些额外的优化的。
设 path 长度为 m,总共要选 k 个数,那么还需要选 d = k - m 个数。剩下的可以选的数的个数小于 d,那么就不需要递归了,肯定选不出来。这是一个剪枝技巧。
int * single_out = NULL;
int single_out_len = 0;
int **result = NULL;
int result_len = 0;
void backtracking(int n, int k, int start)
{
//保存结果
if (single_out_len == k)
{
int *temp = malloc(k*sizeof(int));
for (int i=0; i<k; i++)
{
temp[i] = single_out[i];
}
result[result_len] = temp;
result_len++;
return; //忘记加了
}
for (int i=start; i<=n; i++)
{
single_out[single_out_len] = i;
single_out_len++;
backtracking(n, k, i+1);
single_out--;
}
}
int** combine(int n, int k, int* returnSize, int** returnColumnSizes) {
single_out = (int*)malloc(k*sizeof(int));
result = malloc(5000*sizeof(int));
result_len = 0; // 注释掉会出错,为什么?
backtracking(n, k, 1);
*returnSize = result_len;
int **column = malloc(result_len*sizeof(int));
for(int i=0; i<result_len; i++)
{
(*column)[i] = k;
}
return result;
}
排列型回溯
全排列,全排列的长度就是数组长度,全排列的个数就是元素个数的阶乘。
N 皇后
char ** single_out = NULL;
int single_Q_num = 0;
char ***result = NULL;
int result_len = 0;
void save_single_result(int n)
{
char **temp = (char**)malloc(sizeof(char) * (n + 1)*n);
for(int i = 0; i < (n + 1)*n; i++) {
**(temp + i) = ** (single_out + i);
}
result[result_len] = temp;
result_len++;
}
int is_ok(int x, int y, int n) {
char ** path = single_out;
int i, j;
//检查同一行以及同一列是否有效
for(i = 0; i < n; ++i) {
if(path[y][i] == 'Q' || path[i][x] == 'Q')
return 0;
}
//下面两个for循环检查斜角45度是否有效
i = y - 1;
j = x - 1;
while(i >= 0 && j >= 0) {
if(path[i][j] == 'Q')
return 0;
--i, --j;
}
i = y + 1;
j = x + 1;
while(i < n && j < n) {
if(path[i][j] == 'Q')
return 0;
++i, ++j;
}
//下面两个for循环检查135度是否有效
i = y - 1;
j = x + 1;
while(i >= 0 && j < n) {
if(path[i][j] == 'Q')
return 0;
--i, ++j;
}
i = y + 1;
j = x -1;
while(j >= 0 && i < n) {
if(path[i][j] == 'Q')
return 0;
++i, --j;
}
return 1;
}
void backtracking(int n, int j)
{
if (single_Q_num == n)
{
save_single_result(n);
return;
}
for (int i=0; i<n; i++)
{
if (is_ok(i, j, n) == 1)
{
single_out[j][i] = 'Q';
single_Q_num++;
backtracking(n, j + 1);
single_out[j][i] = '.';
single_Q_num--;
}
}
}
char*** solveNQueens(int n, int* returnSize, int** returnColumnSizes) {
single_Q_num = 0;
result_len = 0;
single_out = (char **)malloc(sizeof(char)*(n+1)*n);
result = (char ***)malloc(sizeof(char**)*100);
memset(single_out, '.', (n+1)*n);
for (int i=0;i<n;i++)
single_out[i][n] = '\0';
backtracking(n, 0);
*returnSize = result_len;
*returnColumnSizes = (int*)malloc(sizeof(int) * result_len);
for(int i; i < result_len; i++) {
(*returnColumnSizes)[i] = n;
}
return result;
}