Modified the data path of svm-python3.6.py and add the data visualization module

This commit is contained in:
xuehuachunsheng 2018-05-17 21:20:17 +08:00
parent 2cd68ce77b
commit 0b79de545d
1 changed files with 27 additions and 8 deletions

View File

@ -14,11 +14,13 @@ import csv
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
# 数据路径
data_dir = '/Users/wuyanxue/Documents/GitHub/datasets/getting-started/digit-recognizer/'
# 加载数据
@ -59,6 +61,7 @@ def dRCsv(x_train, x_test, preData, COMPONENT_NUM):
return pcaTrainData, pcaTestData, pcaPreData
# 训练模型
def trainModel(trainData, trainLabel):
print('Train SVM...')
@ -181,7 +184,6 @@ def getModel(filename):
fr = open(filename, 'rb')
return pickle.load(fr)
def trainDRSVM():
startTime = time.time()
@ -190,8 +192,8 @@ def trainDRSVM():
# 模型训练 (数据预处理-降维)
optimalSVMClf, pcaPreData = getOptimalAccuracy(trainData, trainLabel, preData)
storeModel(optimalSVMClf, os.path.join(data_dir, '/ouput/Result_sklearn_SVM.model'))
storeModel(pcaPreData, os.path.join(data_dir, '/ouput/Result_sklearn_SVM.pcaPreData'))
storeModel(optimalSVMClf, os.path.join(data_dir, 'output/Result_sklearn_SVM.model'))
storeModel(pcaPreData, os.path.join(data_dir, 'output/Result_sklearn_SVM.pcaPreData'))
print("finish!")
stopTime = time.time()
@ -201,8 +203,8 @@ def trainDRSVM():
def preDRSVM():
startTime = time.time()
# 加载模型和数据
optimalSVMClf = getModel(os.path.join(data_dir, '/ouput/Result_sklearn_SVM.model'))
pcaPreData = getModel(os.path.join(data_dir, '/ouput/Result_sklearn_SVM.pcaPreData'))
optimalSVMClf = getModel(os.path.join(data_dir, 'output/Result_sklearn_SVM.model'))
pcaPreData = getModel(os.path.join(data_dir, 'output/Result_sklearn_SVM.pcaPreData'))
# 结果预测
testLabel = optimalSVMClf.predict(pcaPreData)
@ -213,13 +215,30 @@ def preDRSVM():
stopTime = time.time()
print('PreModel load time used:%f s' % (stopTime - startTime))
# 数据可视化
def dataVisulization(data, labels):
pca = PCA(n_components=2, whiten=True) # 使用PCA方法降到2维
pca.fit(data)
pcaData = pca.transform(data)
uniqueClasses = set(labels)
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
for cClass in uniqueClasses:
plt.scatter(pcaData[labels==cClass, 0], pcaData[labels==cClass, 1])
plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.title('MNIST visualization')
plt.show()
if __name__ == '__main__':
trainData, trainLabel, preData = opencsv()
dataVisulization(trainData, trainLabel)
# 训练并保存模型
trainDRSVM()
#trainDRSVM()
# 分析数据
analyse_data(trainData)
#analyse_data(trainData)
# 加载预测数据集
preDRSVM()
#preDRSVM()