37 using namespace internal;
39 DataFetcher::DataFetcher() : m_num_samples(0), train_test_mode(false),
40 train_mode(false), m_samples(nullptr), features_shuffled(false)
44 DataFetcher::DataFetcher(
CFeatures* samples) : train_test_mode(false),
45 train_mode(false), m_samples(samples), features_shuffled(false)
47 REQUIRE(m_samples!=
nullptr,
"Samples cannot be null!\n");
49 m_num_samples=m_samples->get_num_vectors();
52 DataFetcher::~DataFetcher()
57 void DataFetcher::set_blockwise(
bool blockwise)
61 m_block_details=last_blockwise_details;
62 SG_SDEBUG(
"Restoring the blockwise details!\n");
63 m_block_details.m_full_data=
false;
67 last_blockwise_details=m_block_details;
68 SG_SDEBUG(
"Saving the blockwise details!\n");
73 void DataFetcher::set_train_test_mode(
bool on)
78 bool DataFetcher::is_train_test_mode()
const 80 return train_test_mode;
83 void DataFetcher::set_train_mode(
bool on)
88 bool DataFetcher::is_train_mode()
const 93 void DataFetcher::set_train_test_ratio(
float64_t ratio)
95 train_test_ratio=ratio;
98 float64_t DataFetcher::get_train_test_ratio()
const 100 return train_test_ratio;
103 void DataFetcher::shuffle_features()
105 REQUIRE(train_test_mode,
"This method is allowed only when Train/Test method is active!\n");
106 if (features_shuffled)
108 SG_SWARNING(
"Features are already shuffled! Call to shuffle_features() has no effect." 109 "If you want to reshuffle, please call unshuffle_features() first and then call this method!\n");
113 const index_t size=m_samples->get_num_vectors();
114 SG_SDEBUG(
"Current number of feature vectors = %d\n", size);
115 if (shuffle_subset.size()<size)
117 SG_SDEBUG(
"Resizing the shuffle indices vector (from %d to %d)\n", shuffle_subset.size(), size);
120 std::iota(shuffle_subset.data(), shuffle_subset.data()+shuffle_subset.size(), 0);
124 SG_SDEBUG(
"Shuffling %d feature vectors\n", size);
125 m_samples->add_subset(shuffle_subset);
127 features_shuffled=
true;
131 void DataFetcher::unshuffle_features()
133 REQUIRE(train_test_mode,
"This method is allowed only when Train/Test method is active!\n");
134 if (features_shuffled)
136 m_samples->remove_subset();
137 features_shuffled=
false;
141 SG_SWARNING(
"Features are NOT shuffled! Call to unshuffle_features() has no effect." 142 "If you want to reshuffle, please call shuffle_features() instead!\n");
146 void DataFetcher::use_fold(
index_t idx)
148 allocate_active_subset();
149 auto num_samples_per_fold=get_num_samples()/get_num_folds();
150 auto start_idx=idx*num_samples_per_fold;
153 std::iota(active_subset.data(), active_subset.data()+active_subset.size(), 0);
154 if (start_idx<active_subset.size())
156 std::for_each(active_subset.data()+start_idx, active_subset.data()+active_subset.size(),
157 [&num_samples_per_fold](
index_t& val)
159 val+=num_samples_per_fold;
164 std::iota(active_subset.data(), active_subset.data()+active_subset.size(), start_idx);
168 void DataFetcher::init_active_subset()
170 allocate_active_subset();
173 start_index=m_samples->get_num_vectors()*train_test_ratio/(train_test_ratio+1);
174 std::iota(active_subset.data(), active_subset.data()+active_subset.size(), start_index);
178 void DataFetcher::start()
180 REQUIRE(get_num_samples()>0,
"Number of samples is 0!\n");
183 m_samples->add_subset(active_subset);
185 SG_SINFO(
"Currently active number of samples is %d\n", get_num_samples());
188 if (m_block_details.m_full_data || m_block_details.m_blocksize>get_num_samples())
190 SG_SINFO(
"Fetching entire data (%d samples)!\n", get_num_samples());
191 m_block_details.with_blocksize(get_num_samples());
193 m_block_details.m_total_num_blocks=get_num_samples()/m_block_details.m_blocksize;
201 auto num_already_fetched=m_block_details.m_next_block_index*m_block_details.m_blocksize;
202 auto num_more_samples=get_num_samples()-num_already_fetched;
203 if (num_more_samples>0)
207 auto num_samples_this_burst=std::min(m_block_details.m_max_num_samples_per_burst, num_more_samples);
208 if (num_samples_this_burst<next_samples->get_num_vectors())
211 std::iota(inds.vector, inds.vector+inds.vlen, num_already_fetched);
214 m_block_details.m_next_block_index+=m_block_details.m_num_blocks_per_burst;
219 void DataFetcher::reset()
221 m_block_details.m_next_block_index=0;
224 void DataFetcher::end()
230 SG_SINFO(
"Currently active number of samples is %d\n", get_num_samples());
234 index_t DataFetcher::get_num_samples()
const 239 return m_num_samples*train_test_ratio/(train_test_ratio+1);
241 return m_num_samples/(train_test_ratio+1);
243 return m_samples->get_num_vectors();
246 index_t DataFetcher::get_num_folds()
const 248 return 1+ceil(get_train_test_ratio());
251 index_t DataFetcher::get_num_training_samples()
const 253 return get_num_samples()*get_train_test_ratio()/(get_train_test_ratio()+1);
256 index_t DataFetcher::get_num_testing_samples()
const 258 return get_num_samples()/(get_train_test_ratio()+1);
263 m_block_details.m_full_data=
false;
264 return m_block_details;
267 void DataFetcher::allocate_active_subset()
269 REQUIRE(train_test_mode,
"This method is allowed only when Train/Test method is active!\n");
273 num_active_samples=m_samples->get_num_vectors()*train_test_ratio/(train_test_ratio+1);
274 SG_SINFO(
"Using %d number of samples for this fold as training samples!\n", num_active_samples);
278 num_active_samples=m_samples->get_num_vectors()/(train_test_ratio+1);
279 SG_SINFO(
"Using %d number of samples for this fold as testing samples!\n", num_active_samples);
282 ASSERT(num_active_samples>0);
283 if (active_subset.size()!=num_active_samples)
285 SG_SDEBUG(
"Resizing the active subset from %d to %d\n", active_subset.size(), num_active_samples);
virtual CFeatures * shallow_subset_copy()
static void permute(SGVector< T > v, CRandom *rand=NULL)
all of classes and functions are contained in the shogun namespace
virtual void remove_subset()
The class Features is the base class of all feature objects.
Class that holds block-details for the data-fetchers. There are one instance of this class per fetche...
virtual void add_subset(SGVector< index_t > subset)