JCHub

  • Home
  • Category
    • A/V
    • WebRTC
    • Beauty of Programming
    • Linux
    • Windows
    • Moments of Life
    • Campus Life
  • Reference
    • API Reference
    • Utilities
    • AV Test
    • Doc
  • Message Board
  • About
JCHub
Code as My Sword, Lost in Obsession
  1. Main page
  2. Beauty of Programming
  3. Main content

KNN算法

2013年10月20日 2163hotness 0likes 0comments

KNN算法即k-Nearest Neighbor algorithm,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一,可应用于基于智能终端传感器的活动识别。KNN算法思路如下:
1、假设一个样本空间里的样本可分成几个类型
2、给定一个未知类型的待分类样本
3、基本思路:
a人以类聚、物以群归
b待分类样本与哪一类的样本比较相近,就归属于哪一类
c具体的相近比较基于最近的K个样本
d归属于K个样本中的多数样本所属的类
维基百科上的KNN词条中有一个比较经典的图如右:
从右图中可以看到
-图中的有两个类型的样本数据 KNN算法图解
-一类是蓝色的正方形
-一类是红色的三角形。
-绿色的圆形是待分类的数据。
如果K=3,那么离绿色点最近的有2个红色三角形和1个蓝色的正方形,这3个点投票,于是绿色的这个待分类点属于红色的三角形。
如果K=5,那么离绿色点最近的有2个红色三角形和3个蓝色的正方形,这5个点投票,于是绿色的这个待分类点属于蓝色的正方形。
KNN算法实现:
Knn.h

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#pragma once
class Knn
{
private:
double** trainingDataset;
double* arithmeticMean;
double* standardDeviation;
int m, n;
void RescaleDistance(double* row);
void RescaleTrainingDataset();
void ComputeArithmeticMean();
void ComputeStandardDeviation();
double Distance(double* x, double* y);
public:
Knn(double** trainingDataset, int m, int n);
~Knn();
double Vote(double* test, int k);
};

Knn.cpp

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include "Knn.h"
#include <cmath>
#include <map>
using namespace std;
Knn::Knn(double** trainingDataset, int m, int n)
{
this->trainingDataset = trainingDataset;
this->m = m;
this->n = n;
ComputeArithmeticMean();
ComputeStandardDeviation();
RescaleTrainingDataset();
}
void Knn::ComputeArithmeticMean()
{
arithmeticMean = new double[n - 1];
double sum;
for(int i = 0; i < n - 1; i++)
{
  sum = 0;
  for(int j = 0; j < m; j++)
  {
   sum += trainingDataset[j][i];
  }
  arithmeticMean[i] = sum / n;
}
}
void Knn::ComputeStandardDeviation()
{
standardDeviation = new double[n - 1];
double sum, temp;
for(int i = 0; i < n - 1; i++)
{
  sum = 0;
  for(int j = 0; j < m; j++)
  {
   temp = trainingDataset[j][i] - arithmeticMean[i];
   sum += temp * temp;
  }
  standardDeviation[i] = sqrt(sum / n);
}
}
void Knn::RescaleDistance(double* row)
{
for(int i = 0; i < n - 1; i++)
{
  row[i] = (row[i] - arithmeticMean[i]) / standardDeviation[i];
}
}
void Knn::RescaleTrainingDataset()
{
for(int i = 0; i < m; i++)
{
  RescaleDistance(trainingDataset[i]);
}
}
Knn::~Knn()
{
delete[] arithmeticMean;
delete[] standardDeviation;
}
double Knn::Distance(double* x, double* y)
{
double sum = 0, temp;
for(int i = 0; i < n - 1; i++)
{
  temp = (x[i] - y[i]);
  sum += temp * temp;
}
return sqrt(sum);
}
double Knn::Vote(double* test, int k)
{
RescaleDistance(test);
double distance;
map<int, double>::iterator max;
map<int, double> mins;
for(int i = 0; i < m; i++)
{
  distance = Distance(test, trainingDataset[i]);
  if(mins.size() < k)
   mins.insert(map<int, double>::value_type(i, distance));
  else
  {
   max = mins.begin();
   for(map<int, double>::iterator it = mins.begin(); it != mins.end(); it++)
   {
    if(it->second > max->second)
     max = it;
   }
   if(distance < max->second)
   {
    mins.erase(max);
    mins.insert(map<int, double>::value_type(i, distance));
   }
  }
}
map<double, int> votes;
double temp;
for(map<int, double>::iterator it = mins.begin(); it != mins.end(); it++)
{
  temp = trainingDataset[it->first][n-1];
  map<double, int>::iterator voteIt = votes.find(temp);
  if(voteIt != votes.end())
   voteIt->second ++;
  else
   votes.insert(map<double, int>::value_type(temp, 1));
}
map<double, int>::iterator maxVote = votes.begin();
for(map<double, int>::iterator it = votes.begin(); it != votes.end(); it++)
{
  if(it->second > maxVote->second)
   maxVote = it;
}
test[n-1] = maxVote->first;
return maxVote->first;
}

main.cpp

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <iostream>
#include "Knn.h"
using namespace std;
int main(const int& argc, const char* argv[])
{
double** train = new double* [14];
for(int i = 0; i < 14; i ++)
  train[i] = new double[5];
double trainArray[14][5] =
{
  {0, 0, 0, 0, 0},
  {0, 0, 0, 1, 0},
  {1, 0, 0, 0, 1},
  {2, 1, 0, 0, 1},
  {2, 2, 1, 0, 1},
  {2, 2, 1, 1, 0},
  {1, 2, 1, 1, 1},
  {0, 1, 0, 0, 0},
  {0, 2, 1, 0, 1},
  {2, 1, 1, 0, 1},
  {0, 1, 1, 1, 1},
  {1, 1, 0, 1, 1},
  {1, 0, 1, 0, 1},
  {2, 1, 0, 1, 0}
};
for(int i = 0; i < 14; i ++)
  for(int j = 0; j < 5; j ++)
   train[i][j] = trainArray[i][j];
Knn knn(train, 14, 5);
double test[5] = {2, 2, 0, 1, 0};
cout<<knn.Vote(test, 3)<<endl;
for(int i = 0; i < 14; i ++)
  delete[] train[i];
delete[] train;
return 0;
}

This article is licensed with Creative Commons Attribution-NonCommercial-No Derivatives 4.0 International License
Tag: Nothing
Last updated:2018年12月23日

Jeff

管理员——代码为剑,如痴如醉

Tip the author Like
< Last article
Next article >

Comments

razz evil exclaim smile redface biggrin eek confused idea lol mad twisted rolleyes wink cool arrow neutral cry mrgreen drooling persevering
Cancel

This site uses Akismet to reduce spam. Learn how your comment data is processed.

相关文章
  • Google ProtoBuf协议介绍
  • Intel Media SDK 内存优化(转)
  • 网络字节转换到本地字节的函数模板
  • 解决Ubuntu下vlc无法播放文件
  • MFC WebBrowser控件如何实现滚动条滑动

COPYRIGHT © 2026 jianchihu.net. ALL RIGHTS RESERVED.

Theme Kratos Made By Seaton Jiang