Як проста модель логістичної регресії досягає 92% точності класифікації на MNIST?


68

Незважаючи на те, що всі зображення в наборі даних MNIST розташовані в центрі, з подібним масштабом і зверненими вгору без обертань, вони мають значну варіацію рукописного тексту, яка здивує мене тим, як лінійна модель досягає такої високої точності класифікації.

Наскільки я міг уявити, зважаючи на значну варіацію рукописного тексту, цифри повинні бути лінійно нероздільними в просторі розмірів 784, тобто має бути трохи складна (хоча і не дуже складна) нелінійна межа, яка розділяє різні цифри , подібно до добре цитованого прикладу коли позитивні та негативні класи не можна розділити жодним лінійним класифікатором. Мені здається, що бентежить, як багатокласна логістична регресія виробляє таку високу точність із абсолютно лінійними ознаками (без поліноміальних ознак).XOR

Наприклад, з урахуванням будь-якого пікселя на зображенні, різні рукописні зміни цифр та можуть зробити цей піксель підсвіченим чи ні. Таким чином, за допомогою набору вивчених ваг кожен піксель може робити вигляд цифри як так і . Лише за допомогою комбінації значень пікселів слід сказати, чи є цифра або . Це справедливо для більшості пар цифр. Отже, як логістична регресія, яка сліпо базує своє рішення незалежно від усіх значень пікселів (не враховуючи взагалі ніяких міжпіксельних залежностей), здатна досягти таких високих точностей.232323

Я знаю, що десь помиляюся або просто завищую варіацію зображень. Однак було б чудово, якби хтось міг допомогти мені з інтуїцією щодо того, як цифри «майже» лінійно відокремлюються.


Подивіться на підручник "Статистичне навчання з рідкістю: ласо та узагальнення" 3.3.1 Приклад: рукописні цифри web.stanford.edu/~hastie/StatLearnSparsity_files/SLS.pdf
Адріан

Мені було цікаво: наскільки добре щось робить на зразок пеніалізованої лінійної моделі (тобто glmnet)? Якщо я пригадую, те, про що ви повідомляєте, - це невизначена точність поза вибіркою.
Кліф АВ

Відповіді:


86

tl; dr Хоча це набір даних про класифікацію зображень, це залишається дуже легким завданням, для якого можна легко знайти пряме відображення від вхідних даних до прогнозів.


Відповідь:

Це дуже цікаве запитання, і завдяки простоті логістичної регресії ви справді можете дізнатися відповідь.

Логістична регресія полягає в тому, щоб кожне зображення прийняло введення і помножило їх на ваги, щоб створити його передбачення. Цікавим є те, що завдяки прямому відображенню між входом і виходом (тобто відсутність прихованого шару) значення кожної ваги відповідає тому, скільки кожного з входів враховується при обчисленні ймовірності кожного класу. Тепер, взявши ваги для кожного класу і переставивши їх на (тобто роздільну здатність зображення), ми можемо сказати, які пікселі є найбільш важливими для обчислення кожного класу .78478428×28

Знову зауважте, що це ваги .

Тепер подивіться на вищезазначене зображення і зосередиться на перших двох цифрах (тобто нуль і одна). Сині ваги означають, що інтенсивність цього пікселя багато сприяє для цього класу, а червоні значення означають, що він робить негативний внесок.

А тепер уявіть, як людина намалює ? Він малює круглу форму, яка порожня між ними. Саме так набирали ваги. Насправді, якщо хтось намалює середину зображення, це вважається негативно нулем. Тож для розпізнавання нулів вам не потрібні деякі складні фільтри та функції високого рівня. Ви можете просто переглянути місця намальованих пікселів і судити відповідно до цього.0

Те саме для . Вона завжди має пряму вертикальну лінію в середині зображення. Все інше рахується негативно.1

Решта цифр трохи складніші, але з невеликими уявами ви можете бачити , , і . Решта чисел трохи складніше, саме це фактично обмежує логістичну регресію від досягнення високих 90-х.2378

Завдяки цьому ви бачите, що логістична регресія має дуже хороший шанс отримати багато зображень правильно, і тому вона набирає такі високі результати.


Код для відтворення вищевказаної фігури трохи датований, але ось ви йдете:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)

12
Дякую за ілюстрацію Ці зображення ваги дозволяють зрозуміти, наскільки точність настільки висока. Точне множення рукописного розрядного зображення з ваговим зображенням, що відповідає справжній етикетці зображення, "здається" найвищою порівняно з крапковим продуктом з іншими мітками ваги для більшості (все ще 92% для мене схожі на багато) зображень у програмі MNIST. Однак дивно, що і або і рідко підлягають неправильному класифікації один одного при вивченні матриці плутанини. У будь-якому випадку, це і є. Дані ніколи не брешуть. :)2378
Nitish Agarwal

13
Звичайно, це допомагає тим, що зразки MNIST центрируються, масштабуються та нормалізуються контрастно, перш ніж класифікатор їх коли-небудь побачить. Вам не доведеться вирішувати питання на кшталт "а що, якщо край нуля насправді проходить через середину поля?" тому що попередній процесор вже пройшов довгий шлях до того, щоб усі нулі виглядали однаково.
варення

1
@EricDuminil Я додав похвалу до сценарію з вашою пропозицією. Дякую за вклад! : D
Djib2011

1
@NitishAgarwal! Якщо ви вважаєте, що ця відповідь - це відповідь на ваше запитання, подумайте, як позначити її.
sintax

11
Для тих, хто цікавиться, але не особливо знайомий з цим видом обробки, ця відповідь дає фантастичний інтуїтивний приклад механіки.
chrylis
Використовуючи наш веб-сайт, ви визнаєте, що прочитали та зрозуміли наші Політику щодо файлів cookie та Політику конфіденційності.
Licensed under cc by-sa 3.0 with attribution required.