Source code for macrostat.models.torchmodel

# -*- coding: utf-8 -*-
"""
Generic model class for models implemented using PyTorch
"""

__author__ = ["Karl Naumann-Woleske"]
__credits__ = ["Karl Naumann-Woleske"]
__license__ = "MIT"
__version__ = "0.1.0"
__maintainer__ = ["Karl Naumann-Woleske"]

import logging

logger = logging.getLogger(__name__)

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as torchfunc

import macrostat.models.model as model


[docs] class TorchModel(model.Model, torch.nn.Module): """Generic model class for models implemented using PyTorch This class provides a wrapper for users to write their underlying model behavior while maintaining a uniformly accessible interface. Specifically, the user is expected to adapt the model.forward() function to their needs. The use of PyTorch allows for automatic differentiation which has computational advantages compared to finite differences. Attributes ---------- name: str Name of the model, such as "model". Used for file and database names parameters : dict Dictionary of all parameters hyper_parameters : dict Dictionary of all hyperparameters output : pd.DataFrame None, or the latest simulation run for given parameters Example ------- A general workflow for a model might look like >>> model = Model(parameters, hyper_parameters) >>> output = model.simulate() >>> model.save() """ def __init__( self, parameters: dict = None, hyper_parameters: dict = None, name: str = "torchmodel", ): """Initialisation of the TorchModel class Parameters ---------- parameters : dict dictionary of the named parameters of the model hyper_parameters : dict dictionary of hyper-parameters related to the model name : str (default 'model') name of the model (for use in filenaming) """ # Check hyperparameters, adding defaults if necessary if hyper_parameters is None: hyper_parameters = self._default_hyper_parameters() else: for k, v in self._default_hyper_parameters().items(): if k not in hyper_parameters: hyper_parameters[k] = v # Initialize the parent classes model.Model.__init__( self, parameters=parameters, hyper_parameters=hyper_parameters, name=name ) torch.nn.Module.__init__(self) # For PyTorch we define the parameter order self.parameter_order = tuple(self.parameters.keys()) # Generate pytorch parameters self.tparam = {} self._update_tparam()
[docs] def simulate(self) -> pd.DataFrame: """Simulate a model run using the stored parameters This function is designed to be overwritten by the user's specific implementation of their model. Note that it is expected for the user to set the ''self.output'' attribute to the output generated. The function will run ''self.initialize_simulation'' to set up the model and then run the forward pass of the model. Furthermore, in the pure simulation case, we omit the gradient calculation. Returns ------- output : pd.DataFrame Output of the model. Generically it should have a "time"-like index and variables across the columns """ self._update_tparam() self.initialize_simulation() with torch.no_grad(): self.forward(**self.tparam) self.outputs = pd.DataFrame( {k: v.cpu().clone().detach().numpy() for k, v in self.outputs.items()} ) return self.outputs
[docs] def initialize_simulation(self): """Initialize the model's state variables before running the simulation For instance, if one wants to load in a model state to start from. """ pass
[docs] def forward(self, *args, **kwargs): """Run the model forward through time, e.g. the loop over timesteps goes here. This method is called by the simulation method and by the pytorch autograd system. Generally, it shouldn't be called directly by the user. """ return NotImplementedError
def _update_tparam(self): """Update the pytorch parameters from the parameters dictionary. Ensuring that the parameters are on the correct device and require a gradient """ for k, v in self.parameters.items(): self.tparam[k] = torch.tensor( v, dtype=torch.float64, requires_grad=True, device=self.hyper_parameters["device"], ) def _default_hyper_parameters(self) -> dict: """Return the default hyper parameters. This is primarily a dictionary of the PyTorch specific hyperparameters such as constants and the device. Returns ------- parameters : dict dictionary of the parameters """ return { "device": "cpu", "requires_grad": True, "diffwhere": True, "sigmoid_constant": 100, "tanh_constant": 100, "min_constant": 10, "max_constant": 10, } ### Some Differentiable PyTorch Alternatives
[docs] def diffwhere(self, condition, x1, x2): """Where condition that is differentiable with respect to the condition. Requires: self.hyper_parameters['diffwhere'] = True self.hyper_parameters['sigmoid_constant'] as a large number Note: For non-NaN/inf, where(x > eps, z, y) is (x - eps > 0) * (z - y) + y, so we can use the sigmoid function to approximate the where function. Parameters ---------- condition : torch.Tensor Condition to be evaluated expressed as x - eps x1 : torch.Tensor Value to be returned if condition is True x2 : torch.Tensor Value to be returned if condition is False """ if self.hyper_parameters["diffwhere"]: sig = torch.sigmoid( torch.mul(condition, self.hyper_parameters["sigmoid_constant"]) ) out = torch.add(torch.mul(sig, torch.sub(x1, x2)), x2) else: out = torch.where(condition > 0, x1, x2) return out
[docs] def tanhmask(self, x): """Convert a variable into 0 (x<0) and 1 (x>0)""" kwg = {"dtype": torch.float64, "requires_grad": True} return torch.div( torch.add( torch.ones(x.size(), **kwg), torch.tanh(torch.mul(x, self.hyper_parameters["tanh_constant"])), ), torch.tensor(2.0, **kwg), )
[docs] def diffmin(self, x1, x2): """Smooth approximation to the minimum B: https://mathoverflow.net/questions/35191/a-differentiable-approximation-to-the-minimum-function Requires: self.hyper_parameters['min_constant'] as a large number """ r = self.hyper_parameters["min_constant"] pt1 = torch.exp(torch.mul(x1, -1 * r)) pt2 = torch.exp(torch.mul(x2, -1 * r)) return torch.mul(-1 / r, torch.log(torch.add(pt1, pt2)))
[docs] def diffmax(self, x1, x2): """Smooth approximation to the minimum B: https://mathoverflow.net/questions/35191/a-differentiable-approximation-to-the-minimum-function Requires: self.hyper_parameters['max_constant'] as a large number """ r = self.hyper_parameters["max_constant"] pt1 = torch.exp(torch.mul(x1, r)) pt2 = torch.exp(torch.mul(x2, r)) return torch.mul(1 / r, torch.log(torch.add(pt1, pt2)))
[docs] def diffmin_v(self, x): """Smooth approximation to the minimum. See diffmin Requires: self.hyper_parameters['min_constant'] as a large number """ r = self.hyper_parameters["min_constant"] temp = torch.exp(torch.mul(x, -1 * r)) return torch.mul(-1 / r, torch.log(torch.sum(temp)))
[docs] def diffmax_v(self, x): """Smooth approximation to the maximum for a tensor. See diffmax Requires: self.hyper_parameters['max_constant'] as a large number """ r = self.hyper_parameters["max_constant"] temp = torch.exp(torch.mul(x, r)) return torch.mul(1 / r, torch.log(torch.sum(temp)))
if __name__ == "__main__": torch_model = TorchModel({}) x = torch.tensor([1.0, 2.0, 3.0]) torch_model.hyper_parameters["min_constant"] = 10.0 result = torch_model.diffmax_v(x) print(result) pass