ray_test/test1.py

122 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import tempfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ray import train
from ray.train import Checkpoint
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
# In this example, we don't change the model architecture
# due to simplicity.
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
self.fc = nn.Linear(192, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 3))
x = x.view(-1, 192)
x = self.fc(x)
return F.log_softmax(x, dim=1)
EPOCH = 10
TRAIN_SIZE = 512
TEST_SIZE = 256
def train_func(model, optimizer, train_loader):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
# We set this just for the example to run quickly.
if batch_idx * len(data) > TRAIN_SIZE:
return
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
def test_func(model, data_loader):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(data_loader):
# We set this just for the example to run quickly.
if batch_idx * len(data) > TEST_SIZE:
break
data, target = data.to(device), target.to(device)
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return correct / total
# 主函数必须叫main且只能有一个字典参数config
def main(config):
# Data Setup
mnist_transforms = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# 读取自定义数据集的话则要访问config["dataset"]
train_dataset = datasets.MNIST(config["dataset"], train=True, download=False,
transform=mnist_transforms)
test_dataset = datasets.MNIST(config["dataset"], train=False, download=False,
transform=mnist_transforms)
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True)
test_loader = DataLoader(
test_dataset,
batch_size=64,
shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvNet()
# 读取自定义预训练模型参数文件的话则要访问config["model"]
pretrained_state_dict = torch.load(os.path.join(config["model"], 'model.pth'))
model.load_state_dict(pretrained_state_dict)
model.to(device)
optimizer = optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"])
for i in range(EPOCH):
train_func(model, optimizer, train_loader)
acc = test_func(model, test_loader)
# 欲保存checkpoint则必须按照以下代码规范编写
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
if (i + 1) % 5 == 0:
# This saves the model to the trial directory
torch.save(
model.state_dict(),
os.path.join(temp_checkpoint_dir, "model.pth")
)
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
# 每一个epoch必须向tune报告指标指标名称必须与启动命令里的指标名称相同
# Send the current training result back to Tune
train.report({"mean_accuracy": acc}, checkpoint=checkpoint)