Visualization
Source Code
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from hina.dyad import prune_edges
from hina.mesoscale import hina_communities
import networkx as nx
def plot_hina(B, layout='bipartite', group_name = [None, None], pruning_kwargs=None, NetworkX_kwargs=None):
"""
Visualizes a bipartite network with customizable layout, node grouping, and edge pruning.
This function creates a visualization of a bipartite graph `B` using a specified layout. It supports
pruning edges based on statistical significance, grouping nodes by a selected group information, and customizing
the appearance of the graph using NetworkX visualization parameters.
Parameters:
-----------
B : networkx.Graph
A bipartite graph to visualize. Nodes must have a 'bipartite' attribute indicating their partition.
If using HINA to analyze tripartite networks, it is recommended to visualize the object-object graphs for
each community after detecting communities with hina_communities(). These projected graphs represent
relationships between objects within each community and provide additional insights into the
structure and interactions of the network.
layout : str, optional
The layout to use for node positioning. Supported layouts are:
- 'bipartite': Nodes are positioned in two vertical columns (default).
- 'spring': Force-directed layout for a visually appealing arrangement.
- 'circular': Nodes are arranged in a circle.
group_name : list, optional
AA list of two elements specifying the node attribute and its corresponding value to filter nodes
for visualization. The first element is the name of the node attribute (e.g., 'group'),
which corresponds to a column in the input dataframe. The second element is the specific value of the
attribute (e.g., 'A'), used to filter nodes. Only nodes with this attribute value will be included
in the visualization. For example, ['group', 'A'] will include only nodes where the 'group' attribute is 'A'.
Default is `[None, None]`, which includes all nodes.
pruning_kwargs : dict, optional
A dictionary of parameters for pruning edges based on statistical significance.
If provided, the prune_edges function is called to remove edges that are not
statistically significant before visualization.For example, {'fix_deg': 'students'}
specifies that the degrees of the 'students' node set should be fixed during pruning.
Default is `None`, which skips pruning.
NetworkX_kwargs : dict, optional
Additional keyword arguments for customizing the NetworkX visualization (e.g., node size, edge color).
Default is `None`.
Returns:
--------
Displays a plot of the bipartite network.
"""
if NetworkX_kwargs is None:
NetworkX_kwargs = {}
# Prune edges if pruning_kwargs is provided
if pruning_kwargs is not None:
B = prune_edges(B, **pruning_kwargs)['pruned network']
if group_name is not None:
G_sub = nx.Graph()
u_nodes = [i for i, j in B.nodes(data=True) if j.get(group_name[0]) == group_name[1]]
for u_node in u_nodes:
G_sub.add_node(u_node, **B.nodes[u_node])
v_nodes = set()
for u_node in u_nodes:
for v_node in B.neighbors(u_node):
v_nodes.add(v_node)
G_sub.add_node(v_node, **B.nodes[v_node])
for u_node in u_nodes:
for v_node in B.neighbors(u_node):
if B.has_edge(u_node, v_node):
G_sub.add_edge(u_node, v_node, **B.edges[u_node, v_node])
B = G_sub
v = set()
bipartite_top = [x for x in [data['bipartite'] for n, data in B.nodes(data=True)] \
if not (x in v or v.add(x))][0]
# Create a color dictionary
color_dict = {n: 'red' if data['bipartite'] == bipartite_top else 'blue'
for n, data in B.nodes(data=True)}
# Get the list of nodes in the top partition for bipartite layout
top_nodes = [n for n, data in B.nodes(data=True) if data['bipartite'] == bipartite_top]
# Set the layout
if layout == 'bipartite':
pos = nx.bipartite_layout(B, top_nodes, align='vertical', scale=2, aspect_ratio=4)
elif layout == 'spring':
pos = nx.spring_layout(B, k=0.2)
elif layout == 'circular':
pos = nx.circular_layout(B)
else:
raise ValueError(f"Unsupported layout: {layout}")
# Calculate label offset
max_y = max(abs(y) for _, y in pos.values())
label_offset = max_y * 0.03
# Set node colors and edge widths
node_colors = [color_dict[n] for n in B.nodes()]
edge_widths = [d.get('weight', 1) / 15 for _, _, d in B.edges(data=True)] # Default weight = 1 if not present
# Plot the graph
plt.figure(figsize=(12, 12))
nx.draw(
B,
pos,
with_labels=False,
node_color=node_colors,
width=edge_widths,
node_size=200,
**NetworkX_kwargs
)
# Add labels
for node, (x, y) in pos.items():
label = str(node)
plt.text(
x, y + label_offset,
label,
fontsize=9,
ha='center',
va='center',
color='black'
)
plt.title("HINA Network Visualization")
plt.show()
def plot_bipartite_clusters(G, noise_scale=3, radius=20., encode_labels=False,
node_labels=True, edge_labels=False,
scale_nodes_by_degree=False, node_scale=2000.,
node_kwargs={'edgecolors': 'black'}, edge_kwargs={'edge_color': 'black'}):
"""
Visualizes a bipartite graph with nodes grouped into communities, highlighting the community structure.
This function plots a bipartite graph `G` with nodes arranged in a circular layout. Nodes from the first set
are positioned around the circumference, grouped by their community labels, while nodes from the second set
are positioned inside the circle. The visualization supports customizing node sizes, colors, and labels.
Parameters:
-----------
G : networkx.Graph
A bipartite graph with weighted edges. Nodes must belong to one of two sets (e.g., 'set1' and 'set2').
noise_scale : float, optional
Controls the dispersion of nodes in the first set around their community centroids. Higher values increase
randomness in node positions. Default is 3.
radius : float, optional
Controls the radius of the circle on which community centers are placed. Default is 20.
encode_labels : bool, optional
If True, encodes node labels as unique integers and prints the encoding map. Default is False.
node_labels : bool, optional
If True, displays labels for all nodes. If False, only displays labels for nodes in the second set.
Default is True.
edge_labels : bool, optional
If True, displays edge weights as labels. Default is False.
scale_nodes_by_degree : bool, optional
If True, scales node sizes proportionally to their weighted degree. Default is False.
node_scale : float, optional
Controls the average size of nodes. Default is 2000.
node_kwargs : dict, optional
Additional keyword arguments for customizing node appearance in `nx.draw_networkx_nodes`.
Default is `{'edgecolors': 'black'}`.
edge_kwargs : dict, optional
Additional keyword arguments for customizing edge appearance in `nx.draw_networkx_edges`.
Default is `{'edge_color': 'black'}`.
Returns:
--------
None
Displays a plot of the bipartite graph with nodes grouped by communities.
"""
community_labels = hina_communities(G)['node communities']
G_info = set([(i, j, w['weight']) for i, j, w in G.edges(data=True)])
set1 = set([str(e[0]) for e in G_info])
set2 = set([str(e[1]) for e in G_info])
offset = np.random.rand() * np.pi
B = len(set(community_labels.values()))
comm2ind = dict(zip(list(set(community_labels.values())), range(B)))
set1_pos = {}
for node in set1:
c = comm2ind[community_labels[node]]
angle = 2 * np.pi * c / B + offset
x = radius * np.cos(angle) + (2. * np.random.rand() - 1.) * noise_scale
y = radius * np.sin(angle) + (2. * np.random.rand() - 1.) * noise_scale
set1_pos[node] = (x, y)
set2_pos = {}
num_s2 = len(set2)
for c, node in enumerate(set2):
angle = 2 * np.pi * c / num_s2 + offset
x = 0.5 * radius * np.cos(angle)
y = 0.5 * radius * np.sin(angle)
set2_pos[node] = (x, y)
pos = {**set1_pos, **set2_pos}
comm_colors = dict(zip(list(set(community_labels.values())), list(mcolors.TABLEAU_COLORS.values())))
color_dict = {node: comm_colors[community_labels[node]] for node in set1} | {node: 'Gray' for node in set2}
node_colors = {node: color_dict[node] for node in G.nodes()}
edge_weights = [G[u][v]['weight'] for u, v in G.edges()]
max_weight = max(edge_weights)
edge_widths = [weight / max_weight * 5 for weight in edge_weights]
weighted_degrees = {node: sum(weight for _, _, weight in G.edges(node, data='weight'))
for node in G.nodes()}
if scale_nodes_by_degree:
avg = np.mean(list(weighted_degrees.values()))
node_sizes = {node: weighted_degrees[node] / avg * node_scale for node in G.nodes()}
else:
node_sizes = {node: node_scale for node in G.nodes()}
plt.figure(figsize=(20, 20))
nodes = [str(n) for n in set1] + [str(n) for n in set2]
if encode_labels:
codes = [i for i in range(len(nodes))]
labelmap = dict(zip(nodes, codes))
for node in set1:
print('Original Label (Set 1):', node, '| Encoded Label:', labelmap[node])
for node in set2:
print('Original Label (Set 2):', node, '| Encoded Label:', labelmap[node])
else:
labelmap = dict(zip(nodes, nodes))
shapes = {node: 'o' for node in set1} | {node: '^' for node in set2}
for node, shape in shapes.items():
nx.draw_networkx_nodes(G, pos, nodelist=[node], node_shape=shape,
node_color=node_colors[node], node_size=node_sizes[node], **node_kwargs)
nx.draw_networkx_edges(G, pos, width=edge_widths, **edge_kwargs)
label_options = {'bbox': {'facecolor': 'white', 'alpha': 1, 'edgecolor': 'black'}}
if node_labels:
# Show all node labels (both set1 and set2)
nx.draw_networkx_labels(G, pos, labels=labelmap, **label_options)
else:
# Show only set2 labels
set2_labelmap = {node: labelmap[node] for node in set2}
nx.draw_networkx_labels(G, pos, labels=set2_labelmap, **label_options)
if edge_labels:
edge_labels = nx.get_edge_attributes(G, 'weight')
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
plt.show()