1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
|
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms
from . import model
class Train(model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) output = model(data) loss = F.nll_loss(output, target) optimizer.zero_grad() loss.backward() optimizer.step() if(batch_idx+1)%30 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))
if __name__ == '__main__': BATCH_SIZE = 512 EPOCHS = 20 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.LeNet5().to(DEVICE) optimizer = optim.Adam(model.parameters()) train_loader = torch.utils.data.DataLoader( datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=BATCH_SIZE, shuffle=True) for epoch in range(1, EPOCHS + 1): Train(model, DEVICE, train_loader, optimizer, epoch) torch.save(model.state_dict(), 'model.ckpt')
|