Matrix multplication where the matrices have different element types. Uses caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...
#include <gemm_mixed.hpp>
Static Public Member Functions | |
template<typename out_eT , typename in_eT1 , typename in_eT2 > | |
static arma_hot void | apply (Mat< out_eT > &C, const Mat< in_eT1 > &A, const Mat< in_eT2 > &B, const out_eT alpha=out_eT(1), const out_eT beta=out_eT(0)) |
Matrix multplication where the matrices have different element types. Uses caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).
Definition at line 28 of file gemm_mixed.hpp.
static arma_hot void gemm_mixed_cache< do_trans_A, do_trans_B, use_alpha, use_beta >::apply | ( | Mat< out_eT > & | C, | |
const Mat< in_eT1 > & | A, | |||
const Mat< in_eT2 > & | B, | |||
const out_eT | alpha = out_eT(1) , |
|||
const out_eT | beta = out_eT(0) | |||
) | [inline, static] |
Definition at line 38 of file gemm_mixed.hpp.
References Mat< eT >::at(), Mat< eT >::colptr(), podarray< eT >::memptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and trans().
00045 { 00046 arma_extra_debug_sigprint(); 00047 00048 const u32 A_n_rows = A.n_rows; 00049 const u32 A_n_cols = A.n_cols; 00050 00051 const u32 B_n_rows = B.n_rows; 00052 const u32 B_n_cols = B.n_cols; 00053 00054 if( (do_trans_A == false) && (do_trans_B == false) ) 00055 { 00056 podarray<in_eT1> tmp(A_n_cols); 00057 in_eT1* A_rowdata = tmp.memptr(); 00058 00059 for(u32 row_A=0; row_A < A_n_rows; ++row_A) 00060 { 00061 00062 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00063 { 00064 A_rowdata[col_A] = A.at(row_A,col_A); 00065 } 00066 00067 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00068 { 00069 const in_eT2* B_coldata = B.colptr(col_B); 00070 00071 out_eT acc = out_eT(0); 00072 for(u32 i=0; i < B_n_rows; ++i) 00073 { 00074 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00075 } 00076 00077 if( (use_alpha == false) && (use_beta == false) ) 00078 { 00079 C.at(row_A,col_B) = acc; 00080 } 00081 else 00082 if( (use_alpha == true) && (use_beta == false) ) 00083 { 00084 C.at(row_A,col_B) = alpha * acc; 00085 } 00086 else 00087 if( (use_alpha == false) && (use_beta == true) ) 00088 { 00089 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00090 } 00091 else 00092 if( (use_alpha == true) && (use_beta == true) ) 00093 { 00094 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00095 } 00096 00097 } 00098 } 00099 } 00100 else 00101 if( (do_trans_A == true) && (do_trans_B == false) ) 00102 { 00103 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00104 { 00105 // col_A is interpreted as row_A when storing the results in matrix C 00106 00107 const in_eT1* A_coldata = A.colptr(col_A); 00108 00109 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00110 { 00111 const in_eT2* B_coldata = B.colptr(col_B); 00112 00113 out_eT acc = out_eT(0); 00114 for(u32 i=0; i < B_n_rows; ++i) 00115 { 00116 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00117 } 00118 00119 if( (use_alpha == false) && (use_beta == false) ) 00120 { 00121 C.at(col_A,col_B) = acc; 00122 } 00123 else 00124 if( (use_alpha == true) && (use_beta == false) ) 00125 { 00126 C.at(col_A,col_B) = alpha * acc; 00127 } 00128 else 00129 if( (use_alpha == false) && (use_beta == true) ) 00130 { 00131 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00132 } 00133 else 00134 if( (use_alpha == true) && (use_beta == true) ) 00135 { 00136 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00137 } 00138 00139 } 00140 } 00141 } 00142 else 00143 if( (do_trans_A == false) && (do_trans_B == true) ) 00144 { 00145 Mat<in_eT2> B_tmp = trans(B); 00146 gemm_mixed_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00147 } 00148 else 00149 if( (do_trans_A == true) && (do_trans_B == true) ) 00150 { 00151 // mat B_tmp = trans(B); 00152 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00153 00154 00155 // By using the trans(A)*trans(B) = trans(B*A) equivalency, 00156 // transpose operations are not needed 00157 00158 podarray<in_eT2> tmp(B.n_cols); 00159 in_eT2* B_rowdata = tmp.memptr(); 00160 00161 for(u32 row_B=0; row_B < B_n_rows; ++row_B) 00162 { 00163 00164 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00165 { 00166 B_rowdata[col_B] = B.at(row_B,col_B); 00167 } 00168 00169 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00170 { 00171 const in_eT1* A_coldata = A.colptr(col_A); 00172 00173 out_eT acc = out_eT(0); 00174 for(u32 i=0; i < A_n_rows; ++i) 00175 { 00176 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]); 00177 } 00178 00179 if( (use_alpha == false) && (use_beta == false) ) 00180 { 00181 C.at(col_A,row_B) = acc; 00182 } 00183 else 00184 if( (use_alpha == true) && (use_beta == false) ) 00185 { 00186 C.at(col_A,row_B) = alpha * acc; 00187 } 00188 else 00189 if( (use_alpha == false) && (use_beta == true) ) 00190 { 00191 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00192 } 00193 else 00194 if( (use_alpha == true) && (use_beta == true) ) 00195 { 00196 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00197 } 00198 00199 } 00200 } 00201 00202 } 00203 }