SHOGUN  6.0.0
KernelSelectionStrategy.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2012 - 2013 Heiko Strathmann
4  * Written (w) 2014 - 2017 Soumyajit De
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright notice, this
11  * list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright notice,
13  * this list of conditions and the following disclaimer in the documentation
14  * and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  * The views and conclusions contained in the software and documentation are those
28  * of the authors and should not be interpreted as representing official policies,
29  * either expressed or implied, of the Shogun Development Team.
30  */
31 
32 #include <shogun/io/SGIO.h>
33 #include <shogun/lib/SGVector.h>
34 #include <shogun/lib/SGMatrix.h>
48 
49 using namespace shogun;
50 using namespace internal;
51 
53 {
54  Self();
55 
56  KernelManager kernel_mgr;
57  std::unique_ptr<KernelSelection> policy;
58 
60  bool weighted;
64 
65  void init_policy(CMMD* estimator);
66 
68  const static bool default_weighted;
69  const static index_t default_num_runs;
70  const static index_t default_num_folds;
71  const static float64_t default_alpha;
72 };
73 
79 
80 CKernelSelectionStrategy::Self::Self() : policy(nullptr), method(default_method),
81  weighted(default_weighted), num_runs(default_num_runs), num_folds(default_num_folds), alpha(default_alpha)
82 {
83 }
84 
86 {
87  switch (method)
88  {
90  {
91  REQUIRE(!weighted, "Weighted kernel selection is not possible with MEDIAN_HEURISTIC!\n");
92  policy=std::unique_ptr<MedianHeuristic>(new MedianHeuristic(kernel_mgr, estimator));
93  }
94  break;
96  {
97  REQUIRE(!weighted, "Weighted kernel selection is not possible with CROSS_VALIDATION!\n");
98  policy=std::unique_ptr<MaxCrossValidation>(new MaxCrossValidation(kernel_mgr, estimator,
100  }
101  break;
102  case KSM_MAXIMIZE_MMD:
103  {
104  if (weighted)
105  policy=std::unique_ptr<WeightedMaxMeasure>(new WeightedMaxMeasure(kernel_mgr, estimator));
106  else
107  policy=std::unique_ptr<MaxMeasure>(new MaxMeasure(kernel_mgr, estimator));
108  }
109  break;
110  case KSM_MAXIMIZE_POWER:
111  {
112  if (weighted)
113  {
114  auto casted_estimator=dynamic_cast<CStreamingMMD*>(estimator);
115  REQUIRE(casted_estimator, "Weighted kernel selection is not possible with MAXIMIZE_POWER!\n");
116  policy=std::unique_ptr<WeightedMaxTestPower>(new WeightedMaxTestPower(kernel_mgr, estimator));
117  }
118  else
119  policy=std::unique_ptr<MaxTestPower>(new MaxTestPower(kernel_mgr, estimator));
120  }
121  break;
122  default:
123  {
124  SG_SERROR("Unsupported kernel selection method specified! Accepted strategies are "
125  "MAXIMIZE_MMD (single, weighted), "
126  "MAXIMIZE_POWER (single, weighted), "
127  "CROSS_VALIDATION (single) and "
128  "MEDIAN_HEURISTIC (single)!\n");
129  }
130  break;
131  }
132 }
133 
134 CKernelSelectionStrategy::CKernelSelectionStrategy()
135 {
136  init();
137 }
138 
139 CKernelSelectionStrategy::CKernelSelectionStrategy(EKernelSelectionMethod method, bool weighted)
140 {
141  init();
142  self->method=method;
143  self->weighted=weighted;
144 }
145 
146 CKernelSelectionStrategy::CKernelSelectionStrategy(EKernelSelectionMethod method, index_t num_runs,
148 {
149  init();
150  self->method=method;
151  self->num_runs=num_runs;
152  self->num_folds=num_folds;
153  self->alpha=alpha;
154 }
155 
156 void CKernelSelectionStrategy::init()
157 {
158  self=std::unique_ptr<Self>(new Self());
159 }
160 
161 CKernelSelectionStrategy::~CKernelSelectionStrategy()
162 {
163  self->kernel_mgr.clear();
164 }
165 
166 CKernelSelectionStrategy& CKernelSelectionStrategy::use_method(EKernelSelectionMethod method)
167 {
168  self->method=method;
169  return *this;
170 }
171 
172 CKernelSelectionStrategy& CKernelSelectionStrategy::use_num_runs(index_t num_runs)
173 {
174  self->num_runs=num_runs;
175  return *this;
176 }
177 
178 CKernelSelectionStrategy& CKernelSelectionStrategy::use_num_folds(index_t num_folds)
179 {
180  self->num_folds=num_folds;
181  return *this;
182 }
183 
184 CKernelSelectionStrategy& CKernelSelectionStrategy::use_alpha(float64_t alpha)
185 {
186  self->alpha=alpha;
187  return *this;
188 }
189 
190 CKernelSelectionStrategy& CKernelSelectionStrategy::use_weighted(bool weighted)
191 {
192  self->weighted=weighted;
193  return *this;
194 }
195 
196 EKernelSelectionMethod CKernelSelectionStrategy::get_method() const
197 {
198  return self->method;
199 }
200 
201 index_t CKernelSelectionStrategy::get_num_runs() const
202 {
203  return self->num_runs;
204 }
205 
206 index_t CKernelSelectionStrategy::get_num_folds() const
207 {
208  return self->num_folds;
209 }
210 
211 float64_t CKernelSelectionStrategy::get_alpha() const
212 {
213  return self->alpha;
214 }
215 
216 bool CKernelSelectionStrategy::get_weighted() const
217 {
218  return self->weighted;
219 }
220 
221 void CKernelSelectionStrategy::add_kernel(CKernel* kernel)
222 {
223  self->kernel_mgr.push_back(kernel);
224 }
225 
226 CKernel* CKernelSelectionStrategy::select_kernel(CMMD* estimator)
227 {
228  auto num_kernels=self->kernel_mgr.num_kernels();
229  REQUIRE(num_kernels>0, "Number of kernels is 0. Please add kernels using add_kernel method!\n");
230  SG_DEBUG("Selecting kernels from a total of %d kernels!\n", num_kernels);
231 
232  self->init_policy(estimator);
233  ASSERT(self->policy!=nullptr);
234 
235  return self->policy->select_kernel();
236 }
237 
238 // TODO call this method when test train mode is turned off
239 void CKernelSelectionStrategy::erase_intermediate_results()
240 {
241  self->policy=nullptr;
242  self->kernel_mgr.clear();
243 }
244 
245 SGMatrix<float64_t> CKernelSelectionStrategy::get_measure_matrix()
246 {
247  REQUIRE(self->policy!=nullptr, "The kernel selection policy is not initialized!\n");
248  return self->policy->get_measure_matrix();
249 }
250 
251 SGVector<float64_t> CKernelSelectionStrategy::get_measure_vector()
252 {
253  REQUIRE(self->policy!=nullptr, "The kernel selection policy is not initialized!\n");
254  return self->policy->get_measure_vector();
255 }
256 
257 const char* CKernelSelectionStrategy::get_name() const
258 {
259  return "KernelSelectionStrategy";
260 }
261 
262 const KernelManager& CKernelSelectionStrategy::get_kernel_mgr() const
263 {
264  return self->kernel_mgr;
265 }
int32_t index_t
Definition: common.h:72
EKernelSelectionMethod
Definition: TestEnums.h:61
#define REQUIRE(x,...)
Definition: SGIO.h:205
#define ASSERT(x)
Definition: SGIO.h:200
static const EKernelSelectionMethod default_method
double float64_t
Definition: common.h:60
std::unique_ptr< KernelSelection > policy
#define SG_DEBUG(...)
Definition: SGIO.h:106
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
#define SG_SERROR(...)
Definition: SGIO.h:178
Abstract base class that provides an interface for performing kernel two-sample test using Maximum Me...
Definition: MMD.h:120
The Kernel base class.

SHOGUN Machine Learning Toolbox - Documentation