25 using SolverFactory = std::function<std::unique_ptr<MaxFlowSolver<T>>(
int num_vars,
int num_edges)>;
31 : model_(model), solver_factory_(std::move(solver_factory)) {
42 const std::vector<int> active_nodes = model_.get_active_nodes(alpha_label);
43 if (active_nodes.empty())
return false;
45 std::vector<typename MaxFlowSolver<T>::Var> node_var_ids;
46 const auto solver = build_expansion_graph(alpha_label, active_nodes, node_var_ids);
51 std::vector<int> proposed_labels = model_.get_labels();
53 for (
const int node: active_nodes) {
55 proposed_labels[node] = alpha_label;
61 T old_energy = model_.evaluate_total_energy();
62 T new_energy = model_.evaluate_total_energy(proposed_labels);
63 bool improved =
false;
64 if constexpr (std::is_floating_point_v<T>) {
65 improved = (old_energy - new_energy >
static_cast<T
>(1e-5));
67 improved = (new_energy < old_energy);
70 model_.set_labels(proposed_labels);
80 std::unique_ptr<MaxFlowSolver<T>> build_expansion_graph(
const int alpha_label,
const std::vector<int> &active_nodes,
82 const int num_active = active_nodes.size();
83 if (num_active == 0)
return nullptr;
85 const int estimated_edges = num_active * 4;
86 auto solver = solver_factory_(num_active, estimated_edges);
88 node_var_ids.assign(model_.num_nodes(), -1);
90 for (
int i = 0; i < num_active; ++i) {
91 int node = active_nodes[i];
92 node_var_ids[node] = solver->add_variable();
95 for (
const int node: active_nodes) {
97 const int current_label = model_.get_label(node);
98 const T e0 = model_.get_unary_cost(node, alpha_label);
99 const T e1 = model_.get_unary_cost(node, current_label);
100 solver->add_term1(var, e0, e1);
103 for (
const int node_i: active_nodes) {
104 const int current_label_i = model_.get_label(node_i);
107 for (
const int node_j: model_.get_neighbors(node_i)) {
108 if (node_var_ids[node_j] != -1) {
109 if (node_i < node_j) {
111 const int current_label_j = model_.get_label(node_j);
112 const T e00 = model_.get_pairwise_cost(node_i, node_j, alpha_label, alpha_label);
113 const T e01 = model_.get_pairwise_cost(node_i, node_j, alpha_label, current_label_j);
114 const T e10 = model_.get_pairwise_cost(node_i, node_j, current_label_i, alpha_label);
115 const T e11 = model_.get_pairwise_cost(node_i, node_j, current_label_i, current_label_j);
116 solver->add_term2(var_i, var_j, e00, e01, e10, e11);
119 const int current_label_j = model_.get_label(node_j);
120 const T e0 = model_.get_pairwise_cost(node_i, node_j, alpha_label, current_label_j);
121 const T e1 = model_.get_pairwise_cost(node_i, node_j, current_label_i, current_label_j);
122 solver->add_term1(var_i, e0, e1);
AlphaExpansion(EnergyModel< T > &model, SolverFactory solver_factory)
Constructs the optimizer.
Definition AlphaExpansion.hpp:30