Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Aprenda Treinamento e Otimização | Seção
/
Aprendizado Profundo Generativo

bookTreinamento e Otimização

Deslize para mostrar o menu

O treinamento de modelos generativos envolve a otimização de paisagens de perda frequentemente instáveis e complexas. Esta seção apresenta funções de perda adaptadas a cada tipo de modelo, estratégias de otimização para estabilizar o treinamento e métodos para ajuste fino de modelos pré-treinados para casos de uso personalizados.

Funções de Perda Principais

Diferentes famílias de modelos generativos utilizam formulações de perda distintas, dependendo de como modelam distribuições de dados.

Perdas em GANs

Perda minimax (GAN original)

Configuração adversarial entre gerador GG e discriminador DD (exemplo com a biblioteca pythorch):

loss_D = -torch.mean(torch.log(D(real_data)) + torch.log(1. - D(fake_data)))
loss_G = -torch.mean(torch.log(D(fake_data)))

Least squares GAN (LSGAN)

Utiliza perda L2 em vez de log loss para melhorar a estabilidade e o fluxo do gradiente:

loss_D = 0.5 * torch.mean((D(real_data) - 1) ** 2 + D(fake_data) ** 2)
loss_G = 0.5 * torch.mean((D(fake_data) - 1) ** 2)

Wasserstein GAN (WGAN)

Minimiza a distância Earth Mover (EM); substitui o discriminador por um "crítico" e utiliza weight clipping ou penalidade de gradiente para garantir a continuidade de Lipschitz:

loss = torch.mean(D(fake_data)) - torch.mean(D(real_data)) + gradient_penalty

Perda VAE

Evidence Lower Bound (ELBO)

Combina reconstrução e regularização. O termo de divergência KL incentiva o posterior latente a permanecer próximo ao prior (geralmente normal padrão):

recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + kl_div

Perdas em Modelos de Difusão

Perda de Predição de Ruído

Os modelos aprendem a remover o ruído gaussiano adicionado ao longo de um cronograma de difusão. Variantes utilizam predição de velocidade (por exemplo, v-prediction no Stable Diffusion v2) ou objetivos híbridos:

noise = torch.randn_like(x)
x_t = q_sample(x, t, noise)
pred_noise = model(x_t, t)
loss = F.mse_loss(pred_noise, noise)

Técnicas de Otimização

O treinamento de modelos generativos é frequentemente instável e sensível a hiperparâmetros. Diversas técnicas são empregadas para garantir convergência e qualidade.

Otimizadores e Agendadores

  • Adam / AdamW: otimizadores adaptativos de gradiente são o padrão de fato. Utilizar β1=0.5, β2=0.999\beta_1=0.5,\ \beta_2=0.999 para GANs;
  • RMSprop: às vezes utilizado em variantes de WGAN;
  • Agendamento da taxa de aprendizado:
    • Fases de aquecimento para transformers e modelos de difusão;
    • Decaimento cosseno ou ReduceLROnPlateau para convergência estável.
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.5, 0.999))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

Métodos de Estabilização

  • Clipping de gradiente: evitar explosão de gradientes em RNNs ou UNets profundas;
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  • Normalização espectral: aplicada às camadas do discriminador em GANs para impor restrições de Lipschitz;
from torch.nn.utils import spectral_norm
layer = spectral_norm(nn.Linear(100, 100))
  • Suavização de rótulos: suaviza rótulos rígidos (por exemplo, real = 0,9 em vez de 1,0) para reduzir o excesso de confiança;
  • Regra de atualização em duas escalas de tempo (TTUR): utiliza taxas de aprendizado diferentes para o gerador e o discriminador para melhorar a convergência;
  • Treinamento com precisão mista: utiliza FP16 (via NVIDIA Apex ou PyTorch AMP) para treinamento mais rápido em GPUs modernas.
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
    output = model(input)
    loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Note
Nota

Monitorar separadamente as perdas do gerador e do discriminador. Utilizar métricas como FID ou IS periodicamente para avaliar a qualidade real da saída em vez de depender apenas dos valores de perda.

Ajuste Fino de Modelos Generativos Pré-Treinados

Modelos generativos pré-treinados (por exemplo, Stable Diffusion, LLaMA, StyleGAN2) podem ser ajustados para tarefas específicas de domínio utilizando estratégias de treinamento mais leves.

Técnicas de Aprendizado por Transferência

  • Ajuste fino completo: re-treinamento de todos os pesos do modelo. Alto custo computacional, porém máxima flexibilidade;
model = AutoModel.from_pretrained('model-name')
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
  • Re-congelamento de camadas / descongelamento gradual: iniciar congelando a maioria das camadas e, em seguida, descongelar gradualmente camadas selecionadas para um ajuste fino mais eficiente. Isso evita o esquecimento catastrófico. Congelar as camadas iniciais ajuda a preservar características gerais do pré-treinamento (como bordas ou padrões de palavras), enquanto descongelar as camadas finais permite que o modelo aprenda características específicas da tarefa;
for param in model.parameters():
    param.requires_grad = False
# Unfreeze final transformer block or decoder
for param in model.transformer.block[-1].parameters():
    param.requires_grad = True
  • LoRA / camadas adaptadoras: injeção de camadas treináveis de baixa classificação sem atualização dos parâmetros do modelo base;
from peft import get_peft_model, LoraConfig, TaskType

config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, r=8, lora_alpha=16, lora_dropout=0.1)
model = get_peft_model(base_model, config)
  • DreamBooth / inversão textual (modelos de difusão):
    • Ajuste fino em um pequeno conjunto de imagens específicas do sujeito.
    • Utilização do pipeline diffusers:
from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipeline.train(texts, images)  # pseudo-call: use DreamBooth training scripts in practice
  • Ajuste de prompt / p-tuning:
from peft import PromptTuningConfig, get_peft_model
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=20)
model = get_peft_model(base_model, config)

Casos de Uso Comuns

  • Adaptação de estilo: ajuste fino em conjuntos de dados de anime, quadrinhos ou artísticos;
  • Ajuste específico para a indústria: adaptação de LLMs para domínios jurídicos, médicos ou corporativos;
  • Personalização: condicionamento de identidade ou voz personalizada usando pequenos conjuntos de referência.
Note
Nota

Utilize Hugging Face PEFT para métodos baseados em LoRA/adapter, e a biblioteca Diffusers para pipelines de ajuste fino leves com suporte integrado para DreamBooth e orientação livre de classificadores.

Resumo

  • Utilização de funções de perda específicas do modelo que correspondam aos objetivos de treinamento e à estrutura do modelo;
  • Otimização com métodos adaptativos, técnicas de estabilização e agendamento eficiente;
  • Ajuste fino de modelos pré-treinados utilizando estratégias modernas de transferência de baixa ordem ou baseadas em prompt para reduzir custos e aumentar a adaptabilidade ao domínio.

1. Qual das alternativas a seguir é um dos principais objetivos do uso de técnicas de regularização durante o treinamento?

2. Qual dos seguintes otimizadores é comumente utilizado para treinar modelos de deep learning e adapta a taxa de aprendizado durante o treinamento?

3. Qual é o principal desafio ao treinar modelos generativos, especialmente no contexto de GANs (Redes Geradoras Adversariais)?

question mark

Qual das alternativas a seguir é um dos principais objetivos do uso de técnicas de regularização durante o treinamento?

Select the correct answer

question mark

Qual dos seguintes otimizadores é comumente utilizado para treinar modelos de deep learning e adapta a taxa de aprendizado durante o treinamento?

Select the correct answer

question mark

Qual é o principal desafio ao treinar modelos generativos, especialmente no contexto de GANs (Redes Geradoras Adversariais)?

Select the correct answer

Tudo estava claro?

Como podemos melhorá-lo?

Obrigado pelo seu feedback!

Seção 1. Capítulo 12

Pergunte à IA

expand

Pergunte à IA

ChatGPT

Pergunte o que quiser ou experimente uma das perguntas sugeridas para iniciar nosso bate-papo

Seção 1. Capítulo 12
some-alt