34 #include <type_traits> 54 using namespace internal;
60 void create_statistic_job();
61 void create_variance_job();
62 void create_computation_jobs();
64 void merge_samples(
NextSamples&, std::vector<CFeatures*>&)
const;
65 void compute_kernel(ComputationManager&, std::vector<CFeatures*>&,
CKernel*)
const;
66 void compute_jobs(ComputationManager&)
const;
68 std::pair<float64_t, float64_t> compute_statistic_variance();
69 std::pair<SGVector<float64_t>,
SGMatrix<float64_t>> compute_statistic_and_Q(
const KernelManager&);
81 std::function<float32_t(const SGMatrix<float32_t>&)> statistic_job;
82 std::function<float32_t(const SGMatrix<float32_t>&)> permutation_job;
83 std::function<float32_t(const SGMatrix<float32_t>&)> variance_job;
87 use_gpu(false), num_null_samples(250),
91 statistic_job(nullptr), variance_job(nullptr)
108 REQUIRE(Bx>0,
"Blocksize for samples from P cannot be 0!\n");
109 REQUIRE(By>0,
"Blocksize for samples from Q cannot be 0!\n");
111 auto mmd=mmd::ComputeMMD();
137 #pragma omp parallel for 138 for (int64_t i=0; i<(int64_t)blocks.size(); ++i)
143 blocks[i]=block_p_and_q;
151 cm.num_data(blocks.size());
152 #pragma omp parallel for 153 for (int64_t i=0; i<(int64_t)blocks.size(); ++i)
157 auto kernel_clone=std::unique_ptr<CKernel>(
static_cast<CKernel*
>(kernel->
clone()));
158 kernel_clone->init(blocks[i], blocks[i]);
159 cm.data(i)=kernel_clone->get_kernel_matrix<
float32_t>();
160 kernel_clone->remove_lhs_and_rhs();
172 cm.use_gpu().compute_data_parallel_jobs();
174 cm.use_cpu().compute_data_parallel_jobs();
180 auto kernel=kernel_mgr.kernel_at(0);
181 REQUIRE(kernel !=
nullptr,
"Kernel is not set!\n");
186 index_t statistic_term_counter=1;
187 index_t variance_term_counter=1;
191 auto next_burst=data_mgr.next();
192 if (!next_burst.empty())
194 ComputationManager cm;
199 std::vector<CFeatures*> blocks;
201 while (!next_burst.empty())
208 auto mmds=cm.result(0);
209 auto vars=cm.result(1);
211 for (
size_t i=0; i<mmds.size(); ++i)
213 auto delta=mmds[i]-statistic;
214 statistic+=delta/statistic_term_counter;
215 statistic_term_counter++;
220 for (
size_t i=0; i<mmds.size(); ++i)
222 auto delta=vars[i]-variance;
223 variance+=delta/variance_term_counter;
224 variance_term_counter++;
229 for (
size_t i=0; i<vars.size(); ++i)
231 auto delta=vars[i]-permuted_samples_statistic;
232 permuted_samples_statistic+=delta/variance_term_counter;
233 variance+=delta*(vars[i]-permuted_samples_statistic);
234 variance_term_counter++;
237 next_burst=data_mgr.next();
248 return std::make_pair(statistic, variance);
257 REQUIRE(kernel_selection_mgr.num_kernels()>0,
"No kernels specified for kernel learning! " 258 "Please add kernels using add_kernel() method!\n");
260 const auto num_kernels=kernel_selection_mgr.num_kernels();
264 std::fill(statistic.
data(), statistic.
data()+statistic.
size(), 0);
268 term_counters_statistic.set_const(1);
270 std::fill(term_counters_Q.
data(), term_counters_Q.
data()+term_counters_Q.
size(), 1);
273 ComputationManager cm;
278 auto next_burst=data_mgr.next();
279 std::vector<CFeatures*> blocks;
280 std::vector<std::vector<float32_t> > mmds(num_kernels);
281 while (!next_burst.empty())
283 const auto num_blocks=next_burst.num_blocks();
285 "The number of blocks per burst (%d this burst) has to be even!\n",
288 std::for_each(blocks.begin(), blocks.end(), [](
CFeatures* ptr) {
SG_REF(ptr); });
289 for (
auto k=0; k<num_kernels; ++k)
291 CKernel* kernel=kernel_selection_mgr.kernel_at(k);
294 mmds[k]=cm.result(0);
295 for (
auto i=0; i<num_blocks; ++i)
297 auto delta=mmds[k][i]-statistic[k];
298 statistic[k]+=delta/term_counters_statistic[k]++;
301 std::for_each(blocks.begin(), blocks.end(), [](
CFeatures* ptr) {
SG_UNREF(ptr); });
303 for (
auto i=0; i<num_kernels; ++i)
305 for (
auto j=0; j<=i; ++j)
307 for (
auto k=0; k<num_blocks-1; k+=2)
309 auto term=(mmds[i][k]-mmds[i][k+1])*(mmds[j][k]-mmds[j][k+1]);
310 Q(i, j)+=(term-Q(i, j))/term_counters_Q(i, j)++;
315 next_burst=data_mgr.next();
326 return std::make_pair(statistic, Q);
332 auto kernel=kernel_mgr.kernel_at(0);
333 REQUIRE(kernel !=
nullptr,
"Kernel is not set!\n");
339 std::fill(term_counters.
data(), term_counters.
data()+term_counters.
size(), 1);
342 ComputationManager cm;
347 std::vector<CFeatures*> blocks;
350 auto next_burst=data_mgr.next();
352 while (!next_burst.empty())
361 auto mmds=cm.result(0);
362 for (
size_t i=0; i<mmds.size(); ++i)
364 auto delta=mmds[i]-statistic[j];
365 statistic[j]+=delta/term_counters[j];
369 next_burst=data_mgr.next();
378 value=owner.normalize_statistic(value);
384 CStreamingMMD::CStreamingMMD() :
CMMD()
386 #if EIGEN_VERSION_AT_LEAST(3,1,0) 387 Eigen::initParallel();
389 self=std::unique_ptr<Self>(
new Self(*
this));
398 return self->compute_statistic_variance().first;
403 return self->compute_statistic_variance().second;
411 std::pair<float64_t, float64_t> CStreamingMMD::compute_statistic_variance()
413 return self->compute_statistic_variance();
416 std::pair<SGVector<float64_t>,
SGMatrix<float64_t> > CStreamingMMD::compute_statistic_and_Q(
const KernelManager& kernel_selection_mgr)
418 return self->compute_statistic_and_Q(kernel_selection_mgr);
423 return self->sample_null();
428 self->num_null_samples=null_samples;
433 return self->num_null_samples;
443 return self->use_gpu;
454 self->statistic_type=stype;
459 return self->statistic_type;
469 self->variance_estimation_method=vmethod;
474 return self->variance_estimation_method;
489 self->null_approximation_method=nmethod;
494 return self->null_approximation_method;
499 return "StreamingMMD";
void set_statistic_type(EStatisticType stype)
void set_variance_estimation_method(EVarianceEstimationMethod vmethod)
virtual SGVector< float64_t > compute_multiple()
void compute_jobs(ComputationManager &) const
virtual CSGObject * clone()
Class ShogunException defines an exception which is thrown whenever an error inside of shogun occurs...
virtual float64_t compute_variance()
const EVarianceEstimationMethod get_variance_estimation_method() const
void create_variance_job()
std::function< float32_t(const SGMatrix< float32_t > &)> statistic_job
virtual SGVector< float64_t > sample_null()
SGVector< float64_t > sample_null()
void merge_samples(NextSamples &, std::vector< CFeatures *> &) const
std::function< float32_t(const SGMatrix< float32_t > &)> variance_job
virtual CFeatures * create_merged_copy(CList *others)
ENullApproximationMethod null_approximation_method
virtual const operation get_direct_estimation_method() const =0
std::function< float32_t(const SGMatrix< float32_t > &)> permutation_job
const index_t blocksize_at(index_t i) const
void compute_kernel(ComputationManager &, std::vector< CFeatures *> &, CKernel *) const
CKernelSelectionStrategy const * get_kernel_selection_strategy() const
void set_null_approximation_method(ENullApproximationMethod nmethod)
virtual const float64_t normalize_variance(float64_t variance) const =0
Self(CStreamingMMD &cmmd)
internal::DataManager & get_data_mgr()
std::pair< SGVector< float64_t >, SGMatrix< float64_t > > compute_statistic_and_Q(const KernelManager &)
void create_statistic_job()
const index_t get_num_null_samples() const
const EStatisticType get_statistic_type() const
EVarianceEstimationMethod
internal::KernelManager & get_kernel_mgr()
std::pair< float64_t, float64_t > compute_statistic_variance()
const char * get_exception_string()
all of classes and functions are contained in the shogun namespace
virtual EKernelType get_kernel_type()=0
The class Features is the base class of all feature objects.
virtual float64_t normalize_statistic(float64_t statistic) const =0
void create_computation_jobs()
const ENullApproximationMethod get_null_approximation_method() const
Class DataManager for fetching/streaming test data block-wise. It can handle data coming from multipl...
Abstract base class that provides an interface for performing kernel two-sample test using Maximum Me...
void set_num_null_samples(index_t null_samples)
EVarianceEstimationMethod variance_estimation_method
EStatisticType statistic_type
class NextSamples is the return type for next() call in DataManager. If there are no more samples (fr...
const index_t num_blocks() const
virtual float64_t compute_statistic()
virtual const char * get_name() const