from typing import Type
import dash
import dash_cytoscape as cyto
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from dash import html
from macrostat.core import Model
# Register the extra layouts
cyto.load_extra_layouts()
[docs]
class CausalityAnalyzer:
def __init__(self, model_class: Type[Model]):
self.model_class = model_class
self.adjacency_matrix = None
self._dependency = {}
[docs]
def analyze(self):
"""Analyze a model class and return dependency dictionary"""
raise NotImplementedError("Subclasses must implement this method")
[docs]
def build_adjacency_matrix(self) -> pd.DataFrame:
"""Build adjacency matrix from dependencies"""
raise NotImplementedError("Subclasses must implement this method")
[docs]
def check_for_cycles(self):
"""Check for cycles in the model's dependency graph.
While cycles are not necessarily a problem, cycling logic in a given
step may indicate indeterminism. This method returns a list of cycles
found in the created adjacency matrix.
Returns
-------
list
List of cycles found in the graph. Each cycle is a list of nodes.
Returns empty list if no cycles are found.
"""
if self.adjacency_matrix is None:
self.adjacency_matrix = self.analyze()
# Create edgelist from adjacency matrix
edgelist = self._create_edgelist()
# Create directed graph
G = nx.DiGraph()
G.add_edges_from(
[(row["source"], row["target"]) for _, row in edgelist.iterrows()]
)
# Find cycles using simple_cycles
cycles = list(nx.recursive_simple_cycles(G))
return cycles
[docs]
def plot_heatmap(self): # pragma: no cover
"""Visualize the adjacency matrix organized by trophic levels.
Parameters
----------
figsize : tuple, optional
Figure size in inches (width, height), by default (12, 10)
cmap : str, optional
Colormap to use, by default 'viridis'
Returns
-------
matplotlib.figure.Figure
The figure object containing the plot
"""
if self.adjacency_matrix is None:
self.analyze()
# Get the matrix data
matrix = self.adjacency_matrix.values
# Calculate trophic levels using the adjacency matrix
# A variable's trophic level is 1 + max(trophic level of its dependencies)
n = len(matrix)
trophic_levels = np.zeros(n)
max_iter = n # Maximum number of iterations to prevent infinite loops
for _ in range(max_iter):
new_levels = np.zeros(n)
for i in range(n):
# Find all variables that this variable depends on
dependencies = np.where(matrix[i, :] > 0)[0]
if len(dependencies) > 0:
new_levels[i] = 1 + np.max(trophic_levels[dependencies])
else:
new_levels[i] = 1 # Base level for variables with no dependencies
if np.allclose(new_levels, trophic_levels):
break
trophic_levels = new_levels
# Sort indices by trophic level
sorted_indices = np.argsort(trophic_levels)[::-1]
# Create a new DataFrame with reordered indices
reordered_matrix = self.adjacency_matrix.iloc[
sorted_indices, sorted_indices
].copy()
# Replace zeros with NaN for better visualization
reordered_matrix = reordered_matrix.replace(0, np.nan)
# Remove rows and columns that are all NaN
reordered_matrix = reordered_matrix.dropna(axis=1, how="all")
reordered_matrix = reordered_matrix.dropna(axis=0, how="all")
# Create the figure
# Calculate figure size based on number of variables
n_rows, n_cols = reordered_matrix.shape
figsize = (max(8, n_cols * 0.5), max(6, n_rows * 0.5))
fig, ax = plt.subplots(figsize=figsize)
ax.imshow(reordered_matrix.values)
ax.grid(which="major", color="lightgray", linewidth=0.5)
# Set the labels
ax.set_xticks(range(len(reordered_matrix.columns)))
ax.set_yticks(range(len(reordered_matrix.index)))
# Format the labels to show type and name
x_labels = [f"{label[0]}\n{label[1]}" for label in reordered_matrix.columns]
y_labels = [f"{label[0]}\n{label[1]}" for label in reordered_matrix.index]
ax.set_xticklabels(x_labels, rotation=90, ha="center", va="top")
ax.set_yticklabels(y_labels)
# Add source/target labels
ax.set_xlabel("Target Variables", labelpad=10)
ax.set_ylabel("Source Variables", labelpad=10)
# Adjust layout to prevent label cutoff
plt.tight_layout()
return fig, ax
[docs]
def plot_with_cytoscape(
self,
port: int = 8050,
node_styles: dict[str, dict[str, str]] = {
"state": {"background-color": "lightblue", "border-color": "blue"},
"parameters": {"background-color": "lightgreen", "border-color": "green"},
"hyper": {"background-color": "lavender", "border-color": "purple"},
"scenario": {"background-color": "lightyellow", "border-color": "yellow"},
"prior": {"background-color": "lightpink", "border-color": "red"},
"history": {"background-color": "lightgray", "border-color": "black"},
},
): # pragma: no cover
"""Create an interactive flowchart visualization using Dash and Cytoscape.
This method creates a web-based interactive visualization of the model's structure
using Dash and Cytoscape. The visualization allows for:
- Zooming and panning
- Node selection and highlighting
- Edge highlighting
- Node dragging and repositioning
- Export to various formats
Parameters
----------
port : int, optional
The port number to run the Dash server on, by default 8050
Returns
-------
None
Opens a web browser with the interactive visualization
"""
if self.adjacency_matrix is None:
self.adjacency_matrix = self.analyze()
# Create nodes and edges for Cytoscape
nodes = []
edges = []
# Create edgelist using the same method as NetworkX version
edgelist = self._create_edgelist()
edgelist = edgelist[
edgelist["source_type"].isin(node_styles.keys())
& edgelist["target_type"].isin(node_styles.keys())
]
# Get unique nodes from both source and target
all_nodes = pd.concat([edgelist["source"], edgelist["target"]]).unique()
# Add nodes
for node in all_nodes:
node_type = node.split(":")[0]
node_name = node.split(":")[1]
node_style = node_styles.get(node_type, {})
nodes.append(
{
"data": {
"id": str(node), # Ensure ID is string
"label": node_name,
"type": node_type,
},
"style": node_style,
}
)
# Add edges from edgelist
for _, row in edgelist.iterrows():
edges.append(
{
"data": {
"source": str(row["source"]),
"target": str(row["target"]),
"weight": float(row["weight"]),
}
}
)
# Create Dash app
app = dash.Dash(__name__)
# Define the layout
app.layout = html.Div(
[
html.H1("Model Structure", style={"textAlign": "center"}),
cyto.Cytoscape(
id="model-graph",
layout={
"name": "dagre",
"rankDir": "LR",
"nodeSep": 50,
"rankSep": 200,
"spacingFactor": 2.0,
},
style={"width": "100%", "height": "1200px"},
elements=nodes + edges,
stylesheet=[
# Node styles
{
"selector": "node",
"style": {
"content": "data(label)",
"text-valign": "center",
"text-halign": "center",
"text-wrap": "wrap",
"text-max-width": "160px",
"font-size": "28px",
"font-weight": "normal",
"background-color": "data(background-color)",
"border-color": "data(border-color)",
"border-width": 3,
"width": "160px",
"height": "80px",
"padding": "25px",
},
},
# Edge styles
{
"selector": "edge",
"style": {
"width": 4,
"line-color": "#666",
"target-arrow-color": "#666",
"target-arrow-shape": "triangle",
"target-arrow-scale": 3,
"curve-style": "bezier",
},
},
# Hover effects
{
"selector": "node:hover",
"style": {
"background-color": "#BEE",
"line-color": "#000",
"target-arrow-color": "#000",
"source-arrow-color": "#000",
"text-outline-color": "#000",
"text-outline-width": 2,
},
},
{
"selector": "edge:hover",
"style": {
"width": 3,
"line-color": "#000",
"target-arrow-color": "#000",
},
},
],
),
html.Div(
# Create a legend entry for each node typeå
[
html.Div(
[
html.Div(
style={
"width": "20px",
"height": "20px",
"backgroundColor": style.get(
"background-color", "#FFF"
),
"border": f"3px solid {style.get('border-color', '#000')}",
"marginRight": "10px",
"display": "inline-block",
}
),
html.Span(
node_type.title(),
style={"fontSize": "16px"},
),
],
style={
"marginRight": (
"20px" if i < len(node_styles) - 1 else ""
),
"display": "inline-block",
},
)
for i, (node_type, style) in enumerate(node_styles.items())
],
style={
"padding": "10px",
"backgroundColor": "#f8f9fa",
"borderRadius": "5px",
"marginBottom": "20px",
"display": "flex",
"justifyContent": "center",
},
),
],
)
# Run the app
app.run(debug=True, port=port)
def _create_edgelist(self):
"""Create edgelist from adjacency matrix"""
edgelist = self.adjacency_matrix.copy(deep=True)
edgelist = edgelist.stack([0, 1], future_stack=True)
edgelist.name = "weight"
edgelist = edgelist[edgelist != 0]
edgelist = edgelist.reset_index()
edgelist["source"] = edgelist["source_type"] + ":" + edgelist["source_name"]
edgelist["target"] = edgelist["target_type"] + ":" + edgelist["target_name"]
return edgelist