49 using namespace internal;
53 : KernelSelection(km, est), num_runs(M), num_folds(K), alpha(alp)
55 REQUIRE(num_runs>0,
"Number of runs (%d) must be positive!\n", num_runs);
56 REQUIRE(num_folds>0,
"Number of folds (%d) must be positive!\n", num_folds);
57 REQUIRE(alpha>=0.0 && alpha<=1.0,
"Threshold (%f) has to be in [0, 1]!\n", alpha);
60 MaxCrossValidation::~MaxCrossValidation()
74 void MaxCrossValidation::init_measures()
76 const index_t num_kernels=kernel_mgr.num_kernels();
77 if (rejections.num_rows!=num_folds*num_runs || rejections.num_cols!=num_kernels)
79 std::fill(rejections.data(), rejections.data()+rejections.size(), 0);
80 if (measures.size()!=num_kernels)
82 std::fill(measures.data(), measures.data()+measures.size(), 0);
85 void MaxCrossValidation::compute_measures()
87 SG_SDEBUG(
"Performing %d fold cross-validattion!\n", num_folds);
88 const auto num_kernels=kernel_mgr.num_kernels();
91 if (quadratic_time_mmd)
94 "Only supported with PERMUTATION method for null distribution approximation!\n");
96 auto Nx=estimator->get_num_samples_p();
97 auto Ny=estimator->get_num_samples_q();
98 auto num_null_samples=estimator->get_num_null_samples();
99 auto stype=estimator->get_statistic_type();
100 CrossValidationMMD compute(Nx, Ny, num_folds, num_null_samples);
101 compute.m_stype=stype;
102 compute.m_alpha=alpha;
103 compute.m_num_runs=num_runs;
104 compute.m_rejections=rejections;
106 if (kernel_mgr.same_distance_type())
109 auto precomputed_distance=estimator->compute_joint_distance(distance);
110 kernel_mgr.set_precomputed_distance(precomputed_distance);
113 kernel_mgr.unset_precomputed_distance();
118 auto samples_p_and_q=quadratic_time_mmd->
get_p_and_q();
121 for (
auto k=0; k<num_kernels; ++k)
123 CKernel* kernel=kernel_mgr.kernel_at(k);
124 kernel->
init(samples_p_and_q, samples_p_and_q);
129 for (
auto k=0; k<num_kernels; ++k)
131 CKernel* kernel=kernel_mgr.kernel_at(k);
140 auto existing_kernel=estimator->get_kernel();
141 for (
auto i=0; i<num_runs; ++i)
143 for (
auto j=0; j<num_folds; ++j)
146 for (
auto k=0; k<num_kernels; ++k)
148 auto kernel=kernel_mgr.kernel_at(k);
149 estimator->set_kernel(kernel);
150 auto statistic=estimator->compute_statistic();
151 rejections(i*num_folds+j, k)=estimator->compute_p_value(statistic)<alpha;
157 estimator->set_kernel(existing_kernel);
160 for (
auto j=0; j<rejections.num_cols; ++j)
162 auto begin=rejections.get_column_vector(j);
163 auto size=rejections.num_rows;
164 measures[j]=std::accumulate(begin, begin+size, 0.0)/size;
168 CKernel* MaxCrossValidation::select_kernel()
172 auto max_element=std::max_element(measures.vector, measures.vector+measures.vlen);
174 SG_SDEBUG(
"Selected kernel at %d position!\n", max_idx);
175 return kernel_mgr.kernel_at(max_idx);
float distance(CJLCoverTreePoint p1, CJLCoverTreePoint p2, float64_t upper_bound)
virtual bool init(CFeatures *lhs, CFeatures *rhs)
Class Distance, a base class for all the distances used in the Shogun toolbox.
virtual void remove_lhs_and_rhs()
This class implements the quadratic time Maximum Mean Statistic as described in [1]. The MMD is the distance of two probability distributions and in a RKHS which we denote by .
all of classes and functions are contained in the shogun namespace
CFeatures * get_p_and_q()
Abstract base class that provides an interface for performing kernel two-sample test using Maximum Me...