Alpha Expansion Library
C++ library for the Alpha-Expansion graph-cut algorithm with Python bindings
Loading...
Searching...
No Matches
EnergyModel.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <vector>
4#include <functional>
5#include <unordered_map>
6#include <cstdint>
7
16template <typename T>
18public:
20 using UnaryCostFn = std::function<T(int node, int label)>;
23 using PairwiseCostFn = std::function<T(int node1, int node2, int label1, int label2)>;
24
29 : num_nodes_(num_nodes), num_labels_(num_labels), labels_(num_nodes, 0), neighbors_(num_nodes) {}
30
32 [[nodiscard]] int num_nodes() const { return num_nodes_; }
33
35 [[nodiscard]] int num_labels() const { return num_labels_; }
36
38 [[nodiscard]] int get_label(int node) const { return labels_[node]; }
39
41 void set_label(int node, int label) { labels_[node] = label; }
42
44 [[nodiscard]] const std::vector<int>& get_labels() const { return labels_; }
45
48 void set_labels(const std::vector<int>& labels) { labels_ = labels; }
49
53 void set_unary_cost_fn(UnaryCostFn fn) { unary_cost_fn_ = fn; }
54
58 void set_pairwise_cost_fn(PairwiseCostFn fn) { pairwise_cost_fn_ = fn; }
59
63 [[nodiscard]] T get_unary_cost(int node, int label) const {
64 if (!unary_costs_.empty()) {
65 return unary_costs_[node * num_labels_ + label];
66 }
67 return unary_cost_fn_ ? unary_cost_fn_(node, label) : 0;
68 }
69
73 [[nodiscard]] T get_pairwise_cost(int node1, int node2, int label1, int label2) const {
74 if (!edge_weights_.empty()) {
75 auto it = edge_weights_.find(edge_key(node1, node2));
76 if (it != edge_weights_.end()) {
77 return label1 == label2 ? T{0} : it->second;
78 }
79 }
80 if (!pairwise_costs_.empty()) {
81 return pairwise_costs_[label1 * num_labels_ + label2];
82 }
83 return pairwise_cost_fn_ ? pairwise_cost_fn_(node1, node2, label1, label2) : 0;
84 }
85
91 void set_unary_costs(const std::vector<T>& costs) {
92 if (costs.size() != static_cast<size_t>(num_nodes_ * num_labels_)) {
93 throw std::invalid_argument("Unary costs array must have size num_nodes * num_labels");
94 }
95 unary_costs_ = costs;
96 }
97
104 void set_pairwise_costs(const std::vector<T>& costs) {
105 if (costs.size() != static_cast<size_t>(num_labels_ * num_labels_)) {
106 throw std::invalid_argument("Pairwise costs array must have size num_labels * num_labels");
107 }
108 pairwise_costs_ = costs;
109 }
110
117 void set_edge_weights(const std::vector<int>& n1s, const std::vector<int>& n2s, const std::vector<T>& weights) {
118 if (n1s.size() != n2s.size() || n1s.size() != weights.size()) {
119 throw std::invalid_argument("n1s, n2s, and weights must have the same size");
120 }
121 for (size_t i = 0; i < n1s.size(); ++i) {
122 edge_weights_[edge_key(n1s[i], n2s[i])] = weights[i];
123 }
124 }
125
127 void add_neighbor(int node1, int node2) {
128 neighbors_[node1].push_back(node2);
129 neighbors_[node2].push_back(node1);
130 }
131
134 void add_grid_edges(int width, int height) {
135 if (width * height != num_nodes_) {
136 throw std::invalid_argument("Grid dimensions do not match the number of nodes");
137 }
138 for (int y = 0; y < height; ++y) {
139 for (int x = 0; x < width; ++x) {
140 int node = y * width + x;
141 if (x + 1 < width) add_neighbor(node, node + 1);
142 if (y + 1 < height) add_neighbor(node, node + width);
143 }
144 }
145 }
146
148 [[nodiscard]] const std::vector<int>& get_neighbors(int node) const {
149 return neighbors_[node];
150 }
151
155 [[nodiscard]] std::vector<int> get_active_nodes(int alpha_label) const {
156 std::vector<int> active_nodes;
157 active_nodes.reserve(num_nodes_);
158 for (int i = 0; i < num_nodes_; ++i) {
159 if (labels_[i] != alpha_label) active_nodes.push_back(i);
160 }
161 return active_nodes;
162 }
163
166 [[nodiscard]] T evaluate_total_energy(const std::vector<int> &eval_labels) const {
167 T total = 0;
168 for (int i = 0; i < num_nodes_; ++i) {
169 total += get_unary_cost(i, eval_labels[i]);
170 for (int neighbor : neighbors_[i]) {
171 if (i < neighbor) {
172 total += get_pairwise_cost(i, neighbor, eval_labels[i], eval_labels[neighbor]);
173 }
174 }
175 }
176 return total;
177 }
178
180 [[nodiscard]] T evaluate_total_energy() const {
181 return evaluate_total_energy(labels_);
182 }
183
184private:
185 int64_t edge_key(int n1, int n2) const {
186 int a = std::min(n1, n2), b = std::max(n1, n2);
187 return (int64_t)a * num_nodes_ + b;
188 }
189
190 int num_nodes_;
191 int num_labels_;
192 std::vector<int> labels_;
193 std::vector<std::vector<int>> neighbors_;
194 UnaryCostFn unary_cost_fn_;
195 PairwiseCostFn pairwise_cost_fn_;
196 std::vector<T> unary_costs_;
197 std::vector<T> pairwise_costs_;
198 std::unordered_map<int64_t, T> edge_weights_;
199};
Stores the graph and energy costs for the Alpha-Expansion algorithm.
Definition EnergyModel.hpp:17
int num_labels() const
Returns the total number of labels.
Definition EnergyModel.hpp:35
const std::vector< int > & get_neighbors(int node) const
Returns the neighbours of node.
Definition EnergyModel.hpp:148
T get_pairwise_cost(int node1, int node2, int label1, int label2) const
Returns the pairwise cost for the given node–label pair.
Definition EnergyModel.hpp:73
void set_label(int node, int label)
Assigns label to node.
Definition EnergyModel.hpp:41
EnergyModel(int num_nodes, int num_labels)
Constructs an energy model with all labels initialized to 0.
Definition EnergyModel.hpp:28
void set_unary_cost_fn(UnaryCostFn fn)
Sets a callback function for unary costs.
Definition EnergyModel.hpp:53
void set_edge_weights(const std::vector< int > &n1s, const std::vector< int > &n2s, const std::vector< T > &weights)
Sets per-edge smoothness weights (Potts model).
Definition EnergyModel.hpp:117
std::function< T(int node1, int node2, int label1, int label2)> PairwiseCostFn
Definition EnergyModel.hpp:23
void set_unary_costs(const std::vector< T > &costs)
Sets unary costs from a flat row-major array of size num_nodes * num_labels.
Definition EnergyModel.hpp:91
void set_pairwise_cost_fn(PairwiseCostFn fn)
Sets a callback function for pairwise costs.
Definition EnergyModel.hpp:58
int get_label(int node) const
Returns the current label assigned to node.
Definition EnergyModel.hpp:38
void set_labels(const std::vector< int > &labels)
Replaces the full label vector.
Definition EnergyModel.hpp:48
int num_nodes() const
Returns the total number of nodes.
Definition EnergyModel.hpp:32
const std::vector< int > & get_labels() const
Returns the full label vector (one entry per node).
Definition EnergyModel.hpp:44
std::vector< int > get_active_nodes(int alpha_label) const
Returns the indices of all nodes that do not currently have alpha_label.
Definition EnergyModel.hpp:155
std::function< T(int node, int label)> UnaryCostFn
Callable that returns the unary cost for assigning label to node.
Definition EnergyModel.hpp:20
T get_unary_cost(int node, int label) const
Returns the unary cost for assigning label to node.
Definition EnergyModel.hpp:63
void add_grid_edges(int width, int height)
Populates a 4-connected grid neighbourhood for an image of size width × height.
Definition EnergyModel.hpp:134
T evaluate_total_energy() const
Evaluates the total energy for the model's current label assignment.
Definition EnergyModel.hpp:180
void add_neighbor(int node1, int node2)
Adds an undirected edge between node1 and node2.
Definition EnergyModel.hpp:127
T evaluate_total_energy(const std::vector< int > &eval_labels) const
Evaluates the total energy for a given label assignment.
Definition EnergyModel.hpp:166
void set_pairwise_costs(const std::vector< T > &costs)
Sets a global pairwise cost matrix of size num_labels * num_labels.
Definition EnergyModel.hpp:104