剑痴乎

剑痴乎
代码为剑,如痴如醉
  1. 首页
  2. 编程之美
  3. 正文

KNN算法

2013年10月20日 1962点热度 0人点赞 0条评论

KNN算法即k-Nearest Neighbor algorithm,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一,可应用于基于智能终端传感器的活动识别。KNN算法思路如下:
1、假设一个样本空间里的样本可分成几个类型
2、给定一个未知类型的待分类样本
3、基本思路:
a人以类聚、物以群归
b待分类样本与哪一类的样本比较相近,就归属于哪一类
c具体的相近比较基于最近的K个样本
d归属于K个样本中的多数样本所属的类
维基百科上的KNN词条中有一个比较经典的图如右:
从右图中可以看到
-图中的有两个类型的样本数据 KNN算法图解
-一类是蓝色的正方形
-一类是红色的三角形。
-绿色的圆形是待分类的数据。
如果K=3,那么离绿色点最近的有2个红色三角形和1个蓝色的正方形,这3个点投票,于是绿色的这个待分类点属于红色的三角形。
如果K=5,那么离绿色点最近的有2个红色三角形和3个蓝色的正方形,这5个点投票,于是绿色的这个待分类点属于蓝色的正方形。
KNN算法实现:
Knn.h

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#pragma once
class Knn
{
private:
double** trainingDataset;
double* arithmeticMean;
double* standardDeviation;
int m, n;
void RescaleDistance(double* row);
void RescaleTrainingDataset();
void ComputeArithmeticMean();
void ComputeStandardDeviation();
double Distance(double* x, double* y);
public:
Knn(double** trainingDataset, int m, int n);
~Knn();
double Vote(double* test, int k);
};

Knn.cpp

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include "Knn.h"
#include <cmath>
#include <map>
using namespace std;
Knn::Knn(double** trainingDataset, int m, int n)
{
this->trainingDataset = trainingDataset;
this->m = m;
this->n = n;
ComputeArithmeticMean();
ComputeStandardDeviation();
RescaleTrainingDataset();
}
void Knn::ComputeArithmeticMean()
{
arithmeticMean = new double[n - 1];
double sum;
for(int i = 0; i < n - 1; i++)
{
  sum = 0;
  for(int j = 0; j < m; j++)
  {
   sum += trainingDataset[j][i];
  }
  arithmeticMean[i] = sum / n;
}
}
void Knn::ComputeStandardDeviation()
{
standardDeviation = new double[n - 1];
double sum, temp;
for(int i = 0; i < n - 1; i++)
{
  sum = 0;
  for(int j = 0; j < m; j++)
  {
   temp = trainingDataset[j][i] - arithmeticMean[i];
   sum += temp * temp;
  }
  standardDeviation[i] = sqrt(sum / n);
}
}
void Knn::RescaleDistance(double* row)
{
for(int i = 0; i < n - 1; i++)
{
  row[i] = (row[i] - arithmeticMean[i]) / standardDeviation[i];
}
}
void Knn::RescaleTrainingDataset()
{
for(int i = 0; i < m; i++)
{
  RescaleDistance(trainingDataset[i]);
}
}
Knn::~Knn()
{
delete[] arithmeticMean;
delete[] standardDeviation;
}
double Knn::Distance(double* x, double* y)
{
double sum = 0, temp;
for(int i = 0; i < n - 1; i++)
{
  temp = (x[i] - y[i]);
  sum += temp * temp;
}
return sqrt(sum);
}
double Knn::Vote(double* test, int k)
{
RescaleDistance(test);
double distance;
map<int, double>::iterator max;
map<int, double> mins;
for(int i = 0; i < m; i++)
{
  distance = Distance(test, trainingDataset[i]);
  if(mins.size() < k)
   mins.insert(map<int, double>::value_type(i, distance));
  else
  {
   max = mins.begin();
   for(map<int, double>::iterator it = mins.begin(); it != mins.end(); it++)
   {
    if(it->second > max->second)
     max = it;
   }
   if(distance < max->second)
   {
    mins.erase(max);
    mins.insert(map<int, double>::value_type(i, distance));
   }
  }
}
map<double, int> votes;
double temp;
for(map<int, double>::iterator it = mins.begin(); it != mins.end(); it++)
{
  temp = trainingDataset[it->first][n-1];
  map<double, int>::iterator voteIt = votes.find(temp);
  if(voteIt != votes.end())
   voteIt->second ++;
  else
   votes.insert(map<double, int>::value_type(temp, 1));
}
map<double, int>::iterator maxVote = votes.begin();
for(map<double, int>::iterator it = votes.begin(); it != votes.end(); it++)
{
  if(it->second > maxVote->second)
   maxVote = it;
}
test[n-1] = maxVote->first;
return maxVote->first;
}

main.cpp

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <iostream>
#include "Knn.h"
using namespace std;
int main(const int& argc, const char* argv[])
{
double** train = new double* [14];
for(int i = 0; i < 14; i ++)
  train[i] = new double[5];
double trainArray[14][5] =
{
  {0, 0, 0, 0, 0},
  {0, 0, 0, 1, 0},
  {1, 0, 0, 0, 1},
  {2, 1, 0, 0, 1},
  {2, 2, 1, 0, 1},
  {2, 2, 1, 1, 0},
  {1, 2, 1, 1, 1},
  {0, 1, 0, 0, 0},
  {0, 2, 1, 0, 1},
  {2, 1, 1, 0, 1},
  {0, 1, 1, 1, 1},
  {1, 1, 0, 1, 1},
  {1, 0, 1, 0, 1},
  {2, 1, 0, 1, 0}
};
for(int i = 0; i < 14; i ++)
  for(int j = 0; j < 5; j ++)
   train[i][j] = trainArray[i][j];
Knn knn(train, 14, 5);
double test[5] = {2, 2, 0, 1, 0};
cout<<knn.Vote(test, 3)<<endl;
for(int i = 0; i < 14; i ++)
  delete[] train[i];
delete[] train;
return 0;
}

本作品采用 知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议 进行许可
标签: 暂无
最后更新:2018年12月23日

Jeff

管理员——代码为剑,如痴如醉

打赏 点赞
< 上一篇
下一篇 >

文章评论

razz evil exclaim smile redface biggrin eek confused idea lol mad twisted rolleyes wink cool arrow neutral cry mrgreen drooling persevering
取消回复

这个站点使用 Akismet 来减少垃圾评论。了解你的评论数据如何被处理。

版权声明

为支持原创,创作更好的文章,未经许可,禁止任何形式的转载与抄袭,如需转载请邮件私信!本人保留所有法定权利。违者必究!

最近评论
ztt 发布于 3 周前(04月05日) 你好,想看里面的视频和图片为什么没有显示呢?需要下flash吗还是什么。
huowa222 发布于 1 个月前(03月26日) 同问
邱国禄 发布于 2 个月前(02月17日) Receive Delta以0.25ms为单位,reference time以64ms为单位,kDe...
啊非 发布于 4 个月前(12月30日) 大神,请教一个问题: constexpr int kBaseScaleFactor = Tran...
啊非 发布于 4 个月前(12月30日) reference time:3字节,表示参考时间,以64ms为单位,但是 代码里面是 Trans...
相关文章
  • Google ProtoBuf协议介绍
  • Intel Media SDK 内存优化(转)
  • 网络字节转换到本地字节的函数模板
  • 解决Ubuntu下vlc无法播放文件
  • MFC WebBrowser控件如何实现滚动条滑动

COPYRIGHT © 2024 jianchihu.net. ALL RIGHTS RESERVED.

Theme Kratos Made By Seaton Jiang