Modified the data path of svm-python3.6.py and add the data visualization module
This commit is contained in:
parent
2cd68ce77b
commit
0b79de545d
|
@ -14,11 +14,13 @@ import csv
|
|||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.svm import SVC
|
||||
from sklearn.metrics import classification_report
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
# 数据路径
|
||||
data_dir = '/Users/wuyanxue/Documents/GitHub/datasets/getting-started/digit-recognizer/'
|
||||
|
||||
# 加载数据
|
||||
|
@ -59,6 +61,7 @@ def dRCsv(x_train, x_test, preData, COMPONENT_NUM):
|
|||
return pcaTrainData, pcaTestData, pcaPreData
|
||||
|
||||
|
||||
|
||||
# 训练模型
|
||||
def trainModel(trainData, trainLabel):
|
||||
print('Train SVM...')
|
||||
|
@ -181,7 +184,6 @@ def getModel(filename):
|
|||
fr = open(filename, 'rb')
|
||||
return pickle.load(fr)
|
||||
|
||||
|
||||
def trainDRSVM():
|
||||
startTime = time.time()
|
||||
|
||||
|
@ -190,8 +192,8 @@ def trainDRSVM():
|
|||
# 模型训练 (数据预处理-降维)
|
||||
optimalSVMClf, pcaPreData = getOptimalAccuracy(trainData, trainLabel, preData)
|
||||
|
||||
storeModel(optimalSVMClf, os.path.join(data_dir, '/ouput/Result_sklearn_SVM.model'))
|
||||
storeModel(pcaPreData, os.path.join(data_dir, '/ouput/Result_sklearn_SVM.pcaPreData'))
|
||||
storeModel(optimalSVMClf, os.path.join(data_dir, 'output/Result_sklearn_SVM.model'))
|
||||
storeModel(pcaPreData, os.path.join(data_dir, 'output/Result_sklearn_SVM.pcaPreData'))
|
||||
|
||||
print("finish!")
|
||||
stopTime = time.time()
|
||||
|
@ -201,8 +203,8 @@ def trainDRSVM():
|
|||
def preDRSVM():
|
||||
startTime = time.time()
|
||||
# 加载模型和数据
|
||||
optimalSVMClf = getModel(os.path.join(data_dir, '/ouput/Result_sklearn_SVM.model'))
|
||||
pcaPreData = getModel(os.path.join(data_dir, '/ouput/Result_sklearn_SVM.pcaPreData'))
|
||||
optimalSVMClf = getModel(os.path.join(data_dir, 'output/Result_sklearn_SVM.model'))
|
||||
pcaPreData = getModel(os.path.join(data_dir, 'output/Result_sklearn_SVM.pcaPreData'))
|
||||
|
||||
# 结果预测
|
||||
testLabel = optimalSVMClf.predict(pcaPreData)
|
||||
|
@ -213,13 +215,30 @@ def preDRSVM():
|
|||
stopTime = time.time()
|
||||
print('PreModel load time used:%f s' % (stopTime - startTime))
|
||||
|
||||
# 数据可视化
|
||||
def dataVisulization(data, labels):
|
||||
pca = PCA(n_components=2, whiten=True) # 使用PCA方法降到2维
|
||||
pca.fit(data)
|
||||
pcaData = pca.transform(data)
|
||||
uniqueClasses = set(labels)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(1, 1, 1)
|
||||
for cClass in uniqueClasses:
|
||||
plt.scatter(pcaData[labels==cClass, 0], pcaData[labels==cClass, 1])
|
||||
plt.xlabel('$x_1$')
|
||||
plt.ylabel('$x_2$')
|
||||
plt.title('MNIST visualization')
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
trainData, trainLabel, preData = opencsv()
|
||||
dataVisulization(trainData, trainLabel)
|
||||
|
||||
|
||||
# 训练并保存模型
|
||||
trainDRSVM()
|
||||
#trainDRSVM()
|
||||
|
||||
# 分析数据
|
||||
analyse_data(trainData)
|
||||
#analyse_data(trainData)
|
||||
# 加载预测数据集
|
||||
preDRSVM()
|
||||
#preDRSVM()
|
||||
|
|
Loading…
Reference in New Issue