SHOGUN  6.0.0
CARTree.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Parijat Mazumdar
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are those
27  * of the authors and should not be interpreted as representing official policies,
28  * either expressed or implied, of the Shogun Development Team.
29  */
30 
35 
36 using namespace Eigen;
37 using namespace shogun;
38 
39 const float64_t CCARTree::MISSING=CMath::MAX_REAL_NUMBER;
40 const float64_t CCARTree::EQ_DELTA=1e-7;
41 const float64_t CCARTree::MIN_SPLIT_GAIN=1e-7;
42 
43 CCARTree::CCARTree()
45 {
46  init();
47 }
48 
49 CCARTree::CCARTree(SGVector<bool> attribute_types, EProblemType prob_type)
51 {
52  init();
53  set_feature_types(attribute_types);
54  set_machine_problem_type(prob_type);
55 }
56 
57 CCARTree::CCARTree(SGVector<bool> attribute_types, EProblemType prob_type, int32_t num_folds, bool cv_prune)
59 {
60  init();
61  set_feature_types(attribute_types);
62  set_machine_problem_type(prob_type);
63  set_num_folds(num_folds);
64  if (cv_prune)
65  set_cv_pruning(cv_prune);
66 }
67 
69 {
71 }
72 
74 {
75  if (lab->get_label_type()==LT_MULTICLASS)
77  else if (lab->get_label_type()==LT_REGRESSION)
79  else
80  SG_ERROR("label type supplied is not supported\n")
81 
82  SG_REF(lab);
84  m_labels=lab;
85 }
86 
88 {
89  m_mode=mode;
90 }
91 
93 {
95  return true;
96  else if (m_mode==PT_REGRESSION && lab->get_label_type()==LT_REGRESSION)
97  return true;
98  else
99  return false;
100 }
101 
103 {
104  REQUIRE(data, "Data required for classification in apply_multiclass\n")
105 
106  // apply multiclass starting from root
107  bnode_t* current=dynamic_cast<bnode_t*>(get_root());
108 
109  REQUIRE(current, "Tree machine not yet trained.\n");
110  CLabels* ret=apply_from_current_node(dynamic_cast<CDenseFeatures<float64_t>*>(data), current);
111 
112  SG_UNREF(current);
113  return dynamic_cast<CMulticlassLabels*>(ret);
114 }
115 
117 {
118  REQUIRE(data, "Data required for classification in apply_multiclass\n")
119 
120  // apply regression starting from root
121  bnode_t* current=dynamic_cast<bnode_t*>(get_root());
122  CLabels* ret=apply_from_current_node(dynamic_cast<CDenseFeatures<float64_t>*>(data), current);
123 
124  SG_UNREF(current);
125  return dynamic_cast<CRegressionLabels*>(ret);
126 }
127 
129 {
130  if (weights.vlen==0)
131  {
132  weights=SGVector<float64_t>(feats->get_num_vectors());
133  weights.fill_vector(weights.vector,weights.vlen,1);
134  }
135 
136  CDynamicObjectArray* pruned_trees=prune_tree(this);
137 
138  int32_t min_index=0;
140  for (int32_t i=0;i<m_alphas->get_num_elements();i++)
141  {
142  CSGObject* element=pruned_trees->get_element(i);
143  bnode_t* root=NULL;
144  if (element!=NULL)
145  root=dynamic_cast<bnode_t*>(element);
146  else
147  SG_ERROR("%d element is NULL\n",i);
148 
149  CLabels* labels=apply_from_current_node(feats, root);
150  float64_t error=compute_error(labels,gnd_truth,weights);
151  if (error<min_error)
152  {
153  min_index=i;
154  min_error=error;
155  }
156 
157  SG_UNREF(labels);
158  SG_UNREF(element);
159  }
160 
161  CSGObject* element=pruned_trees->get_element(min_index);
162  bnode_t* root=NULL;
163  if (element!=NULL)
164  root=dynamic_cast<bnode_t*>(element);
165  else
166  SG_ERROR("%d element is NULL\n",min_index);
167 
168  this->set_root(root);
169 
170  SG_UNREF(pruned_trees);
171  SG_UNREF(element);
172 }
173 
175 {
176  m_weights=w;
177  m_weights_set=true;
178 }
179 
181 {
182  return m_weights;
183 }
184 
186 {
188  m_weights_set=false;
189 }
190 
192 {
193  m_nominal=ft;
194  m_types_set=true;
195 }
196 
198 {
199  return m_nominal;
200 }
201 
203 {
205  m_types_set=false;
206 }
207 
208 int32_t CCARTree::get_num_folds() const
209 {
210  return m_folds;
211 }
212 
213 void CCARTree::set_num_folds(int32_t folds)
214 {
215  REQUIRE(folds>1,"Number of folds is expected to be greater than 1. Supplied value is %d\n",folds)
216  m_folds=folds;
217 }
218 
219 int32_t CCARTree::get_max_depth() const
220 {
221  return m_max_depth;
222 }
223 
224 void CCARTree::set_max_depth(int32_t depth)
225 {
226  REQUIRE(depth>0,"Max allowed tree depth should be greater than 0. Supplied value is %d\n",depth)
227  m_max_depth=depth;
228 }
229 
231 {
232  return m_min_node_size;
233 }
234 
235 void CCARTree::set_min_node_size(int32_t nsize)
236 {
237  REQUIRE(nsize>0,"Min allowed node size should be greater than 0. Supplied value is %d\n",nsize)
238  m_min_node_size=nsize;
239 }
240 
242 {
243  REQUIRE(ep>=0,"Input epsilon value is expected to be greater than or equal to 0\n")
244  m_label_epsilon=ep;
245 }
246 
248 {
249  REQUIRE(data,"Data required for training\n")
250  REQUIRE(data->get_feature_class()==C_DENSE,"Dense data required for training\n")
251 
252  int32_t num_features=(dynamic_cast<CDenseFeatures<float64_t>*>(data))->get_num_features();
253  int32_t num_vectors=(dynamic_cast<CDenseFeatures<float64_t>*>(data))->get_num_vectors();
254 
255  if (m_weights_set)
256  {
257  REQUIRE(m_weights.vlen==num_vectors,"Length of weights vector (currently %d) should be same as"
258  " number of vectors in data (presently %d)",m_weights.vlen,num_vectors)
259  }
260  else
261  {
262  // all weights are equal to 1
263  m_weights=SGVector<float64_t>(num_vectors);
265  }
266 
267  if (m_types_set)
268  {
269  REQUIRE(m_nominal.vlen==num_features,"Length of m_nominal vector (currently %d) should "
270  "be same as number of features in data (presently %d)",m_nominal.vlen,num_features)
271  }
272  else
273  {
274  SG_WARNING("Feature types are not specified. All features are considered as continuous in training")
275  m_nominal=SGVector<bool>(num_features);
277  }
278 
280 
281  if (m_apply_cv_pruning)
282  {
283  CDenseFeatures<float64_t>* feats=dynamic_cast<CDenseFeatures<float64_t>*>(data);
285  }
286 
287  return true;
288 }
289 
291 {
292  m_pre_sort=true;
293  m_sorted_features=sorted_feats;
294  m_sorted_indices=sorted_indices;
295 }
296 
298 {
299  SGMatrix<float64_t> mat=(dynamic_cast<CDenseFeatures<float64_t>*>(data))->get_feature_matrix();
300  sorted_feats = SGMatrix<float64_t>(mat.num_cols, mat.num_rows);
301  sorted_indices = SGMatrix<index_t>(mat.num_cols, mat.num_rows);
302  for(int32_t i=0; i<sorted_indices.num_cols; i++)
303  for(int32_t j=0; j<sorted_indices.num_rows; j++)
304  sorted_indices(j,i)=j;
305 
306  Map<MatrixXd> map_sorted_feats(sorted_feats.matrix, mat.num_cols, mat.num_rows);
307  Map<MatrixXd> map_data(mat.matrix, mat.num_rows, mat.num_cols);
308 
309  map_sorted_feats=map_data.transpose();
310 
311  #pragma omp parallel for
312  for(int32_t i=0; i<sorted_feats.num_cols; i++)
313  CMath::qsort_index(sorted_feats.get_column_vector(i), sorted_indices.get_column_vector(i), sorted_feats.num_rows);
314 
315 }
316 
318 {
319  REQUIRE(labels,"labels have to be supplied\n");
320  REQUIRE(data,"data matrix has to be supplied\n");
321 
322  bnode_t* node=new bnode_t();
323  SGVector<float64_t> labels_vec=(dynamic_cast<CDenseLabels*>(labels))->get_labels();
324  SGMatrix<float64_t> mat=(dynamic_cast<CDenseFeatures<float64_t>*>(data))->get_feature_matrix();
325  int32_t num_feats=mat.num_rows;
326  int32_t num_vecs=mat.num_cols;
327 
328  // calculate node label
329  switch(m_mode)
330  {
331  case PT_REGRESSION:
332  {
333  float64_t sum=0;
334  for (int32_t i=0;i<labels_vec.vlen;i++)
335  sum+=labels_vec[i]*weights[i];
336 
337  // lsd*total_weight=sum_of_squared_deviation
338  float64_t tot=0;
339  node->data.weight_minus_node=tot*least_squares_deviation(labels_vec,weights,tot);
340  node->data.node_label=sum/tot;
341  node->data.total_weight=tot;
342 
343  break;
344  }
345  case PT_MULTICLASS:
346  {
347  SGVector<float64_t> lab=labels_vec.clone();
348  CMath::qsort(lab);
349  // stores max total weight for a single label
350  int32_t max=weights[0];
351  // stores one of the indices having max total weight
352  int32_t maxi=0;
353  int32_t c=weights[0];
354  for (int32_t i=1;i<lab.vlen;i++)
355  {
356  if (lab[i]==lab[i-1])
357  {
358  c+=weights[i];
359  }
360  else if (c>max)
361  {
362  max=c;
363  maxi=i-1;
364  c=weights[i];
365  }
366  else
367  {
368  c=weights[i];
369  }
370  }
371 
372  if (c>max)
373  {
374  max=c;
375  maxi=lab.vlen-1;
376  }
377 
378  node->data.node_label=lab[maxi];
379 
380  // resubstitution error calculation
381  node->data.total_weight=weights.sum(weights);
382  node->data.weight_minus_node=node->data.total_weight-max;
383  break;
384  }
385  default :
386  SG_ERROR("mode should be either PT_MULTICLASS or PT_REGRESSION\n");
387  }
388 
389  // check stopping rules
390  // case 1 : max tree depth reached if max_depth set
391  if ((m_max_depth>0) && (level==m_max_depth))
392  {
393  node->data.num_leaves=1;
394  node->data.weight_minus_branch=node->data.weight_minus_node;
395  return node;
396  }
397 
398  // case 2 : min node size violated if min_node_size specified
399  if ((m_min_node_size>1) && (labels_vec.vlen<=m_min_node_size))
400  {
401  node->data.num_leaves=1;
402  node->data.weight_minus_branch=node->data.weight_minus_node;
403  return node;
404  }
405 
406  // choose best attribute
407  // transit_into_values for left child
408  SGVector<float64_t> left(num_feats);
409  // transit_into_values for right child
410  SGVector<float64_t> right(num_feats);
411  // final data distribution among children
412  SGVector<bool> left_final(num_vecs);
413  int32_t num_missing_final=0;
414  int32_t c_left=-1;
415  int32_t c_right=-1;
416  int32_t best_attribute;
417 
418  SGVector<index_t> indices(num_vecs);
419  if (m_pre_sort)
420  {
421  CSubsetStack* subset_stack = data->get_subset_stack();
422  if (subset_stack->has_subsets())
423  indices=(subset_stack->get_last_subset())->get_subset_idx();
424  else
425  indices.range_fill();
426  SG_UNREF(subset_stack);
427  best_attribute=compute_best_attribute(m_sorted_features,weights,labels,left,right,left_final,num_missing_final,c_left,c_right,0,indices);
428  }
429  else
430  best_attribute=compute_best_attribute(mat,weights,labels,left,right,left_final,num_missing_final,c_left,c_right);
431 
432  if (best_attribute==-1)
433  {
434  node->data.num_leaves=1;
435  node->data.weight_minus_branch=node->data.weight_minus_node;
436  return node;
437  }
438 
439  SGVector<float64_t> left_transit(c_left);
440  SGVector<float64_t> right_transit(c_right);
441  sg_memcpy(left_transit.vector,left.vector,c_left*sizeof(float64_t));
442  sg_memcpy(right_transit.vector,right.vector,c_right*sizeof(float64_t));
443 
444  if (num_missing_final>0)
445  {
446  SGVector<bool> is_left_final(num_vecs-num_missing_final);
447  int32_t ilf=0;
448  for (int32_t i=0;i<num_vecs;i++)
449  {
450  if (mat(best_attribute,i)!=MISSING)
451  is_left_final[ilf++]=left_final[i];
452  }
453 
454  left_final=surrogate_split(mat,weights,is_left_final,best_attribute);
455  }
456 
457  int32_t count_left=0;
458  for (int32_t c=0;c<num_vecs;c++)
459  count_left=(left_final[c])?count_left+1:count_left;
460 
461  SGVector<index_t> subsetl(count_left);
462  SGVector<float64_t> weightsl(count_left);
463  SGVector<index_t> subsetr(num_vecs-count_left);
464  SGVector<float64_t> weightsr(num_vecs-count_left);
465  index_t l=0;
466  index_t r=0;
467  for (int32_t c=0;c<num_vecs;c++)
468  {
469  if (left_final[c])
470  {
471  subsetl[l]=c;
472  weightsl[l++]=weights[c];
473  }
474  else
475  {
476  subsetr[r]=c;
477  weightsr[r++]=weights[c];
478  }
479  }
480 
481  // left child
482  data->add_subset(subsetl);
483  labels->add_subset(subsetl);
484  bnode_t* left_child=CARTtrain(data,weightsl,labels,level+1);
485  data->remove_subset();
486  labels->remove_subset();
487 
488  // right child
489  data->add_subset(subsetr);
490  labels->add_subset(subsetr);
491  bnode_t* right_child=CARTtrain(data,weightsr,labels,level+1);
492  data->remove_subset();
493  labels->remove_subset();
494 
495  // set node parameters
496  node->data.attribute_id=best_attribute;
497  node->left(left_child);
498  node->right(right_child);
499  left_child->data.transit_into_values=left_transit;
500  right_child->data.transit_into_values=right_transit;
501  node->data.num_leaves=left_child->data.num_leaves+right_child->data.num_leaves;
502  node->data.weight_minus_branch=left_child->data.weight_minus_branch+right_child->data.weight_minus_branch;
503 
504  return node;
505 }
506 
508 {
509  float64_t delta=0;
510  if (m_mode==PT_REGRESSION)
511  delta=m_label_epsilon;
512 
513  SGVector<float64_t> ulabels(labels_vec.vlen);
514  SGVector<index_t> sidx=CMath::argsort(labels_vec);
515  ulabels[0]=labels_vec[sidx[0]];
516  n_ulabels=1;
517  int32_t start=0;
518  for (int32_t i=1;i<sidx.vlen;i++)
519  {
520  if (labels_vec[sidx[i]]<=labels_vec[sidx[start]]+delta)
521  continue;
522 
523  start=i;
524  ulabels[n_ulabels]=labels_vec[sidx[i]];
525  n_ulabels++;
526  }
527 
528  return ulabels;
529 }
530 
532  SGVector<float64_t>& left, SGVector<float64_t>& right, SGVector<bool>& is_left_final, int32_t &num_missing_final, int32_t &count_left,
533  int32_t &count_right, int32_t subset_size, const SGVector<index_t>& active_indices)
534 {
535  SGVector<float64_t> labels_vec=(dynamic_cast<CDenseLabels*>(labels))->get_labels();
536  int32_t num_vecs=labels->get_num_labels();
537  int32_t num_feats;
538  if (m_pre_sort)
539  num_feats=mat.num_cols;
540  else
541  num_feats=mat.num_rows;
542 
543  int32_t n_ulabels;
544  SGVector<float64_t> ulabels=get_unique_labels(labels_vec,n_ulabels);
545 
546  // if all labels same early stop
547  if (n_ulabels==1)
548  return -1;
549 
550  float64_t delta=0;
551  if (m_mode==PT_REGRESSION)
552  delta=m_label_epsilon;
553 
554  SGVector<float64_t> total_wclasses(n_ulabels);
555  total_wclasses.zero();
556 
557  SGVector<int32_t> simple_labels(num_vecs);
558  for (int32_t i=0;i<num_vecs;i++)
559  {
560  for (int32_t j=0;j<n_ulabels;j++)
561  {
562  if (CMath::abs(labels_vec[i]-ulabels[j])<=delta)
563  {
564  simple_labels[i]=j;
565  total_wclasses[j]+=weights[i];
566  break;
567  }
568  }
569  }
570 
571  SGVector<index_t> idx(num_feats);
572  idx.range_fill();
573  if (subset_size)
574  {
575  num_feats=subset_size;
576  CMath::permute(idx);
577  }
578 
579  float64_t max_gain=MIN_SPLIT_GAIN;
580  int32_t best_attribute=-1;
581  float64_t best_threshold=0;
582 
583  SGVector<int64_t> indices_mask;
584  SGVector<int32_t> count_indices(mat.num_rows);
585  count_indices.zero();
586  SGVector<int32_t> dupes(num_vecs);
587  dupes.range_fill();
588  if (m_pre_sort)
589  {
590  indices_mask = SGVector<int64_t>(mat.num_rows);
591  indices_mask.set_const(-1);
592  for(int32_t j=0;j<active_indices.size();j++)
593  {
594  if (indices_mask[active_indices[j]]>=0)
595  dupes[indices_mask[active_indices[j]]]=j;
596 
597  indices_mask[active_indices[j]]=j;
598  count_indices[active_indices[j]]++;
599  }
600  }
601 
602  for (int32_t i=0;i<num_feats;i++)
603  {
604  SGVector<float64_t> feats(num_vecs);
605  SGVector<index_t> sorted_args(num_vecs);
606  SGVector<int32_t> temp_count_indices(count_indices.size());
607  sg_memcpy(temp_count_indices.vector, count_indices.vector, sizeof(int32_t)*count_indices.size());
608 
609  if (m_pre_sort)
610  {
611  SGVector<float64_t> temp_col(mat.get_column_vector(idx[i]), mat.num_rows, false);
612  SGVector<index_t> sorted_indices(m_sorted_indices.get_column_vector(idx[i]), mat.num_rows, false);
613  int32_t count=0;
614  for(int32_t j=0;j<mat.num_rows;j++)
615  {
616  if (indices_mask[sorted_indices[j]]>=0)
617  {
618  int32_t count_index = count_indices[sorted_indices[j]];
619  while(count_index>0)
620  {
621  feats[count]=temp_col[j];
622  sorted_args[count]=indices_mask[sorted_indices[j]];
623  ++count;
624  --count_index;
625  }
626  if (count==num_vecs)
627  break;
628  }
629  }
630  }
631  else
632  {
633  for (int32_t j=0;j<num_vecs;j++)
634  feats[j]=mat(idx[i],j);
635 
636  // O(N*logN)
637  sorted_args.range_fill();
638  CMath::qsort_index(feats.vector, sorted_args.vector, feats.size());
639  }
640  int32_t n_nm_vecs=feats.vlen;
641  // number of non-missing vecs
642  while (feats[n_nm_vecs-1]==MISSING)
643  {
644  total_wclasses[simple_labels[sorted_args[n_nm_vecs-1]]]-=weights[sorted_args[n_nm_vecs-1]];
645  n_nm_vecs--;
646  }
647 
648  // if only one unique value - it cannot be used to split
649  if (feats[n_nm_vecs-1]<=feats[0]+EQ_DELTA)
650  continue;
651 
652  if (m_nominal[idx[i]])
653  {
654  SGVector<int32_t> simple_feats(num_vecs);
655  simple_feats.fill_vector(simple_feats.vector,simple_feats.vlen,-1);
656 
657  // convert to simple values
658  simple_feats[0]=0;
659  int32_t c=0;
660  for (int32_t j=1;j<n_nm_vecs;j++)
661  {
662  if (feats[j]==feats[j-1])
663  simple_feats[j]=c;
664  else
665  simple_feats[j]=(++c);
666  }
667 
668  SGVector<float64_t> ufeats(c+1);
669  ufeats[0]=feats[0];
670  int32_t u=0;
671  for (int32_t j=1;j<n_nm_vecs;j++)
672  {
673  if (feats[j]==feats[j-1])
674  continue;
675  else
676  ufeats[++u]=feats[j];
677  }
678 
679  // test all 2^(I-1)-1 possible division between two nodes
680  int32_t num_cases=CMath::pow(2,c);
681  for (int32_t k=1;k<num_cases;k++)
682  {
683  SGVector<float64_t> wleft(n_ulabels);
684  SGVector<float64_t> wright(n_ulabels);
685  wleft.zero();
686  wright.zero();
687 
688  // stores which vectors are assigned to left child
689  SGVector<bool> is_left(num_vecs);
690  is_left.fill_vector(is_left.vector,is_left.vlen,false);
691 
692  // stores which among the categorical values of chosen attribute are assigned left child
693  SGVector<bool> feats_left(c+1);
694 
695  // fill feats_left in a unique way corresponding to the case
696  for (int32_t p=0;p<c+1;p++)
697  feats_left[p]=((k/CMath::pow(2,p))%(CMath::pow(2,p+1))==1);
698 
699  // form is_left
700  for (int32_t j=0;j<n_nm_vecs;j++)
701  {
702  is_left[sorted_args[j]]=feats_left[simple_feats[j]];
703  if (is_left[sorted_args[j]])
704  wleft[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
705  else
706  wright[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
707  }
708  for (int32_t j=n_nm_vecs-1;j>=0;j--)
709  {
710  if(dupes[j]!=j)
711  is_left[j]=is_left[dupes[j]];
712  }
713 
714  float64_t g=0;
715  if (m_mode==PT_MULTICLASS)
716  g=gain(wleft,wright,total_wclasses);
717  else if (m_mode==PT_REGRESSION)
718  g=gain(wleft,wright,total_wclasses,ulabels);
719  else
720  SG_ERROR("Undefined problem statement\n");
721 
722  if (g>max_gain)
723  {
724  best_attribute=idx[i];
725  max_gain=g;
726  sg_memcpy(is_left_final.vector,is_left.vector,is_left.vlen*sizeof(bool));
727  num_missing_final=num_vecs-n_nm_vecs;
728 
729  count_left=0;
730  for (int32_t l=0;l<c+1;l++)
731  count_left=(feats_left[l])?count_left+1:count_left;
732 
733  count_right=c+1-count_left;
734 
735  int32_t l=0;
736  int32_t r=0;
737  for (int32_t w=0;w<c+1;w++)
738  {
739  if (feats_left[w])
740  left[l++]=ufeats[w];
741  else
742  right[r++]=ufeats[w];
743  }
744  }
745  }
746  }
747  else
748  {
749  // O(N)
750  SGVector<float64_t> right_wclasses=total_wclasses.clone();
751  SGVector<float64_t> left_wclasses(n_ulabels);
752  left_wclasses.zero();
753 
754  // O(N)
755  // find best split for non-nominal attribute - choose threshold (z)
756  float64_t z=feats[0];
757  right_wclasses[simple_labels[sorted_args[0]]]-=weights[sorted_args[0]];
758  left_wclasses[simple_labels[sorted_args[0]]]+=weights[sorted_args[0]];
759  for (int32_t j=1;j<n_nm_vecs;j++)
760  {
761  if (feats[j]<=z+EQ_DELTA)
762  {
763  right_wclasses[simple_labels[sorted_args[j]]]-=weights[sorted_args[j]];
764  left_wclasses[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
765  continue;
766  }
767  // O(F)
768  float64_t g=0;
769  if (m_mode==PT_MULTICLASS)
770  g=gain(left_wclasses,right_wclasses,total_wclasses);
771  else if (m_mode==PT_REGRESSION)
772  g=gain(left_wclasses,right_wclasses,total_wclasses,ulabels);
773  else
774  SG_ERROR("Undefined problem statement\n");
775 
776  if (g>max_gain)
777  {
778  max_gain=g;
779  best_attribute=idx[i];
780  best_threshold=z;
781  num_missing_final=num_vecs-n_nm_vecs;
782  }
783 
784  z=feats[j];
785  if (feats[n_nm_vecs-1]<=z+EQ_DELTA)
786  break;
787  right_wclasses[simple_labels[sorted_args[j]]]-=weights[sorted_args[j]];
788  left_wclasses[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]];
789  }
790  }
791 
792  // restore total_wclasses
793  while (n_nm_vecs<feats.vlen)
794  {
795  total_wclasses[simple_labels[sorted_args[n_nm_vecs-1]]]+=weights[sorted_args[n_nm_vecs-1]];
796  n_nm_vecs++;
797  }
798  }
799 
800  if (best_attribute==-1)
801  return -1;
802 
803  if (!m_nominal[best_attribute])
804  {
805  left[0]=best_threshold;
806  right[0]=best_threshold;
807  count_left=1;
808  count_right=1;
809  if (m_pre_sort)
810  {
811  SGVector<float64_t> temp_vec(mat.get_column_vector(best_attribute), mat.num_rows, false);
812  SGVector<index_t> sorted_indices(m_sorted_indices.get_column_vector(best_attribute), mat.num_rows, false);
813  int32_t count=0;
814  for(int32_t i=0;i<mat.num_rows;i++)
815  {
816  if (indices_mask[sorted_indices[i]]>=0)
817  {
818  is_left_final[indices_mask[sorted_indices[i]]]=(temp_vec[i]<=best_threshold);
819  ++count;
820  if (count==num_vecs)
821  break;
822  }
823  }
824  for (int32_t i=num_vecs-1;i>=0;i--)
825  {
826  if(dupes[i]!=i)
827  is_left_final[i]=is_left_final[dupes[i]];
828  }
829 
830  }
831  else
832  {
833  for (int32_t i=0;i<num_vecs;i++)
834  is_left_final[i]=(mat(best_attribute,i)<=best_threshold);
835  }
836  }
837 
838  return best_attribute;
839 }
840 
842 {
843  // return vector - left/right belongingness
844  SGVector<bool> ret(m.num_cols);
845 
846  // ditribute data with known attributes
847  int32_t l=0;
848  float64_t p_l=0.;
849  float64_t total=0.;
850  // stores indices of vectors with missing attribute
852  // stores lambda values corresponding to missing vectors - initialized all with 0
853  CDynamicArray<float64_t>* association_index=new CDynamicArray<float64_t>();
854  for (int32_t i=0;i<m.num_cols;i++)
855  {
856  if (!CMath::fequals(m(attr,i),MISSING,0))
857  {
858  ret[i]=nm_left[l];
859  total+=weights[i];
860  if (nm_left[l++])
861  p_l+=weights[i];
862  }
863  else
864  {
865  missing_vecs->push_back(i);
866  association_index->push_back(0.);
867  }
868  }
869 
870  // for lambda calculation
871  float64_t p_r=(total-p_l)/total;
872  p_l/=total;
873  float64_t p=CMath::min(p_r,p_l);
874 
875  // for each attribute (X') alternative to best split (X)
876  for (int32_t i=0;i<m.num_rows;i++)
877  {
878  if (i==attr)
879  continue;
880 
881  // find set of vectors with non-missing values for both X and X'
882  CDynamicArray<int32_t>* intersect_vecs=new CDynamicArray<int32_t>();
883  for (int32_t j=0;j<m.num_cols;j++)
884  {
885  if (!(CMath::fequals(m(i,j),MISSING,0) || CMath::fequals(m(attr,j),MISSING,0)))
886  intersect_vecs->push_back(j);
887  }
888 
889  if (intersect_vecs->get_num_elements()==0)
890  {
891  SG_UNREF(intersect_vecs);
892  continue;
893  }
894 
895 
896  if (m_nominal[i])
897  handle_missing_vecs_for_nominal_surrogate(m,missing_vecs,association_index,intersect_vecs,ret,weights,p,i);
898  else
899  handle_missing_vecs_for_continuous_surrogate(m,missing_vecs,association_index,intersect_vecs,ret,weights,p,i);
900 
901  SG_UNREF(intersect_vecs);
902  }
903 
904  // if some missing attribute vectors are yet not addressed, use majority rule
905  for (int32_t i=0;i<association_index->get_num_elements();i++)
906  {
907  if (association_index->get_element(i)==0.)
908  ret[missing_vecs->get_element(i)]=(p_l>=p_r);
909  }
910 
911  SG_UNREF(missing_vecs);
912  SG_UNREF(association_index);
913  return ret;
914 }
915 
917  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
918  SGVector<float64_t> weights, float64_t p, int32_t attr)
919 {
920  // for lambda calculation - total weight of all vectors in X intersect X'
921  float64_t denom=0.;
922  SGVector<float64_t> feats(intersect_vecs->get_num_elements());
923  for (int32_t j=0;j<intersect_vecs->get_num_elements();j++)
924  {
925  feats[j]=m(attr,intersect_vecs->get_element(j));
926  denom+=weights[intersect_vecs->get_element(j)];
927  }
928 
929  // unique feature values for X'
930  int32_t num_unique=feats.unique(feats.vector,feats.vlen);
931 
932 
933  // all possible splits for chosen attribute
934  for (int32_t j=0;j<num_unique-1;j++)
935  {
936  float64_t z=feats[j];
937  float64_t numer=0.;
938  float64_t numerc=0.;
939  for (int32_t k=0;k<intersect_vecs->get_num_elements();k++)
940  {
941  // if both go left or both go right
942  if ((m(attr,intersect_vecs->get_element(k))<=z) && is_left[intersect_vecs->get_element(k)])
943  numer+=weights[intersect_vecs->get_element(k)];
944  else if ((m(attr,intersect_vecs->get_element(k))>z) && !is_left[intersect_vecs->get_element(k)])
945  numer+=weights[intersect_vecs->get_element(k)];
946  // complementary split cases - one goes left other right
947  else if ((m(attr,intersect_vecs->get_element(k))<=z) && !is_left[intersect_vecs->get_element(k)])
948  numerc+=weights[intersect_vecs->get_element(k)];
949  else if ((m(attr,intersect_vecs->get_element(k))>z) && is_left[intersect_vecs->get_element(k)])
950  numerc+=weights[intersect_vecs->get_element(k)];
951  }
952 
953  float64_t lambda=0.;
954  if (numer>=numerc)
955  lambda=(p-(1-numer/denom))/p;
956  else
957  lambda=(p-(1-numerc/denom))/p;
958  for (int32_t k=0;k<missing_vecs->get_num_elements();k++)
959  {
960  if ((lambda>association_index->get_element(k)) &&
961  (!CMath::fequals(m(attr,missing_vecs->get_element(k)),MISSING,0)))
962  {
963  association_index->set_element(lambda,k);
964  if (numer>=numerc)
965  is_left[missing_vecs->get_element(k)]=(m(attr,missing_vecs->get_element(k))<=z);
966  else
967  is_left[missing_vecs->get_element(k)]=(m(attr,missing_vecs->get_element(k))>z);
968  }
969  }
970  }
971 }
972 
974  CDynamicArray<float64_t>* association_index, CDynamicArray<int32_t>* intersect_vecs, SGVector<bool> is_left,
975  SGVector<float64_t> weights, float64_t p, int32_t attr)
976 {
977  // for lambda calculation - total weight of all vectors in X intersect X'
978  float64_t denom=0.;
979  SGVector<float64_t> feats(intersect_vecs->get_num_elements());
980  for (int32_t j=0;j<intersect_vecs->get_num_elements();j++)
981  {
982  feats[j]=m(attr,intersect_vecs->get_element(j));
983  denom+=weights[intersect_vecs->get_element(j)];
984  }
985 
986  // unique feature values for X'
987  int32_t num_unique=feats.unique(feats.vector,feats.vlen);
988 
989  // scan all splits for chosen alternative attribute X'
990  int32_t num_cases=CMath::pow(2,(num_unique-1));
991  for (int32_t j=1;j<num_cases;j++)
992  {
993  SGVector<bool> feats_left(num_unique);
994  for (int32_t k=0;k<num_unique;k++)
995  feats_left[k]=((j/CMath::pow(2,k))%(CMath::pow(2,k+1))==1);
996 
997  SGVector<bool> intersect_vecs_left(intersect_vecs->get_num_elements());
998  for (int32_t k=0;k<intersect_vecs->get_num_elements();k++)
999  {
1000  for (int32_t q=0;q<num_unique;q++)
1001  {
1002  if (feats[q]==m(attr,intersect_vecs->get_element(k)))
1003  {
1004  intersect_vecs_left[k]=feats_left[q];
1005  break;
1006  }
1007  }
1008  }
1009 
1010  float64_t numer=0.;
1011  float64_t numerc=0.;
1012  for (int32_t k=0;k<intersect_vecs->get_num_elements();k++)
1013  {
1014  // if both go left or both go right
1015  if (intersect_vecs_left[k]==is_left[intersect_vecs->get_element(k)])
1016  numer+=weights[intersect_vecs->get_element(k)];
1017  else
1018  numerc+=weights[intersect_vecs->get_element(k)];
1019  }
1020 
1021  // lambda for this split (2 case identical split/complementary split)
1022  float64_t lambda=0.;
1023  if (numer>=numerc)
1024  lambda=(p-(1-numer/denom))/p;
1025  else
1026  lambda=(p-(1-numerc/denom))/p;
1027 
1028  // address missing value vectors not yet addressed or addressed using worse split
1029  for (int32_t k=0;k<missing_vecs->get_num_elements();k++)
1030  {
1031  if ((lambda>association_index->get_element(k)) &&
1032  (!CMath::fequals(m(attr,missing_vecs->get_element(k)),MISSING,0)))
1033  {
1034  association_index->set_element(lambda,k);
1035  // decide left/right based on which feature value the chosen data point has
1036  for (int32_t q=0;q<num_unique;q++)
1037  {
1038  if (feats[q]==m(attr,missing_vecs->get_element(k)))
1039  {
1040  if (numer>=numerc)
1041  is_left[missing_vecs->get_element(k)]=feats_left[q];
1042  else
1043  is_left[missing_vecs->get_element(k)]=!feats_left[q];
1044 
1045  break;
1046  }
1047  }
1048  }
1049  }
1050  }
1051 }
1052 
1054  SGVector<float64_t> feats)
1055 {
1056  float64_t total_lweight=0;
1057  float64_t total_rweight=0;
1058  float64_t total_weight=0;
1059 
1060  float64_t lsd_n=least_squares_deviation(feats,wtotal,total_weight);
1061  float64_t lsd_l=least_squares_deviation(feats,wleft,total_lweight);
1062  float64_t lsd_r=least_squares_deviation(feats,wright,total_rweight);
1063 
1064  return lsd_n-(lsd_l*(total_lweight/total_weight))-(lsd_r*(total_rweight/total_weight));
1065 }
1066 
1068 {
1069  float64_t total_lweight=0;
1070  float64_t total_rweight=0;
1071  float64_t total_weight=0;
1072 
1073  float64_t gini_n=gini_impurity_index(wtotal,total_weight);
1074  float64_t gini_l=gini_impurity_index(wleft,total_lweight);
1075  float64_t gini_r=gini_impurity_index(wright,total_rweight);
1076  return gini_n-(gini_l*(total_lweight/total_weight))-(gini_r*(total_rweight/total_weight));
1077 }
1078 
1079 float64_t CCARTree::gini_impurity_index(const SGVector<float64_t>& weighted_lab_classes, float64_t &total_weight)
1080 {
1081  Map<VectorXd> map_weighted_lab_classes(weighted_lab_classes.vector, weighted_lab_classes.size());
1082  total_weight=map_weighted_lab_classes.sum();
1083  float64_t gini=map_weighted_lab_classes.dot(map_weighted_lab_classes);
1084 
1085  gini=1.0-(gini/(total_weight*total_weight));
1086  return gini;
1087 }
1088 
1090 {
1091 
1092  Map<VectorXd> map_weights(weights.vector, weights.size());
1093  Map<VectorXd> map_feats(feats.vector, weights.size());
1094  float64_t mean=map_weights.dot(map_feats);
1095  total_weight=map_weights.sum();
1096 
1097  mean/=total_weight;
1098  float64_t dev=0;
1099  for (int32_t i=0;i<weights.vlen;i++)
1100  dev+=weights[i]*(feats[i]-mean)*(feats[i]-mean);
1101 
1102  return dev/total_weight;
1103 }
1104 
1106 {
1107  int32_t num_vecs=feats->get_num_vectors();
1108  REQUIRE(num_vecs>0, "No data provided in apply\n");
1109 
1110  SGVector<float64_t> labels(num_vecs);
1111  for (int32_t i=0;i<num_vecs;i++)
1112  {
1113  SGVector<float64_t> sample=feats->get_feature_vector(i);
1114  bnode_t* node=current;
1115  SG_REF(node);
1116 
1117  // until leaf is reached
1118  while(node->data.num_leaves!=1)
1119  {
1120  bnode_t* leftchild=node->left();
1121 
1122  if (m_nominal[node->data.attribute_id])
1123  {
1124  SGVector<float64_t> comp=leftchild->data.transit_into_values;
1125  bool flag=false;
1126  for (int32_t k=0;k<comp.vlen;k++)
1127  {
1128  if (comp[k]==sample[node->data.attribute_id])
1129  {
1130  flag=true;
1131  break;
1132  }
1133  }
1134 
1135  if (flag)
1136  {
1137  SG_UNREF(node);
1138  node=leftchild;
1139  SG_REF(leftchild);
1140  }
1141  else
1142  {
1143  SG_UNREF(node);
1144  node=node->right();
1145  }
1146  }
1147  else
1148  {
1149  if (sample[node->data.attribute_id]<=leftchild->data.transit_into_values[0])
1150  {
1151  SG_UNREF(node);
1152  node=leftchild;
1153  SG_REF(leftchild);
1154  }
1155  else
1156  {
1157  SG_UNREF(node);
1158  node=node->right();
1159  }
1160  }
1161 
1162  SG_UNREF(leftchild);
1163  }
1164 
1165  labels[i]=node->data.node_label;
1166  SG_UNREF(node);
1167  }
1168 
1169  switch(m_mode)
1170  {
1171  case PT_MULTICLASS:
1172  {
1173  CMulticlassLabels* mlabels=new CMulticlassLabels(labels);
1174  return mlabels;
1175  }
1176 
1177  case PT_REGRESSION:
1178  {
1179  CRegressionLabels* rlabels=new CRegressionLabels(labels);
1180  return rlabels;
1181  }
1182 
1183  default:
1184  SG_ERROR("mode should be either PT_MULTICLASS or PT_REGRESSION\n");
1185  }
1186 
1187  return NULL;
1188 }
1189 
1191 {
1192  int32_t num_vecs=data->get_num_vectors();
1193 
1194  // divide data into V folds randomly
1195  SGVector<int32_t> subid(num_vecs);
1196  subid.random_vector(subid.vector,subid.vlen,0,folds-1);
1197 
1198  // for each fold subset
1201  SGVector<int32_t> num_alphak(folds);
1202  for (int32_t i=0;i<folds;i++)
1203  {
1204  // for chosen fold, create subset for training parameters
1205  CDynamicArray<int32_t>* test_indices=new CDynamicArray<int32_t>();
1206  CDynamicArray<int32_t>* train_indices=new CDynamicArray<int32_t>();
1207  for (int32_t j=0;j<num_vecs;j++)
1208  {
1209  if (subid[j]==i)
1210  test_indices->push_back(j);
1211  else
1212  train_indices->push_back(j);
1213  }
1214 
1215  if (test_indices->get_num_elements()==0 || train_indices->get_num_elements()==0)
1216  {
1217  SG_ERROR("Unfortunately you have reached the very low probability event where atleast one of "
1218  "the subsets in cross-validation is not represented at all. Please re-run.")
1219  }
1220 
1221  SGVector<int32_t> subset(train_indices->get_array(),train_indices->get_num_elements(),false);
1222  data->add_subset(subset);
1223  m_labels->add_subset(subset);
1224  SGVector<float64_t> subset_weights(train_indices->get_num_elements());
1225  for (int32_t j=0;j<train_indices->get_num_elements();j++)
1226  subset_weights[j]=m_weights[train_indices->get_element(j)];
1227 
1228  // train with training subset
1229  bnode_t* root=CARTtrain(data,subset_weights,m_labels,0);
1230 
1231  // prune trained tree
1233  tmax->set_root(root);
1234  CDynamicObjectArray* pruned_trees=prune_tree(tmax);
1235 
1236  data->remove_subset();
1238  subset=SGVector<int32_t>(test_indices->get_array(),test_indices->get_num_elements(),false);
1239  data->add_subset(subset);
1240  m_labels->add_subset(subset);
1241  subset_weights=SGVector<float64_t>(test_indices->get_num_elements());
1242  for (int32_t j=0;j<test_indices->get_num_elements();j++)
1243  subset_weights[j]=m_weights[test_indices->get_element(j)];
1244 
1245  // calculate R_CV values for each alpha_k using test subset and store them
1246  num_alphak[i]=m_alphas->get_num_elements();
1247  for (int32_t j=0;j<m_alphas->get_num_elements();j++)
1248  {
1249  alphak->push_back(m_alphas->get_element(j));
1250  CSGObject* jth_element=pruned_trees->get_element(j);
1251  bnode_t* current_root=NULL;
1252  if (jth_element!=NULL)
1253  current_root=dynamic_cast<bnode_t*>(jth_element);
1254  else
1255  SG_ERROR("%d element is NULL which should not be",j);
1256 
1257  CLabels* labels=apply_from_current_node(data, current_root);
1258  float64_t error=compute_error(labels, m_labels, subset_weights);
1259  r_cv->push_back(error);
1260  SG_UNREF(labels);
1261  SG_UNREF(jth_element);
1262  }
1263 
1264  data->remove_subset();
1266  SG_UNREF(train_indices);
1267  SG_UNREF(test_indices);
1268  SG_UNREF(tmax);
1269  SG_UNREF(pruned_trees);
1270  }
1271 
1272  // prune the original T_max
1273  CDynamicObjectArray* pruned_trees=prune_tree(this);
1274 
1275  // find subtree with minimum R_cv
1276  int32_t min_index=-1;
1278  for (int32_t i=0;i<m_alphas->get_num_elements();i++)
1279  {
1280  float64_t alpha=0.;
1281  if (i==m_alphas->get_num_elements()-1)
1282  alpha=m_alphas->get_element(i)+1;
1283  else
1285 
1286  float64_t rv=0.;
1287  int32_t base=0;
1288  for (int32_t j=0;j<folds;j++)
1289  {
1290  bool flag=false;
1291  for (int32_t k=base;k<num_alphak[j]+base-1;k++)
1292  {
1293  if (alphak->get_element(k)<=alpha && alphak->get_element(k+1)>alpha)
1294  {
1295  rv+=r_cv->get_element(k);
1296  flag=true;
1297  break;
1298  }
1299  }
1300 
1301  if (!flag)
1302  rv+=r_cv->get_element(num_alphak[j]+base-1);
1303 
1304  base+=num_alphak[j];
1305  }
1306 
1307  if (rv<min_r_cv)
1308  {
1309  min_index=i;
1310  min_r_cv=rv;
1311  }
1312  }
1313 
1314  CSGObject* element=pruned_trees->get_element(min_index);
1315  bnode_t* best_tree_root=NULL;
1316  if (element!=NULL)
1317  best_tree_root=dynamic_cast<bnode_t*>(element);
1318  else
1319  SG_ERROR("%d element is NULL which should not be",min_index);
1320 
1321  this->set_root(best_tree_root);
1322 
1323  SG_UNREF(element);
1324  SG_UNREF(pruned_trees);
1325  SG_UNREF(r_cv);
1326  SG_UNREF(alphak);
1327 }
1328 
1330 {
1331  REQUIRE(labels,"input labels cannot be NULL");
1332  REQUIRE(reference,"reference labels cannot be NULL")
1333 
1334  CDenseLabels* gnd_truth=dynamic_cast<CDenseLabels*>(reference);
1335  CDenseLabels* result=dynamic_cast<CDenseLabels*>(labels);
1336 
1337  float64_t denom=weights.sum(weights);
1338  float64_t numer=0.;
1339  switch (m_mode)
1340  {
1341  case PT_MULTICLASS:
1342  {
1343  for (int32_t i=0;i<weights.vlen;i++)
1344  {
1345  if (gnd_truth->get_label(i)!=result->get_label(i))
1346  numer+=weights[i];
1347  }
1348 
1349  return numer/denom;
1350  }
1351 
1352  case PT_REGRESSION:
1353  {
1354  for (int32_t i=0;i<weights.vlen;i++)
1355  numer+=weights[i]*CMath::pow((gnd_truth->get_label(i)-result->get_label(i)),2);
1356 
1357  return numer/denom;
1358  }
1359 
1360  default:
1361  SG_ERROR("Case not possible\n");
1362  }
1363 
1364  return 0.;
1365 }
1366 
1368 {
1369  REQUIRE(tree, "Tree not provided for pruning.\n");
1370 
1372  SG_UNREF(m_alphas);
1374  SG_REF(m_alphas);
1375 
1376  // base tree alpha_k=0
1377  m_alphas->push_back(0);
1379  SG_REF(t1);
1380  node_t* t1root=t1->get_root();
1381  bnode_t* t1_root=NULL;
1382  if (t1root!=NULL)
1383  t1_root=dynamic_cast<bnode_t*>(t1root);
1384  else
1385  SG_ERROR("t1_root is NULL. This is not expected\n")
1386 
1387  form_t1(t1_root);
1388  trees->push_back(t1_root);
1389  while(t1_root->data.num_leaves>1)
1390  {
1392  SG_REF(t2);
1393 
1394  node_t* t2root=t2->get_root();
1395  bnode_t* t2_root=NULL;
1396  if (t2root!=NULL)
1397  t2_root=dynamic_cast<bnode_t*>(t2root);
1398  else
1399  SG_ERROR("t1_root is NULL. This is not expected\n")
1400 
1401  float64_t a_k=find_weakest_alpha(t2_root);
1402  m_alphas->push_back(a_k);
1403  cut_weakest_link(t2_root,a_k);
1404  trees->push_back(t2_root);
1405 
1406  SG_UNREF(t1);
1407  SG_UNREF(t1_root);
1408  t1=t2;
1409  t1_root=t2_root;
1410  }
1411 
1412  SG_UNREF(t1);
1413  SG_UNREF(t1_root);
1414  return trees;
1415 }
1416 
1418 {
1419  if (node->data.num_leaves!=1)
1420  {
1421  bnode_t* left=node->left();
1422  bnode_t* right=node->right();
1423 
1424  SGVector<float64_t> weak_links(3);
1425  weak_links[0]=find_weakest_alpha(left);
1426  weak_links[1]=find_weakest_alpha(right);
1427  weak_links[2]=(node->data.weight_minus_node-node->data.weight_minus_branch)/node->data.total_weight;
1428  weak_links[2]/=(node->data.num_leaves-1.0);
1429 
1430  SG_UNREF(left);
1431  SG_UNREF(right);
1432  return CMath::min(weak_links.vector,weak_links.vlen);
1433  }
1434 
1435  return CMath::MAX_REAL_NUMBER;
1436 }
1437 
1439 {
1440  if (node->data.num_leaves==1)
1441  return;
1442 
1443  float64_t g=(node->data.weight_minus_node-node->data.weight_minus_branch)/node->data.total_weight;
1444  g/=(node->data.num_leaves-1.0);
1445  if (alpha==g)
1446  {
1447  node->data.num_leaves=1;
1448  node->data.weight_minus_branch=node->data.weight_minus_node;
1449  CDynamicObjectArray* children=new CDynamicObjectArray();
1450  node->set_children(children);
1451 
1452  SG_UNREF(children);
1453  }
1454  else
1455  {
1456  bnode_t* left=node->left();
1457  bnode_t* right=node->right();
1458  cut_weakest_link(left,alpha);
1459  cut_weakest_link(right,alpha);
1460  node->data.num_leaves=left->data.num_leaves+right->data.num_leaves;
1461  node->data.weight_minus_branch=left->data.weight_minus_branch+right->data.weight_minus_branch;
1462 
1463  SG_UNREF(left);
1464  SG_UNREF(right);
1465  }
1466 }
1467 
1469 {
1470  if (node->data.num_leaves!=1)
1471  {
1472  bnode_t* left=node->left();
1473  bnode_t* right=node->right();
1474 
1475  form_t1(left);
1476  form_t1(right);
1477 
1478  node->data.num_leaves=left->data.num_leaves+right->data.num_leaves;
1479  node->data.weight_minus_branch=left->data.weight_minus_branch+right->data.weight_minus_branch;
1480  if (node->data.weight_minus_node==node->data.weight_minus_branch)
1481  {
1482  node->data.num_leaves=1;
1483  CDynamicObjectArray* children=new CDynamicObjectArray();
1484  node->set_children(children);
1485 
1486  SG_UNREF(children);
1487  }
1488 
1489  SG_UNREF(left);
1490  SG_UNREF(right);
1491  }
1492 }
1493 
1495 {
1499  m_pre_sort=false;
1500  m_types_set=false;
1501  m_weights_set=false;
1502  m_apply_cv_pruning=false;
1503  m_folds=5;
1505  SG_REF(m_alphas);
1506  m_max_depth=0;
1507  m_min_node_size=0;
1508  m_label_epsilon=1e-7;
1511 
1512  SG_ADD(&m_pre_sort, "m_pre_sort", "presort", MS_NOT_AVAILABLE);
1513  SG_ADD(&m_sorted_features, "m_sorted_features", "sorted feats", MS_NOT_AVAILABLE);
1514  SG_ADD(&m_sorted_indices, "m_sorted_indices", "sorted indices", MS_NOT_AVAILABLE);
1515  SG_ADD(&m_nominal, "m_nominal", "feature types", MS_NOT_AVAILABLE);
1516  SG_ADD(&m_weights, "m_weights", "weights", MS_NOT_AVAILABLE);
1517  SG_ADD(&m_weights_set, "m_weights_set", "weights set", MS_NOT_AVAILABLE);
1518  SG_ADD(&m_types_set, "m_types_set", "feature types set", MS_NOT_AVAILABLE);
1519  SG_ADD(&m_apply_cv_pruning, "m_apply_cv_pruning", "apply cross validation pruning", MS_NOT_AVAILABLE);
1520  SG_ADD(&m_folds, "m_folds", "number of subsets for cross validation", MS_NOT_AVAILABLE);
1521  SG_ADD(&m_max_depth, "m_max_depth", "max allowed tree depth", MS_NOT_AVAILABLE)
1522  SG_ADD(&m_min_node_size, "m_min_node_size", "min allowed node size", MS_NOT_AVAILABLE)
1523  SG_ADD(&m_label_epsilon, "m_label_epsilon", "epsilon for labels", MS_NOT_AVAILABLE)
1524  SG_ADD((machine_int_t*)&m_mode, "m_mode", "problem type (multiclass or regression)", MS_NOT_AVAILABLE)
1525 }
void set_cv_pruning(bool cv_pruning)
Definition: CARTree.h:214
CLabels * apply_from_current_node(CDenseFeatures< float64_t > *feats, bnode_t *current)
Definition: CARTree.cpp:1105
bool m_types_set
Definition: CARTree.h:444
virtual int32_t compute_best_attribute(const SGMatrix< float64_t > &mat, const SGVector< float64_t > &weights, CLabels *labels, SGVector< float64_t > &left, SGVector< float64_t > &right, SGVector< bool > &is_left_final, int32_t &num_missing, int32_t &count_left, int32_t &count_right, int32_t subset_size=0, const SGVector< int32_t > &active_indices=SGVector< index_t >())
Definition: CARTree.cpp:531
static void permute(SGVector< T > v, CRandom *rand=NULL)
Definition: Math.h:1165
bool set_element(T e, int32_t idx1, int32_t idx2=0, int32_t idx3=0)
Definition: DynamicArray.h:306
The node of the tree structure forming a TreeMachine The node contains pointer to its parent and poin...
static void random_vector(T *vec, int32_t len, T min_value, T max_value)
Definition: SGVector.cpp:605
virtual ELabelType get_label_type() const =0
void set_weights(SGVector< float64_t > w)
Definition: CARTree.cpp:174
Real Labels are real-valued labels.
void set_machine_problem_type(EProblemType mode)
Definition: CARTree.cpp:87
CDynamicObjectArray * prune_tree(CTreeMachine< CARTreeNodeData > *tree)
Definition: CARTree.cpp:1367
int32_t index_t
Definition: common.h:72
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
SGMatrix< index_t > m_sorted_indices
Definition: CARTree.h:438
virtual int32_t get_num_labels() const =0
real valued labels (e.g. for regression, classifier outputs)
Definition: LabelTypes.h:22
static void qsort_index(T1 *output, T2 *index, uint32_t size)
Definition: Math.h:2223
multi-class labels 0,1,...
Definition: LabelTypes.h:20
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
Definition: CARTree.cpp:102
float64_t find_weakest_alpha(bnode_t *node)
Definition: CARTree.cpp:1417
static T sum(T *vec, int32_t len)
Return sum(vec)
Definition: SGVector.h:392
void form_t1(bnode_t *node)
Definition: CARTree.cpp:1468
Definition: SGMatrix.h:24
virtual bool has_subsets() const
Definition: SubsetStack.h:89
CSGObject * get_element(int32_t index) const
CLabels * m_labels
Definition: Machine.h:365
#define SG_ERROR(...)
Definition: SGIO.h:128
#define REQUIRE(x,...)
Definition: SGIO.h:205
static const float64_t EQ_DELTA
Definition: CARTree.h:422
bool m_apply_cv_pruning
Definition: CARTree.h:450
virtual ~CCARTree()
Definition: CARTree.cpp:68
CTreeMachineNode< CARTreeNodeData > * get_root()
Definition: TreeMachine.h:88
SGVector< bool > surrogate_split(SGMatrix< float64_t > data, SGVector< float64_t > weights, SGVector< bool > nm_left, int32_t attr)
Definition: CARTree.cpp:841
virtual bool train_machine(CFeatures *data=NULL)
Definition: CARTree.cpp:247
float64_t get_label(int32_t idx)
int32_t m_max_depth
Definition: CARTree.h:462
std::enable_if<!std::is_same< T, complex128_t >::value, float64_t >::type mean(const Container< T > &a)
float64_t m_label_epsilon
Definition: CARTree.h:426
#define SG_REF(x)
Definition: SGObject.h:52
virtual void set_labels(CLabels *lab)
Definition: CARTree.cpp:73
void set_root(CTreeMachineNode< CARTreeNodeData > *root)
Definition: TreeMachine.h:78
ST * get_feature_vector(int32_t num, int32_t &len, bool &dofree)
class to add subset support to another class. A CSubsetStackStack instance should be added and wrappe...
Definition: SubsetStack.h:37
static void qsort(T *output, int32_t size)
Definition: Math.h:1334
int32_t get_max_depth() const
Definition: CARTree.cpp:219
Multiclass Labels for multi-class classification.
static bool fequals(const T &a, const T &b, const float64_t eps, bool tolerant=false)
Definition: Math.h:331
virtual void set_children(CDynamicObjectArray *children)
virtual CSubsetStack * get_subset_stack()
Definition: Features.cpp:334
void clear_feature_types()
Definition: CARTree.cpp:202
float64_t gain(SGVector< float64_t > wleft, SGVector< float64_t > wright, SGVector< float64_t > wtotal, SGVector< float64_t > labels)
Definition: CARTree.cpp:1053
T * get_column_vector(index_t col) const
Definition: SGMatrix.h:140
EProblemType
Definition: Machine.h:110
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:125
float64_t least_squares_deviation(const SGVector< float64_t > &labels, const SGVector< float64_t > &weights, float64_t &total_weight)
Definition: CARTree.cpp:1089
void set_min_node_size(int32_t nsize)
Definition: CARTree.cpp:235
void right(CBinaryTreeMachineNode *r)
void set_sorted_features(SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
Definition: CARTree.cpp:290
void handle_missing_vecs_for_continuous_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: CARTree.cpp:916
void set_num_folds(int32_t folds)
Definition: CARTree.cpp:213
float64_t gini_impurity_index(const SGVector< float64_t > &weighted_lab_classes, float64_t &total_weight)
Definition: CARTree.cpp:1079
double float64_t
Definition: common.h:60
int32_t m_min_node_size
Definition: CARTree.h:465
int32_t m_folds
Definition: CARTree.h:453
static SGVector< index_t > argsort(SGVector< T > vector)
Definition: Math.h:1620
bool m_weights_set
Definition: CARTree.h:447
virtual void remove_subset()
Definition: Labels.cpp:49
void range_fill(T start=0)
Definition: SGVector.cpp:208
index_t num_rows
Definition: SGMatrix.h:463
CTreeMachine * clone_tree()
Definition: TreeMachine.h:97
void clear_weights()
Definition: CARTree.cpp:185
virtual void add_subset(SGVector< index_t > subset)
Definition: Labels.cpp:39
virtual EFeatureClass get_feature_class() const =0
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
static void fill_vector(T *vec, int32_t len, T value)
Definition: SGVector.cpp:264
index_t num_cols
Definition: SGMatrix.h:465
int32_t get_min_node_size() const
Definition: CARTree.cpp:230
SGMatrix< float64_t > m_sorted_features
Definition: CARTree.h:435
virtual CBinaryTreeMachineNode< CARTreeNodeData > * CARTtrain(CFeatures *data, SGVector< float64_t > weights, CLabels *labels, int32_t level)
Definition: CARTree.cpp:317
int32_t get_num_elements() const
Definition: DynamicArray.h:200
void set_max_depth(int32_t depth)
Definition: CARTree.cpp:224
void set_const(T const_elem)
Definition: SGVector.cpp:184
void handle_missing_vecs_for_nominal_surrogate(SGMatrix< float64_t > m, CDynamicArray< int32_t > *missing_vecs, CDynamicArray< float64_t > *association_index, CDynamicArray< int32_t > *intersect_vecs, SGVector< bool > is_left, SGVector< float64_t > weights, float64_t p, int32_t attr)
Definition: CARTree.cpp:973
structure to store data of a node of CART. This can be used as a template type in TreeMachineNode cla...
void pre_sort_features(CFeatures *data, SGMatrix< float64_t > &sorted_feats, SGMatrix< index_t > &sorted_indices)
Definition: CARTree.cpp:297
#define SG_UNREF(x)
Definition: SGObject.h:53
virtual int32_t get_num_vectors() const
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
const T & get_element(int32_t idx1, int32_t idx2=0, int32_t idx3=0) const
Definition: DynamicArray.h:212
T sum(const Container< T > &a, bool no_diag=false)
virtual void remove_subset()
Definition: Features.cpp:322
virtual bool is_label_valid(CLabels *lab) const
Definition: CARTree.cpp:92
int machine_int_t
Definition: common.h:69
void set_feature_types(SGVector< bool > ft)
Definition: CARTree.cpp:191
CBinaryTreeMachineNode< CARTreeNodeData > bnode_t
Definition: TreeMachine.h:55
int32_t get_num_folds() const
Definition: CARTree.cpp:208
The class Features is the base class of all feature objects.
Definition: Features.h:68
static T min(T a, T b)
Definition: Math.h:153
SGVector< bool > m_nominal
Definition: CARTree.h:429
float64_t compute_error(CLabels *labels, CLabels *reference, SGVector< float64_t > weights)
Definition: CARTree.cpp:1329
void cut_weakest_link(bnode_t *node, float64_t alpha)
Definition: CARTree.cpp:1438
static const float64_t MIN_SPLIT_GAIN
Definition: CARTree.h:419
virtual CRegressionLabels * apply_regression(CFeatures *data=NULL)
Definition: CARTree.cpp:116
static float base
Definition: JLCoverTree.h:89
void prune_using_test_dataset(CDenseFeatures< float64_t > *feats, CLabels *gnd_truth, SGVector< float64_t > weights=SGVector< float64_t >())
Definition: CARTree.cpp:128
class TreeMachine, a base class for tree based multiclass classifiers. This class is derived from CBa...
Definition: TreeMachine.h:48
#define SG_WARNING(...)
Definition: SGIO.h:127
#define SG_ADD(...)
Definition: SGObject.h:94
SGVector< float64_t > m_weights
Definition: CARTree.h:432
static float32_t sqrt(float32_t x)
Definition: Math.h:454
Dense integer or floating point labels.
Definition: DenseLabels.h:35
CDynamicArray< float64_t > * m_alphas
Definition: CARTree.h:459
T max(const Container< T > &a)
CSubset * get_last_subset() const
Definition: SubsetStack.h:98
SGVector< bool > get_feature_types() const
Definition: CARTree.cpp:197
static int32_t unique(T *output, int32_t size)
Definition: SGVector.cpp:827
SGVector< float64_t > get_weights() const
Definition: CARTree.cpp:180
void left(CBinaryTreeMachineNode *l)
SGVector< T > clone() const
Definition: SGVector.cpp:247
static int32_t pow(bool x, int32_t n)
Definition: Math.h:530
static const float64_t MAX_REAL_NUMBER
Definition: Math.h:2082
virtual void add_subset(SGVector< index_t > subset)
Definition: Features.cpp:310
index_t vlen
Definition: SGVector.h:545
SGVector< float64_t > get_unique_labels(SGVector< float64_t > labels_vec, int32_t &n_ulabels)
Definition: CARTree.cpp:507
static T abs(T a)
Definition: Math.h:175
void set_label_epsilon(float64_t epsilon)
Definition: CARTree.cpp:241
int32_t size() const
Definition: SGVector.h:136
void prune_by_cross_validation(CDenseFeatures< float64_t > *data, int32_t folds)
Definition: CARTree.cpp:1190
static const float64_t MISSING
Definition: CARTree.h:416
EProblemType m_mode
Definition: CARTree.h:456

SHOGUN Machine Learning Toolbox - Documentation