Partial emulation of ATLAS/BLAS gemm(), using caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...
#include <gemm.hpp>
Static Public Member Functions | |
template<typename eT > | |
static arma_hot void | apply (Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0)) |
Partial emulation of ATLAS/BLAS gemm(), using caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).
Definition at line 27 of file gemm.hpp.
static arma_hot void gemm_emul_cache< do_trans_A, do_trans_B, use_alpha, use_beta >::apply | ( | Mat< eT > & | C, | |
const Mat< eT > & | A, | |||
const Mat< eT > & | B, | |||
const eT | alpha = eT(1) , |
|||
const eT | beta = eT(0) | |||
) | [inline, static] |
Definition at line 37 of file gemm.hpp.
References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and trans().
00044 { 00045 arma_extra_debug_sigprint(); 00046 00047 const u32 A_n_rows = A.n_rows; 00048 const u32 A_n_cols = A.n_cols; 00049 00050 const u32 B_n_rows = B.n_rows; 00051 const u32 B_n_cols = B.n_cols; 00052 00053 if( (do_trans_A == false) && (do_trans_B == false) ) 00054 { 00055 arma_aligned podarray<eT> tmp(A_n_cols); 00056 eT* A_rowdata = tmp.memptr(); 00057 00058 for(u32 row_A=0; row_A < A_n_rows; ++row_A) 00059 { 00060 00061 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00062 { 00063 A_rowdata[col_A] = A.at(row_A,col_A); 00064 } 00065 00066 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00067 { 00068 const eT* B_coldata = B.colptr(col_B); 00069 00070 eT acc = eT(0); 00071 for(u32 i=0; i < B_n_rows; ++i) 00072 { 00073 acc += A_rowdata[i] * B_coldata[i]; 00074 } 00075 00076 if( (use_alpha == false) && (use_beta == false) ) 00077 { 00078 C.at(row_A,col_B) = acc; 00079 } 00080 else 00081 if( (use_alpha == true) && (use_beta == false) ) 00082 { 00083 C.at(row_A,col_B) = alpha * acc; 00084 } 00085 else 00086 if( (use_alpha == false) && (use_beta == true) ) 00087 { 00088 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00089 } 00090 else 00091 if( (use_alpha == true) && (use_beta == true) ) 00092 { 00093 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00094 } 00095 00096 } 00097 } 00098 } 00099 else 00100 if( (do_trans_A == true) && (do_trans_B == false) ) 00101 { 00102 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00103 { 00104 // col_A is interpreted as row_A when storing the results in matrix C 00105 00106 const eT* A_coldata = A.colptr(col_A); 00107 00108 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00109 { 00110 const eT* B_coldata = B.colptr(col_B); 00111 00112 eT acc = eT(0); 00113 for(u32 i=0; i < B_n_rows; ++i) 00114 { 00115 acc += A_coldata[i] * B_coldata[i]; 00116 } 00117 00118 if( (use_alpha == false) && (use_beta == false) ) 00119 { 00120 C.at(col_A,col_B) = acc; 00121 } 00122 else 00123 if( (use_alpha == true) && (use_beta == false) ) 00124 { 00125 C.at(col_A,col_B) = alpha * acc; 00126 } 00127 else 00128 if( (use_alpha == false) && (use_beta == true) ) 00129 { 00130 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00131 } 00132 else 00133 if( (use_alpha == true) && (use_beta == true) ) 00134 { 00135 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00136 } 00137 00138 } 00139 } 00140 } 00141 else 00142 if( (do_trans_A == false) && (do_trans_B == true) ) 00143 { 00144 Mat<eT> B_tmp = trans(B); 00145 gemm_emul_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00146 } 00147 else 00148 if( (do_trans_A == true) && (do_trans_B == true) ) 00149 { 00150 // mat B_tmp = trans(B); 00151 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00152 00153 00154 // By using the trans(A)*trans(B) = trans(B*A) equivalency, 00155 // transpose operations are not needed 00156 00157 arma_aligned podarray<eT> tmp(B.n_cols); 00158 eT* B_rowdata = tmp.memptr(); 00159 00160 for(u32 row_B=0; row_B < B_n_rows; ++row_B) 00161 { 00162 00163 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00164 { 00165 B_rowdata[col_B] = B.at(row_B,col_B); 00166 } 00167 00168 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00169 { 00170 const eT* A_coldata = A.colptr(col_A); 00171 00172 eT acc = eT(0); 00173 for(u32 i=0; i < A_n_rows; ++i) 00174 { 00175 acc += B_rowdata[i] * A_coldata[i]; 00176 } 00177 00178 if( (use_alpha == false) && (use_beta == false) ) 00179 { 00180 C.at(col_A,row_B) = acc; 00181 } 00182 else 00183 if( (use_alpha == true) && (use_beta == false) ) 00184 { 00185 C.at(col_A,row_B) = alpha * acc; 00186 } 00187 else 00188 if( (use_alpha == false) && (use_beta == true) ) 00189 { 00190 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00191 } 00192 else 00193 if( (use_alpha == true) && (use_beta == true) ) 00194 { 00195 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00196 } 00197 00198 } 00199 } 00200 00201 } 00202 }