15 using namespace shogun;
23 SG_ERROR(
"Expected StreamingVwFeatures\n")
24 set_features(dynamic_cast<CStreamingVwFeatures*>(data));
27 vector<int32_t> predicts;
29 m_feats->start_parser();
30 while (m_feats->get_next_example())
32 predicts.push_back(apply_multiclass_example(m_feats->get_example()));
33 m_feats->release_example();
35 m_feats->end_parser();
38 for (
size_t i=0; i < predicts.size(); ++i)
47 compute_conditional_probabilities(ex);
49 for (map<int32_t,bnode_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it)
51 probs[it->first] = accumulate_conditional_probability(it->second);
58 stack<bnode_t *> nodes;
61 while (!nodes.empty())
67 nodes.push(node->
left());
68 nodes.push(node->
right());
71 node->
data.p_right = train_node(ex, node);
82 if (leaf == par->
left())
83 prob *= (1-par->
data.p_right);
85 prob *= par->
data.p_right;
99 SG_ERROR(
"Expected StreamingVwFeatures\n")
100 set_features(dynamic_cast<CStreamingVwFeatures*>(data));
105 SG_ERROR(
"No data features provided\n")
108 m_machines->reset_array();
114 m_feats->start_parser();
115 for (int32_t ipass=0; ipass < m_num_passes; ++ipass)
117 while (m_feats->get_next_example())
119 train_example(m_feats->get_example());
120 m_feats->release_example();
123 if (ipass < m_num_passes-1)
124 m_feats->reset_stream();
126 m_feats->end_parser();
133 int32_t label =
static_cast<int32_t
>(ex->
ld->
label);
138 m_root->data.label = label;
139 printf(
" insert %d %p\n", label, m_root);
140 m_leaves.insert(make_pair(label,(
bnode_t*) m_root));
141 m_root->machine(create_machine(ex));
145 if (m_leaves.find(label) != m_leaves.end())
147 train_path(ex, m_leaves[label]);
152 while (node->
left() != NULL)
155 bool is_left = which_subtree(node, ex);
160 train_node(ex, node);
165 node = node->
right();
168 printf(
" remove %d %p\n", node->
data.label, m_leaves[node->
data.label]);
169 m_leaves.erase(node->
data.label);
172 left_node->
data.label = node->
data.label;
173 node->
data.label = -1;
178 m_machines->push_back(vw);
179 left_node->
machine(m_machines->get_num_elements()-1);
180 printf(
" insert %d %p\n", left_node->
data.label, left_node);
181 m_leaves.insert(make_pair(left_node->
data.label, left_node));
182 node->
left(left_node);
185 right_node->
data.label = label;
186 right_node->
machine(create_machine(ex));
187 printf(
" insert %d %p\n", label, right_node);
188 m_leaves.insert(make_pair(label, right_node));
189 node->
right(right_node);
196 train_node(ex, node);
201 if (par->
left() == node)
229 m_machines->push_back(vw);
230 return m_machines->get_num_elements()-1;
void parent(CTreeMachineNode *par)
The node of the tree structure forming a TreeMachine The node contains pointer to its parent and poin...
void machine(int32_t idx)
float64_t train_node(VwExample *ex, bnode_t *node)
void train_example(VwExample *ex)
CVwLearner * get_learner()
static int32_t arg_max(T *vec, int32_t inc, int32_t len, T *maxv_ptr=NULL)
return arg_max(vec)
float64_t accumulate_conditional_probability(bnode_t *leaf)
float32_t label
Label value.
Multiclass Labels for multi-class classification.
virtual void set_learner()
void right(CBinaryTreeMachineNode *r)
virtual bool train_machine(CFeatures *data)
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
virtual void train(VwExample *&ex, float32_t update)=0
virtual EFeatureClass get_feature_class() const =0
int32_t create_machine(VwExample *ex)
virtual float32_t predict_and_finalize(VwExample *ex)
void train_path(VwExample *ex, bnode_t *node)
CBinaryTreeMachineNode< VwConditionalProbabilityTreeNodeData > bnode_t
bool set_int_label(int32_t idx, int32_t label)
The class Features is the base class of all feature objects.
VwLabel * ld
Label object.
float32_t eta_round
Learning rate for this round.
Class CVowpalWabbit is the implementation of the online learning algorithm used in Vowpal Wabbit...
virtual int32_t apply_multiclass_example(VwExample *ex)
void left(CBinaryTreeMachineNode *l)
void compute_conditional_probabilities(VwExample *ex)