生成對抗網(wǎng)絡(luò)(GANs)的訓(xùn)練效果很大程度上取決于其損失函數(shù)的選擇。本研究首先介紹經(jīng)典GAN損失函數(shù)的理論基礎(chǔ),隨后使用PyTorch實現(xiàn)包括原始GAN、最小二乘GAN(LS-GAN)、Wasserstein GAN(WGAN)及帶梯度懲罰的WGAN(WGAN-GP)在內(nèi)的多種損失函數(shù)。
生成對抗網(wǎng)絡(luò)(GANs)的工作原理堪比一場精妙的藝術(shù)創(chuàng)作過程——生成器(Generator)扮演創(chuàng)作者角色,不斷生成作品;判別器(Discriminator)則如同嚴(yán)苛的評論家,持續(xù)提供改進建議。這種對抗學(xué)習(xí)機制促使兩個網(wǎng)絡(luò)在競爭中共同進步。判別器向生成器提供反饋的方式——即損失函數(shù)的設(shè)計——對整個網(wǎng)絡(luò)的學(xué)習(xí)表現(xiàn)有著決定性影響。

GAN的基本原理與經(jīng)典損失函數(shù)
1、原始GAN
Goodfellow等人于2014年提出的原始GAN采用極小極大博弈(Minimax Game)框架,其損失函數(shù)可表述為:

其中:
- 表示判別器對輸入判定為真實樣本的概率
- 表示生成器將隨機噪聲轉(zhuǎn)換為合成圖像的函數(shù)
- 表示真實數(shù)據(jù)分布
- 表示噪聲先驗分布,通常為標(biāo)準(zhǔn)正態(tài)分布
原始GAN在理論上試圖最小化生成分布與真實分布之間的Jensen-Shannon散度(JS散度),但在實際訓(xùn)練中存在梯度消失、模式崩潰和訓(xùn)練不穩(wěn)定等問題。這些局限性促使研究者開發(fā)了多種改進的損失函數(shù)。
PyTorch實現(xiàn):
import torch
import torch.nn as nn
# 原始GAN損失函數(shù)實現(xiàn)
class OriginalGANLoss:
def __init__(self, device):
self.device = device
self.criterion = nn.BCELoss()
def discriminator_loss(self, real_output, fake_output):
# 真實樣本的目標(biāo)標(biāo)簽為1.0
real_labels = torch.ones_like(real_output, device=self.device)
# 生成樣本的目標(biāo)標(biāo)簽為0.0
fake_labels = torch.zeros_like(fake_output, device=self.device)
# 計算判別器對真實樣本的損失
real_loss = self.criterion(real_output, real_labels)
# 計算判別器對生成樣本的損失
fake_loss = self.criterion(fake_output, fake_labels)
# 總損失為兩部分之和
d_loss = real_loss + fake_loss
return d_loss
def generator_loss(self, fake_output):
# 生成器希望判別器將生成樣本判斷為真實樣本
target_labels = torch.ones_like(fake_output, device=self.device)
g_loss = self.criterion(fake_output, target_labels)
return g_loss
2、非飽和損失函數(shù)(Non-Saturating Loss)
為解決原始GAN中生成器梯度消失問題,Goodfellow提出了非飽和損失,將生成器的目標(biāo)函數(shù)修改為:

這種修改保持了相同的最優(yōu)解,但提供了更強的梯度信號,特別是在訓(xùn)練初期生成樣本質(zhì)量較差時,有效改善了學(xué)習(xí)效率。非飽和損失通過直接最大化判別器對生成樣本的預(yù)測概率,而不是最小化判別器正確分類的概率,從而避免了在生成器表現(xiàn)不佳時梯度趨近于零的問題。
PyTorch實現(xiàn):
class NonSaturatingGANLoss:
def __init__(self, device):
self.device = device
self.criterion = nn.BCELoss()
def discriminator_loss(self, real_output, fake_output):
# 與原始GAN相同
real_labels = torch.ones_like(real_output, device=self.device)
fake_labels = torch.zeros_like(fake_output, device=self.device)
real_loss = self.criterion(real_output, real_labels)
fake_loss = self.criterion(fake_output, fake_labels)
d_loss = real_loss + fake_loss
return d_loss
def generator_loss(self, fake_output):
# 非飽和損失:直接最大化log(D(G(z)))
target_labels = torch.ones_like(fake_output, device=self.device)
# 注意這里使用的是相同的BCE損失,但目標(biāo)是讓D將G(z)判斷為真
g_loss = self.criterion(fake_output, target_labels)
return g_loss
GAN變體實現(xiàn)與原理分析
3、最小二乘GAN(LS-GAN)
LS-GAN通過用最小二乘損失替代標(biāo)準(zhǔn)GAN中的二元交叉熵?fù)p失,有效改善了訓(xùn)練過程:

這一修改使得模型在訓(xùn)練過程中梯度變化更為平滑,顯著降低了訓(xùn)練不穩(wěn)定性。LS-GAN的主要優(yōu)勢在于能夠有效減輕模式崩潰問題(即生成器僅產(chǎn)生有限類型樣本的現(xiàn)象),同時促進學(xué)習(xí)過程的連續(xù)性與穩(wěn)定性,使模型能夠更加漸進地學(xué)習(xí)數(shù)據(jù)分布特征。理論上,LS-GAN試圖最小化Pearson 散度,這對于分布重疊較少的情況提供了更好的訓(xùn)練信號。
PyTorch實現(xiàn):
class LSGANLoss:
def __init__(self, device):
self.device = device
# LS-GAN使用MSE損失而非BCE損失
self.criterion = nn.MSELoss()
def discriminator_loss(self, real_output, fake_output):
# 真實樣本的目標(biāo)值為1.0
real_labels = torch.ones_like(real_output, device=self.device)
# 生成樣本的目標(biāo)值為0.0
fake_labels = torch.zeros_like(fake_output, device=self.device)
# 計算真實樣本的MSE損失
real_loss = self.criterion(real_output, real_labels)
# 計算生成樣本的MSE損失
fake_loss = self.criterion(fake_output, fake_labels)
d_loss = real_loss + fake_loss
return d_loss
def generator_loss(self, fake_output):
# 生成器希望生成的樣本被判別為真實樣本
target_labels = torch.ones_like(fake_output, device=self.device)
g_loss = self.criterion(fake_output, target_labels)
return g_loss
4、Wasserstein GAN(WGAN)
WGAN通過引入Wasserstein距離(也稱為地球移動者距離)作為分布差異度量,從根本上改變了GAN的訓(xùn)練機制:

其中是所有滿足1-Lipschitz約束的函數(shù)集合。與傳統(tǒng)GAN關(guān)注樣本真假二分類不同,WGAN評估的是生成分布與真實分布之間的距離,這一方法提供了更為連續(xù)且有意義的梯度信息。WGAN能夠顯著改善梯度傳播問題,有效防止判別器過度主導(dǎo)訓(xùn)練過程,同時大幅減輕模式崩潰現(xiàn)象,提高生成樣本的多樣性。
原始WGAN通過權(quán)重裁剪(weight clipping)實現(xiàn)Lipschitz約束,具體做法是將判別器參數(shù)限制在某個固定范圍內(nèi),如,但這種方法可能會限制網(wǎng)絡(luò)容量并導(dǎo)致病態(tài)行為。
PyTorch實現(xiàn):
class WGANLoss:
def __init__(self, device, clip_value=0.01):
self.device = device
self.clip_value = clip_value
def discriminator_loss(self, real_output, fake_output):
# WGAN的判別器(稱為critic)直接最大化真實樣本和生成樣本輸出的差值
# 注意這里沒有使用sigmoid激活
d_loss = -torch.mean(real_output) + torch.mean(fake_output)
return d_loss
def generator_loss(self, fake_output):
# 生成器希望最大化critic對生成樣本的評分
g_loss = -torch.mean(fake_output)
return g_loss
def weight_clipping(self, critic):
# 權(quán)重裁剪,限制critic參數(shù)范圍
for p in critic.parameters():
p.data.clamp_(-self.clip_value, self.clip_value)
5、帶梯度懲罰的WGAN(WGAN-GP)
WGAN-GP是對WGAN的進一步優(yōu)化,通過引入梯度懲罰項來滿足Lipschitz連續(xù)性約束:

其中是真實樣本和生成樣本之間的隨機插值點。這一改進避免了原始WGAN中權(quán)重裁剪可能帶來的容量限制和訓(xùn)練不穩(wěn)定問題。梯度懲罰使模型訓(xùn)練過程更加穩(wěn)定,同時減少了生成圖像中的偽影,提高了最終生成結(jié)果的質(zhì)量與真實度。WGAN-GP已成為許多高質(zhì)量圖像生成任務(wù)的首選損失函數(shù)。
PyTorch實現(xiàn):
class WGANGP:
def __init__(self, device, lambda_gp=10):
self.device = device
self.lambda_gp = lambda_gp
def discriminator_loss(self, real_output, fake_output, real_samples, fake_samples, discriminator):
# 基本的Wasserstein距離
d_loss = -torch.mean(real_output) + torch.mean(fake_output)
# 計算梯度懲罰
# 在真實和生成樣本之間隨機插值
alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=self.device)
interpolates = alpha * real_samples + (1 - alpha) * fake_samples
interpolates.requires_grad_(True)
# 計算判別器對插值樣本的輸出
d_interpolates = discriminator(interpolates)
# 計算梯度
fake_outputs = torch.ones_like(d_interpolates, device=self.device, requires_grad=False)
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake_outputs,
create_graph=True,
retain_graph=True,
only_inputs=True
)[0]
# 計算梯度L2范數(shù)
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
# 添加梯度懲罰項
d_loss = d_loss + self.lambda_gp * gradient_penalty
return d_loss
def generator_loss(self, fake_output):
# 與WGAN相同
g_loss = -torch.mean(fake_output)
return g_loss
6、條件生成對抗網(wǎng)絡(luò)(CGAN)
CGAN通過在生成器和判別器中引入條件信息(如類別標(biāo)簽),實現(xiàn)對生成過程的控制:

CGAN能夠生成特定類別的樣本,大大增強了模型的實用性,特別是在醫(yī)學(xué)影像等需要精確控制生成內(nèi)容的應(yīng)用場景中。通過條件控制,CGAN可以引導(dǎo)生成過程,使得生成結(jié)果滿足特定的語義或結(jié)構(gòu)要求,為個性化內(nèi)容生成提供了可靠技術(shù)支持。
PyTorch實現(xiàn):
class CGANLoss:
def __init__(self, device):
self.device = device
self.criterion = nn.BCELoss()
def discriminator_loss(self, real_output, fake_output):
# 條件GAN的判別器損失與原始GAN相似,只是輸入增加了條件信息
real_labels = torch.ones_like(real_output, device=self.device)
fake_labels = torch.zeros_like(fake_output, device=self.device)
real_loss = self.criterion(real_output, real_labels)
fake_loss = self.criterion(fake_output, fake_labels)
d_loss = real_loss + fake_loss
return d_loss
def generator_loss(self, fake_output):
# 與原始GAN相似
target_labels = torch.ones_like(fake_output, device=self.device)
g_loss = self.criterion(fake_output, target_labels)
return g_loss
# CGAN的網(wǎng)絡(luò)結(jié)構(gòu)示例
class ConditionalGenerator(nn.Module):
def __init__(self, latent_dim, n_classes, img_shape):
super(ConditionalGenerator, self).__init__()
self.img_shape = img_shape
self.label_emb = nn.Embedding(n_classes, n_classes)
self.model = nn.Sequential(
# 輸入是噪聲向量與條件拼接后的向量
nn.Linear(latent_dim + n_classes, 128),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z, labels):
# 條件嵌入
c = self.label_emb(labels)
# 拼接噪聲和條件
x = torch.cat([z, c], 1)
# 生成圖像
img = self.model(x)
img = img.view(img.size(0), *self.img_shape)
return img
7、信息最大化GAN(InfoGAN)
InfoGAN在無監(jiān)督學(xué)習(xí)框架下實現(xiàn)了對生成樣本特定屬性的控制,其核心思想是最大化潛在編碼與生成樣本之間的互信息:

其中是一個輔助網(wǎng)絡(luò),用于近似后驗分布,而表示互信息。InfoGAN能夠在無監(jiān)督的情況下學(xué)習(xí)數(shù)據(jù)的解耦表示,對于醫(yī)學(xué)圖像分析中的特征提取和異常檢測具有潛在價值。
PyTorch實現(xiàn):
class InfoGANLoss:
def __init__(self, device, lambda_info=1.0):
self.device = device
self.criterion = nn.BCELoss()
self.lambda_info = lambda_info
# 對于離散潛變量使用交叉熵?fù)p失
self.discrete_criterion = nn.CrossEntropyLoss()
# 對于連續(xù)潛變量使用高斯分布負(fù)對數(shù)似然
self.continuous_criterion = nn.MSELoss()
def discriminator_loss(self, real_output, fake_output):
# 判別器損失與原始GAN相同
real_labels = torch.ones_like(real_output, device=self.device)
fake_labels = torch.zeros_like(fake_output, device=self.device)
real_loss = self.criterion(real_output, real_labels)
fake_loss = self.criterion(fake_output, fake_labels)
d_loss = real_loss + fake_loss
return d_loss
def generator_info_loss(self, fake_output, q_discrete, q_continuous, c_discrete, c_continuous):
# 生成器損失部分(欺騙判別器)
target_labels = torch.ones_like(fake_output, device=self.device)
g_loss = self.criterion(fake_output, target_labels)
# 互信息損失部分
# 離散潛變量的互信息損失
info_disc_loss = self.discrete_criterion(q_discrete, c_discrete)
# 連續(xù)潛變量的互信息損失
info_cont_loss = self.continuous_criterion(q_continuous, c_continuous)
# 總損失
total_loss = g_loss + self.lambda_info * (info_disc_loss + info_cont_loss)
return total_loss, info_disc_loss, info_cont_loss
8、能量基礎(chǔ)GAN(EBGAN)
EBGAN將判別器視為能量函數(shù),而非傳統(tǒng)的概率函數(shù),其損失函數(shù)為:

其中表示,是邊界參數(shù)。EBGAN通過能量視角重新詮釋GAN訓(xùn)練過程,為模型設(shè)計提供了新的思路,尤其適合處理具有復(fù)雜分布的醫(yī)學(xué)數(shù)據(jù)。EBGAN的判別器不再輸出概率值,而是輸出能量分?jǐn)?shù),真實樣本的能量應(yīng)當(dāng)?shù)陀谏蓸颖尽?/p>
PyTorch實現(xiàn):
class EBGANLoss:
def __init__(self, device, margin=10.0):
self.device = device
self.margin = margin
def discriminator_loss(self, real_energy, fake_energy):
# 判別器的目標(biāo)是降低真實樣本的能量,提高生成樣本的能量(直到邊界值)
# 對生成樣本的損失使用hinge loss
hinge_loss = torch.mean(torch.clamp(self.margin - fake_energy, min=0))
# 總損失
d_loss = torch.mean(real_energy) + hinge_loss
return d_loss
def generator_loss(self, fake_energy):
# 生成器的目標(biāo)是降低生成樣本的能量
g_loss = torch.mean(fake_energy)
return g_loss
9、f-GAN
f-GAN是一種基于f-散度的GAN框架,可以統(tǒng)一多種GAN變體:

其中是凸函數(shù)的Fenchel共軛。通過選擇不同的函數(shù),f-GAN可以實現(xiàn)對不同散度的優(yōu)化,如KL散度、JS散度、Hellinger距離等,為特定應(yīng)用場景提供了更靈活的選擇。f-GAN為GAN提供了一個統(tǒng)一的理論框架,使研究者能夠根據(jù)具體任務(wù)需求設(shè)計最適合的散度度量。
PyTorch實現(xiàn):
class FGANLoss:
def __init__(self, device, divergence_type='kl'):
self.device = device
self.divergence_type = divergence_type
def activation_function(self, x):
# 不同散度對應(yīng)的激活函數(shù)
if self.divergence_type == 'kl': # KL散度
return x
elif self.divergence_type == 'js': # JS散度
return torch.log(1 + torch.exp(x))
elif self.divergence_type == 'hellinger': # Hellinger距離
return 1 - torch.exp(-x)
elif self.divergence_type == 'total_variation': # 總變差距離
return 0.5 * torch.tanh(x)
else:
return x # 默認(rèn)為KL散度
def conjugate_function(self, x):
# 不同散度的Fenchel共軛
if self.divergence_type == 'kl':
return torch.exp(x - 1)
elif self.divergence_type == 'js':
return -torch.log(2 - torch.exp(x))
elif self.divergence_type == 'hellinger':
return x / (1 - x)
elif self.divergence_type == 'total_variation':
return x
else:
return torch.exp(x - 1) # 默認(rèn)為KL散度
def discriminator_loss(self, real_output, fake_output):
# 判別器損失
# 注意:在f-GAN中,通常D的輸出需要經(jīng)過激活函數(shù)處理
activated_real = self.activation_function(real_output)
d_loss = -torch.mean(activated_real) + torch.mean(self.conjugate_function(fake_output))
return d_loss
def generator_loss(self, fake_output):
# 生成器損失
activated_fake = self.activation_function(fake_output)
g_loss = -torch.mean(activated_fake)
return g_loss
總結(jié)
本文通過詳細(xì)分析GAN的經(jīng)典損失函數(shù)及其多種變體,揭示了不同類型損失函數(shù)各自的優(yōu)勢:LS-GAN訓(xùn)練穩(wěn)定性好,WGAN-GP生成圖像清晰度高,而條件類GAN如CGAN則在可控性方面表現(xiàn)突出。
這介紹代碼對于相關(guān)領(lǐng)域的GAN應(yīng)用具有重要參考價值。未來研究可進一步探索損失函數(shù)組合優(yōu)化策略,以及針對特定圖像模態(tài)的自適應(yīng)損失函數(shù)設(shè)計。
https://avoid.overfit.cn/post/70d0b38796174d1c82ac048375ff17c4
熱門跟貼