35 CMulticlassOCAS::CMulticlassOCAS() :
38 register_parameters();
41 set_max_iter(1000000);
49 register_parameters();
51 set_max_iter(1000000);
56 void CMulticlassOCAS::register_parameters()
65 CMulticlassOCAS::~CMulticlassOCAS()
69 bool CMulticlassOCAS::train_machine(
CFeatures* data)
76 ASSERT(m_multiclass_strategy)
79 int32_t num_classes = m_multiclass_strategy->get_num_classes();
80 int32_t num_features =
m_features->get_dim_feature_space();
84 uint32_t nY = num_classes;
85 uint32_t nData = num_vectors;
90 uint32_t BufSize = m_buf_size;
91 uint8_t Method = m_method;
95 user_data.W = SG_CALLOC(
float64_t, (int64_t)num_features*num_classes);
96 user_data.oldW = SG_CALLOC(
float64_t, (int64_t)num_features*num_classes);
97 user_data.new_a = SG_CALLOC(
float64_t, (int64_t)num_features*num_classes);
98 user_data.full_A = SG_CALLOC(
float64_t, (int64_t)num_features*num_classes*m_buf_size);
99 user_data.output_values = SG_CALLOC(
float64_t, num_vectors);
100 user_data.data_y = labels.
vector;
101 user_data.nY = num_classes;
102 user_data.nDim = num_features;
103 user_data.nData = num_vectors;
105 ocas_return_value_T value =
106 msvm_ocas_solver(C, labels.
vector, nY, nData, TolRel, TolAbs,
107 QPBound, MaxTime, BufSize, Method,
108 &CMulticlassOCAS::msvm_full_compute_W,
109 &CMulticlassOCAS::msvm_update_W,
110 &CMulticlassOCAS::msvm_full_add_new_cut,
111 &CMulticlassOCAS::msvm_full_compute_output,
112 &CMulticlassOCAS::msvm_sort_data,
113 &CMulticlassOCAS::msvm_print,
116 SG_DEBUG(
"Number of iterations [nIter] = %d \n",value.nIter)
117 SG_DEBUG("Number of cutting planes [nCutPlanes] = %d \n",value.nCutPlanes)
118 SG_DEBUG("Number of non-zero alphas [nNZAlpha] = %d \n",value.nNZAlpha)
119 SG_DEBUG("Number of training errors [trn_err] = %d \n",value.trn_err)
120 SG_DEBUG("Primal objective value [Q_P] = %f \n",value.Q_P)
121 SG_DEBUG("Dual objective value [Q_D] = %f \n",value.Q_D)
122 SG_DEBUG("Output time [output_time] = %f \n",value.output_time)
123 SG_DEBUG("Sort time [sort_time] = %f \n",value.sort_time)
124 SG_DEBUG("Add time [add_time] = %f \n",value.add_time)
125 SG_DEBUG("W time [w_time] = %f \n",value.w_time)
126 SG_DEBUG("QP solver time [qp_solver_time] = %f \n",value.qp_solver_time)
127 SG_DEBUG("OCAS time [ocas_time] = %f \n",value.ocas_time)
128 SG_DEBUG("Print time [print_time] = %f \n",value.print_time)
129 SG_DEBUG("QP exit flag [qp_exitflag] = %d \n",value.qp_exitflag)
130 SG_DEBUG("Exit flag [exitflag] = %d \n",value.exitflag)
132 m_machines->reset_array();
133 for (int32_t i=0; i<num_classes; i++)
138 m_machines->push_back(machine);
141 SG_FREE(user_data.W);
142 SG_FREE(user_data.oldW);
143 SG_FREE(user_data.new_a);
144 SG_FREE(user_data.full_A);
145 SG_FREE(user_data.output_values);
152 float64_t* W = ((mocas_data*)user_data)->W;
153 float64_t* oldW = ((mocas_data*)user_data)->oldW;
154 uint32_t nY = ((mocas_data*)user_data)->nY;
155 uint32_t nDim = ((mocas_data*)user_data)->nDim;
157 for(uint32_t j=0; j < nY*nDim; j++)
158 W[j] = oldW[j]*(1-t) + t*W[j];
166 float64_t *alpha, uint32_t nSel,
void* user_data)
168 float64_t* W = ((mocas_data*)user_data)->W;
169 float64_t* oldW = ((mocas_data*)user_data)->oldW;
170 float64_t* full_A = ((mocas_data*)user_data)->full_A;
171 uint32_t nY = ((mocas_data*)user_data)->nY;
172 uint32_t nDim = ((mocas_data*)user_data)->nDim;
176 sg_memcpy(oldW, W,
sizeof(
float64_t)*nDim*nY);
179 for(i=0; i<nSel; i++)
183 for(j=0; j<nDim*nY; j++)
184 W[j] += alpha[i]*full_A[LIBOCAS_INDEX(j,i,nDim*nY)];
194 int CMulticlassOCAS::msvm_full_add_new_cut(
float64_t *new_col_H, uint32_t *new_cut,
195 uint32_t nSel,
void* user_data)
197 float64_t* full_A = ((mocas_data*)user_data)->full_A;
198 float64_t* new_a = ((mocas_data*)user_data)->new_a;
199 float64_t* data_y = ((mocas_data*)user_data)->data_y;
200 uint32_t nY = ((mocas_data*)user_data)->nY;
201 uint32_t nDim = ((mocas_data*)user_data)->nDim;
202 uint32_t nData = ((mocas_data*)user_data)->nData;
203 CDotFeatures* features = ((mocas_data*)user_data)->features;
206 uint32_t i, j, y, y2;
208 memset(new_a, 0,
sizeof(
float64_t)*nDim*nY);
210 for(i=0; i < nData; i++)
212 y = (uint32_t)(data_y[i]);
213 y2 = (uint32_t)new_cut[i];
223 for(j=0; j < nDim*nY; j++ )
224 full_A[LIBOCAS_INDEX(j,nSel,nDim*nY)] = new_a[j];
226 new_col_H[nSel] = sq_norm_a;
227 for(i=0; i < nSel; i++)
231 for(j=0; j < nDim*nY; j++ )
232 tmp += new_a[j]*full_A[LIBOCAS_INDEX(j,i,nDim*nY)];
240 int CMulticlassOCAS::msvm_full_compute_output(
float64_t *output,
void* user_data)
242 float64_t* W = ((mocas_data*)user_data)->W;
243 uint32_t nY = ((mocas_data*)user_data)->nY;
244 uint32_t nDim = ((mocas_data*)user_data)->nDim;
245 uint32_t nData = ((mocas_data*)user_data)->nData;
246 float64_t* output_values = ((mocas_data*)user_data)->output_values;
247 CDotFeatures* features = ((mocas_data*)user_data)->features;
253 features->
dense_dot_range(output_values,0,nData,NULL,&W[nDim*y],nDim,0.0);
254 for (i=0; i<nData; i++)
255 output[LIBOCAS_INDEX(y,i,nY)] = output_values[i];
267 void CMulticlassOCAS::msvm_print(ocas_return_value_T value)
271 #endif //USE_GPL_SHOGUN
virtual void dense_dot_range(float64_t *output, int32_t start, int32_t stop, float64_t *alphas, float64_t *vec, int32_t dim, float64_t b)
virtual void set_w(const SGVector< float64_t > src_w)
The class Labels models labels, i.e. class assignments of objects.
static void qsort_index(T1 *output, T2 *index, uint32_t size)
virtual int32_t get_num_vectors() const =0
virtual void add_to_dense_vec(float64_t alpha, int32_t vec_idx1, float64_t *vec2, int32_t vec2_len, bool abs_val=false)=0
Features that support dot products among other operations.
Multiclass Labels for multi-class classification.
void set_features(CFeatures *feats)
generic linear multiclass machine
Class LinearMachine is a generic interface for all kinds of linear machines like classifiers.
static float64_t dot(const bool *v1, const bool *v2, int32_t n)
Compute dot product between v1 and v2 (blas optimized)
CStructuredLabels * m_labels
all of classes and functions are contained in the shogun namespace
The class Features is the base class of all feature objects.
void set_epsilon(float *begin, float max)
multiclass one vs rest strategy used to train generic multiclass machines for K-class problems with b...
SGVector< T > clone() const