Трансферное обучение, или Transfer Learning (TL) — это метод в машинном обучении, при котором модель, обученная для одной задачи, переиспользуется для другой, связанной задачи.
Представим, что человек умеет играть на гитаре и хочет освоить укулеле. Его навыки помогут сделать это быстрее, потому что техники и приемы игры похожи. Так же работает трансферное обучение. Вместо того чтобы обучать модель с нуля, этот способ помогает использовать существующие знания в виде сохраненных предобученных моделей.
Вместе с Марией Жаровой, дата-сайентистом в Альфа-Банке, рассказываем, как работает TL, где оно используется, и разбираем конкретные кейсы.
Чем трансферное обучение отличается от обычного
В целом, Transfer Learning — это просто дообучение ML-алгоритмов для решения других задач. Можно выделить несколько различий между ML и TL:
Трансферное обучение помогает:
- Экономить ресурсы. Поскольку модель не нужно обучать с нуля, это сокращает трудозатраты и снижает требования к вычислительным мощностям, которых может не быть у компании и тем более у конкретного специалиста при выполнении пет-проектов.
- Повысить качество результатов. На ограниченном наборе данных под конкретную задачу TL-модели могут дать лучшие результаты по сравнению с обычным обучением.
- Ускорить обучение. Поскольку предобученная модель уже знает общие признаки, дообучить ее на небольшом наборе данных получится значительно быстрее.
Как развивалось трансферное обучение
В 1976 году С. Божиновски и А. Фулгоси опубликовали статью о передаче знаний между нейронными сетями. Это один из ранних примеров исследований в области, которая позже станет известна как Transfer Learning. Однако сам термин был сформулирован и популяризирован значительно позже.
Современное понимание TL сформировалось, когда ученые начали активно исследовать и использовать перенос знаний между различными задачами и моделями. Так в 2012 году AlexNet, архитектура нейросети, разработанная Алексом Крижевским, Ильей Суцкевером и Джеффри Хинтоном, выиграла в конкурсе ImageNet Large-Scale Visual Recognition Challenge. Эту нейросеть, предобученную на большом наборе данных, можно было потом дообучить на новом наборе.
В 2014 году Google представил архитектуру Inception (GoogLeNet), которая также использовала идеи трансферного обучения для улучшения производительности в задачах компьютерного зрения (CV).
В 2018 году Google представил модель Bidirectional Encoder Representations from Transformers (BERT), которую предварительно обучили на огромном пласте текстов, а затем дообучили для выполнения специфичных задач по обработке естественного языка.
Принципы и механизм работы трансферного обучения
Перед применением трансферного обучения нужно выяснить:
- Нуждается ли модель в дообучении — следует провести предварительный анализ, чтобы определить, будет ли использование TL эффективным и не ухудшит ли это результаты.
- Какую часть знаний можно передать от TL-модели к финальной, чтобы она выполняла поставленную задачу, — не все знания из исходной модели могут быть полезны. Например, в задачах компьютерного зрения модель может переносить общие признаки, такие как формы и текстуры, которые нужны для распознавания разных объектов. Придется экспериментировать, чтобы определить, какие именно части знаний можно и нужно перенести.
- Как переносить знания из исходной модели — разные задачи и предметные области могут требовать разных методов переноса. Поэтому перед применением TL нужно выбрать подходящий, чтобы он обеспечивал наилучшие результаты.
Само обучение строится по двум основным принципам:
- Предобучение модели (pre-training): сначала модель обучается на большом наборе данных, который не обязательно связан с конкретной задачей, но предоставляет много информации для обучения. Например, в случае обработки изображений модель может обучаться на большом наборе фотографий с различными объектами.
- Адаптация к новой задаче (fine-tuning): после предобучения модель дообучается на более специализированном наборе данных. Например, если нужно распознавать растения из Красной книги, модель будет дообучена на изображениях этих растений.
Как это устроено внутри модели
Модели ML состоят из слоев. Это множество нейронов, которые выполняют определенные вычисления над входными данными.
Каждый слой извлекает из данных признаки разного уровня сложности:
- Начальные слои: обрабатывают исходные данные и извлекают базовые признаки.
- Глубокие слои: обрабатывают данные, полученные от начальных слоев, и выявляют более сложные и абстрактные паттерны, такие как формы объектов и их взаимосвязи.
- Последний слой: отвечает за классификацию исходных классов — категорий и меток объектов данных (например, кошка или собака).
В ML также существует понятие домена:
- Домен предметной области — область знаний, в которой применяется модель машинного обучения (например, медицина, финансы, обработка естественного языка).
- Домен данных — это пространство или набор значений, которые могут принимать данные, используемые для обучения модели (например, возраст от 0 до 80 лет).
В контексте трансферного обучения домен — это область знаний или набор данных, на которых обучена модель (исходный домен) и в которых она будет применяться после дообучения (целевой домен).
Эти домены могут относиться к разным предметным областям и содержать данные, различающиеся по характеристикам. Трансферное обучение предполагает перенос знаний из исходного домена в целевой. Этот процесс включает в себя адаптацию слоев модели («заморозку», замену, если это необходимо, и дообучение) к новому домену данных и, возможно, к новому предметному домену, с учетом различий между ними.
Например, если модель обучалась на анализе рентгеновских снимков людей (исходный домен), она может быть адаптирована для анализа рентгеновских снимков животных (целевой домен). Несмотря на различие в предметных доменах (медицина и ветеринария), в обоих случаях данные представляют собой изображения медицинских сканирований и основные принципы обработки изображений могут быть схожими.
Трансферное обучение на примере CV
Рассмотрим процесс обучения на примере компьютерного зрения. Его можно разделить на семь основных этапов:
- Получение предобученной модели.
На первом этапе получают модель, предварительно обученную на большом наборе данных, таком как ImageNet. Иначе говоря, берут модель, которая прошла pre-training. Она уже умеет распознавать общие визуальные признаки, такие как формы, текстуры и объекты, а также может различать более специфичные параметры.
- Определение базовой модели.
Модель готовят для дальнейшего обучения: анализируют ее архитектуру и выясняют, какие изменения нужно будет внести — какие слои могут быть заменены (если это необходимо) или дообучены для целевого домена.
Также настраивают базовые гиперпараметры — те, которые устанавливают перед началом обучения, чтобы потом управлять его процессом. К ним относятся размер пакета данных (batch size) и скорость обучения (learning rate). Затем готовят программы и оборудование.
Следующие этапы представляют из себя fine-tuning (тонкую настройку, или дообучение):
- «Заморозка» слоев
Сперва «замораживают» выбранные начальные и глубокие слои модели, которые уже были обучены распознавать определенные признаки. Они останутся неизменными в процессе дальнейшего обучения.
«Заморозка» позволяет сохранить приобретенные знания имеющихся слоев и сосредоточиться на дообучении последних слоев модели, которые будут адаптированы к специфике новой задачи.
- Модификация последних слоев
После того как нужные слои «заморозили», предстоит заменить или добавить последние слои, чтобы модель смогла решать новую задачу. Количество классов объектов может быть другим, поэтому старый слой заменяют на новый.
Например, можно добавить новый слой с 10 нейронами, который будет отвечать за классификацию по 10 классам целевого домена.
- Продолжение обучения сети
На этом этапе модель продолжает обучение с новым набором данных, чтобы адаптироваться к конкретной задаче.
- Оценка точности модели
На тестовом наборе данных оценивают точность модели — насколько хорошо она справляется с задачей. Обычно есть несколько параллельно обучаемых моделей, отличных друг от друга.
- Выбор лучшей модели и интеграция
После сравнения нескольких моделей выбирают ту, которая лучше обучена, и интегрируют ее в систему для решения реальных задач.
Подходы и технологии трансферного обучения
Основные подходы к TL
Однородное трансферное обучение (Homogeneous Transfer Learning) — применяется, когда исходный и целевой домены имеют одинаковую или очень схожую структуру данных и решают аналогичные задачи.
Пример: модель, обученная на распознавании объектов на фотографиях, может быть адаптирована для распознавания объектов в видеороликах. Несмотря на различие в формате данных, задача распознавания остается схожей, потому что видеоролики состоят из последовательностей изображений.
Перенос экземпляра (Instance Transfer) — это метод трансферного обучения, при котором выбираются и повторно используются отдельные примеры (экземпляры) из исходного набора данных. Даже если данные из исходного и целевого доменов немного отличаются, схожие примеры могут помочь модели быстрее адаптироваться и лучше справляться с новой задачей.
Пример: модель обучена распознавать фотографии продуктов из интернет-магазина, таких как овощи, фрукты, упаковки продуктов и т. д. Теперь нужно адаптировать ее для распознавания блюд на фотографиях из ресторанов. Хотя изображения отличаются по стилю и контексту, отдельные продукты, такие как овощи или мясо, все еще могут быть узнаваемыми для модели.
Функция-передача представления (Feature Representation Transfer) — это метод, при котором модель, обученная на одном наборе данных, использует свои навыки выделения важных признаков (характеристик) для работы с новыми данными.
Пример: если модель научилась хорошо распознавать общие текстуры и формы на изображениях автомобилей (исходный домен), она может использовать эти навыки, чтобы лучше распознавать текстуры и формы на новых изображениях, например мебели или одежды (целевой домен).
Популярные модели трансферного обучения
- ResNet (Residual Neural Network) — тип архитектуры нейронной сети, разработанный для эффективного обучения глубоких сетей. Ключевая особенность — резидуальные связи (residual connections), которые позволяют передавать информацию, минуя один или несколько промежуточных слоев.
- Трансформеры (Transformers) — архитектура нейронной сети, которая стала стандартом для задач обработки естественного языка (NLP) и других последовательных данных. В основе трансформеров лежит механизм внимания (attention), который позволяет модели фокусироваться на важных частях входных данных, игнорируя менее важные.
Примеры трансформеров — Bidirectional Encoder Representations from Transformers (BERT) и Generative Pre-trained Transformer (GPT). Обе применяют функцию передачи представлений (feature representation transfer) — извлекают и используют полезные признаки из текста.
Инструменты трансферного обучения
- TensorFlow — библиотека с открытым исходным кодом для машинного обучения и глубокого обучения, разработанная Google. Поддерживает множество методов трансферного обучения и предоставляет удобные инструменты для дообучения предварительно обученных моделей.
TensorFlow Hub — платформа для поиска предварительно обученных моделей.
TensorFlow Model Garden — коллекция моделей для различных задач, включая модели для трансферного обучения.
- PyTorch — фреймворк для глубокого обучения, разработанный Facebook. Предоставляет инструменты для трансферного обучения и работы с предварительно обученными моделями.
torchvision.models — пакет моделей для задач компьютерного зрения с такими архитектурами, как AlexNet, ResNet, DenseNet и др.
PyTorch Hub — хранилище предварительно обученных моделей для разных задач. - Hugging Face — библиотека, которая изначально специализировалась на моделях для обработки естественного языка. Сейчас она предоставляет инструменты для трансферного обучения моделей из самых разных областей: CV, NLP, audio.
- Transformers library — поддержка предобученных моделей для различных задач.
- Tokenizers — токенизаторы для предварительной обработки текстов. Эти компоненты преобразуют текст в данные, которые потом будет обрабатывать модель.
- Keras — высокоуровневый API для глубокого обучения, который является частью TensorFlow. Предлагает простой и удобный синтаксис для создания и дообучения моделей.
tf.keras.applications — набор предварительно обученных моделей, таких как VGG, ResNet, Inception, для задач компьютерного зрения. - Preprocessing and Data Augmentation (класс ImageDataGenerator) — инструменты для подготовки данных и их аугментации (генерации новых данных на основе уже существующих).
Где используют трансферное обучение
Обработка естественного языка (NLP)
Модели-трансформеры — BERT и GPT — сначала обучаются на больших объемах текстов, чтобы понять контексты и грамматику. После этого их дообучают на специфических задачах, чтобы добиться более глубокого понимания текста.
Примеры:
- Анализ тональности текста. Модели BERT, например, могут быть дообучены для анализа тональности в отзывах на товары. Так они смогут различать положительные и отрицательные мнения.
- Классификация текстов. Модели GPT можно адаптировать к классификации текстов по категориям, таким как новости, научные публикации или поэзия. Это помогает улучшить автоматическую сортировку и анализ больших объемов текстов.
Компьютерное зрение (CV)
Для решения любых задач классификации в CV часто используют модели такие как Inception v4, предварительно обученные на крупных датасетах, например ImageNet и CIFAR. Их легко дообучить под более узкие задачи. Это особенно полезно, когда у нас есть много данных для одной задачи, но мало для другой (как, например, в медицине).
Примеры:
- Медицинские изображения. Например, модели ResNet или VGG могут быть дообучены для определения заболеваний на рентгеновских снимках.
- Распознавание объектов. Модели также можно адаптировать к распознаванию конкретных объектов, таких как транспортные средства на видеозаписях с камер наблюдения.
Распознавание речи
Модели для распознавания речи, такие как DeepSpeech или Wav2Vec, обучаются на больших объемах аудиоданных и умеют обрабатывать различные акценты и произношения. После этого они также могут быть дообучены для конкретных задач.
Примеры применения:
- Медицина. Речевые модели можно обучить отвечать на вопросы из медицинской практики. В будущем возможно появление специальных цифровых ассистентов для врачей и пациентов.
- Юридические консультации. Модели можно дообучить для работы с юридической терминологией — это поможет создавать точные транскрипции судебных заседаний и других юридических документов.
Кейс: классификация изображений кошек и собак
Шаг 1. Выбор и загрузка предобученной модели
Инструмент: TensorFlow, PyTorch
Загрузите предобученную модель, такую как ResNet или VGG, которая уже обучена на наборе данных ImageNet.
Шаг 2. Замена последнего слоя
Замените последний слой модели на новый, который будет иметь два выхода (по одному для кошек и собак).
Например, в Keras это может выглядеть так:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(2, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
Шаг 3. Обучение модели на новой задаче
Дообучите модель на наборе данных с изображениями кошек и собак. Например, используя Keras:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=10, validation_data=(val_images, val_labels))
Кейс: анализ тональности отзывов
Шаг 1. Выбор и загрузка предобученной модели
Инструмент: Hugging Face Transformers
Загрузите предобученную модель, такую как BERT, которая уже обучена на большом корпусе текстов:
from transformers import BertTokenizer, BertForSequenceClassification
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
Шаг 2. Подготовка данных для обучения
Инструмент: Hugging Face Datasets, TensorFlow, PyTorch
Подготовьте набор данных с отзывами, преобразуйте текстовые данные в формат для работы с моделью BERT:
def tokenize_function(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True)
train_dataset = train_dataset.map(tokenize_function, batched=True)
Шаг 3. Дообучение модели на новой задачи
Инструмент: Hugging Face Transformers, TensorFlow, PyTorch
Обучите модель на данных с отзывами, чтобы она могла классифицировать их по тональности:
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir='./results',
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
Полезные ресурсы для изучения
- Документация TensorFlow и PyTorch — очень много примеров и обучающих туториалов.
- «Глубокое обучение», Я. Гудфеллоу, И. Бенджио, А. Курвилль — книга для знакомства с основами глубокого обучения, включая трансферное. Отлично подойдет для создания фундамента в этой области.
- «Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems», Aurélien Géron — книга, которая дает основы машинного обучения и объясняет, как трансферное обучение применяется на практике в контексте реальных задач.
- Hugging Face — мини-курсы по работе с NLP, аудио и CV.
Подведем итоги
- TL позволяет использовать знания, полученные из одной задачи, для улучшения производительности в другой, похожей задаче.
- Обучение моделей с нуля требует большого объема данных, в то время как TL использует уже обученные модели и адаптирует их к новым задачам, что экономит время и ресурсы.
- TL широко используется в CV, NLP и распознавании речи.
- Процесс TL состоит из получения предобученной модели, создания базовой модели, «заморозки» слоев, замены последнего слоя и последующего обучения до конечного результата.
- Для TL используют такие инструменты, как TensorFlow, Keras, PyTorch, Hugging Face Transformers и др.