SHOGUN  3.2.1
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
KernelMulticlassMachine.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Chiyuan Zhang
8  * Written (W) 2012 Heiko Strathmann
9  * Copyright (C) 2012 Chiyuan Zhang
10  */
11 
12 #include <shogun/lib/Set.h>
14 
15 using namespace shogun;
16 
18 {
19  CKernel *kernel= m_kernel;
20  if (!kernel)
21  SG_ERROR("%s::store_model_features(): kernel is needed to store SV "
22  "features.\n", get_name());
23 
24  CFeatures* lhs=kernel->get_lhs();
25  CFeatures* rhs=kernel->get_rhs();
26  if (!lhs)
27  {
28  SG_ERROR("%s::store_model_features(): kernel lhs is needed to store "
29  "SV features.\n", get_name());
30  }
31 
32  /* this map will be abused as a map */
33  CSet<index_t> all_sv;
34  for (index_t i=0; i<m_machines->get_num_elements(); ++i)
35  {
37  for (index_t j=0; j<machine->get_num_support_vectors(); ++j)
38  all_sv.add(machine->get_support_vector(j));
39 
40  SG_UNREF(machine);
41  }
42 
43  /* convert map to vector of SV */
44  SGVector<index_t> sv_idx(all_sv.get_num_elements());
45  for (index_t i=0; i<sv_idx.vlen; ++i)
46  sv_idx[i]=*all_sv.get_element_ptr(i);
47 
48  CFeatures* sv_features=lhs->copy_subset(sv_idx);
49 
50  /* now, features are replaced by concatenated SV features */
51  kernel->init(sv_features, rhs);
52 
53  /* was SG_REF'ed by copy_subset */
54  SG_UNREF(sv_features);
55 
56  /* now the old SV indices have to be mapped to the new features */
57 
58  /* update SV of all machines */
59  for (int32_t i=0; i<m_machines->get_num_elements(); ++i)
60  {
62 
63  /* for each machine, replace SV by index in sv_idx array */
64  for (int32_t j=0; j<machine->get_num_support_vectors(); ++j)
65  {
66  /* get index of SV in old features */
67  index_t current_sv_idx=machine->get_support_vector(j);
68 
69  /* the position of this old index in the map is the position of
70  * the SV in the new features */
71  index_t new_sv_idx=all_sv.index_of(current_sv_idx);
72 
73  machine->set_support_vector(j, new_sv_idx);
74  }
75 
76  SG_UNREF(machine);
77  }
78 
79  SG_UNREF(lhs);
80  SG_UNREF(rhs);
81 }
82 
84 {
85  SG_ADD((CSGObject**)&m_kernel,"kernel", "The kernel to be used", MS_AVAILABLE);
86 }
87 
95  CMulticlassMachine(strategy,(CMachine*)machine,labs), m_kernel(NULL)
96 {
97  set_kernel(kernel);
98  SG_ADD((CSGObject**)&m_kernel,"kernel", "The kernel to be used", MS_AVAILABLE);
99 }
100 
103 {
105 }
106 
112 {
113  ((CKernelMachine*)m_machine)->set_kernel(k);
114  SG_REF(k);
116  m_kernel=k;
117 }
118 
120 {
121  SG_REF(m_kernel);
122  return m_kernel;
123 }
124 
126 {
127  if (data)
128  m_kernel->init(data,data);
129 
130  ((CKernelMachine*)m_machine)->set_kernel(m_kernel);
131 
132  return true;
133 }
134 
136 {
137  if (data)
138  {
139  /* set data to rhs for this kernel */
140  CFeatures* lhs=m_kernel->get_lhs();
141  m_kernel->init(lhs, data);
142  SG_UNREF(lhs);
143  }
144 
145  /* set kernel to all sub-machines */
146  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
147  {
148  CKernelMachine *machine=
150  machine->set_kernel(m_kernel);
151  SG_UNREF(machine);
152  }
153 
154  return true;
155 }
156 
158 {
160  return true;
161 
162  return false;
163 }
164 
166 {
167  return new CKernelMachine((CKernelMachine*)machine);
168 }
169 
171 {
172  return m_kernel->get_num_vec_rhs();
173 }
174 
176 {
178 }
179 
181 {
183 }
184 
185 
virtual void add_machine_subset(SGVector< index_t > subset)
virtual bool init(CFeatures *lhs, CFeatures *rhs)
Definition: Kernel.cpp:83
int32_t index_of(const T &element)
Definition: Set.h:151
CMachine * get_machine(int32_t num) const
int32_t index_t
Definition: common.h:60
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:35
#define SG_UNREF(x)
Definition: SGRefObject.h:35
CFeatures * get_rhs()
Definition: Kernel.h:349
#define SG_ERROR(...)
Definition: SGIO.h:131
#define SG_NOTIMPLEMENTED
Definition: SGIO.h:141
A generic KernelMachine interface.
Definition: KernelMachine.h:50
virtual int32_t get_num_vec_lhs()
Definition: Kernel.h:355
A generic learning machine interface.
Definition: Machine.h:138
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:102
the class CSet, a set based on the hash-table. w: http://en.wikipedia.org/wiki/Hash_table ...
Definition: Set.h:49
experimental abstract generic multiclass machine class
#define SG_REF(x)
Definition: SGRefObject.h:34
bool set_support_vector(int32_t idx, int32_t val)
virtual CMachine * get_machine_from_trained(CMachine *machine)
virtual const char * get_name() const
int32_t get_support_vector(int32_t idx)
virtual int32_t get_num_vec_rhs()
Definition: Kernel.h:364
T * get_element_ptr(int32_t index)
Definition: Set.h:187
virtual bool init_machine_for_train(CFeatures *data)
virtual CFeatures * copy_subset(SGVector< index_t > indices)
Definition: Features.cpp:330
The class Features is the base class of all feature objects.
Definition: Features.h:62
virtual bool init_machines_for_apply(CFeatures *data)
int32_t get_num_elements() const
Definition: Set.h:166
The Kernel base class.
Definition: Kernel.h:150
void add(const T &element)
Definition: Set.h:107
CSGObject * get_element(int32_t index) const
class MulticlassStrategy used to construct generic multiclass classifiers with ensembles of binary cl...
void set_kernel(CKernel *k)
#define SG_ADD(...)
Definition: SGObject.h:71
CFeatures * get_lhs()
Definition: Kernel.h:343

SHOGUN 机器学习工具包 - 项目文档