Choosing an ODE Algorithm

Chris Rackauckas

While the default algorithms, along with alg_hints = [:stiff], will suffice in most cases, there are times when you may need to exert more control. The purpose of this part of the tutorial is to introduce you to some of the most widely used algorithm choices and when they should be used. The corresponding page of the documentation is the ODE Solvers page which goes into more depth.

Diagnosing Stiffness

One of the key things to know for algorithm choices is whether your problem is stiff. Let's take for example the driven Van Der Pol equation:

using DifferentialEquations, ParameterizedFunctions
van! = @ode_def VanDerPol begin
  dy = μ*((1-x^2)*y - x)
  dx = 1*y
end μ

prob = ODEProblem(van!,[0.0,2.0],(0.0,6.3),1e6)
ODEProblem with uType Array{Float64,1} and tType Float64. In-place: true
timespan: (0.0, 6.3)
u0: [0.0, 2.0]

One indicating factor that should alert you to the fact that this model may be stiff is the fact that the parameter is 1e6: large parameters generally mean stiff models. If we try to solve this with the default method:

sol = solve(prob,Tsit5())
retcode: MaxIters
Interpolation: specialized 4th order "free" interpolation
t: 999978-element Array{Float64,1}:
 0.0                  
 4.997501249375313e-10
 5.4972513743128435e-9
 3.28990927256137e-8  
 9.055577676821075e-8 
 1.7309485648570045e-7
 2.793754678038464e-7 
 4.1495260542675094e-7
 5.807908778765186e-7 
 7.812798295243245e-7 
 ⋮                    
 1.8458616477168546   
 1.845863136999449    
 1.8458646262847271   
 1.8458661155726892   
 1.8458676048633353   
 1.8458690941566653   
 1.8458705834526792   
 1.845872072751377    
 1.8458735620527589   
u: 999978-element Array{Array{Float64,1},1}:
 [0.0, 2.0]          
 [-0.000998751, 2.0] 
 [-0.0109043, 2.0]   
 [-0.0626554, 2.0]   
 [-0.158595, 2.0]    
 [-0.270036, 2.0]    
 [-0.37832, 2.0]     
 [-0.474679, 2.0]    
 [-0.54993, 2.0]     
 [-0.602693, 2.0]    
 ⋮                   
 [-0.777547, 1.83159]
 [-0.777548, 1.83159]
 [-0.777549, 1.83159]
 [-0.77755, 1.83159] 
 [-0.777551, 1.83159]
 [-0.777552, 1.83158]
 [-0.777553, 1.83158]
 [-0.777553, 1.83158]
 [-0.777554, 1.83158]

Here it shows that maximum iterations were reached. Another thing that can happen is that the solution can return that the solver was unstable (exploded to infinity) or that dt became too small. If these happen, the first thing to do is to check that your model is correct. It could very well be that you made an error that causes the model to be unstable!

If the model is the problem, then stiffness could be the reason. We can thus hint to the solver to use an appropriate method:

sol = solve(prob,alg_hints = [:stiff])
retcode: Success
Interpolation: specialized 3rd order "free" stiffness-aware interpolation
t: 695-element Array{Float64,1}:
 0.0                  
 4.997501249375313e-10
 5.454138614593668e-9 
 1.8954284827811007e-8
 4.1496551232327575e-8
 7.308066628216586e-8 
 1.1714615060776353e-7
 1.7481240480546338e-7
 2.4862277925930763e-7
 3.4025374895995275e-7
 ⋮                    
 5.7409760021041745   
 5.801110722137093    
 5.8746506588671075   
 5.955930645265512    
 6.042472092689859    
 6.129115709541026    
 6.215759326392192    
 6.287868297594483    
 6.3                  
u: 695-element Array{Array{Float64,1},1}:
 [0.0, 2.0]          
 [-0.000998751, 2.0] 
 [-0.0108195, 2.0]   
 [-0.0368509, 2.0]   
 [-0.0780351, 2.0]   
 [-0.131248, 2.0]    
 [-0.19755, 2.0]     
 [-0.272074, 2.0]    
 [-0.350452, 2.0]    
 [-0.426453, 2.0]    
 ⋮                   
 [0.703333, -1.93784]
 [0.731566, -1.89471]
 [0.771692, -1.83948]
 [0.825655, -1.77465]
 [0.899292, -1.70015]
 [0.999836, -1.61812]
 [1.14931, -1.5255]  
 [1.35191, -1.43593] 
 [1.39928, -1.41925]

Or we can use the default algorithm. By default, DifferentialEquations.jl uses algorithms like AutoTsit5(Rodas5()) which automatically detect stiffness and switch to an appropriate method once stiffness is known.

sol = solve(prob)
retcode: Success
Interpolation: Automatic order switching interpolation
t: 1927-element Array{Float64,1}:
 0.0                  
 4.997501249375313e-10
 5.4972513743128435e-9
 3.28990927256137e-8  
 9.055577676821075e-8 
 1.7309485648570045e-7
 2.793754678038464e-7 
 4.1495260542675094e-7
 5.807908778765186e-7 
 7.812798295243245e-7 
 ⋮                    
 6.204647119899009    
 6.219555079521211    
 6.233840699473001    
 6.247503397359622    
 6.260546169082511    
 6.272975181001707    
 6.284799378478759    
 6.296030113796843    
 6.3                  
u: 1927-element Array{Array{Float64,1},1}:
 [0.0, 2.0]         
 [-0.000998751, 2.0]
 [-0.0109043, 2.0]  
 [-0.0626554, 2.0]  
 [-0.158595, 2.0]   
 [-0.270036, 2.0]   
 [-0.37832, 2.0]    
 [-0.474679, 2.0]   
 [-0.54993, 2.0]    
 [-0.602693, 2.0]   
 ⋮                  
 [1.11731, -1.54298]
 [1.14817, -1.5261] 
 [1.1805, -1.50946] 
 [1.21435, -1.49311]
 [1.24979, -1.47704]
 [1.28689, -1.46128]
 [1.3257, -1.44583] 
 [1.36632, -1.43072]
 [1.38188, -1.42526]

Another way to understand stiffness is to look at the solution.

using Plots; gr()
sol = solve(prob,alg_hints = [:stiff],reltol=1e-6)
plot(sol,denseplot=false)

Let's zoom in on the y-axis to see what's going on:

plot(sol,ylims = (-10.0,10.0))

Notice how there are some extreme vertical shifts that occur. These vertical shifts are places where the derivative term is very large, and this is indicative of stiffness. This is an extreme example to highlight the behavior, but this general idea can be carried over to your problem. When in doubt, simply try timing using both a stiff solver and a non-stiff solver and see which is more efficient.

To try this out, let's use BenchmarkTools, a package that let's us relatively reliably time code blocks.

function lorenz!(du,u,p,t)
    σ,ρ,β = p
    du[1] = σ*(u[2]-u[1])
    du[2] = u[1]*(ρ-u[3]) - u[2]
    du[3] = u[1]*u[2] - β*u[3]
end
u0 = [1.0,0.0,0.0]
p = (10,28,8/3)
tspan = (0.0,100.0)
prob = ODEProblem(lorenz!,u0,tspan,p)
ODEProblem with uType Array{Float64,1} and tType Float64. In-place: true
timespan: (0.0, 100.0)
u0: [1.0, 0.0, 0.0]

And now, let's use the @btime macro from benchmark tools to compare the use of non-stiff and stiff solvers on this problem.

using BenchmarkTools
@btime solve(prob);
995.395 μs (12678 allocations: 1.37 MiB)
@btime solve(prob,alg_hints = [:stiff]);
10.343 ms (38999 allocations: 2.23 MiB)

In this particular case, we can see that non-stiff solvers get us to the solution much more quickly.

The Recommended Methods

When picking a method, the general rules are as follows:

  • Higher order is more efficient at lower tolerances, lower order is more efficient at higher tolerances

  • Adaptivity is essential in most real-world scenarios

  • Runge-Kutta methods do well with non-stiff equations, Rosenbrock methods do well with small stiff equations, BDF methods do well with large stiff equations

While there are always exceptions to the rule, those are good guiding principles. Based on those, a simple way to choose methods is:

  • The default is Tsit5(), a non-stiff Runge-Kutta method of Order 5

  • If you use low tolerances (1e-8), try Vern7() or Vern9()

  • If you use high tolerances, try BS3()

  • If the problem is stiff, try Rosenbrock23(), Rodas5(), or CVODE_BDF()

  • If you don't know, use AutoTsit5(Rosenbrock23()) or AutoVern9(Rodas5()).

(This is a simplified version of the default algorithm chooser)

Comparison to other Software

If you are familiar with MATLAB, SciPy, or R's DESolve, here's a quick translation start to have transfer your knowledge over.

  • ode23 -> BS3()

  • ode45/dopri5 -> DP5(), though in most cases Tsit5() is more efficient

  • ode23s -> Rosenbrock23(), though in most cases Rodas4() is more efficient

  • ode113 -> VCABM(), though in many cases Vern7() is more efficient

  • dop853 -> DP8(), though in most cases Vern7() is more efficient

  • ode15s/vode -> QNDF(), though in many cases CVODE_BDF(), Rodas4() or radau() are more efficient

  • ode23t -> Trapezoid() for efficiency and GenericTrapezoid() for robustness

  • ode23tb -> TRBDF2

  • lsoda -> lsoda() (requires ]add LSODA; using LSODA)

  • ode15i -> IDA(), though in many cases Rodas4() can handle the DAE and is significantly more efficient

Appendix

This tutorial is part of the DiffEqTutorials.jl repository, found at: https://github.com/JuliaDiffEq/DiffEqTutorials.jl

To locally run this tutorial, do the following commands:

using DiffEqTutorials
DiffEqTutorials.weave_file("introduction","02-choosing_algs.jmd")

Computer Information:

Julia Version 1.1.1
Commit 55e36cc308 (2019-05-16 04:10 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-3770 CPU @ 3.40GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-6.0.1 (ORCJIT, ivybridge)

Package Information:

Status `~/.julia/environments/v1.1/Project.toml`
[7e558dbc-694d-5a72-987c-6f4ebed21442] ArbNumerics 0.5.4
[6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf] BenchmarkTools 0.4.2
[be33ccc6-a3ff-5ff2-a52e-74243cff1e17] CUDAnative 2.2.0
[3a865a2d-5b23-5a0f-bc46-62713ec82fae] CuArrays 1.0.2
[55939f99-70c6-5e9b-8bb0-5071ed7d61fd] DecFP 0.4.8
[abce61dc-4473-55a0-ba07-351d65e31d42] Decimals 0.4.0
[ebbdde9d-f333-5424-9be2-dbf1e9acfb5e] DiffEqBayes 1.1.0
[eb300fae-53e8-50a0-950c-e21f52c2b7e0] DiffEqBiological 3.8.2
[459566f4-90b8-5000-8ac3-15dfb0a30def] DiffEqCallbacks 2.5.2
[f3b72e0c-5b89-59e1-b016-84e28bfd966d] DiffEqDevTools 2.9.0
[1130ab10-4a5a-5621-a13d-e4788d82bd4c] DiffEqParamEstim 1.6.0
[055956cb-9e8b-5191-98cc-73ae4a59e68a] DiffEqPhysics 3.1.0
[6d1b261a-3be8-11e9-3f2f-0b112a9a8436] DiffEqTutorials 0.1.0
[0c46a032-eb83-5123-abaf-570d42b7fbaa] DifferentialEquations 6.4.0
[31c24e10-a181-5473-b8eb-7969acd0382f] Distributions 0.20.0
[497a8b3b-efae-58df-a0af-a86822472b78] DoubleFloats 0.9.1
[f6369f11-7733-5829-9624-2563aa707210] ForwardDiff 0.10.3
[c91e804a-d5a3-530f-b6f0-dfbca275c004] Gadfly 1.0.1
[7073ff75-c697-5162-941a-fcdaad2a7d2a] IJulia 1.18.1
[4138dd39-2aa7-5051-a626-17a0bb65d9c8] JLD 0.9.1
[23fbe1c1-3f47-55db-b15f-69d7ec21a316] Latexify 0.8.2
[eff96d63-e80a-5855-80a2-b1b0885c5ab7] Measurements 2.0.0
[961ee093-0014-501f-94e3-6117800e7a78] ModelingToolkit 0.2.0
[76087f3c-5699-56af-9a33-bf431cd00edd] NLopt 0.5.1
[2774e3e8-f4cf-5e23-947b-6d7e65073b56] NLsolve 4.0.0
[429524aa-4258-5aef-a3af-852621145aeb] Optim 0.18.1
[1dea7af3-3e70-54e6-95c3-0bf5283fa5ed] OrdinaryDiffEq 5.8.1
[65888b18-ceab-5e60-b2b9-181511a3b968] ParameterizedFunctions 4.1.1
[91a5bcdd-55d7-5caf-9e0b-520d859cae80] Plots 0.25.1
[d330b81b-6aea-500a-939a-2ce795aea3ee] PyPlot 2.8.1
[731186ca-8d62-57ce-b412-fbd966d074cd] RecursiveArrayTools 0.20.0
[90137ffa-7385-5640-81b9-e52037218182] StaticArrays 0.11.0
[f3b207a7-027a-5e70-b257-86293d7955fd] StatsPlots 0.11.0
[c3572dad-4567-51f8-b174-8c6c989267f4] Sundials 3.6.1
[1986cc42-f94f-5a68-af5c-568840ba703d] Unitful 0.15.0
[44d3d7a6-8a23-5bf8-98c5-b353f8df5ec9] Weave 0.9.0
[b77e0a4c-d291-57a0-90e8-8db25a27a240] InteractiveUtils
[37e2e46d-f89d-539d-b4ee-838fcccc9c8e] LinearAlgebra
[44cfe95a-1eb2-52ea-b672-e2afdf69b78f] Pkg