80 lines
2.3 KiB
Python
80 lines
2.3 KiB
Python
import sys
|
|
|
|
import numpy as np
|
|
from pyspark import SparkContext
|
|
from pyspark.sql import SQLContext
|
|
from pyspark.mllib.linalg import Vectors
|
|
import numpy as np
|
|
from operator import add
|
|
from os import listdir
|
|
|
|
def computeDist(a,b):
|
|
return np.sqrt(np.sum((a-b)**2))
|
|
class KNN:
|
|
'''
|
|
KNN is used for classification
|
|
'''
|
|
def __init__(self,featuresCol="features", labelCol="label"):
|
|
self.featuresCol,self.labelCol = featuresCol,labelCol
|
|
|
|
def classify(self,inXs,dataSet,k=10):
|
|
'''
|
|
classify unlabeled points in inXs.
|
|
:param inXs: points to be classified.
|
|
:param dataSEt: points that have been labeled.
|
|
:param k: using k nearest neighbors of some unlabeled point to determine its label, default to be 10 when not set.
|
|
'''
|
|
if len(inXs) != len(dataSet.first()[0].values):
|
|
print "length of features of inXs is not corresponding with dataset's"
|
|
return
|
|
dis = dataSet.map(lambda row: (row[1],computeDist(row[0].toArray(),inXs.toArray())))
|
|
|
|
def f(x):
|
|
print x
|
|
#dis.foreach(f)
|
|
orderedDis = dis.takeOrdered(k, key=lambda x: x[1])
|
|
#print orderedDis
|
|
|
|
groupLabel = sc.parallelize(orderedDis).map(lambda row:(row[0],1)).reduceByKey(add).takeOrdered(1,key=lambda row:-row[1])[0][0]
|
|
return groupLabel
|
|
|
|
|
|
|
|
def load_data(data_folder):
|
|
file_list=listdir(data_folder)
|
|
file_num=len(file_list)
|
|
datas = list()
|
|
for i in range(file_num):
|
|
filename=file_list[i]
|
|
fr=open('%s/%s' %(data_folder,filename))
|
|
data_in_line = list()
|
|
for j in range(32):
|
|
line_str=fr.readline()
|
|
for k in range(32):
|
|
data_in_line.append(int(line_str[k]))
|
|
|
|
|
|
label = filename.split('.')[0].split("_")[0]
|
|
# print "file:%s,label is %s"%(filename,label)
|
|
datas.append((Vectors.dense(data_in_line),float(label)))
|
|
return datas
|
|
|
|
if __name__ == "__main__":
|
|
sc = SparkContext(appName="KNN")
|
|
sqlContext = SQLContext(sc)
|
|
|
|
count,errorCount = 0,0
|
|
knn = KNN()
|
|
|
|
datasetDF = sqlContext.createDataFrame(load_data("train"),["features","label"]).cache()
|
|
testData = load_data("test")
|
|
|
|
for x in testData:
|
|
prediction = knn.classify(x[0],datasetDF,10)
|
|
print "%d-%d" %(x[1],prediction)
|
|
if prediction != x[1]:
|
|
errorCount += 1
|
|
count += 1
|
|
print "error rate is %f(%d/%d)" % (1.0 * errorCount / count,errorCount,count)
|
|
|
|
|