How To: Setting ODE parameters with a Neural Network
This is the case where the neural network is evaluated prior to the simulation and the output from it sets values of parameters in the ODE. This approach can be referred to as a Deep Mechanistic Model (DMM). We assume some familiarity with the getting started tutorial, which examines an entire PEtab SciML problem, while this guide focuses on the parts that are relevant to this use case.
As a case study we will use the Lotka-Volterra ODE system:
The observable models, which link measurement data to the model output, are then defined as,
In our formulation the output from the network replaces the parameter \(\gamma\), so that the ODE system becomes,
The input to the neural network can be either vector based (see other tutorials) or provided via an array file, as demonstrated in this tutorial.
Loading the PEtab problem
Let’s load the PEtab problem, build our model and define the overall hybrid problem as a JAXProblem.
[ ]:
from amici.petab import import_petab_problem
from amici.jax import (
JAXProblem,
run_simulations,
)
from petab.v2 import Problem
# Load the PEtab problem information from disk.
petab_problem = Problem.from_yaml("problem.yaml")
# Create a simulator for the ODE and NN models.
jax_model = import_petab_problem(
petab_problem,
model_output_dir="model",
compile_=True,
jax=True
)
# Create a JAXProblem to handle addition simulation information
# (e.g. simulate multiple conditions).
jax_problem = JAXProblem(jax_model, petab_problem)
By looking at the hybridization, parameters and mapping tables we can see how this deep mechanistic modelling problem has been defined.
[3]:
jax_problem._petab_problem.hybridization_df
[3]:
| targetValue | |
|---|---|
| targetId | |
| gamma | net3_output1 |
The gamma parameter from the model is mapped to the first output of the neural network, net3_output1, via the hybridization and mapping tables.
[4]:
petab_problem.mapping_df
[4]:
| modelEntityId | |
|---|---|
| petabEntityId | |
| input0 | net3.inputs[0] |
| net3_output1 | net3.outputs[0][0] |
| net3_ps | net3.parameters |
Finally, in the parameter table, the network parameters are listed with infinite bounds and specified to be estimated. Also note that \(\gamma\) does not appear in the parameters table, as its value is set by the neural network model.
[5]:
petab_problem.parameter_df
[5]:
| parameterScale | lowerBound | upperBound | nominalValue | estimate | |
|---|---|---|---|---|---|
| parameterId | |||||
| alpha | lin | 0.0 | 15.0 | 1.3 | 1 |
| delta | lin | 0.0 | 15.0 | 1.8 | 1 |
| beta | lin | 0.0 | 15.0 | 0.9 | 1 |
| net3_ps | lin | -inf | inf | NaN | 1 |
Array input data
The PEtab problem YAML file, in the petab_sciml section of the extensions field, specifies two files of array data to be included in the problem:
net3_ps.hdf5to set the values of the network parametersnet3_input2.hdf5to provide the inputs into the network
...
extensions:
sciml:
array_files:
- "net3_ps.hdf5"
- "net3_input2.hdf5"
hybridization_files:
- "hybridization.tsv"
neural_nets:
net3:
location: "net3.yaml"
static: true
format: "YAML"
The nested structure of the neural network input file can be seen below. The input data to the neural network is under inputs/input0 and data for two conditions is specified. The keys for the different conditions match those defined in the conditions table.
[7]:
import h5py
# Convenience function to show the nested structure of an HDF5 file
def show_h5_struct(file):
file.visit(lambda x: print(" " * (len(x.split("/")) - 1), x.split("/")[-1]))
file = h5py.File("net3_input2.hdf5")
show_h5_struct(file)
inputs
input0
cond1
cond2
metadata
perm
[9]:
petab_problem.condition_df
[9]:
| conditionId |
|---|
| cond1 |
| cond2 |
Also note the static: true setting in the neural network definition of the extensions config. This means the input to the neural network is not expected to depend on the model. This keyword will indicate to PEtab SciML importers that the network precedes the ODE, as opposed to being inside it (i.e. a UDE) or in one of the observable formulae.
Network Architecture
The PyTorch snippet below shows how a network architecture would be defined and exported to YAML format using PEtab SciML. The predefined YAML file is also provided in the PEtab SciML repo for completeness. We have used a convolution architecture here to indicate how the DMM problem set up could enable inclusion of information from high dimensional inputs in the mechanistic model.
[15]:
from petab_sciml.standard.nn_model import Input, NNModel, NNModelStandard
import torch
from torch import nn
import torch.nn.functional as F
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Conv2d(
3, 1, (5, 5), stride=(1, 1), padding=(0, 0), dilation=(1, 1)
)
self.layer2 = torch.nn.Flatten()
self.layer3 = torch.nn.Linear(36, 1)
def forward(self, net_input):
x = self.layer1(net_input)
x = self.layer2(x)
x = self.layer3(x)
x = F.relu(x)
return x
net1 = NeuralNetwork()
nn_model1 = NNModel.from_pytorch_module(
module=net1, nn_model_id="net3", inputs=[Input(input_id="input0")]
)
NNModelStandard.save_data(
data=nn_model1, filename="net3_from_pytorch.yaml"
)
Parameterizing network inputs
A straightforward alternative to defining network inputs as array data in files, is to define them in the parameters table, with appropriate bounds and nominal values.
parameterId |
parameterScale |
lowerBound |
upperBound |
nominalValue |
estimate |
|---|---|---|---|---|---|
net1_input_pre1 |
lin |
-inf |
inf |
1 |
0 |
net1_input_pre2 |
lin |
-inf |
inf |
1 |
0 |
The network inputs can be defined as condition specific, with different scalar values for different conditions.
conditionId |
net1_input1 |
net1_input2 |
|---|---|---|
cond1 |
10.0 |
20.0 |
cond2 |
net1_input_pre1 |
net1_input_pre2 |
The corresponding mapping table then defines the model entities for condition PEtab identifiers.
petabEntityId |
modelEntityId |
|---|---|
net1_input1 |
net1.inputs[0][0] |
net1_input2 |
net1.inputs[0][1] |