Skip to content

Commit a2a9c88

Browse files
committed
bug fix: GridSearch handles ND grid points, N > 1 (grid mapping is compliant to Eigen storage order)
1 parent 7ed083b commit a2a9c88

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

fdaPDE/src/optimization/grid_search.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ namespace fdapde {
2424
template <int N> class GridSearch {
2525
private:
2626
using vector_t = std::conditional_t<N == Dynamic, Eigen::Matrix<double, Dynamic, 1>, Eigen::Matrix<double, N, 1>>;
27-
using grid_t = MdMap<const double, MdExtents<Dynamic, Dynamic>>;
2827

2928
vector_t optimum_;
3029
double value_; // objective value at optimum
@@ -49,7 +48,16 @@ template <int N> class GridSearch {
4948
fdapde_static_assert(
5049
std::is_same<decltype(std::declval<ObjectiveT>().operator()(vector_t())) FDAPDE_COMMA double>::value,
5150
INVALID_CALL_TO_OPTIMIZE__OBJECTIVE_FUNCTOR_NOT_CALLABLE_AT_VECTOR_TYPE);
51+
using layout_policy = decltype([]() {
52+
if constexpr (internals::is_eigen_dense_xpr_v<GridT>) {
53+
return std::conditional_t<GridT::IsRowMajor, internals::layout_right, internals::layout_left> {};
54+
} else {
55+
return internals::layout_right {};
56+
}
57+
}());
58+
using grid_t = MdMap<const double, MdExtents<Dynamic, Dynamic>, layout_policy>;
5259
constexpr double NaN = std::numeric_limits<double>::quiet_NaN();
60+
5361
std::tuple<Callbacks...> callbacks_ {callbacks...};
5462
grid_t grid_;
5563
value_ = std::numeric_limits<double>::max();
@@ -61,7 +69,7 @@ template <int N> class GridSearch {
6169
grid_ = grid_t(grid.data(), grid.rows(), size_);
6270
}
6371
bool stop = false; // asserted true in case of forced stop
64-
grid_.row(0).assign_to(x_curr);
72+
grid_.row(0).assign_to(x_curr.transpose());
6573
obj_curr = objective(x_curr);
6674
stop |= internals::exec_eval_hooks(*this, objective, callbacks_);
6775
values_.clear();
@@ -72,7 +80,7 @@ template <int N> class GridSearch {
7280
}
7381
// optimize field over supplied grid
7482
for (std::size_t i = 1; i < grid_.rows() && !stop; ++i) {
75-
grid_.row(i).assign_to(x_curr);
83+
grid_.row(i).assign_to(x_curr.transpose());
7684
obj_curr = objective(x_curr);
7785
stop |= internals::exec_eval_hooks(*this, objective, callbacks_);
7886
values_.push_back(obj_curr);

0 commit comments

Comments
 (0)