45 using namespace internal;
47 using std::unique_ptr;
64 m_pairwise_distance(nullptr), m_dtype(
D_UNKNOWN)
84 CMultiKernelQuadraticTimeMMD::CMultiKernelQuadraticTimeMMD() :
CSGObject()
86 self=unique_ptr<Self>(
new Self(
nullptr));
91 self=unique_ptr<Self>(
new Self(owner));
102 REQUIRE(kernel,
"Kernel instance cannot be NULL!\n");
103 self->m_kernel_mgr.push_back(kernel);
108 self->m_kernel_mgr.clear();
109 invalidate_precomputed_distance();
112 void CMultiKernelQuadraticTimeMMD::invalidate_precomputed_distance()
114 self->m_pairwise_distance=
nullptr;
121 return statistic(self->m_kernel_mgr);
134 return variance_h1(self->m_kernel_mgr);
140 return test_power(self->m_kernel_mgr);
152 return p_values(self->m_kernel_mgr);
159 for (
auto i=0; i<pvalues.
size(); ++i)
161 rejections[i]=pvalues[i]<alpha;
166 SGVector<float64_t> CMultiKernelQuadraticTimeMMD::statistic(
const KernelManager& kernel_mgr)
169 REQUIRE(kernel_mgr.num_kernels()>0,
"Number of kernels (%d) have to be greater than 0!\n", kernel_mgr.num_kernels());
171 const auto nx=
self->m_owner->get_num_samples_p();
172 const auto ny=
self->m_owner->get_num_samples_q();
173 const auto stype =
self->m_owner->get_statistic_type();
176 self->update_pairwise_distance(distance);
177 kernel_mgr.set_precomputed_distance(self->m_pairwise_distance.get());
180 self->statistic_job.m_n_x=nx;
181 self->statistic_job.m_n_y=ny;
182 self->statistic_job.m_stype=stype;
185 kernel_mgr.unset_precomputed_distance();
187 for (
auto i=0; i<result.
vlen; ++i)
188 result[i]=self->m_owner->normalize_statistic(result[i]);
194 SGVector<float64_t> CMultiKernelQuadraticTimeMMD::variance_h1(
const KernelManager& kernel_mgr)
197 REQUIRE(kernel_mgr.num_kernels()>0,
"Number of kernels (%d) have to be greater than 0!\n", kernel_mgr.num_kernels());
199 const auto nx=
self->m_owner->get_num_samples_p();
200 const auto ny=
self->m_owner->get_num_samples_q();
203 self->update_pairwise_distance(distance);
204 kernel_mgr.set_precomputed_distance(self->m_pairwise_distance.get());
207 self->variance_h1_job.m_n_x=nx;
208 self->variance_h1_job.m_n_y=ny;
211 kernel_mgr.unset_precomputed_distance();
217 SGVector<float64_t> CMultiKernelQuadraticTimeMMD::test_power(
const KernelManager& kernel_mgr)
220 REQUIRE(kernel_mgr.num_kernels()>0,
"Number of kernels (%d) have to be greater than 0!\n", kernel_mgr.num_kernels());
223 const auto nx=
self->m_owner->get_num_samples_p();
224 const auto ny=
self->m_owner->get_num_samples_q();
227 self->update_pairwise_distance(distance);
228 kernel_mgr.set_precomputed_distance(self->m_pairwise_distance.get());
231 self->variance_h1_job.m_n_x=nx;
232 self->variance_h1_job.m_n_y=ny;
235 kernel_mgr.unset_precomputed_distance();
245 "Multi-kernel tests requires the H0 approximation method to be PERMUTATION!\n");
247 REQUIRE(kernel_mgr.num_kernels()>0,
"Number of kernels (%d) have to be greater than 0!\n", kernel_mgr.num_kernels());
249 const auto nx=
self->m_owner->get_num_samples_p();
250 const auto ny=
self->m_owner->get_num_samples_q();
251 const auto stype =
self->m_owner->get_statistic_type();
252 const auto num_null_samples =
self->m_owner->get_num_null_samples();
255 self->update_pairwise_distance(distance);
256 kernel_mgr.set_precomputed_distance(self->m_pairwise_distance.get());
259 self->permutation_job.m_n_x=nx;
260 self->permutation_job.m_n_y=ny;
261 self->permutation_job.m_num_null_samples=num_null_samples;
262 self->permutation_job.m_stype=stype;
265 kernel_mgr.unset_precomputed_distance();
267 for (
size_t i=0; i<result.
size(); ++i)
268 result.
matrix[i]=self->m_owner->normalize_statistic(result.
matrix[i]);
274 SGVector<float64_t> CMultiKernelQuadraticTimeMMD::p_values(
const KernelManager& kernel_mgr)
278 "Multi-kernel tests requires the H0 approximation method to be PERMUTATION!\n");
280 REQUIRE(kernel_mgr.num_kernels()>0,
"Number of kernels (%d) have to be greater than 0!\n", kernel_mgr.num_kernels());
282 const auto nx=
self->m_owner->get_num_samples_p();
283 const auto ny=
self->m_owner->get_num_samples_q();
284 const auto stype =
self->m_owner->get_statistic_type();
285 const auto num_null_samples =
self->m_owner->get_num_null_samples();
288 self->update_pairwise_distance(distance);
289 kernel_mgr.set_precomputed_distance(self->m_pairwise_distance.get());
292 self->permutation_job.m_n_x=nx;
293 self->permutation_job.m_n_y=ny;
294 self->permutation_job.m_num_null_samples=num_null_samples;
295 self->permutation_job.m_stype=stype;
298 kernel_mgr.unset_precomputed_distance();
306 return "MultiKernelQuadraticTimeMMD";
virtual const char * get_name() const =0
float distance(CJLCoverTreePoint p1, CJLCoverTreePoint p2, float64_t upper_bound)
CQuadraticTimeMMD * m_owner
Base class for the family of kernel functions that only depend on the difference of the inputs...
Class Distance, a base class for all the distances used in the Shogun toolbox.
virtual EDistanceType get_distance_type()=0
#define SG_NOTIMPLEMENTED
unique_ptr< CCustomDistance > m_pairwise_distance
Self(CQuadraticTimeMMD *owner)
VarianceH1 variance_h1_job
This class implements the quadratic time Maximum Mean Statistic as described in [1]. The MMD is the distance of two probability distributions and in a RKHS which we denote by .
void update_pairwise_distance(CDistance *distance)
Class SGObject is the base class of all shogun objects.
void add_kernel(CShiftInvariantKernel *kernel)
CCustomDistance * compute_joint_distance(CDistance *distance)
KernelManager m_kernel_mgr
SGVector< float64_t > compute_statistic()
SGVector< float64_t > compute_variance_h0()
SGVector< bool > perform_test(float64_t alpha)
CMultiKernelQuadraticTimeMMD()
SGMatrix< float32_t > sample_null()
virtual const char * get_name() const
SGVector< float64_t > compute_p_value()
all of classes and functions are contained in the shogun namespace
virtual ~CMultiKernelQuadraticTimeMMD()
SGVector< float64_t > compute_variance_h1()
PermutationMMD permutation_job
SGVector< float64_t > compute_test_power()