Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Leer Training en Optimalisatie | Sectie
Generatieve Deep Learning

bookTraining en Optimalisatie

Veeg om het menu te tonen

Het trainen van generatieve modellen omvat het optimaliseren van vaak instabiele en complexe verlieslandschappen. Deze sectie introduceert verliesfuncties die zijn afgestemd op elk modeltype, optimalisatiestrategieën om de training te stabiliseren, en methoden voor het verfijnen van voorgetrainde modellen voor specifieke toepassingen.

Kernverliesfuncties

Verschillende families van generatieve modellen gebruiken verschillende verliesformuleringen, afhankelijk van hoe zij datadistributies modelleren.

GAN-verliezen

Minimax-verlies (originele GAN)

Adversariële opzet tussen generator GG en discriminator DD (voorbeeld met pythorch-bibliotheek):

loss_D = -torch.mean(torch.log(D(real_data)) + torch.log(1. - D(fake_data)))
loss_G = -torch.mean(torch.log(D(fake_data)))

Least squares GAN (LSGAN)

Gebruikt L2-verlies in plaats van log-verlies om stabiliteit en gradiëntdoorstroming te verbeteren:

loss_D = 0.5 * torch.mean((D(real_data) - 1) ** 2 + D(fake_data) ** 2)
loss_G = 0.5 * torch.mean((D(fake_data) - 1) ** 2)

Wasserstein GAN (WGAN)

Minimaliseert Earth Mover (EM) afstand; vervangt de discriminator door een "critic" en gebruikt gewichtsafkapping of gradiëntpenalty voor Lipschitz-continuïteit:

loss = torch.mean(D(fake_data)) - torch.mean(D(real_data)) + gradient_penalty

VAE-verlies

Evidence Lower Bound (ELBO)

Combineert reconstructie en regularisatie. De KL-divergentie term stimuleert dat de latente posterior dicht bij de prior blijft (meestal standaardnormaal):

recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + kl_div

Diffusiemodel-verliesfuncties

Ruisvoorspellingsverlies

Modellen leren toegevoegde Gaussische ruis te verwijderen volgens een diffusieschema. Varianten gebruiken snelheidsvoorspelling (bijv. v-prediction in Stable Diffusion v2) of hybride doelstellingen:

noise = torch.randn_like(x)
x_t = q_sample(x, t, noise)
pred_noise = model(x_t, t)
loss = F.mse_loss(pred_noise, noise)

Optimalisatietechnieken

Het trainen van generatieve modellen is vaak instabiel en gevoelig voor hyperparameters. Diverse technieken worden toegepast om convergentie en kwaliteit te waarborgen.

Optimalisatoren en Planners

  • Adam / AdamW: adaptieve gradiëntoptimalisatoren zijn de standaard. Gebruik β1=0.5, β2=0.999\beta_1=0.5,\ \beta_2=0.999 voor GANs;
  • RMSprop: soms gebruikt in WGAN-varianten;
  • Learning rate scheduling:
    • Warm-up-fases voor transformers en diffusiemodellen;
    • Cosinusafname of ReduceLROnPlateau voor stabiele convergentie.
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.5, 0.999))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

Stabilisatiemethoden

  • Gradiëntclipping: voorkomt exploderende gradiënten in RNNs of diepe UNets;
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  • Spectrale normalisatie: toegepast op discriminatorlagen in GANs om Lipschitz-voorwaarden af te dwingen;
from torch.nn.utils import spectral_norm
layer = spectral_norm(nn.Linear(100, 100))
  • Label smoothing: verzacht harde labels (bijv. echt = 0,9 in plaats van 1,0) om overmatige zekerheid te verminderen;
  • Two-time-scale update rule (TTUR): gebruik verschillende leersnelheden voor generator en discriminator om de convergentie te verbeteren;
  • Mixed-precision training: maakt gebruik van FP16 (via NVIDIA Apex of PyTorch AMP) voor snellere training op moderne GPU's.
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
    output = model(input)
    loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Note
Opmerking

Monitor generator- en discriminatorverliezen afzonderlijk. Gebruik periodiek metriek zoals FID of IS om de werkelijke outputkwaliteit te evalueren in plaats van uitsluitend op verlieswaarden te vertrouwen.

Fijn-afstemming van Voorgetrainde Generatieve Modellen

Voorgetrainde generatieve modellen (zoals Stable Diffusion, LLaMA, StyleGAN2) kunnen worden fijn-afgestemd voor domeinspecifieke taken met lichtere trainingsstrategieën.

Transfer Learning-technieken

  • Volledige fijn-afstemming: alle modelgewichten opnieuw trainen. Hoge rekeneisen maar maximale flexibiliteit;
model = AutoModel.from_pretrained('model-name')
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
  • Laag opnieuw bevriezen / geleidelijk ontdooien: begin met het bevriezen van de meeste lagen en ontdooi vervolgens geselecteerde lagen geleidelijk voor betere fijn-afstemming. Dit voorkomt catastrofale vergeten. Het bevriezen van vroege lagen helpt om algemene kenmerken van de voortraining te behouden (zoals randen of woordpatronen), terwijl het ontdooien van latere lagen het model in staat stelt taakspecifieke kenmerken te leren;
for param in model.parameters():
    param.requires_grad = False
# Unfreeze final transformer block or decoder
for param in model.transformer.block[-1].parameters():
    param.requires_grad = True
  • LoRA / adapterlagen: injecteer laag-rang trainbare lagen zonder de parameters van het basismodel bij te werken;
from peft import get_peft_model, LoraConfig, TaskType

config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, r=8, lora_alpha=16, lora_dropout=0.1)
model = get_peft_model(base_model, config)
  • DreamBooth / tekstuele inversie (diffusiemodellen):
    • Fijn-afstemming op een klein aantal onderwerp-specifieke afbeeldingen.
    • Gebruik diffusers pipeline:
from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipeline.train(texts, images)  # pseudo-call: use DreamBooth training scripts in practice
  • Prompt tuning / p-tuning:
from peft import PromptTuningConfig, get_peft_model
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=20)
model = get_peft_model(base_model, config)

Veelvoorkomende Toepassingen

  • Stijlaanpassing: fijn-afstemming op anime-, strip- of artistieke datasets;
  • Industriespecifieke afstemming: aanpassen van LLM's aan juridische, medische of zakelijke domeinen;
  • Personalisatie: aangepaste identiteit of stemconditionering met behulp van kleine referentiesets.
Note
Opmerking

Gebruik Hugging Face PEFT voor LoRA/adapter-gebaseerde methoden en de Diffusers-bibliotheek voor lichtgewicht fine-tuning pipelines met ingebouwde ondersteuning voor DreamBooth en classifier-free guidance.

Samenvatting

  • Gebruik modelspecifieke verliesfuncties die overeenkomen met trainingsdoelstellingen en modelstructuur;
  • Optimaliseer met adaptieve methoden, stabilisatietechnieken en efficiënte planning;
  • Fijn-afstemmen van voorgetrainde modellen met moderne low-rank- of prompt-gebaseerde transferstrategieën om kosten te verlagen en domeinaanpasbaarheid te vergroten.

1. Wat is een primair doel van het gebruik van regularisatietechnieken tijdens training?

2. Welke van de volgende optimizers wordt vaak gebruikt voor het trainen van deep learning modellen en past het leerrendement aan tijdens training?

3. Wat is de belangrijkste uitdaging bij het trainen van generatieve modellen, vooral in de context van GANs (Generative Adversarial Networks)?

question mark

Wat is een primair doel van het gebruik van regularisatietechnieken tijdens training?

Select the correct answer

question mark

Welke van de volgende optimizers wordt vaak gebruikt voor het trainen van deep learning modellen en past het leerrendement aan tijdens training?

Select the correct answer

question mark

Wat is de belangrijkste uitdaging bij het trainen van generatieve modellen, vooral in de context van GANs (Generative Adversarial Networks)?

Select the correct answer

Was alles duidelijk?

Hoe kunnen we het verbeteren?

Bedankt voor je feedback!

Sectie 1. Hoofdstuk 12

Vraag AI

expand

Vraag AI

ChatGPT

Vraag wat u wilt of probeer een van de voorgestelde vragen om onze chat te starten.

Sectie 1. Hoofdstuk 12
some-alt