优化了cnn路径

This commit is contained in:
wang-sw 2018-05-29 11:14:24 +08:00
parent 93238d0fbf
commit 13a6156ead
2 changed files with 14 additions and 9 deletions

View File

@ -17,13 +17,16 @@ from keras.models import Sequential
from keras.optimizers import RMSprop
from keras.preprocessing.image import ImageDataGenerator
from keras.utils.np_utils import to_categorical # convert to one-hot-encoding
import os
np.random.seed(2)
# 数据路径
data_dir = '/media/wsw/B634091A3408DF6D/data/kaggle/datasets/getting-started/digit-recognizer/'
# Load the data
train = pd.read_csv(
r'datasets/getting-started/digit-recognizer/input/train.csv')
test = pd.read_csv(r'datasets/getting-started/digit-recognizer/input/test.csv')
train = pd.read_csv(os.path.join(data_dir, 'input/train.csv'))
test = pd.read_csv(os.path.join(data_dir, 'input/test.csv'))
X_train = train.values[:, 1:]
Y_train = train.values[:, 0]
@ -129,6 +132,5 @@ submission = pd.concat(
[pd.Series(
range(1, 28001), name="ImageId"), results], axis=1)
submission.to_csv(
"datasets/getting-started/digit-recognizer/ouput/Result_keras_CNN.csv",
index=False)
submission.to_csv(os.path.join(data_dir, "output/Result_keras_CNN.csv",index=False))
print('finished')

View File

@ -17,14 +17,17 @@ import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import os.path
# 数据路径
data_dir = '/media/wsw/B634091A3408DF6D/data/kaggle/datasets/getting-started/digit-recognizer/'
class CustomedDataSet(Dataset):
def __init__(self, train=True):
self.train = train
if self.train:
trainX = pd.read_csv(
'/opt/data/kaggle/getting-started/digit-recognizer/input/train.csv'
os.path.join(data_dir, 'input/train.csv')
# names=["ImageId", "Label"]
)
trainY = trainX.label.as_matrix().tolist()
@ -34,7 +37,7 @@ class CustomedDataSet(Dataset):
self.labellist = trainY
else:
testX = pd.read_csv(
'/opt/data/kaggle/getting-started/digit-recognizer/input/test.csv'
os.path.join(data_dir, 'input/test.csv')
)
self.testID = testX.index
testX = testX.as_matrix().reshape(testX.shape[0], 1, 28, 28)
@ -178,6 +181,6 @@ submission_df = pd.DataFrame(
'Label': testLabel})
# print(submission_df.head(10))
submission_df.to_csv(
'/opt/data/kaggle/getting-started/digit-recognizer/output/Result_pytorch_CNN.csv',
os.path.join(data_dir, 'output/Result_pytorch_CNN.csv'),
columns=["ImageId", "Label"],
index=False)