SHOGUN  6.0.0
MaxCrossValidation.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2012 - 2013 Heiko Strathmann
4  * Written (w) 2014 - 2017 Soumyajit De
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright notice, this
11  * list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright notice,
13  * this list of conditions and the following disclaimer in the documentation
14  * and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  * The views and conclusions contained in the software and documentation are those
28  * of the authors and should not be interpreted as representing official policies,
29  * either expressed or implied, of the Shogun Development Team.
30  */
31 
32 #include <algorithm>
33 #include <numeric>
34 #include <shogun/lib/SGVector.h>
35 #include <shogun/lib/SGMatrix.h>
36 #include <shogun/kernel/Kernel.h>
47 
48 using namespace shogun;
49 using namespace internal;
50 using namespace mmd;
51 
52 MaxCrossValidation::MaxCrossValidation(KernelManager& km, CMMD* est, const index_t& M, const index_t& K, const float64_t& alp)
53 : KernelSelection(km, est), num_runs(M), num_folds(K), alpha(alp)
54 {
55  REQUIRE(num_runs>0, "Number of runs (%d) must be positive!\n", num_runs);
56  REQUIRE(num_folds>0, "Number of folds (%d) must be positive!\n", num_folds);
57  REQUIRE(alpha>=0.0 && alpha<=1.0, "Threshold (%f) has to be in [0, 1]!\n", alpha);
58 }
59 
60 MaxCrossValidation::~MaxCrossValidation()
61 {
62 }
63 
64 SGVector<float64_t> MaxCrossValidation::get_measure_vector()
65 {
66  return measures;
67 }
68 
69 SGMatrix<float64_t> MaxCrossValidation::get_measure_matrix()
70 {
71  return rejections;
72 }
73 
74 void MaxCrossValidation::init_measures()
75 {
76  const index_t num_kernels=kernel_mgr.num_kernels();
77  if (rejections.num_rows!=num_folds*num_runs || rejections.num_cols!=num_kernels)
78  rejections=SGMatrix<float64_t>(num_folds*num_runs, num_kernels);
79  std::fill(rejections.data(), rejections.data()+rejections.size(), 0);
80  if (measures.size()!=num_kernels)
81  measures=SGVector<float64_t>(num_kernels);
82  std::fill(measures.data(), measures.data()+measures.size(), 0);
83 }
84 
85 void MaxCrossValidation::compute_measures()
86 {
87  SG_SDEBUG("Performing %d fold cross-validattion!\n", num_folds);
88  const auto num_kernels=kernel_mgr.num_kernels();
89 
90  CQuadraticTimeMMD* quadratic_time_mmd=dynamic_cast<CQuadraticTimeMMD*>(estimator);
91  if (quadratic_time_mmd)
92  {
93  REQUIRE(estimator->get_null_approximation_method()==NAM_PERMUTATION,
94  "Only supported with PERMUTATION method for null distribution approximation!\n");
95 
96  auto Nx=estimator->get_num_samples_p();
97  auto Ny=estimator->get_num_samples_q();
98  auto num_null_samples=estimator->get_num_null_samples();
99  auto stype=estimator->get_statistic_type();
100  CrossValidationMMD compute(Nx, Ny, num_folds, num_null_samples);
101  compute.m_stype=stype;
102  compute.m_alpha=alpha;
103  compute.m_num_runs=num_runs;
104  compute.m_rejections=rejections;
105 
106  if (kernel_mgr.same_distance_type())
107  {
108  CDistance* distance=kernel_mgr.get_distance_instance();
109  auto precomputed_distance=estimator->compute_joint_distance(distance);
110  kernel_mgr.set_precomputed_distance(precomputed_distance);
111  SG_UNREF(distance);
112  compute(kernel_mgr);
113  kernel_mgr.unset_precomputed_distance();
114  SG_UNREF(precomputed_distance);
115  }
116  else
117  {
118  auto samples_p_and_q=quadratic_time_mmd->get_p_and_q();
119  SG_REF(samples_p_and_q);
120 
121  for (auto k=0; k<num_kernels; ++k)
122  {
123  CKernel* kernel=kernel_mgr.kernel_at(k);
124  kernel->init(samples_p_and_q, samples_p_and_q);
125  }
126 
127  compute(kernel_mgr);
128 
129  for (auto k=0; k<num_kernels; ++k)
130  {
131  CKernel* kernel=kernel_mgr.kernel_at(k);
132  kernel->remove_lhs_and_rhs();
133  }
134 
135  SG_UNREF(samples_p_and_q);
136  }
137  }
138  else // TODO put check, this one assumes infinite data
139  {
140  auto existing_kernel=estimator->get_kernel();
141  for (auto i=0; i<num_runs; ++i)
142  {
143  for (auto j=0; j<num_folds; ++j)
144  {
145  SG_SDEBUG("Running fold %d\n", j);
146  for (auto k=0; k<num_kernels; ++k)
147  {
148  auto kernel=kernel_mgr.kernel_at(k);
149  estimator->set_kernel(kernel);
150  auto statistic=estimator->compute_statistic();
151  rejections(i*num_folds+j, k)=estimator->compute_p_value(statistic)<alpha;
152  estimator->cleanup();
153  }
154  }
155  }
156  if (existing_kernel)
157  estimator->set_kernel(existing_kernel);
158  }
159 
160  for (auto j=0; j<rejections.num_cols; ++j)
161  {
162  auto begin=rejections.get_column_vector(j);
163  auto size=rejections.num_rows;
164  measures[j]=std::accumulate(begin, begin+size, 0.0)/size;
165  }
166 }
167 
168 CKernel* MaxCrossValidation::select_kernel()
169 {
170  init_measures();
171  compute_measures();
172  auto max_element=std::max_element(measures.vector, measures.vector+measures.vlen);
173  auto max_idx=std::distance(measures.vector, max_element);
174  SG_SDEBUG("Selected kernel at %d position!\n", max_idx);
175  return kernel_mgr.kernel_at(max_idx);
176 }
float distance(CJLCoverTreePoint p1, CJLCoverTreePoint p2, float64_t upper_bound)
virtual bool init(CFeatures *lhs, CFeatures *rhs)
Definition: Kernel.cpp:81
virtual void cleanup()
Definition: Kernel.cpp:156
Class Distance, a base class for all the distances used in the Shogun toolbox.
Definition: Distance.h:87
int32_t index_t
Definition: common.h:72
#define REQUIRE(x,...)
Definition: SGIO.h:205
#define SG_REF(x)
Definition: SGObject.h:52
virtual void remove_lhs_and_rhs()
Definition: Kernel.cpp:177
This class implements the quadratic time Maximum Mean Statistic as described in [1]. The MMD is the distance of two probability distributions and in a RKHS which we denote by .
double float64_t
Definition: common.h:60
#define SG_UNREF(x)
Definition: SGObject.h:53
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
#define SG_SDEBUG(...)
Definition: SGIO.h:167
Abstract base class that provides an interface for performing kernel two-sample test using Maximum Me...
Definition: MMD.h:120
The Kernel base class.

SHOGUN Machine Learning Toolbox - Documentation