Merge pull request #222 from xuehuachunsheng/dev

修复rf-python3.6.py中文件打开没有关闭的问题
This commit is contained in:
片刻 2018-05-22 12:21:03 +08:00 committed by GitHub
commit 11f84f0a10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 13 deletions

View File

@ -19,7 +19,7 @@ import os.path
import time
# 数据路径
data_dir = '/media/wsw/B634091A3408DF6D/data/kaggle/datasets/getting-started/digit-recognizer/'
data_dir = '/Users/wuyanxue/Documents/GitHub/datasets/getting-started/digit-recognizer/'
# 加载数据
def opencsv():
@ -57,8 +57,8 @@ def dRPCA(data, COMPONENT_NUM=100):
def trainModel(X_train, y_train):
print('Train RF...')
clf = RandomForestClassifier(
n_estimators=140,
max_depth=20,
n_estimators=10,
max_depth=10,
min_samples_split=2,
min_samples_leaf=1,
random_state=34)
@ -99,16 +99,13 @@ def getModel(filename):
# 结果输出保存
def saveResult(result, csvName):
i = 0
fw = open(csvName, 'w')
with open(os.path.join(data_dir, 'output/sample_submission.csv')
) as pred_file:
n = len(result)
print('the size of test set is {}'.format(n))
with open(os.path.join(data_dir, 'output/Result_sklearn_RF.csv'), 'w') as fw:
fw.write('{},{}\n'.format('ImageId', 'Label'))
for line in pred_file.readlines()[1:]:
splits = line.strip().split(',')
fw.write('{},{}\n'.format(splits[0], result[i]))
i += 1
fw.close()
print('Result saved successfully...')
for i in range(1, n + 1):
fw.write('{},{}\n'.format(i, result[i - 1]))
print('Result saved successfully... and the path = {}'.format(csvName))
def trainRF():
@ -151,7 +148,7 @@ def preRF():
result = clf.predict(pcaPreData)
# 结果的输出
saveResult(result,os.path.join(data_dir, 'output/Result_sklearn_rf.csv'))
saveResult(result, os.path.join(data_dir, 'output/Result_sklearn_rf.csv'))
print("finish!")
stopTime = time.time()
print('PreModel load time used:%f s' % (stopTime - startTime))