00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 template<u32 N>
00023 template<typename T1, typename T2>
00024 inline
00025 void
00026 glue_times_redirect<N>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
00027 {
00028 arma_extra_debug_sigprint();
00029
00030 typedef typename T1::elem_type eT;
00031
00032 const partial_unwrap_check<T1> tmp1(X.A, out);
00033 const partial_unwrap_check<T2> tmp2(X.B, out);
00034
00035 const Mat<eT>& A = tmp1.M;
00036 const Mat<eT>& B = tmp2.M;
00037
00038 const bool do_trans_A = tmp1.do_trans;
00039 const bool do_trans_B = tmp2.do_trans;
00040
00041 const bool use_alpha = tmp1.do_times | tmp2.do_times;
00042 const eT alpha = use_alpha ? (tmp1.val * tmp2.val) : eT(0);
00043
00044 glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
00045 }
00046
00047
00048
00049 template<typename T1, typename T2, typename T3>
00050 inline
00051 void
00052 glue_times_redirect<3>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
00053 {
00054 arma_extra_debug_sigprint();
00055
00056 typedef typename T1::elem_type eT;
00057
00058
00059
00060
00061 const partial_unwrap_check<T1> tmp1(X.A.A, out);
00062 const partial_unwrap_check<T2> tmp2(X.A.B, out);
00063 const partial_unwrap_check<T3> tmp3(X.B, out);
00064
00065 const Mat<eT>& A = tmp1.M;
00066 const Mat<eT>& B = tmp2.M;
00067 const Mat<eT>& C = tmp3.M;
00068
00069 const bool do_trans_A = tmp1.do_trans;
00070 const bool do_trans_B = tmp2.do_trans;
00071 const bool do_trans_C = tmp3.do_trans;
00072
00073 const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times;
00074 const eT alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val) : eT(0);
00075
00076 glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
00077 }
00078
00079
00080
00081 template<typename T1, typename T2, typename T3, typename T4>
00082 inline
00083 void
00084 glue_times_redirect<4>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue< Glue<T1,T2,glue_times>, T3, glue_times>, T4, glue_times>& X)
00085 {
00086 arma_extra_debug_sigprint();
00087
00088 typedef typename T1::elem_type eT;
00089
00090
00091
00092
00093 const partial_unwrap_check<T1> tmp1(X.A.A.A, out);
00094 const partial_unwrap_check<T2> tmp2(X.A.A.B, out);
00095 const partial_unwrap_check<T3> tmp3(X.A.B, out);
00096 const partial_unwrap_check<T4> tmp4(X.B, out);
00097
00098 const Mat<eT>& A = tmp1.M;
00099 const Mat<eT>& B = tmp2.M;
00100 const Mat<eT>& C = tmp3.M;
00101 const Mat<eT>& D = tmp4.M;
00102
00103 const bool do_trans_A = tmp1.do_trans;
00104 const bool do_trans_B = tmp2.do_trans;
00105 const bool do_trans_C = tmp3.do_trans;
00106 const bool do_trans_D = tmp4.do_trans;
00107
00108 const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times | tmp4.do_times;
00109 const eT alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val * tmp4.val) : eT(0);
00110
00111 glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha);
00112 }
00113
00114
00115
00116 template<typename T1, typename T2>
00117 inline
00118 void
00119 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
00120 {
00121 arma_extra_debug_sigprint();
00122
00123 typedef typename T1::elem_type eT;
00124
00125 const s32 N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
00126
00127 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
00128
00129 glue_times_redirect<N_mat>::apply(out, X);
00130 }
00131
00132
00133
00134 template<typename T1>
00135 inline
00136 void
00137 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X)
00138 {
00139 arma_extra_debug_sigprint();
00140
00141 typedef typename T1::elem_type eT;
00142
00143 const unwrap_check<T1> tmp(X, out);
00144 const Mat<eT>& B = tmp.M;
00145
00146 arma_debug_assert_mul_size(out, B, "matrix multiply");
00147
00148 if(out.n_cols == B.n_cols)
00149 {
00150 podarray<eT> tmp(out.n_cols);
00151 eT* tmp_rowdata = tmp.memptr();
00152
00153 for(u32 out_row=0; out_row < out.n_rows; ++out_row)
00154 {
00155 for(u32 out_col=0; out_col < out.n_cols; ++out_col)
00156 {
00157 tmp_rowdata[out_col] = out.at(out_row,out_col);
00158 }
00159
00160 for(u32 B_col=0; B_col < B.n_cols; ++B_col)
00161 {
00162 const eT* B_coldata = B.colptr(B_col);
00163
00164 eT val = eT(0);
00165 for(u32 i=0; i < B.n_rows; ++i)
00166 {
00167 val += tmp_rowdata[i] * B_coldata[i];
00168 }
00169
00170 out.at(out_row,B_col) = val;
00171 }
00172 }
00173
00174 }
00175 else
00176 {
00177 const Mat<eT> tmp(out);
00178 glue_times::apply(out, tmp, B, eT(1), false, false, false);
00179 }
00180
00181 }
00182
00183
00184
00185 template<typename T1, typename T2>
00186 arma_hot
00187 inline
00188 void
00189 glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const s32 sign)
00190 {
00191 arma_extra_debug_sigprint();
00192
00193 typedef typename T1::elem_type eT;
00194
00195 const partial_unwrap_check<T1> tmp1(X.A, out);
00196 const partial_unwrap_check<T2> tmp2(X.B, out);
00197
00198 const Mat<eT>& A = tmp1.M;
00199 const Mat<eT>& B = tmp2.M;
00200 const eT alpha = tmp1.val * tmp2.val * ( (sign > s32(0)) ? eT(1) : eT(-1) );
00201
00202 const bool do_trans_A = tmp1.do_trans;
00203 const bool do_trans_B = tmp2.do_trans;
00204 const bool use_alpha = tmp1.do_times | tmp2.do_times | (sign < s32(0));
00205
00206 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply");
00207
00208 const u32 result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00209 const u32 result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00210
00211 arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "matrix addition");
00212
00213 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
00214 {
00215 if(A.n_rows == 1)
00216 {
00217 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00218 }
00219 if(B.n_cols == 1)
00220 {
00221 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00222 }
00223 else
00224 {
00225 gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1));
00226 }
00227 }
00228 else
00229 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
00230 {
00231 if(A.n_rows == 1)
00232 {
00233 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00234 }
00235 if(B.n_cols == 1)
00236 {
00237 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00238 }
00239 else
00240 {
00241 gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1));
00242 }
00243 }
00244 else
00245 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
00246 {
00247 if(A.n_cols == 1)
00248 {
00249 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00250 }
00251 if(B.n_cols == 1)
00252 {
00253 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00254 }
00255 else
00256 {
00257 gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1));
00258 }
00259 }
00260 else
00261 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
00262 {
00263 if(A.n_cols == 1)
00264 {
00265 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00266 }
00267 if(B.n_cols == 1)
00268 {
00269 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00270 }
00271 else
00272 {
00273 gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1));
00274 }
00275 }
00276 else
00277 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
00278 {
00279 if(A.n_rows == 1)
00280 {
00281 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00282 }
00283 if(B.n_rows == 1)
00284 {
00285 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00286 }
00287 else
00288 {
00289 gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1));
00290 }
00291 }
00292 else
00293 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
00294 {
00295 if(A.n_rows == 1)
00296 {
00297 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00298 }
00299 if(B.n_rows == 1)
00300 {
00301 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00302 }
00303 else
00304 {
00305 gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1));
00306 }
00307 }
00308 else
00309 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
00310 {
00311 if(A.n_cols == 1)
00312 {
00313 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00314 }
00315 if(B.n_rows == 1)
00316 {
00317 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00318 }
00319 else
00320 {
00321 gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1));
00322 }
00323 }
00324 else
00325 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
00326 {
00327 if(A.n_cols == 1)
00328 {
00329 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00330 }
00331 if(B.n_rows == 1)
00332 {
00333 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00334 }
00335 else
00336 {
00337 gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1));
00338 }
00339 }
00340
00341
00342 }
00343
00344
00345
00346
00347 template<typename eT1, typename eT2>
00348 inline
00349 void
00350 glue_times::apply_mixed(Mat<typename promote_type<eT1,eT2>::result>& out, const Mat<eT1>& X, const Mat<eT2>& Y)
00351 {
00352 arma_extra_debug_sigprint();
00353
00354 typedef typename promote_type<eT1,eT2>::result out_eT;
00355
00356 arma_debug_assert_mul_size(X,Y, "matrix multiply");
00357
00358 out.set_size(X.n_rows,Y.n_cols);
00359 gemm_mixed<>::apply(out, X, Y);
00360 }
00361
00362
00363
00364 template<typename eT>
00365 arma_inline
00366 u32
00367 glue_times::mul_storage_cost(const Mat<eT>& A, const Mat<eT>& B, const bool do_trans_A, const bool do_trans_B)
00368 {
00369 const u32 final_A_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00370 const u32 final_B_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00371
00372 return final_A_n_rows * final_B_n_cols;
00373 }
00374
00375
00376
00377 template<typename eT>
00378 arma_hot
00379 inline
00380 void
00381 glue_times::apply
00382 (
00383 Mat<eT>& out,
00384 const Mat<eT>& A,
00385 const Mat<eT>& B,
00386 const eT alpha,
00387 const bool do_trans_A,
00388 const bool do_trans_B,
00389 const bool use_alpha
00390 )
00391 {
00392 arma_extra_debug_sigprint();
00393
00394 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply");
00395
00396 const u32 final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00397 const u32 final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00398
00399 out.set_size(final_n_rows, final_n_cols);
00400
00401
00402
00403 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
00404 {
00405 if(A.n_rows == 1)
00406 {
00407 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
00408 }
00409 else
00410 if(B.n_cols == 1)
00411 {
00412 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
00413 }
00414 else
00415 {
00416 gemm<false, false, false, false>::apply(out, A, B);
00417 }
00418 }
00419 else
00420 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
00421 {
00422 if(A.n_rows == 1)
00423 {
00424 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00425 }
00426 else
00427 if(B.n_cols == 1)
00428 {
00429 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00430 }
00431 else
00432 {
00433 gemm<false, false, true, false>::apply(out, A, B, alpha);
00434 }
00435 }
00436 else
00437 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
00438 {
00439 if(A.n_cols == 1)
00440 {
00441 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
00442 }
00443 else
00444 if(B.n_cols == 1)
00445 {
00446 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
00447 }
00448 else
00449 {
00450 gemm<true, false, false, false>::apply(out, A, B);
00451 }
00452 }
00453 else
00454 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
00455 {
00456 if(A.n_cols == 1)
00457 {
00458 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00459 }
00460 if(B.n_cols == 1)
00461 {
00462 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00463 }
00464 else
00465 {
00466 gemm<true, false, true, false>::apply(out, A, B, alpha);
00467 }
00468 }
00469 else
00470 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
00471 {
00472 if(A.n_rows == 1)
00473 {
00474 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
00475 }
00476 if(B.n_rows == 1)
00477 {
00478 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
00479 }
00480 else
00481 {
00482 gemm<false, true, false, false>::apply(out, A, B);
00483 }
00484 }
00485 else
00486 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
00487 {
00488 if(A.n_rows == 1)
00489 {
00490 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00491 }
00492 if(B.n_rows == 1)
00493 {
00494 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00495 }
00496 else
00497 {
00498 gemm<false, true, true, false>::apply(out, A, B, alpha);
00499 }
00500 }
00501 else
00502 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
00503 {
00504 if(A.n_cols == 1)
00505 {
00506 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
00507 }
00508 if(B.n_rows == 1)
00509 {
00510 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
00511 }
00512 else
00513 {
00514 gemm<true, true, false, false>::apply(out, A, B);
00515 }
00516 }
00517 else
00518 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
00519 {
00520 if(A.n_cols == 1)
00521 {
00522 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00523 }
00524 if(B.n_rows == 1)
00525 {
00526 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00527 }
00528 else
00529 {
00530 gemm<true, true, true, false>::apply(out, A, B, alpha);
00531 }
00532 }
00533 }
00534
00535
00536
00537 template<typename eT>
00538 inline
00539 void
00540 glue_times::apply
00541 (
00542 Mat<eT>& out,
00543 const Mat<eT>& A,
00544 const Mat<eT>& B,
00545 const Mat<eT>& C,
00546 const eT alpha,
00547 const bool do_trans_A,
00548 const bool do_trans_B,
00549 const bool do_trans_C,
00550 const bool use_alpha
00551 )
00552 {
00553 arma_extra_debug_sigprint();
00554
00555 Mat<eT> tmp;
00556
00557 if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) )
00558 {
00559
00560 glue_times::apply(tmp, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
00561 glue_times::apply(out, tmp, C, eT(0), false, do_trans_C, false );
00562 }
00563 else
00564 {
00565
00566 glue_times::apply(tmp, B, C, alpha, do_trans_B, do_trans_C, use_alpha);
00567 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false );
00568 }
00569 }
00570
00571
00572
00573 template<typename eT>
00574 inline
00575 void
00576 glue_times::apply
00577 (
00578 Mat<eT>& out,
00579 const Mat<eT>& A,
00580 const Mat<eT>& B,
00581 const Mat<eT>& C,
00582 const Mat<eT>& D,
00583 const eT alpha,
00584 const bool do_trans_A,
00585 const bool do_trans_B,
00586 const bool do_trans_C,
00587 const bool do_trans_D,
00588 const bool use_alpha
00589 )
00590 {
00591 arma_extra_debug_sigprint();
00592
00593 Mat<eT> tmp;
00594
00595 if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) )
00596 {
00597
00598 glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
00599
00600 glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false);
00601 }
00602 else
00603 {
00604
00605 glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha);
00606
00607 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false);
00608 }
00609 }
00610
00611
00612
00613
00614
00615
00616
00617 template<typename T1, typename T2>
00618 arma_hot
00619 inline
00620 void
00621 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_diag>& X)
00622 {
00623 arma_extra_debug_sigprint();
00624
00625 typedef typename T1::elem_type eT;
00626
00627 const strip_diagmat<T1> S1(X.A);
00628 const strip_diagmat<T2> S2(X.B);
00629
00630 typedef typename strip_diagmat<T1>::stored_type T1_stripped;
00631 typedef typename strip_diagmat<T2>::stored_type T2_stripped;
00632
00633 if( (S1.do_diagmat == true) && (S2.do_diagmat == false) )
00634 {
00635 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
00636
00637 const unwrap_check<T2> tmp(X.B, out);
00638 const Mat<eT>& B = tmp.M;
00639
00640 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiply");
00641
00642 out.set_size(A.n_elem, B.n_cols);
00643
00644 for(u32 col=0; col<B.n_cols; ++col)
00645 {
00646 eT* out_coldata = out.colptr(col);
00647 const eT* B_coldata = B.colptr(col);
00648
00649 for(u32 row=0; row<B.n_rows; ++row)
00650 {
00651 out_coldata[row] = A[row] * B_coldata[row];
00652 }
00653 }
00654 }
00655 else
00656 if( (S1.do_diagmat == false) && (S2.do_diagmat == true) )
00657 {
00658 const unwrap_check<T1> tmp(X.A, out);
00659 const Mat<eT>& A = tmp.M;
00660
00661 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
00662
00663 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiply");
00664
00665 out.set_size(A.n_rows, B.n_elem);
00666
00667 for(u32 col=0; col<A.n_cols; ++col)
00668 {
00669 const eT val = B[col];
00670
00671 eT* out_coldata = out.colptr(col);
00672 const eT* A_coldata = A.colptr(col);
00673
00674 for(u32 row=0; row<A.n_rows; ++row)
00675 {
00676 out_coldata[row] = A_coldata[row] * val;
00677 }
00678 }
00679 }
00680 else
00681 if( (S1.do_diagmat == true) && (S2.do_diagmat == true) )
00682 {
00683 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
00684 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
00685
00686 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiply");
00687
00688 out.zeros(A.n_elem, A.n_elem);
00689
00690 for(u32 i=0; i<A.n_elem; ++i)
00691 {
00692 out.at(i,i) = A[i] * B[i];
00693 }
00694 }
00695 }
00696
00697
00698
00699