У мене 1D масив у numpy, і я хочу знайти позицію індексу, де значення перевищує значення в numpy масиві.
Напр
aa = range(-10,10)
Знайдіть позицію, aa
де значення 5
перевищується.
У мене 1D масив у numpy, і я хочу знайти позицію індексу, де значення перевищує значення в numpy масиві.
Напр
aa = range(-10,10)
Знайдіть позицію, aa
де значення 5
перевищується.
Відповіді:
Це трохи швидше (і виглядає приємніше)
np.argmax(aa>5)
Оскільки argmax
зупиняється на першому True
("У разі декількох зустрічей максимальних значень, індекси, що відповідають першому входженню, повертаються.") І не зберігають інший список.
In [2]: N = 10000
In [3]: aa = np.arange(-N,N)
In [4]: timeit np.argmax(aa>N/2)
100000 loops, best of 3: 52.3 us per loop
In [5]: timeit np.where(aa>N/2)[0][0]
10000 loops, best of 3: 141 us per loop
In [6]: timeit np.nonzero(aa>N/2)[0][0]
10000 loops, best of 3: 142 us per loop
argmax
схоже, не зупиняється спочатку True
. (Це можна перевірити, створивши булеві масиви з одинарним True
в різних положеннях.) Швидкість, ймовірно, пояснюється тим, що argmax
не потрібно створювати вихідний список.
argmax
.
aa
сортується, як у відповіді @ Майкла).
argmax
на 10-мільйонних булевих масивах з одним синглом True
в різних позиціях, використовуючи NumPy 1.11.2, і положення True
важливого. Отже, 1.11.2, argmax
здається, "коротке замикання" на булевих масивах.
Враховуючи відсортований вміст масиву, існує ще більш швидкий метод: пошук за посиланням .
import time
N = 10000
aa = np.arange(-N,N)
%timeit np.searchsorted(aa, N/2)+1
%timeit np.argmax(aa>N/2)
%timeit np.where(aa>N/2)[0][0]
%timeit np.nonzero(aa>N/2)[0][0]
# Output
100000 loops, best of 3: 5.97 µs per loop
10000 loops, best of 3: 46.3 µs per loop
10000 loops, best of 3: 154 µs per loop
10000 loops, best of 3: 154 µs per loop
+1
допомогоюnp.searchsorted(..., side='right')
side
аргумент має значення лише в тому випадку, якщо в відсортованому масиві є повторні значення. Це не змінює значення повернутого індексу, який завжди є індексом, на який можна вставити значення запиту, зміщуючи всі наступні записи праворуч і підтримуючи відсортований масив.
side
має ефект, коли однакове значення є і в відсортованому, і в вставленому масиві, незалежно від повторних значень в будь-якому. Повторні значення в відсортованому масиві просто перебільшують ефект (різниця між сторонами - це кількість разів, коли значення, яке вставляється, з'являється в відсортованому масиві). side
це змінити значення, що повертається індексу, хоча він не змінює результуючий масив з вставки значення в відсортований масив на цих індексах. Тонка, але важлива відмінність; насправді ця відповідь дає неправильний показник, якщо N/2
його немає aa
.
N/2
її немає aa
. Правильна форма була б np.searchsorted(aa, N/2, side='right')
(без +1
). Обидві форми дають однаковий показник інакше. Розглянемо тестовий випадок N
як непарний (і N/2.0
примусити плавати, якщо використовується python 2).
Мене це також зацікавило і я порівняв усі запропоновані відповіді з perfplot . (Відмова: Я автор перфплоту.)
Якщо ви знаєте, що масив, який ви шукаєте, вже відсортований , то
numpy.searchsorted(a, alpha)
для вас. Це операція постійного часу, тобто швидкість не залежить від розміру масиву. Ви не можете отримати швидше, ніж це.
Якщо ви нічого не знаєте про свій масив, ви не помилитесь
numpy.argmax(a > alpha)
Вже відсортовано:
Несортовано:
Код для відтворення сюжету:
import numpy
import perfplot
alpha = 0.5
def argmax(data):
return numpy.argmax(data > alpha)
def where(data):
return numpy.where(data > alpha)[0][0]
def nonzero(data):
return numpy.nonzero(data > alpha)[0][0]
def searchsorted(data):
return numpy.searchsorted(data, alpha)
out = perfplot.show(
# setup=numpy.random.rand,
setup=lambda n: numpy.sort(numpy.random.rand(n)),
kernels=[
argmax, where,
nonzero,
searchsorted
],
n_range=[2**k for k in range(2, 20)],
logx=True,
logy=True,
xlabel='len(array)'
)
np.searchsorted
не є постійним часом. Це насправді O(log(n))
. Але ваш тестовий випадок насправді відміряє найкращий випадок searchsorted
(який є O(1)
).
searchsorted
(або будь-який алгоритм) може перемогти O(log(n))
двійковий пошук відсортованих рівномірно розподілених даних. EDIT: searchsorted
це двійковий пошук.
У випадку з range
будь-яким іншим лінійно зростаючим масивом ви можете просто обчислити індекс програмно, взагалі не потрібно ітераціювати масив:
def first_index_calculate_range_like(val, arr):
if len(arr) == 0:
raise ValueError('no value greater than {}'.format(val))
elif len(arr) == 1:
if arr[0] > val:
return 0
else:
raise ValueError('no value greater than {}'.format(val))
first_value = arr[0]
step = arr[1] - first_value
# For linearly decreasing arrays or constant arrays we only need to check
# the first element, because if that does not satisfy the condition
# no other element will.
if step <= 0:
if first_value > val:
return 0
else:
raise ValueError('no value greater than {}'.format(val))
calculated_position = (val - first_value) / step
if calculated_position < 0:
return 0
elif calculated_position > len(arr) - 1:
raise ValueError('no value greater than {}'.format(val))
return int(calculated_position) + 1
Можливо, можна було б трохи покращити це. Я переконався, що він працює правильно для декількох вибіркових масивів і значень, але це не означає, що там не могло бути помилок, особливо враховуючи, що він використовує floats ...
>>> import numpy as np
>>> first_index_calculate_range_like(5, np.arange(-10, 10))
16
>>> np.arange(-10, 10)[16] # double check
6
>>> first_index_calculate_range_like(4.8, np.arange(-10, 10))
15
Враховуючи, що він може обчислити позицію без будь-якої ітерації, це буде постійний час ( O(1)
) і, ймовірно, може обіграти всі інші згадані підходи. Однак він вимагає постійного кроку в масиві, інакше він дасть неправильні результати.
Більш загальним підходом було б використання функції numba:
@nb.njit
def first_index_numba(val, arr):
for idx in range(len(arr)):
if arr[idx] > val:
return idx
return -1
Це буде працювати для будь-якого масиву, але він повинен перебирати масив, тому в середньому випадку це буде O(n)
:
>>> first_index_numba(4.8, np.arange(-10, 10))
15
>>> first_index_numba(5, np.arange(-10, 10))
16
Навіть незважаючи на те, що Ніко Шльомер вже надав деякі орієнтири, я вважав, що може бути корисним включити мої нові рішення та протестувати на різні "значення".
Тестова установка:
import numpy as np
import math
import numba as nb
def first_index_using_argmax(val, arr):
return np.argmax(arr > val)
def first_index_using_where(val, arr):
return np.where(arr > val)[0][0]
def first_index_using_nonzero(val, arr):
return np.nonzero(arr > val)[0][0]
def first_index_using_searchsorted(val, arr):
return np.searchsorted(arr, val) + 1
def first_index_using_min(val, arr):
return np.min(np.where(arr > val))
def first_index_calculate_range_like(val, arr):
if len(arr) == 0:
raise ValueError('empty array')
elif len(arr) == 1:
if arr[0] > val:
return 0
else:
raise ValueError('no value greater than {}'.format(val))
first_value = arr[0]
step = arr[1] - first_value
if step <= 0:
if first_value > val:
return 0
else:
raise ValueError('no value greater than {}'.format(val))
calculated_position = (val - first_value) / step
if calculated_position < 0:
return 0
elif calculated_position > len(arr) - 1:
raise ValueError('no value greater than {}'.format(val))
return int(calculated_position) + 1
@nb.njit
def first_index_numba(val, arr):
for idx in range(len(arr)):
if arr[idx] > val:
return idx
return -1
funcs = [
first_index_using_argmax,
first_index_using_min,
first_index_using_nonzero,
first_index_calculate_range_like,
first_index_numba,
first_index_using_searchsorted,
first_index_using_where
]
from simple_benchmark import benchmark, MultiArgument
і сюжети були створені за допомогою:
%matplotlib notebook
b.plot()
b = benchmark(
funcs,
{2**i: MultiArgument([0, np.arange(2**i)]) for i in range(2, 20)},
argument_name="array size")
Функція numba найкраще виконує функцію обчислення та функцію, на яку здійснюється пошук. Інші рішення діють набагато гірше.
b = benchmark(
funcs,
{2**i: MultiArgument([2**i-2, np.arange(2**i)]) for i in range(2, 20)},
argument_name="array size")
Для малих масивів функція numba виконує надзвичайно швидко, однак для більших масивів вона перевершує функцію обчислення та функцію, на яку здійснюється пошук.
b = benchmark(
funcs,
{2**i: MultiArgument([np.sqrt(2**i), np.arange(2**i)]) for i in range(2, 20)},
argument_name="array size")
Це цікавіше. Знову numba і обчислювальна функція чудово справляються, однак це насправді викликає найгірший випадок пошуку, який насправді не працює в цьому випадку.
Ще один цікавий момент - як вони поводяться, якщо немає значення, індекс якого слід повернути:
arr = np.ones(100)
value = 2
for func in funcs:
print(func.__name__)
try:
print('-->', func(value, arr))
except Exception as e:
print('-->', e)
З цим результатом:
first_index_using_argmax
--> 0
first_index_using_min
--> zero-size array to reduction operation minimum which has no identity
first_index_using_nonzero
--> index 0 is out of bounds for axis 0 with size 0
first_index_calculate_range_like
--> no value greater than 2
first_index_numba
--> -1
first_index_using_searchsorted
--> 101
first_index_using_where
--> index 0 is out of bounds for axis 0 with size 0
Пошукові, аргументовані та numba просто повертають неправильне значення. Тим НЕ менше , searchsorted
і numba
повертає індекс , який не є допустимим індексом для масиву.
Функції where
, min
, nonzero
і calculate
кинути виняток. Однак лише виняток calculate
насправді говорить про щось корисне.
Це означає, що потрібно фактично перетворити ці виклики у відповідну функцію обгортки, яка фіксує винятки або недійсні значення повернення та обробляє належним чином, принаймні, якщо ви не впевнені, чи може це значення знаходитись у масиві.
Примітка. Обчислення та searchsorted
параметри працюють лише в особливих умовах. Функція "обчислити" вимагає постійного кроку, а пошуковий запит вимагає сортування масиву. Таким чином, вони можуть бути корисними за правильних обставин, але не є загальним рішенням цієї проблеми. У випадку, якщо ви маєте справу з відсортованими списками Python, вам варто поглянути на бісект- модуль замість того, щоб використовувати пошукові запити Numpys.