Update test1.py
This commit is contained in:
parent
d08d381cf6
commit
4f4d4025c6
19
test1.py
19
test1.py
|
@ -68,32 +68,39 @@ def test_func(model, data_loader):
|
|||
# 主函数必须叫main,且只能有一个字典参数config
|
||||
def main(config):
|
||||
# Data Setup
|
||||
|
||||
|
||||
mnist_transforms = transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))])
|
||||
|
||||
# 自定义读取数据的话,则要访问config["dataset_path"]
|
||||
# 读取自定义数据集的话,则要访问config["dataset"]
|
||||
train_dataset = datasets.MNIST(config["dataset"], train=True, download=False,
|
||||
transform=mnist_transforms)
|
||||
|
||||
test_dataset = datasets.MNIST(config["dataset"], train=False, download=False,
|
||||
transform=mnist_transforms)
|
||||
|
||||
train_loader = DataLoader(
|
||||
datasets.MNIST("/opt/ray/MNIST", train=True, download=True, transform=mnist_transforms),
|
||||
train_dataset,
|
||||
batch_size=64,
|
||||
shuffle=True)
|
||||
test_loader = DataLoader(
|
||||
datasets.MNIST("/opt/ray/MNIST", train=False, transform=mnist_transforms),
|
||||
test_dataset,
|
||||
batch_size=64,
|
||||
shuffle=True)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = ConvNet()
|
||||
|
||||
# 读取自定义预训练模型参数文件的话,则要访问config["model"]
|
||||
pretrained_state_dict = torch.load(config["model"])
|
||||
model.load_state_dict(pretrained_state_dict)
|
||||
|
||||
model.to(device)
|
||||
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(), lr=config["lr"], momentum=config["momentum"])
|
||||
|
||||
|
||||
for i in range(EPOCH):
|
||||
train_func(model, optimizer, train_loader)
|
||||
acc = test_func(model, test_loader)
|
||||
|
|
Loading…
Reference in New Issue