Uploaded by Rinki

desnoinig

advertisement
6/17/23, 11:47 PM
Untitled1
In [74]: import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from PIL import Image
import matplotlib.pyplot as plt
# Define the VAE model
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
self.fc_mu = nn.Linear(32 * 7 * 7, 32)
self.fc_logvar = nn.Linear(32 * 7 * 7, 32)
self.decoder = nn.Sequential(
nn.Linear(32, 32 * 7 * 7),
nn.Unflatten(1, (32, 7, 7)),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1),
nn.Sigmoid()
)
def encode(self, x):
x = self.encoder(x)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar
def decode(self, z):
x = self.decoder(z)
return x
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
return z
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
x_recon = self.decode(z)
return x_recon, mu, logvar
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set hyperparameters
file:///C:/Users/rinki/Downloads/Untitled1.html
1/3
6/17/23, 11:47 PM
Untitled1
batch_size = 128
epochs = 2
learning_rate = 1e-3
image_size = 28
# Load the MNIST dataset
transform = transforms.Compose([
transforms.Grayscale(), # Convert image to grayscale
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transfo
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Initialize the VAE model
model = VAE().to(device)
# Define the loss function
def loss_function(x_recon, x, mu, logvar):
reconstruction_loss = nn.functional.binary_cross_entropy(x_recon, x, reducti
kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return reconstruction_loss + kl_divergence
# Set the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Training loop
model.train()
for epoch in range(1, epochs + 1):
total_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
# Forward pass
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
# Calculate loss
loss = loss_function(recon_batch, data, mu, logvar)
# Backward pass and optimization
loss.backward()
optimizer.step()
total_loss += loss.item()
# Print progress
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch, epochs, total_loss / len(t
# Save the trained model
torch.save(model.state_dict(), 'vae_model.pth')
# Load the trained model
model.load_state_dict(torch.load('vae_model.pth'))
# Load your own image
your_image = Image.open('m.webp').convert('L') # Convert to grayscale
your_image = your_image.resize((image_size, image_size))
your_image = transform(your_image).unsqueeze(0).to(device)
file:///C:/Users/rinki/Downloads/Untitled1.html
2/3
6/17/23, 11:47 PM
Untitled1
# Denoise your image
model.eval()
with torch.no_grad():
denoised_image, _, _ = model(your_image)
# Save the denoised image
save_image(denoised_image.view(1, 1, image_size, image_size), 'denoised_image.jp
# Display the original and denoised images
fig, axs = plt.subplots(1, 2)
axs[0].imshow(your_image.squeeze().cpu().numpy())
axs[0].set_title('Original Image')
axs[1].imshow(denoised_image.squeeze().cpu().numpy())
axs[1].set_title('Denoised Image')
plt.show()
Epoch [1/2], Loss: -3876699.6050
Epoch [2/2], Loss: -4502466.7196
file:///C:/Users/rinki/Downloads/Untitled1.html
3/3
Download