mlpack  2.0.1
gini_impurity.hpp
Go to the documentation of this file.
1 
15 #ifndef __MLPACK_METHODS_HOEFFDING_TREES_GINI_INDEX_HPP
16 #define __MLPACK_METHODS_HOEFFDING_TREES_GINI_INDEX_HPP
17 
18 #include <mlpack/core.hpp>
19 
20 namespace mlpack {
21 namespace tree {
22 
24 {
25  public:
26  static double Evaluate(const arma::Mat<size_t>& counts)
27  {
28  // We need to sum over the difference between the un-split node and the
29  // split nodes. First we'll calculate the number of elements in each split
30  // and total.
31  size_t numElem = 0;
32  arma::vec splitCounts(counts.n_cols);
33  for (size_t i = 0; i < counts.n_cols; ++i)
34  {
35  splitCounts[i] = arma::accu(counts.col(i));
36  numElem += splitCounts[i];
37  }
38 
39  // Corner case: if there are no elements, the impurity is zero.
40  if (numElem == 0)
41  return 0.0;
42 
43  arma::Col<size_t> classCounts = arma::sum(counts, 1);
44 
45  // Calculate the Gini impurity of the un-split node.
46  double impurity = 0.0;
47  for (size_t i = 0; i < classCounts.n_elem; ++i)
48  {
49  const double f = ((double) classCounts[i] / (double) numElem);
50  impurity += f * (1.0 - f);
51  }
52 
53  // Now calculate the impurity of the split nodes and subtract them from the
54  // overall impurity.
55  for (size_t i = 0; i < counts.n_cols; ++i)
56  {
57  if (splitCounts[i] > 0)
58  {
59  double splitImpurity = 0.0;
60  for (size_t j = 0; j < counts.n_rows; ++j)
61  {
62  const double f = ((double) counts(j, i) / (double) splitCounts[i]);
63  splitImpurity += f * (1.0 - f);
64  }
65 
66  impurity -= ((double) splitCounts[i] / (double) numElem) *
67  splitImpurity;
68  }
69  }
70 
71  return impurity;
72  }
73 
79  static double Range(const size_t numClasses)
80  {
81  // The best possible case is that only one class exists, which gives a Gini
82  // impurity of 0. The worst possible case is that the classes are evenly
83  // distributed, which gives n * (1/n * (1 - 1/n)) = 1 - 1/n.
84  return 1.0 - (1.0 / double(numClasses));
85  }
86 };
87 
88 } // namespace tree
89 } // namespace mlpack
90 
91 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
static double Evaluate(const arma::Mat< size_t > &counts)
static double Range(const size_t numClasses)
Return the range of the Gini impurity for the given number of classes.
Include all of the base components required to write MLPACK methods, and the main MLPACK Doxygen docu...