SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
statistics
KernelMeanMatching.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Copyright (W) 2012 Sergey Lisitsyn
8
*/
9
10
#include <
shogun/statistics/KernelMeanMatching.h
>
11
#include <shogun/lib/external/libqp.h>
12
13
14
static
float64_t
*
kmm_K
= NULL;
15
static
int32_t
kmm_K_ld
= 0;
16
17
static
const
float64_t
*
kmm_get_col
(uint32_t i)
18
{
19
return
kmm_K
+
kmm_K_ld
*i;
20
}
21
22
namespace
shogun
23
{
24
CKernelMeanMatching::CKernelMeanMatching
() :
25
CSGObject
(), m_kernel(NULL)
26
{
27
}
28
29
CKernelMeanMatching::CKernelMeanMatching
(
CKernel
* kernel,
SGVector<index_t>
training_indices,
30
SGVector<index_t>
test_indices) :
31
CSGObject
(), m_kernel(NULL)
32
{
33
set_kernel
(kernel);
34
set_training_indices
(training_indices);
35
set_test_indices
(test_indices);
36
}
37
38
SGVector<float64_t>
CKernelMeanMatching::compute_weights
()
39
{
40
int32_t i,j;
41
ASSERT
(
m_kernel
)
42
ASSERT
(
m_training_indices
.
vlen
)
43
ASSERT
(
m_test_indices
.
vlen
)
44
45
int32_t n_tr =
m_training_indices
.
vlen
;
46
int32_t n_te =
m_test_indices
.
vlen
;
47
48
SGVector<float64_t>
weights(n_tr);
49
weights.
zero
();
50
51
kmm_K
= SG_MALLOC(
float64_t
, n_tr*n_tr);
52
kmm_K_ld
= n_tr;
53
float64_t
* diag_K = SG_MALLOC(
float64_t
, n_tr);
54
for
(i=0; i<n_tr; i++)
55
{
56
float64_t
d =
m_kernel
->
kernel
(
m_training_indices
[i],
m_training_indices
[i]);
57
diag_K[i] = d;
58
kmm_K
[i*n_tr+i] = d;
59
for
(j=i+1; j<n_tr; j++)
60
{
61
d =
m_kernel
->
kernel
(
m_training_indices
[i],
m_training_indices
[j]);
62
kmm_K
[i*n_tr+j] = d;
63
kmm_K
[j*n_tr+i] = d;
64
}
65
}
66
float64_t
* kappa = SG_MALLOC(
float64_t
, n_tr);
67
for
(i=0; i<n_tr; i++)
68
{
69
float64_t
avg = 0.0;
70
for
(j=0; j<n_te; j++)
71
avg+=
m_kernel
->
kernel
(
m_training_indices
[i],
m_test_indices
[j]);
72
73
avg *=
float64_t
(n_tr)/n_te;
74
kappa[i] = -avg;
75
}
76
float64_t
* a = SG_MALLOC(
float64_t
, n_tr);
77
for
(i=0; i<n_tr; i++) a[i] = 1.0;
78
float64_t
* LB = SG_MALLOC(
float64_t
, n_tr);
79
float64_t
* UB = SG_MALLOC(
float64_t
, n_tr);
80
float64_t
B = 2.0;
81
for
(i=0; i<n_tr; i++)
82
{
83
LB[i] = 0.0;
84
UB[i] = B;
85
}
86
for
(i=0; i<n_tr; i++)
87
weights[i] = 1.0/
float64_t
(n_tr);
88
89
libqp_state_T result =
90
libqp_gsmo_solver(&
kmm_get_col
,diag_K,kappa,a,1.0,LB,UB,weights,n_tr,1000,1e-9,NULL);
91
92
SG_DEBUG
(
"libqp exitflag=%d, %d iterations passed, primal objective=%f\n"
,
93
result.exitflag,result.nIter,result.QP);
94
95
SG_FREE(kappa);
96
SG_FREE(a);
97
SG_FREE(LB);
98
SG_FREE(UB);
99
SG_FREE(diag_K);
100
SG_FREE(
kmm_K
);
101
102
return
weights;
103
}
104
105
}
shogun::CKernelMeanMatching::m_training_indices
SGVector< index_t > m_training_indices
Definition:
KernelMeanMatching.h:53
kmm_K_ld
static int32_t kmm_K_ld
Definition:
KernelMeanMatching.cpp:15
kmm_get_col
static const float64_t * kmm_get_col(uint32_t i)
Definition:
KernelMeanMatching.cpp:17
shogun::CKernel::kernel
float64_t kernel(int32_t idx_a, int32_t idx_b)
Definition:
Kernel.h:198
shogun::SGVector::zero
void zero()
Definition:
SGVector.cpp:110
shogun::CKernelMeanMatching::compute_weights
SGVector< float64_t > compute_weights()
Definition:
KernelMeanMatching.cpp:38
shogun::CKernelMeanMatching::set_kernel
void set_kernel(CKernel *kernel)
Definition:
KernelMeanMatching.h:33
ASSERT
#define ASSERT(x)
Definition:
SGIO.h:203
shogun::CSGObject
Class SGObject is the base class of all shogun objects.
Definition:
SGObject.h:102
shogun::CKernelMeanMatching::set_training_indices
void set_training_indices(SGVector< index_t > training_indices)
Definition:
KernelMeanMatching.h:37
float64_t
double float64_t
Definition:
common.h:48
shogun::SGVector< index_t >
shogun::CKernelMeanMatching::m_test_indices
SGVector< index_t > m_test_indices
Definition:
KernelMeanMatching.h:55
shogun::CKernelMeanMatching::set_test_indices
void set_test_indices(SGVector< index_t > test_indices)
Definition:
KernelMeanMatching.h:41
KernelMeanMatching.h
SG_DEBUG
#define SG_DEBUG(...)
Definition:
SGIO.h:109
shogun::CKernel
The Kernel base class.
Definition:
Kernel.h:150
shogun::CKernelMeanMatching::m_kernel
CKernel * m_kernel
Definition:
KernelMeanMatching.h:51
kmm_K
static float64_t * kmm_K
Definition:
KernelMeanMatching.cpp:14
shogun::CKernelMeanMatching::CKernelMeanMatching
CKernelMeanMatching()
Definition:
KernelMeanMatching.cpp:24
shogun::SGVector::vlen
index_t vlen
Definition:
SGVector.h:706
SHOGUN
机器学习工具包 - 项目文档