Увага - це метод агрегації набору векторів лише в один вектор, часто через вектор . Зазвичай є або входами до моделі, або прихованими станами попередніх часових кроків, або прихованими станами на один рівень вниз (у випадку складених LSTM).viuvi
Результат часто називають контекстним вектором , оскільки він містить контекст, відповідний поточному етапу часу.c
Цей додатковий векторний контекст також подається в RNN / LSTM (він може бути просто з'єднаний з вихідним входом). Тому контекст може бути використаний, щоб допомогти з передбаченням.c
Найпростіший спосіб зробити це - обчислити вектор вірогідності і де - конкатенація всіх попередніх . Загальним вектором пошуку є поточний прихований стан .p=softmax(VTu)c=∑ipiviVviuht
Існує багато варіацій щодо цього, і ви можете зробити такі речі складними, як хочете. Наприклад, замість того, щоб використовувати як логити, можна вибрати замість , де - довільна нейронна мережа.vTiuf(vi,u)f
Загальний механізм уваги для моделей послідовності до послідовності використовує , де - приховані стани кодера, а - поточний прихований стан декодера. і обидва s - параметри.p=softmax(qTtanh(W1vi+W2ht))vhtqW
Деякі документи, які демонструють різні варіанти ідеї уваги:
Мережі вказівників використовують увагу на опорні входи для вирішення проблем комбінаторної оптимізації.
Поточні мережеві сутності підтримують окремі стани пам'яті для різних об'єктів (людей / об'єктів) під час читання тексту та оновлюють правильний стан пам'яті, використовуючи увагу.
Моделі трансформаторів також широко використовують увагу. Формулювання їх уваги є дещо більш загальним, а також включає ключові вектори : ваги уваги фактично обчислюються між клавішами та пошуком, а контекст потім будується за допомогою .kipvi
Ось швидка реалізація однієї форми уваги, хоча я не можу гарантувати правильність поза тим, що вона пройшла кілька простих тестів.
Основна RNN:
def rnn(inputs_split):
bias = tf.get_variable('bias', shape = [hidden_dim, 1])
weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])
hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
for i, input in enumerate(inputs_split):
input = tf.reshape(input, (batch, in_dim, 1))
last_state = hidden_states[-1]
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
hidden_states.append(hidden)
return hidden_states[-1]
З увагою ми додаємо лише кілька рядків до обчислення нового прихованого стану:
if len(hidden_states) > 1:
logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
probs = tf.nn.softmax(logits)
probs = tf.reshape(probs, (batch, -1, 1, 1))
context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
else:
context = tf.zeros_like(last_state)
last_state = tf.concat([last_state, context], axis = 1)
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
повний код