How To: Machine Learning models in the observables
This guide covers how to include a machine learning (ML) model in the observable formula, which links the model output to the observed measurement data. We assume some familiarity with the getting started tutorial, which examines an entire PEtab SciML problem, while this guide focuses on the parts that are relevant to the observable use case. As a case study we will use the Lotka-Volterra ODE system:
The observable models are defined as,
There are two observables in this system, prey and predator, which will have measurements associated with them. The species in these expressions are not substituted in the SBML model file, but instead the inclusion of the ML model is accomplished by correctly formulating the observable, hybridization and mapping PEtab tables. The observable model for prey is formulated as,
The predator observable model is unchanged in this example.
Loading the PEtab problem
Let’s load the PEtab problem so that we can examine the contents of the relevant PEtab tables.
[ ]:
from amici.petab import import_petab_problem
from amici.jax import (
JAXProblem,
run_simulations,
)
from petab.v2 import Problem
# Load the PEtab problem information from disk.
petab_problem = Problem.from_yaml("problem.yaml")
# Create a simulator for the ODE and NN models.
jax_model = import_petab_problem(
petab_problem,
model_output_dir="model",
compile_=True,
jax=True
)
# Create a JAXProblem to handle addition simulation information
# (e.g. simulate multiple conditions).
jax_problem = JAXProblem(jax_model, petab_problem)
[2]:
petab_problem.observable_df
[2]:
| observableFormula | noiseFormula | observableTransformation | noiseDistribution | |
|---|---|---|---|---|
| observableId | ||||
| prey_o | net1_output1 | 0.05 | lin | normal |
| predator_o | predator | 0.05 | lin | normal |
We can see here that the formula for the prey observable is defined by a PEtab identifier indicating a network output. The model definition of that PEtab id can be found in the mapping table. It is given by the first output of the neural network.
[3]:
petab_problem.mapping_df
[3]:
| modelEntityId | |
|---|---|
| petabEntityId | |
| net1_input1 | net1.inputs[0][0] |
| net1_input2 | net1.inputs[0][1] |
| net1_output1 | net1.outputs[0][0] |
| net1_ps | net1.parameters |
Wherever the problem formulation refers to the prey observable (prey_o) the neural network will be evaluated and its output used in place of that observable value, for instance, in the measurements table.
[4]:
petab_problem.measurement_df
[4]:
| observableId | simulationConditionId | measurement | time | |
|---|---|---|---|---|
| 0 | prey_o | cond1 | 0.173017 | 1.0 |
| 1 | prey_o | cond1 | 0.489177 | 2.0 |
| 2 | prey_o | cond1 | 1.643996 | 3.0 |
| 3 | prey_o | cond1 | 5.451963 | 4.0 |
| 4 | prey_o | cond1 | 2.977522 | 5.0 |
| 5 | prey_o | cond1 | 0.181663 | 6.0 |
| 6 | prey_o | cond1 | 0.348112 | 7.0 |
| 7 | prey_o | cond1 | 0.937919 | 8.0 |
| 8 | prey_o | cond1 | 3.113240 | 9.0 |
| 9 | prey_o | cond1 | 8.863933 | 10.0 |
| 10 | predator_o | cond1 | 0.847416 | 1.0 |
| 11 | predator_o | cond1 | 0.211135 | 2.0 |
| 12 | predator_o | cond1 | -0.025054 | 3.0 |
| 13 | predator_o | cond1 | 0.125010 | 4.0 |
| 14 | predator_o | cond1 | 6.700455 | 5.0 |
| 15 | predator_o | cond1 | 2.007158 | 6.0 |
| 16 | predator_o | cond1 | 0.420092 | 7.0 |
| 17 | predator_o | cond1 | 0.048032 | 8.0 |
| 18 | predator_o | cond1 | 0.128669 | 9.0 |
| 19 | predator_o | cond1 | 1.192784 | 10.0 |
It is worth noting that the PEtab standard supports mathematical expressions in the observables table. For instance, instead of net1_output1, we could write net1_output1^2 + 2 in its place and PEtab importers would duly perform that arithmetic.