MNIST_Example/inference.py

156 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
######################## inference lenet example ########################
inference lenet according to model file
"""
"""
######################## 推理环境使用说明 ########################
1、在推理环境中需要将数据集从obs拷贝到推理镜像中推理完以后需要将输出的结果拷贝到obs.
(1)将数据集从obs拷贝到推理镜像中
obs_data_url = args.data_url
args.data_url = '/home/work/user-job-dir/data/'
if not os.path.exists(args.data_url):
os.mkdir(args.data_url)
try:
mox.file.copy_parallel(obs_data_url, args.data_url)
print("Successfully Download {} to {}".format(obs_data_url,
args.data_url))
except Exception as e:
print('moxing download {} to {} failed: '.format(
obs_data_url, args.data_url) + str(e))
(2)将模型文件从obs拷贝到推理镜像中
obs_ckpt_url = args.ckpt_url
args.ckpt_url = '/home/work/user-job-dir/checkpoint.ckpt'
try:
mox.file.copy(obs_ckpt_url, args.ckpt_url)
print("Successfully Download {} to {}".format(obs_ckpt_url,
args.ckpt_url))
except Exception as e:
print('moxing download {} to {} failed: '.format(
obs_ckpt_url, args.ckpt_url) + str(e))
(3)将输出的结果拷贝回obs
obs_result_url = args.result_url
args.result_url = '/home/work/user-job-dir/result/'
if not os.path.exists(args.result_url):
os.mkdir(args.result_url)
try:
mox.file.copy_parallel(args.result_url, obs_result_url)
print("Successfully Upload {} to {}".format(args.result_url, obs_result_url))
except Exception as e:
print('moxing upload {} to {} failed: '.format(args.result_url, obs_result_url) + str(e))
详细代码可参考以下示例代码:
"""
import os
import argparse
import moxing as mox
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore import Tensor
import numpy as np
from glob import glob
from dataset import create_dataset
from config import mnist_cfg as cfg
from lenet import LeNet5
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_url',
type=str,
default="./Data",
help='path where the dataset is saved')
parser.add_argument('--ckpt_url',
help='model to save/load',
default='./ckpt_url')
parser.add_argument('--result_url',
help='result folder to save/load',
default='./result')
args = parser.parse_args()
#将数据集从obs拷贝到推理镜像中
obs_data_url = args.data_url
args.data_url = '/home/work/user-job-dir/data/'
if not os.path.exists(args.data_url):
os.mkdir(args.data_url)
try:
mox.file.copy_parallel(obs_data_url, args.data_url)
print("Successfully Download {} to {}".format(obs_data_url,
args.data_url))
except Exception as e:
print('moxing download {} to {} failed: '.format(
obs_data_url, args.data_url) + str(e))
#对文件夹进行操作请使用mox.file.copy_parallel。如果拷贝一个文件。请使用mox.file.copy对文件操作本次操作是对文件进行操作
#将模型文件从obs拷贝到推理镜像中
obs_ckpt_url = args.ckpt_url
args.ckpt_url = '/home/work/user-job-dir/checkpoint.ckpt'
try:
mox.file.copy(obs_ckpt_url, args.ckpt_url)
print("Successfully Download {} to {}".format(obs_ckpt_url,
args.ckpt_url))
except Exception as e:
print('moxing download {} to {} failed: '.format(
obs_ckpt_url, args.ckpt_url) + str(e))
#设置输出路径result_url
obs_result_url = args.result_url
args.result_url = '/home/work/user-job-dir/result/'
if not os.path.exists(args.result_url):
os.mkdir(args.result_url)
args.dataset_path = args.data_url
args.save_checkpoint_path = args.ckpt_url
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Testing ==============")
args.load_ckpt_url = os.path.join(args.save_checkpoint_path)
print("args.load_ckpt_url is{}", args.load_ckpt_url )
param_dict = load_checkpoint(args.load_ckpt_url )
load_param_into_net(network, param_dict)
# 定义测试数据集batch_size设置为1则取出一张图片
ds_test = create_dataset(os.path.join(args.dataset_path, "test"), batch_size=1).create_dict_iterator()
data = next(ds_test)
# images为测试图片labels为测试图片的实际分类
images = data["image"].asnumpy()
labels = data["label"].asnumpy()
print('Tensor:', Tensor(data['image']))
# 使用函数model.predict预测image对应分类
output = model.predict(Tensor(data['image']))
predicted = np.argmax(output.asnumpy(), axis=1)
pred = np.argmax(output.asnumpy(), axis=1)
print('predicted:', predicted)
print('pred:', pred)
# 输出预测分类与实际分类,并输出到result_url
print(f'Predicted: "{predicted[0]}", Actual: "{labels[0]}"')
filename = 'result.txt'
file_path = os.path.join(args.result_url, filename)
with open(file_path, 'a+') as file:
file.write(" {}: {:.2f} \n".format("Predicted", predicted[0]))
# Upload results to obs
######################## 将输出的结果拷贝到obs固定写法 ########################
# 把推理后的结果从本地的运行环境拷贝回obs在启智平台相对应的推理任务中会提供下载
try:
mox.file.copy_parallel(args.result_url, obs_result_url)
print("Successfully Upload {} to {}".format(args.result_url, obs_result_url))
except Exception as e:
print('moxing upload {} to {} failed: '.format(args.result_url, obs_result_url) + str(e))
######################## 将输出的模型拷贝到obs ########################