32 #ifndef COMPUTE_MMD_H_ 33 #define COMPUTE_MMD_H_ 37 #include <shogun/lib/config.h> 58 std::array<float64_t, 3>
term{};
59 std::array<float64_t, 3>
diag{};
61 #ifndef DOXYGEN_SHOULD_SKIP_THIS 71 template <
class Kernel>
72 float32_t operator()(
const Kernel& kernel)
const 74 ASSERT(m_n_x>0 && m_n_y>0);
77 for (
auto i=0; i<size; ++i)
79 for (
auto j=i; j<size; ++j)
80 add_term_upper(terms, kernel(i, j), i, j);
82 return compute(terms);
86 float32_t operator()(
const SGMatrix<T>& kernel_matrix)
const 88 ASSERT(m_n_x>0 && m_n_y>0);
90 ASSERT(kernel_matrix.num_rows==size && kernel_matrix.num_cols==size);
93 typedef Eigen::Block<Eigen::Map<const MatrixXt> > BlockXt;
97 const BlockXt& b_x=map.block(0, 0, m_n_x, m_n_x);
98 const BlockXt& b_y=map.block(m_n_x, m_n_x, m_n_y, m_n_y);
99 const BlockXt& b_xy=map.block(m_n_x, 0, m_n_y, m_n_x);
102 terms.diag[0]=b_x.diagonal().sum();
103 terms.diag[1]=b_y.diagonal().sum();
104 terms.diag[2]=b_xy.diagonal().sum();
106 terms.term[0]=(b_x.sum()-terms.diag[0])/2+terms.diag[0];
107 terms.term[1]=(b_y.sum()-terms.diag[1])/2+terms.diag[1];
108 terms.term[2]=b_xy.sum();
110 return compute(terms);
113 SGVector<float64_t> operator()(
const KernelManager& kernel_mgr)
const 115 ASSERT(m_n_x>0 && m_n_y>0);
116 std::vector<terms_t> terms(kernel_mgr.num_kernels());
117 const index_t size=m_n_x+m_n_y;
118 for (
auto j=0; j<size; ++j)
120 for (
auto i=j; i<size; ++i)
122 for (
auto k=0; k<kernel_mgr.num_kernels(); ++k)
124 auto kernel=kernel_mgr.kernel_at(k)->kernel(i, j);
125 add_term_lower(terms[k], kernel, i, j);
130 SGVector<float64_t> result(kernel_mgr.num_kernels());
131 for (
auto k=0; k<kernel_mgr.num_kernels(); ++k)
133 result[k]=compute(terms[k]);
134 SG_SDEBUG(
"result[%d] = %f!\n", k, result[k]);
149 template <
typename T>
150 inline void add_term_lower(terms_t& terms, T kernel_value,
index_t i,
index_t j)
const 152 ASSERT(m_n_x>0 && m_n_y>0);
153 if (i<m_n_x && j<m_n_x && i>=j)
155 SG_SDEBUG(
"Adding Kernel(%d, %d)=%f to term_0!\n", i, j, kernel_value);
156 terms.term[0]+=kernel_value;
158 terms.diag[0]+=kernel_value;
160 else if (i>=m_n_x && j>=m_n_x && i>=j)
162 SG_SDEBUG(
"Adding Kernel(%d, %d)=%f to term_1!\n", i, j, kernel_value);
163 terms.term[1]+=kernel_value;
165 terms.diag[1]+=kernel_value;
167 else if (i>=m_n_x && j<m_n_x)
169 SG_SDEBUG(
"Adding Kernel(%d, %d)=%f to term_2!\n", i, j, kernel_value);
170 terms.term[2]+=kernel_value;
172 terms.diag[2]+=kernel_value;
185 template <
typename T>
186 inline void add_term_upper(terms_t& terms, T kernel_value,
index_t i,
index_t j)
const 188 ASSERT(m_n_x>0 && m_n_y>0);
189 if (i<m_n_x && j<m_n_x && i<=j)
191 SG_SDEBUG(
"Adding Kernel(%d, %d)=%f to term_0!\n", i, j, kernel_value);
192 terms.term[0]+=kernel_value;
194 terms.diag[0]+=kernel_value;
196 else if (i>=m_n_x && j>=m_n_x && i<=j)
198 SG_SDEBUG(
"Adding Kernel(%d, %d)=%f to term_1!\n", i, j, kernel_value);
199 terms.term[1]+=kernel_value;
201 terms.diag[1]+=kernel_value;
203 else if (i<m_n_x && j>=m_n_x)
205 SG_SDEBUG(
"Adding Kernel(%d, %d)=%f to term_2!\n", i, j, kernel_value);
206 terms.term[2]+=kernel_value;
208 terms.diag[2]+=kernel_value;
212 inline float64_t compute(terms_t& terms)
const 214 ASSERT(m_n_x>0 && m_n_y>0);
215 terms.term[0]=2*(terms.term[0]-terms.diag[0]);
216 terms.term[1]=2*(terms.term[1]-terms.diag[1]);
217 SG_SDEBUG(
"term_0 sum (without diagonal) = %f!\n", terms.term[0]);
218 SG_SDEBUG(
"term_1 sum (without diagonal) = %f!\n", terms.term[1]);
221 terms.term[0]/=m_n_x*(m_n_x-1);
222 terms.term[1]/=m_n_y*(m_n_y-1);
226 terms.term[0]+=terms.diag[0];
227 terms.term[1]+=terms.diag[1];
228 SG_SDEBUG(
"term_0 sum (with diagonal) = %f!\n", terms.term[0]);
229 SG_SDEBUG(
"term_1 sum (with diagonal) = %f!\n", terms.term[1]);
230 terms.term[0]/=m_n_x*m_n_x;
231 terms.term[1]/=m_n_y*m_n_y;
233 SG_SDEBUG(
"term_0 (normalized) = %f!\n", terms.term[0]);
234 SG_SDEBUG(
"term_1 (normalized) = %f!\n", terms.term[1]);
236 SG_SDEBUG(
"term_2 sum (with diagonal) = %f!\n", terms.term[2]);
239 terms.term[2]-=terms.diag[2];
240 SG_SDEBUG(
"term_2 sum (without diagonal) = %f!\n", terms.term[2]);
241 terms.term[2]/=m_n_x*(m_n_x-1);
244 terms.term[2]/=m_n_x*m_n_y;
245 SG_SDEBUG(
"term_2 (normalized) = %f!\n", terms.term[2]);
247 auto result=terms.term[0]+terms.term[1]-2*terms.term[2];
256 #endif // DOXYGEN_SHOULD_SKIP_THIS 262 #endif // COMPUTE_MMD_H_
std::array< float64_t, 3 > term
all of classes and functions are contained in the shogun namespace
std::array< float64_t, 3 > diag