How To: Neural ODE
In this guide we will cover how to set up a Neural ODE problem using the PEtab SciML format and utility functions from the petab_sciml Python library. We assume some familiarity with the getting started tutorial, which examines an entire PEtab SciML problem, while this guide focuses on the parts that are relevant to the Neural ODE use case.
In the Neural ODE case, the whole right-hand-side of the ODE model is replaced with a neural network. Taking the Lotka-Volterra system as an example,
where the observable models, which link measurement data to the model output, are defined as,
The problem can be configured as a Neural ODE with,
Defining the network architecture
[1]:
from petab_sciml.standard.nn_model import Input, NNModel, NNModelStandard
import torch
from torch import nn
import torch.nn.functional as F
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(2, 10)
self.layer2 = torch.nn.Linear(10, 10)
self.layer3 = torch.nn.Linear(10, 2)
def forward(self, net_input):
x = self.layer1(net_input)
x = F.tanh(x)
x = self.layer2(x)
x = F.tanh(x)
x = self.layer3(x)
return x
net1 = NeuralNetwork()
nn_model1 = NNModel.from_pytorch_module(
module=net1, nn_model_id="net1", inputs=[Input(input_id="input0")]
)
NNModelStandard.save_data(
data=nn_model1, filename="net1.yaml"
)
The network architecture in this example is kept simple for demonstration purposes. Refer to the page on supported layers and activation functions for more inspiration, but note that PEtab SciML and its importers currently only support Neural ODEs where the neural network has vector inputs and outputs.
Generating the PEtab files
The PEtab SciML Python package provides utility functions to generate the model and PEtab files for a neural ODE case. The names of the species in the ODE system are required to generate the model. The utility functions will generate hybridization, mapping and parameter files.
[3]:
from petab_sciml.problem_utils.neural_ode import (
create_neural_ode,
create_neural_ode_problem
)
create_neural_ode(["prey", "predator"], model_filename="lv.xml")
In order to completely define the PEtab problem, the measurement, observable and array input files need to be supplied by the user. There is then a utility function to generate the problem.yaml file and reference all the PEtab files in it. Example files are included in the docs as a demonstration.
[4]:
create_neural_ode_problem(
"lv.xml",
"measurements.tsv",
"observables.tsv",
"net1.yaml",
["net1_ps.hdf5"]
)
Loading the PEtab problem
[ ]:
from amici.petab import import_petab_problem
from amici.jax import (
JAXProblem,
run_simulations,
)
from petab.v2 import Problem
# Load the PEtab problem information from disk.
petab_problem = Problem.from_yaml("problem.yaml")
# Create a simulator for the ODE and NN models.
jax_model = import_petab_problem(
petab_problem,
model_output_dir="model",
compile_=True,
jax=True
)
# Create a JAXProblem to handle addition simulation information
# (e.g. simulate multiple conditions).
jax_problem = JAXProblem(jax_model, petab_problem)
The hybridization and mapping tables show us how the neural network inputs and outputs are mapped to the model.
[9]:
jax_problem._petab_problem.hybridization_df
[9]:
| targetValue | |
|---|---|
| targetId | |
| net1_input0 | prey |
| net1_input1 | predator |
| prey_param | net1_output0 |
| predator_param | net1_output1 |
[10]:
petab_problem.mapping_df
[10]:
| modelEntityId | |
|---|---|
| petabEntityId | |
| net1_input0 | net1.inputs[0][0] |
| net1_input1 | net1.inputs[0][1] |
| net1_output0 | net1.outputs[0][0] |
| net1_output1 | net1.outputs[0][1] |
| net1_ps | net1.parameters |
The input to the neural network is the state (the amounts of prey and predator). The rate of change in the prey and predator is given respectively by the first and second output of the network.
It is also worth showing that the parameter table only has the network parameters defined in it. Unlike previous examples, there are no other parameters to be estimated in the problem. This PEtab SciML problem specifies that the neural network parameters must be optimised, and that the outputs of the neural network will provide the time derivative of the solution.
[11]:
petab_problem.parameter_df
[11]:
| parameterScale | lowerBound | upperBound | nominalValue | estimate | |
|---|---|---|---|---|---|
| parameterId | |||||
| net1_ps | lin | -inf | inf | NaN | 1 |