Using numpy where on multidimensional array but need to control indexing

Jim Parker

I need to modify elements of an 3D array if they exceed some threshold value. The modification is based upon related elements of another array. More concretely:

A_ijk = A_ijk if A_ijk < threshold value

= (B_(i-1)jk + B_ijk) / 2, otherwise

Numpy.where provides most of the functionality I need, but I don't know how to iterate over the first index without an explicit loop. The follow code does what I want, but uses a loop. Is there a better way? Assume A and B are same shape.

for i in xrange(A.shape[0]):
    A[i] = numpy.where(A[i] <= threshold, A[i], (B[i - 1] + B[i]) / 2)

To address the comments below: The first few rows of A are guaranteed to be below threshold. This keeps the i index from looping over to the last entry of A.

Andras Deak

You can vectorize your operation by using boolean indexing to replace the elements of A that are above the threshold. A little care has to be taken, since the auxiliary array corresponding to (B[i-1] + B[i])/2 has one less size along the first dimension than A, so we have to explicitly ignore the first row of A (knowing that they are all below the threshold, as explained in the question):

import numpy as np

# some dummy data
A = np.random.rand(3,4,5)
B = np.random.rand(3,4,5)
threshold = 0.5
A[0,:] *= threshold # put the first dummy row below threshhold
mask = A[1:] > threshold # to be overwritten, shape (2,4,5)

replace = (B[:-1] + B[1:])/2 # to overwrite elements in A from, shape (2,4,5)

# replace corresponding elements where `mask` is True
A[1:][mask] = replace[mask]

Collected from the Internet

Please contact [email protected] to delete if infringement.

edited at
0

Comments

0 comments
Login to comment

Related