Source code for macrostat.causality.docstring_causality_analyzer

import ast
import inspect
import logging
from typing import Dict, Type

import pandas as pd

from macrostat.causality import CausalityAnalyzer
from macrostat.core import Model

logger = logging.getLogger(__name__)


[docs] class DocstringCausalityAnalyzer(CausalityAnalyzer): def __init__(self, model_class: Type[Model]): super().__init__(model_class=model_class)
[docs] def analyze(self): """Analyze a model class and return dependency dictionary""" # Gather the docstrings self._parse_behavior_docstrings() # Parse the docstrings of each method called by step() self._relations = { k: self._parse_docstring(v) for k, v in self._docstrings.items() } # Build the adjacency matrix self.build_adjacency_matrix() return self.adjacency_matrix
[docs] def build_adjacency_matrix(self): """Build adjacency matrix from dependencies The adjacency matrix maps scenarios, state variables, and parameters to each other. The rows represent the prior, scenario, parameters and state variables, i.e. the dependency section of the docstring. The columns represent the state variables, i.e. the sets section of the docstring. """ # Gather all of the (type, name) pairs in the dependencies and sets dependency_rows, set_columns = set(), set() for components in self._relations.values(): # Handle dependencies for type_name, names in components["Dependency"].items(): for name in names: dependency_rows.add((type_name, name)) # Handle sets for name in components["Sets"]["state"]: set_columns.add(("state", name)) all_entities = set(dependency_rows) | set(set_columns) # Generate the empty adjacency matrix # Build the adjacency matrix self.adjacency_matrix = pd.DataFrame( 0.0, index=pd.MultiIndex.from_tuples( all_entities, names=["source_type", "source_name"] ), columns=pd.MultiIndex.from_tuples( all_entities, names=["target_type", "target_name"] ), dtype=float, ) self.adjacency_matrix.sort_index(axis=0, inplace=True) self.adjacency_matrix.sort_index(axis=1, inplace=True) # Fill adjacency matrix with weight 1 for all dependencies for components in self._relations.values(): for target in components["Sets"]["state"]: for type_name, names in components["Dependency"].items(): for name in names: self.adjacency_matrix.loc[ (type_name, name), ("state", target) ] = 1 return self.adjacency_matrix
########################################################################### # Docstring parsing methods ########################################################################### def _get_methods_called_by_step(self): """Extract docstrings from methods called by step() that have a Dependency and/or Sets section. This function is used to extract the docstrings of the methods called by the step() methods of a Behavior class. It handles inheritance by checking both the current class and its parent classes for method implementations. Sets ------- docstrings : Dict[str, str] Dictionary mapping method names to their docstrings. order : Tuple[str, ...] Tuple with the order of the methods. """ behavior = self.model_class().behavior # Get the correct step method step_method = getattr(behavior, "step") if inspect.isfunction(step_method) or inspect.ismethod(step_method): step_method = step_method # Get the source code for the step method source = inspect.getsource(step_method) # Remove any common leading whitespace from every line lines = source.splitlines() if not lines: logger.warning(f"Empty source code for step method in {behavior.__name__}") return # Find the minimum indentation min_indent = float("inf") for line in lines: if line.strip(): # Skip empty lines indent = len(line) - len(line.lstrip()) min_indent = min(min_indent, indent) # Remove the common indentation if min_indent < float("inf"): source = "\n".join(line[min_indent:] for line in lines) # Parse the step method definition step_node = ast.parse(source) # Visit all nodes in the AST to find method calls self._called_methods = [] for node in ast.walk(step_node): # Look for method calls (self.method_name()) if ( isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and node.func.value.id == "self" and node.func.attr not in self._called_methods ): self._called_methods.append(node.func.attr) # For each called method, find its actual implementation resolved_methods = [] for method_name in self._called_methods: # Find the first class in MRO that implements this method for cls in behavior.__mro__: if hasattr(cls, method_name): resolved_methods.append(method_name) # Convert to tuple to avoid permutation issues self._called_methods = tuple(resolved_methods) def _parse_behavior_docstrings(self): """Extract docstrings from methods called by step() that have a Dependency and/or Sets section. This function is used to extract the docstrings of the methods called by the step() methods of a Behavior class. It then returns a dictionary mapping method names to their docstrings, and a tuple with the order of the methods. Sets ------- docstrings : Dict[str, str] Dictionary mapping method names to their docstrings. order : Tuple[str, ...] Tuple with the order of the methods. """ behavior = self.model_class().behavior self._docstrings = {} self._get_methods_called_by_step() # Get all methods from the class for name, method in inspect.getmembers(behavior, predicate=inspect.isfunction): # Skip private methods and methods not called by step() if name.startswith("_") or name not in self._called_methods: continue # Extract docstring doc = method.__doc__ if doc and ("Dependency" in doc or "Sets" in doc): self._docstrings[name] = doc else: logger.warning(f"No Dependency or Sets section in {name}") def _parse_docstring(self, docstring: str) -> Dict[str, Dict[str, str]]: """Parse a docstring and return a dictionary of dependencies and sets Docstring titles are "underlined" with a variable number of "-" characters. We extract the Dependency and Sets sections. Then for each line in that section, we extract the item type (pre-colon) and the item name (post-colon). Returns ------- Dict[str, Dict[str, str]] Dictionary mapping item type to a dictionary mapping item name to the item value. """ result = {"Dependency": {}, "Sets": {"state": []}} # Split docstring into lines and remove empty lines lines = [line.strip() for line in docstring.split("\n") if line.strip()] current_section = None for i, line in enumerate(lines): if line.replace("-", "").strip() == "": # Check section header (based on all dash underlines) current_section = lines[i - 1].strip() continue elif lines[min(i + 1, len(lines) - 1)].replace("-", "").strip() == "": # If the next line is also an underline, skip it continue elif current_section == "Dependency" and ":" in line: # For Dependency section, parse type:value pairs type_name, value = line.split(":", 1) type_name = type_name.replace("-", "").strip() value = value.strip() # Initialize list if this is the first value for this type if type_name not in result[current_section]: result[current_section][type_name] = [] # Append the value to the list result[current_section][type_name].append(value) elif current_section == "Sets": # For Sets section, just add the state variable name result[current_section]["state"].append(line.replace("-", "").strip()) return result