Змусьте Керас працювати за допомогою багатоядерної багатоядерної процесорної системи


11

Я працюю над моделлю Seq2Seq, використовуючи LSTM від Keras (використовуючи фон Theano), і я хотів би паралелізувати процеси, тому що навіть на кілька МБ даних потрібно кілька годин для навчання.

Зрозуміло, що GPU набагато кращі в паралелізації, ніж процесори. На даний момент у мене є лише процесори, з якими можна працювати. Я міг отримати доступ до 16 процесорів (2 нитки на ядро ​​X 4 ядра на сокет X 2 розетки)

З документа про багатоядерну підтримку в Theano, мені вдалося використати всі чотири ядра одного сокета. Таким чином, в основному ЦП використовується 400% використання з використовуваними 4CPU, а решта 12 процесорів залишаються невикористаними. Як я їх також використовую. Tensorflow також може використовуватися замість фону Theano, якщо він працює.

введіть тут опис зображення

Відповіді:


7

Для того щоб встановити кількість потоків, використовуваних у Theano (а отже, і кількість процесорних ядер), вам потрібно буде встановити кілька параметрів у середовищі:

import os
os.environ['MKL_NUM_THREADS'] = '16'
os.environ['GOTO_NUM_THREADS'] = '16'
os.environ['OMP_NUM_THREADS'] = '16'
os.eviron['openmp'] = 'True'

Це повинно дозволяти вам використовувати всі ядра всіх процесорів.

Звичайно, це можна зробити і в Тенсдорфлоу:

import tensorflow as tf
from keras.backend import tensorflow_backend as K

with tf.Session(config=tf.ConfigProto(
                    intra_op_parallelism_threads=16)) as sess:
    K.set_session(sess)
    <Your Keras code>

Я буду вдячний за цей код за встановлення кількості ядер в R (я використовую кери в R, і він використовує лише частину ядер ...). Особливо для tensorflow. Спасибі, Мілан
gutompf

2
Ви можете використовувати Sys.setenv () для встановлення змінних середовища в R, подібно до використання os.environ [var] в Python.
Томас Кліберг
Використовуючи наш веб-сайт, ви визнаєте, що прочитали та зрозуміли наші Політику щодо файлів cookie та Політику конфіденційності.
Licensed under cc by-sa 3.0 with attribution required.