{
"cells": [
{
"cell_type": "markdown",
"id": "81cd17cb-2d7f-4d74-a63f-e6ff16ec8319",
"metadata": {},
"source": [
"# How To: Machine Learning models in the observables\n",
"\n",
"This guide covers how to include a machine learning (ML) model in the observable formula, which links the model output to the observed measurement data. We assume some familiarity with the getting started tutorial, which examines an entire PEtab SciML problem, while this guide focuses on the parts that are relevant to the observable use case. As a case study we will use the Lotka-Volterra ODE system:\n",
"\n",
"$$\\frac{\\mathrm{d} \\text{prey}}{\\mathrm{d} t} = \\alpha \\cdot \\text{prey} - \\beta \\cdot \\text{prey} \\cdot \\text{predator}$$\n",
"\n",
"$$\\frac{\\mathrm{d} \\text{predator}}{\\mathrm{d} t} = \\gamma \\cdot \\text{prey} \\cdot \\text{predator} - \\delta \\cdot \\text{predator}$$\n",
"\n",
"The observable models are defined as,\n",
"\n",
"$$y_{\\text{prey}} = \\text{prey} + \\epsilon$$\n",
"\n",
"$$y_{\\text{predator}} = \\text{predator} + \\epsilon$$\n",
"\n",
"There are two observables in this system, `prey` and `predator`, which will have measurements associated with them. The species in these expressions are not substituted in the SBML model file, but instead the inclusion of the ML model is accomplished by correctly formulating the observable, hybridization and mapping PEtab tables. The observable model for `prey` is formulated as,\n",
"\n",
"$$y_{\\text{prey}} = \\text{NN}(\\text{prey}, \\text{predator})[0] + \\epsilon$$\n",
"\n",
"The `predator` observable model is unchanged in this example."
]
},
{
"cell_type": "markdown",
"id": "51c1a0d5-102f-46a4-821d-28f58bc4be48",
"metadata": {},
"source": [
"## Loading the PEtab problem"
]
},
{
"cell_type": "markdown",
"id": "ad529fe5-a91b-41b3-af22-6a8b855c50cf",
"metadata": {},
"source": [
"Let's load the PEtab problem so that we can examine the contents of the relevant PEtab tables."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d810c818-7914-4e3f-acbf-9596fe4cbfda",
"metadata": {},
"outputs": [],
"source": [
"from amici.petab import import_petab_problem\n",
"from amici.jax import (\n",
" JAXProblem,\n",
" run_simulations,\n",
")\n",
"from petab.v2 import Problem\n",
"\n",
"# Load the PEtab problem information from disk.\n",
"petab_problem = Problem.from_yaml(\"problem.yaml\")\n",
"\n",
"# Create a simulator for the ODE and NN models.\n",
"jax_model = import_petab_problem(\n",
" petab_problem,\n",
" model_output_dir=\"model\",\n",
" compile_=True,\n",
" jax=True\n",
")\n",
"\n",
"# Create a JAXProblem to handle addition simulation information \n",
"# (e.g. simulate multiple conditions).\n",
"jax_problem = JAXProblem(jax_model, petab_problem)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d44cbd06-a3fe-4984-86dc-d5a48c7718c5",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" observableFormula | \n",
" noiseFormula | \n",
" observableTransformation | \n",
" noiseDistribution | \n",
"
\n",
" \n",
" | observableId | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | prey_o | \n",
" net1_output1 | \n",
" 0.05 | \n",
" lin | \n",
" normal | \n",
"
\n",
" \n",
" | predator_o | \n",
" predator | \n",
" 0.05 | \n",
" lin | \n",
" normal | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" observableFormula noiseFormula observableTransformation \\\n",
"observableId \n",
"prey_o net1_output1 0.05 lin \n",
"predator_o predator 0.05 lin \n",
"\n",
" noiseDistribution \n",
"observableId \n",
"prey_o normal \n",
"predator_o normal "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"petab_problem.observable_df"
]
},
{
"cell_type": "markdown",
"id": "618ca834-e138-4189-bf93-16c46c5daa7a",
"metadata": {},
"source": [
"We can see here that the formula for the ``prey`` observable is defined by a PEtab identifier indicating a network output. The model definition of that PEtab id can be found in the mapping table. It is given by the first output of the neural network."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "217954a4-3488-4f23-9a8e-15ca29ca5a08",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" modelEntityId | \n",
"
\n",
" \n",
" | petabEntityId | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | net1_input1 | \n",
" net1.inputs[0][0] | \n",
"
\n",
" \n",
" | net1_input2 | \n",
" net1.inputs[0][1] | \n",
"
\n",
" \n",
" | net1_output1 | \n",
" net1.outputs[0][0] | \n",
"
\n",
" \n",
" | net1_ps | \n",
" net1.parameters | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" modelEntityId\n",
"petabEntityId \n",
"net1_input1 net1.inputs[0][0]\n",
"net1_input2 net1.inputs[0][1]\n",
"net1_output1 net1.outputs[0][0]\n",
"net1_ps net1.parameters"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"petab_problem.mapping_df"
]
},
{
"cell_type": "markdown",
"id": "8b8e370f-49ea-40e3-93b2-b3c7c4a74f48",
"metadata": {},
"source": [
"Wherever the problem formulation refers to the prey observable (``prey_o``) the neural network will be evaluated and its output used in place of that observable value, for instance, in the measurements table."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e1fac571-1ad7-4acc-9fb6-c43b06e6d53d",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" observableId | \n",
" simulationConditionId | \n",
" measurement | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" prey_o | \n",
" cond1 | \n",
" 0.173017 | \n",
" 1.0 | \n",
"
\n",
" \n",
" | 1 | \n",
" prey_o | \n",
" cond1 | \n",
" 0.489177 | \n",
" 2.0 | \n",
"
\n",
" \n",
" | 2 | \n",
" prey_o | \n",
" cond1 | \n",
" 1.643996 | \n",
" 3.0 | \n",
"
\n",
" \n",
" | 3 | \n",
" prey_o | \n",
" cond1 | \n",
" 5.451963 | \n",
" 4.0 | \n",
"
\n",
" \n",
" | 4 | \n",
" prey_o | \n",
" cond1 | \n",
" 2.977522 | \n",
" 5.0 | \n",
"
\n",
" \n",
" | 5 | \n",
" prey_o | \n",
" cond1 | \n",
" 0.181663 | \n",
" 6.0 | \n",
"
\n",
" \n",
" | 6 | \n",
" prey_o | \n",
" cond1 | \n",
" 0.348112 | \n",
" 7.0 | \n",
"
\n",
" \n",
" | 7 | \n",
" prey_o | \n",
" cond1 | \n",
" 0.937919 | \n",
" 8.0 | \n",
"
\n",
" \n",
" | 8 | \n",
" prey_o | \n",
" cond1 | \n",
" 3.113240 | \n",
" 9.0 | \n",
"
\n",
" \n",
" | 9 | \n",
" prey_o | \n",
" cond1 | \n",
" 8.863933 | \n",
" 10.0 | \n",
"
\n",
" \n",
" | 10 | \n",
" predator_o | \n",
" cond1 | \n",
" 0.847416 | \n",
" 1.0 | \n",
"
\n",
" \n",
" | 11 | \n",
" predator_o | \n",
" cond1 | \n",
" 0.211135 | \n",
" 2.0 | \n",
"
\n",
" \n",
" | 12 | \n",
" predator_o | \n",
" cond1 | \n",
" -0.025054 | \n",
" 3.0 | \n",
"
\n",
" \n",
" | 13 | \n",
" predator_o | \n",
" cond1 | \n",
" 0.125010 | \n",
" 4.0 | \n",
"
\n",
" \n",
" | 14 | \n",
" predator_o | \n",
" cond1 | \n",
" 6.700455 | \n",
" 5.0 | \n",
"
\n",
" \n",
" | 15 | \n",
" predator_o | \n",
" cond1 | \n",
" 2.007158 | \n",
" 6.0 | \n",
"
\n",
" \n",
" | 16 | \n",
" predator_o | \n",
" cond1 | \n",
" 0.420092 | \n",
" 7.0 | \n",
"
\n",
" \n",
" | 17 | \n",
" predator_o | \n",
" cond1 | \n",
" 0.048032 | \n",
" 8.0 | \n",
"
\n",
" \n",
" | 18 | \n",
" predator_o | \n",
" cond1 | \n",
" 0.128669 | \n",
" 9.0 | \n",
"
\n",
" \n",
" | 19 | \n",
" predator_o | \n",
" cond1 | \n",
" 1.192784 | \n",
" 10.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" observableId simulationConditionId measurement time\n",
"0 prey_o cond1 0.173017 1.0\n",
"1 prey_o cond1 0.489177 2.0\n",
"2 prey_o cond1 1.643996 3.0\n",
"3 prey_o cond1 5.451963 4.0\n",
"4 prey_o cond1 2.977522 5.0\n",
"5 prey_o cond1 0.181663 6.0\n",
"6 prey_o cond1 0.348112 7.0\n",
"7 prey_o cond1 0.937919 8.0\n",
"8 prey_o cond1 3.113240 9.0\n",
"9 prey_o cond1 8.863933 10.0\n",
"10 predator_o cond1 0.847416 1.0\n",
"11 predator_o cond1 0.211135 2.0\n",
"12 predator_o cond1 -0.025054 3.0\n",
"13 predator_o cond1 0.125010 4.0\n",
"14 predator_o cond1 6.700455 5.0\n",
"15 predator_o cond1 2.007158 6.0\n",
"16 predator_o cond1 0.420092 7.0\n",
"17 predator_o cond1 0.048032 8.0\n",
"18 predator_o cond1 0.128669 9.0\n",
"19 predator_o cond1 1.192784 10.0"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"petab_problem.measurement_df"
]
},
{
"cell_type": "markdown",
"id": "261867c9-48a2-4310-ac0a-e52ae2012e85",
"metadata": {},
"source": [
"It is worth noting that the PEtab standard supports [mathematical expressions](https://petab.readthedocs.io/en/latest/v2/documentation_data_format.html#math-expressions-syntax) in the observables table. For instance, instead of `net1_output1`, we could write `net1_output1^2 + 2` in its place and PEtab importers would duly perform that arithmetic. "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}