import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
# データセットの変換(リサイズと正規化)
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # ピクセル値を[-1, 1]の範囲にスケーリング
])
# AnimeFaceDatasetのロード
dataset = datasets.ImageFolder(root='C:/Users/tyosu/projects/anime_faces',transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
# Generator(生成モデル)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, 64 * 64 * 3),
nn.Tanh()
)
def forward(self, input):
output = self.main(input)
return output.view(-1, 3, 64, 64)
# Discriminator(判別モデル)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(64 * 64 * 3, 1024),
nn.ReLU(True),
nn.Linear(1024, 512),
nn.ReLU(True),
nn.Linear(512, 256),
nn.ReLU(True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, input):
input_flat = input.view(-1, 64 * 64 * 3)
return self.main(input_flat)
# モデルのインスタンス化
G = Generator()
D = Discriminator()
# ロス関数とオプティマイザ
criterion = nn.BCELoss()
optimizerD = optim.Adam(D.parameters(), lr=0.0002)
optimizerG = optim.Adam(G.parameters(), lr=0.0002)
# ランダムノイズ生成関数
def generate_noise(batch_size):
return torch.randn(batch_size, 100)
# トレーニングループ
num_epochs = 50
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(tqdm(dataloader)):
batch_size = real_images.size(0)
# 本物の画像のラベルは1、偽物の画像のラベルは0
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# Discriminatorの学習
optimizerD.zero_grad()
outputs = D(real_images)
real_loss = criterion(outputs, real_labels)
noise = generate_noise(batch_size)
fake_images = G(noise)
outputs = D(fake_images.detach())
fake_loss = criterion(outputs, fake_labels)
d_loss = real_loss + fake_loss
d_loss.backward()
optimizerD.step()
# Generatorの学習
optimizerG.zero_grad()
outputs = D(fake_images)
g_loss = criterion(outputs, real_labels) # 生成画像を本物と認識させたい
g_loss.backward()
optimizerG.step()
print(f'Epoch [{epoch+1}/{num_epochs}] | d_loss: {d_loss.item()} | g_loss: {g_loss.item()}')
# 生成された画像を表示
if (epoch + 1) % 10 == 0:
fake_images = G(generate_noise(64)).detach().cpu()
plt.imshow(fake_images[0].permute(1, 2, 0) * 0.5 + 0.5)