43 using namespace internal;
51 set_p(samples_from_p);
52 set_q(samples_from_q);
62 auto min_blocksize=data_mgr.get_min_blocksize();
66 auto N=data_mgr.num_samples_at(0);
67 for (
auto i=2; i<N; ++i)
76 data_mgr.set_blocksize(min_blocksize);
77 data_mgr.set_num_blocks_per_burst(num_blocks_per_burst);
78 SG_SDEBUG(
"Block contains %d and %d samples, from P and Q respectively!\n", data_mgr.blocksize_at(0), data_mgr.blocksize_at(1));
81 const std::function<float32_t(SGMatrix<float32_t>)> CLinearTimeMMD::get_direct_estimation_method()
const 83 return mmd::WithinBlockDirect();
102 return variance * B * (B - 2) / 16;
104 return variance * Bx * By * (Bx - 1) * (By - 1) / (B - 1) / (B - 2);
115 return variance * 4 / (B - 2);
117 return variance * (B - 1) * (B - 2) / (Bx - 1) / (By - 1) / B;
134 SG_SERROR(
"Null approximation via permutation does not make sense " 135 "for linear time MMD. Use the Gaussian approximation instead.\n");
170 return "LinearTimeMMD";
index_t & num_samples_at(index_t i)
virtual float64_t compute_p_value(float64_t statistic)
void set_num_blocks_per_burst(index_t num_blocks_per_burst)
virtual float64_t compute_variance()
virtual void set_p(CFeatures *samples_from_p)
virtual float64_t compute_threshold(float64_t alpha)
virtual void set_q(CFeatures *samples_from_q)
virtual const char * get_name() const
const index_t blocksize_at(index_t i) const
virtual ~CLinearTimeMMD()
virtual float64_t compute_threshold(float64_t alpha)
static float64_t inverse_normal_cdf(float64_t y0, float64_t mean=0, float64_t std_dev=1)
internal::DataManager & get_data_mgr()
const EStatisticType get_statistic_type() const
all of classes and functions are contained in the shogun namespace
The class Features is the base class of all feature objects.
const ENullApproximationMethod get_null_approximation_method() const
static float64_t normal_cdf(float64_t x, float64_t std_dev=1)
Class DataManager for fetching/streaming test data block-wise. It can handle data coming from multipl...
static float32_t sqrt(float32_t x)
virtual float64_t compute_p_value(float64_t statistic)