122 lines
3.9 KiB
Python
122 lines
3.9 KiB
Python
import os
|
||
import tempfile
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import torch.optim as optim
|
||
from ray import train
|
||
from ray.train import Checkpoint
|
||
from torch.utils.data import DataLoader
|
||
from torchvision import datasets, transforms
|
||
|
||
|
||
class ConvNet(nn.Module):
|
||
def __init__(self):
|
||
super(ConvNet, self).__init__()
|
||
# In this example, we don't change the model architecture
|
||
# due to simplicity.
|
||
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
|
||
self.fc = nn.Linear(192, 10)
|
||
|
||
def forward(self, x):
|
||
x = F.relu(F.max_pool2d(self.conv1(x), 3))
|
||
x = x.view(-1, 192)
|
||
x = self.fc(x)
|
||
return F.log_softmax(x, dim=1)
|
||
|
||
|
||
EPOCH = 10
|
||
TRAIN_SIZE = 512
|
||
TEST_SIZE = 256
|
||
|
||
|
||
def train_func(model, optimizer, train_loader):
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
model.train()
|
||
for batch_idx, (data, target) in enumerate(train_loader):
|
||
# We set this just for the example to run quickly.
|
||
if batch_idx * len(data) > TRAIN_SIZE:
|
||
return
|
||
data, target = data.to(device), target.to(device)
|
||
optimizer.zero_grad()
|
||
output = model(data)
|
||
loss = F.nll_loss(output, target)
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
|
||
def test_func(model, data_loader):
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
model.eval()
|
||
correct = 0
|
||
total = 0
|
||
with torch.no_grad():
|
||
for batch_idx, (data, target) in enumerate(data_loader):
|
||
# We set this just for the example to run quickly.
|
||
if batch_idx * len(data) > TEST_SIZE:
|
||
break
|
||
data, target = data.to(device), target.to(device)
|
||
outputs = model(data)
|
||
_, predicted = torch.max(outputs.data, 1)
|
||
total += target.size(0)
|
||
correct += (predicted == target).sum().item()
|
||
|
||
return correct / total
|
||
|
||
|
||
# 主函数必须叫main,且只能有一个字典参数config
|
||
def main(config):
|
||
# Data Setup
|
||
mnist_transforms = transforms.Compose(
|
||
[transforms.ToTensor(),
|
||
transforms.Normalize((0.1307,), (0.3081,))])
|
||
|
||
# 读取自定义数据集的话,则要访问config["dataset"]
|
||
train_dataset = datasets.MNIST(config["dataset"], train=True, download=False,
|
||
transform=mnist_transforms)
|
||
|
||
test_dataset = datasets.MNIST(config["dataset"], train=False, download=False,
|
||
transform=mnist_transforms)
|
||
|
||
train_loader = DataLoader(
|
||
train_dataset,
|
||
batch_size=64,
|
||
shuffle=True)
|
||
test_loader = DataLoader(
|
||
test_dataset,
|
||
batch_size=64,
|
||
shuffle=True)
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
model = ConvNet()
|
||
|
||
# 读取自定义预训练模型参数文件的话,则要访问config["model"]
|
||
pretrained_state_dict = torch.load(os.path.join(config["model"], 'model.pth'))
|
||
model.load_state_dict(pretrained_state_dict)
|
||
|
||
model.to(device)
|
||
|
||
optimizer = optim.SGD(
|
||
model.parameters(), lr=config["lr"], momentum=config["momentum"])
|
||
|
||
for i in range(EPOCH):
|
||
train_func(model, optimizer, train_loader)
|
||
acc = test_func(model, test_loader)
|
||
|
||
# 欲保存checkpoint,则必须按照以下代码规范编写
|
||
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
|
||
checkpoint = None
|
||
if (i + 1) % 5 == 0:
|
||
# This saves the model to the trial directory
|
||
torch.save(
|
||
model.state_dict(),
|
||
os.path.join(temp_checkpoint_dir, "model.pth")
|
||
)
|
||
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
|
||
|
||
# 每一个epoch必须向tune报告指标,指标名称必须与启动命令里的指标名称相同
|
||
# Send the current training result back to Tune
|
||
train.report({"mean_accuracy": acc}, checkpoint=checkpoint)
|