{
"cells": [
{
"cell_type": "markdown",
"id": "b3339048-6d80-4eeb-befa-4bc1b3c468ec",
"metadata": {},
"source": [
"# How To: Setting ODE parameters with a Neural Network"
]
},
{
"cell_type": "markdown",
"id": "d54f42a2-a31d-4fad-926a-b562ba26cc11",
"metadata": {},
"source": [
"This is the case where the neural network is evaluated prior to the simulation and the output from it sets values of parameters in the ODE. This approach can be referred to as a Deep Mechanistic Model (DMM). We assume some familiarity with the getting started tutorial, which examines an entire PEtab SciML problem, while this guide focuses on the parts that are relevant to this use case. \n",
"\n",
"As a case study we will use the Lotka-Volterra ODE system:\n",
"\n",
"$$\\frac{\\mathrm{d} \\text{prey}}{\\mathrm{d} t} = \\alpha \\cdot \\text{prey} - \\beta \\cdot \\text{prey} \\cdot \\text{predator}$$\n",
"\n",
"$$\\frac{\\mathrm{d} \\text{predator}}{\\mathrm{d} t} = \\gamma \\cdot \\text{prey} \\cdot \\text{predator} - \\delta \\cdot \\text{predator}$$\n",
"\n",
"The observable models, which link measurement data to the model output, are then defined as,\n",
"\n",
"$$y_{\\text{prey}} = \\text{prey} + \\epsilon$$\n",
"\n",
"$$y_{\\text{predator}} = \\text{predator} + \\epsilon$$\n",
"\n",
"In our formulation the output from the network replaces the parameter $\\gamma$, so that the ODE system becomes,\n",
"\n",
"$$\\frac{\\mathrm{d} \\text{prey}}{\\mathrm{d} t} = \\alpha \\cdot \\text{prey} - \\beta \\cdot \\text{prey} \\cdot \\text{predator}$$\n",
"\n",
"$$\\frac{\\mathrm{d} \\text{predator}}{\\mathrm{d} t} = \\text{NN}[0] \\cdot \\text{prey} \\cdot \\text{predator} - \\delta \\cdot \\text{predator}$$\n",
"\n",
"The input to the neural network can be either vector based (see other tutorials) or provided via an array file, as demonstrated in this tutorial."
]
},
{
"cell_type": "markdown",
"id": "292488b8-24c6-43c9-9e88-5270fa0efb33",
"metadata": {},
"source": [
"## Loading the PEtab problem"
]
},
{
"cell_type": "markdown",
"id": "79c7c39c-e010-4bfd-971f-70bf26c82d93",
"metadata": {},
"source": [
"Let's load the PEtab problem, build our model and define the overall hybrid problem as a ``JAXProblem``."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9880e7f-a9dd-4a00-9ee1-5ed95eda9110",
"metadata": {},
"outputs": [],
"source": [
"from amici.petab import import_petab_problem\n",
"from amici.jax import (\n",
" JAXProblem,\n",
" run_simulations,\n",
")\n",
"from petab.v2 import Problem\n",
"\n",
"# Load the PEtab problem information from disk.\n",
"petab_problem = Problem.from_yaml(\"problem.yaml\")\n",
"\n",
"# Create a simulator for the ODE and NN models.\n",
"jax_model = import_petab_problem(\n",
" petab_problem,\n",
" model_output_dir=\"model\",\n",
" compile_=True,\n",
" jax=True\n",
")\n",
"\n",
"# Create a JAXProblem to handle addition simulation information \n",
"# (e.g. simulate multiple conditions).\n",
"jax_problem = JAXProblem(jax_model, petab_problem)"
]
},
{
"cell_type": "markdown",
"id": "da148931-f7ba-4b1b-850d-8c8754cd4ad9",
"metadata": {},
"source": [
"By looking at the hybridization, parameters and mapping tables we can see how this deep mechanistic modelling problem has been defined."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7c1e0233-7e4b-411b-8ff7-ac57a2febf06",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" targetValue | \n",
"
\n",
" \n",
" | targetId | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | gamma | \n",
" net3_output1 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" targetValue\n",
"targetId \n",
"gamma net3_output1"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax_problem._petab_problem.hybridization_df"
]
},
{
"cell_type": "markdown",
"id": "38411614-bf14-4854-b655-5b29c3e345ec",
"metadata": {},
"source": [
"The ``gamma`` parameter from the model is mapped to the first output of the neural network, ``net3_output1``, via the hybridization and mapping tables."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "832c8deb-927d-4189-b278-0bc8e5dfdfbd",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" modelEntityId | \n",
"
\n",
" \n",
" | petabEntityId | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | input0 | \n",
" net3.inputs[0] | \n",
"
\n",
" \n",
" | net3_output1 | \n",
" net3.outputs[0][0] | \n",
"
\n",
" \n",
" | net3_ps | \n",
" net3.parameters | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" modelEntityId\n",
"petabEntityId \n",
"input0 net3.inputs[0]\n",
"net3_output1 net3.outputs[0][0]\n",
"net3_ps net3.parameters"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"petab_problem.mapping_df"
]
},
{
"cell_type": "markdown",
"id": "c999332c-e901-4f2c-877a-7ca482f75fdd",
"metadata": {},
"source": [
"Finally, in the parameter table, the network parameters are listed with infinite bounds and specified to be estimated. Also note that $\\gamma$ does not appear in the parameters table, as its value is set by the neural network model."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3cb05448-ca2a-4fbc-92e8-9d6f1f8ff630",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" parameterScale | \n",
" lowerBound | \n",
" upperBound | \n",
" nominalValue | \n",
" estimate | \n",
"
\n",
" \n",
" | parameterId | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | alpha | \n",
" lin | \n",
" 0.0 | \n",
" 15.0 | \n",
" 1.3 | \n",
" 1 | \n",
"
\n",
" \n",
" | delta | \n",
" lin | \n",
" 0.0 | \n",
" 15.0 | \n",
" 1.8 | \n",
" 1 | \n",
"
\n",
" \n",
" | beta | \n",
" lin | \n",
" 0.0 | \n",
" 15.0 | \n",
" 0.9 | \n",
" 1 | \n",
"
\n",
" \n",
" | net3_ps | \n",
" lin | \n",
" -inf | \n",
" inf | \n",
" NaN | \n",
" 1 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" parameterScale lowerBound upperBound nominalValue estimate\n",
"parameterId \n",
"alpha lin 0.0 15.0 1.3 1\n",
"delta lin 0.0 15.0 1.8 1\n",
"beta lin 0.0 15.0 0.9 1\n",
"net3_ps lin -inf inf NaN 1"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"petab_problem.parameter_df"
]
},
{
"cell_type": "markdown",
"id": "bb0d08d2-088b-4850-96e2-d0c12e1605eb",
"metadata": {},
"source": [
"### Array input data\n",
"\n",
"The PEtab problem YAML file, in the `petab_sciml` section of the `extensions` field, specifies two files of array data to be included in the problem:\n",
"- ``net3_ps.hdf5`` to set the values of the network parameters\n",
"- ``net3_input2.hdf5`` to provide the inputs into the network"
]
},
{
"cell_type": "markdown",
"id": "7202b239-6e80-4a19-bde7-b35e809dc3bb",
"metadata": {},
"source": [
"```yaml\n",
"...\n",
"extensions:\n",
" sciml:\n",
" array_files:\n",
" - \"net3_ps.hdf5\"\n",
" - \"net3_input2.hdf5\"\n",
" hybridization_files:\n",
" - \"hybridization.tsv\"\n",
" neural_nets:\n",
" net3:\n",
" location: \"net3.yaml\"\n",
" static: true\n",
" format: \"YAML\"\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "3393b6c1-b193-4f64-8269-838c054bce97",
"metadata": {},
"source": [
"The nested structure of the neural network input file can be seen below. The input data to the neural network is under ``inputs/input0`` and data for two conditions is specified. The keys for the different conditions match those defined in the conditions table."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "87200584-39ec-4bb2-a4ee-919065598b65",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" inputs\n",
" input0\n",
" cond1\n",
" cond2\n",
" metadata\n",
" perm\n"
]
}
],
"source": [
"import h5py\n",
"\n",
"# Convenience function to show the nested structure of an HDF5 file\n",
"def show_h5_struct(file):\n",
" file.visit(lambda x: print(\" \" * (len(x.split(\"/\")) - 1), x.split(\"/\")[-1]))\n",
"\n",
"file = h5py.File(\"net3_input2.hdf5\")\n",
"show_h5_struct(file)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "78615c62-63c7-4b53-ab69-1a0cff91482d",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
"
\n",
" \n",
" | conditionId | \n",
"
\n",
" \n",
" \n",
" \n",
" | cond1 | \n",
"
\n",
" \n",
" | cond2 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: []\n",
"Index: [cond1, cond2]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"petab_problem.condition_df"
]
},
{
"cell_type": "markdown",
"id": "c26425d8-2b07-4262-9bbe-fb67f7dc24d0",
"metadata": {},
"source": [
"\n",
"Also note the ``static: true`` setting in the neural network definition of the extensions config. This means the input to the neural network is not expected to depend on the model. This keyword will indicate to PEtab SciML importers that the network precedes the ODE, as opposed to being inside it (i.e. a UDE) or in one of the observable formulae.\n",
"
\n"
]
},
{
"cell_type": "markdown",
"id": "c6a998ed-e561-4025-8572-36d7d1eb350a",
"metadata": {},
"source": [
"### Network Architecture\n",
"\n",
"The PyTorch snippet below shows how a network architecture would be defined and exported to YAML format using PEtab SciML. The predefined [YAML file](./net3.yaml) is also provided in the PEtab SciML repo for completeness. We have used a convolution architecture here to indicate how the DMM problem set up could enable inclusion of information from high dimensional inputs in the mechanistic model. "
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "743ed691-5ceb-41b8-9255-6576bfcc3996",
"metadata": {},
"outputs": [],
"source": [
"from petab_sciml.standard.nn_model import Input, NNModel, NNModelStandard\n",
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"\n",
"class NeuralNetwork(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.layer1 = torch.nn.Conv2d(\n",
" 3, 1, (5, 5), stride=(1, 1), padding=(0, 0), dilation=(1, 1)\n",
" )\n",
" self.layer2 = torch.nn.Flatten()\n",
" self.layer3 = torch.nn.Linear(36, 1)\n",
"\n",
" def forward(self, net_input):\n",
" x = self.layer1(net_input)\n",
" x = self.layer2(x)\n",
" x = self.layer3(x)\n",
" x = F.relu(x)\n",
" return x\n",
"\n",
"net1 = NeuralNetwork()\n",
"nn_model1 = NNModel.from_pytorch_module(\n",
" module=net1, nn_model_id=\"net3\", inputs=[Input(input_id=\"input0\")]\n",
")\n",
"NNModelStandard.save_data(\n",
" data=nn_model1, filename=\"net3_from_pytorch.yaml\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "8b04d5fe-42f0-443b-b02d-38832b6523f6",
"metadata": {},
"source": [
"### Parameterizing network inputs\n",
"\n",
"A straightforward alternative to defining network inputs as array data in files, is to define them in the parameters table, with appropriate bounds and nominal values. \n",
"\n",
"| parameterId | parameterScale | lowerBound | upperBound | nominalValue | estimate |\n",
"|-----------------|----------------|------------|------------|--------------|----------|\n",
"| net1_input_pre1 | lin | -inf | inf | 1 | 0 |\n",
"| net1_input_pre2 | lin | -inf | inf | 1 | 0 |\n",
"| | | | | | |\n",
"\n",
"The network inputs can be defined as condition specific, with different scalar values for different conditions.\n",
"\n",
"| conditionId | net1_input1 | net1_input2 |\n",
"|-----------------|-------------------|-------------------|\n",
"| cond1 | 10.0 | 20.0 |\n",
"| cond2 | net1_input_pre1 | net1_input_pre2 |\n",
"\n",
"The corresponding mapping table then defines the model entities for condition PEtab identifiers.\n",
"\n",
"| petabEntityId | modelEntityId |\n",
"|-----------------|-------------------|\n",
"| net1_input1 | net1.inputs[0][0] |\n",
"| net1_input2 | net1.inputs[0][1] |"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}