在 MATLAB 中矢量化包含 circshift 的多个嵌套循环

drummy_1

我想提高包含多个嵌套循环的操作的性能。该操作应该给我 4 个矩阵的所有可能的总和组合。目前,我为此使用 circshift,如下所示:

for ss = 1:max_n_steps+1 % max steps + possibility of NO STEP!!
    for tt = 1:max_n_steps+1
        for uu = 1:max_n_steps+1
            curr_perm = horzcat(curr_perm, curr_first + circshift(curr_second, ss-1, 2) + circshift(curr_third, tt-1, 2) + circshift(curr_fourth, uu-1, 2));
        end
    end
end

这会遍历所有四个矩阵(curr_first、curr_second 等),将它们相加,并将它们与之前的组合连接起来。它导致所有可能的列总和的令人满意的数组。然而,我引入的循环越多(例如,当我有 5 个矩阵而不是 4 个矩阵时),这需要越来越多的时间。是否有可能对这个过程进行矢量化?我看到使用 bsxfun 之类的东西,但不知道如何在此处应用它们。

谢谢!

将要

起点的主要低效不是它使用循环而不是向量运算,而是它circshift多次重新计算对相同矩阵和相同偏移量的调用例如,您的表达式circshift(curr_second, ss-1, 2)仅在更改时ss更改,但在内部循环的每次迭代中进行评估:

for tt = 1:max_n_steps+1
    for uu = 1:max_n_steps+1

即使curr_second并且ss在整个循环中都不会改变。

您可以通过评估每个矩阵的所有可能的移位,然后通过索引这些结果来添加这些移位的排列来避免这种情况:

curr_second_shifts = nan([size(curr_second) max_n_steps+1]);
curr_third_shifts  = nan([size(curr_third ) max_n_steps+1]);
curr_fourth_shifts = nan([size(curr_fourth) max_n_steps+1]);

for ii = 1:max_n_steps+1
    curr_second_shifts(:,:,ii) = circshift(curr_second, ii-1, 2);
    curr_third_shifts( :,:,ii) = circshift(curr_third,  ii-1, 2);
    curr_fourth_shifts(:,:,ii) = circshift(curr_fourth, ii-1, 2);
end

for ss = 1:max_n_steps+1
    for tt = 1:max_n_steps+1
        for uu = 1:max_n_steps+1
            curr_perm = horzcat(curr_perm, curr_first + ...
                                           curr_second_shifts(:,:,ss) + ...
                                           curr_third_shifts( :,:,tt) + ...
                                           curr_fourth_shifts(:,:,uu));
         end
    end
end

为了完全矢量化,您可以存储每个矩阵沿不同维度的不同移位;MATLAB 的单例扩展将在添加时用相关的排列填充这个 n 维数组。

curr_second_shifts = nan([size(curr_second) max_n_steps+1]);
curr_third_shifts  = nan([size(curr_third ) 1 max_n_steps+1]);
curr_fourth_shifts = nan([size(curr_fourth) 1 1 max_n_steps+1]);

for ii = 1:max_n_steps+1
    curr_second_shifts(:,:,ii)     = circshift(curr_second, ii-1, 2);
    curr_third_shifts( :,:,1,ii)   = circshift(curr_third,  ii-1, 2);
    curr_fourth_shifts(:,:,1,1,ii) = circshift(curr_fourth, ii-1, 2);
end

curr_perm = curr_first  + ...
            curr_second_shifts + ...
            curr_third_shifts  + ...
            curr_fourth_shifts);

然后要将排列在更高维度上的所有排列重新放入原始列系列中,只需重新索引 2 是最高维度:

curr_perm = curr_perm(:,:);

然而,这可以进一步矢量化以消除每个矩阵的单独变量,这对于您必须添加到这些排列中的维度越多越重要。让我们定义curr一个 3D 矩阵,使得:

curr = cat(3,curr_first,curr_second,curr_third,curr_fourth);

然后可以通过对circshift每个偏移量的一次调用来计算所有偏移:

curr_shift = nan([size(curr(:,:,2:end)) max_n_steps+1]);

for ii = 1:max_n_steps+1
    curr_shift(:,:,:,ii) = circshift(curr(:,:,2:end), ii-1, 2);
end

然后在添加之前使用越来越多的单例维度进行置换:

curr_perm = curr_first;

for ii = 1:size(curr_shift,3)
    curr_perm = curr_perm + permute(curr_shift(:,:,ii,:),[1:3 5:(ii+3) 4]);
end

curr_perm = curr_perm(:,:);

对于早于 R2016b 的 MATLAB,扩展不会隐式发生,您将需要bsxfun按照您的建议:

for ii = 1:size(curr_shift,3)
    curr_perm = bsxfun(@plus, curr_perm, permute(curr_shift(:,:,ii,:),[1:3 5:(ii+3) 4]);
end

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章