PyTorch实现多层感知机

import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
from torch.utils import data
import matplotlib.pyplot as plt

# 定义模型
i_size, h_size, o_size = 784, 1024, 10
net = nn.Sequential(nn.Flatten(), nn.Linear(i_size, h_size), nn.ReLU(), nn.Linear(h_size, o_size))

# 初始化模型参数
def init_weights(m):
if m.type == nn.Linear: # 只需要填充线性模型
nn.init.normal_(m.weight, std=0.01) # 正态分布注入Tensor

net.apply(init_weights)

# 定义训练参数
b_size, lr, n_epochs = 256, 0.1, 1000

# 定义损失函数
loss = nn.CrossEntropyLoss()

# 定义优化函数
optimizer = torch.optim.SGD(net.parameters(), lr=lr)

# 读取数据集
def load_data(batch_size):
trans = transforms.ToTensor()
train_set = torchvision.datasets.FashionMNIST(root='../data', train=True, transform=trans, download=True)
test_set = torchvision.datasets.FashionMNIST(root='../data', train=False, transform=trans, download=True)
return data.DataLoader(train_set, batch_size, shuffle=True), data.DataLoader(test_set, batch_size, shuffle=False)

train_iter, test_iter = load_data(b_size)

# 定义单次训练迭代
def train_epoch(net, train_iter, test_iter, loss, optimizer):
net.train()
train_correct = 0
test_correct = 0
for X, y in train_iter:
y_hat = net(X)
print(y.shape, y_hat.shape)
l = loss(y_hat, y)
optimizer.zero_grad()
l.mean().backward()
optimizer.step()
with torch.no_grad():
train_correct += evaluate_accuracy(net, X, y)
with torch.no_grad():
for X, y in test_iter:
test_correct += evaluate_accuracy(net, X, y)
return train_correct, test_correct

# 定义测试精度函数
def evaluate_accuracy(net, X, y):
y_hat = net(X).argmax(axis=1)
grade = y_hat.type(y.dtype) == y
return float(grade.type(y.dtype).sum())


t1s, t2s = [], []

def train(net, train_iter, test_iter, loss, num_epochs, optimizer):
for epoch in range(num_epochs):
t1, t2 = train_epoch(net, train_iter, test_iter, loss, optimizer)
print(f'epoch{epoch + 1}: train={t1}, test={t2}')
t1s.append(t1/600)
t2s.append(t2/100)
plt.ylim((0, 100))
plt.plot(range(epoch+1), t1s)
plt.plot(range(epoch+1), t2s)
if epoch == 0:
plt.show()

train(net, train_iter, test_iter, loss, n_epochs, optimizer)

    所属分类:机器学习     发表于2022-02-10