31 #ifndef CROSS_VALIDATION_MMD_H_ 32 #define CROSS_VALIDATION_MMD_H_ 44 using std::unique_ptr;
54 #ifndef DOXYGEN_SHOULD_SKIP_THIS 55 struct CrossValidationMMD : PermutationMMD
61 ASSERT(num_null_samples>0);
65 m_num_folds=num_folds;
66 m_num_null_samples=num_null_samples;
67 m_num_runs=DEFAULT_NUM_RUNS;
68 m_alpha=DEFAULT_ALPHA;
73 void operator()(
const KernelManager& kernel_mgr)
75 REQUIRE(m_rejections.num_rows==m_num_runs*m_num_folds,
76 "Number of rows in the measure matrix (was %d), has to be >= %d*%d = %d!\n",
77 m_rejections.num_rows, m_num_runs, m_num_folds, m_num_runs*m_num_folds);
78 REQUIRE(m_rejections.num_cols==kernel_mgr.num_kernels(),
79 "Number of columns in the measure matrix (was %d), has to equal to the nunber of kernels (%d)!\n",
80 m_rejections.num_cols, kernel_mgr.num_kernels());
85 SGVector<float64_t> null_samples(m_num_null_samples);
86 SGVector<float32_t> precomputed_km(size*(size+1)/2);
88 for (
auto k=0; k<kernel_mgr.num_kernels(); ++k)
90 auto kernel=kernel_mgr.kernel_at(k);
91 for (
auto i=0; i<size; ++i)
93 for (
auto j=i; j<size; ++j)
95 auto index=i*size-i*(i+1)/2+j;
96 precomputed_km[index]=kernel->kernel(i, j);
100 for (
auto current_run=0; current_run<m_num_runs; ++current_run)
102 m_kfold_x->build_subsets();
103 m_kfold_y->build_subsets();
104 for (
auto current_fold=0; current_fold<m_num_folds; ++current_fold)
106 generate_inds(current_fold);
107 std::fill(m_inverted_inds.data(), m_inverted_inds.data()+m_inverted_inds.size(), -1);
108 for (
index_t idx=0; idx<m_xy_inds.size(); ++idx)
109 m_inverted_inds[m_xy_inds[idx]]=idx;
111 m_stack->add_subset(m_xy_inds);
113 if (m_permuted_inds.size()!=m_xy_inds.size())
114 m_permuted_inds=SGVector<index_t>(m_xy_inds.size());
116 m_inverted_permuted_inds.set_const(-1);
118 for (
auto n=0; n<m_num_null_samples; ++n)
120 std::iota(m_permuted_inds.data(), m_permuted_inds.data()+m_permuted_inds.size(), 0);
123 m_stack->add_subset(m_permuted_inds);
124 SGVector<index_t> inds=m_stack->get_last_subset()->get_subset_idx();
125 m_stack->remove_subset();
127 for (
int idx=0; idx<inds.size(); ++idx)
128 m_inverted_permuted_inds(inds[idx], n)=idx;
130 m_stack->remove_subset();
133 for (
auto i=0; i<size; ++i)
135 auto inverted_row=m_inverted_inds[i];
136 auto idx_base=i*size-i*(i+1)/2;
137 for (
auto j=i; j<size; ++j)
139 auto inverted_col=m_inverted_inds[j];
140 if (inverted_row!=-1 && inverted_col!=-1)
143 add_term_upper(terms, precomputed_km[idx], inverted_row, inverted_col);
147 auto statistic=compute(terms);
149 #pragma omp parallel for 150 for (
auto n=0; n<m_num_null_samples; ++n)
153 for (
auto i=0; i<size; ++i)
155 auto inverted_row=m_inverted_permuted_inds(i, n);
156 auto idx_base=i*size-i*(i+1)/2;
157 for (
auto j=i; j<size; ++j)
159 auto inverted_col=m_inverted_permuted_inds(j, n);
160 if (inverted_row!=-1 && inverted_col!=-1)
163 if (inverted_row<=inverted_col)
164 add_term_upper(null_terms, precomputed_km[idx], inverted_row, inverted_col);
166 add_term_upper(null_terms, precomputed_km[idx], inverted_col, inverted_row);
170 null_samples[n]=compute(null_terms);
173 std::sort(null_samples.data(), null_samples.data()+null_samples.size());
175 float64_t idx=null_samples.find_position_to_insert(statistic);
177 auto p_value=1.0-idx/m_num_null_samples;
178 bool rejected=p_value<m_alpha;
179 SG_SDEBUG(
"p-value=%f, alpha=%f, rejected=%d\n", p_value, m_alpha, rejected);
180 m_rejections(current_run*m_num_folds+current_fold, k)=rejected;
191 SGVector<int64_t> dummy_labels_x(m_n_x);
192 SGVector<int64_t> dummy_labels_y(m_n_y);
194 auto instance_x=
new CCrossValidationSplitting(
new CBinaryLabels(dummy_labels_x), m_num_folds);
195 auto instance_y=
new CCrossValidationSplitting(
new CBinaryLabels(dummy_labels_y), m_num_folds);
196 m_kfold_x=unique_ptr<CCrossValidationSplitting>(instance_x);
197 m_kfold_y=unique_ptr<CCrossValidationSplitting>(instance_y);
199 m_stack=unique_ptr<CSubsetStack>(
new CSubsetStack());
201 const index_t size=m_n_x+m_n_y;
202 m_inverted_inds=SGVector<index_t>(size);
203 m_inverted_permuted_inds=SGMatrix<index_t>(size, m_num_null_samples);
206 void generate_inds(
index_t current_fold)
208 SGVector<index_t> x_inds=m_kfold_x->generate_subset_inverse(current_fold);
209 SGVector<index_t> y_inds=m_kfold_y->generate_subset_inverse(current_fold);
210 std::for_each(y_inds.data(), y_inds.data()+y_inds.size(), [
this](
index_t& val) { val += m_n_x; });
215 if (m_xy_inds.size()!=m_n_x+m_n_y)
216 m_xy_inds=SGVector<index_t>(m_n_x+m_n_y);
218 std::copy(x_inds.data(), x_inds.data()+x_inds.size(), m_xy_inds.data());
219 std::copy(y_inds.data(), y_inds.data()+y_inds.size(), m_xy_inds.data()+x_inds.size());
224 static constexpr
index_t DEFAULT_NUM_RUNS=10;
227 static constexpr
float64_t DEFAULT_ALPHA=0.05;
229 unique_ptr<CCrossValidationSplitting> m_kfold_x;
230 unique_ptr<CCrossValidationSplitting> m_kfold_y;
231 unique_ptr<CSubsetStack> m_stack;
233 SGVector<index_t> m_xy_inds;
234 SGVector<index_t> m_inverted_inds;
235 SGMatrix<float64_t> m_rejections;
238 #endif // DOXYGEN_SHOULD_SKIP_THIS 244 #endif // CROSS_VALIDATION_MMD_H_ static void permute(SGVector< T > v, CRandom *rand=NULL)
all of classes and functions are contained in the shogun namespace