Update test1.py

This commit is contained in:
fanshuai 2025-01-23 08:26:57 +08:00
parent d08d381cf6
commit 4f4d4025c6
1 changed files with 121 additions and 114 deletions

View File

@ -68,32 +68,39 @@ def test_func(model, data_loader):
# 主函数必须叫main且只能有一个字典参数config
def main(config):
# Data Setup
mnist_transforms = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# 自定义读取数据的话则要访问config["dataset_path"]
# 读取自定义数据集的话则要访问config["dataset"]
train_dataset = datasets.MNIST(config["dataset"], train=True, download=False,
transform=mnist_transforms)
test_dataset = datasets.MNIST(config["dataset"], train=False, download=False,
transform=mnist_transforms)
train_loader = DataLoader(
datasets.MNIST("/opt/ray/MNIST", train=True, download=True, transform=mnist_transforms),
train_dataset,
batch_size=64,
shuffle=True)
test_loader = DataLoader(
datasets.MNIST("/opt/ray/MNIST", train=False, transform=mnist_transforms),
test_dataset,
batch_size=64,
shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvNet()
# 读取自定义预训练模型参数文件的话则要访问config["model"]
pretrained_state_dict = torch.load(config["model"])
model.load_state_dict(pretrained_state_dict)
model.to(device)
optimizer = optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"])
for i in range(EPOCH):
train_func(model, optimizer, train_loader)
acc = test_func(model, test_loader)