2121
2222namespace fdapde {
2323
24+ namespace internals {
25+
2426// implementation of Conjugate Gradient method for unconstrained nonlinear optimization
25- template <int N, typename ... Args> class ConjugateGradient {
27+ template <int N, typename DirectionMethod, typename ... Args> class conjugate_gradient_impl {
2628 private:
27- using vector_t =
28- std::conditional_t <N == Dynamic, Eigen::Matrix<double , Dynamic, 1 >, Eigen::Matrix<double , N, 1 >>;
29+ using vector_t = std::conditional_t <N == Dynamic, Eigen::Matrix<double , Dynamic, 1 >, Eigen::Matrix<double , N, 1 >>;
2930 using matrix_t =
30- std::conditional_t <N == Dynamic, Eigen::Matrix<double , Dynamic, Dynamic>, Eigen::Matrix<double , N, N>>;
31+ std::conditional_t <N == Dynamic, Eigen::Matrix<double , Dynamic, Dynamic>, Eigen::Matrix<double , N, N>>;
3132
3233 std::tuple<Args...> callbacks_ {};
3334 vector_t optimum_;
@@ -36,89 +37,106 @@ template <int N, typename... Args> class ConjugateGradient {
3637 double tol_; // tolerance on error before forced stop
3738 double step_; // update step
3839 int n_iter_ = 0 ; // current iteration number
39- bool use_polak_ribiere_ = true ; // If set to false, use Fletcher-Reave, else use Polak-Ribière to calculate beta
40-
4140 public:
42- vector_t x_old, x_new, update, dir, grad_old, grad_new;
41+ vector_t x_old, x_new, update, grad_old, grad_new;
4342 double h;
44-
4543 // constructor
46- ConjugateGradient () = default ;
47-
48- ConjugateGradient (int max_iter, double tol, double step, bool use_polak_ribiere = true )
44+ conjugate_gradient_impl () = default ;
45+ conjugate_gradient_impl (int max_iter, double tol, double step)
4946 requires (sizeof ...(Args) != 0 )
50- : max_iter_(max_iter), tol_(tol), step_(step), use_polak_ribiere_(use_polak_ribiere) { }
51-
52- ConjugateGradient (int max_iter, double tol, double step, bool use_polak_ribiere, Args&&... callbacks) :
53- callbacks_ (std::make_tuple(std::forward<Args>(callbacks)...)), max_iter_(max_iter), tol_(tol), step_(step), use_polak_ribiere_(use_polak_ribiere) { }
54-
47+ : max_iter_(max_iter), tol_(tol), step_(step) { }
48+ conjugate_gradient_impl (int max_iter, double tol, double step, Args&&... callbacks) :
49+ callbacks_ (std::make_tuple(std::forward<Args>(callbacks)...)), max_iter_(max_iter), tol_(tol), step_(step) { }
5550 // copy semantic
56- ConjugateGradient (const ConjugateGradient& other) :
57- callbacks_ (other.callbacks_), max_iter_(other.max_iter_), tol_(other.tol_), step_(other.step_), use_polak_ribiere_(other.use_polak_ribiere_) { }
58-
59- ConjugateGradient& operator =(const ConjugateGradient& other) {
51+ conjugate_gradient_impl (const conjugate_gradient_impl& other) :
52+ callbacks_ (other.callbacks_), max_iter_(other.max_iter_), tol_(other.tol_), step_(other.step_) { }
53+ conjugate_gradient_impl& operator =(const conjugate_gradient_impl& other) {
6054 max_iter_ = other.max_iter_ ;
6155 tol_ = other.tol_ ;
6256 step_ = other.step_ ;
6357 callbacks_ = other.callbacks_ ;
64- use_polak_ribiere_ = other.use_polak_ribiere_ ;
6558 return *this ;
6659 }
67-
6860 template <typename ObjectiveT, typename ... Functor>
6961 requires (sizeof ...(Functor) < 2 ) && ((requires (Functor f, double value) { f (value); }) && ...)
7062 vector_t optimize (ObjectiveT&& objective, const vector_t & x0, Functor&&... func) {
7163 fdapde_static_assert (
72- std::is_same<decltype (std::declval<ObjectiveT>().operator ()(vector_t ())) FDAPDE_COMMA double >::value,
73- INVALID_CALL_TO_OPTIMIZE__OBJECTIVE_FUNCTOR_NOT_ACCEPTING_VECTORTYPE
74- );
75-
64+ std::is_same<decltype (std::declval<ObjectiveT>().operator ()(vector_t ())) FDAPDE_COMMA double >::value,
65+ INVALID_CALL_TO_OPTIMIZE__OBJECTIVE_FUNCTOR_NOT_ACCEPTING_VECTORTYPE);
7666 bool stop = false ; // asserted true in case of forced stop
7767 double error = std::numeric_limits<double >::max ();
78- double beta = 0.0 ;
79- h = step_ ;
68+ DirectionMethod beta;
69+ auto grad = objective. gradient () ;
8070 n_iter_ = 0 ;
71+ h = step_;
8172 x_old = x0, x_new = x0;
82- auto grad = objective.gradient ();
8373 grad_old = grad (x_old);
84- dir = -grad_old;
74+ update = -grad_old;
8575
8676 while (n_iter_ < max_iter_ && error > tol_ && !stop) {
87- update = dir;
8877 stop |= execute_pre_update_step (*this , objective, callbacks_);
89-
9078 // update along descent direction
9179 x_new = x_old + h * update;
9280 grad_new = grad (x_new);
9381 if constexpr (sizeof ...(Functor) == 1 ) { (func (objective (x_old)), ...); }
94- if ( use_polak_ribiere_ )
95- beta = grad_new.dot (grad_new - grad_old) / grad_old.dot (grad_old); // Polak Ribière
96- else
97- beta = grad_new.dot (grad_new) / grad_old.dot (grad_old); // Fletcher-Reeves
98-
99- dir = -grad_new + beta * dir;
100-
101- // prepare next iteration
82+ // prepare next iteration
83+ update = -grad_new + std::max (0.0 , beta (*this )) * update; // update conjugate direction
10284 error = grad_new.norm ();
10385 stop |=
10486 (execute_post_update_step (*this , objective, callbacks_) || execute_stopping_criterion (*this , objective));
10587 x_old = x_new;
10688 grad_old = grad_new;
107-
10889 n_iter_++;
10990 }
11091 optimum_ = x_old;
11192 value_ = objective (optimum_);
112- if constexpr (sizeof ...(Functor) == 1 ) { (func (value_), ...); }
93+ if constexpr (sizeof ...(Functor) == 1 ) { (func (value_), ...); }
11394 return optimum_;
11495 }
115-
11696 // getters
11797 vector_t optimum () const { return optimum_; }
11898 double value () const { return value_; }
11999 int n_iter () const { return n_iter_; }
120100};
121101
102+ struct fletcher_reeves_impl {
103+ template <typename Opt> double operator ()(const Opt& opt) { return opt.grad_new .norm () / opt.grad_old .norm (); }
104+ };
105+ struct polak_ribiere_impl {
106+ template <typename Opt> double operator ()(const Opt& opt) {
107+ return opt.grad_new .dot (opt.grad_new - opt.grad_old ) / opt.grad_old .norm ();
108+ }
109+ };
110+
111+ } // namespace internals
112+
113+ // public CG optimizers
114+ template <int N, typename ... Args>
115+ class FletcherReevesCG : public internals ::conjugate_gradient_impl<N, internals::fletcher_reeves_impl, Args...> {
116+ private:
117+ using Base = internals::conjugate_gradient_impl<N, internals::fletcher_reeves_impl, Args...>;
118+ public:
119+ FletcherReevesCG () = default ;
120+ FletcherReevesCG (int max_iter, double tol, double step)
121+ requires (sizeof ...(Args) != 0 )
122+ : Base(max_iter, tol, step) { }
123+ FletcherReevesCG (int max_iter, double tol, double step, Args&&... callbacks) :
124+ Base (max_iter, tol, step, callbacks...) { }
125+ };
126+
127+ template <int N, typename ... Args>
128+ class PolakRibiereCG : public internals ::conjugate_gradient_impl<N, internals::polak_ribiere_impl, Args...> {
129+ private:
130+ using Base = internals::conjugate_gradient_impl<N, internals::polak_ribiere_impl, Args...>;
131+ public:
132+ PolakRibiereCG () = default ;
133+ PolakRibiereCG (int max_iter, double tol, double step)
134+ requires (sizeof ...(Args) != 0 )
135+ : Base(max_iter, tol, step) { }
136+ PolakRibiereCG (int max_iter, double tol, double step, Args&&... callbacks) :
137+ Base (max_iter, tol, step, callbacks...) { }
138+ };
139+
122140} // namespace fdapde
123141
124- #endif // __FDAPDE_CONJUGATE_GRADIENT_H__
142+ #endif // __FDAPDE_CONJUGATE_GRADIENT_H__
0 commit comments