优化了cnn路径
This commit is contained in:
parent
93238d0fbf
commit
13a6156ead
|
@ -17,13 +17,16 @@ from keras.models import Sequential
|
|||
from keras.optimizers import RMSprop
|
||||
from keras.preprocessing.image import ImageDataGenerator
|
||||
from keras.utils.np_utils import to_categorical # convert to one-hot-encoding
|
||||
import os
|
||||
|
||||
np.random.seed(2)
|
||||
|
||||
# 数据路径
|
||||
data_dir = '/media/wsw/B634091A3408DF6D/data/kaggle/datasets/getting-started/digit-recognizer/'
|
||||
|
||||
# Load the data
|
||||
train = pd.read_csv(
|
||||
r'datasets/getting-started/digit-recognizer/input/train.csv')
|
||||
test = pd.read_csv(r'datasets/getting-started/digit-recognizer/input/test.csv')
|
||||
train = pd.read_csv(os.path.join(data_dir, 'input/train.csv'))
|
||||
test = pd.read_csv(os.path.join(data_dir, 'input/test.csv'))
|
||||
|
||||
X_train = train.values[:, 1:]
|
||||
Y_train = train.values[:, 0]
|
||||
|
@ -129,6 +132,5 @@ submission = pd.concat(
|
|||
[pd.Series(
|
||||
range(1, 28001), name="ImageId"), results], axis=1)
|
||||
|
||||
submission.to_csv(
|
||||
"datasets/getting-started/digit-recognizer/ouput/Result_keras_CNN.csv",
|
||||
index=False)
|
||||
submission.to_csv(os.path.join(data_dir, "output/Result_keras_CNN.csv",index=False))
|
||||
print('finished')
|
||||
|
|
|
@ -17,14 +17,17 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import os.path
|
||||
|
||||
# 数据路径
|
||||
data_dir = '/media/wsw/B634091A3408DF6D/data/kaggle/datasets/getting-started/digit-recognizer/'
|
||||
|
||||
class CustomedDataSet(Dataset):
|
||||
def __init__(self, train=True):
|
||||
self.train = train
|
||||
if self.train:
|
||||
trainX = pd.read_csv(
|
||||
'/opt/data/kaggle/getting-started/digit-recognizer/input/train.csv'
|
||||
os.path.join(data_dir, 'input/train.csv')
|
||||
# names=["ImageId", "Label"]
|
||||
)
|
||||
trainY = trainX.label.as_matrix().tolist()
|
||||
|
@ -34,7 +37,7 @@ class CustomedDataSet(Dataset):
|
|||
self.labellist = trainY
|
||||
else:
|
||||
testX = pd.read_csv(
|
||||
'/opt/data/kaggle/getting-started/digit-recognizer/input/test.csv'
|
||||
os.path.join(data_dir, 'input/test.csv')
|
||||
)
|
||||
self.testID = testX.index
|
||||
testX = testX.as_matrix().reshape(testX.shape[0], 1, 28, 28)
|
||||
|
@ -178,6 +181,6 @@ submission_df = pd.DataFrame(
|
|||
'Label': testLabel})
|
||||
# print(submission_df.head(10))
|
||||
submission_df.to_csv(
|
||||
'/opt/data/kaggle/getting-started/digit-recognizer/output/Result_pytorch_CNN.csv',
|
||||
os.path.join(data_dir, 'output/Result_pytorch_CNN.csv'),
|
||||
columns=["ImageId", "Label"],
|
||||
index=False)
|
||||
|
|
Loading…
Reference in New Issue