kmeans.h
1 //
2 // Copyright (c) 2003-2011, MIST Project, Nagoya University
3 // All rights reserved.
4 //
5 // Redistribution and use in source and binary forms, with or without modification,
6 // are permitted provided that the following conditions are met:
7 //
8 // 1. Redistributions of source code must retain the above copyright notice,
9 // this list of conditions and the following disclaimer.
10 //
11 // 2. Redistributions in binary form must reproduce the above copyright notice,
12 // this list of conditions and the following disclaimer in the documentation
13 // and/or other materials provided with the distribution.
14 //
15 // 3. Neither the name of the Nagoya University nor the names of its contributors
16 // may be used to endorse or promote products derived from this software
17 // without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
20 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
21 // FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
22 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
25 // IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
26 // THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 
28 #ifndef __INCLUDE_MIST_KMEANS__
29 #define __INCLUDE_MIST_KMEANS__
30 
31 #ifndef __INCLUDE_MIST_H__
32 #include "../mist.h"
33 #endif
34 
35 #ifndef __INCLUDE_MIST_RANDOM__
36 #include "../random.h"
37 #endif
38 
39 #ifndef __INCLUDE_MIST_CRITERIA__
40 #include "criteria.h"
41 #endif
42 
43 #include <cstdio>
44 
46 
47 namespace kmeans
48 {
49  namespace detail
50  {
51  double nearestCenter( const matrix< double > &in, int col, const matrix< double > &center, int n, int &idx )
52  {
53  double minv = 1e12;
54 
55  for( int i = 0 ; i < n ; ++i )
56  {
57  double dist = 0.0;
58  for( int j = 0 ; j < static_cast< int >( in.rows() ) ; ++j )
59  {
60  dist += pow( in( j, col ) - center( j, i ), 2.0 );
61  }
62  if( dist < minv )
63  {
64  minv = dist;
65  idx = i;
66  }
67  }
68 
69  return minv;
70  }
71  }
72 
80  inline void clustering(
81  const matrix< double > &in,
82  int &k,
83  array1< int > &response,
84  matrix< double > &center,
85  criteria crt = criteria( criteria::iteration, 0.0, 20 ) )
86  {
87  response.resize( in.cols() );
88  center.resize( in.rows(), k );
89 
90  // calculate value range(min-max)
91  matrix< double > minv( in.rows(), 1 );
92  matrix< double > maxv( in.rows(), 1 );
93  for( size_t i = 0 ; i < in.rows() ; ++i )
94  {
95  minv( i, 0 ) = 1e12;
96  maxv( i, 0 ) = -1e12;
97  for( size_t j = 0 ; j < in.cols() ; ++j )
98  {
99  minv( i, 0 ) = min( minv( i, 0 ), in( i, j ) );
100  maxv( i, 0 ) = max( maxv( i, 0 ), in( i, j ) );
101  }
102  }
103  /*
104  // initialize centroid[0-1]
105  uniform::random rnd;
106  for( int i = 0 ; i < k ; ++i )
107  {
108  for( size_t j = 0 ; j < in.rows() ; ++j )
109  {
110  center( j, i ) = minv( j, 0 ) + rnd.real3() * ( maxv( j, 0 ) - minv( j, 0 ) );
111  }
112  }
113  */
114 
115  // initialize centroid using kmeans++
116  // kmeans++: David Arthur, Sergei Vassilvitskii, "k-means++: The Advantages of Careful Seeding," Proc. SODA, 2007
117  uniform::random rnd;
118  for( int i = 0 ; i < k ; ++i )
119  {
120  if( i == 0 )
121  {
122  for( int j = 0 ; j < static_cast< int >( in.rows() ) ; ++j )
123  {
124  center( j, i ) = minv( j, 0 ) + rnd.real3() * ( maxv( j, 0 ) - minv( j, 0 ) );
125  }
126  }
127  else
128  {
129  /*
130  double total = 0.0;
131  for( int l = 0 ; l < in.cols() ; ++l )
132  {
133  total += detail::nearestCenter( in, l, center, i );
134  }
135  */
136  double max_prob = 0.0;
137  int max_idx = 0;
138  for( int l = 0 ; l < static_cast< int >( in.cols() ) ; ++l )
139  {
140  int idx;
141  double prob = detail::nearestCenter( in, l, center, i, idx );// / total;
142  if( prob > max_prob )
143  {
144  max_prob = prob;
145  max_idx = l;
146  }
147  }
148 
149  for( int j = 0 ; j < static_cast< int >( in.rows() ) ; ++j )
150  {
151  center( j, i ) = in( j, max_idx );
152  }
153  }
154  }
155 
156  array1< bool > is_set( k );
157  for( int m = 0 ; ( crt.type & criteria::iteration ) ? ( m < crt.max_itr ) : true ; ++m )
158  {
159  bool is_finish = true;
160  is_set.fill( false );
161 
162  // allocate to centroid
163  for( size_t i = 0 ; i < in.cols() ; ++i )
164  {
165  int minidx = 0;
166  /*
167  double mindist = 1e12;
168  for( int j = 0 ; j < k ; ++j )
169  {
170  double dist = 0;
171  for( size_t l = 0 ; l < in.rows() ; ++l )
172  {
173  dist += pow( center( l, j ) - in( l, i ), 2.0 );
174  }
175 
176  if( dist < mindist )
177  {
178  mindist = dist;
179  minidx = j;
180  }
181  }
182  */
183  detail::nearestCenter( in, i, center, k, minidx );
184  //
185  if( response( i ) != minidx )
186  {
187  response( i ) = minidx;
188  is_finish = false;
189  }
190  is_set( minidx ) = true;
191  }
192 
193  // re-calculate centroid
194  for( int i = 0 ; i < k ; ++i )
195  {
196  for( size_t j = 0 ; j < center.rows() ; ++j )
197  {
198  center( j, i ) = 0;
199  }
200 
201  int cnt = 0;
202  for( size_t j = 0 ; j < in.cols() ; ++j )
203  {
204  if( response( j ) == i )
205  {
206  for( size_t l = 0 ; l < center.rows() ; ++l )
207  {
208  center( l, i ) += in( l, j );
209  }
210  ++cnt;
211  }
212  }
213 
214  if( cnt > 0 )
215  {
216  for( size_t j = 0 ; j < center.rows() ; ++j )
217  {
218  center( j, i ) /= cnt;
219  }
220  }
221  }
222 
223  // check finish condition
224  if( is_finish )
225  {
226  break;
227  }
228  }
229 
230 
231  array1< int > idx_map( k );
232  int cnt = 0;
233  for( size_t i = 0 ; i < is_set.size() ; ++i )
234  {
235  if( is_set( i ) )
236  {
237  idx_map( i ) = cnt++;
238  }
239  else
240  {
241  idx_map( i ) = -1;
242  }
243  }
244 
245  if( k > cnt )
246  {
247  // reduce k
248  matrix< double > tmp = center;
249  center.resize( in.rows(), cnt );
250 
251  int c = 0;
252  for( size_t i = 0 ; i < tmp.cols() ; ++i )
253  {
254  if( is_set( i ) )
255  {
256  for( size_t j = 0 ; j < tmp.rows() ; ++j )
257  {
258  center( j, c ) = tmp( j, i );
259  }
260  ++c;
261  }
262  }
263 
264  for( size_t i = 0 ; i < response.size() ; ++i )
265  {
266  response( i ) = idx_map( response( i ) );
267  }
268 
269  k = cnt;
270  }
271  }
272 }
273 
274 
275 _MIST_END
276 
277 #endif
278 
279 

Generated on Wed Nov 12 2014 19:44:17 for MIST by doxygen 1.8.1.2