00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 #include "nnsearch.h"
00022 #include <gsl/gsl_sort.h>
00023 #include "pqueue.h"
00024 #include "mathadd.h"
00025
00038 SearchTree* nn_prepare( const double **X, int N, int m, OptArgList *optargs ){
00039 SearchTree *S;
00040 double **D;
00041 int *A;
00042 VectorDistanceFunction distfct;
00043 int max_numel;
00044 void *tmp;
00045 double x;
00046
00047
00048 distfct = vectordist_euclidean;
00049 max_numel = 50;
00050 if( optarglist_has_key( optargs, "metric" ) ){
00051 tmp = optarglist_ptr_by_key( optargs, "metric" );
00052 if( tmp )
00053 distfct = (VectorDistanceFunction) tmp;
00054 }
00055 if( optarglist_has_key( optargs, "max_numel_terminal_node" ) ){
00056 x = optarglist_scalar_by_key( optargs, "max_numel_terminal_node" );
00057 if( !isnan( x ) )
00058 max_numel=(int)x;
00059 }
00060 dprintf("maxnumel=%i\n", max_numel);
00061
00062
00063 dprintf("preparing search tree, %i points of dim %i\n", N, m);
00064 S = searchtree_init(N);
00065 A=S->A;
00066 D = vectordist_distmatrix( distfct, X,
00067 N, m, NULL, NULL, optargs );
00068 S->d=X;
00069 S->distfct = distfct;
00070 S->optargs = optargs;
00071 S->m=m;
00072 dblpp_print( X, N, m );
00073 dblpp_print( (const double**)D, N, N );
00074 S->root = tnode_init();
00075 S->root->c = (int)(random() / (RAND_MAX / N+1));
00076 S->root->start=0;
00077 S->root->end=N-1;
00078 S->root->R = dblp_max( D[S->root->c], N, NULL );
00079 dprintf("root center: %i, radius=%f\n", S->root->c, S->root->R );
00080
00081 build_tree_recursive( S->root, D, N, A, max_numel );
00082 dprintf(" S->N=%i\n", S->N);
00083 return S;
00084 }
00085
00102 void nn_search_k_slow( const double **X, int N, int m, const double *x, int k,
00103 int *nn_idx, double *nn_dist, OptArgList *optargs ){
00104 VectorDistanceFunction distfct;
00105 void *tmp;
00106 int i;
00107 size_t *permut;
00108 double *ds;
00109
00110 permut=(size_t*)malloc(N*sizeof(size_t));
00111 ds = (double*)malloc( N*sizeof(double) );
00112
00113
00114 distfct = vectordist_euclidean;
00115 if( optarglist_has_key( optargs, "metric" ) ){
00116 tmp = optarglist_ptr_by_key( optargs, "metric" );
00117 if( tmp )
00118 distfct = (VectorDistanceFunction) tmp;
00119 }
00120
00121
00122 for( i=0; i<N; i++ ){
00123 ds[i] = distfct( x, X[i], m, optargs);
00124 }
00125 gsl_sort_index( permut, ds, 1, N );
00126
00127 for( i=0; i<k; i++ ){
00128 nn_idx[i]=permut[i];
00129 nn_dist[i] = ds[permut[i]];
00130 }
00131
00132 dprintf("Done\n");
00133
00134 free( ds );
00135 free( permut );
00136 }
00137
00147 void nn_search_k( const SearchTree *S, const double *x, int k, int *nn_idx,
00148 double *nn_dist ){
00149 int i;
00150 size_t *permut;
00151 PriorityQueue *pq=pq_init();
00152 double dmin, dtmp;
00153
00154 permut=(size_t*)malloc(k*sizeof(size_t));
00155
00156 dprintf("searching for %i NN's\n", k );
00157
00158 for( i=0; i<k; i++ ){
00159 nn_idx[i] = (int)(random() / (RAND_MAX / S->N+1));
00160 nn_dist[i]= S->distfct( x, S->d[nn_idx[i]], S->m, S->optargs );
00161 }
00162 #ifdef DEBUG
00163 dprintf(" initial m_k\n");
00164 dblp_print_int( nn_idx, k );
00165 dprintf(" initial d(m_k,q)\n");
00166 dblp_print( nn_dist, k );
00167 #endif
00168
00169
00170 dmin = S->distfct( x, S->d[S->root->c], S->m, S->optargs )-S->root->R;
00171 pq_insert( pq, (void*)S->root, dmin );
00172
00173 TreeNode *C;
00174 PQnode *pqn;
00175 int count, ki;
00176 double dcq, dclq, dcrq;
00177 double pdmin;
00178 while( 1 ){
00179
00180 pqn = pq_pop_node(pq);
00181 C = (TreeNode*)pqn->content;
00182 dmin = pqn->priority;
00183 pqnode_free( pqn );
00184 dcq = S->distfct( x, S->d[C->c], S->m, S->optargs );
00185
00186
00187 for( count=0,i=0; i<k; i++ ){
00188 if( dmin<nn_dist[i] ){
00189 count++;
00190 }
00191 }
00192 if( count==k ){
00193 break;
00194 }
00195
00196
00197 if( tnode_isleaf( C ) ){
00198 for( ki=0; ki<k; ki++ ){
00199 for( i=0; i<C->end-C->start+1; i++ ){
00200 if( nn_dist[ki] < ABS( dcq - C->cdist[i] ) ){
00201 dtmp = S->distfct( x, S->d[S->A[C->start+i]], S->m, S->optargs );
00202 if( dtmp<nn_dist[ki] ){
00203 nn_dist[ki]=dtmp;
00204 nn_idx[ki] =S->A[C->start+i];
00205 }
00206 }
00207 }
00208 }
00209 } else {
00210 pdmin = dmin;
00211 dclq=S->distfct( x, S->d[C->left->c], S->m, S->optargs );
00212 dcrq=S->distfct( x, S->d[C->right->c], S->m, S->optargs );
00213 if( (dtmp=(dclq-dcrq-C->left->g)/2.0) > pdmin ){
00214 dmin = dtmp;
00215 }
00216 if( (dtmp=(dclq - C->left->R)) > pdmin ){
00217 dmin = dtmp;
00218 }
00219 pq_insert( pq, C->left, dmin );
00220 dmin=pdmin;
00221 if( (dtmp=(dcrq-dclq-C->right->g)/2.0) > pdmin ){
00222 dmin = dtmp;
00223 }
00224 if( (dtmp=(dcrq - C->right->R)) > pdmin ){
00225 dmin = dtmp;
00226 }
00227 pq_insert( pq, C->right, dmin);
00228 }
00229 }
00230
00231
00232 gsl_sort_index( permut, nn_dist, 1, k );
00233 int *nnidx_tmp=(int*)malloc( k*sizeof(int));
00234 double *nndist_tmp=(double*)malloc(k*sizeof(double));
00235 memcpy( nnidx_tmp, nn_idx, k*sizeof(int));
00236 memcpy( nndist_tmp, nn_dist, k*sizeof(double));
00237 for( i=0; i<k; i++ ){
00238 nn_idx[i] = nnidx_tmp[permut[i]];
00239 nn_dist[i] = nndist_tmp[permut[i]];
00240 }
00241
00242
00243 free( nnidx_tmp );
00244 free( nndist_tmp );
00245 pq_free( pq );
00246 free( permut );
00247 }
00248
00252 void build_tree_recursive( TreeNode *C, double **D, int N, int *A, int maxel ){
00253 TreeNode *L, *R;
00254 int il,ir;
00255 int i;
00256
00257 C->left=tnode_init();
00258 C->right=tnode_init();
00259 L = C->left; R = C->right;
00260
00261
00262 dprintf("calc left child c\n");
00263 L->c = A[C->start];
00264 for( i=C->start; i<=C->end; i++ ){
00265 if( D[L->c][C->c] < D[A[i]][C->c] ){
00266 L->c = A[i];
00267 }
00268 }
00269
00270 dprintf("calc right child c\n");
00271 R->c = A[C->start];
00272 for( i=C->start; i<C->end; i++ ){
00273 if( D[R->c][L->c] < D[A[i]][L->c] ){
00274 R->c = A[i];
00275 }
00276 }
00277 dprintf("New Child centers: %i, %i\n", L->c, R->c );
00278
00279
00280 L->start = C->start;
00281 L->end = C->start;
00282 R->start = C->end;
00283 R->end = C->end;
00284
00285 il=L->end;
00286 ir=R->start;
00287 double tmp;
00288 while(il <= ir){
00289 while(D[L->c][A[il]] < D[R->c][A[il]]){
00290 if( (tmp=D[R->c][A[il]]-D[L->c][A[il]]) < L->g ){
00291 L->g=tmp;
00292 }
00293 il++;
00294 }
00295 while(D[L->c][A[ir]] > D[R->c][A[ir]]){
00296 if( (tmp=D[L->c][A[ir]]-D[R->c][A[ir]]) < R->g ){
00297 R->g=tmp;
00298 }
00299 ir--;
00300 }
00301 dblp_print_int( A, N );
00302 dprintf("ptrs: il,ir=(%i,%i)\n", il, ir);
00303
00304 if( il<ir ){
00305 SWAPT( int, A[il], A[ir] );
00306 L->end=il;
00307 R->start=ir;
00308 il++; ir--;
00309 } else {
00310 L->end=il-1;
00311 R->start=ir+1;
00312 break;
00313 }
00314 }
00315
00316 dprintf("new cranges: (%i-%i), (%i-%i)\n",
00317 L->start, L->end, R->start, R->end );
00318
00319
00320 L->R = 0;
00321 for( i=L->start; i<=L->end; i++ ){
00322 if( D[A[i]][L->c]>L->R )
00323 L->R = D[A[i]][L->c];
00324 }
00325 R->R = 0;
00326 for( i=R->start; i<=R->end; i++ ){
00327 if( D[A[i]][R->c]>R->R )
00328 R->R = D[A[i]][R->c];
00329 }
00330
00331 dblp_print_int( A, N );
00332 dprintf("L->R=%f, R->R=%f\n", L->R, R->R);
00333
00334
00335 if( L->end-L->start > maxel ){
00336 dprintf("left recursion\n");
00337 build_tree_recursive( L, D, N, A, maxel );
00338 } else {
00339 L->cdist = (double*)malloc( (L->end - L->start + 1)*sizeof(double));
00340 for( i=0; i<L->end-L->start+1; i++ ){
00341 L->cdist[i] = D[L->c][A[L->start+i]];
00342 }
00343 }
00344 if( L->end-L->start > maxel ){
00345 dprintf("right recursion\n");
00346 build_tree_recursive( R, D, N, A, maxel );
00347 } else {
00348 R->cdist = (double*)malloc( (R->end - R->start + 1)*sizeof(double));
00349 for( i=0; i<R->end-R->start+1; i++ ){
00350 R->cdist[i] = D[R->c][A[R->start+i]];
00351 }
00352 }
00353 }
00354
00355 SearchTree* searchtree_init( int n ){
00356 int i;
00357 SearchTree *S=(SearchTree*)malloc(sizeof(SearchTree));
00358 S->A = (int*)malloc(n*sizeof(int));
00359 for( i=0; i<n; i++ ){
00360 S->A[i] = i;
00361 }
00362 S->m = 0;
00363 S->N = n;
00364 S->root = NULL;
00365 return S;
00366 }
00367
00368 TreeNode* tnode_init(){
00369 TreeNode *t=(TreeNode*)malloc(sizeof(TreeNode));
00370 t->c=-1;
00371 t->R=-1;
00372 t->g=DBL_MAX;
00373 t->start=-1;
00374 t->end=-1;
00375 t->cdist=NULL;
00376 t->left=NULL;
00377 t->right=NULL;
00378 return t;
00379 }
00380
00381 bool tnode_isleaf( TreeNode *C ){
00382 if( C->cdist )
00383 return TRUE;
00384 else
00385 return FALSE;
00386 }
00387
00388