二次元画像生成AI python

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

# データセットの変換(リサイズと正規化)
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # ピクセル値を[-1, 1]の範囲にスケーリング
])

# AnimeFaceDatasetのロード
dataset = datasets.ImageFolder(root='C:/Users/tyosu/projects/anime_faces',transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Generator(生成モデル)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 64 * 64 * 3),
            nn.Tanh()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1, 3, 64, 64)

# Discriminator(判別モデル)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(64 * 64 * 3, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 512),
            nn.ReLU(True),
            nn.Linear(512, 256),
            nn.ReLU(True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, input):
        input_flat = input.view(-1, 64 * 64 * 3)
        return self.main(input_flat)

# モデルのインスタンス化
G = Generator()
D = Discriminator()

# ロス関数とオプティマイザ
criterion = nn.BCELoss()
optimizerD = optim.Adam(D.parameters(), lr=0.0002)
optimizerG = optim.Adam(G.parameters(), lr=0.0002)

# ランダムノイズ生成関数
def generate_noise(batch_size):
    return torch.randn(batch_size, 100)

# トレーニングループ
num_epochs = 50
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(tqdm(dataloader)):
        batch_size = real_images.size(0)

        # 本物の画像のラベルは1、偽物の画像のラベルは0
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        # Discriminatorの学習
        optimizerD.zero_grad()
        outputs = D(real_images)
        real_loss = criterion(outputs, real_labels)

        noise = generate_noise(batch_size)
        fake_images = G(noise)
        outputs = D(fake_images.detach())
        fake_loss = criterion(outputs, fake_labels)

        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizerD.step()

        # Generatorの学習
        optimizerG.zero_grad()
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)  # 生成画像を本物と認識させたい
        g_loss.backward()
        optimizerG.step()

    print(f'Epoch [{epoch+1}/{num_epochs}] | d_loss: {d_loss.item()} | g_loss: {g_loss.item()}')

    # 生成された画像を表示
    if (epoch + 1) % 10 == 0:
        fake_images = G(generate_noise(64)).detach().cpu()
        plt.imshow(fake_images[0].permute(1, 2, 0) * 0.5 + 0.5)
       

投稿者: chosuke

趣味はゲームやアニメや漫画などです

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です