Update test1.py

This commit is contained in:
fanshuai 2025-01-23 08:26:57 +08:00
parent d08d381cf6
commit 4f4d4025c6
1 changed files with 121 additions and 114 deletions

235
test1.py
View File

@ -1,114 +1,121 @@
import os
import tempfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ray import train
from ray.train import Checkpoint
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
# In this example, we don't change the model architecture
# due to simplicity.
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
self.fc = nn.Linear(192, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 3))
x = x.view(-1, 192)
x = self.fc(x)
return F.log_softmax(x, dim=1)
EPOCH = 10
TRAIN_SIZE = 512
TEST_SIZE = 256
def train_func(model, optimizer, train_loader):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
# We set this just for the example to run quickly.
if batch_idx * len(data) > TRAIN_SIZE:
return
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
def test_func(model, data_loader):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(data_loader):
# We set this just for the example to run quickly.
if batch_idx * len(data) > TEST_SIZE:
break
data, target = data.to(device), target.to(device)
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return correct / total
# 主函数必须叫main且只能有一个字典参数config
def main(config):
# Data Setup
mnist_transforms = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# 自定义读取数据的话则要访问config["dataset_path"]
train_loader = DataLoader(
datasets.MNIST("/opt/ray/MNIST", train=True, download=True, transform=mnist_transforms),
batch_size=64,
shuffle=True)
test_loader = DataLoader(
datasets.MNIST("/opt/ray/MNIST", train=False, transform=mnist_transforms),
batch_size=64,
shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvNet()
model.to(device)
optimizer = optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"])
for i in range(EPOCH):
train_func(model, optimizer, train_loader)
acc = test_func(model, test_loader)
# 欲保存checkpoint则必须按照以下代码规范编写
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
if (i + 1) % 5 == 0:
# This saves the model to the trial directory
torch.save(
model.state_dict(),
os.path.join(temp_checkpoint_dir, "model.pth")
)
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
# 每一个epoch必须向tune报告指标指标名称必须与启动命令里的指标名称相同
# Send the current training result back to Tune
train.report({"mean_accuracy": acc}, checkpoint=checkpoint)
import os
import tempfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ray import train
from ray.train import Checkpoint
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
# In this example, we don't change the model architecture
# due to simplicity.
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
self.fc = nn.Linear(192, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 3))
x = x.view(-1, 192)
x = self.fc(x)
return F.log_softmax(x, dim=1)
EPOCH = 10
TRAIN_SIZE = 512
TEST_SIZE = 256
def train_func(model, optimizer, train_loader):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
# We set this just for the example to run quickly.
if batch_idx * len(data) > TRAIN_SIZE:
return
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
def test_func(model, data_loader):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(data_loader):
# We set this just for the example to run quickly.
if batch_idx * len(data) > TEST_SIZE:
break
data, target = data.to(device), target.to(device)
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return correct / total
# 主函数必须叫main且只能有一个字典参数config
def main(config):
# Data Setup
mnist_transforms = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# 读取自定义数据集的话则要访问config["dataset"]
train_dataset = datasets.MNIST(config["dataset"], train=True, download=False,
transform=mnist_transforms)
test_dataset = datasets.MNIST(config["dataset"], train=False, download=False,
transform=mnist_transforms)
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True)
test_loader = DataLoader(
test_dataset,
batch_size=64,
shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvNet()
# 读取自定义预训练模型参数文件的话则要访问config["model"]
pretrained_state_dict = torch.load(config["model"])
model.load_state_dict(pretrained_state_dict)
model.to(device)
optimizer = optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"])
for i in range(EPOCH):
train_func(model, optimizer, train_loader)
acc = test_func(model, test_loader)
# 欲保存checkpoint则必须按照以下代码规范编写
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
if (i + 1) % 5 == 0:
# This saves the model to the trial directory
torch.save(
model.state_dict(),
os.path.join(temp_checkpoint_dir, "model.pth")
)
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
# 每一个epoch必须向tune报告指标指标名称必须与启动命令里的指标名称相同
# Send the current training result back to Tune
train.report({"mean_accuracy": acc}, checkpoint=checkpoint)