41 using namespace internal;
45 SG_SDEBUG(
"Data manager instance initialized with %d data sources!\n", num_distributions);
46 fetchers.resize(num_distributions);
47 std::fill(fetchers.begin(), fetchers.end(),
nullptr);
49 train_test_mode=default_train_test_mode;
50 train_mode=default_train_mode;
51 train_test_ratio=default_train_test_ratio;
52 cross_validation_mode=default_cross_validation_mode;
63 typedef const std::unique_ptr<DataFetcher> fetcher_type;
64 if (std::any_of(fetchers.begin(), fetchers.end(), [](fetcher_type& f) {
return f->m_num_samples==0; }))
65 SG_SERROR(
"number of samples from all the distributions are not set!")
67 std::for_each(fetchers.begin(), fetchers.end(), [&n](fetcher_type& f) { n+=f->m_num_samples; });
76 typedef const std::unique_ptr<DataFetcher> fetcher_type;
77 if (std::any_of(fetchers.begin(), fetchers.end(), [](fetcher_type& f) {
return f->m_num_samples==0; }))
78 SG_SERROR(
"number of samples from all the distributions are not set!")
82 for (
size_t i=0; i<fetchers.size(); ++i)
83 divisor=
CMath::gcd(divisor, fetchers[i]->m_num_samples);
86 SG_SDEBUG(
"min blocksize is %d!", min_blocksize);
97 "Total number of samples is 0! Please set the number of samples!\n");
98 REQUIRE(blocksize>0 && blocksize<=n,
99 "The blocksize has to be within [0, %d], given = %d!\n",
102 "Total number of samples (%d) has to be divisble by the blocksize (%d)!\n",
105 for (
size_t i=0; i<fetchers.size(); ++i)
107 index_t m=fetchers[i]->m_num_samples;
109 "Blocksize (%d) cannot be even distributed with a ratio of %f!\n",
111 fetchers[i]->fetch_blockwise().with_blocksize(blocksize*m/n);
112 SG_SDEBUG(
"block[%d].size = ", i, blocksize*m/n);
120 REQUIRE(num_blocks_per_burst>0,
121 "Number of blocks per burst (%d) has to be greater than 0!\n",
122 num_blocks_per_burst);
125 typedef std::unique_ptr<DataFetcher> fetcher_type;
126 std::for_each(fetchers.begin(), fetchers.end(), [&blocksize](fetcher_type& f)
128 blocksize+=f->m_block_details.m_blocksize;
131 "Blocksizes are not set!\n");
134 if (num_blocks_per_burst>max_num_blocks_per_burst)
136 SG_SINFO(
"There can only be %d blocks per burst given the blocksize (%d)!\n", max_num_blocks_per_burst, blocksize);
137 SG_SINFO(
"Setting num blocks per burst to be %d instead!\n", max_num_blocks_per_burst);
138 num_blocks_per_burst=max_num_blocks_per_burst;
141 for (
size_t i=0; i<fetchers.size(); ++i)
142 fetchers[i]->fetch_blockwise().with_num_blocks_per_burst(num_blocks_per_burst);
149 REQUIRE(i<(int64_t)fetchers.size(),
150 "Value of i (%d) should be between 0 and %d, inclusive!",
151 i, fetchers.size()-1);
153 return InitPerFeature(fetchers[i]);
159 REQUIRE(i<(int64_t)fetchers.size(),
160 "Value of i (%d) should be between 0 and %d, inclusive!",
161 i, fetchers.size()-1);
163 if (fetchers[i]!=
nullptr)
164 return fetchers[i]->m_samples;
172 REQUIRE(i<(int64_t)fetchers.size(),
173 "Value of i (%d) should be between 0 and %d, inclusive!",
174 i, fetchers.size()-1);
176 return fetchers[i]->m_num_samples;
182 REQUIRE(i<(int64_t)fetchers.size(),
183 "Value of i (%d) should be between 0 and %d, inclusive!",
184 i, fetchers.size()-1);
186 if (fetchers[i]!=
nullptr)
187 return fetchers[i]->get_num_samples();
195 REQUIRE(i<(int64_t)fetchers.size(),
196 "Value of i (%d) should be between 0 and %d, inclusive!",
197 i, fetchers.size()-1);
199 if (fetchers[i]!=
nullptr)
200 return fetchers[i]->m_block_details.m_blocksize;
205 void DataManager::set_blockwise(
bool blockwise)
208 for (
size_t i=0; i<fetchers.size(); ++i)
209 fetchers[i]->set_blockwise(blockwise);
213 const bool DataManager::is_blockwise()
const 217 for (
size_t i=0; i<fetchers.size(); ++i)
218 blockwise&=!fetchers[i]->m_block_details.m_full_data;
223 void DataManager::set_train_test_mode(
bool on)
226 if (!train_test_mode)
228 train_mode=default_train_mode;
229 train_test_ratio=default_train_test_ratio;
230 cross_validation_mode=default_cross_validation_mode;
232 REQUIRE(fetchers.size()>0,
"Features are not set!");
233 typedef std::unique_ptr<DataFetcher> fetcher_type;
234 std::for_each(fetchers.begin(), fetchers.end(), [
this, on](fetcher_type& f)
236 f->set_train_test_mode(on);
239 f->set_train_mode(train_mode);
240 f->set_train_test_ratio(train_test_ratio);
245 bool DataManager::is_train_test_mode()
const 247 return train_test_mode;
250 void DataManager::set_train_mode(
bool on)
256 SG_SERROR(
"Train mode cannot be used without turning on Train/Test mode first!" 257 "Please call set_train_test_mode(True) before using this method!\n");
261 bool DataManager::is_train_mode()
const 266 void DataManager::set_cross_validation_mode(
bool on)
269 cross_validation_mode=on;
272 SG_SERROR(
"Cross-validation mode cannot be used without turning on Train/Test mode first!" 273 "Please call set_train_test_mode(True) before using this method!\n");
277 bool DataManager::is_cross_validation_mode()
const 279 return cross_validation_mode;
282 void DataManager::set_train_test_ratio(
float64_t ratio)
285 train_test_ratio=ratio;
288 SG_SERROR(
"Train-test ratio cannot be set without turning on Train/Test mode first!" 289 "Please call set_train_test_mode(True) before using this method!\n");
293 float64_t DataManager::get_train_test_ratio()
const 295 return train_test_ratio;
298 index_t DataManager::get_num_folds()
const 300 return ceil(get_train_test_ratio())+1;
303 void DataManager::shuffle_features()
306 REQUIRE(fetchers.size()>0,
"Features are not set!");
307 typedef std::unique_ptr<DataFetcher> fetcher_type;
308 std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->shuffle_features(); });
312 void DataManager::unshuffle_features()
315 REQUIRE(fetchers.size()>0,
"Features are not set!");
316 typedef std::unique_ptr<DataFetcher> fetcher_type;
317 std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->unshuffle_features(); });
321 void DataManager::init_active_subset()
326 "Train-test subset cannot be used without turning on Train/Test mode first!" 327 "Please call set_train_test_mode(True) before using this method!\n");
328 REQUIRE(fetchers.size()>0,
"Features are not set!");
330 typedef std::unique_ptr<DataFetcher> fetcher_type;
331 std::for_each(fetchers.begin(), fetchers.end(), [
this](fetcher_type& f)
333 f->set_train_mode(train_mode);
334 f->set_train_test_ratio(train_test_ratio);
335 f->init_active_subset();
340 void DataManager::use_fold(
index_t idx)
345 "Fold subset cannot be used without turning on Train/Test mode first!" 346 "Please call set_train_test_mode(True) before using this method!\n");
347 REQUIRE(fetchers.size()>0,
"Features are not set!");
348 REQUIRE(idx>=0,
"Fold index has to be in [0, %d]!", get_num_folds()-1);
349 REQUIRE(idx<get_num_folds(),
"Fold index has to be in [0, %d]!", get_num_folds()-1);
351 typedef std::unique_ptr<DataFetcher> fetcher_type;
352 std::for_each(fetchers.begin(), fetchers.end(), [
this, idx](fetcher_type& f)
354 f->set_train_mode(train_mode);
355 f->set_train_test_ratio(train_test_ratio);
361 void DataManager::start()
364 REQUIRE(fetchers.size()>0,
"Features are not set!");
366 if (train_test_mode && !cross_validation_mode)
367 init_active_subset();
369 typedef std::unique_ptr<DataFetcher> fetcher_type;
370 std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->start(); });
382 for (
size_t i=0; i<fetchers.size(); ++i)
384 auto feats=fetchers[i]->next();
387 auto blocksize=fetchers[i]->m_block_details.m_blocksize;
388 auto num_blocks_curr_burst=feats->get_num_vectors()/blocksize;
391 if (next_samples.m_num_blocks==0)
392 next_samples.m_num_blocks=num_blocks_curr_burst;
394 ASSERT(next_samples.m_num_blocks==num_blocks_curr_burst);
404 void DataManager::end()
407 REQUIRE(fetchers.size()>0,
"Features are not set!");
408 typedef std::unique_ptr<DataFetcher> fetcher_type;
409 std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->end(); });
413 void DataManager::reset()
416 REQUIRE(fetchers.size()>0,
"Features are not set!");
417 typedef std::unique_ptr<DataFetcher> fetcher_type;
418 std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->reset(); });
static int32_t gcd(int32_t a, int32_t b)
static std::vector< Block > create_blocks(CFeatures *feats, index_t num_blocks, index_t size)
DataManager(index_t num_distributions)
index_t & num_samples_at(index_t i)
void set_blocksize(index_t blocksize)
InitPerFeature samples_at(index_t i)
const index_t blocksize_at(index_t i) const
index_t get_num_samples() const
void set_num_blocks_per_burst(index_t num_blocks_per_burst)
all of classes and functions are contained in the shogun namespace
The class Features is the base class of all feature objects.
class NextSamples is the return type for next() call in DataManager. If there are no more samples (fr...
index_t get_min_blocksize() const