Як працює параметр class_weight в scikit-learn?


116

У мене виникають багато проблем з розумінням того, як функціонує class_weightпараметр логістичної регресії scikit-learn.

Ситуація

Я хочу використовувати логістичну регресію, щоб зробити бінарну класифікацію на дуже незбалансованому наборі даних. Класи мають позначення 0 (негативний) та 1 (позитивний), і спостережувані дані знаходяться у співвідношенні приблизно 19: 1, більшість зразків мають негативний результат.

Перша спроба: підготовка даних про навчання вручну

Я розділив отримані нами дані на непересічні набори для навчання та тестування (приблизно 80/20). Потім я випадковим чином відбирав дані тренувань вручну, щоб отримати дані тренувань у різних пропорціях, ніж 19: 1; від 2: 1 -> 16: 1.

Потім я тренував логістичну регресію на цих різних підмножинах даних тренувань і будував графічний відкликання (= TP / (TP + FN)) як функцію різних пропорцій тренувань. Звичайно, відкликання було обчислено на непересічних зразках TEST, які мали дотриману пропорцію 19: 1. Зауважте, хоча я тренував різні моделі за різними даними про навчання, я обчислював виклик для всіх їх на одних і тих же (непересічних) даних тесту.

Результати були такими, як і очікувалося: відкликання становило близько 60% при пропорціях тренувань 2: 1 і впало досить швидко до того моменту, коли він досяг 16: 1. Було кілька пропорцій 2: 1 -> 6: 1, коли відкликання було гідно вище 5%.

Друга спроба: пошук сітки

Далі я хотів перевірити різні параметри регуляризації, і тому я використав GridSearchCV і зробив сітку з кількох значень Cпараметра, а також class_weightпараметра. Щоб перевести мої n: m пропорції негативні: позитивні навчальні зразки на мову словника, class_weightя подумав, що я просто вкажу кілька словників так:

{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 }   #expected 4:1

і я також включав Noneі auto.

Цього разу результати були повністю пробиті. Усі мої відкликання виходили крихітними (<0,05) за кожне значення, за class_weightвинятком auto. Тож я можу лише припустити, що моє розуміння того, як встановити class_weightсловник, є неправильним. Цікаво, що class_weightзначення "автоматичного" в пошуку в сітці становило близько 59% для всіх значень C, і я здогадався, що він відповідає 1: 1?

Мої запитання

  1. Як ви правильно використовуєте class_weightдля досягнення різних балансів у навчальних даних від того, що ви насправді надаєте? Зокрема, до якого словника я переходжу, щоб class_weightвикористовувати n: m пропорції мінус: позитивні навчальні зразки?

  2. Якщо ви class_weightпередаєте різні словники до GridSearchCV, під час перехресної перевірки вона повторно врівноважить дані тренувальної складки відповідно до словника, але використовуватиме справжні задані пропорції пропорцій для обчислення моєї функції оцінювання на тестовій складці? Це критично важливо, оскільки будь-який показник корисний для мене, лише якщо він виходить із даних у спостережуваних пропорціях.

  3. Що означає autoзначення, class_weightщо стосується пропорцій? Я читаю документацію, і я припускаю, що "врівноважує дані, обернено пропорційні їх частоті", це означає, що це робить 1: 1. Це правильно? Якщо ні, то хтось може уточнити?


Коли користується class_weight, функція втрати змінюється. Наприклад, замість хрестової ентропії вона стає зваженою хрестовою ентропією. в напрямку доdatascience.com/…
prashanth

Відповіді:


123

По-перше, це може бути не добре, щоб просто піти на згадку в самоті. Ви можете просто досягти 100% відкликання, класифікуючи все як позитивний клас. Зазвичай я пропоную використовувати AUC для вибору параметрів, а потім знайти поріг для робочої точки (скажімо, заданого рівня точності), який вас цікавить.

Як class_weightпрацює: Це карає на помилки у зразках class[i]з class_weight[i]замість 1. Отже, більша вага класу означає, що ви хочете зробити більше уваги на класі. З того, що ви говорите, здається, що клас 0 в 19 разів частіше, ніж клас 1. Тому вам слід збільшити class_weightклас 1 щодо класу 0, скажімо, {0: .1, 1: .9}. Якщо class_weightзначення не дорівнює 1, це в основному змінить параметр регуляризації.

Як це class_weight="auto"працює, ви можете ознайомитись із цією дискусією . У версії dev можна використовувати class_weight="balanced", що простіше зрозуміти: це в основному означає реплікацію меншого класу, поки у вас не буде стільки зразків, скільки у більшому, але неявно.


1
Дякую! Швидке запитання: я згадав про відкликання для ясності, і насправді я намагаюся вирішити, який AUC використовувати як міру. Я розумію, що для пошуку параметрів я повинен бути або максимізувати площу під кривою ROC, або площу під кривою відкликання проти точності. Після вибору параметрів таким чином, я вважаю, що вибираю поріг класифікації, ковзаючи по кривій. Це ви мали на увазі? Якщо так, то яка з двох кривих має найбільш сенс дивитись, якщо моя мета - захопити якомога більше ТП? Також дякую за вашу роботу та внесок у scikit-learn !!!
kilgoretrout

1
Я думаю, що використання ROC було б більш стандартним шляхом, але я не думаю, що буде велика різниця. Для вибору точки на кривій вам потрібен певний критерій.
Андреас Мюллер

3
@MiNdFrEaK Я думаю, що Андрій означає, що оцінювач копіює зразки класу меншості, щоб вибірки різних класів були врівноваженими. Це просто пересимплінг неявно.
Шон ТІАН

8
@MiNdFrEaK та Shawn Tian: Класифікатори на основі SV не створюють більше зразків менших класів, коли ви використовуєте 'збалансований'. Це буквально карає помилки, допущені на менших класах. Якщо сказати інакше, це помилка і вводить в оману, особливо у великих наборах даних, коли ви не можете дозволити собі створення більше зразків. Цю відповідь необхідно відредагувати.
Пабло Рівас

4
scikit-learn.org/dev/glossary.html#term-class- вага Ваги класів будуть використовуватися по-різному в залежності від алгоритму: для лінійних моделей (таких як лінійна SVM або логістична регресія) ваги класів змінять функцію втрат на зважування втрат кожного зразка на його класну вагу. Для алгоритмів на основі дерев ваги класів будуть використані для зважування критерію розщеплення. Однак зауважте, що цей баланс не враховує вагу зразків у кожному класі.
Prashanth

2

Перша відповідь хороша для розуміння того, як це працює. Але я хотів зрозуміти, як я маю це використовувати на практиці.

ПІДСУМОК

  • для помірно незбалансованих даних БЕЗ шуму не велика різниця у застосуванні вагових класів
  • для помірно незбалансованих даних З шумом і сильно незбалансованих, краще застосовувати ваги класів
  • парам class_weight="balanced"працює пристойно за відсутності бажаючих оптимізувати вручну
  • з class_weight="balanced"вами захопити більше справжні події (вище ІСТИНА згадування) , але і ви, швидше за все , щоб отримати помилкові сигнали (знизити ИСТИНУ точності)
    • як результат, загальний% ІСТИНА може бути вище фактичного через усі помилкові позитиви
    • AUC може ввести вас в оману, якщо помилкові сигнали тривожать
  • не потрібно змінювати поріг рішення на дисбаланс%, навіть для сильного дисбалансу, добре, щоб утримувати 0,5 (або десь навколо цього, залежно від того, що вам потрібно)

NB

Результат може відрізнятися при використанні RF або GBM. sklearn не має class_weight="balanced" для GBM, але lightgbm маєLGBMClassifier(is_unbalance=False)

КОД

# scikit-learn==0.21.3
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
import pandas as pd

# case: moderate imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.8]) #,flip_y=0.1,class_sep=0.5)
np.mean(y) # 0.2

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.184
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X).mean() # 0.296 => seems to make things worse?
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.292 => seems to make things worse?

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.83
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X)) # 0.86 => about the same
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.86 => about the same

# case: strong imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.95])
np.mean(y) # 0.06

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.02
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X).mean() # 0.25 => huh??
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.22 => huh??
(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).mean() # same as last

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.64
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X)) # 0.84 => much better
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.85 => similar to manual
roc_auc_score(y,(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).astype(int)) # same as last

print(classification_report(y,LogisticRegression(C=1e9).fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True,normalize='index') # few prediced TRUE with only 28% TRUE recall and 86% TRUE precision so 6%*28%~=2%

print(classification_report(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True,normalize='index') # 88% TRUE recall but also lot of false positives with only 23% TRUE precision, making total predicted % TRUE > actual % TRUE
Використовуючи наш веб-сайт, ви визнаєте, що прочитали та зрозуміли наші Політику щодо файлів cookie та Політику конфіденційності.
Licensed under cc by-sa 3.0 with attribution required.