5#include <unordered_map>
23 using PairwiseCostFn = std::function<T(
int node1,
int node2,
int label1,
int label2)>;
32 [[nodiscard]]
int num_nodes()
const {
return num_nodes_; }
35 [[nodiscard]]
int num_labels()
const {
return num_labels_; }
38 [[nodiscard]]
int get_label(
int node)
const {
return labels_[node]; }
41 void set_label(
int node,
int label) { labels_[node] = label; }
44 [[nodiscard]]
const std::vector<int>&
get_labels()
const {
return labels_; }
48 void set_labels(
const std::vector<int>& labels) { labels_ = labels; }
64 if (!unary_costs_.empty()) {
65 return unary_costs_[node * num_labels_ + label];
67 return unary_cost_fn_ ? unary_cost_fn_(node, label) : 0;
74 if (!edge_weights_.empty()) {
75 auto it = edge_weights_.find(edge_key(node1, node2));
76 if (it != edge_weights_.end()) {
77 return label1 == label2 ? T{0} : it->second;
80 if (!pairwise_costs_.empty()) {
81 return pairwise_costs_[label1 * num_labels_ + label2];
83 return pairwise_cost_fn_ ? pairwise_cost_fn_(node1, node2, label1, label2) : 0;
92 if (costs.size() !=
static_cast<size_t>(num_nodes_ * num_labels_)) {
93 throw std::invalid_argument(
"Unary costs array must have size num_nodes * num_labels");
105 if (costs.size() !=
static_cast<size_t>(num_labels_ * num_labels_)) {
106 throw std::invalid_argument(
"Pairwise costs array must have size num_labels * num_labels");
108 pairwise_costs_ = costs;
117 void set_edge_weights(
const std::vector<int>& n1s,
const std::vector<int>& n2s,
const std::vector<T>& weights) {
118 if (n1s.size() != n2s.size() || n1s.size() != weights.size()) {
119 throw std::invalid_argument(
"n1s, n2s, and weights must have the same size");
121 for (
size_t i = 0; i < n1s.size(); ++i) {
122 edge_weights_[edge_key(n1s[i], n2s[i])] = weights[i];
128 neighbors_[node1].push_back(node2);
129 neighbors_[node2].push_back(node1);
135 if (width * height != num_nodes_) {
136 throw std::invalid_argument(
"Grid dimensions do not match the number of nodes");
138 for (
int y = 0; y < height; ++y) {
139 for (
int x = 0; x < width; ++x) {
140 int node = y * width + x;
149 return neighbors_[node];
156 std::vector<int> active_nodes;
157 active_nodes.reserve(num_nodes_);
158 for (
int i = 0; i < num_nodes_; ++i) {
159 if (labels_[i] != alpha_label) active_nodes.push_back(i);
168 for (
int i = 0; i < num_nodes_; ++i) {
170 for (
int neighbor : neighbors_[i]) {
185 int64_t edge_key(
int n1,
int n2)
const {
186 int a = std::min(n1, n2), b = std::max(n1, n2);
187 return (int64_t)a * num_nodes_ + b;
192 std::vector<int> labels_;
193 std::vector<std::vector<int>> neighbors_;
196 std::vector<T> unary_costs_;
197 std::vector<T> pairwise_costs_;
198 std::unordered_map<int64_t, T> edge_weights_;
Stores the graph and energy costs for the Alpha-Expansion algorithm.
Definition EnergyModel.hpp:17
int num_labels() const
Returns the total number of labels.
Definition EnergyModel.hpp:35
const std::vector< int > & get_neighbors(int node) const
Returns the neighbours of node.
Definition EnergyModel.hpp:148
T get_pairwise_cost(int node1, int node2, int label1, int label2) const
Returns the pairwise cost for the given node–label pair.
Definition EnergyModel.hpp:73
void set_label(int node, int label)
Assigns label to node.
Definition EnergyModel.hpp:41
EnergyModel(int num_nodes, int num_labels)
Constructs an energy model with all labels initialized to 0.
Definition EnergyModel.hpp:28
void set_unary_cost_fn(UnaryCostFn fn)
Sets a callback function for unary costs.
Definition EnergyModel.hpp:53
void set_edge_weights(const std::vector< int > &n1s, const std::vector< int > &n2s, const std::vector< T > &weights)
Sets per-edge smoothness weights (Potts model).
Definition EnergyModel.hpp:117
std::function< T(int node1, int node2, int label1, int label2)> PairwiseCostFn
Definition EnergyModel.hpp:23
void set_unary_costs(const std::vector< T > &costs)
Sets unary costs from a flat row-major array of size num_nodes * num_labels.
Definition EnergyModel.hpp:91
void set_pairwise_cost_fn(PairwiseCostFn fn)
Sets a callback function for pairwise costs.
Definition EnergyModel.hpp:58
int get_label(int node) const
Returns the current label assigned to node.
Definition EnergyModel.hpp:38
void set_labels(const std::vector< int > &labels)
Replaces the full label vector.
Definition EnergyModel.hpp:48
int num_nodes() const
Returns the total number of nodes.
Definition EnergyModel.hpp:32
const std::vector< int > & get_labels() const
Returns the full label vector (one entry per node).
Definition EnergyModel.hpp:44
std::vector< int > get_active_nodes(int alpha_label) const
Returns the indices of all nodes that do not currently have alpha_label.
Definition EnergyModel.hpp:155
std::function< T(int node, int label)> UnaryCostFn
Callable that returns the unary cost for assigning label to node.
Definition EnergyModel.hpp:20
T get_unary_cost(int node, int label) const
Returns the unary cost for assigning label to node.
Definition EnergyModel.hpp:63
void add_grid_edges(int width, int height)
Populates a 4-connected grid neighbourhood for an image of size width × height.
Definition EnergyModel.hpp:134
T evaluate_total_energy() const
Evaluates the total energy for the model's current label assignment.
Definition EnergyModel.hpp:180
void add_neighbor(int node1, int node2)
Adds an undirected edge between node1 and node2.
Definition EnergyModel.hpp:127
T evaluate_total_energy(const std::vector< int > &eval_labels) const
Evaluates the total energy for a given label assignment.
Definition EnergyModel.hpp:166
void set_pairwise_costs(const std::vector< T > &costs)
Sets a global pairwise cost matrix of size num_labels * num_labels.
Definition EnergyModel.hpp:104