knn.h
1 //
2 // Copyright (c) 2003-2011, MIST Project, Nagoya University
3 // All rights reserved.
4 //
5 // Redistribution and use in source and binary forms, with or without modification,
6 // are permitted provided that the following conditions are met:
7 //
8 // 1. Redistributions of source code must retain the above copyright notice,
9 // this list of conditions and the following disclaimer.
10 //
11 // 2. Redistributions in binary form must reproduce the above copyright notice,
12 // this list of conditions and the following disclaimer in the documentation
13 // and/or other materials provided with the distribution.
14 //
15 // 3. Neither the name of the Nagoya University nor the names of its contributors
16 // may be used to endorse or promote products derived from this software
17 // without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
20 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
21 // FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
22 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
25 // IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
26 // THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 
28 #ifndef __INCLUDE_MIST_KNN__
29 #define __INCLUDE_MIST_KNN__
30 
31 #ifndef __INCLUDE_MIST_H__
32 #include "../mist.h"
33 #endif
34 
35 #ifndef __INCLUDE_MIST_MATRIX_H__
36 #include "../matrix.h"
37 #endif
38 
39 #include <map>
40 #include <algorithm>
41 
43 
44 namespace knn
45 {
47  class classifier
48  {
49  private:
50  struct distance_pair
51  {
52  int index;
53  double distance;
54 
55  bool operator<( const distance_pair &a ) const
56  {
57  return distance < a.distance;
58  }
59  };
60 
61  protected:
62  matrix< double > in_;
63  array1< int > response_;
64  int category_count_;
65 
66  public:
67  classifier( ) : category_count_( 0 )
68  {
69  }
70 
75  bool train( const matrix< double > &in, const array1< int > &response )
76  {
77  if( in.cols( ) != response.size( ) )
78  {
79  return false;
80  }
81 
82  in_ = in;
83  response_ = response;
84 
85  // calculate # of category
86  std::map< int, bool > rmap;
87  category_count_ = 0;
88  for( size_t i = 0 ; i < response.size() ; ++i )
89  {
90  if( rmap.find( response( i ) ) == rmap.end() )
91  {
92  rmap[ response( i ) ] = true;
93  ++category_count_;
94  }
95  }
96 
97  return true;
98  }
99 
105  int predict( const matrix< double > &in, int k = 1 )
106  {
107  array< distance_pair > mes( in_.cols() );
108  array< int > vote( category_count_ );
109 
110  // calculate distance from training dataset
111  for( size_t i = 0 ; i < in_.cols() ; ++i )
112  {
113  double dist = 0;
114  for( size_t j = 0 ; j < in_.rows() ; ++j )
115  {
116  dist += pow( in_( j, i ) - in( j, 0 ), 2.0 );
117  }
118  mes( i ).index = i;
119  mes( i ).distance = dist;
120  }
121 
122  // sort by ascending order
123  std::sort( mes.begin(), mes.end() );
124 
125  for( int i = 0 ; i < k ; ++i )
126  {
127  ++vote( response_( mes( i ).index ) );
128  }
129 
130  // predict category
131  int max_count = 0;
132  int max_idx = 0;
133  for( int i = 0 ; i < category_count_ ; ++i )
134  {
135  if( vote( i ) > max_count )
136  {
137  max_count = vote( i );
138  max_idx = i;
139  }
140  }
141 
142  return max_idx;
143  }
144  };
145 }
146 
147 _MIST_END
148 
149 #endif
150 

Generated on Wed Nov 12 2014 19:44:17 for MIST by doxygen 1.8.1.2