32 #ifndef PERMUTATION_MMD_H_ 33 #define PERMUTATION_MMD_H_ 50 #ifndef DOXYGEN_SHOULD_SKIP_THIS 51 struct PermutationMMD : ComputeMMD
53 PermutationMMD() : m_save_inds(false)
57 template <
class Kernel>
58 SGVector<float32_t> operator()(
const Kernel& kernel)
60 ASSERT(m_n_x>0 && m_n_y>0);
61 ASSERT(m_num_null_samples>0);
62 precompute_permutation_inds();
65 SGVector<float32_t> null_samples(m_num_null_samples);
66 #pragma omp parallel for 67 for (
auto n=0; n<m_num_null_samples; ++n)
70 for (
auto j=0; j<size; ++j)
72 auto inverted_col=m_inverted_permuted_inds(j, n);
73 for (
auto i=j; i<size; ++i)
75 auto inverted_row=m_inverted_permuted_inds(i, n);
77 if (inverted_row>=inverted_col)
78 add_term_lower(terms, kernel(i, j), inverted_row, inverted_col);
80 add_term_lower(terms, kernel(i, j), inverted_col, inverted_row);
83 null_samples[n]=compute(terms);
84 SG_SDEBUG(
"null_samples[%d] = %f!\n", n, null_samples[n]);
89 SGMatrix<float32_t> operator()(
const KernelManager& kernel_mgr)
91 ASSERT(m_n_x>0 && m_n_y>0);
92 ASSERT(m_num_null_samples>0);
93 precompute_permutation_inds();
96 SGMatrix<float32_t> null_samples(m_num_null_samples, kernel_mgr.num_kernels());
97 SGVector<float32_t> km(size*(size+1)/2);
98 for (
auto k=0; k<kernel_mgr.num_kernels(); ++k)
100 auto kernel=kernel_mgr.kernel_at(k);
102 for (
auto i=0; i<size; ++i)
104 for (
auto j=i; j<size; ++j)
106 auto index=i*size-i*(i+1)/2+j;
107 km[index]=kernel->kernel(i, j);
111 #pragma omp parallel for 112 for (
auto n=0; n<m_num_null_samples; ++n)
115 for (
auto i=0; i<size; ++i)
117 auto inverted_row=m_inverted_permuted_inds(i, n);
118 auto index_base=i*size-i*(i+1)/2;
119 for (
auto j=i; j<size; ++j)
121 auto index=index_base+j;
122 auto inverted_col=m_inverted_permuted_inds(j, n);
124 if (inverted_row<=inverted_col)
125 add_term_upper(null_terms, km[index], inverted_row, inverted_col);
127 add_term_upper(null_terms, km[index], inverted_col, inverted_row);
130 null_samples(n, k)=compute(null_terms);
136 template <
class Kernel>
139 auto statistic=ComputeMMD::operator()(kernel);
140 auto null_samples=operator()(kernel);
141 return compute_p_value(null_samples, statistic);
144 SGVector<float64_t> p_value(
const KernelManager& kernel_mgr)
146 ASSERT(m_n_x>0 && m_n_y>0);
147 ASSERT(m_num_null_samples>0);
148 precompute_permutation_inds();
150 const index_t size=m_n_x+m_n_y;
151 SGVector<float32_t> null_samples(m_num_null_samples);
152 SGVector<float64_t> result(kernel_mgr.num_kernels());
154 SGVector<float32_t> km(size*(size+1)/2);
155 for (
auto k=0; k<kernel_mgr.num_kernels(); ++k)
157 auto kernel=kernel_mgr.kernel_at(k);
159 for (
auto i=0; i<size; ++i)
161 for (
auto j=i; j<size; ++j)
163 auto index=i*size-i*(i+1)/2+j;
164 km[index]=kernel->kernel(i, j);
165 add_term_upper(terms, km[index], i, j);
169 SG_SDEBUG(
"Kernel(%d): statistic=%f\n", k, statistic);
171 #pragma omp parallel for 172 for (
auto n=0; n<m_num_null_samples; ++n)
175 for (
auto i=0; i<size; ++i)
177 auto inverted_row=m_inverted_permuted_inds(i, n);
178 auto index_base=i*size-i*(i+1)/2;
179 for (
auto j=i; j<size; ++j)
181 auto index=index_base+j;
182 auto inverted_col=m_inverted_permuted_inds(j, n);
184 if (inverted_row<=inverted_col)
185 add_term_upper(null_terms, km[index], inverted_row, inverted_col);
187 add_term_upper(null_terms, km[index], inverted_col, inverted_row);
190 null_samples[n]=compute(null_terms);
192 result[k]=compute_p_value(null_samples, statistic);
193 SG_SDEBUG(
"Kernel(%d): p_value=%f\n", k, result[k]);
199 inline void precompute_permutation_inds()
201 ASSERT(m_num_null_samples>0);
202 allocate_permutation_inds();
203 for (
auto n=0; n<m_num_null_samples; ++n)
205 std::iota(m_permuted_inds.data(), m_permuted_inds.data()+m_permuted_inds.size(), 0);
209 auto offset=n*m_permuted_inds.size();
210 std::copy(m_permuted_inds.data(), m_permuted_inds.data()+m_permuted_inds.size(), &m_all_inds.matrix[offset]);
212 for (
index_t i=0; i<m_permuted_inds.size(); ++i)
213 m_inverted_permuted_inds(m_permuted_inds[i], n)=i;
217 inline float64_t compute_p_value(SGVector<float32_t>& null_samples,
float32_t statistic)
const 219 std::sort(null_samples.data(), null_samples.data()+null_samples.size());
220 float64_t idx=null_samples.find_position_to_insert(statistic);
221 return 1.0-idx/null_samples.size();
224 inline void allocate_permutation_inds()
226 const index_t size=m_n_x+m_n_y;
227 if (m_permuted_inds.size()!=size)
228 m_permuted_inds=SGVector<index_t>(size);
230 if (m_inverted_permuted_inds.num_cols!=m_num_null_samples || m_inverted_permuted_inds.num_rows!=size)
231 m_inverted_permuted_inds=SGMatrix<index_t>(size, m_num_null_samples);
233 if (m_save_inds && (m_all_inds.num_cols!=m_num_null_samples || m_all_inds.num_rows!=size))
234 m_all_inds=SGMatrix<index_t>(size, m_num_null_samples);
239 SGVector<index_t> m_permuted_inds;
240 SGMatrix<index_t> m_inverted_permuted_inds;
241 SGMatrix<index_t> m_all_inds;
243 #endif // DOXYGEN_SHOULD_SKIP_THIS 250 #endif // PERMUTATION_MMD_H_ static void permute(SGVector< T > v, CRandom *rand=NULL)
all of classes and functions are contained in the shogun namespace