Навчання та оптимізація
Свайпніть щоб показати меню
Навчання генеративних моделей включає оптимізацію часто нестабільних і складних ландшафтів функцій втрат. У цьому розділі розглядаються функції втрат, адаптовані до кожного типу моделі, стратегії оптимізації для стабілізації навчання та методи тонкого налаштування попередньо навчених моделей для індивідуальних випадків використання.
Основні функції втрат
Різні сімейства генеративних моделей використовують різні формулювання функцій втрат залежно від способу моделювання розподілів даних.
Втрати GAN
Мінімакс-втрата (оригінальний GAN)
Змагальна взаємодія між генератором G і дискримінатором D (приклад з бібліотекою pythorch):
loss_D = -torch.mean(torch.log(D(real_data)) + torch.log(1. - D(fake_data)))
loss_G = -torch.mean(torch.log(D(fake_data)))
Least squares GAN (LSGAN)
Використання L2-втрат замість логарифмічних втрат для підвищення стабільності та покращення градієнтного потоку:
loss_D = 0.5 * torch.mean((D(real_data) - 1) ** 2 + D(fake_data) ** 2)
loss_G = 0.5 * torch.mean((D(fake_data) - 1) ** 2)
Wasserstein GAN (WGAN)
Мінімізація відстані Земельного перевізника (Earth Mover, EM); дискримінатор замінюється на "критика" та використовується обрізання ваг або штраф за градієнт для забезпечення властивості Ліпшиця:
loss = torch.mean(D(fake_data)) - torch.mean(D(real_data)) + gradient_penalty
Втрати VAE
Нижня межа доказу (ELBO)
Поєднання реконструкції та регуляризації. Член дивергенції Кульбака-Лейблера (KL) стимулює латентний постеріор залишатися близьким до апріорного розподілу (зазвичай стандартного нормального):
recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + kl_div
Втрати дифузійної моделі
Втрата передбачення шуму
Моделі навчаються видаляти доданий гаусівський шум протягом дифузійного розкладу. Варіанти використовують передбачення швидкості (наприклад, v-prediction у Stable Diffusion v2) або гібридні цілі:
noise = torch.randn_like(x)
x_t = q_sample(x, t, noise)
pred_noise = model(x_t, t)
loss = F.mse_loss(pred_noise, noise)
Техніки оптимізації
Навчання генеративних моделей часто є нестабільним і чутливим до гіперпараметрів. Для забезпечення збіжності та якості застосовуються різні техніки.
Оптимізатори та планувальники
- Adam / AdamW: адаптивні оптимізатори градієнта є стандартом де-факто. Для GAN використовуйте β1=0.5, β2=0.999;
- RMSprop: іноді використовується у варіантах WGAN;
- Планування швидкості навчання:
- Фази розігріву для трансформерів і дифузійних моделей;
- Косинусне спадання або ReduceLROnPlateau для стабільної збіжності.
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.5, 0.999))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
Методи стабілізації
- Обрізання градієнта: запобігання вибуху градієнтів у RNN або глибоких UNet;
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
- Спектральна нормалізація: застосовується до шарів дискримінатора в GAN для забезпечення обмеження Ліпшиця;
from torch.nn.utils import spectral_norm
layer = spectral_norm(nn.Linear(100, 100))
- Згладжування міток: пом'якшення жорстких міток (наприклад, реальні = 0.9 замість 1.0) для зменшення надмірної впевненості;
- Правило оновлення з двома часовими масштабами (TTUR): використання різних швидкостей навчання для генератора та дискримінатора для покращення збіжності;
- Навчання зі змішаною точністю: використання FP16 (через NVIDIA Apex або PyTorch AMP) для пришвидшення навчання на сучасних GPU.
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Окремий моніторинг втрат генератора та дискримінатора. Періодичне використання метрик, таких як FID або IS, для оцінки фактичної якості результату замість орієнтації лише на значення втрат.
Тонке налаштування попередньо навчених генеративних моделей
Попередньо навчені генеративні моделі (наприклад, Stable Diffusion, LLaMA, StyleGAN2) можна тонко налаштовувати для задач певної предметної області за допомогою полегшених стратегій навчання.
Техніки перенавчання (Transfer Learning)
- Повне тонке налаштування: перенавчання всіх ваг моделі. Високі обчислювальні витрати, але максимальна гнучкість;
model = AutoModel.from_pretrained('model-name')
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
- Повторне заморожування шарів / поступове розморожування: початково заморожуються більшість шарів, потім поступово розморожуються вибрані шари для кращого тонкого налаштування. Це дозволяє уникнути катастрофічного забування. Заморожування ранніх шарів допомагає зберегти загальні ознаки з попереднього навчання (наприклад, краї або шаблони слів), а розморожування пізніших дозволяє моделі навчатися ознакам, специфічним для задачі;
for param in model.parameters():
param.requires_grad = False
# Unfreeze final transformer block or decoder
for param in model.transformer.block[-1].parameters():
param.requires_grad = True
- LoRA / адаптерні шари: впровадження тренованих шарів низького рангу без оновлення параметрів базової моделі;
from peft import get_peft_model, LoraConfig, TaskType
config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, r=8, lora_alpha=16, lora_dropout=0.1)
model = get_peft_model(base_model, config)
- DreamBooth / текстова інверсія (дифузійні моделі):
- Донавчання на невеликій кількості зображень, специфічних для об'єкта.
- Використання пайплайну
diffusers:
from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipeline.train(texts, images) # pseudo-call: use DreamBooth training scripts in practice
- Тюнінг підказок / p-tuning:
from peft import PromptTuningConfig, get_peft_model
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=20)
model = get_peft_model(base_model, config)
Поширені випадки використання
- Адаптація стилю: донавчання на аніме, коміксах або художніх датасетах;
- Галузеве налаштування: адаптація LLM до юридичних, медичних або корпоративних сфер;
- Персоналізація: налаштування ідентичності або голосу за допомогою невеликих референсних наборів.
Використання Hugging Face PEFT для методів на основі LoRA/адаптерів, а також бібліотеки Diffusers для легких пайплайнів донавчання з вбудованою підтримкою DreamBooth і classifier-free guidance.
Підсумок
- Використання специфічних для моделі функцій втрат, що відповідають цілям навчання та структурі моделі;
- Оптимізація за допомогою адаптивних методів, стабілізаційних технік і ефективного планування;
- Донавчання попередньо навчених моделей із використанням сучасних low-rank або prompt-based стратегій трансферу для зниження вартості та підвищення адаптивності до домену.
1. Яка з наведених є основною метою використання технік регуляризації під час навчання?
2. Який із наведених оптимізаторів часто використовується для навчання моделей глибокого навчання та адаптує швидкість навчання під час тренування?
3. Яка основна проблема при навчанні генеративних моделей, особливо у контексті GAN (Generative Adversarial Networks)?
Дякуємо за ваш відгук!
Запитати АІ
Запитати АІ
Запитайте про що завгодно або спробуйте одне із запропонованих запитань, щоб почати наш чат