Skip to content

Add distance_matrix to rustworkx-core #1439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions releasenotes/notes/add-distance-matrix-8cbe417d6f4eaf6d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added a new function ``rustworkx_core::shortest_path::distance_matrix``
to rustworkx-core. This function is the equivalent of :func:`.distance_matrix`
for the Python library, but as a generic Rust function for rustworkx-core.
159 changes: 159 additions & 0 deletions rustworkx-core/src/shortest_path/distance_matrix.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use std::hash::Hash;

use hashbrown::HashMap;

use fixedbitset::FixedBitSet;
use ndarray::prelude::*;
use petgraph::visit::{
GraphProp, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, NodeIndexable,
};
use petgraph::{Incoming, Outgoing};
use rayon::prelude::*;

/// Get the distance matrix for a graph
///
/// The generated distance matrix assumes the edge weight for all edges is
/// 1.0 and returns a matrix.
///
/// This function is also multithreaded and will run in parallel if the number
/// of nodes in the graph is above the value of `parallel_threshold`. If the function
/// will be running in parallel the env var
/// `RAYON_NUM_THREADS` can be used to adjust how many threads will be used.
///
/// # Arguments:
///
/// * graph - The graph object to compute the distance matrix for.
/// * parallel_threshold - The threshold in number of nodes to run this function in parallel.
/// If `graph` has fewer nodes than this the algorithm will run serially. A good default
/// to use for this is 300.
/// * as_undirected - If the input graph is directed and this is set to true the output
/// matrix generated
/// * null_value - The value to use for the absence of a path in the graph.
///
/// # Returns
///
/// A 2d ndarray [`Array`] of the distance matrix
///
/// # Example
///
/// ```rust
/// use rustworkx_core::petgraph;
/// use rustworkx_core::shortest_path::distance_matrix;
/// use ndarray::{array, Array2};
///
/// let graph = petgraph::graph::UnGraph::<(), ()>::from_edges(&[
/// (0, 1), (0, 6), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)
/// ]);
/// let distance_matrix = distance_matrix(&graph, 300, false, 0.);
/// let expected: Array2<f64> = array![
/// [0.0, 1.0, 2.0, 3.0, 3.0, 2.0, 1.0],
/// [1.0, 0.0, 1.0, 2.0, 3.0, 3.0, 2.0],
/// [2.0, 1.0, 0.0, 1.0, 2.0, 3.0, 3.0],
/// [3.0, 2.0, 1.0, 0.0, 1.0, 2.0, 3.0],
/// [3.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0],
/// [2.0, 3.0, 3.0, 2.0, 1.0, 0.0, 1.0],
/// [1.0, 2.0, 3.0, 3.0, 2.0, 1.0, 0.0],
/// ];
/// assert_eq!(distance_matrix, expected)
/// ```
pub fn distance_matrix<G>(
graph: G,
parallel_threshold: usize,
as_undirected: bool,
null_value: f64,
) -> Array2<f64>
where
G: Sync + IntoNeighborsDirected + NodeCount + NodeIndexable + IntoNodeIdentifiers + GraphProp,
G::NodeId: Hash + Eq + Sync,
{
let n = graph.node_count();
let node_map: HashMap<G::NodeId, usize> = if n != graph.node_bound() {
graph
.node_identifiers()
.enumerate()
.map(|(i, v)| (v, i))
.collect()
} else {
HashMap::new()
};
let node_map_inv: Vec<G::NodeId> = if n != graph.node_bound() {
graph.node_identifiers().collect()
} else {
Vec::new()
};
let mut node_map_fn: Box<dyn FnMut(G::NodeId) -> usize> = if n != graph.node_bound() {
Box::new(|n: G::NodeId| -> usize { node_map[&n] })
} else {
Box::new(|n: G::NodeId| -> usize { graph.to_index(n) })
};
let mut reverse_node_map: Box<dyn FnMut(usize) -> G::NodeId> = if n != graph.node_bound() {
Box::new(|n: usize| -> G::NodeId { node_map_inv[n] })
} else {
Box::new(|n: usize| -> G::NodeId { graph.from_index(n) })
};
let mut matrix = Array2::<f64>::from_elem((n, n), null_value);
let neighbors = if as_undirected {
(0..n)
.map(|index| {
graph
.neighbors_directed(reverse_node_map(index), Incoming)
.chain(graph.neighbors_directed(reverse_node_map(index), Outgoing))
.map(&mut node_map_fn)
.collect::<FixedBitSet>()
})
.collect::<Vec<_>>()
} else {
(0..n)
.map(|index| {
graph
.neighbors(reverse_node_map(index))
.map(&mut node_map_fn)
.collect::<FixedBitSet>()
})
.collect::<Vec<_>>()
};
let bfs_traversal = |start: usize, mut row: ArrayViewMut1<f64>| {
let mut distance = 0.0;
let mut seen = FixedBitSet::with_capacity(n);
let mut next = FixedBitSet::with_capacity(n);
let mut cur = FixedBitSet::with_capacity(n);
cur.put(start);
while !cur.is_clear() {
next.clear();
for found in cur.ones() {
row[[found]] = distance;
next |= &neighbors[found];
}
seen.union_with(&cur);
next.difference_with(&seen);
distance += 1.0;
::std::mem::swap(&mut cur, &mut next);
}
};
if n < parallel_threshold {
matrix
.axis_iter_mut(Axis(0))
.enumerate()
.for_each(|(index, row)| bfs_traversal(index, row));
} else {
// Parallelize by row and iterate from each row index in BFS order
matrix
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(index, row)| bfs_traversal(index, row));
}
matrix
}
2 changes: 2 additions & 0 deletions rustworkx-core/src/shortest_path/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ mod all_shortest_paths;
mod astar;
mod bellman_ford;
mod dijkstra;
mod distance_matrix;
mod k_shortest_path;
mod single_source_all_shortest_paths;

pub use all_shortest_paths::all_shortest_paths;
pub use astar::astar;
pub use bellman_ford::{bellman_ford, negative_cycle_finder};
pub use dijkstra::dijkstra;
pub use distance_matrix::distance_matrix;
pub use k_shortest_path::k_shortest_path;
pub use single_source_all_shortest_paths::single_source_all_shortest_paths;
91 changes: 2 additions & 89 deletions src/shortest_path/distance_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,105 +10,18 @@
// License for the specific language governing permissions and limitations
// under the License.

use std::ops::Index;

use hashbrown::{HashMap, HashSet};

use ndarray::prelude::*;
use petgraph::prelude::*;
use petgraph::EdgeType;
use rayon::prelude::*;

use crate::NodesRemoved;
use crate::StablePyGraph;

#[inline]
fn apply<I, M>(
map_fn: &Option<M>,
x: I,
default: <M as Index<I>>::Output,
) -> <M as Index<I>>::Output
where
M: Index<I>,
<M as Index<I>>::Output: Sized + Copy,
{
match map_fn {
Some(map) => map[x],
None => default,
}
}
use rustworkx_core::shortest_path;

pub fn compute_distance_matrix<Ty: EdgeType + Sync>(
graph: &StablePyGraph<Ty>,
parallel_threshold: usize,
as_undirected: bool,
null_value: f64,
) -> Array2<f64> {
let node_map: Option<HashMap<NodeIndex, usize>> = if graph.nodes_removed() {
Some(
graph
.node_indices()
.enumerate()
.map(|(i, v)| (v, i))
.collect(),
)
} else {
None
};

let node_map_inv: Option<Vec<NodeIndex>> = if graph.nodes_removed() {
Some(graph.node_indices().collect())
} else {
None
};

let n = graph.node_count();
let mut matrix = Array2::<f64>::from_elem((n, n), null_value);
let bfs_traversal = |index: usize, mut row: ArrayViewMut1<f64>| {
let mut seen: HashMap<NodeIndex, usize> = HashMap::with_capacity(n);
let start_index = apply(&node_map_inv, index, NodeIndex::new(index));
let mut level = 0;
let mut next_level: HashSet<NodeIndex> = HashSet::new();
next_level.insert(start_index);
while !next_level.is_empty() {
let this_level = next_level;
next_level = HashSet::new();
let mut found: Vec<NodeIndex> = Vec::new();
for v in this_level {
if !seen.contains_key(&v) {
seen.insert(v, level);
found.push(v);
row[[apply(&node_map, &v, v.index())]] = level as f64;
}
}
if seen.len() == n {
return;
}
for node in found {
for v in graph.neighbors_directed(node, petgraph::Direction::Outgoing) {
next_level.insert(v);
}
if graph.is_directed() && as_undirected {
for v in graph.neighbors_directed(node, petgraph::Direction::Incoming) {
next_level.insert(v);
}
}
}
level += 1
}
};
if n < parallel_threshold {
matrix
.axis_iter_mut(Axis(0))
.enumerate()
.for_each(|(index, row)| bfs_traversal(index, row));
} else {
// Parallelize by row and iterate from each row index in BFS order
matrix
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(index, row)| bfs_traversal(index, row));
}
matrix
shortest_path::distance_matrix(graph, parallel_threshold, as_undirected, null_value)
}