How To: Setting ODE parameters with a Neural Network

This is the case where the neural network is evaluated prior to the simulation and the output from it sets values of parameters in the ODE. This approach can be referred to as a Deep Mechanistic Model (DMM). We assume some familiarity with the getting started tutorial, which examines an entire PEtab SciML problem, while this guide focuses on the parts that are relevant to this use case.

As a case study we will use the Lotka-Volterra ODE system:

\[\frac{\mathrm{d} \text{prey}}{\mathrm{d} t} = \alpha \cdot \text{prey} - \beta \cdot \text{prey} \cdot \text{predator}\]
\[\frac{\mathrm{d} \text{predator}}{\mathrm{d} t} = \gamma \cdot \text{prey} \cdot \text{predator} - \delta \cdot \text{predator}\]

The observable models, which link measurement data to the model output, are then defined as,

\[y_{\text{prey}} = \text{prey} + \epsilon\]
\[y_{\text{predator}} = \text{predator} + \epsilon\]

In our formulation the output from the network replaces the parameter \(\gamma\), so that the ODE system becomes,

\[\frac{\mathrm{d} \text{prey}}{\mathrm{d} t} = \alpha \cdot \text{prey} - \beta \cdot \text{prey} \cdot \text{predator}\]
\[\frac{\mathrm{d} \text{predator}}{\mathrm{d} t} = \text{NN}[0] \cdot \text{prey} \cdot \text{predator} - \delta \cdot \text{predator}\]

The input to the neural network can be either vector based (see other tutorials) or provided via an array file, as demonstrated in this tutorial.

Loading the PEtab problem

Let’s load the PEtab problem, build our model and define the overall hybrid problem as a JAXProblem.

[ ]:
from amici.petab import import_petab_problem
from amici.jax import (
    JAXProblem,
    run_simulations,
)
from petab.v2 import Problem

# Load the PEtab problem information from disk.
petab_problem = Problem.from_yaml("problem.yaml")

# Create a simulator for the ODE and NN models.
jax_model = import_petab_problem(
    petab_problem,
    model_output_dir="model",
    compile_=True,
    jax=True
)

# Create a JAXProblem to handle addition simulation information
# (e.g. simulate multiple conditions).
jax_problem = JAXProblem(jax_model, petab_problem)

By looking at the hybridization, parameters and mapping tables we can see how this deep mechanistic modelling problem has been defined.

[3]:
jax_problem._petab_problem.hybridization_df
[3]:
targetValue
targetId
gamma net3_output1

The gamma parameter from the model is mapped to the first output of the neural network, net3_output1, via the hybridization and mapping tables.

[4]:
petab_problem.mapping_df
[4]:
modelEntityId
petabEntityId
input0 net3.inputs[0]
net3_output1 net3.outputs[0][0]
net3_ps net3.parameters

Finally, in the parameter table, the network parameters are listed with infinite bounds and specified to be estimated. Also note that \(\gamma\) does not appear in the parameters table, as its value is set by the neural network model.

[5]:
petab_problem.parameter_df
[5]:
parameterScale lowerBound upperBound nominalValue estimate
parameterId
alpha lin 0.0 15.0 1.3 1
delta lin 0.0 15.0 1.8 1
beta lin 0.0 15.0 0.9 1
net3_ps lin -inf inf NaN 1

Array input data

The PEtab problem YAML file, in the petab_sciml section of the extensions field, specifies two files of array data to be included in the problem:

  • net3_ps.hdf5 to set the values of the network parameters

  • net3_input2.hdf5 to provide the inputs into the network

...
extensions:
  sciml:
    array_files:
      - "net3_ps.hdf5"
      - "net3_input2.hdf5"
    hybridization_files:
      - "hybridization.tsv"
    neural_nets:
      net3:
        location: "net3.yaml"
        static: true
        format: "YAML"

The nested structure of the neural network input file can be seen below. The input data to the neural network is under inputs/input0 and data for two conditions is specified. The keys for the different conditions match those defined in the conditions table.

[7]:
import h5py

# Convenience function to show the nested structure of an HDF5 file
def show_h5_struct(file):
    file.visit(lambda x: print("  " * (len(x.split("/")) - 1), x.split("/")[-1]))

file = h5py.File("net3_input2.hdf5")
show_h5_struct(file)
 inputs
   input0
     cond1
     cond2
 metadata
   perm
[9]:
petab_problem.condition_df
[9]:
conditionId
cond1
cond2

Also note the static: true setting in the neural network definition of the extensions config. This means the input to the neural network is not expected to depend on the model. This keyword will indicate to PEtab SciML importers that the network precedes the ODE, as opposed to being inside it (i.e. a UDE) or in one of the observable formulae.

Network Architecture

The PyTorch snippet below shows how a network architecture would be defined and exported to YAML format using PEtab SciML. The predefined YAML file is also provided in the PEtab SciML repo for completeness. We have used a convolution architecture here to indicate how the DMM problem set up could enable inclusion of information from high dimensional inputs in the mechanistic model.

[15]:
from petab_sciml.standard.nn_model import Input, NNModel, NNModelStandard
import torch
from torch import nn
import torch.nn.functional as F

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = torch.nn.Conv2d(
            3, 1, (5, 5), stride=(1, 1), padding=(0, 0), dilation=(1, 1)
        )
        self.layer2 = torch.nn.Flatten()
        self.layer3 = torch.nn.Linear(36, 1)

    def forward(self, net_input):
        x = self.layer1(net_input)
        x = self.layer2(x)
        x = self.layer3(x)
        x = F.relu(x)
        return x

net1 = NeuralNetwork()
nn_model1 = NNModel.from_pytorch_module(
    module=net1, nn_model_id="net3", inputs=[Input(input_id="input0")]
)
NNModelStandard.save_data(
    data=nn_model1, filename="net3_from_pytorch.yaml"
)

Parameterizing network inputs

A straightforward alternative to defining network inputs as array data in files, is to define them in the parameters table, with appropriate bounds and nominal values.

parameterId

parameterScale

lowerBound

upperBound

nominalValue

estimate

net1_input_pre1

lin

-inf

inf

1

0

net1_input_pre2

lin

-inf

inf

1

0

The network inputs can be defined as condition specific, with different scalar values for different conditions.

conditionId

net1_input1

net1_input2

cond1

10.0

20.0

cond2

net1_input_pre1

net1_input_pre2

The corresponding mapping table then defines the model entities for condition PEtab identifiers.

petabEntityId

modelEntityId

net1_input1

net1.inputs[0][0]

net1_input2

net1.inputs[0][1]