PyTorch实现Dropout

在模型中加入
nn.Dropout(d)
其中d是dropout参数,当d=0时所有单元均保留,d=1时所有单元均丢弃。

完整示例代码如下:
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.utils import data

# 定义模型参数
n_i, n_h1, n_h2, n_o = 784, 256, 256, 10
dropout1, dropout2 = 0.2, 0.5
# 定义模型
net = nn.Sequential(nn.Flatten(),
nn.Linear(n_i, n_h1),
nn.ReLU(),
nn.Dropout(dropout1),
nn.Linear(n_h1, n_h2),
nn.ReLU(),
nn.Dropout(dropout2),
nn.Linear(n_h2, n_o))
# 初始化模型参数
def init_weights(m):
if m.type == nn.Linear:
nn.init.normal(m.weight, std=0.01)
net.apply(init_weights)

#定义训练参数
num_epochs, lr, batch_size = 100, 0.5, 256

#损失函数和优化器
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr)

#定义单次训练
def epoch(net, train_iter, test_iter, loss, optimizer):
net.train()
train_correct, test_correct = 0, 0
for X, y in train_iter:
y_hat = net(X)
l = loss(y_hat, y)
optimizer.zero_grad()
l.mean().backward()
optimizer.step()
with torch.no_grad():
train_correct += evaluate_accuracy(net, X, y)
with torch.no_grad():
for X, y in test_iter:
test_correct += evaluate_accuracy(net, X, y)
return train_correct, test_correct

# 定义精度检验函数
def evaluate_accuracy(net, X, y):
y_hat = net(X).argmax(axis=1)
grade = y_hat.type(y.dtype) == y
return float(grade.type(y.dtype).sum())

# 读取数据集
def load_data(batch_size):
trans = transforms.ToTensor()
train_set = torchvision.datasets.FashionMNIST(root='../data', train=True, transform=trans, download=True)
test_set = torchvision.datasets.FashionMNIST(root='../data', train=False, transform=trans, download=True)
return data.DataLoader(train_set, batch_size, shuffle=True), data.DataLoader(test_set, batch_size, shuffle=False)

train_iter, test_iter = load_data(batch_size)

# 训练
t1s, t2s = [], []
def train(net, train_iter, test_iter, loss, num_epochs, optimizer):
for epochs in range(num_epochs):
t1, t2 = epoch(net, train_iter, test_iter, loss, optimizer)
print(f'epoch{epochs + 1}: train={t1}, test={t2}')
t1s.append(t1/600)
t2s.append(t2/100)
plt.ylim((0, 100))
plt.plot(range(epochs+1), t1s)
plt.plot(range(epochs+1), t2s)
if epochs % 10 == 1:
plt.show()

train(net, train_iter, test_iter, loss, num_epochs, optimizer)

    所属分类:机器学习     发表于2022-02-10