advse/logisticReg_HWR.py

43 lines
1.4 KiB
Python

from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.ml.classification import LogisticRegression
from pyspark.mllib.linalg import Vectors
from os import listdir
from pyspark.ml.evaluation import BinaryClassificationEvaluator
sc = SparkContext(appName="PythonlogExample")
sqlContext = SQLContext(sc)
def load_data(data_folder):
file_list=listdir(data_folder)
file_num=len(file_list)
datas = list()
for i in range(file_num):
filename=file_list[i]
fr=open('%s/%s' %(data_folder,filename))
data_in_line = list()
for j in range(32):
line_str=fr.readline()
for k in range(32):
data_in_line.append(int(line_str[k]))
label = filename.split('.')[0].split("_")[0]
# print "file:%s,label is %s"%(filename,label)
datas.append((float(label),Vectors.dense(data_in_line)))
return sqlContext.createDataFrame(datas,["label","features"])
if __name__ == "__main__":
train_df = load_data("train")
lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
lrModel = lr.fit(train_df)
test_df = load_data("test")
predictions = lrModel.transform(test_df)
#predictions.select("prediction","label").show(5)
evaluator = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction", metricName="areaUnderPR")
accuracy = evaluator.evaluate(predictions)
print("Test Error = %g " % (1.0 - accuracy))