[1]:
import logging
import optimistix as optx
from atmodeller import (
ChemicalSpecies,
EquilibriumModel,
Planet,
SolverParameters,
SpeciesNetwork,
debug_logger,
earth_oceans_to_hydrogen_mass,
)
from atmodeller.solubility import get_solubility_models
logger = debug_logger()
logger.setLevel(logging.INFO)
# For more output use DEBUG
# logger.setLevel(logging.DEBUG)
Atmodeller initialized with double precision (float64)
Iteration
This notebook is available at notebooks/iteration.ipynb and is easiest to obtain by downloading the source code.
Simple changing constraints
For models that require iterative updates where constraints evolve gradually—such as during time integration or other forms of sequential solving—Atmodeller can be used as follows. The order of the arguments and the size of the arrays must match those used to initialise the model, although the values themselves can vary between iterations.
A simple Python looping structure can be used to perform these updates. This approach is often the most intuitive way to couple Atmodeller with external codes or models that provide new constraints at each step. While not always the most performant strategy, especially for large parameter sweeps or high-resolution simulations, it offers a clear and flexible mechanism for driving iterative update processes.
[2]:
# Atmodeller initialisation outside of the iterative update (e.g., time loop)
solubility_models = get_solubility_models()
H2_g = ChemicalSpecies.create_gas("H2")
H2O_g = ChemicalSpecies.create_gas("H2O", solubility=solubility_models["H2O_peridotite_sossi23"])
O2_g = ChemicalSpecies.create_gas("O2")
species = SpeciesNetwork((H2_g, H2O_g, O2_g))
planet = Planet()
model = EquilibriumModel(species)
# Optionally, set the solver and its parameters. For an iterative update loop, you typically want
# the solver to report failures (throw=True) so you can handle them. Otherwise, failed solutions
# will propagate through the loop and generate meaningless results.
solver = optx.Newton
solver_parameters = SolverParameters(solver=solver, throw=True)
# Solve once for the initial state
oceans = 1
h_kg = earth_oceans_to_hydrogen_mass(oceans)
o_kg = 6.25774e20
mass_constraints = {"H": h_kg, "O": o_kg}
model.solve(
state=planet,
mass_constraints=mass_constraints,
solver_parameters=solver_parameters,
solver="basic",
)
# Get the solution from the initial state to provide as the guess for the next solution, which
# usually works well when constraints are not changing much between iterations.
output = model.output
# Iterative loop parameters
start_index = 1
end_index = 4
# Using Atmodeller in the iterative update loop
# This is the update loop, where something changes and you want to re-solve using Atmodeller
for ii in range(start_index, end_index):
# Let's say we update the mass constraints. The number of constraints and the value type (here,
# floats) must remain the same as the initialised model, but you can update their values.
logger.info("Iteration %d", ii)
logger.info("Your code does something here to compute new masses")
# For example, decrease H and O masses by factors that depend on the iteration number,
# mimicking atmospheric escape or other loss processes.
H_decrease = 1 - 0.1 * ii
O_decrease = 1 - 0.05 * ii
# Let's also change the melt fraction. We must create a new Planet with the desired properties.
planet = Planet(mantle_melt_fraction=1 - 0.1 * ii)
mass_constraints = {"H": h_kg * H_decrease, "O": o_kg * O_decrease}
# These solves are fast because they use the JAX-compiled code after compiling once. Note that
# we pass in an estimate of the initial_log_number_moles from the previous iteration, which
# helps with both convergence and speed.
logger.info("Atmodeller solve using JIT compiled code")
model.solve(
state=planet, # Pass in the new planet
mass_constraints=mass_constraints, # Pass in the new constraints
solver_parameters=solver_parameters, # Keep this the same
initial_log_number_moles=output.log_number_moles, # Pass in the previous solution
)
# Update output with the new solution to use as the initial guess for the next iteration
output = model.output
# Quick look at the solution
solution = output.quick_look()
logger.info("solution = %s", solution)
# Get complete solution as a dictionary
# If required, get complete output to feedback into other calculations during the time loop
# solution_asdict = output.asdict()
[19:35:31 - atmodeller.classes - INFO ] - species_network = ('H2_g: IdealGas, NoSolubility', 'H2O_g: IdealGas, SolubilityPowerLaw', 'O2_g: IdealGas, NoSolubility')
[19:35:31 - atmodeller.classes - INFO ] - Thermodynamic data requires temperatures between 200 K and 6000 K
[19:35:31 - atmodeller.classes - INFO ] - reactions = {0: '2.0 H2O_g = 2.0 H2_g + 1.0 O2_g'}
[19:35:34 - atmodeller.classes - INFO ] - Solve (basic) complete: 1 (100.00%) successful model(s)
[19:35:35 - atmodeller.classes - INFO ] - Multistart summary: 1 (100.00%) models(s) required 1 attempt(s)
[19:35:35 - atmodeller.classes - INFO ] - Solver steps (max) = 35
[19:35:35 - atmodeller - INFO ] - Iteration 1
[19:35:35 - atmodeller - INFO ] - Your code does something here to compute new masses
[19:35:35 - atmodeller - INFO ] - Atmodeller solve using JIT compiled code
[19:35:35 - atmodeller.classes - INFO ] - Solve (basic) complete: 1 (100.00%) successful model(s)
[19:35:35 - atmodeller.classes - INFO ] - Multistart summary: 1 (100.00%) models(s) required 1 attempt(s)
[19:35:35 - atmodeller.classes - INFO ] - Solver steps (max) = 4
[19:35:36 - atmodeller - INFO ] - solution = {'H2_g': array(12.994488603473695), 'H2_g_activity': array(12.994488603473654), 'H2O_g': array(0.073890412085436), 'H2O_g_activity': array(0.073890412085436), 'O2_g': array(2.684268156152379e-12), 'O2_g_activity': array(2.684268156152368e-12)}
[19:35:36 - atmodeller - INFO ] - Iteration 2
[19:35:36 - atmodeller - INFO ] - Your code does something here to compute new masses
[19:35:36 - atmodeller - INFO ] - Atmodeller solve using JIT compiled code
[19:35:36 - atmodeller.classes - INFO ] - Solve (basic) complete: 1 (100.00%) successful model(s)
[19:35:36 - atmodeller.classes - INFO ] - Multistart summary: 1 (100.00%) models(s) required 1 attempt(s)
[19:35:36 - atmodeller.classes - INFO ] - Solver steps (max) = 4
[19:35:36 - atmodeller - INFO ] - solution = {'H2_g': array(10.833246671405849), 'H2_g_activity': array(10.833246671405849), 'H2O_g': array(0.083784798332519), 'H2O_g_activity': array(0.083784798332519), 'O2_g': array(4.965709399122488e-12), 'O2_g_activity': array(4.965709399122486e-12)}
[19:35:36 - atmodeller - INFO ] - Iteration 3
[19:35:36 - atmodeller - INFO ] - Your code does something here to compute new masses
[19:35:36 - atmodeller - INFO ] - Atmodeller solve using JIT compiled code
[19:35:36 - atmodeller.classes - INFO ] - Solve (basic) complete: 1 (100.00%) successful model(s)
[19:35:36 - atmodeller.classes - INFO ] - Multistart summary: 1 (100.00%) models(s) required 1 attempt(s)
[19:35:36 - atmodeller.classes - INFO ] - Solver steps (max) = 4
[19:35:36 - atmodeller - INFO ] - solution = {'H2_g': array(8.688663692609659), 'H2_g_activity': array(8.688663692609664), 'H2O_g': array(0.097384737802907), 'H2O_g_activity': array(0.097384737802907), 'O2_g': array(1.042903033608945e-11), 'O2_g_activity': array(1.042903033608945e-11)}
Fully JAX compatible approach
You may be thinking: “Atmodeller is a JAX-compatible code, so why would I embed JAX-compiled functions inside an inefficient Python for loop?” And you’d be absolutely right. While simple loops offer clarity, they also limit performance by forcing execution back onto the Python interpreter at every step. Instead, there is a far more optimal way to integrate an Atmodeller solver into a JAX workflow—one that keeps the entire update sequence within JAX’s functional, compiled execution model. By
restructuring the iterative procedure into a form suitable for jax.lax.scan or similar control-flow primitives, the full computation can be jitted end-to-end, avoiding Python overhead and enabling XLA to optimise the entire sequence as a single fused computation. This approach preserves the flexibility of iterative updates while achieving JAX-level performance and full accelerator compatibility.
[3]:
# TODO