Alpha Expansion Library
C++ library for the Alpha-Expansion graph-cut algorithm with Python bindings
Loading...
Searching...
No Matches
AlphaExpansion.hpp
Go to the documentation of this file.
1#pragma once
2
5#include <memory>
6#include <functional>
7#include <type_traits>
8
19template <typename T>
21public:
25 using SolverFactory = std::function<std::unique_ptr<MaxFlowSolver<T>>(int num_vars, int num_edges)>;
26
31 : model_(model), solver_factory_(std::move(solver_factory)) {
32 }
33
41 [[nodiscard]] bool perform_expansion_move(const int alpha_label) const {
42 const std::vector<int> active_nodes = model_.get_active_nodes(alpha_label);
43 if (active_nodes.empty()) return false;
44
45 std::vector<typename MaxFlowSolver<T>::Var> node_var_ids;
46 const auto solver = build_expansion_graph(alpha_label, active_nodes, node_var_ids);
47
48 solver->minimize();
49
50 bool changed = false;
51 std::vector<int> proposed_labels = model_.get_labels();
52
53 for (const int node: active_nodes) {
54 if (const typename MaxFlowSolver<T>::Var var = node_var_ids[node]; solver->get_var(var) == 0) {
55 proposed_labels[node] = alpha_label;
56 changed = true;
57 }
58 }
59
60 if (changed) {
61 T old_energy = model_.evaluate_total_energy();
62 T new_energy = model_.evaluate_total_energy(proposed_labels);
63 bool improved = false;
64 if constexpr (std::is_floating_point_v<T>) {
65 improved = (old_energy - new_energy > static_cast<T>(1e-5));
66 } else {
67 improved = (new_energy < old_energy);
68 }
69 if (improved) {
70 model_.set_labels(proposed_labels);
71 return true;
72 }
73 }
74
75 return false;
76 }
77
78private:
80 std::unique_ptr<MaxFlowSolver<T>> build_expansion_graph(const int alpha_label, const std::vector<int> &active_nodes,
81 std::vector<typename MaxFlowSolver<T>::Var> &node_var_ids) const {
82 const int num_active = active_nodes.size();
83 if (num_active == 0) return nullptr;
84
85 const int estimated_edges = num_active * 4;
86 auto solver = solver_factory_(num_active, estimated_edges);
87
88 node_var_ids.assign(model_.num_nodes(), -1);
89
90 for (int i = 0; i < num_active; ++i) {
91 int node = active_nodes[i];
92 node_var_ids[node] = solver->add_variable();
93 }
94
95 for (const int node: active_nodes) {
96 const typename MaxFlowSolver<T>::Var var = node_var_ids[node];
97 const int current_label = model_.get_label(node);
98 const T e0 = model_.get_unary_cost(node, alpha_label);
99 const T e1 = model_.get_unary_cost(node, current_label);
100 solver->add_term1(var, e0, e1);
101 }
102
103 for (const int node_i: active_nodes) {
104 const int current_label_i = model_.get_label(node_i);
105 const typename MaxFlowSolver<T>::Var var_i = node_var_ids[node_i];
106
107 for (const int node_j: model_.get_neighbors(node_i)) {
108 if (node_var_ids[node_j] != -1) {
109 if (node_i < node_j) {
110 const typename MaxFlowSolver<T>::Var var_j = node_var_ids[node_j];
111 const int current_label_j = model_.get_label(node_j);
112 const T e00 = model_.get_pairwise_cost(node_i, node_j, alpha_label, alpha_label);
113 const T e01 = model_.get_pairwise_cost(node_i, node_j, alpha_label, current_label_j);
114 const T e10 = model_.get_pairwise_cost(node_i, node_j, current_label_i, alpha_label);
115 const T e11 = model_.get_pairwise_cost(node_i, node_j, current_label_i, current_label_j);
116 solver->add_term2(var_i, var_j, e00, e01, e10, e11);
117 }
118 } else {
119 const int current_label_j = model_.get_label(node_j);
120 const T e0 = model_.get_pairwise_cost(node_i, node_j, alpha_label, current_label_j);
121 const T e1 = model_.get_pairwise_cost(node_i, node_j, current_label_i, current_label_j);
122 solver->add_term1(var_i, e0, e1);
123 }
124 }
125 }
126
127 return solver;
128 }
129
130 EnergyModel<T> &model_;
131 SolverFactory solver_factory_;
132};
Performs alpha-expansion moves on an EnergyModel using a pluggable max-flow solver.
Definition AlphaExpansion.hpp:20
bool perform_expansion_move(const int alpha_label) const
Attempts a single alpha-expansion move for alpha_label.
Definition AlphaExpansion.hpp:41
std::function< std::unique_ptr< MaxFlowSolver< T > >(int num_vars, int num_edges)> SolverFactory
Definition AlphaExpansion.hpp:25
AlphaExpansion(EnergyModel< T > &model, SolverFactory solver_factory)
Constructs the optimizer.
Definition AlphaExpansion.hpp:30
Stores the graph and energy costs for the Alpha-Expansion algorithm.
Definition EnergyModel.hpp:17
int Var
Integer handle identifying a binary variable.
Definition MaxFlowSolver.hpp:20