43 using namespace internal;
45 KernelManager::KernelManager()
47 SG_SDEBUG(
"Kernel manager instance initialized!\n");
50 KernelManager::KernelManager(
index_t num_kernels)
52 SG_SDEBUG(
"Kernel manager instance initialized with %d kernels!\n", num_kernels);
53 m_kernels.resize(num_kernels);
54 m_precomputed_kernels.resize(num_kernels);
55 std::fill(m_kernels.begin(), m_kernels.end(),
nullptr);
56 std::fill(m_precomputed_kernels.begin(), m_precomputed_kernels.end(),
nullptr);
59 KernelManager::~KernelManager()
64 void KernelManager::clear()
67 m_precomputed_kernels.resize(0);
70 InitPerKernel KernelManager::kernel_at(
index_t i)
74 "Value of i (%d) should be between 0 and %d, inclusive!",
77 return InitPerKernel(m_kernels[i]);
84 "Value of i (%d) should be between 0 and %d, inclusive!",
86 if (m_precomputed_kernels[i]==
nullptr)
89 return m_kernels[i].get();
91 SG_SDEBUG(
"Precomputed kernel exists!\n");
93 return m_precomputed_kernels[i].get();
96 void KernelManager::push_back(
CKernel* kernel)
100 m_kernels.push_back(std::shared_ptr<CKernel>(kernel, [](
CKernel* ptr) {
SG_UNREF(ptr); }));
101 m_precomputed_kernels.push_back(
nullptr);
105 const index_t KernelManager::num_kernels()
const 110 return (
index_t)m_kernels.size();
113 void KernelManager::precompute_kernel_at(
index_t i)
117 "Value of i (%d) should be between 0 and %d, inclusive!",
119 auto kernel=m_kernels[i].
get();
126 m_precomputed_kernels[i]=std::shared_ptr<CCustomKernel>(
new CCustomKernel(kernel_matrix));
127 SG_SDEBUG(
"Kernel type %s is precomputed and replaced internally with %s!\n",
128 kernel->
get_name(), m_precomputed_kernels[i]->get_name());
133 void KernelManager::restore_kernel_at(
index_t i)
137 "Value of i (%d) should be between 0 and %d, inclusive!",
139 m_precomputed_kernels[i]=
nullptr;
140 SG_SDEBUG(
"Precomputed kernel (if any) was deleted!\n");
144 bool KernelManager::same_distance_type()
const 149 for (
auto i=0; i<num_kernels(); ++i)
152 if (shift_invariant_kernel!=
nullptr)
157 distance_type=current_distance_type;
160 else if (distance_type==current_distance_type)
171 SG_SINFO(
"Kernel at location %d is not of CShiftInvariantKernel type (was of %s type)!\n",
172 i, kernel_at(i)->get_name());
179 CDistance* KernelManager::get_distance_instance()
const 181 REQUIRE(same_distance_type(),
"Distance types for all the kernels are not the same!\n");
185 REQUIRE(kernel_0,
"Kernel (%s) must be of CShiftInvariantKernel type!\n", kernel_at(0)->get_name());
189 euclidean_distance->set_disable_sqrt(
true);
190 distance=euclidean_distance;
195 distance=manhattan_distance;
199 SG_SERROR(
"Unsupported distance type!\n");
204 void KernelManager::set_precomputed_distance(
CCustomDistance* distance)
const 206 REQUIRE(distance!=
nullptr,
"Distance instance cannot be null!\n");
207 for (
auto i=0; i<num_kernels(); ++i)
211 REQUIRE(shift_inv_kernel!=
nullptr,
"Kernel instance (was %s) must be of CShiftInvarintKernel type!\n", kernel->
get_name());
212 shift_inv_kernel->m_precomputed_distance=
distance;
218 void KernelManager::unset_precomputed_distance()
const 220 for (
auto i=0; i<num_kernels(); ++i)
224 REQUIRE(shift_inv_kernel!=
nullptr,
"Kernel instance (was %s) must be of CShiftInvarintKernel type!\n", kernel->
get_name());
225 shift_inv_kernel->m_precomputed_distance=
nullptr;
virtual const char * get_name() const =0
float distance(CJLCoverTreePoint p1, CJLCoverTreePoint p2, float64_t upper_bound)
Base class for the family of kernel functions that only depend on the difference of the inputs...
Class Distance, a base class for all the distances used in the Shogun toolbox.
int32_t num_rhs
number of feature vectors on right hand side
virtual int32_t get_num_vec_lhs()
The Custom Kernel allows for custom user provided kernel matrices.
SGMatrix< float64_t > get_kernel_matrix()
virtual int32_t get_num_vec_rhs()
int32_t num_lhs
number of feature vectors on left hand side
all of classes and functions are contained in the shogun namespace
virtual EKernelType get_kernel_type()=0
T get(const Tag< T > &_tag) const
The Custom Distance allows for custom user provided distance matrices.
virtual EDistanceType get_distance_type() const