ViennaCL - The Vienna Computing Library
1.5.0
|
00001 #ifndef VIENNACL_LINALG_NMF_HPP 00002 #define VIENNACL_LINALG_NMF_HPP 00003 00004 /* ========================================================================= 00005 Copyright (c) 2010-2013, Institute for Microelectronics, 00006 Institute for Analysis and Scientific Computing, 00007 TU Wien. 00008 Portions of this software are copyright by UChicago Argonne, LLC. 00009 00010 ----------------- 00011 ViennaCL - The Vienna Computing Library 00012 ----------------- 00013 00014 Project Head: Karl Rupp rupp@iue.tuwien.ac.at 00015 00016 (A list of authors and contributors can be found in the PDF manual) 00017 00018 License: MIT (X11), see file LICENSE in the base directory 00019 ============================================================================= */ 00020 00028 #include "viennacl/vector.hpp" 00029 #include "viennacl/matrix.hpp" 00030 #include "viennacl/linalg/prod.hpp" 00031 #include "viennacl/linalg/norm_2.hpp" 00032 #include "viennacl/linalg/norm_frobenius.hpp" 00033 #include "viennacl/linalg/opencl/kernels/nmf.hpp" 00034 00035 namespace viennacl 00036 { 00037 namespace linalg 00038 { 00040 class nmf_config 00041 { 00042 public: 00043 nmf_config(double val_epsilon = 1e-4, 00044 double val_epsilon_stagnation = 1e-5, 00045 vcl_size_t num_max_iters = 10000, 00046 vcl_size_t num_check_iters = 100) 00047 : eps_(val_epsilon), stagnation_eps_(val_epsilon_stagnation), 00048 max_iters_(num_max_iters), 00049 check_after_steps_( (num_check_iters > 0) ? num_check_iters : 1), 00050 print_relative_error_(false), 00051 iters_(0) {} 00052 00054 double tolerance() const { return eps_; } 00055 00057 void tolerance(double e) { eps_ = e; } 00058 00060 double stagnation_tolerance() const { return stagnation_eps_; } 00061 00063 void stagnation_tolerance(double e) { stagnation_eps_ = e; } 00064 00066 vcl_size_t max_iterations() const { return max_iters_; } 00068 void max_iterations(vcl_size_t m) { max_iters_ = m; } 00069 00071 vcl_size_t iters() const { return iters_; } 00072 00073 00075 vcl_size_t check_after_steps() const { return check_after_steps_; } 00077 void check_after_steps(vcl_size_t c) { if (c > 0) check_after_steps_ = c; } 00078 00080 bool print_relative_error() const { return print_relative_error_; } 00082 void print_relative_error(bool b) { print_relative_error_ = b; } 00083 00084 template <typename ScalarType> 00085 friend void nmf(viennacl::matrix<ScalarType> const & V, 00086 viennacl::matrix<ScalarType> & W, 00087 viennacl::matrix<ScalarType> & H, 00088 nmf_config const & conf); 00089 00090 private: 00091 double eps_; 00092 double stagnation_eps_; 00093 vcl_size_t max_iters_; 00094 vcl_size_t check_after_steps_; 00095 bool print_relative_error_; 00096 mutable vcl_size_t iters_; 00097 }; 00098 00099 00107 template <typename ScalarType> 00108 void nmf(viennacl::matrix<ScalarType> const & V, 00109 viennacl::matrix<ScalarType> & W, 00110 viennacl::matrix<ScalarType> & H, 00111 nmf_config const & conf) 00112 { 00113 viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(V).context()); 00114 00115 const std::string NMF_MUL_DIV_KERNEL = "el_wise_mul_div"; 00116 00117 viennacl::linalg::opencl::kernels::nmf<ScalarType>::init(ctx); 00118 00119 assert(V.size1() == W.size1() && V.size2() == H.size2() && bool("Dimensions of W and H don't allow for V = W * H")); 00120 assert(W.size2() == H.size1() && bool("Dimensions of W and H don't match, prod(W, H) impossible")); 00121 00122 vcl_size_t k = W.size2(); 00123 conf.iters_ = 0; 00124 00125 viennacl::matrix<ScalarType> wn(V.size1(), k); 00126 viennacl::matrix<ScalarType> wd(V.size1(), k); 00127 viennacl::matrix<ScalarType> wtmp(V.size1(), V.size2()); 00128 00129 viennacl::matrix<ScalarType> hn(k, V.size2()); 00130 viennacl::matrix<ScalarType> hd(k, V.size2()); 00131 viennacl::matrix<ScalarType> htmp(k, k); 00132 00133 viennacl::matrix<ScalarType> appr(V.size1(), V.size2()); 00134 viennacl::vector<ScalarType> diff(V.size1() * V.size2()); 00135 00136 ScalarType last_diff = 0; 00137 ScalarType diff_init = 0; 00138 bool stagnation_flag = false; 00139 00140 00141 for (vcl_size_t i = 0; i < conf.max_iterations(); i++) 00142 { 00143 conf.iters_ = i + 1; 00144 { 00145 hn = viennacl::linalg::prod(trans(W), V); 00146 htmp = viennacl::linalg::prod(trans(W), W); 00147 hd = viennacl::linalg::prod(htmp, H); 00148 00149 viennacl::ocl::kernel & mul_div_kernel = ctx.get_kernel(viennacl::linalg::opencl::kernels::nmf<ScalarType>::program_name(), NMF_MUL_DIV_KERNEL); 00150 viennacl::ocl::enqueue(mul_div_kernel(H, hn, hd, cl_uint(H.internal_size1() * H.internal_size2()))); 00151 } 00152 { 00153 wn = viennacl::linalg::prod(V, trans(H)); 00154 wtmp = viennacl::linalg::prod(W, H); 00155 wd = viennacl::linalg::prod(wtmp, trans(H)); 00156 00157 viennacl::ocl::kernel & mul_div_kernel = ctx.get_kernel(viennacl::linalg::opencl::kernels::nmf<ScalarType>::program_name(), NMF_MUL_DIV_KERNEL); 00158 00159 viennacl::ocl::enqueue(mul_div_kernel(W, wn, wd, cl_uint(W.internal_size1() * W.internal_size2()))); 00160 } 00161 00162 if (i % conf.check_after_steps() == 0) //check for convergence 00163 { 00164 appr = viennacl::linalg::prod(W, H); 00165 00166 appr -= V; 00167 ScalarType diff_val = viennacl::linalg::norm_frobenius(appr); 00168 00169 if (i == 0) 00170 diff_init = diff_val; 00171 00172 if (conf.print_relative_error()) 00173 std::cout << diff_val / diff_init << std::endl; 00174 00175 // Approximation check 00176 if (diff_val / diff_init < conf.tolerance()) 00177 break; 00178 00179 // Stagnation check 00180 if (std::fabs(diff_val - last_diff) / (diff_val * conf.check_after_steps()) < conf.stagnation_tolerance()) //avoid situations where convergence stagnates 00181 { 00182 if (stagnation_flag) // iteration stagnates (two iterates with no notable progress) 00183 break; 00184 else // record stagnation in this iteration 00185 stagnation_flag = true; 00186 } 00187 else // good progress in this iteration, so unset stagnation flag 00188 stagnation_flag = false; 00189 00190 // prepare for next iterate: 00191 last_diff = diff_val; 00192 } 00193 } 00194 00195 00196 } 00197 } 00198 } 00199 00200 #endif