advse/KNN.py

80 lines
2.3 KiB
Python

import sys
import numpy as np
from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.mllib.linalg import Vectors
import numpy as np
from operator import add
from os import listdir
def computeDist(a,b):
return np.sqrt(np.sum((a-b)**2))
class KNN:
'''
KNN is used for classification
'''
def __init__(self,featuresCol="features", labelCol="label"):
self.featuresCol,self.labelCol = featuresCol,labelCol
def classify(self,inXs,dataSet,k=10):
'''
classify unlabeled points in inXs.
:param inXs: points to be classified.
:param dataSEt: points that have been labeled.
:param k: using k nearest neighbors of some unlabeled point to determine its label, default to be 10 when not set.
'''
if len(inXs) != len(dataSet.first()[0].values):
print "length of features of inXs is not corresponding with dataset's"
return
dis = dataSet.map(lambda row: (row[1],computeDist(row[0].toArray(),inXs.toArray())))
def f(x):
print x
#dis.foreach(f)
orderedDis = dis.takeOrdered(k, key=lambda x: x[1])
#print orderedDis
groupLabel = sc.parallelize(orderedDis).map(lambda row:(row[0],1)).reduceByKey(add).takeOrdered(1,key=lambda row:-row[1])[0][0]
return groupLabel
def load_data(data_folder):
file_list=listdir(data_folder)
file_num=len(file_list)
datas = list()
for i in range(file_num):
filename=file_list[i]
fr=open('%s/%s' %(data_folder,filename))
data_in_line = list()
for j in range(32):
line_str=fr.readline()
for k in range(32):
data_in_line.append(int(line_str[k]))
label = filename.split('.')[0].split("_")[0]
# print "file:%s,label is %s"%(filename,label)
datas.append((Vectors.dense(data_in_line),float(label)))
return datas
if __name__ == "__main__":
sc = SparkContext(appName="KNN")
sqlContext = SQLContext(sc)
count,errorCount = 0,0
knn = KNN()
datasetDF = sqlContext.createDataFrame(load_data("train"),["features","label"]).cache()
testData = load_data("test")
for x in testData:
prediction = knn.classify(x[0],datasetDF,10)
print "%d-%d" %(x[1],prediction)
if prediction != x[1]:
errorCount += 1
count += 1
print "error rate is %f(%d/%d)" % (1.0 * errorCount / count,errorCount,count)