MNIST_Example/inference.py

104 lines
4.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
示例选用的数据集是MNISTData.zip
数据集结构是:
MNISTData.zip
├── test
│ ├── t10k-images-idx3-ubyte
│ └── t10k-labels-idx1-ubyte
└── train
├── train-images-idx3-ubyte
└── train-labels-idx1-ubyte
示例选用的模型文件是checkpoint_lenet-1_1875.ckpt
使用注意事项:
1、本示例需要用户定义的参数有--multi_data_url,--pretrain_url,--result_url这3个参数任务中必须定义
具体的含义如下:
--multi_data_url是启智平台上选择的数据集的obs路径
--pretrain_url是启智平台上选择的预训练模型文件的obs路径
--result_url是训练结果回传到启智平台的obs路径
2、用户需要调用OpenI.py下的DatasetToEnv,PretrainToEnv,UploadToOpenI等函数来实现数据集、预训练模型文件、训练结果的拷贝和回传
"""
import os
import argparse
import mindspore.nn as nn
import numpy as np
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindspore import Tensor
from dataset import create_dataset
from config import mnist_cfg as cfg
from lenet import LeNet5
from openi import openi_multidataset_to_env as DatasetToEnv
from openi import env_to_openi
from openi import pretrain_to_env
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--multi_data_url',
type=str,
default= '[{}]',
help='path where the dataset is saved')
parser.add_argument('--pretrain_url',
help='model to save/load',
default= '[{}]')
parser.add_argument('--result_url',
help='result folder to save/load',
default= '')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
if __name__ == "__main__":
args, unknown = parser.parse_known_args()
###Initialize the data and result directories in the inference image###
data_dir = '/cache/data'
pretrain_dir = '/cache/pretrain'
result_dir = '/cache/result'
if not os.path.exists(data_dir):
os.makedirs(data_dir)
if not os.path.exists(pretrain_dir):
os.makedirs(pretrain_dir)
if not os.path.exists(result_dir):
os.makedirs(result_dir)
###拷贝数据集到训练环境
DatasetToEnv(args.multi_data_url, data_dir)
###拷贝预训练模型文件到训练环境
pretrain_to_env(args.pretrain_url, pretrain_dir)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
model = Model(network,net_loss, net_opt, metrics={"accuracy"}, amp_level="O2")
print("============== Starting Testing ==============")
param_dict = load_checkpoint(os.path.join(pretrain_dir, "checkpoint_lenet-1_1875.ckpt"))
load_param_into_net(network, param_dict)
ds_test = create_dataset(os.path.join(data_dir + "/MNISTData", "test"), batch_size=1).create_dict_iterator()
data = next(ds_test)
images = data["image"].asnumpy()
labels = data["label"].asnumpy()
print('Tensor:', Tensor(data['image']))
output = model.predict(Tensor(data['image']))
predicted = np.argmax(output.asnumpy(), axis=1)
pred = np.argmax(output.asnumpy(), axis=1)
print('predicted:', predicted)
print('pred:', pred)
print(f'Predicted: "{predicted[0]}", Actual: "{labels[0]}"')
filename = 'result.txt'
file_path = os.path.join(result_dir, filename)
with open(file_path, 'a+') as file:
file.write(" {}: {:.2f} \n".format("Predicted", predicted[0]))
###上传训练结果到启智平台
env_to_openi(result_dir, args.result_url)