Python Forum
Extracting rows based on condition on one column - Printable Version

+- Python Forum (https://python-forum.io)
+-- Forum: Python Coding (https://python-forum.io/forum-7.html)
+--- Forum: Data Science (https://python-forum.io/forum-44.html)
+--- Thread: Extracting rows based on condition on one column (/thread-28858.html)



Extracting rows based on condition on one column - Robotguy - Aug-06-2020

Hi Everyone,

Here is an interesting problem I am trying to solve.

I have a (Nx4) array and want to extract those rows which have their third column's element in certain range. Are there existing capabilities in NumPy? Below is a simple example.

PS: I know how for loops can be used by comparing each element of col. 3; and saving the rows that meet the condition. But I want to use NumPy here (like slicing etc., that is promisingly fast). In reality, the arrays I use are large and implementing additional loops will sacrifice comp. times.

For example,:
input = [[1,2,-97,4],
         [5,6,93,8],
         [9,10,-105,12],
         [11,12,105,14]]

output = [[1,2,-97,4], # desired output: rows in which column third's element is greater than -100 and less than 100
          [5,6,93,8]]



RE: Extracting rows based on condition on one column - scidam - Aug-07-2020

import numpy as np
input = [[1,2,-97,4],
         [5,6,93,8],
         [9,10,-105,12],
         [11,12,105,14]]
input = np.array(input)
input[(-100 < input[:, 2]) & (input[:, 2] < 100)]



RE: Extracting rows based on condition on one column - Robotguy - Aug-07-2020

Thanks, it worked!

Out of curiosity, here is a little test I did comparing the execution time. It appears the NumPy method is 75x faster than looping. Do, you know what makes NumPy fast? Does it store the array in some efficient manner or something else?

import time
input = np.arange(4*10**7).reshape((10**7, 4))

# First method: Using NumPy
start_time = time.time()
print(input[(-1000 < input[:, 2]) & (input[:, 2] < 10000)])
print(time.time()-start_time)
start_time = time.time()

# Second method: Without NumPy

diff = []
for row in range(10**7):
    if -1000 < arr[row, 2] < 10000:
          diff.append(arr[row, :])

print(diff)

print(time.time()-start_time)