class VDP(nn.Module):
"""
Define the Van der Pol oscillator as a PyTorch module.
"""
def __init__(self,
float, # Stiffness parameter of the VDP oscillator
mu:
):super().__init__()
self.mu = torch.nn.Parameter(torch.tensor(mu)) # make mu a learnable parameter
def forward(self,
float, # time index
t: # state of the system first dimension is the batch size
state: torch.TensorType, -> torch.Tensor: # return the derivative of the state
) """
Define the right hand side of the VDP oscillator.
"""
= state[..., 0] # first dimension is the batch size
x = state[..., 1]
y = self.mu*(x-1/3*x**3 - y)
dX = 1/self.mu*x
dY # trick to make sure our return value has the same shape as the input
= torch.zeros_like(state)
dfunc 0] = dX
dfunc[..., 1] = dY
dfunc[..., return dfunc
def __repr__(self):
"""Print the parameters of the model."""
return f" mu: {self.mu.item()}"
Differential Equations as a Pytorch Neural Network Layer
What is the problem we are trying to solve?
Let’s say we have some time series data y(t) that we want to model with a differential equation. The data takes the form of a set of observations yᵢ at times tᵢ. Based on some domain knowledge of the underlying system we can write down a differential equation to approximate the system.
In the most general form this takes the form:
\[\begin{align} \frac{dy}{dt} = f(y,t;\theta) \\ y(t_0) = y_0 \end{align}\]
where \(y\) is the state of the system, \(t\) is time, and \(\theta\) are the parameters of the model. In this post we will assume that the parameters \(\theta\) are unknown and we want to learn them from the data.
Let’s import the libraries we will need for this post. The only non standard machine learning library we will use the torchdiffeq library to solve the differential equations. This library implements numerical differential equation solvers in pytorch.
Models
The first step of our modeling process is to define the model. For differential equations this means we must choose a form for the function \(f(y,t;\theta)\) and a way to represent the parameters \(\theta\). We also need to do this in a way that is compatible with pytorch.
This means we need to encode our function as a torch.nn.Module class. As you will see this is pretty easy and only requires defining two methods. Lets get started with the first of out three example models.
van Der Pol Oscillator (VDP)
We can define a differential equation system using the torch.nn.Module class where the parameters are created using the torch.nn.Parameter declaration. This lets pytorch know that we want to accumulate gradients for those parameters. We can also include fixed parameters (don’t want to fit these) by just not wrapping them with this declaration.
The first example we will use is the classic VDP oscillator which is a nonlinear oscillator with a single parameter \(\mu\). The differential equations for this system are:
\[\begin{align} \frac{dX}{dt} &= \mu(x-\frac{1}{3}x^3-y) \\ \frac{dY}{dt} &= \frac{x}{\mu} \\ \end{align}\]
where \(X\) and \(Y\) are the state variables. The VDP model is used to model everything from electronic circuits to cardiac arrhythmias and circadian rhythms. We can define this system in pytorch as follows:
VDP
VDP (mu:float)
Define the Van der Pol oscillator as a PyTorch module.
Type | Details | |
---|---|---|
mu | float | Stiffness parameter of the VDP oscillator |
You only need to define the dunder init method (init) and the forward method. I added a string method repr to pretty print the parameter. The key point here is how we can translate from the differential equation to torch code in the forward method. This method needs to define the right-hand side of the differential equation.
Let’s see how we can integrate this model using the odeint method from torchdiffeq:
= VDP(mu=0.5)
vdp_model
# Create a time vector, this is the time axis of the ODE
= torch.linspace(0,30.0,1000)
ts # Create a batch of initial conditions
= 30
batch_size # Creates some random initial conditions
= torch.tensor([0.01, 0.01]) + 0.2*torch.randn((batch_size,2))
initial_conditions
# Solve the ODE, odeint comes from torchdiffeq
= odeint(vdp_model, initial_conditions, ts, method='dopri5').detach().numpy() sol
0], lw=0.5);
plt.plot(ts, sol[:,:,"Time series of the VDP oscillator");
plt.title("time");
plt.xlabel("x"); plt.ylabel(
Here is a phase plane plot of the solution (a phase plane plot of a parametric plot of the dynamical state).
# Check the solution
0], sol[:,:,1], lw=0.5);
plt.plot(sol[:,:,"Phase plot of the VDP oscillator");
plt.title("x");
plt.xlabel("y"); plt.ylabel(
The colors indicate the 30 seperate trajectories in our batch. The solution comes back as a torch tensor with dimensions (time_points, batch number, dynamical_dimension).
sol.shape
(1000, 30, 2)
Lotka Volterra Predator Prey equations
As another example we create a module for the Lotka Volterra predator-prey equations. In the Lotka-Volterra (LV) predator-prey model, there are two primary variables: the population of prey (\(x\)) and the population of predators (\(y\)). The model is defined by the following equations:
\[\begin{align} \frac{dx}{dt} &= \alpha x - \beta xy \\ \frac{dy}{dt} &= -\delta y + \gamma xy \\ \end{align}\]
The population of prey (\(x\)) represents the number of individuals of the prey species present in the ecosystem at any given time. The population of predators (\(y\)) represents the number of individuals of the predator species present in the ecosystem at any given time.
In addition to the primary variables, there are also four parameters that are used to describe various ecological factors in the model:
\(\alpha\) represents the intrinsic growth rate of the prey population in the absence of predators. \(\beta\) represents the predation rate of the predators on the prey. \(\gamma\) represents the death rate of the predator population in the absence of prey. \(\delta\) represents the efficiency with which the predators convert the consumed prey into new predator biomass.
Together, these variables and parameters describe the dynamics of predator-prey interactions in an ecosystem and are used to mathematically model the changes in the populations of prey and predators over time.
class LotkaVolterra(nn.Module):
"""
The Lotka-Volterra equations are a pair of first-order, non-linear, differential equations
describing the dynamics of two species interacting in a predator-prey relationship.
"""
def __init__(self,
float = 1.5, # The alpha parameter of the Lotka-Volterra system
alpha: float = 1.0, # The beta parameter of the Lotka-Volterra system
beta: float = 3.0, # The delta parameter of the Lotka-Volterra system
delta: float = 1.0 # The gamma parameter of the Lotka-Volterra system
gamma: -> None:
) super().__init__()
self.model_params = torch.nn.Parameter(torch.tensor([alpha, beta, delta, gamma]))
def forward(self, t, state):
= state[...,0] #variables are part of vector array u
x = state[...,1]
y = torch.zeros_like(state)
sol
#coefficients are part of tensor model_params
= self.model_params
alpha, beta, delta, gamma 0] = alpha*x - beta*x*y
sol[...,1] = -delta*y + gamma*x*y
sol[...,return sol
def __repr__(self):
return f" alpha: {self.model_params[0].item()}, \
beta: {self.model_params[1].item()}, \
delta: {self.model_params[2].item()}, \
gamma: {self.model_params[3].item()}"
LotkaVolterra
LotkaVolterra (alpha:float=1.5, beta:float=1.0, delta:float=3.0, gamma:float=1.0)
The Lotka-Volterra equations are a pair of first-order, non-linear, differential equations describing the dynamics of two species interacting in a predator-prey relationship.
Type | Default | Details | |
---|---|---|---|
alpha | float | 1.5 | The alpha parameter of the Lotka-Volterra system |
beta | float | 1.0 | The beta parameter of the Lotka-Volterra system |
delta | float | 3.0 | The delta parameter of the Lotka-Volterra system |
gamma | float | 1.0 | The gamma parameter of the Lotka-Volterra system |
Returns | None |
This follows the same pattern as the first example, the main difference is that we now have four parameters and store them as a model_params tensor. Here is the integration and plotting code for the predator-prey equations.
= LotkaVolterra() #use default parameters
lv_model = torch.linspace(0,30.0,1000)
ts = 30
batch_size # Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value
= torch.tensor([[3,3]]) + 0.50*torch.randn((batch_size,2))
initial_conditions = odeint(lv_model, initial_conditions, ts, method='dopri5').detach().numpy()
sol # Check the solution
0], lw=0.5);
plt.plot(ts, sol[:,:,"Time series of the Lotka-Volterra system");
plt.title("time");
plt.xlabel("x"); plt.ylabel(
Now a phase plane plot of the system:
0], sol[:,:,1], lw=0.5);
plt.plot(sol[:,:,"Phase plot of the Lotka-Volterra system");
plt.title("x");
plt.xlabel("y"); plt.ylabel(
Lorenz system
The last example we will use is the Lorenz equations which are famous for their beatiful plots illustrating chaotic dynamics. They originally came from a reduced model for fluid dynamics and take the form:
\[\begin{align} \frac{dx}{dt} &= \sigma(y - x) \\ \frac{dy}{dt} &= x(\rho - z) - y \\ \frac{dz}{dt} &= xy - \beta z \end{align}\]
where \(x\), \(y\), and \(z\) are the state variables, and \(\sigma\), \(\rho\), and \(\beta\) are the system parameters.
class Lorenz(nn.Module):
"""
Define the Lorenz system as a PyTorch module.
"""
def __init__(self,
float =10.0, # The sigma parameter of the Lorenz system
sigma: float=28.0, # The rho parameter of the Lorenz system
rho: float=8.0/3, # The beta parameter of the Lorenz system
beta:
):super().__init__()
self.model_params = torch.nn.Parameter(torch.tensor([sigma, rho, beta]))
def forward(self, t, state):
= state[...,0] #variables are part of vector array u
x = state[...,1]
y = state[...,2]
z = torch.zeros_like(state)
sol
= self.model_params #coefficients are part of vector array p
sigma, rho, beta 0] = sigma*(y-x)
sol[...,1] = x*(rho-z) - y
sol[...,2] = x*y - beta*z
sol[...,return sol
def __repr__(self):
return f" sigma: {self.model_params[0].item()}, \
rho: {self.model_params[1].item()}, \
beta: {self.model_params[2].item()}"
Lorenz
Lorenz (sigma:float=10.0, rho:float=28.0, beta:float=2.6666666666666665)
Define the Lorenz system as a PyTorch module.
Type | Default | Details | |
---|---|---|---|
sigma | float | 10.0 | The sigma parameter of the Lorenz system |
rho | float | 28.0 | The rho parameter of the Lorenz system |
beta | float | 2.6666666666666665 | The beta parameter of the Lorenz system |
This shows how to integrate this system and plot the results. This system (at these parameter values) shows chaotic dynamics so initial conditions that start off close together diverge from one another exponetially.
= Lorenz()
lorenz_model = torch.linspace(0,50.0,3000)
ts = 30
batch_size # Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value
= torch.tensor([[1.0,0.0,0.0]]) + 0.10*torch.randn((batch_size,3))
initial_conditions = odeint(lorenz_model, initial_conditions, ts, method='dopri5').detach().numpy()
sol
# Check the solution
2000], sol[:2000,:,0], lw=0.5);
plt.plot(ts[:"Time series of the Lorenz system");
plt.title("time");
plt.xlabel("x"); plt.ylabel(
Here we show the famous butterfly plot (phase plane plot) for the first set of initial conditions in the batch.
0,0], sol[:,0,1], color='black', lw=0.5);
plt.plot(sol[:,"Phase plot of the Lorenz system");
plt.title("x");
plt.xlabel("y"); plt.ylabel(
Data
Now that we can define the differential equation models in pytorch we need to create some data to be used in training. This is where things start to get really neat as we see our first glimpse of being able to hijack deep learning machinery for fitting the parameters. Really we could just use tensor of data directly, but this is a nice way to organize the data. It will also be useful if you have some experimental data that you want to use.
Torch provides the Dataset class for loading in data. To use it you just need to create a subclass and define two methods. The __len__
function that returns the number of data points and a __getitem__
function that returns the data point at a given index. If you are wondering these methods are what underly the len(array)
and ’array[0]` subscript access in python lists.
The rest of boilerplate code needed in defined in the parent class torch.utils.data.Dataset
. We will see the power of these method when we go to define a training loop.
class SimODEData(Dataset):
"""
A very simple dataset class for simulating ODEs
"""
def __init__(self,
# List of time points as tensors
ts: List[torch.Tensor], # List of dynamical state values (tensor) at each time point
values: List[torch.Tensor], None] = None,
true_model: Union[torch.nn.Module,-> None:
) self.ts = ts
self.values = values
self.true_model = true_model
def __len__(self) -> int:
return len(self.ts)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self.ts[index], self.values[index]
SimODEData
SimODEData (ts:List[torch.Tensor], values:List[torch.Tensor], true_model:Optional[torch.nn.modules.module.Module]=None)
A very simple dataset class for simulating ODEs
Type | Default | Details | |
---|---|---|---|
ts | typing.List[torch.Tensor] | List of time points as tensors | |
values | typing.List[torch.Tensor] | List of dynamical state values (tensor) at each time point | |
true_model | typing.Optional[torch.nn.modules.module.Module] | None | |
Returns | None |
Next let’s create a quick generator function to generate some simulated data to test the algorithms on. In a real use case the data would be loaded from a file or database, but for this example we will just generate some data. In fact, I recommend that you always start with generated data to make sure your code is working before you try to load real data.
create_sim_dataset
create_sim_dataset (model:torch.nn.modules.module.Module, ts:torch.Tensor, num_samples:int=10, sigma_noise:float=0.1, initial_conditions_default:torch.Tensor=tensor([0., 0.]), sigma_initial_conditions:float=0.1)
Type | Default | Details | |
---|---|---|---|
model | Module | model to simulate from | |
ts | Tensor | Time points to simulate for | |
num_samples | int | 10 | Number of samples to generate |
sigma_noise | float | 0.1 | Noise level to add to the data |
initial_conditions_default | Tensor | tensor([0., 0.]) | Default initial conditions |
sigma_initial_conditions | float | 0.1 | Noise level to add to the initial conditions |
Returns | SimODEData |
This just takes in a differential equation model with some initial states and generates some time-series data from it (and adds in some gaussian noise). This data is then passed into our custom dataset container.
plot_time_series
plot_time_series (true_model:torch.nn.modules.module.Module, fit_model:torch.nn.modules.module.Module, data:__main__.SimODEData, time_range:tuple=(0.0, 30.0), ax:matplotlib.axes._axes.Axes=None, dyn_var_idx:int=0, title:str='Model fits', *args, **kwargs)
Plot the true model and fit model on the same axes.
Type | Default | Details | |
---|---|---|---|
true_model | Module | true underlying model for the simulated data | |
fit_model | Module | model fit to the data | |
data | SimODEData | data set to plot (scatter) | |
time_range | tuple | (0.0, 30.0) | range of times to simulate the models for |
ax | Axes | None | |
dyn_var_idx | int | 0 | |
title | str | Model fits | |
args | |||
kwargs | |||
Returns | typing.Tuple[matplotlib.figure.Figure, matplotlib.axes._axes.Axes] |
plot_phase_plane
plot_phase_plane (true_model:torch.nn.modules.module.Module, fit_model:torch.nn.modules.module.Module, data:__main__.SimODEData, time_range:tuple=(0.0, 30.0), ax:matplotlib.axes._axes.Axes=None, dyn_var_idx:tuple=(0, 1), title:str='Model fits', *args, **kwargs)
Plot the true model and fit model on the same axes.
Type | Default | Details | |
---|---|---|---|
true_model | Module | true underlying model for the simulated data | |
fit_model | Module | model fit to the data | |
data | SimODEData | data set to plot (scatter) | |
time_range | tuple | (0.0, 30.0) | range of times to simulate the models for |
ax | Axes | None | |
dyn_var_idx | tuple | (0, 1) | |
title | str | Model fits | |
args | |||
kwargs | |||
Returns | typing.Tuple[matplotlib.figure.Figure, matplotlib.axes._axes.Axes] |
Training Loop
Next we will create a wrapper function for a pytorch training loop. Training means we want to update the model parameters to increase the alignment with the data ( or decrease the misalignment).
One of the tricks for this from deep learning is to not use all the data before taking a gradient step. Part of this is necessity for using enormous datasets as you can’t fit all of that data inside a GPU’s memory, but this also can help the gradient descent algorithm avoid getting stuck in local minima.
The training loop in words: * Divide the dataset into mini-batches, these are subsets of your entire data set. Usually want to choose these randomly. * Iterate through the mini-batches, for each mini-batch: * Generate the predictions using the current model parameters * Calculate the loss (here we will use the mean squared error) * Calculate the gradients, using backpropagation.
* Update the parameters using a gradient descent step. Here we use the Adam optimizer. * Each full pass through the dataset is called an epoch.
Okay here is the code:
def train(model: torch.nn.Module, # Model to train
# Data to train on
data: SimODEData, float = 1e-2, # learning rate for the Adam optimizer
lr: int = 10, # Number of epochs to train for
epochs: int = 5, # Batch size for training
batch_size: = 'rk4', # ODE solver to use
method float = 0.10, # for fixed diffeq solver set the step size
step_size: int = 10, # How often to print the loss function message
show_every: int,None] = None, # save a plot of the fit, to disable make this None
save_plots_every: Union[str = "", #string for the model, used to reference the saved plots
model_name: *args: tuple,
**kwargs: dict
):
# Create a data loader to iterate over the data. This takes in our dataset and returns batches of data
= DataLoader(data, batch_size=batch_size, shuffle=True)
trainloader # Choose an optimizer. Adam is a good default choice as a fancy gradient descent
= torch.optim.Adam(model.parameters(), lr=lr)
optimizer # Create a loss function this computes the error between the predicted and true values
= torch.nn.MSELoss()
criterion
for epoch in range(epochs):
= 0.0
running_loss for batchdata in trainloader:
# reset gradients, famous gotcha in a pytorch training loop
optimizer.zero_grad() = batchdata # unpack the data
ts, states = states[:,0,:] # grab the initial state
initial_state # Make the prediction and then flip the dimensions to be (batch, state_dim, time)
# Pytorch expects the batch dimension to be first
= odeint(model, initial_state, ts[0], method=method, options={'step_size': step_size}).transpose(0,1)
pred # Compute the loss
= criterion(pred, states)
loss # compute gradients
loss.backward() # update parameters
optimizer.step() += loss.item() # record loss
running_loss if epoch % show_every == 0:
print(f"Loss at {epoch}: {running_loss}")
# Use this to save plots of the fit every save_plots_every epochs
if save_plots_every is not None and epoch % save_plots_every == 0:
with torch.no_grad():
= plot_time_series(data.true_model, model, data[0])
fig, ax f"Epoch: {epoch}")
ax.set_title(f"./tmp_plots/{epoch}_{model_name}_fit_plot")
fig.savefig( plt.close()
train
train (model:torch.nn.modules.module.Module, data:__main__.SimODEData, lr:float=0.01, epochs:int=10, batch_size:int=5, method='rk4', step_size:float=0.1, show_every:int=10, save_plots_every:Optional[int]=None, model_name:str='', *args:tuple, **kwargs:dict)
Type | Default | Details | |
---|---|---|---|
model | Module | Model to train | |
data | SimODEData | Data to train on | |
lr | float | 0.01 | learning rate for the Adam optimizer |
epochs | int | 10 | Number of epochs to train for |
batch_size | int | 5 | Batch size for training |
method | str | rk4 | ODE solver to use |
step_size | float | 0.1 | for fixed diffeq solver set the step size |
show_every | int | 10 | How often to print the loss function message |
save_plots_every | typing.Optional[int] | None | save a plot of the fit, to disable make this None |
model_name | str | string for the model, used to reference the saved plots | |
args | tuple | ||
kwargs | dict |
Examples
Fitting the VDP Oscillator
Let’s use this training loop to recover the parameters from simulated VDP oscillator data.
= 0.30
true_mu = VDP(mu=true_mu)
model_sim = torch.linspace(0.0,10.0,10)
ts_data = create_sim_dataset(model_sim,
data_vdp = ts_data,
ts =10,
num_samples=0.01) sigma_noise
Let’s create a model with the wrong parameter value and visualize the starting point.
= VDP(mu = 0.10)
vdp_model
plot_time_series(model_sim,
vdp_model, 0],
data_vdp[=1,
dyn_var_idx= "VDP Model: Before Parameter Fits"); title
Now, we will use the training loop to fit the parameters of the VDP oscillator to the simulated data.
=50, model_name="vdp");
train(vdp_model, data_vdp, epochsprint(f"After training: {vdp_model}, where the true value is {true_mu}")
print(f"Final Parameter Recovery Error: {vdp_model.mu - true_mu}")
Loss at 0: 0.09973308071494102
Loss at 10: 0.005132006946951151
Loss at 20: 0.0007074056047713384
Loss at 30: 0.00021287801791913807
Loss at 40: 0.00020217221026541665
After training: mu: 0.3017330467700958, where the true value is 0.3
Final Parameter Recovery Error: 0.0017330348491668701
Not to bad! Let’s see how the plot looks now…
0], dyn_var_idx=1, title = "VDP Model: Before Parameter Fits"); plot_time_series(model_sim, vdp_model, data_vdp[
The plot confirms that we almost perfectly recovered the parameter. One more quick plot, where we plot the dynamics of the system in the phase plane (a parametric plot of the state variables).
0], title = "VDP Model: After Fitting"); plot_phase_plane(model_sim, vdp_model, data_vdp[
Lotka Voltera Equations
Now lets adapt our methods to fit simulated data from the Lotka Voltera equations.
= LotkaVolterra(1.5,1.0,3.0,1.0)
model_sim_lv = torch.arange(0.0, 10.0, 0.1)
ts_data = create_sim_dataset(model_sim_lv,
data_lv = ts_data,
ts =10,
num_samples=0.1,
sigma_noise=torch.tensor([2.5, 2.5])) initial_conditions_default
= LotkaVolterra(alpha=1.6, beta=1.1,delta=2.7, gamma=1.2)
model_lv
= data_lv[0], title = "Lotka Volterra: Before Fitting"); plot_time_series(model_sim_lv, model_lv, data
=60, lr=1e-2, model_name="lotkavolterra")
train(model_lv, data_lv, epochsprint(f"Fitted model: {model_lv}")
print(f"True model: {model_sim_lv}")
Loss at 0: 1.1245160698890686
Loss at 10: 0.13308029621839523
Loss at 20: 0.047104522585868835
Loss at 30: 0.023627933114767075
Loss at 40: 0.021535277366638184
Loss at 50: 0.02137285191565752
Fitted model: alpha: 1.5944933891296387, beta: 1.0464898347854614, delta: 2.819714307785034, gamma: 0.9388865232467651
True model: alpha: 1.5, beta: 1.0, delta: 3.0, gamma: 1.0
= data_lv[0], title = "Lotka Volterra: After Fitting"); plot_time_series(model_sim_lv, model_lv, data
Now let’s visualize the results using a phase plane plot.
0], title= "Phase Plane for Lotka Volterra: After Fitting"); plot_phase_plane(model_sim_lv, model_lv, data_lv[
Lorenz Equations
Finally, let’s try to fit the Lorenz equations.
= Lorenz(sigma=10.0, rho=28.0, beta=8.0/3.0)
model_sim_lorenz = torch.arange(0, 10.0, 0.05)
ts_data = create_sim_dataset(model_sim_lorenz,
data_lorenz = ts_data,
ts =30,
num_samples=torch.tensor([1.0, 0.0, 0.0]),
initial_conditions_default=0.01,
sigma_noise=0.10) sigma_initial_conditions
= Lorenz(sigma=10.2, rho=28.2, beta=9.0/3)
lorenz_model = plot_time_series(model_sim_lorenz, lorenz_model, data_lorenz[0], title="Lorenz Model: Before Fitting");
fig, ax
2,15)) ax.set_xlim((
(2.0, 15.0)
train(lorenz_model,
data_lorenz, =300,
epochs=5,
batch_size= 'rk4',
method =0.05,
step_size=50,
show_every= 1e-3) lr
Loss at 0: 113.75426864624023
Loss at 50: 4.351463496685028
Loss at 100: 2.0423043966293335
Loss at 150: 1.2441555112600327
Loss at 200: 0.7774392068386078
Loss at 250: 0.5306024551391602
Let’s look at the results from the fitting procedure. Starting with a full plot of the dynamics.
= plot_time_series(model_sim_lorenz, lorenz_model, data_lorenz[0], title = "Lorenz Model: After Fitting"); fig, ax
Let’s zoom in on the bulk of the data and see how the fit looks.
You can see the model is very close to the true model for the data range. Now the phase plane plot.
0], title = "Lorenz Model: After Fitting", time_range=(0,20.0)); plot_phase_plane(model_sim_lorenz, lorenz_model, data_lorenz[
You can see that our fitted model performs well for t in [0,17] and then starts to diverge.
Intro to Neural Differential Equations
This is great for the situation where we know the form of the equations on the right-hand-side, but what if we don’t? Can we use this procedure to discover the model equations?
This is much too big of a subject to cover in this post (stay tuned), but one of the biggest advantages of moving our differential equations models into the torch framework is that we can mix and match them with artificial neural network layers.
The simplest thing we can do is to replace the right-hand-side \(f(y,t; \theta)\) with a neural network layer \(l_\theta(y,t)\). These types of equations have been called a neural differential equations and it can be viewed as generalization of a recurrent neural network (citation).
Let’s do this for the our simple VDP oscillator system.
Let’s remake the simulated data, you will notice that I am creating longer time-series of the data, and more samples. Fitting a neural differential equation takes much more data and more computational power since we have many more parameters that need to be determined.
# remake the data
= VDP(mu=0.20)
model_sim_vdp = torch.linspace(0.0,30.0,100) # longer time series than the custom ode layer
ts_data = create_sim_dataset(model_sim_vdp,
data_vdp = ts_data,
ts =30, # more samples than the custom ode layer
num_samples=0.1,
sigma_noise=torch.tensor([0.50,0.10])) initial_conditions_default
class NeuralDiffEq(nn.Module):
"""
Basic Neural ODE model
"""
def __init__(self,
int = 2, # dimension of the state vector
dim: -> None:
) super().__init__()
self.ann = nn.Sequential(torch.nn.Linear(dim, 8),
torch.nn.LeakyReLU(), 8, 16),
torch.nn.Linear(
torch.nn.LeakyReLU(), 16, 32),
torch.nn.Linear(
torch.nn.LeakyReLU(), 32, dim))
torch.nn.Linear(
def forward(self, t, state):
return self.ann(state)
NeuralDiffEq
NeuralDiffEq (dim:int=2)
Basic Neural ODE model
Type | Default | Details | |
---|---|---|---|
dim | int | 2 | dimension of the state vector |
Returns | None |
= NeuralDiffEq(dim=2)
model_vdp_nde 0], title = "Neural ODE: Before Fitting"); plot_time_series(model_sim_vdp, model_vdp_nde, data_vdp[
You can see we start very far away for the correct solution, but then again we are injecting much less information into our model. Let’s see if we can fit the model to get better results.
train(model_vdp_nde,
data_vdp, =1500,
epochs=1e-3,
lr=5,
batch_size=100,
show_every= "nde") model_name
Loss at 0: 84.39617252349854
Loss at 100: 84.34061241149902
Loss at 200: 73.75008296966553
Loss at 300: 3.4929964542388916
Loss at 400: 1.6555403769016266
Loss at 500: 0.7814530655741692
Loss at 600: 0.41551147401332855
Loss at 700: 0.3157300055027008
Loss at 800: 0.19066352397203445
Loss at 900: 0.15869349241256714
Loss at 1000: 0.12904016114771366
Loss at 1100: 0.23840919509530067
Loss at 1200: 0.1681726910173893
Loss at 1300: 0.09865255374461412
Loss at 1400: 0.09134986530989408
Visualizing the results, we can see that the model is able to fit the data and even extrapolate to the future (although it is not as good or fast as the specified model).
0], title = "Neural ODE: After Fitting", time_range=(0,60.0)); plot_time_series(model_sim_vdp, model_vdp_nde, data_vdp[
Now the phase plane plot of our neural differential equation model.
0], title = "Neural ODE Phase Plane: After Fitting"); plot_phase_plane(model_sim_vdp, model_vdp_nde, data_vdp[
These models take a long time to train and more data to converge on a good fit. This makes sense since we are both trying to learn the model and the parameters at the same time.
Conclusions and Wrap-Up
In this article I have demonstrated how we can use differential equation models within the pytorch ecosytem using the torchdiffeq package. The code from this article is available on github and can be opened directly to google colab for experimentation. You can also install the code from this article using pip (pip install paramfittorchdemo).
This post is an introduction in the future I will be writing more about the following topics:
- How to blend some mechanistic knowledge of the dynamics with deep learning. These have been called universal differential equations as they enable us to combine scientific knowledge with deep learning. This basically blends the two approaches together.
- How to combine differential equation layers with other deep learning layers.
- Model discovery: Can we recover the actual model equations from data? This uses tools like SINDy to extract the model equations from data.
- MLOps tools for managing the training of these models. This includes tools like MLFlow , Weights and Biases , and Tensorboard .
- Anything else I hear back about from you!
Happy modeling!