Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Вивчайте Навчання та Оптимізація | Секція
Генеративне глибинне навчання

bookНавчання та Оптимізація

Свайпніть щоб показати меню

Навчання генеративних моделей передбачає оптимізацію часто нестабільних і складних ландшафтів функцій втрат. У цьому розділі розглядаються функції втрат, адаптовані до кожного типу моделі, стратегії оптимізації для стабілізації навчання та методи тонкого налаштування попередньо навчених моделей для індивідуальних завдань.

Основні функції втрат

Різні сімейства генеративних моделей використовують окремі формулювання функцій втрат залежно від способу моделювання розподілів даних.

Втрати GAN

Мінімакс-втрата (оригінальний GAN)

Змагальна взаємодія між генератором GG та дискримінатором DD (приклад з бібліотекою pythorch):

loss_D = -torch.mean(torch.log(D(real_data)) + torch.log(1. - D(fake_data)))
loss_G = -torch.mean(torch.log(D(fake_data)))

Least squares GAN (LSGAN)

Використовує L2-втрату замість логарифмічної втрати для підвищення стабільності та покращення градієнтного потоку:

loss_D = 0.5 * torch.mean((D(real_data) - 1) ** 2 + D(fake_data) ** 2)
loss_G = 0.5 * torch.mean((D(fake_data) - 1) ** 2)

Wasserstein GAN (WGAN)

Мінімізує відстань Земельного перевізника (Earth Mover, EM); замінює дискримінатор на "критика" та використовує обрізання ваг або штраф за градієнт для забезпечення ліпшицевої неперервності:

loss = torch.mean(D(fake_data)) - torch.mean(D(real_data)) + gradient_penalty

Втрата VAE

Нижня межа доказу (ELBO)

Поєднує реконструкцію та регуляризацію. Член KL-дивергенції стимулює латентний постеріор залишатися близьким до апріорного розподілу (зазвичай стандартне нормальне):

recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + kl_div

Функції втрат для дифузійних моделей

Втрата передбачення шуму

Моделі навчаються видаляти доданий гаусівський шум протягом дифузійного розкладу. Існують варіанти з передбаченням швидкості (наприклад, v-prediction у Stable Diffusion v2) або гібридні цілі:

noise = torch.randn_like(x)
x_t = q_sample(x, t, noise)
pred_noise = model(x_t, t)
loss = F.mse_loss(pred_noise, noise)

Техніки оптимізації

Навчання генеративних моделей часто є нестабільним і чутливим до гіперпараметрів. Для забезпечення збіжності та якості використовуються різні методи.

Оптимізатори та планувальники

  • Adam / AdamW: адаптивні оптимізатори градієнта є стандартом де-факто. Для GAN використовуйте β1=0.5, β2=0.999\beta_1=0.5,\ \beta_2=0.999;
  • RMSprop: іноді використовується у варіантах WGAN;
  • Планування швидкості навчання:
    • Фази розігріву для трансформерів і дифузійних моделей;
    • Косинусне зменшення або ReduceLROnPlateau для стабільної збіжності.
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.5, 0.999))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

Методи стабілізації

  • Обрізання градієнта: запобігання вибуху градієнтів у RNN або глибоких UNet;
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  • Спектральна нормалізація: застосовується до шарів дискримінатора в GAN для забезпечення обмежень Ліпшиця;
from torch.nn.utils import spectral_norm
layer = spectral_norm(nn.Linear(100, 100))
  • Згладжування міток: пом'якшує жорсткі мітки (наприклад, реальні = 0.9 замість 1.0) для зменшення надмірної впевненості;
  • Правило оновлення з двома часовими масштабами (TTUR): використання різних швидкостей навчання для генератора та дискримінатора для покращення збіжності;
  • Навчання зі змішаною точністю: використання FP16 (через NVIDIA Apex або PyTorch AMP) для прискорення навчання на сучасних GPU.
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
    output = model(input)
    loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Note
Примітка

Окремий моніторинг втрат генератора та дискримінатора. Періодичне використання метрик, таких як FID або IS, для оцінки фактичної якості результату замість орієнтації лише на значення втрат.

Тонке налаштування попередньо навчених генеративних моделей

Попередньо навчені генеративні моделі (наприклад, Stable Diffusion, LLaMA, StyleGAN2) можна адаптувати для задач певної предметної області за допомогою полегшених стратегій навчання.

Техніки перенесення навчання

  • Повне тонке налаштування: повторне навчання всіх ваг моделі. Висока обчислювальна вартість, але максимальна гнучкість;
model = AutoModel.from_pretrained('model-name')
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
  • Повторне заморожування шарів / поступове розморожування: спочатку заморожуються більшість шарів, потім поступово розморожуються вибрані шари для кращого тонкого налаштування. Це дозволяє уникнути катастрофічного забування. Заморожування ранніх шарів допомагає зберегти загальні ознаки з попереднього навчання (наприклад, краї або шаблони слів), а розморожування пізніших дозволяє моделі вивчати ознаки, специфічні для задачі;
for param in model.parameters():
    param.requires_grad = False
# Unfreeze final transformer block or decoder
for param in model.transformer.block[-1].parameters():
    param.requires_grad = True
  • LoRA / адаптерні шари: впровадження тренованих шарів низького рангу без оновлення параметрів базової моделі;
from peft import get_peft_model, LoraConfig, TaskType

config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, r=8, lora_alpha=16, lora_dropout=0.1)
model = get_peft_model(base_model, config)
  • DreamBooth / текстова інверсія (дифузійні моделі):
    • Дотонування на невеликій кількості зображень, специфічних для об'єкта.
    • Використання пайплайну diffusers:
from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipeline.train(texts, images)  # pseudo-call: use DreamBooth training scripts in practice
  • Тюнінг підказок / p-tuning:
from peft import PromptTuningConfig, get_peft_model
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=20)
model = get_peft_model(base_model, config)

Поширені випадки використання

  • Адаптація стилю: донавчання на аніме, коміксах або художніх датасетах;
  • Галузеве налаштування: адаптація LLM до юридичних, медичних або корпоративних сфер;
  • Персоналізація: налаштування ідентичності або голосу за допомогою невеликих референсних наборів.
Note
Примітка

Використовуйте Hugging Face PEFT для методів на основі LoRA/адаптерів, а також бібліотеку Diffusers для легких пайплайнів донавчання з вбудованою підтримкою DreamBooth і classifier-free guidance.

Підсумок

  • Використання специфічних для моделі функцій втрат, що відповідають навчальним цілям і структурі моделі;
  • Оптимізація за допомогою адаптивних методів, стабілізаційних технік і ефективного планування;
  • Донавчання попередньо навчених моделей із використанням сучасних low-rank або prompt-based стратегій трансферу для зниження витрат і підвищення адаптивності до домену.

1. Яка з наведених є основною метою використання технік регуляризації під час навчання?

2. Який із наведених оптимізаторів зазвичай використовується для навчання моделей глибокого навчання та адаптує швидкість навчання під час тренування?

3. Яка основна проблема виникає під час навчання генеративних моделей, особливо у контексті GAN (Generative Adversarial Networks)?

question mark

Яка з наведених є основною метою використання технік регуляризації під час навчання?

Select the correct answer

question mark

Який із наведених оптимізаторів зазвичай використовується для навчання моделей глибокого навчання та адаптує швидкість навчання під час тренування?

Select the correct answer

question mark

Яка основна проблема виникає під час навчання генеративних моделей, особливо у контексті GAN (Generative Adversarial Networks)?

Select the correct answer

Все було зрозуміло?

Як ми можемо покращити це?

Дякуємо за ваш відгук!

Секція 1. Розділ 12

Запитати АІ

expand

Запитати АІ

ChatGPT

Запитайте про що завгодно або спробуйте одне із запропонованих запитань, щоб почати наш чат

Секція 1. Розділ 12
some-alt