Package pychnosz

pyCHNOSZ: Thermodynamic Calculations and Diagrams for Geochemistry

An integrated set of tools for thermodynamic calculations in aqueous geochemistry and geobiochemistry. Functions are provided for writing balanced reactions to form species from user-selected basis species and for calculating the standard molal properties of species and reactions, including the standard Gibbs energy and equilibrium constant.

Python port of the CHNOSZ package for R. The original CHNOSZ package belongs to Dr. Jeffrey Dick.

Sub-modules

pychnosz.biomolecules

Biomolecule thermodynamics package for CHNOSZ …

pychnosz.core

Core thermodynamic calculation functions for CHNOSZ.

pychnosz.data

Data management and access for CHNOSZ thermodynamic database.

pychnosz.fortran

CHNOSZ Fortran interface package …

pychnosz.geochemistry

Geochemistry package for CHNOSZ …

pychnosz.models

Equation of state models and water property models for CHNOSZ.

pychnosz.utils

Utility functions for CHNOSZ calculations.

Functions

def Berman(name: str,
T: float | List[float] = 298.15,
P: float | List[float] = 1,
check_G: bool = False,
calc_transition: bool = True,
calc_disorder: bool = True) ‑> pandas.core.frame.DataFrame
Expand source code
def Berman(name: str, T: Union[float, List[float]] = 298.15, P: Union[float, List[float]] = 1, 
           check_G: bool = False, calc_transition: bool = True, calc_disorder: bool = True) -> pd.DataFrame:
    """
    Calculate thermodynamic properties of minerals using Berman equations.
    
    Parameters
    ----------
    name : str
        Name of the mineral
    T : float or list, optional
        Temperature in Kelvin (default: 298.15)
    P : float or list, optional  
        Pressure in bar (default: 1)
    check_G : bool, optional
        Check consistency of G in data file (default: False)
    calc_transition : bool, optional
        Calculate polymorphic transition contributions (default: True)
    calc_disorder : bool, optional
        Calculate disorder contributions (default: True)
        
    Returns
    -------
    pd.DataFrame
        DataFrame with columns T, P, G, H, S, Cp, V
    """
    
    # Reference temperature and pressure
    Pr = 1
    Tr = 298.15
    
    # Make T and P the same length
    if isinstance(T, (int, float)):
        T = [T]
    if isinstance(P, (int, float)):
        P = [P]

    # Convert to list if numpy array (to avoid element-wise multiplication bug)
    if isinstance(T, np.ndarray):
        T = T.tolist()
    if isinstance(P, np.ndarray):
        P = P.tolist()

    ncond = max(len(T), len(P))
    T = np.array(T * (ncond // len(T) + 1), dtype=float)[:ncond]
    P = np.array(P * (ncond // len(P) + 1), dtype=float)[:ncond]
    
    # Get parameters in the Berman equations
    # Start with thermodynamic parameters provided with CHNOSZ
    thermo_sys = thermo()
    if thermo_sys.Berman is None:
        raise RuntimeError("Berman data not loaded. Please run pychnosz.reset() first.")
    
    dat = thermo_sys.Berman.copy()
    
    # TODO: Handle user-supplied data file (thermo()$opt$Berman)
    # For now, just use the default data
    
    # Remove duplicates (only the first, i.e. most recent entry is kept)
    dat = dat.drop_duplicates(subset=['name'], keep='first')
    
    # Remove the multipliers on volume parameters
    vcols = ['v1', 'v2', 'v3', 'v4']  # columns with v1, v2, v3, v4
    multexp = [5, 5, 5, 8]
    for i, col in enumerate(vcols):
        if col in dat.columns:
            dat[col] = dat[col] / (10 ** multexp[i])
    
    # Which row has data for this mineral?
    matching_rows = dat[dat['name'] == name]
    if len(matching_rows) == 0:
        raise ValueError(f"Data for {name} not available in Berman database")
    
    dat_mineral = matching_rows.iloc[0]
    
    # Extract parameters for easier access
    GfPrTr = dat_mineral['GfPrTr']
    HfPrTr = dat_mineral['HfPrTr'] 
    SPrTr = dat_mineral['SPrTr']
    VPrTr = dat_mineral['VPrTr']
    
    k0 = dat_mineral['k0']
    k1 = dat_mineral['k1'] 
    k2 = dat_mineral['k2']
    k3 = dat_mineral['k3']
    k4 = dat_mineral['k4'] if not pd.isna(dat_mineral['k4']) else 0
    k5 = dat_mineral['k5'] if not pd.isna(dat_mineral['k5']) else 0
    k6 = dat_mineral['k6'] if not pd.isna(dat_mineral['k6']) else 0
    
    v1 = dat_mineral['v1'] if not pd.isna(dat_mineral['v1']) else 0
    v2 = dat_mineral['v2'] if not pd.isna(dat_mineral['v2']) else 0
    v3 = dat_mineral['v3'] if not pd.isna(dat_mineral['v3']) else 0
    v4 = dat_mineral['v4'] if not pd.isna(dat_mineral['v4']) else 0
    
    # Transition parameters
    Tlambda = dat_mineral['Tlambda'] if not pd.isna(dat_mineral['Tlambda']) else None
    Tref = dat_mineral['Tref'] if not pd.isna(dat_mineral['Tref']) else None
    dTdP = dat_mineral['dTdP'] if not pd.isna(dat_mineral['dTdP']) else None
    l1 = dat_mineral['l1'] if not pd.isna(dat_mineral['l1']) else None
    l2 = dat_mineral['l2'] if not pd.isna(dat_mineral['l2']) else None
    
    # Disorder parameters
    Tmin = dat_mineral['Tmin'] if not pd.isna(dat_mineral['Tmin']) else None
    Tmax = dat_mineral['Tmax'] if not pd.isna(dat_mineral['Tmax']) else None
    d0 = dat_mineral['d0'] if not pd.isna(dat_mineral['d0']) else None
    d1 = dat_mineral['d1'] if not pd.isna(dat_mineral['d1']) else None
    d2 = dat_mineral['d2'] if not pd.isna(dat_mineral['d2']) else None
    d3 = dat_mineral['d3'] if not pd.isna(dat_mineral['d3']) else None
    d4 = dat_mineral['d4'] if not pd.isna(dat_mineral['d4']) else None
    Vad = dat_mineral['Vad'] if not pd.isna(dat_mineral['Vad']) else None
    
    # Get the entropy of the elements using the chemical formula
    # Get formula from OBIGT and calculate using entropy() function like in R CHNOSZ
    SPrTr_elements = 0
    if thermo_sys.obigt is not None:
        obigt_match = thermo_sys.obigt[thermo_sys.obigt['name'] == name]
        if len(obigt_match) > 0:
            formula = obigt_match.iloc[0]['formula']
            # Import entropy function and calculate SPrTr_elements properly
            from ..utils.formula import entropy
            SPrTr_elements = entropy(formula)
    
    # Check that G in data file follows Benson-Helgeson convention
    if check_G and not pd.isna(GfPrTr):
        GfPrTr_calc = HfPrTr - Tr * (SPrTr - SPrTr_elements)
        Gdiff = GfPrTr_calc - GfPrTr
        if abs(Gdiff) >= 1000:
            warnings.warn(f"{name}: GfPrTr(calc) - GfPrTr(table) is too big! == {round(Gdiff)} J/mol")
    
    ### Thermodynamic properties ###
    # Calculate Cp and V (Berman, 1988 Eqs. 4 and 5)
    # k4, k5, k6 terms from winTWQ documentation (doi:10.4095/223425)
    Cp = k0 + k1 * T**(-0.5) + k2 * T**(-2) + k3 * T**(-3) + k4 * T**(-1) + k5 * T + k6 * T**2
    
    P_Pr = P - Pr
    T_Tr = T - Tr
    V = VPrTr * (1 + v1 * T_Tr + v2 * T_Tr**2 + v3 * P_Pr + v4 * P_Pr**2)
    
    # Calculate Ha (symbolically integrated using sympy - expressions not simplified)
    intCp = (T*k0 - Tr*k0 + k2/Tr - k2/T + k3/(2*Tr**2) - k3/(2*T**2) + 2.0*k1*T**0.5 - 2.0*k1*Tr**0.5 + 
             k4*np.log(T) - k4*np.log(Tr) + k5*T**2/2 - k5*Tr**2/2 - k6*Tr**3/3 + k6*T**3/3)
    
    intVminusTdVdT = (-VPrTr + P*(VPrTr + VPrTr*v4 - VPrTr*v3 - Tr*VPrTr*v1 + VPrTr*v2*Tr**2 - VPrTr*v2*T**2) +
                      P**2*(VPrTr*v3/2 - VPrTr*v4) + VPrTr*v3/2 - VPrTr*v4/3 + Tr*VPrTr*v1 + 
                      VPrTr*v2*T**2 - VPrTr*v2*Tr**2 + VPrTr*v4*P**3/3)
    
    Ha = HfPrTr + intCp + intVminusTdVdT
    
    # Calculate S (also symbolically integrated)
    intCpoverT = (k0*np.log(T) - k0*np.log(Tr) - k3/(3*T**3) + k3/(3*Tr**3) + k2/(2*Tr**2) - k2/(2*T**2) + 
                  2.0*k1*Tr**(-0.5) - 2.0*k1*T**(-0.5) + k4/Tr - k4/T + T*k5 - Tr*k5 + k6*T**2/2 - k6*Tr**2/2)
    
    intdVdT = -VPrTr*(v1 + v2*(-2*Tr + 2*T)) + P*VPrTr*(v1 + v2*(-2*Tr + 2*T))
    
    S = SPrTr + intCpoverT - intdVdT
    
    # Calculate Ga --> Berman-Brown convention (DG = DH - T*S, no S(element))
    Ga = Ha - T * S
    
    ### Polymorphic transition properties ###
    if (Tlambda is not None and Tref is not None and 
        not pd.isna(Tlambda) and not pd.isna(Tref) and 
        np.any(T > Tref) and calc_transition):
        
        # Starting transition contributions are 0
        Cptr = np.zeros(ncond)
        Htr = np.zeros(ncond)
        Str = np.zeros(ncond)
        
        # Eq. 9: Tlambda at P
        Tlambda_P = Tlambda + dTdP * (P - 1)
        
        # Eq. 8a: Cp at P
        Td = Tlambda - Tlambda_P
        Tprime = T + Td
        
        # With the condition that Tref < Tprime < Tlambda(1bar)
        iTprime = (Tref < Tprime) & (Tprime < Tlambda)
        # Handle NA values
        iTprime = iTprime & ~np.isnan(Tprime)
        
        if np.any(iTprime):
            Tprime_valid = Tprime[iTprime]
            Cptr[iTprime] = Tprime_valid * (l1 + l2 * Tprime_valid)**2
        
        # We got Cp, now calculate the integrations for H and S
        iTtr = T > Tref
        if np.any(iTtr):
            Ttr = T[iTtr].copy()
            Tlambda_P_tr = Tlambda_P[iTtr].copy()
            Td_tr = Td[iTtr] if hasattr(Td, '__len__') else np.full_like(Ttr, Td)
            
            # Handle NA values
            Tlambda_P_tr[np.isnan(Tlambda_P_tr)] = np.inf
            
            # The upper integration limit is Tlambda_P
            Ttr[Ttr >= Tlambda_P_tr] = Tlambda_P_tr[Ttr >= Tlambda_P_tr]
            
            # Derived variables
            tref = Tref - Td_tr
            x1 = l1**2 * Td_tr + 2 * l1 * l2 * Td_tr**2 + l2**2 * Td_tr**3
            x2 = l1**2 + 4 * l1 * l2 * Td_tr + 3 * l2**2 * Td_tr**2
            x3 = 2 * l1 * l2 + 3 * l2**2 * Td_tr
            x4 = l2**2
            
            # Eqs. 10, 11, 12
            Htr[iTtr] = (x1 * (Ttr - tref) + x2/2 * (Ttr**2 - tref**2) + 
                        x3/3 * (Ttr**3 - tref**3) + x4/4 * (Ttr**4 - tref**4))
            Str[iTtr] = (x1 * (np.log(Ttr) - np.log(tref)) + x2 * (Ttr - tref) + 
                        x3/2 * (Ttr**2 - tref**2) + x4/3 * (Ttr**3 - tref**3))
        
        Gtr = Htr - T * Str
        
        # Apply the transition contributions
        Ga = Ga + Gtr
        Ha = Ha + Htr
        S = S + Str
        Cp = Cp + Cptr
    
    ### Disorder thermodynamic properties ###
    if (Tmin is not None and Tmax is not None and 
        not pd.isna(Tmin) and not pd.isna(Tmax) and 
        np.any(T > Tmin) and calc_disorder):
        
        # Starting disorder contributions are 0
        Cpds = np.zeros(ncond)
        Hds = np.zeros(ncond)
        Sds = np.zeros(ncond)
        Vds = np.zeros(ncond)
        
        # The lower integration limit is Tmin
        iTds = T > Tmin
        if np.any(iTds):
            Tds = T[iTds].copy()
            # The upper integration limit is Tmax
            Tds[Tds > Tmax] = Tmax
            
            # Ber88 Eqs. 15, 16, 17
            Cpds[iTds] = d0 + d1*Tds**(-0.5) + d2*Tds**(-2) + d3*Tds + d4*Tds**2
            Hds[iTds] = (d0*(Tds - Tmin) + d1*(Tds**0.5 - Tmin**0.5)/0.5 +
                        d2*(Tds**(-1) - Tmin**(-1))/(-1) + d3*(Tds**2 - Tmin**2)/2 + d4*(Tds**3 - Tmin**3)/3)
            Sds[iTds] = (d0*(np.log(Tds) - np.log(Tmin)) + d1*(Tds**(-0.5) - Tmin**(-0.5))/(-0.5) +
                        d2*(Tds**(-2) - Tmin**(-2))/(-2) + d3*(Tds - Tmin) + d4*(Tds**2 - Tmin**2)/2)
        
        # Eq. 18; we can't do this if Vad == 0 (dolomite and gehlenite)
        if Vad is not None and not pd.isna(Vad) and Vad != 0:
            Vds = Hds / Vad
        
        # Include the Vds term with Hds
        Hds = Hds + Vds * (P - Pr)
        
        # Disordering properties above Tmax (Eq. 20)
        ihigh = T > Tmax
        if np.any(ihigh):
            Hds[ihigh] = Hds[ihigh] - (T[ihigh] - Tmax) * Sds[ihigh]
        
        Gds = Hds - T * Sds
        
        # Apply the disorder contributions
        Ga = Ga + Gds
        Ha = Ha + Hds
        S = S + Sds
        V = V + Vds
        Cp = Cp + Cpds
    
    ### (for testing) Use G = H - TS to check that integrals for H and S are written correctly
    Ga_fromHminusTS = Ha - T * S
    if not np.allclose(Ga_fromHminusTS, Ga, atol=1e-6):
        raise RuntimeError(f"{name}: incorrect integrals detected using DG = DH - T*S")
    
    ### Thermodynamic and unit conventions used in SUPCRT ###
    # Use entropy of the elements in calculation of G --> Benson-Helgeson convention (DG = DH - T*DS)
    Gf = Ga + Tr * SPrTr_elements
    
    # The output will just have "G" and "H"
    G = Gf
    H = Ha
    
    # Convert J/bar to cm^3/mol
    V = V * 10
    
    return pd.DataFrame({
        'T': T,
        'P': P, 
        'G': G,
        'H': H,
        'S': S,
        'Cp': Cp,
        'V': V
    })

Calculate thermodynamic properties of minerals using Berman equations.

Parameters

name : str
Name of the mineral
T : float or list, optional
Temperature in Kelvin (default: 298.15)
P : float or list, optional
Pressure in bar (default: 1)
check_G : bool, optional
Check consistency of G in data file (default: False)
calc_transition : bool, optional
Calculate polymorphic transition contributions (default: True)
calc_disorder : bool, optional
Calculate disorder contributions (default: True)

Returns

pd.DataFrame
DataFrame with columns T, P, G, H, S, Cp, V
def G2logK(G, Tc)
Expand source code
def G2logK(G, Tc):
    # Gas constant R is in cal/mol K
    return G / (-math.log(10) * 1.9872 * (273.15+Tc))
def OBIGT2eos(OBIGT, fixGHS=True, tocal=True, messages=True)
Expand source code
def OBIGT2eos(OBIGT, fixGHS=True, tocal=True, messages=True):
    """
    Convert OBIGT dataframe to equation of state parameters.

    This function processes the OBIGT thermodynamic database to prepare it for
    equation of state calculations. It handles energy unit conversions and
    optionally fills in missing G, H, or S values.

    Parameters
    ----------
    OBIGT : pd.DataFrame
        OBIGT thermodynamic database
    fixGHS : bool, default True
        Fill in one missing value among G, H, S using thermodynamic relations
    tocal : bool, default True
        Convert energy units from Joules to calories
    messages : bool, default True
        Print informational messages (currently not used, reserved for future)

    Returns
    -------
    pd.DataFrame
        Modified OBIGT dataframe with converted parameters
    """
    OBIGT_out = OBIGT.copy()

    # Get column indices for named columns (to handle varying column positions)
    G_idx = OBIGT.columns.get_loc('G')
    H_idx = OBIGT.columns.get_loc('H')
    S_idx = OBIGT.columns.get_loc('S')
    Cp_idx = OBIGT.columns.get_loc('Cp')
    V_idx = OBIGT.columns.get_loc('V')
    omega_lambda_idx = OBIGT.columns.get_loc('omega.lambda')

    for i in range(0, OBIGT.shape[0]):

        # we only convert omega for aqueous species, not lambda for cgl species
        if tocal and OBIGT.iloc[i, :]["E_units"] == "J" and OBIGT.iloc[i, :]["state"] == "aq":
            # Convert G, H, S, Cp
            OBIGT_out.iloc[i, G_idx:Cp_idx+1] = OBIGT.iloc[i, G_idx:Cp_idx+1]/4.184
            # Convert V through omega (includes omega for aq species)
            OBIGT_out.iloc[i, V_idx:omega_lambda_idx+1] = OBIGT.iloc[i, V_idx:omega_lambda_idx+1]/4.184
            OBIGT_out.iloc[i, OBIGT.columns.get_loc('E_units')] = "cal"

        elif tocal and OBIGT.iloc[i, :]["E_units"] == "J":
            # Convert G, H, S, Cp
            OBIGT_out.iloc[i, G_idx:Cp_idx+1] = OBIGT.iloc[i, G_idx:Cp_idx+1]/4.184
            # Convert V through c2.f (exclude omega.lambda for non-aq species)
            OBIGT_out.iloc[i, V_idx:omega_lambda_idx] = OBIGT.iloc[i, V_idx:omega_lambda_idx]/4.184
            OBIGT_out.iloc[i, OBIGT.columns.get_loc('E_units')] = "cal"

        # fill in one of missing G, H, S
        # for use esp. by subcrt because NA for one of G, H or S
        # will preclude calculations at high T
        if fixGHS:
            # which entries are missing just one
            GHS_values = [OBIGT.iloc[i, G_idx], OBIGT.iloc[i, H_idx], OBIGT.iloc[i, S_idx]]
            imiss = [pd.isna(v) for v in GHS_values]
            if sum(imiss) == 1:

                ii = imiss.index(True)

                if ii == 0:  # G is missing
                    H = OBIGT_out.iloc[i, H_idx]
                    S = OBIGT_out.iloc[i, S_idx]
                    Selem = entropy(OBIGT_out.iloc[i, OBIGT_out.columns.get_loc('formula')])
                    T = 298.15
                    G = H - T*(S - Selem)
                    OBIGT_out.iloc[i, G_idx] = G
                elif ii == 1:  # H is missing
                    G = OBIGT_out.iloc[i, G_idx]
                    S = OBIGT_out.iloc[i, S_idx]
                    Selem = entropy(OBIGT_out.iloc[i, OBIGT_out.columns.get_loc('formula')])
                    T = 298.15
                    H = G + T*(S - Selem)
                    OBIGT_out.iloc[i, H_idx] = H
                elif ii == 2:  # S is missing
                    G = OBIGT_out.iloc[i, G_idx]
                    H = OBIGT_out.iloc[i, H_idx]
                    Selem = entropy(OBIGT_out.iloc[i, OBIGT_out.columns.get_loc('formula')])
                    T = 298.15
                    S = Selem + (H - G)/T
                    OBIGT_out.iloc[i, S_idx] = S

    return OBIGT_out

Convert OBIGT dataframe to equation of state parameters.

This function processes the OBIGT thermodynamic database to prepare it for equation of state calculations. It handles energy unit conversions and optionally fills in missing G, H, or S values.

Parameters

OBIGT : pd.DataFrame
OBIGT thermodynamic database
fixGHS : bool, default True
Fill in one missing value among G, H, S using thermodynamic relations
tocal : bool, default True
Convert energy units from Joules to calories
messages : bool, default True
Print informational messages (currently not used, reserved for future)

Returns

pd.DataFrame
Modified OBIGT dataframe with converted parameters
def ZC(formula: str | int | List[str | int]) ‑> float | List[float]
Expand source code
def ZC(formula: Union[str, int, List[Union[str, int]]]) -> Union[float, List[float]]:
    """
    Calculate average oxidation state of carbon in chemical formulas.
    
    Parameters
    ----------
    formula : str, int, or list
        Chemical formula(s) or species index(es)
        
    Returns
    -------
    float or list of float
        Average oxidation state(s) of carbon
    """
    # Get elemental compositions
    compositions = makeup(formula, count_zero=False)
    if not isinstance(compositions, list):
        compositions = [compositions]
    
    results = []
    
    # Nominal charges of elements
    known_elements = ['H', 'N', 'O', 'S', 'Z']
    charges = [-1, 3, 2, 2, 1]
    
    for comp in compositions:
        if comp is None or 'C' not in comp:
            results.append(np.nan)
            continue
        
        # Calculate total charge from known elements
        total_charge = 0
        unknown_elements = []
        
        for element, count in comp.items():
            if element == 'C':
                continue
            elif element in known_elements:
                idx = known_elements.index(element)
                total_charge += count * charges[idx]
            else:
                unknown_elements.append(element)
        
        if unknown_elements:
            warnings.warn(f"element(s) {' '.join(unknown_elements)} not in "
                         f"{' '.join(known_elements)} so not included in ZC calculation")
        
        # Calculate carbon oxidation state
        n_carbon = comp['C']
        zc = total_charge / n_carbon
        results.append(zc)
    
    if len(results) == 1:
        return results[0]
    else:
        return results

Calculate average oxidation state of carbon in chemical formulas.

Parameters

formula : str, int, or list
Chemical formula(s) or species index(es)

Returns

float or list of float
Average oxidation state(s) of carbon
def add_OBIGT(file: str | pandas.core.frame.DataFrame,
force: bool = True,
messages: bool = True) ‑> List[int]
Expand source code
def add_OBIGT(file: Union[str, pd.DataFrame], force: bool = True, messages: bool = True) -> List[int]:
    """
    Add or replace entries in the thermodynamic database from external files or DataFrames.

    This function replicates the behavior of R CHNOSZ add.OBIGT() by loading
    CSV files from inst/extdata/OBIGT/ or accepting pandas DataFrames directly,
    and replacing entries with matching names.

    Parameters
    ----------
    file : str or pd.DataFrame
        Either:
        - Name of the database file to load (e.g., "SUPCRT92")
          The function will look for file.csv in inst/extdata/OBIGT/
        - Full path to a CSV file
        - A pandas DataFrame containing OBIGT data
    force : bool, default True
        If True, proceed even if some species are not found
    messages : bool, default True
        If True, print informational messages about additions/replacements
        If False, suppress all output (equivalent to R's suppressMessages())

    Returns
    -------
    list of int
        List of species indices (1-based) that were added or replaced

    Examples
    --------
    >>> import pychnosz
    >>> import pandas as pd
    >>>
    >>> # Example 1: Load from file name
    >>> pychnosz.reset()
    >>> indices = pychnosz.add_OBIGT("SUPCRT92")
    >>>
    >>> # Example 2: Load from DataFrame
    >>> thermo_df = pd.read_csv("thermodata.csv")
    >>> indices = pychnosz.add_OBIGT(thermo_df)
    >>>
    >>> # Example 3: Suppress messages
    >>> indices = pychnosz.add_OBIGT(thermo_df, messages=False)

    Notes
    -----
    This function modifies the thermo() object in place, replacing entries
    with matching names and adding new entries for species not in the database.
    The behavior exactly matches R CHNOSZ add.OBIGT().
    """

    # Get the thermo system
    thermo_sys = thermo()

    # Ensure the thermodynamic system is initialized
    if not thermo_sys.is_initialized() or thermo_sys.obigt is None:
        thermo_sys.reset()

    # Handle DataFrame input
    if isinstance(file, pd.DataFrame):
        new_data = file.copy()
        file_path = "<DataFrame>"
        file_basename = None
    else:
        # Handle string file path
        # If file is not an existing path, look for it in OBIGT directories
        if not os.path.exists(file):
            if not file.endswith('.csv'):
                file_to_find = file + '.csv'
            else:
                file_to_find = file

            # Look for the file in the OBIGT data directory
            # Use package-relative path
            base_paths = [
                os.path.join(os.path.dirname(__file__), 'extdata', 'OBIGT'),
            ]

            file_path = None
            for base_path in base_paths:
                potential_path = os.path.join(base_path, file_to_find)
                if os.path.exists(potential_path):
                    file_path = potential_path
                    break

            if file_path is None:
                raise FileNotFoundError(f"Could not find OBIGT file: {file}")
        else:
            # Use the file path as provided
            file_path = file

        # Extract the basename for source_file column
        file_basename = os.path.basename(file_path)

        # Read the CSV file
        try:
            new_data = pd.read_csv(file_path)
        except Exception as e:
            raise ValueError(f"Error reading {file_path}: {e}")

    if new_data.empty:
        raise ValueError(f"No data found in {file_path if isinstance(file, str) else 'DataFrame'}")

    # Validate columns before processing
    # Get the current OBIGT database to determine required columns
    to1 = thermo_sys.obigt

    # Define core required columns that all species must have
    # These are the fundamental columns needed for thermodynamic calculations
    # Model-specific columns (logK*, T*, P*, etc.) are optional
    core_required_columns = [
        'name', 'abbrv', 'formula', 'state', 'ref1', 'ref2', 'date', 'E_units',
        'G', 'H', 'S', 'Cp', 'V',
        'a1.a', 'a2.b', 'a3.c', 'a4.d', 'c1.e', 'c2.f', 'omega.lambda', 'z.T'
    ]

    # The 'model' column is optional and will be auto-generated if missing
    # Filter to only include columns that exist in current OBIGT (for compatibility)
    required_columns = [col for col in core_required_columns if col in to1.columns]

    # Check for missing required columns
    missing_columns = [col for col in required_columns if col not in new_data.columns]

    if missing_columns:
        raise ValueError(
            f"Missing required columns in input data: {', '.join(missing_columns)}. "
            f"Please ensure the CSV file contains all necessary OBIGT database columns."
        )

    # Special handling for 'model' column
    if 'model' not in new_data.columns:
        # Create model column with proper values
        new_data = new_data.copy()  # Make a copy to avoid SettingWithCopyWarning

        # Assign model based on state:
        # - aqueous species (state == 'aq') get 'HKF'
        # - non-aqueous species get 'CGL'
        new_data['model'] = new_data['state'].apply(lambda x: 'HKF' if x == 'aq' else 'CGL')

        # Issue a warning to inform the user
        warnings.warn(
            "The 'model' column was not found in the input data. "
            "Auto-generating 'model' column: 'HKF' for aqueous species (state='aq'), "
            "'CGL' for all other species.",
            UserWarning
        )

    # Get energy units from the file (all unique values)
    # Match R's behavior: unique values joined with " and "
    if 'E_units' in new_data.columns:
        unique_units = new_data['E_units'].dropna().unique().tolist()
        # Filter out non-energy unit values like "CGL" (which is a model, not energy unit)
        # Valid energy units are typically "cal" and "J"
        energy_unit_names = [str(u) for u in unique_units if str(u) in ['cal', 'J']]
        # Join in the order they appear in the file (matching R's paste(unique(...), collapse = " and "))
        energy_units_str = ' and '.join(energy_unit_names) if energy_unit_names else 'cal'
    else:
        energy_units_str = 'cal'

    # Create identifier strings for matching (name + state)
    id1 = to1['name'].astype(str) + ' ' + to1['state'].astype(str)
    id2 = new_data['name'].astype(str) + ' ' + new_data['state'].astype(str)

    # Track the indices we've modified/added
    inew = []

    # Check which entries in new_data exist in current database
    # does_exist is a boolean array indicating which id2 entries are in id1
    does_exist = id2.isin(id1.values)

    # Get the indices in to1 where matches exist (matching R's match(id2, id1))
    # This gives us the positions in to1 for each id2 element
    ispecies_exist = []
    for i, id_val in enumerate(id2):
        if does_exist.iloc[i]:
            # Find the index in to1 where this matches
            match_idx = id1[id1 == id_val].index[0]
            ispecies_exist.append(match_idx)
        else:
            ispecies_exist.append(None)

    nexist = sum(does_exist)

    # Check if new_data has columns that to1 doesn't have, and add them
    # Use object dtype for new columns to match pandas default behavior and avoid FutureWarning
    for col in new_data.columns:
        if col not in to1.columns:
            # Determine dtype from new_data
            dtype = new_data[col].dtype
            # Use object dtype for string columns to avoid dtype incompatibility
            if dtype == object or pd.api.types.is_string_dtype(dtype):
                to1[col] = pd.Series(dtype=object)
            else:
                to1[col] = np.nan

    if force:
        # Replace existing entries
        if nexist > 0:
            # Update rows in to1 for species that exist
            for i, idx in enumerate(ispecies_exist):
                if idx is not None:
                    # Replace the row in to1 with data from new_data
                    for col in new_data.columns:
                        # col should now be in to1 since we added missing columns above
                        to1.loc[idx, col] = new_data.iloc[i][col]
                    # Set source_file for replaced entries
                    if file_basename is not None:
                        to1.loc[idx, 'source_file'] = file_basename

            # Add these indices to inew
            inew.extend([idx for idx in ispecies_exist if idx is not None])

            # Remove existing entries from new_data (to2 <- to2[!does.exist, ])
            to2 = new_data[~does_exist].copy()
        else:
            to2 = new_data.copy()
    else:
        # Ignore any new entries that already exist
        to2 = new_data[~does_exist].copy()
        nexist = 0

    # Add new entries
    if len(to2) > 0:
        # Store the starting index for new additions
        len_id1 = len(id1)

        # Ensure new entries have all required columns
        # Make a proper copy to avoid SettingWithCopyWarning
        to2 = to2.copy()
        for col in to1.columns:
            if col not in to2.columns:
                to2[col] = np.nan

        # Set source_file for new entries
        if file_basename is not None:
            to2['source_file'] = file_basename

        # Reorder columns to match current OBIGT
        to2 = to2.reindex(columns=to1.columns)

        # Add to the database
        # Use concat with explicit future behavior to avoid FutureWarning
        to1 = pd.concat([to1, to2], ignore_index=True, sort=False)

        # Add new indices: (length(id1)+1):nrow(to1)
        new_indices = list(range(len_id1 + 1, len(to1) + 1))
        inew.extend(new_indices)

    # Reset rownames to 1:nrow (matching R's rownames(thermo$OBIGT) <- 1:nrow(thermo$OBIGT))
    to1.index = range(1, len(to1) + 1)

    # Update the thermo system with modified database
    thermo_sys.obigt = to1

    # Update formula_ox if the column exists in the database
    if 'formula_ox' in to1.columns:
        # Create a DataFrame with name and formula_ox columns
        # Keep the same index as the obigt DataFrame (1-based)
        formula_ox_df = pd.DataFrame({
            'name': to1['name'],
            'formula_ox': to1['formula_ox']
        })
        # Preserve the 1-based index
        formula_ox_df.index = to1.index
        thermo_sys.formula_ox = formula_ox_df
    else:
        # If formula_ox column doesn't exist, set to None
        thermo_sys.formula_ox = None

    # Print summary (matching R CHNOSZ output)
    if messages:
        print(f"add_OBIGT: read {len(new_data)} rows; made {nexist} replacements, {len(to2) if len(to2) > 0 else 0} additions [energy units: {energy_units_str}]")

    return inew

Add or replace entries in the thermodynamic database from external files or DataFrames.

This function replicates the behavior of R CHNOSZ add.OBIGT() by loading CSV files from inst/extdata/OBIGT/ or accepting pandas DataFrames directly, and replacing entries with matching names.

Parameters

file : str or pd.DataFrame
Either: - Name of the database file to load (e.g., "SUPCRT92") The function will look for file.csv in inst/extdata/OBIGT/ - Full path to a CSV file - A pandas DataFrame containing OBIGT data
force : bool, default True
If True, proceed even if some species are not found
messages : bool, default True
If True, print informational messages about additions/replacements If False, suppress all output (equivalent to R's suppressMessages())

Returns

list of int
List of species indices (1-based) that were added or replaced

Examples

>>> import pychnosz
>>> import pandas as pd
>>>
>>> # Example 1: Load from file name
>>> pychnosz.reset()
>>> indices = pychnosz.add_OBIGT("SUPCRT92")
>>>
>>> # Example 2: Load from DataFrame
>>> thermo_df = pd.read_csv("thermodata.csv")
>>> indices = pychnosz.add_OBIGT(thermo_df)
>>>
>>> # Example 3: Suppress messages
>>> indices = pychnosz.add_OBIGT(thermo_df, messages=False)

Notes

This function modifies the thermo() object in place, replacing entries with matching names and adding new entries for species not in the database. The behavior exactly matches R CHNOSZ add.OBIGT().

def add_legend(ax,
labels: list = None,
loc: str = 'best',
frameon: bool = False,
fontsize: float = 9,
**kwargs)
Expand source code
def add_legend(ax, labels: list = None, loc: str = 'best',
              frameon: bool = False, fontsize: float = 9, **kwargs):
    """
    Add a legend to a diagram with matplotlib or Plotly formatting.

    This is a convenience function that adds a legend with sensible
    defaults matching R CHNOSZ legend styling. Works with both matplotlib
    and Plotly figures.

    Parameters
    ----------
    ax : matplotlib.axes.Axes or plotly.graph_objs.Figure
        Axes/Figure object to add legend to. For interactive diagrams,
        pass the figure from d['fig'] or d['ax'].
    labels : list of str
        Legend labels (can be from describe_property, describe_basis, etc.)
    loc : str, default 'best'
        Legend location. Options: 'best', 'upper left', 'upper right',
        'lower left', 'lower right', 'right', 'center left', 'center right',
        'lower center', 'upper center', 'center'
        For Plotly: 'best' defaults to 'lower right'
    frameon : bool, default False
        Whether to draw a frame around the legend (R bty="n" equivalent)
    fontsize : float, default 9
        Font size for legend text (R cex=0.9 equivalent)
    **kwargs
        Additional arguments passed to matplotlib legend() or Plotly annotation

    Returns
    -------
    matplotlib.legend.Legend or plotly.graph_objs.Figure
        The legend object (matplotlib) or the figure (Plotly)

    Examples
    --------
    >>> from pychnosz.utils.expression import add_legend, describe_property
    >>> # Matplotlib diagram with plot_it=False
    >>> d1 = diagram(a, interactive=False, plot_it=False)
    >>> dprop = describe_property(["T", "P"], [300, 1000])
    >>> add_legend(d1['ax'], dprop, loc='lower right')
    >>> # Display the figure in Jupyter:
    >>> from IPython.display import display
    >>> display(d1['fig'])
    >>> # Or save it:
    >>> d1['fig'].savefig('diagram.png')

    >>> # Plotly diagram
    >>> d1 = diagram(a, interactive=True, plot_it=False)
    >>> dprop = describe_property(["T", "P"], [300, 1000])
    >>> add_legend(d1['fig'], dprop, loc='lower right')
    >>> d1['fig'].show()

    Notes
    -----
    Common R legend locations and their matplotlib equivalents:
    - "bottomright" → "lower right"
    - "topleft" → "upper left"
    - "topright" → "upper right"
    - "bottomleft" → "lower left"

    When using plot_it=False, you need to explicitly display the figure after
    adding legends. In Jupyter notebooks, use display(d['fig']) or d['fig'].show()
    for Plotly diagrams. Outside Jupyter, use plt.show() or save with d['fig'].savefig().
    """
    if labels is None:
        raise ValueError("labels must be provided")

    # Detect if this is a Plotly figure
    is_plotly = _is_plotly_figure(ax)

    if is_plotly:
        return _add_plotly_legend(ax, labels, loc, frameon, fontsize, **kwargs)
    else:
        return _add_matplotlib_legend(ax, labels, loc, frameon, fontsize, **kwargs)

Add a legend to a diagram with matplotlib or Plotly formatting.

This is a convenience function that adds a legend with sensible defaults matching R CHNOSZ legend styling. Works with both matplotlib and Plotly figures.

Parameters

ax : matplotlib.axes.Axes or plotly.graph_objs.Figure
Axes/Figure object to add legend to. For interactive diagrams, pass the figure from d['fig'] or d['ax'].
labels : list of str
Legend labels (can be from describe_property, describe_basis, etc.)
loc : str, default 'best'
Legend location. Options: 'best', 'upper left', 'upper right', 'lower left', 'lower right', 'right', 'center left', 'center right', 'lower center', 'upper center', 'center' For Plotly: 'best' defaults to 'lower right'
frameon : bool, default False
Whether to draw a frame around the legend (R bty="n" equivalent)
fontsize : float, default 9
Font size for legend text (R cex=0.9 equivalent)
**kwargs
Additional arguments passed to matplotlib legend() or Plotly annotation

Returns

matplotlib.legend.Legend or plotly.graph_objs.Figure
The legend object (matplotlib) or the figure (Plotly)

Examples

>>> from pychnosz.utils.expression import add_legend, describe_property
>>> # Matplotlib diagram with plot_it=False
>>> d1 = diagram(a, interactive=False, plot_it=False)
>>> dprop = describe_property(["T", "P"], [300, 1000])
>>> add_legend(d1['ax'], dprop, loc='lower right')
>>> # Display the figure in Jupyter:
>>> from IPython.display import display
>>> display(d1['fig'])
>>> # Or save it:
>>> d1['fig'].savefig('diagram.png')
>>> # Plotly diagram
>>> d1 = diagram(a, interactive=True, plot_it=False)
>>> dprop = describe_property(["T", "P"], [300, 1000])
>>> add_legend(d1['fig'], dprop, loc='lower right')
>>> d1['fig'].show()

Notes

Common R legend locations and their matplotlib equivalents: - "bottomright" → "lower right" - "topleft" → "upper left" - "topright" → "upper right" - "bottomleft" → "lower left"

When using plot_it=False, you need to explicitly display the figure after adding legends. In Jupyter notebooks, use display(d['fig']) or d['fig'].show() for Plotly diagrams. Outside Jupyter, use plt.show() or save with d['fig'].savefig().

def add_protein(aa: pandas.core.frame.DataFrame, as_residue: bool = False) ‑> numpy.ndarray
Expand source code
def add_protein(aa: pd.DataFrame, as_residue: bool = False) -> np.ndarray:
    """
    Add protein amino acid compositions to thermo().protein.

    Parameters
    ----------
    aa : DataFrame
        DataFrame with protein amino acid compositions.
        Must have same columns as thermo().protein
    as_residue : bool, default False
        Normalize amino acid counts by protein length

    Returns
    -------
    array
        Row numbers of added/updated proteins in thermo().protein

    Examples
    --------
    >>> import pandas as pd
    >>> from pychnosz import *
    >>> aa = pd.read_csv("POLG.csv")
    >>> iprotein = add_protein(aa)
    """
    t = thermo()

    if t.protein is None:
        raise RuntimeError("Protein database not loaded. Run reset() first.")

    # Check that columns match
    if list(aa.columns) != list(t.protein.columns):
        raise ValueError("'aa' does not have the same columns as thermo().protein")

    # Check that new protein IDs are unique
    po = aa['protein'] + '_' + aa['organism']
    idup = po.duplicated()
    if idup.any():
        dup_proteins = po[idup].unique()
        raise ValueError(f"some protein IDs are duplicated: {' '.join(dup_proteins)}")

    # Normalize by protein length if as_residue = True
    if as_residue:
        pl = protein_length(aa)
        aa.iloc[:, 4:24] = aa.iloc[:, 4:24].div(pl, axis=0)

    # Find any protein IDs that are already present
    ip = pinfo(po.tolist())
    if isinstance(ip, (int, np.integer)):
        ip = np.array([ip])
    elif not isinstance(ip, np.ndarray):
        ip = np.array([ip])

    ip_present = ~np.isnan(ip)

    # Now we're ready to go
    tp_new = t.protein.copy()

    # Add new proteins
    if not all(ip_present):
        new_proteins = aa[~ip_present].copy()
        tp_new = pd.concat([tp_new, new_proteins], ignore_index=True)

    # Update existing proteins
    if any(ip_present):
        valid_ip = ip[ip_present].astype(int)
        tp_new.iloc[valid_ip] = aa[ip_present].values

    # Update the protein database
    tp_new.reset_index(drop=True, inplace=True)
    t.protein = tp_new

    # Return the new row numbers
    ip_new = pinfo(po.tolist())
    if isinstance(ip_new, (int, np.integer)):
        ip_new = np.array([ip_new])

    # Print messages
    n_added = sum(~ip_present)
    n_replaced = sum(ip_present)

    if n_added > 0:
        print(f"add_protein: added {n_added} new protein(s) to thermo().protein")
    if n_replaced > 0:
        print(f"add_protein: replaced {n_replaced} existing protein(s) in thermo().protein")

    return ip_new

Add protein amino acid compositions to thermo().protein.

Parameters

aa : DataFrame
DataFrame with protein amino acid compositions. Must have same columns as thermo().protein
as_residue : bool, default False
Normalize amino acid counts by protein length

Returns

array
Row numbers of added/updated proteins in thermo().protein

Examples

>>> import pandas as pd
>>> from pychnosz import *
>>> aa = pd.read_csv("POLG.csv")
>>> iprotein = add_protein(aa)
def affinity(messages: bool = True,
basis: pandas.core.frame.DataFrame | None = None,
species: pandas.core.frame.DataFrame | None = None,
iprotein: int | List[int] | numpy.ndarray | None = None,
loga_protein: float | List[float] = 0.0,
**kwargs) ‑> Dict[str, Any]
Expand source code
def affinity(messages: bool = True, basis: Optional[pd.DataFrame] = None,
             species: Optional[pd.DataFrame] = None, iprotein: Optional[Union[int, List[int], np.ndarray]] = None,
             loga_protein: Union[float, List[float]] = 0.0, **kwargs) -> Dict[str, Any]:
    """
    Calculate affinities of formation reactions.

    This function calculates chemical affinities for the formation reactions of
    species of interest from user-selected basis species. The affinities are
    calculated as A/2.303RT where A is the chemical affinity.

    Parameters
    ----------
    messages : bool, default True
        Whether to print informational messages
    basis : pd.DataFrame, optional
        Basis species definition to use (if not using global basis)
    species : pd.DataFrame, optional
        Species definition to use (if not using global species)
    iprotein : int, list of int, or array, optional
        Build proteins from residues (row numbers in thermo().protein)
    loga_protein : float or list of float, default 0.0
        Activity of proteins (log scale)
    **kwargs : dict
        Variable arguments defining calculation conditions:
        - Basis species names (e.g., CO2=[-60, 20, 5]): Variable basis species activities
        - T : float or list, Temperature in °C
        - P : float or list, Pressure in bar
        - property : str, Property to calculate ("A", "logK", "G", etc.)
        - exceed_Ttr : bool, Allow extrapolation beyond transition temperatures
        - exceed_rhomin : bool, Allow calculations below minimum water density
        - return_buffer : bool, Return buffer activities
        - balance : str, Balance method for protein buffers

    Returns
    -------
    dict
        Dictionary containing:
        - fun : str, Function name ("affinity")
        - args : dict, Arguments used in calculation
        - sout : dict, Subcrt calculation results
        - property : str, Property calculated
        - basis : pd.DataFrame, Basis species definition
        - species : pd.DataFrame, Species of interest definition
        - T : float or array, Temperature(s) in Kelvin
        - P : float or array, Pressure(s) in bar
        - vars : list, Variable names
        - vals : dict, Variable values
        - values : dict, Calculated affinity values by species

    Examples
    --------
    >>> import pychnosz
    >>> pychnosz.reset()
    >>> pychnosz.basis(["CO2", "H2O", "NH3", "H2S", "H+", "O2"])
    >>> pychnosz.species(["glycine", "tyrosine", "serine", "methionine"])
    >>> result = pychnosz.affinity(CO2=[-60, 20, 5], T=350, P=2000)
    >>> print(result['values'][1566])  # Glycine affinities

    >>> # With proteins
    >>> import pandas as pd
    >>> aa = pd.read_csv("POLG.csv")
    >>> iprotein = pychnosz.add_protein(aa)
    >>> pychnosz.basis("CHNOSe")
    >>> a = pychnosz.affinity(iprotein=iprotein, pH=[2, 14], Eh=[-1, 1])

    Notes
    -----
    This implementation maintains complete fidelity to R CHNOSZ affinity():
    - Identical argument processing including dynamic basis species parameters
    - Same variable expansion and multi-dimensional calculations
    - Exact energy() function behavior for property calculations
    - Identical output structure and formatting
    - Support for protein calculations via iprotein parameter
    """

    # Get thermo object for protein handling
    thermo_obj = thermo()

    # Handle iprotein parameter
    ires = None
    original_species = None
    if iprotein is not None:
        # Convert to array
        if isinstance(iprotein, (int, np.integer)):
            iprotein = np.array([iprotein])
        elif isinstance(iprotein, list):
            iprotein = np.array(iprotein)

        # Check all proteins are available
        if np.any(np.isnan(iprotein)):
            raise AffinityError("`iprotein` has some NA values")
        if thermo_obj.protein is None or not np.all(iprotein < len(thermo_obj.protein)):
            raise AffinityError("some value(s) of `iprotein` are not rownumbers of thermo().protein")

        # Add protein residues to the species list
        # Amino acids in 3-letter code
        aminoacids_3 = ["Ala", "Cys", "Asp", "Glu", "Phe", "Gly", "His", "Ile", "Lys", "Leu",
                        "Met", "Asn", "Pro", "Gln", "Arg", "Ser", "Thr", "Val", "Trp", "Tyr"]

        # Use _RESIDUE notation (matches R CHNOSZ affinity.R line 84)
        resnames_residue = ["H2O_RESIDUE"] + [f"{aa}_RESIDUE" for aa in aminoacids_3]

        # Save original species
        from .species import species as species_func
        original_species = get_species() if is_species_defined() else None

        # Add residue species with activity 0 (all in "aq" state)
        species_func(resnames_residue, state="aq", add=True, messages=messages)

        # Get indices of residues in species list
        species_df_temp = get_species()
        ires = []
        for name in resnames_residue:
            idx = np.where(species_df_temp['name'] == name)[0]
            if len(idx) > 0:
                ires.append(idx[0])
        ires = np.array(ires)

    # Check if basis and species are defined (use provided or global)
    if basis is None:
        if not is_basis_defined():
            raise AffinityError("basis species are not defined")
        basis_df = get_basis()
    else:
        basis_df = basis

    if species is None:
        if not is_species_defined():
            raise AffinityError("species are not defined")
        species_df = get_species()
    else:
        species_df = species

    # Process arguments
    args_orig = dict(kwargs)

    # Handle argument recall (if first argument is previous affinity result)
    if len(args_orig) > 0:
        first_key = list(args_orig.keys())[0]
        first_value = args_orig[first_key]
        if (isinstance(first_value, dict) and
            first_value.get('fun') == 'affinity'):
            # Update arguments from previous result
            aargs = first_value.get('args', {})
            # Update with new arguments (skip the first one)
            new_args = dict(list(args_orig.items())[1:])
            aargs.update(new_args)
            return affinity(**aargs)

    # Process energy arguments
    args = energy_args(args_orig, messages, basis_df=basis_df)

    # Get property to calculate
    property_name = args.get('what', 'A')

    # Get thermo data
    thermo_obj = thermo()
    # basis_df and species_df are already set above

    # Determine if we need specific property calculation
    if property_name and property_name != 'A':
        # Calculate specific property using energy function
        energy_result = energy(
            what=property_name,
            vars=args['vars'],
            vals=args['vals'],
            lims=args['lims'],
            T=args['T'],
            P=args['P'],
            IS=args.get('IS', 0),
            exceed_Ttr=kwargs.get('exceed_Ttr', True),
            exceed_rhomin=kwargs.get('exceed_rhomin', False),
            basis_df=basis_df,
            species_df=species_df,
            messages=messages
        )
        affinity_values = energy_result['a']
        energy_sout = energy_result['sout']
    else:
        # Calculate affinities (A/2.303RT)
        energy_result = energy(
            what='A',
            vars=args['vars'],
            vals=args['vals'],
            lims=args['lims'],
            T=args['T'],
            P=args['P'],
            IS=args.get('IS', 0),
            exceed_Ttr=kwargs.get('exceed_Ttr', True),
            exceed_rhomin=kwargs.get('exceed_rhomin', False),
            basis_df=basis_df,
            species_df=species_df,
            messages=messages
        )
        affinity_values = energy_result['a']
        energy_sout = energy_result['sout']

    # Handle protein affinity calculations if iprotein was provided
    if iprotein is not None and ires is not None:
        # Calculate protein affinities from residue affinities using group additivity
        # Normalize loga_protein to match number of proteins
        if isinstance(loga_protein, (int, float)):
            loga_protein_arr = np.full(len(iprotein), loga_protein)
        else:
            loga_protein_arr = np.array(loga_protein)
            if len(loga_protein_arr) < len(iprotein):
                loga_protein_arr = np.resize(loga_protein_arr, len(iprotein))

        # Calculate affinity for each protein
        protein_affinities = {}

        for ip, iprot in enumerate(iprotein):
            # Get protein amino acid composition from thermo().protein
            # Columns 4:24 contain chains and amino acid counts (0-indexed: columns 4-23)
            protein_row = thermo_obj.protein.iloc[iprot]
            aa_counts = protein_row.iloc[4:24].values.astype(float)

            # Calculate protein affinity by summing residue affinities weighted by composition
            # affinity_values keys are ispecies indices
            # Get the ispecies for each residue
            species_df_current = get_species()
            residue_ispecies = species_df_current.iloc[ires]['ispecies'].values

            # Initialize protein affinity with same shape as residue affinities
            first_residue_key = residue_ispecies[0]
            if first_residue_key in affinity_values:
                template_affinity = affinity_values[first_residue_key]
                protein_affinity = np.zeros_like(template_affinity)

                # Sum up contributions from all residues
                for i, res_ispecies in enumerate(residue_ispecies):
                    if res_ispecies in affinity_values:
                        residue_contrib = affinity_values[res_ispecies] * aa_counts[i]
                        protein_affinity = protein_affinity + residue_contrib

                # Subtract protein activity
                protein_affinity = protein_affinity - loga_protein_arr[ip]

                # Use negative index to denote protein (matches R CHNOSZ convention)
                protein_key = -(iprot + 1)  # Negative of (row number + 1)
                protein_affinities[protein_key] = protein_affinity

        # Add ionization affinity if H+ is in basis (matching R CHNOSZ behavior)
        if 'H+' in basis_df.index:
            if messages:
                print("affinity: ionizing proteins ...")

            # Get protein amino acid compositions
            from ..biomolecules.proteins import pinfo
            from ..biomolecules.ionize_aa import ionize_aa

            # Get aa compositions for these proteins
            aa = pinfo(iprotein)

            # Determine pH values from vars/vals or basis
            # Check if H+ is a variable
            if 'H+' in args['vars']:
                # H+ is a variable - get pH from vals
                iHplus = args['vars'].index('H+')
                pH_vals = -np.array(args['vals'][iHplus])  # pH = -log(a_H+)
            else:
                # H+ is constant - get from basis
                pH_val = -basis_df.loc['H+', 'logact']  # pH = -log(a_H+)
                pH_vals = np.array([pH_val])

            # Get T values (already processed earlier)
            T_vals = args['T']
            if isinstance(T_vals, (int, float)):
                T_celsius = T_vals - 273.15
            else:
                T_celsius = T_vals - 273.15

            # Get P values
            P_vals = args['P']

            # Calculate ionization affinity
            # ionize_aa expects arrays, so ensure T, P, pH are properly shaped
            # For grid calculations, we need to expand T, P, pH into a grid matching the affinity grid
            if len(args['vars']) >= 2:
                # Multi-dimensional case - create grid
                # Figure out which vars are T, P, H+
                var_names = args['vars']
                has_T_var = 'T' in var_names
                has_P_var = 'P' in var_names
                has_Hplus_var = 'H+' in var_names

                # Build T, P, pH grids matching the affinity calculation grid
                if has_T_var and has_Hplus_var:
                    # Both T and pH vary - create meshgrid
                    T_grid, pH_grid = np.meshgrid(T_celsius, pH_vals, indexing='ij')
                    T_flat = T_grid.flatten()
                    pH_flat = pH_grid.flatten()
                    if isinstance(P_vals, str):
                        P_flat = np.array([P_vals] * len(T_flat))
                    else:
                        P_flat = np.full(len(T_flat), P_vals if isinstance(P_vals, (int, float)) else P_vals[0])
                elif has_T_var:
                    # Only T varies
                    T_flat = T_celsius if isinstance(T_celsius, np.ndarray) else np.array([T_celsius])
                    pH_flat = np.full(len(T_flat), pH_vals[0])
                    P_flat = np.array([P_vals] * len(T_flat)) if isinstance(P_vals, str) else np.full(len(T_flat), P_vals if isinstance(P_vals, (int, float)) else P_vals[0])
                elif has_Hplus_var:
                    # Only pH varies
                    pH_flat = pH_vals
                    T_flat = np.full(len(pH_flat), T_celsius if isinstance(T_celsius, (int, float)) else T_celsius[0])
                    P_flat = np.array([P_vals] * len(pH_flat)) if isinstance(P_vals, str) else np.full(len(pH_flat), P_vals if isinstance(P_vals, (int, float)) else P_vals[0])
                else:
                    # No T or pH variables
                    T_flat = np.array([T_celsius if isinstance(T_celsius, (int, float)) else T_celsius[0]])
                    pH_flat = pH_vals
                    P_flat = np.array([P_vals] if isinstance(P_vals, str) else [P_vals if isinstance(P_vals, (int, float)) else P_vals[0]])
            else:
                # Single or no variable case
                T_flat = np.array([T_celsius if isinstance(T_celsius, (int, float)) else T_celsius[0]])
                pH_flat = pH_vals if isinstance(pH_vals, np.ndarray) else np.array([pH_vals[0] if hasattr(pH_vals, '__getitem__') else pH_vals])
                P_flat = np.array([P_vals] if isinstance(P_vals, str) else [P_vals if isinstance(P_vals, (int, float)) else P_vals[0]])

            # Call ionize_aa to get ionization affinity
            ionization_result = ionize_aa(aa, property="A", T=T_flat, P=P_flat, pH=pH_flat)

            # Add ionization affinity to formation affinity for each protein
            for ip, iprot in enumerate(iprotein):
                protein_key = -(iprot + 1)
                ionization_affinity = ionization_result.iloc[:, ip].values

                # Reshape to match formation affinity dimensions if needed
                formation_affinity = protein_affinities[protein_key]
                if isinstance(formation_affinity, np.ndarray):
                    if formation_affinity.shape != ionization_affinity.shape:
                        # Reshape ionization affinity to match formation affinity
                        ionization_affinity = ionization_affinity.reshape(formation_affinity.shape)

                # Add ionization to formation affinity
                protein_affinities[protein_key] = formation_affinity + ionization_affinity

        # Replace affinity_values with protein affinities
        affinity_values = protein_affinities

        # Calculate stoichiometric coefficients for proteins using matrix multiplication
        # This matches R CHNOSZ: protbasis <- t(t((resspecies[ires, 1:nrow(thermo$basis)])) %*% t((thermo$protein[iprotein, 5:25])))
        # IMPORTANT: Get the species list BEFORE deletion
        species_df_with_residues = get_species()

        # Extract basis species coefficients from residue species (rows = residues, cols = basis species)
        # ires contains indices of residues in the species list
        # We need the columns corresponding to basis species
        basis_cols = list(basis_df.index)  # e.g., ['CO2', 'H2O', 'NH3', 'H2S', 'e-', 'H+']

        # Create residue coefficient matrix (n_residues x n_basis)
        # resspecies[ires, 1:nrow(thermo$basis)] in R
        res_coeffs = species_df_with_residues.iloc[ires][basis_cols].values.astype(float)

        # Get amino acid composition matrix (n_proteins x n_residues)
        # thermo$protein[iprotein, 5:25] in R (columns 5-25 contain chains and 20 amino acids)
        # In Python (0-indexed): columns 4:24 contain chains and 20 amino acids
        aa_composition = []
        for iprot in iprotein:
            protein_row = thermo_obj.protein.iloc[iprot]
            # Columns 4:24 contain: chains, Ala, Cys, Asp, Glu, Phe, Gly, His, Ile, Lys, Leu,
            #                       Met, Asn, Pro, Gln, Arg, Ser, Thr, Val, Trp, Tyr
            aa_counts = protein_row.iloc[4:24].values.astype(float)
            aa_composition.append(aa_counts)
        aa_composition = np.array(aa_composition)  # Shape: (n_proteins, 21)

        # Matrix multiplication: (n_proteins x 21) @ (21 x n_basis) = (n_proteins x n_basis)
        # Note: res_coeffs has shape (21, n_basis) - first row is H2O, next 20 are amino acids
        # R code: t(t(resspecies) %*% t(protein)) means: (n_basis x n_residues) @ (n_residues x n_proteins) = (n_basis x n_proteins)
        # Then transpose to get (n_proteins x n_basis)
        # In Python: (n_proteins x n_residues) @ (n_residues x n_basis) = (n_proteins x n_basis)
        protein_coeffs = aa_composition @ res_coeffs  # Shape: (n_proteins, n_basis)

        # Delete residue species from species list now that we have the coefficients
        from .species import species as species_func
        species_func(ires.tolist(), delete=True, messages=False)

        if original_species is not None:
            # Restore original species (but we've already calculated, so just update species_df)
            pass

        # Create DataFrame for proteins with basis species coefficients
        species_data = {}

        # Add basis species columns
        for j, basis_sp in enumerate(basis_cols):
            species_data[basis_sp] = protein_coeffs[:, j]

        # Add metadata columns
        protein_names = []
        protein_ispecies = []

        for iprot in iprotein:
            prot_row = thermo_obj.protein.iloc[iprot]
            # Escape underscores for LaTeX compatibility in diagram labels
            protein_name = f"{prot_row['protein']}_{prot_row['organism']}"
            # Replace underscores with escaped version for matplotlib/LaTeX
            protein_name_escaped = protein_name.replace('_', r'\_')
            protein_names.append(protein_name_escaped)
            protein_ispecies.append(-(iprot + 1))  # Negative index

        species_data['ispecies'] = protein_ispecies
        species_data['logact'] = loga_protein_arr[:len(iprotein)]
        species_data['state'] = ['aq'] * len(iprotein)
        species_data['name'] = protein_names

        species_df = pd.DataFrame(species_data)

    # Process temperature and pressure for output
    T_out = args['T']
    P_out = args['P']
    vars_list = args['vars']
    vals_dict = {}

    # Convert variable names and values for output
    # Important: Keep vars_list with actual basis species names (H+, e-) for internal use
    # but create display versions in vals_dict with user-friendly names (pH, pe, Eh)
    vars_list_display = vars_list.copy()
    for i, var in enumerate(vars_list):
        # Handle pH, pe, Eh conversions for output
        if var == 'H+' and 'pH' in args_orig:
            vars_list_display[i] = 'pH'
            vals_dict['pH'] = [-val for val in args['vals'][i]]
        elif var == 'e-' and 'pe' in args_orig:
            vars_list_display[i] = 'pe'
            vals_dict['pe'] = [-val for val in args['vals'][i]]
        elif var == 'e-' and 'Eh' in args_orig:
            vars_list_display[i] = 'Eh'
            # Convert from log(a_e-) back to Eh using temperature-dependent formula
            # log(a_e-) = -pe, so pe = -log(a_e-)
            # Eh = pe * (ln(10) * R * T) / F = -log(a_e-) * T / 5039.76
            T_kelvin = args['T'] if isinstance(args['T'], (int, float)) else args['T'][0] if hasattr(args['T'], '__len__') else 298.15
            conversion_factor = T_kelvin / 5039.76  # volts per pe unit
            vals_dict['Eh'] = [-val * conversion_factor for val in args['vals'][i]]
        else:
            vals_dict[var] = args['vals'][i]

    # Keep vars_list as-is (with basis species names) for internal calculations
    # vars_list_display will be used for output only

    # Check if T or P are variables
    if 'T' in vars_list:
        T_out = []  # Variable T
        # Convert back to Celsius for output
        T_vals = vals_dict['T']
        vals_dict['T'] = [T - 273.15 for T in T_vals]
    else:
        # Convert to Kelvin for output (matching R)
        T_out = args['T']

    if 'P' in vars_list:
        P_out = []  # Variable P
    else:
        P_out = args['P']

    # Build output dictionary matching R CHNOSZ structure
    result = {
        'fun': 'affinity',
        'args': {
            **args_orig,
            'property': property_name,
            'exceed_Ttr': kwargs.get('exceed_Ttr', False),
            'exceed_rhomin': kwargs.get('exceed_rhomin', False),
            'return_buffer': kwargs.get('return_buffer', False),
            'balance': kwargs.get('balance', 'PBB')
        },
        'sout': energy_sout,
        'property': property_name,
        'basis': basis_df,
        'species': species_df,
        'T': T_out,
        'P': P_out,
        'vars': vars_list_display,  # Use display version with 'Eh', 'pH', 'pe' for output
        'vals': vals_dict,
        'values': affinity_values
    }

    return result

Calculate affinities of formation reactions.

This function calculates chemical affinities for the formation reactions of species of interest from user-selected basis species. The affinities are calculated as A/2.303RT where A is the chemical affinity.

Parameters

messages : bool, default True
Whether to print informational messages
basis : pd.DataFrame, optional
Basis species definition to use (if not using global basis)
species : pd.DataFrame, optional
Species definition to use (if not using global species)
iprotein : int, list of int, or array, optional
Build proteins from residues (row numbers in thermo().protein)
loga_protein : float or list of float, default 0.0
Activity of proteins (log scale)
**kwargs : dict
Variable arguments defining calculation conditions: - Basis species names (e.g., CO2=[-60, 20, 5]): Variable basis species activities - T : float or list, Temperature in °C - P : float or list, Pressure in bar - property : str, Property to calculate ("A", "logK", "G", etc.) - exceed_Ttr : bool, Allow extrapolation beyond transition temperatures - exceed_rhomin : bool, Allow calculations below minimum water density - return_buffer : bool, Return buffer activities - balance : str, Balance method for protein buffers

Returns

dict
Dictionary containing: - fun : str, Function name ("affinity") - args : dict, Arguments used in calculation - sout : dict, Subcrt calculation results - property : str, Property calculated - basis : pd.DataFrame, Basis species definition - species : pd.DataFrame, Species of interest definition - T : float or array, Temperature(s) in Kelvin - P : float or array, Pressure(s) in bar - vars : list, Variable names - vals : dict, Variable values - values : dict, Calculated affinity values by species

Examples

>>> import pychnosz
>>> pychnosz.reset()
>>> pychnosz.basis(["CO2", "H2O", "NH3", "H2S", "H+", "O2"])
>>> pychnosz.species(["glycine", "tyrosine", "serine", "methionine"])
>>> result = pychnosz.affinity(CO2=[-60, 20, 5], T=350, P=2000)
>>> print(result['values'][1566])  # Glycine affinities
>>> # With proteins
>>> import pandas as pd
>>> aa = pd.read_csv("POLG.csv")
>>> iprotein = pychnosz.add_protein(aa)
>>> pychnosz.basis("CHNOSe")
>>> a = pychnosz.affinity(iprotein=iprotein, pH=[2, 14], Eh=[-1, 1])

Notes

This implementation maintains complete fidelity to R CHNOSZ affinity(): - Identical argument processing including dynamic basis species parameters - Same variable expansion and multi-dimensional calculations - Exact energy() function behavior for property calculations - Identical output structure and formatting - Support for protein calculations via iprotein parameter

def animation(basis_args={},
species_args={},
affinity_args={},
equilibrate_args=None,
diagram_args={},
anim_var='T',
anim_range=[0, 350, 8],
xlab=None,
ylab=None,
save_as='newanimationframe',
save_format='png',
height=300,
width=400,
save_scale=1,
messages=False)
Expand source code
def animation(basis_args={}, species_args={}, affinity_args={},
              equilibrate_args=None, diagram_args={},
              anim_var="T", anim_range=[0, 350, 8], xlab=None, ylab=None,
              save_as="newanimationframe", save_format="png", height=300,
              width=400, save_scale=1,
              messages=False):
    
    """
    Produce an animated interactive affinity, activity, or predominance diagram.
    
    Parameters
    ----------
    basis_args : dict
        Dictionary of options for defining basis species (see `basis`) in the
        animated diagram.
        Example: basis_args={'species':['CO2', 'O2', 'H2O', 'H+']}

    species_args : dict
        Dictionary of options for defining species (see `species`) in the
        animated diagram, or a list of dicts.
        Example 1: species_args={'species':['CO2', 'HCO3-', 'CO3-2']}
        Example 2: species_args=[
                {'species':['CO2', 'HCO3-', 'CO3-2'], 'state':[-4]},
                {'species':['graphite'], state:[0], 'add':True}]

    affinity_args : dict
        Dictionary of options for defining the affinity calculation (see
        `affinity`).
        Example: affinity_args={"pH":[2, 12, 100]}
        Example: affinity_args={"pH":[2, 12, 100], "P":[2000, 4000, 100]}
    
    equilibrate_args : dict or None, default None
        Dictionary of options for defining equilibration calculation
        (see `equilibrate`). If None, plots output from `affinity`.
        Example: equilibrate_args={"balance":1}
    
    diagram_args : dict
        Dictionary of options for diagramming (see `diagram`). Diagram option
        `interactive` is set to True.
        Example: diagram_args={"alpha":True}
    
    anim_var : str, default "T"
        Variable that changes with each frame of animation.
    
    anim_range : list of numeric, default [0, 350, 8]
        The first two numbers in the list are the starting and ending
        values for `anim_var`. The third number in the list is the desired
        number of animation frames.

    xlab, ylab : str, optional
        Custom names for the X and Y axes.
    
    messages : bool, default True
        Display messages from CHNOSZ?
    
    Returns
    -------
    An interactive animated plot.
    """
    
    # cap number of frames in animation. Remove limitation after more testing.
    if isinstance(anim_range, list):
        if len(anim_range) == 3:
            if anim_range[2] > 30:
                raise Exception("anim_range is limited to 30 frames.")
        else:
            raise Exception("anim_range must be a list with three values: starting "
                            "value of anim_var, stopping value, and number of "
                            "frames in the animation")
    else:
        raise Exception("anim_range must be a list with three values: starting "
                        "value of anim_var, stopping value, and number of "
                        "frames in the animation")

    if isinstance(basis_args, dict):
        if "species" not in basis_args.keys():
            raise Exception("basis_args needs to contain a list of species for 'species'. "
                            "Example: basis_args={'species':['CO2', 'O2', 'H2O', 'H+']}")
    else:
        raise Exception("basis_args needs to be a Python dictionary with a key "
                        "called 'species' (additional keys are optional). "
                        "Example: basis_args={'species':['CO2', 'O2', 'H2O', 'H+']}")


    # Add messages parameter to basis_args if not already present
    if "messages" not in basis_args.keys():
        basis_args["messages"] = messages

    basis_out = basis(**basis_args)
    basis_sp = list(basis_out.index)
    basis_state = list(basis_out["state"])
    
    if isinstance(species_args, dict):
        if "species" not in species_args.keys():
            raise Exception("species_args needs to contain a list of species for 'species'. "
                            "Example: species_args={'species':['CO2', 'HCO3-', 'CO3-2']}")
        species_args_list = [species_args]
    elif isinstance(species_args, list):
        species_args_list = species_args
        for species_args in species_args_list:
            if "species" not in species_args.keys():
                raise Exception("species_args needs to contain a list of species for 'species'. "
                                "Example: species_args={'species':['CO2', 'HCO3-', 'CO3-2']}")
    else:
        raise Exception("species_args needs to be either a Python dictionary with a key "
                        "called 'species' (additional keys are optional). "
                        "Example: species_args={'species':['CO2', 'HCO3-', 'CO3-2']}"
                        "or else species_args needs to be a list of Python dictionaries."
                        "Example: species_args=[{'species':['CO2', 'HCO3-', 'CO3-2'], 'state':[-4]},"
                        "{'species':['graphite'], state:[0], 'add':True}]")

    # There may be multiple arguments passed to species, especially in cases
    # where add=True. Loop through all the arguments to apply them.
    for species_args in species_args_list:
        if "logact" in species_args.keys():
            mod_species_logact = copy.copy(species_args['logact'])
            del species_args['logact']
        else:
            mod_species_logact = []

        # Add messages parameter to species_args if not already present
        if "messages" not in species_args.keys():
            species_args["messages"] = messages

        species_out = species(**species_args)

        if len(mod_species_logact)>0:
            for i in range(0, len(mod_species_logact)) :
                species_out = species(species_args["species"][i], mod_species_logact[i], messages=messages)

    sp = list(species_out["name"])

    if isinstance(sp[0], (int, np.integer)):
        sp = [info(s, messages=messages)["name"].values[0] for s in sp]

    dfs = []
    dmaps = []
    dmaps_names = []
    
    if len(anim_range) == 2:
        anim_res = 8
        anim_range = anim_range + [anim_res]
    elif len(anim_range) == 3:
        anim_res = anim_range[2]
        anim_range = [anim_range[0], anim_range[1]]
    
    zvals = __seq(anim_range[0], anim_range[1], length_out=anim_res)
        
    if "messages" not in affinity_args.keys():
        affinity_args["messages"] = messages
    if "messages" not in diagram_args.keys():
        diagram_args["messages"] = messages
    if "plot_it" not in diagram_args.keys():
        diagram_args["plot_it"] = False
    diagram_args["interactive"] = True
    if "format_names" not in diagram_args.keys():
        format_names=True
        format_x_names=True
        format_y_names=True
    
    for z in zvals:

        if anim_var in basis_out.index:
            basis_out = basis(anim_var, z, messages=messages)
        elif anim_var in list(species_out["name"]):
            species_out = species(anim_var, -z, messages=messages)
        elif anim_var == "pH":
            basis_out = basis("H+", -z, messages=messages)
        else:
            affinity_args[anim_var] = z
        
        aeout = affinity(**affinity_args)

        if equilibrate_args != None:
            equilibrate_args["aout"] = aeout
            if "messages" not in equilibrate_args.keys():
                equilibrate_args["messages"] = messages
            aeout = equilibrate(**equilibrate_args)

        # Get affinity arguments from the result dictionary
        aeout_args = aeout.get("args", {})
        xvar = list(aeout_args.keys())[0]
        xrange = list(aeout_args[xvar])

        res_default = 256 # default affinity resolution
        if len(xrange) == 3:
            xres = int(xrange[2])
        else:
            xres = res_default
        
        diagram_args["eout"] = aeout

        # Use diagram_interactive since interactive=True is set
        # Remove 'interactive' key as diagram_interactive doesn't need it
        diagram_args_copy = diagram_args.copy()
        diagram_args_copy.pop('interactive', None)
        df, fig = diagram_interactive(**diagram_args_copy)

        # Check if this is a predominance plot (2D) or affinity/activity plot (1D)
        if 'pred' not in df.columns:
            # affinity/activity plot (1D) - melt to long format for animation
            is_predom_plot = False
            id_vars = [xvar]  # Keep the x-variable as identifier
            value_vars = [col for col in df.columns if col != xvar]  # All species columns
            df_melted = df.melt(id_vars=id_vars, value_vars=value_vars,
                               var_name='variable', value_name='value')
            df_melted[anim_var] = z
            dfs.append(df_melted)
        else:
            # predominance plot (2D) - keep original format with pred and prednames
            is_predom_plot = True
            df[anim_var] = z
            dfs.append(df)
            yvar = list(aeout_args.keys())[1]
            yrange = list(aeout_args[yvar])
            if len(yrange) == 3:
                yres = int(yrange[2])
            else:
                yres = res_default

            data = np.array(df.pred)
            shape = (xres, yres)
            dmap = data.reshape(shape)
            dmaps.append(dmap)

            data = np.array(df.prednames)
            shape = (xres, yres)
            dmap_names = data.reshape(shape)
            dmaps_names.append(dmap_names)
        
    xvals = __seq(xrange[0], xrange[1], length_out=xres)


    unit_dict = {"P":"bar", "T":"°C", "pH":"", "Eh":"volts", "IS":"mol/kg"}
    
    if any([anim_var in basis_out.index, anim_var in list(species_out["name"])]) and anim_var not in unit_dict.keys():
        unit_dict[anim_var] = "logact "+anim_var

    for i,s in enumerate(basis_sp):
        if basis_state[i] in ["aq", "liq", "cr"]:
            if format_names:
                unit_dict[s] = "log <i>a</i><sub>{}</sub>".format(chemlabel(s))
            else:
                unit_dict[s] = "log <i>a</i><sub>{}</sub>".format(s)
        else:
            if format_names:
                unit_dict[s] = "log <i>f</i><sub>{}</sub>".format(chemlabel(s))
            else:
                unit_dict[s] = "log <i>f</i><sub>{}</sub>".format(s)

    xlab = xvar+", "+unit_dict[xvar]
    
    if xvar in basis_sp:
        xlab = unit_dict[xvar]
    if xvar == "pH":
        xlab = "pH"
    
    if is_predom_plot:
        ylab = yvar+", "+unit_dict[yvar]
        if yvar in basis_sp:
            ylab = unit_dict[yvar]
        if yvar == "pH":
            yvar = "pH"

    
    if not is_predom_plot:

        if 'loga.equil' not in aeout.keys():
            yvar = "A/(2.303RT)"
        else:
            yvar = "log a"
        if "alpha" in diagram_args.keys():
            if diagram_args["alpha"]:
                yvar = "alpha"
        
        df_c = pd.concat(dfs)

        if "fill" in diagram_args.keys():
            if isinstance(diagram_args["fill"], list):
                colormap = {key:col for key,col in zip(list(dict.fromkeys(df_c["variable"])), diagram_args["fill"])}
            else:
                colormap = diagram_args["fill"]

            # with color mapping
            fig = px.line(df_c, x=xvar, y="value", color='variable', template="simple_white",
                          width=500,  height=400, animation_frame=anim_var,
                          color_discrete_map = colormap,
                          labels=dict(value=yvar, x=xvar),
                         )
        else:
            # without color mapping
            fig = px.line(df_c, x=xvar, y="value", color='variable', template="simple_white",
                          width=500,  height=400, animation_frame=anim_var,
                          labels=dict(value=yvar, x=xvar),
                         )
        
        if "annotation" in diagram_args.keys():
            if "annotation_coords" not in diagram_args.keys():
                diagram_args["annotation_coords"] = [0, 0]
            fig.add_annotation(x=diagram_args["annotation_coords"][0],
                               y=diagram_args["annotation_coords"][1],
                               xref="paper",
                               yref="paper",
                               align='left',
                               text=diagram_args["annotation"],
                               bgcolor="rgba(255, 255, 255, 0.5)",
                               showarrow=False)
        
        if 'main' in diagram_args.keys():
            fig.update_layout(title={'text':diagram_args["main"], 'x':0.5, 'xanchor':'center'})

        if isinstance(xlab, str):
            fig.update_layout(xaxis_title=xlab)
        if isinstance(ylab, str):
            fig.update_layout(yaxis_title=ylab)
        
        if 'fill' in diagram_args.keys():
            if isinstance(diagram_args["fill"], list):
                for i,v in enumerate(diagram_args["fill"]):
                    fig['data'][i]['line']['color']=v
        
        fig.update_layout(legend_title=None)

        config = {'displaylogo': False,
                  'modeBarButtonsToRemove': ['resetScale2d', 'toggleSpikelines'],
                  'toImageButtonOptions': {
                                             'format': save_format, # one of png, svg, jpeg, webp
                                             'filename': save_as,
                                             'height': height,
                                             'width': width,
                                             'scale': save_scale,
                                          },
                 }

        fig.show(config=config)
        return
    
    else:
        yvals = __seq(yrange[0], yrange[1], length_out=yres)


    
    frames = []
    slider_steps = []
    annotations = []
    cst_data = []
    heatmaps = []
    
    # i is a frame in the animation
    for i in range(0, len(zvals)):

        annotations_i = []
        for s in sp:
            if s in set(dfs[i]["prednames"]):
                # if an annotation should appear, create one for this frame
                df_s = dfs[i].loc[dfs[i]["prednames"]==s,]
                namex = df_s[xvar].mean()
                namey = df_s[yvar].mean()
                a = go.layout.Annotation(
                    x=namex,
                    y=namey,
                    xref="x",
                    yref="y",
                    text=chemlabel(s),
                    bgcolor="rgba(255, 255, 255, 0.5)",
                    showarrow=False,
                    )
            else:
                # if an annotation shouldn't appear, make an invisible annotation
                # (workaround for a plotly bug where annotations won't clear in an animation)
                namex = statistics.mean(xvals)
                namey = statistics.mean(yvals)
                a = go.layout.Annotation(
                    x=namex,
                    y=namey,
                    xref="x",
                    yref="y",
                    text="",
                    bgcolor="rgba(255, 255, 255, 0)",
                    showarrow=False,
                    )
            annotations_i.append(a)
            
        # allows adding a custom annotation; append to frame
        if "annotation" in diagram_args.keys():
            if "annotation_coords" not in diagram_args.keys():
                diagram_args["annotation_coords"] = [0, 0]
            custom_annotation = go.layout.Annotation(
                    x=diagram_args["annotation_coords"][0],
                    y=diagram_args["annotation_coords"][1],
                    xref="paper",
                    yref="paper",
                    align='left',
                    text=diagram_args["annotation"],
                    bgcolor="rgba(255, 255, 255, 0.5)",
                    showarrow=False,
                    )
            annotations_i.append(custom_annotation)
    
        annotations.append(annotations_i)

        if 'ylab' in diagram_args.keys():
            ylab = diagram_args["ylab"]
            hover_ylab = ylab+': %{y} '
        else:
            ylab = chemlabel(ylab)
            hover_ylab = yvar+': %{y} '+unit_dict[yvar]

        if 'xlab' in diagram_args.keys():
            xlab = diagram_args["xlab"]
            hover_xlab = xlab+': %{x} '
        else:
            xlab = chemlabel(xlab)
            hover_xlab = xvar+': %{x} '+unit_dict[xvar]
        
        heatmaps_i = go.Heatmap(z=dmaps[i], x=xvals, y=yvals, zmin=0, zmax=len(sp)-1,
                                customdata=dmaps_names[i],
                                hovertemplate=hover_xlab+'<br>'+hover_ylab+'<br>Region: %{customdata}<extra></extra>')

        heatmaps.append(heatmaps_i)

        frame = go.Frame(data=[heatmaps_i],
                         name=str(i),
                         layout=go.Layout(annotations=annotations_i))

        frames.append(frame)

        slider_step = dict(
            method='animate',
            label=zvals[i],
            value=i,
            args=[
                [i],
                dict(
                    frame=dict(duration=300, redraw=True),
                    mode='immediate',
                    transition=dict(duration=0)
                )
            ]
        )

        slider_steps.append(slider_step)

    fig = go.Figure(
        data = heatmaps[0],
        layout=go.Layout(
    #         title="Frame 0",
            title_x=0.5,
            width=500, height=500,
            annotations=annotations[0],
            sliders=[dict(
                active=0,
                yanchor='top',
                xanchor='left',
                currentvalue=dict(
                    font=dict(size=12),
                    prefix='{}: '.format(anim_var),
                    suffix=' '+unit_dict[anim_var],
                    visible=True,
                    xanchor='right'
                ),
                transition=dict(duration=0, easing='cubic-in-out'),
                pad=dict(b=10, t=50),
                len=0.9,
                x=0.1,
                y=0,
                steps=slider_steps
            )],
            updatemenus=[dict(
                type="buttons",
                buttons=[dict(label="Play",
                              method="animate",
                              args=[None, {"fromcurrent":True}]),
                        dict(label="Pause",
                             method="animate",
                             args=[[None],
                                   {"frame": {"duration": 0, "redraw": True},
                                    "mode": "immediate",
                                    "transition": {"duration": 0}}],
                             )],
                direction="left",
                pad={"r": 10, "t": 87},
                showactive=False,
                x=0.1,
                xanchor="right",
                y=0,
                yanchor="top",
            )]
        ),
        frames=frames

    )


    if 'fill' in diagram_args.keys():
        if isinstance(diagram_args["fill"], list):
            colorscale_temp = []
            for i,v in enumerate(diagram_args["fill"]):
                colorscale_temp.append([i, v])
            colorscale = colorscale_temp
        elif isinstance(diagram_args["fill"], str):
            colorscale = diagram_args["fill"]
    else:
        colorscale = "viridis"
    
    fig.update_traces(dict(showscale=False,
                           colorscale=colorscale),
                      selector={'type':'heatmap'})
    
    fig.update_layout(
        xaxis_title=xlab,
        yaxis_title=ylab,
        xaxis={"range":[list(dfs[0][xvar])[0], list(dfs[0][xvar])[-1]]},
        yaxis={"range":[list(dfs[0][yvar])[0], list(dfs[0][yvar])[-1]]},
        margin={"t": 60, "r":60},
    )

    if 'main' in diagram_args.keys():
        fig.update_layout(title={'text':diagram_args['main'], 'x':0.5, 'xanchor':'center'})

    config = {'displaylogo': False,
              'modeBarButtonsToRemove': ['zoom2d', 'pan2d', 'zoomIn2d', 'zoomOut2d',
                                         'autoScale2d', 'toggleSpikelines',
                                         'hoverClosestCartesian', 'hoverCompareCartesian'],
              'toImageButtonOptions': {
                                       'format': save_format, # one of png, svg, jpeg, webp
                                       'filename': save_as,
                                       'height': height,
                                       'width': width,
                                       'scale': save_scale,
                                        },
             }
    

    fig.show(config=config)

Produce an animated interactive affinity, activity, or predominance diagram.

Parameters

basis_args : dict
Dictionary of options for defining basis species (see basis()) in the animated diagram. Example: basis_args={'species':['CO2', 'O2', 'H2O', 'H+']}
species_args : dict
Dictionary of options for defining species (see species()) in the animated diagram, or a list of dicts. Example 1: species_args={'species':['CO2', 'HCO3-', 'CO3-2']} Example 2: species_args=[ {'species':['CO2', 'HCO3-', 'CO3-2'], 'state':[-4]}, {'species':['graphite'], state:[0], 'add':True}]
affinity_args : dict
Dictionary of options for defining the affinity calculation (see affinity()). Example: affinity_args={"pH":[2, 12, 100]} Example: affinity_args={"pH":[2, 12, 100], "P":[2000, 4000, 100]}
equilibrate_args : dict or None, default None
Dictionary of options for defining equilibration calculation (see equilibrate()). If None, plots output from affinity(). Example: equilibrate_args={"balance":1}
diagram_args : dict
Dictionary of options for diagramming (see diagram()). Diagram option interactive is set to True. Example: diagram_args={"alpha":True}
anim_var : str, default "T"
Variable that changes with each frame of animation.
anim_range : list of numeric, default [0, 350, 8]
The first two numbers in the list are the starting and ending values for anim_var. The third number in the list is the desired number of animation frames.
xlab, ylab : str, optional
Custom names for the X and Y axes.
messages : bool, default True
Display messages from CHNOSZ?

Returns

An interactive animated plot.

def balance_reaction(species: str | List[str] | int | List[int],
coeff: int | float | List[int | float],
state: str | List[str] | None = None,
basis: pandas.core.frame.DataFrame | None = None,
messages: bool = False) ‑> Tuple[List, List] | None
Expand source code
def balance_reaction(species: Union[str, List[str], int, List[int]],
                    coeff: Union[int, float, List[Union[int, float]]],
                    state: Optional[Union[str, List[str]]] = None,
                    basis: Optional[pd.DataFrame] = None,
                    messages: bool = False) -> Optional[Tuple[List, List]]:
    """
    Balance a chemical reaction using basis species.

    This function checks if a reaction is balanced and, if not, attempts to
    balance it by adding basis species. Unlike subcrt(), this function only
    performs the balancing calculation without computing thermodynamic properties,
    making it much more efficient for reaction generation.

    Parameters
    ----------
    species : str, int, list of str, or list of int
        Species names or indices in the reaction
    coeff : int, float, or list
        Stoichiometric coefficients for the species
    state : str, list of str, or None
        Physical states for species (optional)
    basis : pd.DataFrame, optional
        Basis species definition to use. If None, uses global basis from thermo()
    messages : bool
        Whether to print informational messages

    Returns
    -------
    tuple or None
        If reaction is balanced or can be balanced:
            (balanced_species, balanced_coeffs) where both are lists
        If reaction cannot be balanced:
            None

    Examples
    --------
    >>> import chnosz
    >>> pychnosz.reset()
    >>> pychnosz.basis(['H2O', 'H+', 'Fe+2'])
    >>> # Balance reaction for Fe+3
    >>> species, coeffs = balance_reaction('Fe+3', [-1])
    >>> print(f"Species: {species}")
    >>> print(f"Coefficients: {coeffs}")
    """

    # Convert inputs to lists
    if not isinstance(species, list):
        species = [species]
    if not isinstance(coeff, list):
        coeff = [coeff]
    if state is not None and not isinstance(state, list):
        state = [state]

    # Validate lengths
    if len(species) != len(coeff):
        raise ValueError("Length of species and coeff must match")

    # Get basis definition
    thermo_sys = thermo()
    if basis is None:
        if hasattr(thermo_sys, 'basis') and thermo_sys.basis is not None:
            basis = thermo_sys.basis
        else:
            raise RuntimeError("Basis species not defined. Call pychnosz.basis() first.")

    # Look up species indices
    ispecies = []
    for i, sp in enumerate(species):
        if isinstance(sp, (int, np.integer)):
            ispecies.append(int(sp))
        else:
            sp_state = state[i] if state and i < len(state) else None
            sp_idx = info(sp, sp_state, messages=messages)
            if sp_idx is None or (isinstance(sp_idx, float) and np.isnan(sp_idx)):
                raise ValueError(f"Species not found: {sp}")
            ispecies.append(sp_idx)

    # Calculate mass balance
    try:
        mass_balance = makeup(ispecies, coeff, sum_formulas=True)

        # Check if balanced
        tolerance = 1e-6
        unbalanced_elements = {elem: val for elem, val in mass_balance.items()
                             if abs(val) > tolerance}

        if not unbalanced_elements:
            # Already balanced
            if messages:
                print("Reaction is already balanced")
            return (species, coeff)

        # Reaction is unbalanced - try to balance using basis species
        missing_composition = {elem: -val for elem, val in unbalanced_elements.items()}

        if messages:
            print("Reaction is not balanced; missing composition:")
            elem_names = list(missing_composition.keys())
            elem_values = list(missing_composition.values())
            print(" ".join(elem_names))
            print(" ".join([f"{val:.4f}" for val in elem_values]))

        # Get basis element columns
        basis_elements = [col for col in basis.columns
                        if col not in ['ispecies', 'logact', 'state']]

        # Check if all missing elements are in basis
        missing_elements = set(missing_composition.keys())
        if not missing_elements.issubset(set(basis_elements)):
            if messages:
                print(f"Cannot balance: elements {missing_elements - set(basis_elements)} not in basis")
            return None

        # Calculate coefficients for missing composition from basis species
        missing_matrix = np.zeros((1, len(basis_elements)))
        for i, elem in enumerate(basis_elements):
            missing_matrix[0, i] = missing_composition.get(elem, 0)

        # Get basis matrix
        basis_matrix = basis[basis_elements].values.T  # Transpose: (elements × basis_species)

        try:
            # Try to find simple integer solutions first
            basis_coeffs = _find_simple_integer_solution(
                basis_matrix.T,
                missing_matrix.flatten(),
                basis['ispecies'].tolist(),
                missing_composition
            )

            if basis_coeffs is None:
                # Fall back to linear algebra solution
                basis_coeffs = np.linalg.solve(basis_matrix, missing_matrix.T).flatten()

                # Apply zapsmall equivalent (digits=7)
                basis_coeffs = np.around(basis_coeffs, decimals=7)

                # Clean up very small numbers
                basis_coeffs[np.abs(basis_coeffs) < 1e-7] = 0

            # Get non-zero coefficients and corresponding basis species
            nonzero_indices = np.abs(basis_coeffs) > 1e-6
            if not np.any(nonzero_indices):
                if messages:
                    print("No basis species needed to balance (coefficients are zero)")
                return (species, coeff)

            # Get basis species info
            basis_indices = basis['ispecies'].values[nonzero_indices]
            basis_coeffs_nz = basis_coeffs[nonzero_indices]

            # Create new species list and coefficients
            new_species = list(species) + [int(idx) for idx in basis_indices]
            new_coeff = list(coeff) + list(basis_coeffs_nz)

            if messages:
                print("Balanced reaction by adding basis species:")
                for sp_idx, cf in zip(basis_indices, basis_coeffs_nz):
                    sp_name = thermo_sys.obigt.loc[int(sp_idx)]['name']
                    print(f"  {cf:.4f} {sp_name}")

            # CRITICAL: Consolidate duplicate species by summing coefficients
            # This prevents infinite recursion and matches subcrt's behavior
            consolidated_species = []
            consolidated_coeffs = []

            # Convert all species to indices for consolidation
            species_indices = []
            for sp in new_species:
                if isinstance(sp, (int, np.integer)):
                    species_indices.append(int(sp))
                else:
                    sp_idx = info(sp, None, messages=False)
                    if sp_idx is None or (isinstance(sp_idx, float) and np.isnan(sp_idx)):
                        # Keep as string if not found
                        species_indices.append(sp)
                    else:
                        species_indices.append(sp_idx)

            # Group by species index and sum coefficients
            species_coeff_map = {}
            for sp_idx, coeff in zip(species_indices, new_coeff):
                if sp_idx in species_coeff_map:
                    species_coeff_map[sp_idx] += coeff
                else:
                    species_coeff_map[sp_idx] = coeff

            # Remove species with zero coefficient (cancelled out)
            for sp_idx, coeff in species_coeff_map.items():
                if abs(coeff) > tolerance:
                    consolidated_species.append(sp_idx)
                    consolidated_coeffs.append(coeff)

            # Now check if consolidated reaction is balanced
            # If not, recursively balance again
            try:
                final_mass_balance = makeup(consolidated_species, consolidated_coeffs, sum_formulas=True)
                final_unbalanced = {elem: val for elem, val in final_mass_balance.items()
                                   if abs(val) > tolerance}

                if final_unbalanced:
                    # Still unbalanced after consolidation - recursively balance
                    if messages:
                        print(f"After consolidation, reaction still unbalanced: {final_unbalanced}")
                        print(f"Attempting recursive balance...")
                    return balance_reaction(consolidated_species, consolidated_coeffs, state=None,
                                           basis=basis, messages=messages)
                else:
                    # Balanced! Return consolidated result
                    if messages:
                        print(f"Reaction balanced after consolidation")
                    return (consolidated_species, consolidated_coeffs)

            except Exception as e:
                # If check fails, return consolidated result anyway
                if messages:
                    print(f"Could not verify final balance: {e}")
                return (consolidated_species, consolidated_coeffs)

        except np.linalg.LinAlgError:
            if messages:
                print("Cannot balance: singular basis matrix")
            return None

    except Exception as e:
        if messages:
            print(f"Error checking reaction balance: {e}")
            import traceback
            traceback.print_exc()
        return None

Balance a chemical reaction using basis species.

This function checks if a reaction is balanced and, if not, attempts to balance it by adding basis species. Unlike subcrt(), this function only performs the balancing calculation without computing thermodynamic properties, making it much more efficient for reaction generation.

Parameters

species : str, int, list of str, or list of int
Species names or indices in the reaction
coeff : int, float, or list
Stoichiometric coefficients for the species
state : str, list of str, or None
Physical states for species (optional)
basis : pd.DataFrame, optional
Basis species definition to use. If None, uses global basis from thermo()
messages : bool
Whether to print informational messages

Returns

tuple or None
If reaction is balanced or can be balanced: (balanced_species, balanced_coeffs) where both are lists If reaction cannot be balanced: None

Examples

>>> import chnosz
>>> pychnosz.reset()
>>> pychnosz.basis(['H2O', 'H+', 'Fe+2'])
>>> # Balance reaction for Fe+3
>>> species, coeffs = balance_reaction('Fe+3', [-1])
>>> print(f"Species: {species}")
>>> print(f"Coefficients: {coeffs}")
def basis(species: str | int | List[str | int] | None = None,
state: str | List[str] | None = None,
logact: float | List[float] | None = None,
delete: bool = False,
add: bool = False,
messages: bool = True,
global_state: bool = True) ‑> pandas.core.frame.DataFrame | None
Expand source code
def basis(species: Optional[Union[str, int, List[Union[str, int]]]] = None,
          state: Optional[Union[str, List[str]]] = None,
          logact: Optional[Union[float, List[float]]] = None,
          delete: bool = False,
          add: bool = False,
          messages: bool = True,
          global_state: bool = True) -> Optional[pd.DataFrame]:
    """
    Set up the basis species of a thermodynamic system.

    Parameters
    ----------
    species : str, int, list, or None
        Species name(s), formula(s), or index(es), or preset keyword.
        If None, returns current basis definition.
    state : str, list of str, or None
        Physical state(s) for the species
    logact : float, list of float, or None
        Log activities for the basis species
    delete : bool, default False
        If True, delete the basis definition
    add : bool, default False
        If True, add to existing basis instead of replacing
    messages : bool, default True
        If True, print informational messages about species lookup
        If False, suppress all output (equivalent to R's suppressMessages())
    global_state : bool, default True
        If True, store basis definition in global thermo().basis (default behavior)
        If False, return basis definition without storing globally (local state)

    Returns
    -------
    pd.DataFrame or None
        Basis species definition DataFrame, or None if deleted

    Examples
    --------
    >>> # Set up a simple basis
    >>> basis(["H2O", "CO2", "NH3"], logact=[0, -3, -4])

    >>> # Use a preset basis
    >>> basis("CHNOS")

    >>> # Add species to existing basis
    >>> basis("Fe2O3", add=True)

    >>> # Delete basis
    >>> basis(delete=True)

    >>> # Suppress messages
    >>> basis("CHNOS", messages=False)
    """
    thermo_obj = thermo()
    
    # Get current basis
    old_basis = thermo_obj.basis
    
    # Delete basis if requested
    if delete or species == "":
        thermo_obj.basis = None
        thermo_obj.species = None
        return old_basis
    
    # Return current basis if no species specified
    if species is None:
        return old_basis
    
    # Handle empty species list
    if isinstance(species, list) and len(species) == 0:
        raise ValueError("species argument is empty")
    
    # Check for preset keywords
    if isinstance(species, str) and species in _get_preset_basis_keywords():
        return preset_basis(species, messages=messages, global_state=global_state)

    # Ensure species names are unique
    if isinstance(species, list):
        if len(set([str(s) for s in species])) != len(species):
            raise ValueError("species names are not unique")

    # Process arguments
    species, state, logact = _process_basis_arguments(species, state, logact)

    # Handle special transformations
    species, logact = _handle_special_species(species, logact)

    # Check if we're modifying existing basis species
    if (old_basis is not None and not add and
        _all_species_in_basis(species, old_basis)):
        if state is not None or logact is not None:
            return mod_basis(species, state, logact, messages=messages)

    # Create new basis definition or add to existing
    if logact is None:
        logact = [0.0] * len(species)

    # Get species indices
    ispecies = _get_species_indices(species, state, messages=messages)
    
    # Handle adding to existing basis
    if add and old_basis is not None:
        # Check for duplicates
        existing_indices = old_basis['ispecies'].tolist()
        for i, idx in enumerate(ispecies):
            if idx in existing_indices:
                sp_name = species[i] if isinstance(species[i], str) else str(species[i])
                raise BasisError(f"Species {sp_name} is already in the basis definition")
        
        # Append to existing basis
        ispecies = existing_indices + ispecies
        logact = old_basis['logact'].tolist() + logact
    
    # Create new basis
    new_basis = put_basis(ispecies, logact, global_state=global_state)

    # Only update global species list if using global state
    if global_state:
        # Handle species list when adding
        if add and thermo_obj.species is not None:
            _update_species_for_added_basis(old_basis, new_basis)
        else:
            # Clear species since basis changed
            from .species import species as species_func
            species_func(delete=True)

    return new_basis

Set up the basis species of a thermodynamic system.

Parameters

species : str, int, list, or None
Species name(s), formula(s), or index(es), or preset keyword. If None, returns current basis definition.
state : str, list of str, or None
Physical state(s) for the species
logact : float, list of float, or None
Log activities for the basis species
delete : bool, default False
If True, delete the basis definition
add : bool, default False
If True, add to existing basis instead of replacing
messages : bool, default True
If True, print informational messages about species lookup If False, suppress all output (equivalent to R's suppressMessages())
global_state : bool, default True
If True, store basis definition in global thermo().basis (default behavior) If False, return basis definition without storing globally (local state)

Returns

pd.DataFrame or None
Basis species definition DataFrame, or None if deleted

Examples

>>> # Set up a simple basis
>>> basis(["H2O", "CO2", "NH3"], logact=[0, -3, -4])
>>> # Use a preset basis
>>> basis("CHNOS")
>>> # Add species to existing basis
>>> basis("Fe2O3", add=True)
>>> # Delete basis
>>> basis(delete=True)
>>> # Suppress messages
>>> basis("CHNOS", messages=False)
def calc_G_TP(OBIGT, Tc, P, water_model)
Expand source code
def calc_G_TP(OBIGT, Tc, P, water_model):
    
    aq_out, H2O_Pt = hkf(property=["G"], parameters=OBIGT,
                         T=273.15+Tc, P=P, contrib=["n", "s", "o"],
                         H2O_props=["rho"], water_model=water_model)
    
    cgl_out = cgl(property=["G"], parameters=OBIGT, T=273.15+Tc, P=P)
    
    aq_col = pd.DataFrame.from_dict(aq_out, orient="index")
    cgl_col = pd.DataFrame.from_dict(cgl_out, orient="index")

    G_TP_df = pd.concat([aq_col, cgl_col], axis=1)
    G_TP_df.columns = ['aq','cgl']
    
    OBIGT["G_TP"] = G_TP_df['aq'].combine_first(G_TP_df['cgl'])
    
    rows_added = 0

    # add a row for water
    if "H2O" not in list(OBIGT["name"]):
        # Set the water model (without printing messages)
        water(water_model, messages=False)
        # water() returns a scalar when called with single property and scalar T, P
        # The result is in J/mol, need to convert to cal/mol by dividing by 4.184
        G_water = water("G", T=Tc+273.15, P=P, messages=False)
        # Handle both scalar and DataFrame returns
        if isinstance(G_water, pd.DataFrame):
            G_water_cal = G_water.iloc[0]["G"] / 4.184
        else:
            G_water_cal = float(G_water) / 4.184
        OBIGT = pd.concat([OBIGT, pd.DataFrame({"name": "H2O", "tag": "nan", "G_TP": G_water_cal}, index=[OBIGT.shape[0]])], ignore_index=True)
        rows_added += 1

    # add a row for protons
    if "H+" not in list(OBIGT["name"]):
        OBIGT = pd.concat([OBIGT, pd.DataFrame({"name": "H+", "tag": "nan", "G_TP": 0}, index=[OBIGT.shape[0]])], ignore_index=True)
        rows_added += 1

    return OBIGT, rows_added
def calc_logK(OBIGT_df, Tc, P, TP_i, water_model)
Expand source code
def calc_logK(OBIGT_df, Tc, P, TP_i, water_model):
    
    OBIGT_TP, rows_added = calc_G_TP(OBIGT_df, Tc, P, water_model)
    
    dissrxn2logK_out = []
    for i in OBIGT_TP.index:
        dissrxn2logK_out.append(dissrxn2logK(OBIGT_TP, i, Tc))
    assert len(dissrxn2logK_out) == OBIGT_TP.shape[0]
    
    OBIGT_TP['dissrxn_logK_'+str(TP_i)] = dissrxn2logK_out
    
    # remove any rows added by calc_G_TP
    OBIGT_TP.drop(OBIGT_TP.tail(rows_added).index, inplace = True)
    
    return OBIGT_TP
def cgl(property=None, parameters=None, T=298.15, P=1)
Expand source code
def cgl(property = None, parameters = None, T = 298.15, P = 1):
    # calculate properties of crystalline, liquid (except H2O) and gas species
    Tr = 298.15
    Pr = 1

    # Convert T and P to arrays for vectorized operations
    T = np.atleast_1d(T)
    P = np.atleast_1d(P)

    # make T and P equal length
    if P.size < T.size:
        P = np.full_like(T, P[0] if P.size == 1 else P)
    if T.size < P.size:
        T = np.full_like(P, T[0] if T.size == 1 else T)

    n_conditions = T.size
    # initialize output dict
    out_dict = dict()
    # loop over each species
    
    # Iterate over each row by position to handle duplicate indices properly
    for i in range(len(parameters)):
        # Get the index label for this row
        k = parameters.index[i]
        # Get the row data by position (iloc) to avoid duplicate index issues
        PAR = parameters.iloc[i]

        if PAR["state"] == "aq":
            # For aqueous species processed by CGL, return NaN
            # (they should be processed by HKF instead)
            out_dict[k] = {p:float('NaN') for p in property}
        else:

            # OBIGT database stores G, H, S in calories (E_units = "cal")
            # CGL calculations use calories (integrals intCpdT, intCpdlnT, intVdP are in cal)
            # Results are output in calories and converted to J in subcrt.py at line 959

            # Parameter scaling - SUPCRT92 data is already in correct units
            # PAR["a2.b"] = copy.copy(PAR["a2.b"]*10**-3)
            # PAR["a3.c"] = copy.copy(PAR["a3.c"]*10**5) 
            # PAR["c1.e"] = copy.copy(PAR["c1.e"]*10**-5)

            # Check if this is a Berman mineral (columns 9-21 are all NA in R indexing)
            # In Python/pandas, we check the relevant thermodynamic parameter columns
            # NOTE: A mineral is only Berman if it LACKS standard thermodynamic data (G,H,S)
            # If G,H,S are present, use regular CGL even if heat capacity coefficients are all zero
            berman_cols = ['a1.a', 'a2.b', 'a3.c', 'a4.d', 'c1.e', 'c2.f', 'omega.lambda', 'z.T']
            has_standard_thermo = pd.notna(PAR.get('G', np.nan)) and pd.notna(PAR.get('H', np.nan)) and pd.notna(PAR.get('S', np.nan))
            all_coeffs_zero_or_na = all(pd.isna(PAR.get(col, np.nan)) or PAR.get(col, 0) == 0 for col in berman_cols)
            is_berman_mineral = all_coeffs_zero_or_na and not has_standard_thermo

            if is_berman_mineral:
                # Use Berman equations (parameters not in thermo()$OBIGT)
                from .berman import Berman
                try:
                    # Berman is already vectorized - pass T and P arrays directly
                    properties_df = Berman(PAR["name"], T=T, P=P)
                    # Extract the requested properties as arrays
                    values = {}
                    for prop in property:
                        if prop in properties_df.columns:
                            # Get all values as an array
                            prop_values = properties_df[prop].values

                            # IMPORTANT: Berman function returns values in J/mol (Joules)
                            # but CGL returns values in cal/mol (calories)
                            # Convert Berman results from J/mol to cal/mol for consistency
                            # Energy properties that need conversion: G, H, S, Cp
                            # Volume (V) and other properties don't need conversion
                            energy_props = ['G', 'H', 'S', 'Cp']
                            if prop in energy_props:
                                # Convert J/mol to cal/mol by dividing by 4.184
                                prop_values = prop_values / 4.184

                            values[prop] = prop_values
                        else:
                            values[prop] = np.full(n_conditions, float('NaN'))
                except Exception as e:
                    # If Berman calculation fails, fall back to NaN arrays
                    values = {prop: np.full(n_conditions, float('NaN')) for prop in property}
            else:
                # Use regular CGL equations
                
                # in CHNOSZ, we have
                # 1 cm^3 bar --> convert(1, "calories") == 0.02390057 cal
                # but REAC92D.F in SUPCRT92 uses
                cm3bar_to_cal = 0.023901488 # cal
                # start with NA values
                values = dict()
                # a test for availability of heat capacity coefficients (a, b, c, d, e, f)
                # based on the column assignments in thermo()$OBIGT

                # Check for heat capacity coefficients, handling both NaN and non-numeric values
                # Heat capacity coefficients are at positions 14-19 (a1.a through c2.f)
                # Position 13 is V (volume), not a heat capacity coefficient
                has_hc_coeffs = False
                try:
                    hc_values = list(PAR.iloc[14:20])
                    has_hc_coeffs = any([pd.notna(p) and p != 0 for p in hc_values if pd.api.types.is_numeric_dtype(type(p))])

                    # DEBUG
                    if False and PAR["name"] == "rhomboclase":
                        print(f"DEBUG for rhomboclase:")
                        print(f"  hc_values (iloc[14:20]): {hc_values}")
                        print(f"  has_hc_coeffs: {has_hc_coeffs}")
                except Exception as e:
                    has_hc_coeffs = False

                if has_hc_coeffs:
                    # we have at least one of the heat capacity coefficients;
                    # zero out any NA's in the rest (leave lambda and T of transition (columns 20-21) alone)
                    for i in range(14, 20):
                        if pd.isna(PAR.iloc[i]) or not pd.api.types.is_numeric_dtype(type(PAR.iloc[i])):
                            PAR.iloc[i] = 0.0
                    # calculate the heat capacity and its integrals (vectorized)
                    Cp = PAR["a1.a"] + PAR["a2.b"]*T + PAR["a3.c"]*T**-2 + PAR["a4.d"]*T**-0.5 + PAR["c1.e"]*T**2 + PAR["c2.f"]*T**PAR["omega.lambda"]
                    intCpdT = PAR["a1.a"]*(T - Tr) + PAR["a2.b"]*(T**2 - Tr**2)/2 + PAR["a3.c"]*(1/T - 1/Tr)/-1 + PAR["a4.d"]*(T**0.5 - Tr**0.5)/0.5 + PAR["c1.e"]*(T**3-Tr**3)/3
                    intCpdlnT = PAR["a1.a"]*np.log(T / Tr) + PAR["a2.b"]*(T - Tr) + PAR["a3.c"]*(T**-2 - Tr**-2)/-2 + PAR["a4.d"]*(T**-0.5 - Tr**-0.5)/-0.5  + PAR["c1.e"]*(T**2 - Tr**2)/2

                    # do we also have the lambda parameter (Cp term with adjustable exponent on T)?
                    if pd.notna(PAR["omega.lambda"]) and PAR["omega.lambda"] != 0:
                        # equations for lambda adapted from Helgeson et al., 1998 (doi:10.1016/S0016-7037(97)00219-6)
                        if PAR["omega.lambda"] == -1:
                            intCpdT = intCpdT + PAR["c2.f"]*np.log(T/Tr)
                        else:
                            intCpdT = intCpdT - PAR["c2.f"]*( T**(PAR["omega.lambda"] + 1) - Tr**(PAR["omega.lambda"] + 1) ) / (PAR["omega.lambda"] + 1)
                        intCpdlnT = intCpdlnT + PAR["c2.f"]*(T**PAR["omega.lambda"] - Tr**PAR["omega.lambda"]) / PAR["omega.lambda"]

                else:
                    # use constant heat capacity if the coefficients are not available (vectorized)
                    # If Cp is NA/NaN, use 0 (matching R CHNOSZ behavior)
                    Cp_value = PAR["Cp"] if pd.notna(PAR["Cp"]) else 0.0
                    Cp = np.full(n_conditions, Cp_value)
                    intCpdT = Cp_value*(T - Tr)
                    intCpdlnT = Cp_value*np.log(T / Tr)
                    # in case Cp is listed as NA, set the integrals to 0 at Tr
                    at_Tr = (T == Tr)
                    intCpdT = np.where(at_Tr, 0, intCpdT)
                    intCpdlnT = np.where(at_Tr, 0, intCpdlnT)


                # volume and its integrals (vectorized)
                if PAR["name"] in ["quartz", "coesite"]:
                    # volume calculations for quartz and coesite
                    qtz = quartz_coesite(PAR, T, P)
                    V = qtz["V"]
                    intVdP = qtz["intVdP"]
                    intdVdTdP = qtz["intdVdTdP"]

                else:
                    # for other minerals, volume is constant (Helgeson et al., 1978)
                    V = np.full(n_conditions, PAR["V"])
                    # if the volume is NA, set its integrals to zero
                    if pd.isna(PAR["V"]):
                        intVdP = np.zeros(n_conditions)
                        intdVdTdP = np.zeros(n_conditions)
                    else:
                        intVdP = PAR["V"]*(P - Pr) * cm3bar_to_cal
                        intdVdTdP = np.zeros(n_conditions)

                # get the values of each of the requested thermodynamic properties (vectorized)
                for i,prop in enumerate(property):
                    if prop == "Cp": values["Cp"] = Cp
                    if prop == "V": values["V"] = V
                    if prop == "E": values["E"] = np.full(n_conditions, float('NaN'))
                    if prop == "kT": values["kT"] = np.full(n_conditions, float('NaN'))
                    if prop == "G":
                        # calculate S * (T - Tr), but set it to 0 at Tr (in case S is NA)
                        Sterm = PAR["S"]*(T - Tr)
                        Sterm = np.where(T == Tr, 0, Sterm)

                        # DEBUG
                        if False and PAR["name"] == "iron" and PAR.get("state") == "cr4":
                            print(f"DEBUG G calculation for {PAR['name']} {PAR.get('state', 'unknown')}:")
                            print(f"  PAR['G'] = {PAR['G']}")
                            print(f"  PAR['S'] = {PAR['S']}")
                            print(f"  model = {PAR.get('model', 'unknown')}")
                            print(f"  Sterm[0] = {Sterm[0] if hasattr(Sterm, '__len__') else Sterm}")
                            print(f"  intCpdT[0] = {intCpdT[0] if hasattr(intCpdT, '__len__') else intCpdT}")
                            print(f"  T[0]*intCpdlnT[0] = {(T[0]*intCpdlnT[0]) if hasattr(intCpdlnT, '__len__') else T*intCpdlnT}")
                            print(f"  intVdP[0] = {intVdP[0] if hasattr(intVdP, '__len__') else intVdP}")
                            G_calc = PAR['G'] - Sterm + intCpdT - T*intCpdlnT + intVdP
                            print(f"  G[0] (before subcrt conversion) = {G_calc[0] if hasattr(G_calc, '__len__') else G_calc}")

                        values["G"] = PAR["G"] - Sterm + intCpdT - T*intCpdlnT + intVdP
                    if prop == "H":
                        values["H"] = PAR["H"] + intCpdT + intVdP - T*intdVdTdP
                    if prop == "S": values["S"] = PAR["S"] + intCpdlnT - intdVdTdP

            out_dict[k] = values # species have to be numbered instead of named because of name repeats (e.g., cr polymorphs)

    return out_dict
def convert(value: float | numpy.ndarray | List[float],
units: str,
T: float | numpy.ndarray = 298.15,
P: float | numpy.ndarray = 1,
pH: float | numpy.ndarray = 7,
logaH2O: float | numpy.ndarray = 0,
messages: bool = True) ‑> float | numpy.ndarray
Expand source code
def convert(value: Union[float, np.ndarray, List[float]],
            units: str,
            T: Union[float, np.ndarray] = 298.15,
            P: Union[float, np.ndarray] = 1,
            pH: Union[float, np.ndarray] = 7,
            logaH2O: Union[float, np.ndarray] = 0,
            messages: bool = True) -> Union[float, np.ndarray]:
    """
    Convert values to the specified units.

    This function converts thermodynamic values between different units commonly
    used in geochemistry.

    Parameters
    ----------
    value : float, ndarray, or list
        Value(s) to convert
    units : str
        Target units. Options include:
        - Temperature: 'C', 'K'
        - Energy: 'J', 'cal'
        - Pressure: 'bar', 'MPa'
        - Thermodynamic: 'G', 'logK'
        - Electrochemical: 'Eh', 'pe', 'E0', 'logfO2'
        - Volume: 'cm3bar', 'joules'
    T : float or ndarray, default 298.15
        Temperature in K (for Eh/pe/logK conversions)
    P : float or ndarray, default 1
        Pressure in bar (for E0/logfO2 conversions)
    pH : float or ndarray, default 7
        pH value (for E0/logfO2 conversions)
    logaH2O : float or ndarray, default 0
        Log activity of water (for E0/logfO2 conversions)
    messages : bool, default True
        Whether to print informational messages

    Returns
    -------
    float or ndarray
        Converted value(s)

    Examples
    --------
    >>> convert(25, 'K')  # Convert 25°C to K
    298.15
    >>> convert(1.0, 'pe', T=298.15)  # Convert 1V Eh to pe
    16.9
    """

    if value is None:
        return None

    # Convert to numpy array for uniform handling
    value = np.asarray(value)
    T = np.asarray(T)
    P = np.asarray(P)
    pH = np.asarray(pH)
    logaH2O = np.asarray(logaH2O)

    units = units.lower()

    # Temperature conversions (C <-> K)
    if units in ['c', 'k']:
        CK = 273.15
        if units == 'k':
            return value + CK
        if units == 'c':
            return value - CK

    # Energy conversions (J <-> cal)
    elif units in ['j', 'cal']:
        Jcal = 4.184
        if units == 'j':
            return value * Jcal
        if units == 'cal':
            return value / Jcal

    # Gibbs energy <-> logK conversions
    elif units in ['g', 'logk']:
        # Gas constant (J K^-1 mol^-1)
        R = 8.314463  # NIST value
        if units == 'logk':
            return value / (-np.log(10) * R * T)
        if units == 'g':
            return value * (-np.log(10) * R * T)

    # Volume conversions (cm3bar <-> joules)
    elif units in ['cm3bar', 'joules']:
        if units == 'cm3bar':
            return value * 10
        if units == 'joules':
            return value / 10

    # Electrochemical potential conversions (Eh <-> pe)
    elif units in ['eh', 'pe']:
        R = 0.00831470  # Gas constant in kJ K^-1 mol^-1
        F = 96.4935     # Faraday constant in kJ V^-1 mol^-1
        if units == 'pe':
            return value * F / (np.log(10) * R * T)
        if units == 'eh':
            return value * (np.log(10) * R * T) / F

    # Pressure conversions (bar <-> MPa)
    elif units in ['bar', 'mpa']:
        barmpa = 10
        if units == 'mpa':
            return value / barmpa
        if units == 'bar':
            return value * barmpa

    # Eh <-> logfO2 conversions
    elif units in ['e0', 'logfo2']:
        # Calculate equilibrium constant for: H2O = 1/2 O2 + 2 H+ + 2 e-
        # Handle P="Psat" case (pass it directly to subcrt)
        # Check if P is a string (including numpy string types)
        P_is_psat = False
        if isinstance(P, (str, np.str_)):
            P_is_psat = str(P).lower() == 'psat'
        elif isinstance(P, (list, tuple)):
            # P is a list/tuple - check if it's a single-element string
            if len(P) == 1 and isinstance(P[0], (str, np.str_)):
                P_is_psat = str(P[0]).lower() == 'psat'
        elif isinstance(P, np.ndarray):
            # P is a numpy array
            if P.ndim == 0:
                # Scalar array - check if it's a string
                try:
                    if isinstance(P.item(), (str, np.str_)):
                        P_is_psat = str(P.item()).lower() == 'psat'
                except (ValueError, AttributeError):
                    pass
            elif P.size == 1:
                # Single-element array - check if it's a string
                try:
                    if isinstance(P.flat[0], (str, np.str_)):
                        P_is_psat = str(P.flat[0]).lower() == 'psat'
                except (ValueError, AttributeError, IndexError):
                    pass

        if P_is_psat:
            P_arg = 'Psat'
            T_arg = np.atleast_1d(T)
            if len(T_arg) == 1:
                T_arg = float(T_arg[0])
            else:
                T_arg = T_arg.tolist()
        else:
            # Convert T and P to proper format for subcrt
            T_vals = np.atleast_1d(T)
            P_vals = np.atleast_1d(P)

            # subcrt needs lists for multiple T/P values
            if len(T_vals) > 1 or len(P_vals) > 1:
                T_arg = T_vals.tolist() if len(T_vals) > 1 else float(T_vals[0])
                P_arg = P_vals.tolist() if len(P_vals) > 1 else float(P_vals[0])
            else:
                T_arg = float(T_vals[0])
                P_arg = float(P_vals[0])

        supcrt_out = subcrt(['H2O', 'oxygen', 'H+', 'e-'],
                           [-1, 0.5, 2, 2],
                           T=T_arg, P=P_arg, convert=False, messages=messages, show=False)

        # Extract logK values
        if hasattr(supcrt_out.out, 'logK'):
            logK = supcrt_out.out.logK
        else:
            logK = supcrt_out.out['logK']

        # Convert to numpy array
        logK = np.asarray(logK)

        if units == 'logfo2':
            # Convert Eh to logfO2
            pe_value = convert(value, 'pe', T=T, messages=messages)
            return 2 * (logK + logaH2O + 2*pH + 2*pe_value)
        if units == 'e0':
            # Convert logfO2 to Eh
            pe_value = (-logK - 2*pH + value/2 - logaH2O) / 2
            return convert(pe_value, 'Eh', T=T, messages=messages)

    else:
        warnings.warn(f"convert: no conversion to {units} found")
        return value

Convert values to the specified units.

This function converts thermodynamic values between different units commonly used in geochemistry.

Parameters

value : float, ndarray, or list
Value(s) to convert
units : str
Target units. Options include: - Temperature: 'C', 'K' - Energy: 'J', 'cal' - Pressure: 'bar', 'MPa' - Thermodynamic: 'G', 'logK' - Electrochemical: 'Eh', 'pe', 'E0', 'logfO2' - Volume: 'cm3bar', 'joules'
T : float or ndarray, default 298.15
Temperature in K (for Eh/pe/logK conversions)
P : float or ndarray, default 1
Pressure in bar (for E0/logfO2 conversions)
pH : float or ndarray, default 7
pH value (for E0/logfO2 conversions)
logaH2O : float or ndarray, default 0
Log activity of water (for E0/logfO2 conversions)
messages : bool, default True
Whether to print informational messages

Returns

float or ndarray
Converted value(s)

Examples

>>> convert(25, 'K')  # Convert 25°C to K
298.15
>>> convert(1.0, 'pe', T=298.15)  # Convert 1V Eh to pe
16.9
def copy_plot(diagram_result: Dict[str, Any]) ‑> Dict[str, Any]
Expand source code
def copy_plot(diagram_result: Dict[str, Any]) -> Dict[str, Any]:
    """
    Create a deep copy of a diagram result, allowing independent modification.

    This function addresses a fundamental limitation in Python plotting libraries:
    matplotlib figure and axes objects are mutable, so passing them between
    functions causes modifications to affect all references. This function
    creates a true deep copy that can be modified independently.

    Parameters
    ----------
    diagram_result : dict
        Result dictionary from diagram(), which may contain 'fig' and 'ax' keys

    Returns
    -------
    dict
        A deep copy of the diagram result with independent figure and axes objects

    Examples
    --------
    Manual copying workflow (advanced usage - normally use add_to parameter instead):

    >>> import pychnosz
    >>> # Create base plot (Plot A)
    >>> basis(['SiO2', 'Ca+2', 'Mg+2', 'CO2', 'H2O', 'O2', 'H+'])
    >>> species(['quartz', 'talc', 'chrysotile', 'forsterite'])
    >>> a = affinity(**{'Mg+2': [4, 10, 500], 'Ca+2': [5, 15, 500]})
    >>> plot_a = diagram(a, fill='terrain')
    >>>
    >>> # Manual approach: create copies first, then modify the axes directly
    >>> plot_a1 = copy_plot(plot_a)  # For modification 1
    >>> plot_a2 = copy_plot(plot_a)  # For modification 2
    >>> # ... then modify plot_a1['ax'] and plot_a2['ax'] directly
    >>>
    >>> # Recommended approach: use add_to parameter instead
    >>> # This automatically handles copying internally
    >>> basis('CO2', -1)
    >>> species(['calcite', 'dolomide'])
    >>> a2 = affinity(**{'Mg+2': [4, 10, 500], 'Ca+2': [5, 15, 500]})
    >>> plot_a1 = diagram(a2, type='saturation', add_to=plot_a, col='blue')
    >>> plot_a2 = diagram(a2, type='saturation', add_to=plot_a, col='red')
    >>> # Now you have three independent plots: plot_a, plot_a1, plot_a2

    Notes
    -----
    - This function uses copy.deepcopy() which works well for matplotlib figures
    - For very large plots, copying may be memory-intensive
    - Interactive plots (plotly) may not copy perfectly - test before relying on this
    - The copied plot is fully independent and can be saved, displayed, or modified
      without affecting the original

    Limitations
    -----------
    Python's matplotlib (unlike R's base graphics) uses mutable objects for plots.
    Without explicit copying, all references point to the same plot. This is a
    known limitation of matplotlib that this function works around.

    See Also
    --------
    diagram : Create plots that can be copied with this function
    """
    return copy.deepcopy(diagram_result)

Create a deep copy of a diagram result, allowing independent modification.

This function addresses a fundamental limitation in Python plotting libraries: matplotlib figure and axes objects are mutable, so passing them between functions causes modifications to affect all references. This function creates a true deep copy that can be modified independently.

Parameters

diagram_result : dict
Result dictionary from diagram(), which may contain 'fig' and 'ax' keys

Returns

dict
A deep copy of the diagram result with independent figure and axes objects

Examples

Manual copying workflow (advanced usage - normally use add_to parameter instead):

>>> import pychnosz
>>> # Create base plot (Plot A)
>>> basis(['SiO2', 'Ca+2', 'Mg+2', 'CO2', 'H2O', 'O2', 'H+'])
>>> species(['quartz', 'talc', 'chrysotile', 'forsterite'])
>>> a = affinity(**{'Mg+2': [4, 10, 500], 'Ca+2': [5, 15, 500]})
>>> plot_a = diagram(a, fill='terrain')
>>>
>>> # Manual approach: create copies first, then modify the axes directly
>>> plot_a1 = copy_plot(plot_a)  # For modification 1
>>> plot_a2 = copy_plot(plot_a)  # For modification 2
>>> # ... then modify plot_a1['ax'] and plot_a2['ax'] directly
>>>
>>> # Recommended approach: use add_to parameter instead
>>> # This automatically handles copying internally
>>> basis('CO2', -1)
>>> species(['calcite', 'dolomide'])
>>> a2 = affinity(**{'Mg+2': [4, 10, 500], 'Ca+2': [5, 15, 500]})
>>> plot_a1 = diagram(a2, type='saturation', add_to=plot_a, col='blue')
>>> plot_a2 = diagram(a2, type='saturation', add_to=plot_a, col='red')
>>> # Now you have three independent plots: plot_a, plot_a1, plot_a2

Notes

  • This function uses copy.deepcopy() which works well for matplotlib figures
  • For very large plots, copying may be memory-intensive
  • Interactive plots (plotly) may not copy perfectly - test before relying on this
  • The copied plot is fully independent and can be saved, displayed, or modified without affecting the original

Limitations

Python's matplotlib (unlike R's base graphics) uses mutable objects for plots. Without explicit copying, all references point to the same plot. This is a known limitation of matplotlib that this function works around.

See Also

diagram()
Create plots that can be copied with this function
def describe_basis(ibasis: list = None,
digits: int = 1,
oneline: bool = False,
molality: bool = False,
use_pH: bool = True) ‑> list
Expand source code
def describe_basis(ibasis: list = None, digits: int = 1,
                  oneline: bool = False, molality: bool = False,
                  use_pH: bool = True) -> list:
    """
    Create formatted text describing basis species activities/fugacities.

    This function generates formatted strings for displaying the chemical
    activities or fugacities of basis species, typically for plot legends.

    Parameters
    ----------
    ibasis : list of int, optional
        Indices of basis species to describe (1-based). If None, describes all.
    digits : int, default 1
        Number of decimal places to display
    oneline : bool, default False
        If True, combine all species on one line (not fully implemented)
    molality : bool, default False
        If True, use molality (m) instead of activity (a)
    use_pH : bool, default True
        If True, display H+ as pH instead of log a_H+

    Returns
    -------
    list of str
        Formatted basis species descriptions

    Examples
    --------
    >>> from pychnosz.core.basis import basis
    >>> basis(["H2O", "H+", "O2"], [-10, -7, -80])
    >>> describe_basis([2, 3])
    ['pH = 7.0', 'log $f_{O_2}$ = -80.0']

    >>> describe_basis()  # All basis species
    ['log $a_{H_2O}$ = -10.0', 'pH = 7.0', 'log $f_{O_2}$ = -80.0']

    Notes
    -----
    This is used to create legend entries showing the basis species
    activities used in thermodynamic calculations.
    """
    from ..core.basis import get_basis

    basis_df = get_basis()
    if basis_df is None:
        raise RuntimeError("Basis species are not defined")

    # Default to all basis species
    if ibasis is None:
        ibasis = list(range(1, len(basis_df) + 1))

    # Convert to 0-based indexing
    ibasis_0 = [i - 1 for i in ibasis]

    descriptions = []

    for i in ibasis_0:
        species_name = basis_df.index[i]
        state = basis_df.iloc[i]['state']
        logact = basis_df.iloc[i]['logact']

        # Check if logact is numeric
        try:
            logact_val = float(logact)
            is_numeric = True
        except (ValueError, TypeError):
            is_numeric = False

        if is_numeric:
            # Handle H+ specially with pH
            if species_name == "H+" and use_pH:
                pH_val = -logact_val
                val_formatted = format(round(pH_val, digits), f'.{digits}f')
                descriptions.append(f"pH = {val_formatted}")
            else:
                # Format the activity/fugacity
                val_formatted = format(round(logact_val, digits), f'.{digits}f')

                # Determine if it's activity or fugacity based on state
                if state in ['aq', 'liq', 'cr']:
                    a_or_f = "a" if not molality else "m"
                else:
                    a_or_f = "f"

                # Format the species name
                species_formatted = _format_species_latex(species_name)

                descriptions.append(f"log ${a_or_f}_{{{species_formatted}}}$ = {val_formatted}")
        else:
            # Non-numeric value (buffer)
            if species_name == "H+" and use_pH:
                descriptions.append(f"pH = {logact}")
            else:
                # For buffers, just show the buffer name
                if state in ['aq', 'liq', 'cr']:
                    a_or_f = "a" if not molality else "m"
                else:
                    a_or_f = "f"

                species_formatted = _format_species_latex(species_name)
                descriptions.append(f"${a_or_f}_{{{species_formatted}}}$ = {logact}")

    return descriptions

Create formatted text describing basis species activities/fugacities.

This function generates formatted strings for displaying the chemical activities or fugacities of basis species, typically for plot legends.

Parameters

ibasis : list of int, optional
Indices of basis species to describe (1-based). If None, describes all.
digits : int, default 1
Number of decimal places to display
oneline : bool, default False
If True, combine all species on one line (not fully implemented)
molality : bool, default False
If True, use molality (m) instead of activity (a)
use_pH : bool, default True
If True, display H+ as pH instead of log a_H+

Returns

list of str
Formatted basis species descriptions

Examples

>>> from pychnosz.core.basis import basis
>>> basis(["H2O", "H+", "O2"], [-10, -7, -80])
>>> describe_basis([2, 3])
['pH = 7.0', 'log $f_{O_2}$ = -80.0']
>>> describe_basis()  # All basis species
['log $a_{H_2O}$ = -10.0', 'pH = 7.0', 'log $f_{O_2}$ = -80.0']

Notes

This is used to create legend entries showing the basis species activities used in thermodynamic calculations.

def describe_basis_html(ibasis: list = None,
digits: int = 1,
oneline: bool = False,
molality: bool = False,
use_pH: bool = True) ‑> list
Expand source code
def describe_basis_html(ibasis: list = None, digits: int = 1,
                        oneline: bool = False, molality: bool = False,
                        use_pH: bool = True) -> list:
    """
    Create HTML-formatted text describing basis species (for Plotly).

    This function generates HTML-formatted strings for displaying the chemical
    activities or fugacities of basis species, typically for plot legends in
    interactive diagrams.

    Parameters
    ----------
    ibasis : list of int, optional
        Indices of basis species to describe (1-based). If None, describes all.
    digits : int, default 1
        Number of decimal places to display
    oneline : bool, default False
        If True, combine all species on one line (not fully implemented)
    molality : bool, default False
        If True, use molality (m) instead of activity (a)
    use_pH : bool, default True
        If True, display H+ as pH instead of log a_H+

    Returns
    -------
    list of str
        HTML-formatted basis species descriptions

    Examples
    --------
    >>> from pychnosz.core.basis import basis
    >>> basis(["H2O", "H+", "O2"], [-10, -7, -80])
    >>> describe_basis_html([2, 3])
    ['pH = 7.0', 'log <i>f</i><sub>O<sub>2</sub></sub> = -80.0']

    >>> describe_basis_html([4])  # CO2
    ['log <i>f</i><sub>CO<sub>2</sub></sub> = -1.0']

    Notes
    -----
    Use this instead of describe_basis() when creating legends for
    interactive (Plotly) diagrams.
    """
    if not _HTML_DEPS_AVAILABLE:
        raise ImportError(
            "describe_basis_html() requires 'WORMutils' package.\n"
            "Install with: pip install WORMutils"
        )

    from ..core.basis import get_basis

    basis_df = get_basis()
    if basis_df is None:
        raise RuntimeError("Basis species are not defined")

    # Default to all basis species
    if ibasis is None:
        ibasis = list(range(1, len(basis_df) + 1))

    # Convert to 0-based indexing
    ibasis_0 = [i - 1 for i in ibasis]

    descriptions = []

    for i in ibasis_0:
        species_name = basis_df.index[i]
        state = basis_df.iloc[i]['state']
        logact = basis_df.iloc[i]['logact']

        # Check if logact is numeric
        try:
            logact_val = float(logact)
            is_numeric = True
        except (ValueError, TypeError):
            is_numeric = False

        if is_numeric:
            # Handle H+ specially with pH
            if species_name == "H+" and use_pH:
                pH_val = -logact_val
                val_formatted = format(round(pH_val, digits), f'.{digits}f')
                descriptions.append(f"pH = {val_formatted}")
            else:
                # Format the activity/fugacity
                val_formatted = format(round(logact_val, digits), f'.{digits}f')

                # Determine if it's activity or fugacity based on state
                if state in ['aq', 'liq', 'cr']:
                    a_or_f = "a" if not molality else "m"
                else:
                    a_or_f = "f"

                # Format the species name using HTML
                species_formatted = chemlabel(species_name)

                descriptions.append(f"log <i>{a_or_f}</i><sub>{species_formatted}</sub> = {val_formatted}")
        else:
            # Non-numeric value (buffer)
            if species_name == "H+" and use_pH:
                descriptions.append(f"pH = {logact}")
            else:
                # For buffers, just show the buffer name
                if state in ['aq', 'liq', 'cr']:
                    a_or_f = "a" if not molality else "m"
                else:
                    a_or_f = "f"

                species_formatted = chemlabel(species_name)
                descriptions.append(f"<i>{a_or_f}</i><sub>{species_formatted}</sub> = {logact}")

    return descriptions

Create HTML-formatted text describing basis species (for Plotly).

This function generates HTML-formatted strings for displaying the chemical activities or fugacities of basis species, typically for plot legends in interactive diagrams.

Parameters

ibasis : list of int, optional
Indices of basis species to describe (1-based). If None, describes all.
digits : int, default 1
Number of decimal places to display
oneline : bool, default False
If True, combine all species on one line (not fully implemented)
molality : bool, default False
If True, use molality (m) instead of activity (a)
use_pH : bool, default True
If True, display H+ as pH instead of log a_H+

Returns

list of str
HTML-formatted basis species descriptions

Examples

>>> from pychnosz.core.basis import basis
>>> basis(["H2O", "H+", "O2"], [-10, -7, -80])
>>> describe_basis_html([2, 3])
['pH = 7.0', 'log <i>f</i><sub>O<sub>2</sub></sub> = -80.0']
>>> describe_basis_html([4])  # CO2
['log <i>f</i><sub>CO<sub>2</sub></sub> = -1.0']

Notes

Use this instead of describe_basis() when creating legends for interactive (Plotly) diagrams.

def describe_property(property: list = None,
value: list = None,
digits: int = 0,
oneline: bool = False,
ret_val: bool = False) ‑> list
Expand source code
def describe_property(property: list = None, value: list = None,
                     digits: int = 0, oneline: bool = False,
                     ret_val: bool = False) -> list:
    """
    Create formatted text describing thermodynamic properties and their values.

    This function generates formatted strings for displaying property-value pairs
    in legends, typically for temperature, pressure, and other conditions.

    Parameters
    ----------
    property : list of str
        Property names (e.g., ["T", "P"])
    value : list
        Property values (e.g., [300, 1000])
    digits : int, default 0
        Number of decimal places to display
    oneline : bool, default False
        If True, combine all properties on one line (not implemented)
    ret_val : bool, default False
        If True, return only values with units (not property names)

    Returns
    -------
    list of str
        Formatted property descriptions

    Examples
    --------
    >>> describe_property(["T", "P"], [300, 1000])
    ['$T$ = 300 °C', '$P$ = 1000 bar']

    >>> describe_property(["T"], [25], digits=1)
    ['$T$ = 25.0 °C']

    Notes
    -----
    This is used to create legend entries showing the conditions
    used in thermodynamic calculations.
    """
    if property is None or value is None:
        raise ValueError("property or value is None")

    descriptions = []

    for i in range(len(property)):
        prop = property[i]
        val = value[i]

        # Get property symbol
        if prop == "T":
            prop_str = "$T$"
            if val == "Psat" or val == "NA":
                val_str = "$P_{sat}$"
            else:
                val_formatted = format(round(float(val), digits), f'.{digits}f')
                val_str = f"{val_formatted} °C"
        elif prop == "P":
            prop_str = "$P$"
            if val == "Psat" or val == "NA":
                val_str = "$P_{sat}$"
            else:
                val_formatted = format(round(float(val), digits), f'.{digits}f')
                val_str = f"{val_formatted} bar"
        elif prop == "pH":
            prop_str = "pH"
            val_formatted = format(round(float(val), digits), f'.{digits}f')
            val_str = val_formatted
        elif prop == "Eh":
            prop_str = "Eh"
            val_formatted = format(round(float(val), digits), f'.{digits}f')
            val_str = f"{val_formatted} V"
        elif prop == "IS":
            prop_str = "$IS$"
            val_formatted = format(round(float(val), digits), f'.{digits}f')
            val_str = val_formatted
        else:
            prop_str = f"${prop}$"
            val_formatted = format(round(float(val), digits), f'.{digits}f')
            val_str = val_formatted

        if ret_val:
            descriptions.append(val_str)
        else:
            descriptions.append(f"{prop_str} = {val_str}")

    return descriptions

Create formatted text describing thermodynamic properties and their values.

This function generates formatted strings for displaying property-value pairs in legends, typically for temperature, pressure, and other conditions.

Parameters

property : list of str
Property names (e.g., ["T", "P"])
value : list
Property values (e.g., [300, 1000])
digits : int, default 0
Number of decimal places to display
oneline : bool, default False
If True, combine all properties on one line (not implemented)
ret_val : bool, default False
If True, return only values with units (not property names)

Returns

list of str
Formatted property descriptions

Examples

>>> describe_property(["T", "P"], [300, 1000])
['$T$ = 300 °C', '$P$ = 1000 bar']
>>> describe_property(["T"], [25], digits=1)
['$T$ = 25.0 °C']

Notes

This is used to create legend entries showing the conditions used in thermodynamic calculations.

def describe_property_html(property: list = None,
value: list = None,
digits: int = 0,
oneline: bool = False,
ret_val: bool = False) ‑> list
Expand source code
def describe_property_html(property: list = None, value: list = None,
                           digits: int = 0, oneline: bool = False,
                           ret_val: bool = False) -> list:
    """
    Create HTML-formatted text describing thermodynamic properties (for Plotly).

    This function generates HTML-formatted strings for displaying thermodynamic
    properties and their values, typically for plot legends in interactive diagrams.

    Parameters
    ----------
    property : list of str
        Property names (e.g., ["T", "P"])
    value : list
        Property values
    digits : int, default 0
        Number of decimal places to display
    oneline : bool, default False
        If True, format on one line (not implemented)
    ret_val : bool, default False
        If True, return only values without property names

    Returns
    -------
    list of str
        HTML-formatted property descriptions

    Examples
    --------
    >>> describe_property_html(["T", "P"], [300, 1000])
    ['<i>T</i> = 300 °C', '<i>P</i> = 1000 bar']

    Notes
    -----
    Use this instead of describe_property() when creating legends for
    interactive (Plotly) diagrams.
    """
    if property is None or value is None:
        raise ValueError("property or value is None")

    descriptions = []

    for i in range(len(property)):
        prop = property[i]
        val = value[i]

        # Get property symbol (HTML format)
        if prop == "T":
            prop_str = "<i>T</i>"
            if val == "Psat" or val == "NA":
                val_str = "<i>P</i><sub>sat</sub>"
            else:
                val_formatted = format(round(float(val), digits), f'.{digits}f')
                val_str = f"{val_formatted} °C"
        elif prop == "P":
            prop_str = "<i>P</i>"
            if val == "Psat" or val == "NA":
                val_str = "<i>P</i><sub>sat</sub>"
            else:
                val_formatted = format(round(float(val), digits), f'.{digits}f')
                val_str = f"{val_formatted} bar"
        elif prop == "pH":
            prop_str = "pH"
            val_formatted = format(round(float(val), digits), f'.{digits}f')
            val_str = val_formatted
        elif prop == "Eh":
            prop_str = "Eh"
            val_formatted = format(round(float(val), digits), f'.{digits}f')
            val_str = f"{val_formatted} V"
        elif prop == "IS":
            prop_str = "<i>IS</i>"
            val_formatted = format(round(float(val), digits), f'.{digits}f')
            val_str = val_formatted
        else:
            prop_str = f"<i>{prop}</i>"
            val_formatted = format(round(float(val), digits), f'.{digits}f')
            val_str = val_formatted

        if ret_val:
            descriptions.append(val_str)
        else:
            descriptions.append(f"{prop_str} = {val_str}")

    return descriptions

Create HTML-formatted text describing thermodynamic properties (for Plotly).

This function generates HTML-formatted strings for displaying thermodynamic properties and their values, typically for plot legends in interactive diagrams.

Parameters

property : list of str
Property names (e.g., ["T", "P"])
value : list
Property values
digits : int, default 0
Number of decimal places to display
oneline : bool, default False
If True, format on one line (not implemented)
ret_val : bool, default False
If True, return only values without property names

Returns

list of str
HTML-formatted property descriptions

Examples

>>> describe_property_html(["T", "P"], [300, 1000])
['<i>T</i> = 300 °C', '<i>P</i> = 1000 bar']

Notes

Use this instead of describe_property() when creating legends for interactive (Plotly) diagrams.

def diagram(eout: Dict[str, Any],
type: str = 'auto',
alpha: bool = False,
balance: str | float | List[float] | None = None,
names: List[str] | None = None,
format_names: bool = True,
xlab: str | None = None,
ylab: str | None = None,
xlim: List[float] | None = None,
ylim: List[float] | None = None,
col: str | List[str] | None = None,
col_names: str | List[str] | None = None,
lty: str | int | List | None = None,
lwd: float | List[float] = 1,
cex: float | List[float] = 1.0,
main: str | None = None,
fill: str | None = None,
fill_NA: str = '0.8',
limit_water: bool | None = None,
plot_it: bool = True,
add_to: Dict[str, Any] | None = None,
contour_method: str | List[str] | None = 'edge',
messages: bool = True,
interactive: bool = False,
annotation: str | None = None,
annotation_coords: List[float] = [0, 0],
width: int = 600,
height: int = 520,
save_as: str | None = None,
save_format: str | None = None,
save_scale: float = 1,
normalize: bool | List[bool] = False,
as_residue: bool = False,
**kwargs) ‑> Dict[str, Any]
Expand source code
def diagram(eout: Dict[str, Any],
            type: str = "auto",
            alpha: bool = False,
            balance: Optional[Union[str, float, List[float]]] = None,
            names: Optional[List[str]] = None,
            format_names: bool = True,
            xlab: Optional[str] = None,
            ylab: Optional[str] = None,
            xlim: Optional[List[float]] = None,
            ylim: Optional[List[float]] = None,
            col: Optional[Union[str, List[str]]] = None,
            col_names: Optional[Union[str, List[str]]] = None,
            lty: Optional[Union[str, int, List]] = None,
            lwd: Union[float, List[float]] = 1,
            cex: Union[float, List[float]] = 1.0,
            main: Optional[str] = None,
            fill: Optional[str] = None,
            fill_NA: str = "0.8",
            limit_water: Optional[bool] = None,
            plot_it: bool = True,
            add_to: Optional[Dict[str, Any]] = None,
            contour_method: Optional[Union[str, List[str]]] = "edge",
            messages: bool = True,
            interactive: bool = False,
            annotation: Optional[str] = None,
            annotation_coords: List[float] = [0, 0],
            width: int = 600,
            height: int = 520,
            save_as: Optional[str] = None,
            save_format: Optional[str] = None,
            save_scale: float = 1,
            normalize: Union[bool, List[bool]] = False,
            as_residue: bool = False,
            **kwargs) -> Dict[str, Any]:
    """
    Plot equilibrium chemical activity and predominance diagrams.

    This function creates plots from the output of affinity() or equilibrate().
    For 1D diagrams, it produces line plots showing how affinity or activity
    varies with a single variable. For 2D diagrams, it creates predominance
    field diagrams.

    Parameters
    ----------
    eout : dict
        Output from affinity() or equilibrate()
    type : str, default "auto"
        Type of diagram:
        - "auto" (default): Plot affinity values (A/2.303RT)
        - "loga.equil": Plot equilibrium activities from equilibrate()
        - "saturation": Draw affinity=0 contour lines (mineral saturation)
        - Basis species name (e.g., "O2", "H2O", "CO2"): Plot equilibrium
          log activity/fugacity of the specified basis species where affinity=0
          for each formed species. Useful for Eh-pH diagrams and showing
          oxygen/water fugacities at equilibrium.
    alpha : bool or str, default False
        Plot degree of formation instead of activities?
        If "balance", scale by balancing coefficients
    balance : str, float, or list of float, optional
        Balancing coefficients or method for balancing reactions
    names : list of str, optional
        Custom names for species (for labels)
    format_names : bool, default True
        Apply formatting to chemical formulas?
    xlab : str, optional
        Custom x-axis label
    ylab : str, optional
        Custom y-axis label
    xlim : list of float, optional
        X-axis limits [min, max]
    ylim : list of float, optional
        Y-axis limits [min, max]
    col : str or list of str, optional
        Line colors for 1-D plots and boundary lines in 2-D plots (matplotlib color specs)
    col_names : str or list of str, optional
        Text colors for field labels in 2-D plots (matplotlib color specs)
    lty : str, int, or list, optional
        Line styles (matplotlib linestyle specs)
    lwd : float or list of float, default 1
        Line widths for 1-D plots and boundary lines in 2-D predominance
        diagrams. Set to 0 to disable borders in 2-D diagrams. If fill is
        None and lwd > 0, uses white fill with black borders (R CHNOSZ default).
    cex : float or list of float, default 1.0
        Character expansion factor for text labels. Values > 1 make text larger,
        values < 1 make text smaller. Can be a single value or a list (one per species).
        Used for contour labels in type="saturation" plots.
    main : str, optional
        Plot title
    fill : str, optional
        Color palette for 2-D predominance diagrams. Can be any matplotlib
        colormap name (e.g., 'viridis', 'plasma', 'terrain', 'rainbow',
        'Set1', 'tab10', 'Pastel1'). If None, uses discrete colors from
        the default color cycle. Ignored for 1-D diagrams.
    fill_NA : str, default "0.8"
        Color for regions outside water stability limits (water instability regions).
        Matplotlib color specification (e.g., "0.8" for gray, "#CCCCCC").
        Set to "transparent" to disable shading. Default "0.8" matches R's "gray80".
    limit_water : bool, optional
        Whether to show water stability limits as shaded regions (default True for
        2-D diagrams). If True, also clips the diagram to the water stability region.
        Set to False to disable water stability shading.
    plot_it : bool, default True
        Display the plot?
    add_to : dict, optional
        A diagram result dictionary from a previous diagram() call. When provided,
        this plot will be AUTOMATICALLY COPIED and the new diagram will be added to
        the copy. This preserves the original plot while creating a modified version.
        The axes object is extracted from add_to['ax'].

        This parameter eliminates the need for a separate 'add' boolean - when
        add_to is provided, the function automatically operates in "add" mode.

        Example workflow:
        >>> plot_a = diagram(affinity1, fill='terrain')  # Create base plot
        >>> plot_a1 = diagram(affinity2, add_to=plot_a, col='blue')  # Copy and add
        >>> plot_a2 = diagram(affinity3, add_to=plot_a, col='red')   # Copy and add again
        >>> # plot_a remains unchanged, plot_a1 and plot_a2 are independent modifications
    contour_method : str or list of str, optional
        Method for labeling contour lines. Default "edge" labels at plot edges.
        Can be a single value (applied to all species) or a list (one per species).
        Set to None, NA, or "" to disable labels (only for type="saturation").
        In R CHNOSZ, different methods like "edge", "flattest", "simple" control
        label placement; in Python, this mainly controls whether labels are shown.
    interactive : bool, default False
        Create an interactive plot using Plotly instead of matplotlib?
        If True, calls diagram_interactive() with the appropriate parameters.
    annotation : str, optional
        For interactive plots only. Annotation text to add to the plot.
    annotation_coords : list of float, default [0, 0]
        For interactive plots only. Coordinates of annotation, where [0, 0] is
        bottom left and [1, 1] is top right.
    width : int, default 600
        For interactive plots only. Width of the plot in pixels.
    height : int, default 520
        For interactive plots only. Height of the plot in pixels.
    save_as : str, optional
        For interactive plots only. Provide a filename to save this figure.
        Filetype is determined by `save_format`.
    save_format : str, optional
        For interactive plots only. Desired format of saved or downloaded figure.
        Can be 'png', 'jpg', 'jpeg', 'webp', 'svg', 'pdf', 'eps', 'json', or 'html'.
        If 'html', an interactive plot will be saved.
    save_scale : float, default 1
        For interactive plots only. Multiply title/legend/axis/canvas sizes by
        this factor when saving the figure.
    **kwargs
        Additional arguments passed to matplotlib plotting functions

    Returns
    -------
    dict
        Dictionary containing:
        - plotvar : str, Variable that was plotted
        - plotvals : dict, Values that were plotted
        - names : list, Names used for labels
        - predominant : array or NA, Predominance matrix (for 2D)
        - balance : str or list, Balancing coefficients used
        - n.balance : list, Numerical balancing coefficients
        - ax : matplotlib.axes.Axes, The axes object used for plotting (if plot_it=True)
        - fig : matplotlib.figure.Figure, The figure object used for plotting (if plot_it=True)
        - All original eout contents

    Examples
    --------
    >>> import pychnosz
    >>> pychnosz.basis(["Fe2O3", "CO2", "H2O", "NH3", "H2S", "oxygen", "H+"],
    ...              [0, -3, 0, -4, -7, -80, -7])
    >>> pychnosz.species(["pyrite", "goethite"])
    >>> a = pychnosz.affinity(H2S=[-60, 20, 5], T=25, P=1)
    >>> d = diagram(a)

    Notes
    -----
    This implementation is based on R CHNOSZ diagram() function but adapted
    for Python's matplotlib plotting instead of R's base graphics. The key
    differences from diagram_from_WORM.py are:
    - Works directly with Python dict output from affinity() (no rpy2)
    - Uses matplotlib for 1D plots by default
    - Can optionally use plotly if requested
    """

    # Handle add_to parameter: automatically copy the provided plot
    # This extracts the axes object and creates an independent copy
    # When add_to is provided, we're in "add" mode
    ax = None
    add = add_to is not None
    plot_was_provided = add

    if add_to is not None:
        # Make a deep copy of the provided plot to preserve the original
        plot_copy = copy_plot(add_to)
        # Extract the axes from the copied plot
        if 'ax' in plot_copy:
            ax = plot_copy['ax']
        else:
            raise ValueError("The 'add_to' parameter must contain an 'ax' key (a diagram result dictionary)")

    # If interactive mode is requested, delegate to diagram_interactive
    if interactive:
        df, fig = diagram_interactive(
            eout=eout,
            type=type,
            main=main,
            borders=lwd,
            names=names,
            format_names=format_names,
            annotation=annotation,
            annotation_coords=annotation_coords,
            balance=balance,
            xlab=xlab,
            ylab=ylab,
            fill=fill,
            width=width,
            height=height,
            alpha=alpha,
            plot_it=plot_it,
            add=add,
            ax=ax,
            col=col,
            lty=lty,
            lwd=lwd,
            cex=cex,
            contour_method=contour_method,
            save_as=save_as,
            save_format=save_format,
            save_scale=save_scale,
            messages=messages
        )
        # Return in a format compatible with diagram's normal output
        # diagram_interactive returns (df, fig), wrap in a dict for consistency
        # Include eout data so water_lines() can access vars, vals, basis, etc.
        result = {
            **eout,  # Include all original eout data
            'df': df,
            'fig': fig,
            'ax': fig  # For compatibility, store fig in ax key for add=True workflow
        }
        return result

    # Check that eout is valid
    efun = eout.get('fun', '')
    if efun not in ['affinity', 'equilibrate', 'solubility']:
        raise ValueError("'eout' is not the output from affinity(), equilibrate(), or solubility()")

    # Determine if eout is from affinity() (as opposed to equilibrate())
    # Check for both Python naming (loga_equil) and R naming (loga.equil)
    eout_is_aout = 'loga_equil' not in eout and 'loga.equil' not in eout

    # Check if type is a basis species name
    plot_loga_basis = False
    if type not in ["auto", "saturation", "loga.equil", "loga_equil", "loga.balance", "loga_balance"]:
        # Check if type matches a basis species name
        if 'basis' in eout:
            basis_species = list(eout['basis'].index) if hasattr(eout['basis'], 'index') else []
            if type in basis_species:
                plot_loga_basis = True
                if alpha:
                    raise ValueError("equilibrium activities of basis species not available with alpha = TRUE")

    # Handle type="saturation" - requires affinity output
    if type == "saturation":
        if not eout_is_aout:
            raise ValueError("type='saturation' requires output from affinity(), not equilibrate()")
        # Set eout_is_aout flag
        eout_is_aout = True

    # Get number of dimensions
    # Handle both dict (affinity) and list (equilibrate) values structures
    if isinstance(eout['values'], dict):
        first_values = list(eout['values'].values())[0]
    elif isinstance(eout['values'], list):
        first_values = eout['values'][0]
    else:
        first_values = eout['values']

    if hasattr(first_values, 'shape'):
        nd = len(first_values.shape)
    elif hasattr(first_values, '__len__'):
        nd = 1
    else:
        nd = 0  # Single value

    # For affinity output, get balancing coefficients
    if eout_is_aout and type == "auto":
        n_balance, balance = _get_balance(eout, balance, messages)
    elif eout_is_aout and type == "saturation":
        # For saturation diagrams, use n_balance = 1 for all species (don't normalize by stoichiometry)
        if isinstance(eout['values'], dict):
            n_balance = [1] * len(eout['values'])
        elif isinstance(eout['values'], list):
            n_balance = [1] * len(eout['values'])
        else:
            n_balance = [1]
        if balance is None:
            balance = 1
    else:
        # For equilibrate output, use n_balance from equilibrate if available
        if 'n_balance' in eout:
            n_balance = eout['n_balance']
            balance = eout.get('balance', 1)
        else:
            if isinstance(eout['values'], dict):
                n_balance = [1] * len(eout['values'])
            elif isinstance(eout['values'], list):
                n_balance = [1] * len(eout['values'])
            else:
                n_balance = [1]
            if balance is None:
                balance = 1

    # Determine what to plot
    plotvals = {}
    plotvar = eout.get('property', 'A')

    # Calculate equilibrium log activity/fugacity of basis species
    if plot_loga_basis:
        # Find the index of the basis species
        basis_df = eout['basis']
        ibasis = list(basis_df.index).index(type)

        # Get the logarithm of activity used in the affinity calculation
        logact = basis_df.iloc[ibasis]['logact']

        # Check if logact is numeric
        try:
            loga_basis = float(logact)
        except (ValueError, TypeError):
            raise ValueError(f"the logarithm of activity for basis species {type} is not numeric - was a buffer selected?")

        # Get the reaction coefficients for this basis species
        # eout['species'] is a DataFrame with basis species as columns
        nu_basis = eout['species'].iloc[:, ibasis].values

        # Calculate the logarithm of activity where affinity = 0
        # loga_equilibrium = loga_basis - affinity / nu_basis
        plotvals = {}
        for i, (sp_idx, affinity_vals) in enumerate(eout['values'].items()):
            plotvals[sp_idx] = loga_basis - affinity_vals / nu_basis[i]

        plotvar = type

        # Set n_balance (not used for basis species plots, but needed for compatibility)
        n_balance = [1] * len(plotvals)
        if balance is None:
            balance = 1
    elif eout_is_aout:
        # Plot affinity values divided by balancing coefficients
        # DEBUG: Check balance application
        if False:  # Set to True for debugging
            print(f"\nDEBUG: Applying balance to affinity values")
            print(f"  n_balance: {n_balance}")

        # Handle dict-based values (from affinity)
        if isinstance(eout['values'], dict):
            for i, (species_idx, values) in enumerate(eout['values'].items()):
                if False:  # Set to True for debugging
                    print(f"  Species {i} (ispecies {species_idx}): values/n_balance[{i}]={n_balance[i]}")
                plotvals[species_idx] = values / n_balance[i]
        # Handle list-based values
        elif isinstance(eout['values'], list):
            for i, values in enumerate(eout['values']):
                species_idx = eout['species']['ispecies'].iloc[i]
                plotvals[species_idx] = values / n_balance[i]

        if plotvar == 'A':
            plotvar = 'A/(2.303RT)'
            if nd == 1:
                if messages:
                    print(f"diagram: plotting {plotvar} / n.balance")
    else:
        # Plot equilibrated activities
        # Check for both Python naming (loga_equil) and R naming (loga.equil)
        loga_equil_key = 'loga_equil' if 'loga_equil' in eout else 'loga.equil'
        loga_equil_list = eout[loga_equil_key]

        # For equilibrate output, keep plotvals as a dict with INTEGER indices as keys
        # This preserves the 1:1 correspondence with the species list, including duplicates
        # Do NOT use ispecies as keys because duplicates would overwrite each other
        if isinstance(loga_equil_list, list):
            for i, loga_val in enumerate(loga_equil_list):
                plotvals[i] = loga_val  # Use integer index, not ispecies
        else:
            # Already a dict
            plotvals = loga_equil_list

        plotvar = 'loga.equil'

    # Handle alpha (degree of formation)
    if alpha:
        # Convert to activities (remove logarithms)
        # Use numpy arrays for proper element-wise operations
        act_vals = {}
        for k, v in plotvals.items():
            if isinstance(v, np.ndarray):
                act_vals[k] = 10**v
            else:
                act_vals[k] = np.power(10, v)

        # Scale by balance if requested
        if alpha == "balance":
            species_keys = list(act_vals.keys())
            for i, k in enumerate(species_keys):
                act_vals[k] = act_vals[k] * n_balance[i]

        # Calculate sum of activities (element-wise for arrays)
        # Get the first value to determine shape
        first_val = list(act_vals.values())[0]
        if isinstance(first_val, np.ndarray):
            # Multi-dimensional case
            sum_act = np.zeros_like(first_val)
            for v in act_vals.values():
                sum_act = sum_act + v
        else:
            # Single value case
            sum_act = sum(act_vals.values())

        # Calculate alpha (fraction) - element-wise division
        plotvals = {k: v / sum_act for k, v in act_vals.items()}
        plotvar = "alpha"

    # Get species information for labels
    species_df = eout['species']
    if names is None:
        names = species_df['name'].tolist()

    # Format chemical names if requested
    if format_names and not alpha:
        names = [_format_chemname(name) for name in names]

    # Prepare for plotting
    if nd == 0:
        # 0-D: Bar plot (not implemented yet)
        raise NotImplementedError("0-D bar plots not yet implemented")

    elif nd == 1:
        # 1-D: Line plot
        result = _plot_1d(eout, plotvals, plotvar, names, n_balance, balance,
                       xlab, ylab, xlim, ylim, col, lty, lwd, main, add, plot_it, ax, width, height, plot_was_provided, **kwargs)

    elif nd == 2:
        # 2-D: Predominance diagram or saturation lines
        # Pass lty and cex through kwargs for saturation plots
        result = _plot_2d(eout, plotvals, plotvar, names, n_balance, balance,
                       xlab, ylab, xlim, ylim, col, col_names, fill, fill_NA, limit_water, lwd, main, add, plot_it, ax,
                       type, contour_method, messages, width, height, plot_was_provided, lty=lty, cex=cex, **kwargs)

    else:
        raise ValueError(f"Cannot create diagram with {nd} dimensions")

    # Handle Jupyter display behavior
    # When plot_it=True, we want the figure to display
    # When plot_it=False, we want to suppress display and close the figure
    if not plot_it and result is not None and 'fig' in result:
        # Close the figure to prevent auto-display in Jupyter
        # The figure is still in the result dict, so users can access it via result['fig']
        # but it won't be displayed automatically
        plt.close(result['fig'])
    elif plot_it and result is not None and 'fig' in result:
        # Try to use IPython display if available (for Jupyter notebooks)
        try:
            from IPython.display import display
            display(result['fig'])
        except (ImportError, NameError):
            # Not in IPython/Jupyter, regular matplotlib display
            pass

    return result

Plot equilibrium chemical activity and predominance diagrams.

This function creates plots from the output of affinity() or equilibrate(). For 1D diagrams, it produces line plots showing how affinity or activity varies with a single variable. For 2D diagrams, it creates predominance field diagrams.

Parameters

eout : dict
Output from affinity() or equilibrate()
type : str, default "auto"
Type of diagram: - "auto" (default): Plot affinity values (A/2.303RT) - "loga.equil": Plot equilibrium activities from equilibrate() - "saturation": Draw affinity=0 contour lines (mineral saturation) - Basis species name (e.g., "O2", "H2O", "CO2"): Plot equilibrium log activity/fugacity of the specified basis species where affinity=0 for each formed species. Useful for Eh-pH diagrams and showing oxygen/water fugacities at equilibrium.
alpha : bool or str, default False
Plot degree of formation instead of activities? If "balance", scale by balancing coefficients
balance : str, float, or list of float, optional
Balancing coefficients or method for balancing reactions
names : list of str, optional
Custom names for species (for labels)
format_names : bool, default True
Apply formatting to chemical formulas?
xlab : str, optional
Custom x-axis label
ylab : str, optional
Custom y-axis label
xlim : list of float, optional
X-axis limits [min, max]
ylim : list of float, optional
Y-axis limits [min, max]
col : str or list of str, optional
Line colors for 1-D plots and boundary lines in 2-D plots (matplotlib color specs)
col_names : str or list of str, optional
Text colors for field labels in 2-D plots (matplotlib color specs)
lty : str, int, or list, optional
Line styles (matplotlib linestyle specs)
lwd : float or list of float, default 1
Line widths for 1-D plots and boundary lines in 2-D predominance diagrams. Set to 0 to disable borders in 2-D diagrams. If fill is None and lwd > 0, uses white fill with black borders (R CHNOSZ default).
cex : float or list of float, default 1.0
Character expansion factor for text labels. Values > 1 make text larger, values < 1 make text smaller. Can be a single value or a list (one per species). Used for contour labels in type="saturation" plots.
main : str, optional
Plot title
fill : str, optional
Color palette for 2-D predominance diagrams. Can be any matplotlib colormap name (e.g., 'viridis', 'plasma', 'terrain', 'rainbow', 'Set1', 'tab10', 'Pastel1'). If None, uses discrete colors from the default color cycle. Ignored for 1-D diagrams.
fill_NA : str, default "0.8"
Color for regions outside water stability limits (water instability regions). Matplotlib color specification (e.g., "0.8" for gray, "#CCCCCC"). Set to "transparent" to disable shading. Default "0.8" matches R's "gray80".
limit_water : bool, optional
Whether to show water stability limits as shaded regions (default True for 2-D diagrams). If True, also clips the diagram to the water stability region. Set to False to disable water stability shading.
plot_it : bool, default True
Display the plot?
add_to : dict, optional

A diagram result dictionary from a previous diagram() call. When provided, this plot will be AUTOMATICALLY COPIED and the new diagram will be added to the copy. This preserves the original plot while creating a modified version. The axes object is extracted from add_to['ax'].

This parameter eliminates the need for a separate 'add' boolean - when add_to is provided, the function automatically operates in "add" mode.

Example workflow:

plot_a = diagram(affinity1, fill='terrain') # Create base plot plot_a1 = diagram(affinity2, add_to=plot_a, col='blue') # Copy and add plot_a2 = diagram(affinity3, add_to=plot_a, col='red') # Copy and add again

plot_a remains unchanged, plot_a1 and plot_a2 are independent modifications

contour_method : str or list of str, optional
Method for labeling contour lines. Default "edge" labels at plot edges. Can be a single value (applied to all species) or a list (one per species). Set to None, NA, or "" to disable labels (only for type="saturation"). In R CHNOSZ, different methods like "edge", "flattest", "simple" control label placement; in Python, this mainly controls whether labels are shown.
interactive : bool, default False
Create an interactive plot using Plotly instead of matplotlib? If True, calls diagram_interactive() with the appropriate parameters.
annotation : str, optional
For interactive plots only. Annotation text to add to the plot.
annotation_coords : list of float, default [0, 0]
For interactive plots only. Coordinates of annotation, where [0, 0] is bottom left and [1, 1] is top right.
width : int, default 600
For interactive plots only. Width of the plot in pixels.
height : int, default 520
For interactive plots only. Height of the plot in pixels.
save_as : str, optional
For interactive plots only. Provide a filename to save this figure. Filetype is determined by save_format.
save_format : str, optional
For interactive plots only. Desired format of saved or downloaded figure. Can be 'png', 'jpg', 'jpeg', 'webp', 'svg', 'pdf', 'eps', 'json', or 'html'. If 'html', an interactive plot will be saved.
save_scale : float, default 1
For interactive plots only. Multiply title/legend/axis/canvas sizes by this factor when saving the figure.
**kwargs
Additional arguments passed to matplotlib plotting functions

Returns

dict
Dictionary containing: - plotvar : str, Variable that was plotted - plotvals : dict, Values that were plotted - names : list, Names used for labels - predominant : array or NA, Predominance matrix (for 2D) - balance : str or list, Balancing coefficients used - n.balance : list, Numerical balancing coefficients - ax : matplotlib.axes.Axes, The axes object used for plotting (if plot_it=True) - fig : matplotlib.figure.Figure, The figure object used for plotting (if plot_it=True) - All original eout contents

Examples

>>> import pychnosz
>>> pychnosz.basis(["Fe2O3", "CO2", "H2O", "NH3", "H2S", "oxygen", "H+"],
...              [0, -3, 0, -4, -7, -80, -7])
>>> pychnosz.species(["pyrite", "goethite"])
>>> a = pychnosz.affinity(H2S=[-60, 20, 5], T=25, P=1)
>>> d = diagram(a)

Notes

This implementation is based on R CHNOSZ diagram() function but adapted for Python's matplotlib plotting instead of R's base graphics. The key differences from diagram_from_WORM.py are: - Works directly with Python dict output from affinity() (no rpy2) - Uses matplotlib for 1D plots by default - Can optionally use plotly if requested

def diagram_interactive(eout: Dict[str, Any],
type: str = 'auto',
main: str | None = None,
borders: float | str = 0,
names: List[str] | None = None,
format_names: bool = True,
annotation: str | None = None,
annotation_coords: List[float] = [0, 0],
balance: str | float | List[float] | None = None,
xlab: str | None = None,
ylab: str | None = None,
fill: str | List[str] | None = 'viridis',
width: int = 600,
height: int = 520,
alpha: bool | str = False,
add: bool = False,
ax: Any | None = None,
col: str | List[str] | None = None,
lty: str | int | List | None = None,
lwd: float | List[float] = 1,
cex: float | List[float] = 1.0,
contour_method: str | List[str] | None = 'edge',
messages: bool = True,
plot_it: bool = True,
save_as: str | None = None,
save_format: str | None = None,
save_scale: float = 1) ‑> Tuple[pandas.core.frame.DataFrame, Any]
Expand source code
def diagram_interactive(eout: Dict[str, Any],
                        type: str = "auto",
                        main: Optional[str] = None,
                        borders: Union[float, str] = 0,
                        names: Optional[List[str]] = None,
                        format_names: bool = True,
                        annotation: Optional[str] = None,
                        annotation_coords: List[float] = [0, 0],
                        balance: Optional[Union[str, float, List[float]]] = None,
                        xlab: Optional[str] = None,
                        ylab: Optional[str] = None,
                        fill: Optional[Union[str, List[str]]] = "viridis",
                        width: int = 600,
                        height: int = 520,
                        alpha: Union[bool, str] = False,
                        add: bool = False,
                        ax: Optional[Any] = None,
                        col: Optional[Union[str, List[str]]] = None,
                        lty: Optional[Union[str, int, List]] = None,
                        lwd: Union[float, List[float]] = 1,
                        cex: Union[float, List[float]] = 1.0,
                        contour_method: Optional[Union[str, List[str]]] = "edge",
                        messages: bool = True,
                        plot_it: bool = True,
                        save_as: Optional[str] = None,
                        save_format: Optional[str] = None,
                        save_scale: float = 1) -> Tuple[pd.DataFrame, Any]:
    """
    Create an interactive diagram using Plotly.

    This function produces interactive versions of the diagrams created by diagram(),
    using Plotly for interactivity. It accepts output from affinity() or equilibrate()
    and creates either 1D line plots or 2D predominance diagrams.

    Parameters
    ----------
    eout : dict
        Output from affinity() or equilibrate().
    main : str, optional
        Title of the plot.
    borders : float or str, default 0
        Controls boundary lines between regions in 2D predominance diagrams.
        - If numeric > 0: draws grid-aligned borders with specified thickness (pixels)
        - If "contour": draws smooth contour-based boundaries (like diagram())
        - If 0 or None: no borders drawn
    names : list of str, optional
        Names of species for activity lines or predominance fields.
    format_names : bool, default True
        Apply formatting to chemical formulas?
    annotation : str, optional
        Annotation to add to the plot.
    annotation_coords : list of float, default [0, 0]
        Coordinates of annotation, where 0,0 is bottom left and 1,1 is top right.
    balance : str or numeric, optional
        How to balance the transformations.
    xlab : str, optional
        Custom x-axis label.
    ylab : str, optional
        Custom y-axis label.
    fill : str or list of str, default "viridis"
        For 2D diagrams: colormap name (e.g., "viridis", "hot") or list of colors.
        For 1D diagrams: list of line colors.
    width : int, default 600
        Width of the plot in pixels.
    height : int, default 520
        Height of the plot in pixels.
    alpha : bool or str, default False
        For speciation diagrams, plot degree of formation instead of activities?
        If True, plots mole fractions. If "balance", scales by stoichiometry.
    messages : bool, default True
        Display messages?
    plot_it : bool, default True
        Show the plot?
    save_as : str, optional
        Provide a filename to save this figure. Filetype of saved figure is
        determined by save_format.
    save_format : str, default "png"
        Desired format of saved or downloaded figure. Can be 'png', 'jpg', 'jpeg',
        'webp', 'svg', 'pdf', 'eps', 'json', or 'html'. If 'html', an interactive
        plot will be saved.
    save_scale : float, default 1
        Multiply title/legend/axis/canvas sizes by this factor when saving.

    Returns
    -------
    tuple
        (df, fig) where df is a pandas DataFrame with the data and fig is the
        Plotly figure object.

    Examples
    --------
    1D diagram:
    >>> basis("CHNOS+")
    >>> species(info(["glycinium", "glycine", "glycinate"]))
    >>> a = affinity(pH=[0, 14])
    >>> e = equilibrate(a)
    >>> diagram_interactive(e, alpha=True)

    2D diagram:
    >>> basis(["Fe", "oxygen", "S2"])
    >>> species(["iron", "ferrous-oxide", "magnetite", "hematite", "pyrite", "pyrrhotite"])
    >>> a = affinity(S2=[-50, 0], O2=[-90, -10], T=200)
    >>> diagram_interactive(a, fill="hot")

    Notes
    -----
    This function requires plotly to be installed. Install with:
        pip install plotly

    The function adapts the pyCHNOSZ diagram_interactive() implementation
    to work with Python CHNOSZ's native data structures.
    """

    # Import plotly (lazy import to avoid dependency issues)
    try:
        import plotly.express as px
        import plotly.graph_objects as go
        import plotly.io as pio
    except ImportError:
        raise ImportError("diagram_interactive() requires plotly. Install with: pip install plotly")

    # Check that eout is valid
    efun = eout.get('fun', '')
    if efun not in ['affinity', 'equilibrate', 'solubility']:
        raise ValueError("'eout' is not the output from affinity(), equilibrate(), or solubility()")

    # Determine if this is affinity or equilibrate output
    calc_type = "a" if ('loga_equil' not in eout and 'loga.equil' not in eout) else "e"

    # Get basis species and their states
    basis_df = eout['basis']
    basis_sp = list(basis_df.index)
    basis_state = list(basis_df['state'])

    # Get variable names and values
    xyvars = eout['vars']
    xyvals_dict = eout['vals']
    # Convert vals dict to list format for easier access
    xyvals = [xyvals_dict[var] for var in xyvars]

    # Determine balance if not provided
    if balance is None or balance == "":
        # For saturation diagrams, use balance=1 (formula units) to match R behavior
        # This avoids issues when minerals don't have a common basis element
        if type == "saturation":
            balance = 1
            n_balance = [1] * len(eout['values'])
        else:
            # Call diagram with plot_it=False to get balance
            # Need to import matplotlib to close the figure afterward
            import matplotlib.pyplot as plt_temp
            temp_result = diagram(eout, messages=False, plot_it=False)
            balance = temp_result.get('balance', 1)
            n_balance = temp_result.get('n_balance', [1])
            # Close the matplotlib figure created by diagram() since we don't need it
            if 'fig' in temp_result and temp_result['fig'] is not None:
                plt_temp.close(temp_result['fig'])
    else:
        # Calculate n_balance from balance
        try:
            balance_float = float(balance)
            n_balance = [balance_float] * len(eout['values'])
        except (ValueError, TypeError):
            # balance is a string (element name)
            # Get species from eout instead of global state
            if 'species' in eout and eout['species'] is not None:
                sp_df = eout['species']
            else:
                # Fallback to global species if not in eout
                from .species import species as species_func
                sp_df = species_func()

            # Check if balance is a list (user-provided values) or a string (column name)
            if isinstance(balance, list):
                n_balance = balance
            elif balance in sp_df.columns:
                n_balance = list(sp_df[balance])
            else:
                n_balance = [1] * len(eout['values'])

    # Get output values
    if calc_type == "a":
        # handling output of affinity()
        out_vals = eout['values']
        out_units = "A/(2.303RT)"
    else:
        # handling output of equilibrate()
        loga_equil_key = 'loga_equil' if 'loga_equil' in eout else 'loga.equil'
        out_vals = eout[loga_equil_key]
        out_units = "log a"

    # Convert values to a list format
    if isinstance(out_vals, dict):
        nsp = len(out_vals)
        values_list = list(out_vals.values())
        species_indices = list(out_vals.keys())
    else:
        nsp = len(out_vals)
        values_list = out_vals
        species_indices = eout['species']['ispecies'].tolist()

    # Get species names
    from .info import info as info_func
    # Convert numpy types to Python types
    species_indices_py = [int(idx) for idx in species_indices]
    sp_info = info_func(species_indices_py, messages=False)
    sp_names = sp_info['name'].tolist()

    # Use custom names if provided
    if isinstance(names, list) and len(names) == len(sp_names):
        sp_names = names

    # Determine dimensions
    first_val = values_list[0]
    if hasattr(first_val, 'shape'):
        nd = len(first_val.shape)
    else:
        nd = 1 if hasattr(first_val, '__len__') else 0

    # Handle type="saturation" - plot contour lines where affinity=0
    if type == "saturation":
        if nd != 2:
            raise ValueError("type='saturation' requires 2-D diagram")
        if calc_type != "a":
            raise ValueError("type='saturation' requires output from affinity(), not equilibrate()")

        # Delegate to saturation plotting function
        return _plot_saturation_interactive(
            eout, values_list, sp_names, xyvars, xyvals,
            xlab, ylab, col, lwd, lty, cex, contour_method,
            main, add, ax, width, height, plot_it,
            save_as, save_format, save_scale, messages
        )

    # Build DataFrame
    if nd == 2:
        # 2D case - flatten the data
        xvals = xyvals[0]
        yvals = xyvals[1]
        xvar = xyvars[0]
        yvar = xyvars[1]

        # Flatten the data - transpose first so coordinates match
        # Original shape is (nx, ny) where nx=len(xvals), ny=len(yvals)
        # After transpose, shape is (ny, nx)
        # Flattening with C-order then gives: [row0, row1, ...] = [x-values at y[0], x-values at y[1], ...]
        flat_out_vals = []
        for v in values_list:
            # Transpose then flatten so coordinates align correctly
            flat_out_vals.append(v.T.flatten())
        df = pd.DataFrame(flat_out_vals, index=sp_names).T

        # Apply balance if needed
        if calc_type == "a":
            if isinstance(balance, str):
                # Get balance from species dataframe
                # Get species from eout instead of global state
                if 'species' in eout and eout['species'] is not None:
                    sp_df = eout['species']
                else:
                    # Fallback to global species if not in eout
                    from .species import species as species_func
                    sp_df = species_func()

                # Check if balance is a list (user-provided values) or a string (column name)
                if isinstance(balance, list):
                    n_balance = balance
                elif balance in sp_df.columns:
                    n_balance = list(sp_df[balance])
            # Divide by balance
            for i, sp in enumerate(sp_names):
                df[sp] = df[sp] / n_balance[i]

        # Find predominant species
        df["pred"] = df.idxmax(axis=1, skipna=True)
        df["prednames"] = df["pred"]

        # Add x and y coordinates
        # After transpose and flatten, data is ordered as:
        # [x0,y0], [x1,y0], ..., [xn,y0], [x0,y1], [x1,y1], ...
        xvals_full = list(xvals) * len(yvals)
        yvals_full = []
        for y in yvals:
            yvals_full.extend([y] * len(xvals))
        df[xvar] = xvals_full
        df[yvar] = yvals_full

    else:
        # 1D case
        xvar = xyvars[0]
        xvals = xyvals[0]

        flat_out_vals = []
        for v in values_list:
            flat_out_vals.append(v)
        df = pd.DataFrame(flat_out_vals, index=sp_names).T

        # Apply balance if needed
        if calc_type == "a":
            if isinstance(balance, str):
                # Get species from eout instead of global state
                if 'species' in eout and eout['species'] is not None:
                    sp_df = eout['species']
                else:
                    # Fallback to global species if not in eout
                    from .species import species as species_func
                    sp_df = species_func()

                # Check if balance is a list (user-provided values) or a string (column name)
                if isinstance(balance, list):
                    n_balance = balance
                elif balance in sp_df.columns:
                    n_balance = list(sp_df[balance])
            # Divide by balance
            for i, sp in enumerate(sp_names):
                df[sp] = df[sp] / n_balance[i]

        # Handle alpha (degree of formation)
        if alpha:
            df = df.apply(lambda x: 10**x)
            df = df[sp_names].div(df[sp_names].sum(axis=1), axis=0)

        df[xvar] = xvals

    # Create axis labels
    unit_dict = {"P": "bar", "T": "°C", "pH": "", "Eh": "volts", "IS": "mol/kg"}

    for i, s in enumerate(basis_sp):
        if basis_state[i] in ["aq", "liq", "cr"]:
            if format_names:
                unit_dict[s] = f"log <i>a</i><sub>{_format_html_species(s)}</sub>"
            else:
                unit_dict[s] = f"log <i>a</i><sub>{s}</sub>"
        else:
            if format_names:
                unit_dict[s] = f"log <i>f</i><sub>{_format_html_species(s)}</sub>"
            else:
                unit_dict[s] = f"log <i>f</i><sub>{s}</sub>"

    # Set x-axis label
    if not isinstance(xlab, str):
        xlab = xvar + ", " + unit_dict.get(xvar, "")
        if xvar == "pH":
            xlab = "pH"
        if xvar in basis_sp:
            xlab = unit_dict[xvar]

    # Create the plot
    if nd == 1:
        # 1D plot
        # Melt the dataframe for plotting
        df_melted = pd.melt(df, id_vars=[xvar], value_vars=sp_names, var_name='variable', value_name='value')

        # Format species names if requested
        if format_names:
            df_melted['variable'] = df_melted['variable'].apply(_format_html_species)

        # Set y-axis label
        if not isinstance(ylab, str):
            if alpha:
                ylab = "alpha"
            else:
                ylab = out_units

        fig = px.line(df_melted, x=xvar, y="value", color='variable',
                      template="simple_white", width=width, height=height,
                      labels={'value': ylab, xvar: xlab},
                      render_mode='svg')

        # Apply custom colors if provided
        if isinstance(fill, list):
            for i, color in enumerate(fill):
                if i < len(fig.data):
                    fig.data[i].line.color = color

        # Check for LaTeX format in axis labels
        if xlab and _detect_latex_format(xlab):
            warnings.warn(
                "LaTeX formatting detected in 'xlab' parameter. "
                "Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
                "For activity ratios, use ratlab_html() instead of ratlab().",
                UserWarning
            )
        if ylab and _detect_latex_format(ylab):
            warnings.warn(
                "LaTeX formatting detected in 'ylab' parameter. "
                "Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
                "For activity ratios, use ratlab_html() instead of ratlab().",
                UserWarning
            )

        fig.update_layout(xaxis_title=xlab,
                          yaxis_title=ylab,
                          legend_title=None)

        if isinstance(main, str):
            fig.update_layout(title={'text': main, 'x': 0.5, 'xanchor': 'center'})

        if isinstance(annotation, str):
            # Check for LaTeX format and warn user
            if _detect_latex_format(annotation):
                warnings.warn(
                    "LaTeX formatting detected in 'annotation' parameter. "
                    "Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
                    "For activity ratios, use ratlab_html() instead of ratlab().",
                    UserWarning
                )

            fig.add_annotation(
                x=annotation_coords[0],
                y=annotation_coords[1],
                text=annotation,
                showarrow=False,
                xref="paper",
                yref="paper",
                align='left',
                bgcolor="rgba(255, 255, 255, 0.5)")

        # Configure download button
        save_as_name, save_format_final = _save_figure(fig, save_as, save_format, save_scale,
                                                        plot_width=width, plot_height=height, ppi=1)

        config = {'displaylogo': False,
                  'modeBarButtonsToRemove': ['resetScale2d', 'toggleSpikelines'],
                  'toImageButtonOptions': {
                      'format': save_format_final,
                      'filename': save_as_name,
                      'height': height,
                      'width': width,
                      'scale': save_scale,
                  }}

        # Store config on figure so it persists when fig.show() is called later
        fig._config = fig._config | config

    else:
        # 2D plot
        # Map species names to numeric values
        mappings = {s: lab for s, lab in zip(sp_names, range(len(sp_names)))}
        df['pred'] = df['pred'].map(mappings).astype(int)

        # Reshape data
        # Data is flattened as [x0,y0], [x1,y0], ..., [xn,y0], [x0,y1], ...
        # Reshape to (ny, nx) for proper orientation in Plotly
        # Plotly expects data[i,j] to correspond to x[j], y[i]
        data = np.array(df['pred'])
        shape = (len(yvals), len(xvals))
        dmap = data.reshape(shape)

        data_names = np.array(df['prednames'])
        dmap_names = data_names.reshape(shape)

        # Set y-axis label
        if not isinstance(ylab, str):
            ylab = yvar + ", " + unit_dict.get(yvar, "")
            if yvar in basis_sp:
                ylab = unit_dict[yvar]
            if yvar == "pH":
                ylab = "pH"

        # Check for LaTeX format in axis labels (2D plot)
        if xlab and _detect_latex_format(xlab):
            warnings.warn(
                "LaTeX formatting detected in 'xlab' parameter. "
                "Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
                "For activity ratios, use ratlab_html() instead of ratlab().",
                UserWarning
            )
        if ylab and _detect_latex_format(ylab):
            warnings.warn(
                "LaTeX formatting detected in 'ylab' parameter. "
                "Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
                "For activity ratios, use ratlab_html() instead of ratlab().",
                UserWarning
            )

        # Create heatmap
        fig = px.imshow(dmap, width=width, height=height, aspect="auto",
                        labels={'x': xlab, 'y': ylab, 'color': "region"},
                        x=xvals, y=yvals, template="simple_white")

        fig.update(data=[{'customdata': dmap_names,
                          'hovertemplate': xlab + ': %{x}<br>' + ylab + ': %{y}<br>Region: %{customdata}<extra></extra>'}])

        # Set colormap
        if fill == 'none':
            colormap = [[0, 'white'], [1, 'white']]
        elif isinstance(fill, list):
            colmap_temp = []
            for i, v in enumerate(fill):
                colmap_temp.append([i / (len(fill) - 1) if len(fill) > 1 else 0, v])
            colormap = colmap_temp
        else:
            colormap = fill

        fig.update_traces(dict(showscale=False,
                               coloraxis=None,
                               colorscale=colormap),
                          selector={'type': 'heatmap'})

        fig.update_yaxes(autorange=True)

        if isinstance(main, str):
            fig.update_layout(title={'text': main, 'x': 0.5, 'xanchor': 'center'})

        # Add species labels
        for s in sp_names:
            if s in set(df["prednames"]):
                df_s = df.loc[df["prednames"] == s]
                namex = df_s[xvar].mean()
                namey = df_s[yvar].mean()

                if format_names:
                    annot_text = _format_html_species(s)
                else:
                    annot_text = str(s)

                fig.add_annotation(x=namex, y=namey,
                                   text=annot_text,
                                   bgcolor="rgba(255, 255, 255, 0.5)",
                                   showarrow=False)

        if isinstance(annotation, str):
            # Check for LaTeX format and warn user
            if _detect_latex_format(annotation):
                warnings.warn(
                    "LaTeX formatting detected in 'annotation' parameter. "
                    "Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
                    "For activity ratios, use ratlab_html() instead of ratlab().",
                    UserWarning
                )

            fig.add_annotation(
                x=annotation_coords[0],
                y=annotation_coords[1],
                text=annotation,
                showarrow=False,
                xref="paper",
                yref="paper",
                align='left',
                bgcolor="rgba(255, 255, 255, 0.5)")

        # Add borders if requested
        if borders == "contour":
            # Use contour-based boundaries (smooth, like diagram())
            # Draw boundaries using matplotlib contour extraction without filling

            # Get unique species (excluding any that don't appear)
            unique_species_names = sorted(df["prednames"].unique())

            # Create a temporary matplotlib figure to extract contour paths
            # We won't display it, just use it to calculate contours
            temp_fig, temp_ax = plt.subplots()

            # For each species, create a binary mask and extract contours
            for i, sp_name in enumerate(unique_species_names):
                # Create binary mask: 1 where this species predominates, 0 elsewhere
                z = (dmap_names == sp_name).astype(float)

                # Create meshgrid for contour
                X, Y = np.meshgrid(xvals, yvals)

                # Find contours at level 0.5 using matplotlib
                try:
                    cs = temp_ax.contour(X, Y, z, levels=[0.5])

                    # Extract the contour segments
                    # cs.allsegs is a list of lists: [level][segment]
                    for level_segs in cs.allsegs:
                        for segment in level_segs:
                            # segment is an (N, 2) array of (x, y) coordinates
                            # Add as a scatter trace with lines
                            fig.add_trace(
                                go.Scatter(
                                    x=segment[:, 0],
                                    y=segment[:, 1],
                                    mode='lines',
                                    line=dict(color='black', width=2),
                                    hoverinfo='skip',
                                    showlegend=False
                                )
                            )

                    # Clear the temp axes for next species
                    temp_ax.clear()
                except Exception as e:
                    if messages:
                        warnings.warn(f"Could not draw contour for {sp_name}: {e}")
                    pass  # Skip if contour can't be drawn

            # Close the temporary figure
            plt.close(temp_fig)

        elif isinstance(borders, (int, float)) and borders > 0:
            unique_x_vals = sorted(list(set(df[xvar])))
            unique_y_vals = sorted(list(set(df[yvar])))

            # Skip border drawing if there are fewer than 2 unique values
            # (single point or single line - no borders to draw between regions)
            if len(unique_x_vals) < 2 or len(unique_y_vals) < 2:
                if messages:
                    warnings.warn("Skipping border drawing: need at least 2 unique values in each dimension")
            else:
                def mov_mean(numbers, window_size=2):
                    moving_averages = []
                    for i in range(len(numbers) - window_size + 1):
                        window_average = sum(numbers[i:i + window_size]) / window_size
                        moving_averages.append(window_average)
                    return moving_averages

                x_mov_mean = mov_mean(unique_x_vals)
                y_mov_mean = mov_mean(unique_y_vals)

                x_plot_min = x_mov_mean[0] - (x_mov_mean[1] - x_mov_mean[0])
                y_plot_min = y_mov_mean[0] - (y_mov_mean[1] - y_mov_mean[0])

                x_plot_max = x_mov_mean[-1] + (x_mov_mean[1] - x_mov_mean[0])
                y_plot_max = y_mov_mean[-1] + (y_mov_mean[1] - y_mov_mean[0])

                x_vals_border = [x_plot_min] + x_mov_mean + [x_plot_max]
                y_vals_border = [y_plot_min] + y_mov_mean + [y_plot_max]

                # Find border lines
                def find_line(dmap, row_index):
                    return [i for i in range(len(dmap[row_index]) - 1) if dmap[row_index][i] != dmap[row_index][i + 1]]

                nrows, ncols = dmap.shape
                vlines = [find_line(dmap, row_i) for row_i in range(nrows)]

                dmap_transposed = dmap.transpose()
                nrows_t, ncols_t = dmap_transposed.shape
                hlines = [find_line(dmap_transposed, row_i) for row_i in range(nrows_t)]

                y_coord_list_vertical = []
                x_coord_list_vertical = []
                for i, row in enumerate(vlines):
                    for line in row:
                        x_coord_list_vertical += [x_vals_border[line + 1], x_vals_border[line + 1], np.nan]
                        y_coord_list_vertical += [y_vals_border[i], y_vals_border[i + 1], np.nan]

                y_coord_list_horizontal = []
                x_coord_list_horizontal = []
                for i, col in enumerate(hlines):
                    for line in col:
                        y_coord_list_horizontal += [y_vals_border[line + 1], y_vals_border[line + 1], np.nan]
                        x_coord_list_horizontal += [x_vals_border[i], x_vals_border[i + 1], np.nan]

                fig.add_trace(
                    go.Scatter(
                        mode='lines',
                        x=x_coord_list_horizontal,
                        y=y_coord_list_horizontal,
                        line={'width': borders, 'color': 'black'},
                        hoverinfo='skip',
                        showlegend=False))

                fig.add_trace(
                    go.Scatter(
                        mode='lines',
                        x=x_coord_list_vertical,
                        y=y_coord_list_vertical,
                        line={'width': borders, 'color': 'black'},
                        hoverinfo='skip',
                        showlegend=False))

                fig.update_yaxes(range=[min(yvals), max(yvals)], autorange=False, mirror=True)
                fig.update_xaxes(range=[min(xvals), max(xvals)], autorange=False, mirror=True)

        # Configure download button
        save_as_name, save_format_final = _save_figure(fig, save_as, save_format, save_scale,
                                                        plot_width=width, plot_height=height, ppi=1)

        config = {'displaylogo': False,
                  'modeBarButtonsToRemove': ['zoom2d', 'pan2d', 'zoomIn2d', 'zoomOut2d',
                                             'autoScale2d', 'resetScale2d', 'toggleSpikelines',
                                             'hoverClosestCartesian', 'hoverCompareCartesian'],
                  'toImageButtonOptions': {
                      'format': save_format_final,
                      'filename': save_as_name,
                      'height': height,
                      'width': width,
                      'scale': save_scale,
                  }}

        # Store config on figure so it persists when fig.show() is called later
        fig._config = fig._config | config

    if plot_it:
        fig.show(config=config)

    return df, fig

Create an interactive diagram using Plotly.

This function produces interactive versions of the diagrams created by diagram(), using Plotly for interactivity. It accepts output from affinity() or equilibrate() and creates either 1D line plots or 2D predominance diagrams.

Parameters

eout : dict
Output from affinity() or equilibrate().
main : str, optional
Title of the plot.
borders : float or str, default 0
Controls boundary lines between regions in 2D predominance diagrams. - If numeric > 0: draws grid-aligned borders with specified thickness (pixels) - If "contour": draws smooth contour-based boundaries (like diagram()) - If 0 or None: no borders drawn
names : list of str, optional
Names of species for activity lines or predominance fields.
format_names : bool, default True
Apply formatting to chemical formulas?
annotation : str, optional
Annotation to add to the plot.
annotation_coords : list of float, default [0, 0]
Coordinates of annotation, where 0,0 is bottom left and 1,1 is top right.
balance : str or numeric, optional
How to balance the transformations.
xlab : str, optional
Custom x-axis label.
ylab : str, optional
Custom y-axis label.
fill : str or list of str, default "viridis"
For 2D diagrams: colormap name (e.g., "viridis", "hot") or list of colors. For 1D diagrams: list of line colors.
width : int, default 600
Width of the plot in pixels.
height : int, default 520
Height of the plot in pixels.
alpha : bool or str, default False
For speciation diagrams, plot degree of formation instead of activities? If True, plots mole fractions. If "balance", scales by stoichiometry.
messages : bool, default True
Display messages?
plot_it : bool, default True
Show the plot?
save_as : str, optional
Provide a filename to save this figure. Filetype of saved figure is determined by save_format.
save_format : str, default "png"
Desired format of saved or downloaded figure. Can be 'png', 'jpg', 'jpeg', 'webp', 'svg', 'pdf', 'eps', 'json', or 'html'. If 'html', an interactive plot will be saved.
save_scale : float, default 1
Multiply title/legend/axis/canvas sizes by this factor when saving.

Returns

tuple
(df, fig) where df is a pandas DataFrame with the data and fig is the Plotly figure object.

Examples

1D diagram:

>>> basis("CHNOS+")
>>> species(info(["glycinium", "glycine", "glycinate"]))
>>> a = affinity(pH=[0, 14])
>>> e = equilibrate(a)
>>> diagram_interactive(e, alpha=True)

2D diagram:

>>> basis(["Fe", "oxygen", "S2"])
>>> species(["iron", "ferrous-oxide", "magnetite", "hematite", "pyrite", "pyrrhotite"])
>>> a = affinity(S2=[-50, 0], O2=[-90, -10], T=200)
>>> diagram_interactive(a, fill="hot")

Notes

This function requires plotly to be installed. Install with: pip install plotly

The function adapts the pyCHNOSZ diagram_interactive() implementation to work with Python CHNOSZ's native data structures.

def dissrxn2logK(OBIGT, i, Tc)
Expand source code
def dissrxn2logK(OBIGT, i, Tc):
    
    this_dissrxn = OBIGT.iloc[i, OBIGT.columns.get_loc('dissrxn')]
    
    if this_dissrxn == "nan":
        this_dissrxn = OBIGT.iloc[i, OBIGT.columns.get_loc('regenerate_dissrxn')]
    
#     print(OBIGT["name"][i], this_dissrxn)
    
    try:
        this_dissrxn = this_dissrxn.strip()
        split_dissrxn = this_dissrxn.split(" ")
    except:
        return float('NaN')
    
    
    
    coeff = [float(n) for n in split_dissrxn[::2]]
    species = split_dissrxn[1::2]
    try:
        G = sum([float(c*OBIGT.loc[OBIGT["name"]==sp, "G_TP"].iloc[0]) for c,sp in zip(coeff, species)])
    except:
        G_list = []
        for ii, sp in enumerate(species):
            G_TP = OBIGT.loc[OBIGT["name"]==sp, "G_TP"]
            if len(G_TP) == 1:
                G_list.append(float(coeff[ii]*OBIGT.loc[OBIGT["name"]==sp, "G_TP"]))
            else:
                ### check valid polymorph T

                # get polymorph entries of OBIGT that match mineral
                poly_df = copy.copy(OBIGT.loc[OBIGT["name"]==sp,:])
                # ensure polymorph df is sorted according to cr, cr2, cr3... etc.
                poly_df = poly_df.sort_values("state")

                z_Ts = list(poly_df.loc[poly_df["name"]==sp, "z.T"])

                last_t = float('-inf')
                appended=False
                for iii,t in enumerate(z_Ts):

                    if Tc+273.15 > last_t and Tc+273.15 < t:
                        G_list.append(float(coeff[ii]*list(poly_df.loc[poly_df["name"]==sp, "G_TP"])[iii]))
                        appended=True
                    if not appended and z_Ts[-1] == t:
                        G_list.append(float(coeff[ii]*list(poly_df.loc[poly_df["name"]==sp, "G_TP"])[iii]))
                    last_t = t

        G = sum(G_list)

    return G2logK(G, Tc)
def entropy(formula: str | int | List[str | int]) ‑> float | List[float]
Expand source code
def entropy(formula: Union[str, int, List[Union[str, int]]]) -> Union[float, List[float]]:
    """
    Calculate standard molal entropy of elements in chemical formulas.
    
    Parameters
    ----------
    formula : str, int, or list
        Chemical formula(s) or species index(es)
        
    Returns
    -------
    float or list of float
        Standard entropy(ies) in J/(mol*K)
    """
    thermo_obj = thermo()
    if thermo_obj.element is None:
        raise RuntimeError("Element data not available")
    
    # Convert to stoichiometric matrix
    compositions = makeup(formula, count_zero=False)
    if not isinstance(compositions, list):
        compositions = [compositions]
    
    entropies = []
    
    for comp in compositions:
        if comp is None:
            entropies.append(np.nan)
            continue
        
        total_entropy = 0.0
        has_na = False
        
        for element, count in comp.items():
            
            # Look up element entropy
            element_data = thermo_obj.element[thermo_obj.element['element'] == element]
            if len(element_data) == 0:
                warnings.warn(f"Element {element} not available in thermo().element")
                has_na = True
                continue
            
            element_s = element_data.iloc[0]['s']
            element_n = element_data.iloc[0]['n']
            
            if pd.isna(element_s) or pd.isna(element_n):
                has_na = True
                continue
            
            # Entropy per atom
            entropy_per_atom = element_s / element_n
            total_entropy += count * entropy_per_atom
        
        if has_na and total_entropy == 0:
            entropies.append(np.nan)
        else:
            # Convert to Joules (assuming input is in cal)
            entropies.append(total_entropy * 4.184)  # cal to J conversion
    
    if len(entropies) == 1:
        return entropies[0]
    else:
        return entropies

Calculate standard molal entropy of elements in chemical formulas.

Parameters

formula : str, int, or list
Chemical formula(s) or species index(es)

Returns

float or list of float
Standard entropy(ies) in J/(mol*K)
def envert(value: float | numpy.ndarray | List[float], units: str) ‑> float | numpy.ndarray
Expand source code
def envert(value: Union[float, np.ndarray, List[float]],
           units: str) -> Union[float, np.ndarray]:
    """
    Convert values to the specified units from those given in thermo()$opt.

    This function is used internally to convert from the user's preferred units
    (stored in thermo().opt) to standard internal units.

    Parameters
    ----------
    value : float, ndarray, or list
        Value(s) to convert
    units : str
        Target units ('C', 'K', 'bar', 'MPa', 'J', 'cal')

    Returns
    -------
    float or ndarray
        Converted value(s)
    """

    if not isinstance(value, (int, float, np.ndarray, list)):
        return value

    value = np.asarray(value)

    # Check if first element is numeric
    if value.size > 0 and not np.issubdtype(value.dtype, np.number):
        return value

    units = units.lower()
    opt = thermo().opt

    # Temperature conversions
    if units in ['c', 'k', 't.units']:
        if units == 'c' and opt['T.units'] == 'K':
            return convert(value, 'c')
        if units == 'k' and opt['T.units'] == 'C':
            return convert(value, 'k')

    # Energy conversions
    if units in ['j', 'cal', 'e.units']:
        if units == 'j' and opt['E.units'] == 'cal':
            return convert(value, 'j')
        if units == 'cal' and opt['E.units'] == 'J':
            return convert(value, 'cal')

    # Pressure conversions
    if units in ['bar', 'mpa', 'p.units']:
        if units == 'mpa' and opt['P.units'] == 'bar':
            return convert(value, 'mpa')
        if units == 'bar' and opt['P.units'] == 'MPa':
            return convert(value, 'bar')

    return value

Convert values to the specified units from those given in thermo()$opt.

This function is used internally to convert from the user's preferred units (stored in thermo().opt) to standard internal units.

Parameters

value : float, ndarray, or list
Value(s) to convert
units : str
Target units ('C', 'K', 'bar', 'MPa', 'J', 'cal')

Returns

float or ndarray
Converted value(s)
def equilibrate(aout: Dict[str, Any],
balance: str | int | List[float] | None = None,
loga_balance: float | List[float] | None = None,
ispecies: List[int] | List[bool] | None = None,
normalize: bool | List[bool] = False,
as_residue: bool = False,
method: str | List[str] | None = None,
tol: float = np.float64(0.0001220703125),
messages: bool = True) ‑> Dict[str, Any]
Expand source code
def equilibrate(aout: Dict[str, Any],
                balance: Optional[Union[str, int, List[float]]] = None,
                loga_balance: Optional[Union[float, List[float]]] = None,
                ispecies: Optional[Union[List[int], List[bool]]] = None,
                normalize: Union[bool, List[bool]] = False,
                as_residue: bool = False,
                method: Optional[Union[str, List[str]]] = None,
                tol: float = np.finfo(float).eps ** 0.25,
                messages: bool = True) -> Dict[str, Any]:
    """
    Calculate equilibrium activities of species from affinities.

    This function calculates the equilibrium activities of species in
    (metastable) equilibrium from the affinities of their formation reactions
    from basis species at given activities.

    Parameters
    ----------
    aout : dict
        Output from affinity() containing chemical affinities
    balance : str, int, or list of float, optional
        Balancing method:
        - None: Autoselect using which_balance()
        - str: Name of basis species to balance on
        - "length": Balance on protein length (for proteins)
        - "volume": Balance on standard-state volume
        - 1: Balance on one mole of species (formula units)
        - list: User-defined balancing coefficients
    loga_balance : float or list of float, optional
        Logarithm of total activity of the balancing basis species
        If None, calculated from species initial activities and n.balance
    ispecies : list of int or list of bool, optional
        Indices or boolean mask of species to include in equilibration
        Default: all species except those with state "cr" (crystalline)
    normalize : bool or list of bool, default False
        Normalize formulas by balancing coefficients?
    as_residue : bool, default False
        Use residue basis for proteins?
    method : str or list of str, optional
        Equilibration method:
        - "boltzmann": Boltzmann distribution (for n.balance = 1)
        - "reaction": Reaction-based equilibration (general method)
        If None, chooses "boltzmann" if all n.balance == 1, else "reaction"
    tol : float, default np.finfo(float).eps**0.25
        Tolerance for root-finding in reaction method
    messages : bool, default True
        Whether to print informational messages

    Returns
    -------
    dict
        Dictionary containing all aout contents plus:
        - balance : str or list, Balancing description
        - m_balance : list, Molar formula divisors
        - n_balance : list, Balancing coefficients
        - loga_balance : float or array, Log activity of balanced quantity
        - Astar : list of arrays, Normalized affinities
        - loga_equil : list of arrays, Equilibrium log activities

    Examples
    --------
    >>> import pychnosz
    >>> pychnosz.basis("CHNOS")
    >>> pychnosz.basis("NH3", -2)
    >>> pychnosz.species(["alanine", "glycine", "serine"])
    >>> a = pychnosz.affinity(NH3=[-80, 60], T=55, P=2000)
    >>> e = pychnosz.equilibrate(a, balance="CO2")

    Notes
    -----
    This is a 1:1 replica of the R CHNOSZ equilibrate() function.
    - Handles both Boltzmann and reaction-based equilibration
    - Supports normalization and residue basis for proteins
    - Properly handles crystalline species via predominance diagrams
    - Implements identical balancing logic to R version
    """

    # Handle mosaic output (not implemented yet, but keep structure)
    if aout.get('fun') == 'mosaic':
        raise NotImplementedError("mosaic equilibration not yet implemented")

    # Number of possible species
    # affinity() returns values as a dict with ispecies as keys
    if isinstance(aout['values'], dict):
        # Convert dict to list ordered by species dataframe
        values_list = []
        for i in range(len(aout['species'])):
            species_idx = aout['species']['ispecies'].iloc[i]
            if species_idx in aout['values']:
                values_list.append(aout['values'][species_idx])
            else:
                # Species not in values dict - use NaN array
                values_list.append(np.array([np.nan]))
        aout['values'] = values_list

    nspecies = len(aout['values'])

    # Get the balancing coefficients
    bout = _balance(aout, balance, messages)
    n_balance_orig = bout['n_balance'].copy()
    n_balance = bout['n_balance'].copy()
    balance = bout['balance']

    # If solids (cr) species are present, find them on a predominance diagram
    iscr = [('cr' in str(state)) for state in aout['species']['state']]
    ncr = sum(iscr)

    # Set default ispecies to exclude cr species (matching R default)
    if ispecies is None:
        ispecies = [not is_cr for is_cr in iscr]

    if ncr > 0:
        # Import diagram here to avoid circular imports
        from .diagram import diagram
        dout = diagram(aout, balance=balance, normalize=normalize,
                      as_residue=as_residue, plot_it=False, limit_water=False, messages=messages)

    if ncr == nspecies:
        # We get here if there are only solids
        m_balance = None
        Astar = None
        loga_equil = []
        for i in range(len(aout['values'])):
            la = np.array(aout['values'][i], copy=True)
            la[:] = np.nan
            loga_equil.append(la)
    else:
        # We get here if there are any aqueous species
        # Take selected species in 'ispecies'
        if len(ispecies) == 0:
            raise ValueError("the length of ispecies is zero")

        # Convert boolean to indices if needed
        if isinstance(ispecies, list) and len(ispecies) > 0:
            if isinstance(ispecies[0], bool):
                ispecies = [i for i, x in enumerate(ispecies) if x]

        # Take out species that have NA affinities
        ina = [all(np.isnan(np.array(x).flatten())) for x in aout['values']]
        ispecies = [i for i in ispecies if not ina[i]]

        if len(ispecies) == 0:
            raise ValueError("all species have NA affinities")

        if ispecies != list(range(nspecies)):
            if messages:
                print(f"equilibrate: using {len(ispecies)} of {nspecies} species")
            aout_species_df = aout['species']
            aout['species'] = aout_species_df.iloc[ispecies].reset_index(drop=True)
            aout['values'] = [aout['values'][i] for i in ispecies]
            n_balance = [n_balance[i] for i in ispecies]

        # Number of species that are left
        nspecies = len(aout['values'])

        # Say what the balancing coefficients are
        if len(n_balance) < 100:
            if messages:
                print(f"equilibrate: n.balance is {', '.join(map(str, n_balance))}")

        # Logarithm of total activity of the balancing basis species
        if loga_balance is None:
            # Sum up the activities, then take absolute value
            # in case n.balance is negative
            logact = np.array([aout['species']['logact'].iloc[i] for i in range(len(aout['species']))])
            sumact = abs(sum(10**logact * n_balance))
            loga_balance = np.log10(sumact)

        # Make loga.balance the same length as the values of affinity
        if isinstance(loga_balance, (int, float)):
            loga_balance = float(loga_balance)
        else:
            loga_balance = np.array(loga_balance).flatten()

        nvalues = len(np.array(aout['values'][0]).flatten())

        if isinstance(loga_balance, float) or len(np.atleast_1d(loga_balance)) == 1:
            # We have a constant loga.balance
            if isinstance(loga_balance, np.ndarray):
                loga_balance = float(loga_balance[0])
            if messages:
                print(f"equilibrate: loga.balance is {loga_balance}")
            loga_balance = np.full(nvalues, loga_balance)
        else:
            # We are using a variable loga.balance (supplied by the user)
            if len(loga_balance) != nvalues:
                raise ValueError(f"length of loga.balance ({len(loga_balance)}) doesn't match "
                               f"the affinity values ({nvalues})")
            if messages:
                print(f"equilibrate: loga.balance has same length as affinity values ({len(loga_balance)})")

        # Normalize the molar formula by the balance coefficients
        m_balance = n_balance.copy()
        isprotein = ['_' in str(name) for name in aout['species']['name']]

        # Handle normalize parameter
        if isinstance(normalize, bool):
            normalize = [normalize] * nspecies
        elif not isinstance(normalize, list):
            normalize = list(normalize)

        if any(normalize) or as_residue:
            if any(n < 0 for n in n_balance):
                raise ValueError("one or more negative balancing coefficients prohibit using normalized molar formulas")

            for i in range(nspecies):
                if normalize[i] or as_residue:
                    n_balance[i] = 1

            if as_residue:
                if messages:
                    print("equilibrate: using 'as.residue' for molar formulas")
            else:
                if messages:
                    print("equilibrate: using 'normalize' for molar formulas")

            # Set the formula divisor (m.balance) to 1 for species whose formulas are *not* normalized
            m_balance = [m_balance[i] if (normalize[i] or as_residue) else 1
                        for i in range(nspecies)]
        else:
            m_balance = [1] * nspecies

        # Astar: the affinities/2.303RT of formation reactions with
        # formed species in their standard-state activities
        Astar = []
        for i in range(nspecies):
            # 'starve' the affinity of the activity of the species,
            # and normalize the value by the molar ratio
            logact_i = aout['species']['logact'].iloc[i]
            astar_i = (np.array(aout['values'][i]) + logact_i) / m_balance[i]
            Astar.append(astar_i)

        # Choose a method and compute the equilibrium activities of species
        if method is None:
            if all(n == 1 for n in n_balance):
                method = ["boltzmann"]
            else:
                method = ["reaction"]
        elif isinstance(method, str):
            method = [method]

        if messages:
            print(f"equilibrate: using {method[0]} method")

        if method[0] == "boltzmann":
            loga_equil = equil_boltzmann(Astar, n_balance, loga_balance)
        elif method[0] == "reaction":
            loga_equil = equil_reaction(Astar, n_balance, loga_balance, tol)
        else:
            raise ValueError(f"unknown method: {method[0]}")

        # If we normalized the formulas, get back to activities of species
        if any(normalize) and not as_residue:
            loga_equil = [loga_equil[i] - np.log10(m_balance[i])
                         for i in range(nspecies)]

    # Process cr species
    if ncr > 0:
        # cr species were excluded from equilibrium calculation,
        # so get values back to original lengths
        norig = len(dout['values'])
        n_balance = n_balance_orig

        # Ensure ispecies is in index form (not boolean)
        # When ncr == nspecies, ispecies was never converted from boolean to indices
        if isinstance(ispecies, list) and len(ispecies) > 0:
            if isinstance(ispecies[0], bool):
                ispecies = [i for i, x in enumerate(ispecies) if x]

        # Match indices back to original
        imatch = [None] * norig
        for j, orig_idx in enumerate(range(norig)):
            if orig_idx in ispecies:
                imatch[orig_idx] = ispecies.index(orig_idx)

        # Handle None values (when ncr == nspecies, these are set to None)
        # In R, indexing NULL returns NULL, so we need to check for None in Python
        if m_balance is not None:
            m_balance = [m_balance[imatch[i]] if imatch[i] is not None else None
                        for i in range(norig)]
        if Astar is not None:
            Astar = [Astar[imatch[i]] if imatch[i] is not None else None
                    for i in range(norig)]

        # Get a template from first loga_equil to determine shape
        loga_equil1 = loga_equil[0]
        loga_equil_orig = [None] * norig

        for i in range(norig):
            if imatch[i] is not None:
                loga_equil_orig[i] = loga_equil[imatch[i]]

        # Replace None loga_equil with -999 for cr-only species (will be set to 0 where predominant)
        # Use np.full with shape, not full_like, to avoid inheriting NaN values
        ina = [i for i in range(norig) if imatch[i] is None]
        for i in ina:
            loga_equil_orig[i] = np.full(loga_equil1.shape, -999.0)
        loga_equil = loga_equil_orig
        aout['species'] = dout['species']
        aout['values'] = dout['values']

        # Find the grid points where any cr species is predominant
        icr = [i for i in range(len(dout['species']))
               if 'cr' in str(dout['species']['state'].iloc[i])]

        # predominant uses 1-based R indexing (1, 2, 3, ...), convert to 0-based for Python
        predominant = dout['predominant']
        iscr_mask = np.zeros_like(predominant, dtype=bool)
        for icr_idx in icr:
            # Compare with icr_idx + 1 because predominant is 1-based
            iscr_mask |= (predominant == icr_idx + 1)

        # At those grid points, make the aqueous species' activities practically zero
        for i in range(norig):
            if i not in icr:
                loga_equil[i] = np.array(loga_equil[i], copy=True)
                loga_equil[i][iscr_mask] = -999

        # At the grid points where cr species predominate, set their loga_equil to 0 (standard state)
        for i in icr:
            # Compare with i + 1 because predominant is 1-based
            ispredom = (predominant == i + 1)
            loga_equil[i] = np.array(loga_equil[i], copy=True)
            # Set to standard state activity (logact, typically 0) where predominant
            loga_equil[i][ispredom] = dout['species']['logact'].iloc[i]

    # Put together the output
    out = aout.copy()
    out['fun'] = 'equilibrate'  # Mark this as equilibrate output
    out['balance'] = balance
    out['m_balance'] = m_balance
    out['n_balance'] = n_balance
    out['loga_balance'] = loga_balance
    out['Astar'] = Astar
    out['loga_equil'] = loga_equil

    return out

Calculate equilibrium activities of species from affinities.

This function calculates the equilibrium activities of species in (metastable) equilibrium from the affinities of their formation reactions from basis species at given activities.

Parameters

aout : dict
Output from affinity() containing chemical affinities
balance : str, int, or list of float, optional
Balancing method: - None: Autoselect using which_balance() - str: Name of basis species to balance on - "length": Balance on protein length (for proteins) - "volume": Balance on standard-state volume - 1: Balance on one mole of species (formula units) - list: User-defined balancing coefficients
loga_balance : float or list of float, optional
Logarithm of total activity of the balancing basis species If None, calculated from species initial activities and n.balance
ispecies : list of int or list of bool, optional
Indices or boolean mask of species to include in equilibration Default: all species except those with state "cr" (crystalline)
normalize : bool or list of bool, default False
Normalize formulas by balancing coefficients?
as_residue : bool, default False
Use residue basis for proteins?
method : str or list of str, optional
Equilibration method: - "boltzmann": Boltzmann distribution (for n.balance = 1) - "reaction": Reaction-based equilibration (general method) If None, chooses "boltzmann" if all n.balance == 1, else "reaction"
tol : float, default np.finfo(float).eps**0.25
Tolerance for root-finding in reaction method
messages : bool, default True
Whether to print informational messages

Returns

dict
Dictionary containing all aout contents plus: - balance : str or list, Balancing description - m_balance : list, Molar formula divisors - n_balance : list, Balancing coefficients - loga_balance : float or array, Log activity of balanced quantity - Astar : list of arrays, Normalized affinities - loga_equil : list of arrays, Equilibrium log activities

Examples

>>> import pychnosz
>>> pychnosz.basis("CHNOS")
>>> pychnosz.basis("NH3", -2)
>>> pychnosz.species(["alanine", "glycine", "serine"])
>>> a = pychnosz.affinity(NH3=[-80, 60], T=55, P=2000)
>>> e = pychnosz.equilibrate(a, balance="CO2")

Notes

This is a 1:1 replica of the R CHNOSZ equilibrate() function. - Handles both Boltzmann and reaction-based equilibration - Supports normalization and residue basis for proteins - Properly handles crystalline species via predominance diagrams - Implements identical balancing logic to R version

def expr_species(formula: str, state: str | None = None, use_state: bool = False) ‑> str
Expand source code
def expr_species(formula: str, state: Optional[str] = None, use_state: bool = False) -> str:
    """
    Format a chemical species formula for display.

    This is a simplified version that returns LaTeX-formatted strings
    suitable for matplotlib. The R version returns plotmath expressions.

    Parameters
    ----------
    formula : str
        Chemical formula
    state : str, optional
        Physical state (aq, cr, gas, liq)
    use_state : bool, default False
        Whether to include state in the formatted output

    Returns
    -------
    str
        LaTeX-formatted formula string

    Examples
    --------
    >>> expr_species("H2O")
    '$H_{2}O$'

    >>> expr_species("Ca+2")
    '$Ca^{2+}$'

    >>> expr_species("SO4-2")
    '$SO_{4}^{2-}$'
    """
    formatted = _format_species_latex(formula)

    if use_state and state:
        # Add state subscript
        return f"${formatted}_{{{state}}}$"
    else:
        return f"${formatted}$"

Format a chemical species formula for display.

This is a simplified version that returns LaTeX-formatted strings suitable for matplotlib. The R version returns plotmath expressions.

Parameters

formula : str
Chemical formula
state : str, optional
Physical state (aq, cr, gas, liq)
use_state : bool, default False
Whether to include state in the formatted output

Returns

str
LaTeX-formatted formula string

Examples

>>> expr_species("H2O")
'$H_{2}O$'
>>> expr_species("Ca+2")
'$Ca^{2+}$'
>>> expr_species("SO4-2")
'$SO_{4}^{2-}$'
def find_tp(predominant: numpy.ndarray) ‑> numpy.ndarray
Expand source code
def find_tp(predominant: np.ndarray) -> np.ndarray:
    """
    Find triple points in a predominance diagram.

    This function identifies the approximate positions of triple points
    (where three phases meet) in a 2-D predominance diagram by locating
    cells with the greatest number of different neighboring values.

    Parameters
    ----------
    predominant : np.ndarray
        Matrix of integers from diagram() output indicating which species
        predominates at each point. Should be a 2-D array where each value
        represents a different species/phase.

    Returns
    -------
    np.ndarray
        Array of shape (n, 2) where n is the number of triple points found.
        Each row contains [row_index, col_index] of a triple point location.
        Indices are 1-based to match R behavior.

    Examples
    --------
    >>> from pychnosz import *
    >>> reset()
    >>> basis(["corundum", "quartz", "oxygen"])
    >>> species(["kyanite", "sillimanite", "andalusite"])
    >>> a = affinity(T=[200, 900, 99], P=[0, 9000, 101], exceed_Ttr=True)
    >>> d = diagram(a)
    >>> tp = find_tp(d['predominant'])
    >>> # Get T and P at the triple point
    >>> Ttp = a['vals'][0][tp[0, 1] - 1]  # -1 for 0-based indexing
    >>> Ptp = a['vals'][1][::-1][tp[0, 0] - 1]  # reversed and -1

    Notes
    -----
    This is a Python translation of the R function find.tp() from CHNOSZ.
    The R version returns 1-based indices, and this Python version does too
    for consistency. When using these indices to access Python arrays,
    remember to subtract 1.

    The function works by:
    1. Rearranging the matrix as done by diagram() for plotting
    2. For each position, examining a 3x3 neighborhood
    3. Counting the number of unique values in that neighborhood
    4. Returning positions with the maximum count (typically 3 or more)
    """
    # Rearrange the matrix in the same way that diagram() does for 2-D predominance diagrams
    # R code: x <- t(x[, ncol(x):1])
    # This means: first reverse columns, then transpose
    x = np.transpose(predominant[:, ::-1])

    # Get all positions with valid values (> 0)
    valid_positions = np.argwhere(x > 0)

    if len(valid_positions) == 0:
        return np.array([])

    # For each position, count unique values in 3x3 neighborhood
    counts = []
    for pos in valid_positions:
        row, col = pos

        # Define the range to look at (3x3 except at edges)
        r1 = max(row - 1, 0)
        r2 = min(row + 1, x.shape[0] - 1)
        c1 = max(col - 1, 0)
        c2 = min(col + 1, x.shape[1] - 1)

        # Extract the neighborhood
        neighborhood = x[r1:r2+1, c1:c2+1]

        # Count unique values
        n_unique = len(np.unique(neighborhood))
        counts.append(n_unique)

    counts = np.array(counts)

    # Find positions with the maximum count
    max_count = np.max(counts)
    max_positions = valid_positions[counts == max_count]

    # Convert to 1-based indexing (to match R)
    # Return as [row, col] with 1-based indices
    result = max_positions + 1

    return result

Find triple points in a predominance diagram.

This function identifies the approximate positions of triple points (where three phases meet) in a 2-D predominance diagram by locating cells with the greatest number of different neighboring values.

Parameters

predominant : np.ndarray
Matrix of integers from diagram() output indicating which species predominates at each point. Should be a 2-D array where each value represents a different species/phase.

Returns

np.ndarray
Array of shape (n, 2) where n is the number of triple points found. Each row contains [row_index, col_index] of a triple point location. Indices are 1-based to match R behavior.

Examples

>>> from pychnosz import *
>>> reset()
>>> basis(["corundum", "quartz", "oxygen"])
>>> species(["kyanite", "sillimanite", "andalusite"])
>>> a = affinity(T=[200, 900, 99], P=[0, 9000, 101], exceed_Ttr=True)
>>> d = diagram(a)
>>> tp = find_tp(d['predominant'])
>>> # Get T and P at the triple point
>>> Ttp = a['vals'][0][tp[0, 1] - 1]  # -1 for 0-based indexing
>>> Ptp = a['vals'][1][::-1][tp[0, 0] - 1]  # reversed and -1

Notes

This is a Python translation of the R function find.tp() from CHNOSZ. The R version returns 1-based indices, and this Python version does too for consistency. When using these indices to access Python arrays, remember to subtract 1.

The function works by: 1. Rearranging the matrix as done by diagram() for plotting 2. For each position, examining a 3x3 neighborhood 3. Counting the number of unique values in that neighborhood 4. Returning positions with the maximum count (typically 3 or more)

def format_reaction(species: List[str | int], coeffs: List[float]) ‑> str
Expand source code
def format_reaction(species: List[Union[str, int]], coeffs: List[float]) -> str:
    """
    Format a reaction as a string for EQ3/6 input.

    Parameters
    ----------
    species : list
        Species names or indices
    coeffs : list
        Stoichiometric coefficients

    Returns
    -------
    str
        Formatted reaction string like "-1.0000 Fe+3 1.0000 Fe+2 0.2500 O2(g)"
    """
    thermo_sys = thermo()
    parts = []

    for sp, coeff in zip(species, coeffs):
        # Get species name if we have an index
        if isinstance(sp, (int, np.integer)):
            sp_name = thermo_sys.obigt.loc[int(sp)]['name']
        else:
            sp_name = sp

        # Replace 'water' with 'H2O' for EQ3 compatibility
        if sp_name == 'water':
            sp_name = 'H2O'

        parts.append(f"{coeff:.4f}")
        parts.append(sp_name)

    return " ".join(parts)

Format a reaction as a string for EQ3/6 input.

Parameters

species : list
Species names or indices
coeffs : list
Stoichiometric coefficients

Returns

str
Formatted reaction string like "-1.0000 Fe+3 1.0000 Fe+2 0.2500 O2(g)"
def get_formula_ox(name: str | int) ‑> Dict[str, float]
Expand source code
def get_formula_ox(name: Union[str, int]) -> Dict[str, float]:
    """
    Get quantities of elements and their oxidation states in a chemical compound.

    This function only works when a thermodynamic database with the 'formula_ox'
    column is loaded (e.g., the WORM database). For example, an input of "magnetite"
    would return the following: {'Fe+3': 2.0, 'Fe+2': 1.0, 'O-2': 4.0}.

    Parameters
    ----------
    name : str or int
        The name or database index of the chemical species of interest. Example:
        "magnetite" or 738.

    Returns
    -------
    dict
        A dictionary where each key represents an element in a specific
        oxidation state, and its value is the number of that element in the
        chemical species' formula.

    Raises
    ------
    TypeError
        If input is not a string or integer.
    AttributeError
        If the WORM thermodynamic database is not loaded (no formula_ox attribute).
    ValueError
        If the species is not found in the database or does not have oxidation
        state information.

    Examples
    --------
    >>> import pychnosz
    >>> # Load the WORM database
    >>> pychnosz.thermo("WORM")
    >>> # Get formula with oxidation states for magnetite
    >>> pychnosz.get_formula_ox("magnetite")
    {'Fe+3': 2.0, 'Fe+2': 1.0, 'O-2': 4.0}
    >>> # Can also use species index
    >>> pychnosz.get_formula_ox(738)
    {'Fe+3': 2.0, 'Fe+2': 1.0, 'O-2': 4.0}

    Notes
    -----
    This function requires the wormutils package to be installed for parsing
    the formula_ox strings. Install it with: pip install wormutils
    """

    # Import parse_formula_ox from wormutils
    try:
        from wormutils import parse_formula_ox
    except ImportError:
        raise ImportError(
            "The wormutils package is required to use get_formula_ox(). "
            "Install it with: pip install wormutils"
        )

    # Validate input type
    if not isinstance(name, str) and not isinstance(name, int):
        raise TypeError(
            "Must provide input as a string (chemical species name) or "
            "an integer (chemical species index)."
        )

    # Get the thermo system
    thermo_sys = thermo()

    # Convert index to name if necessary
    if isinstance(name, int):
        species_info = info(name, messages=False)
        if species_info is None or len(species_info) == 0:
            raise ValueError(f"Species index {name} not found in the database.")
        name = species_info.name.iloc[0]

    # Check if formula_ox exists in thermo()
    if not hasattr(thermo_sys, 'formula_ox') or thermo_sys.formula_ox is None:
        raise AttributeError(
            "The 'formula_ox' attribute is not available. "
            "This function only works when the WORM thermodynamic database "
            "is loaded. Load it with: pychnosz.thermo('WORM')"
        )

    df = thermo_sys.formula_ox

    # Check if the species name exists in the database
    if name not in list(df["name"]):
        raise ValueError(
            f"The species '{name}' was not found in the loaded thermodynamic database."
        )

    # Get the formula_ox string for this species
    try:
        formula_ox_str = df[df["name"] == name]["formula_ox"].iloc[0]
    except (KeyError, IndexError):
        raise ValueError(
            f"The species '{name}' does not have elemental oxidation states "
            "given in the 'formula_ox' column of the loaded thermodynamic database."
        )

    # Check if formula_ox is valid (not NaN or empty)
    if formula_ox_str is None or (isinstance(formula_ox_str, float) and pd.isna(formula_ox_str)) or formula_ox_str == "":
        raise ValueError(
            f"The species '{name}' does not have elemental oxidation states "
            "given in the 'formula_ox' column of the loaded thermodynamic database."
        )

    # Parse the formula_ox string and return
    return parse_formula_ox(formula_ox_str)

Get quantities of elements and their oxidation states in a chemical compound.

This function only works when a thermodynamic database with the 'formula_ox' column is loaded (e.g., the WORM database). For example, an input of "magnetite" would return the following: {'Fe+3': 2.0, 'Fe+2': 1.0, 'O-2': 4.0}.

Parameters

name : str or int
The name or database index of the chemical species of interest. Example: "magnetite" or 738.

Returns

dict
A dictionary where each key represents an element in a specific oxidation state, and its value is the number of that element in the chemical species' formula.

Raises

TypeError
If input is not a string or integer.
AttributeError
If the WORM thermodynamic database is not loaded (no formula_ox attribute).
ValueError
If the species is not found in the database or does not have oxidation state information.

Examples

>>> import pychnosz
>>> # Load the WORM database
>>> pychnosz.thermo("WORM")
>>> # Get formula with oxidation states for magnetite
>>> pychnosz.get_formula_ox("magnetite")
{'Fe+3': 2.0, 'Fe+2': 1.0, 'O-2': 4.0}
>>> # Can also use species index
>>> pychnosz.get_formula_ox(738)
{'Fe+3': 2.0, 'Fe+2': 1.0, 'O-2': 4.0}

Notes

This function requires the wormutils package to be installed for parsing the formula_ox strings. Install it with: pip install wormutils

def get_n_element_ox(names: str | int | List[str | int] | pandas.core.series.Series,
element_ox: str,
binary: bool = False) ‑> List[float | bool]
Expand source code
def get_n_element_ox(names: Union[str, int, List[Union[str, int]], pd.Series],
                     element_ox: str,
                     binary: bool = False) -> List[Union[float, bool]]:
    """
    Get the number of an element of a chosen oxidation state in chemical species formulas.

    This function only works when a thermodynamic database with the 'formula_ox'
    column is loaded (e.g., the WORM database).

    If binary is False, returns a list containing the number of the chosen
    element and oxidation state in the chemical species. For example, how many
    ferrous irons are in the formulae of hematite, fayalite, and magnetite,
    respectively?

    >>> get_n_element_ox(names=["hematite", "fayalite", "magnetite"],
    ...                  element_ox="Fe+2",
    ...                  binary=False)
    [0, 2.0, 1.0]

    If binary is True, returns a list of whether or not ferrous iron is in their
    formulas:

    >>> get_n_element_ox(names=["hematite", "fayalite", "magnetite"],
    ...                  element_ox="Fe+2",
    ...                  binary=True)
    [False, True, True]

    Parameters
    ----------
    names : str, int, list of str/int, or pd.Series
        The name or database index of a chemical species, or a list of
        names or indices. Can also be a pandas Series (e.g., from retrieve()).
        Example: ["hematite", "fayalite", "magnetite"] or [788, 782, 798].
    element_ox : str
        An element with a specific oxidation state. For example: "Fe+2" for
        ferrous iron.
    binary : bool, default False
        Should the output list show True/False for presence or absence of the
        element defined by `element_ox`? By default, this parameter is set to
        False so the output list shows quantities of the element instead.

    Returns
    -------
    list of float or list of bool
        A list containing quantities of the chosen element oxidation state in
        the formulas of the chemical species (if `binary=False`) or whether the
        chosen element oxidation state is present in the formulae (if `binary=True`).

    Raises
    ------
    AttributeError
        If the WORM thermodynamic database is not loaded (no formula_ox attribute).
    ValueError
        If a species is not found in the database or does not have oxidation
        state information.

    Examples
    --------
    >>> import pychnosz
    >>> # Load the WORM database
    >>> pychnosz.thermo("WORM")
    >>> # Get counts of Fe+2 in several minerals
    >>> pychnosz.get_n_element_ox(["hematite", "fayalite", "magnetite"], "Fe+2")
    [0, 2.0, 1.0]
    >>> # Get binary presence/absence
    >>> pychnosz.get_n_element_ox(["hematite", "fayalite", "magnetite"], "Fe+2", binary=True)
    [False, True, True]
    >>> # Can also use with retrieve()
    >>> r = pychnosz.retrieve("Fe", ["Si", "O", "H"], state=["cr"])
    >>> pychnosz.get_n_element_ox(r, "Fe+2")
    [1, 0, 0, 2.0, 1, 0, 1, 3.0, 1, 3.0, 0, 7.0]

    Notes
    -----
    This function requires the wormutils package to be installed for parsing
    the formula_ox strings. Install it with: pip install wormutils
    """

    # Handle pandas Series (e.g., from retrieve())
    if isinstance(names, pd.Series):
        # Convert Series to list of indices
        names = names.values.tolist()
    # Handle single name/index
    elif not isinstance(names, list):
        names = [names]

    # Get the count of element_ox for each species
    n_list = []
    for name in names:
        # Get the formula_ox dictionary for this species
        formula_ox_dict = get_formula_ox(name)
        # Get the count of element_ox (default to 0 if not present)
        count = formula_ox_dict.get(element_ox, 0)
        n_list.append(count)

    # Convert to binary if requested
    if binary:
        out_list = [True if n != 0 else False for n in n_list]
    else:
        out_list = n_list

    return out_list

Get the number of an element of a chosen oxidation state in chemical species formulas.

This function only works when a thermodynamic database with the 'formula_ox' column is loaded (e.g., the WORM database).

If binary is False, returns a list containing the number of the chosen element and oxidation state in the chemical species. For example, how many ferrous irons are in the formulae of hematite, fayalite, and magnetite, respectively?

>>> get_n_element_ox(names=["hematite", "fayalite", "magnetite"],
...                  element_ox="Fe+2",
...                  binary=False)
[0, 2.0, 1.0]

If binary is True, returns a list of whether or not ferrous iron is in their formulas:

>>> get_n_element_ox(names=["hematite", "fayalite", "magnetite"],
...                  element_ox="Fe+2",
...                  binary=True)
[False, True, True]

Parameters

names : str, int, list of str/int, or pd.Series
The name or database index of a chemical species, or a list of names or indices. Can also be a pandas Series (e.g., from retrieve()). Example: ["hematite", "fayalite", "magnetite"] or [788, 782, 798].
element_ox : str
An element with a specific oxidation state. For example: "Fe+2" for ferrous iron.
binary : bool, default False
Should the output list show True/False for presence or absence of the element defined by element_ox? By default, this parameter is set to False so the output list shows quantities of the element instead.

Returns

list of float or list of bool
A list containing quantities of the chosen element oxidation state in the formulas of the chemical species (if binary=False) or whether the chosen element oxidation state is present in the formulae (if binary=True).

Raises

AttributeError
If the WORM thermodynamic database is not loaded (no formula_ox attribute).
ValueError
If a species is not found in the database or does not have oxidation state information.

Examples

>>> import pychnosz
>>> # Load the WORM database
>>> pychnosz.thermo("WORM")
>>> # Get counts of Fe+2 in several minerals
>>> pychnosz.get_n_element_ox(["hematite", "fayalite", "magnetite"], "Fe+2")
[0, 2.0, 1.0]
>>> # Get binary presence/absence
>>> pychnosz.get_n_element_ox(["hematite", "fayalite", "magnetite"], "Fe+2", binary=True)
[False, True, True]
>>> # Can also use with retrieve()
>>> r = pychnosz.retrieve("Fe", ["Si", "O", "H"], state=["cr"])
>>> pychnosz.get_n_element_ox(r, "Fe+2")
[1, 0, 0, 2.0, 1, 0, 1, 3.0, 1, 3.0, 0, 7.0]

Notes

This function requires the wormutils package to be installed for parsing the formula_ox strings. Install it with: pip install wormutils

def gfun(rhohat, Tc, P, alpha, daldT, beta)
Expand source code
def gfun(rhohat, Tc, P, alpha, daldT, beta):
    ## g and f functions for describing effective electrostatic radii of ions
    ## split from hkf() 20120123 jmd
    ## based on equations in
    ## Shock EL, Oelkers EH, Johnson JW, Sverjensky DA, Helgeson HC, 1992
    ## Calculation of the Thermodynamic Properties of Aqueous Species at High Pressures
    ## and Temperatures: Effective Electrostatic Radii, Dissociation Constants and
    ## Standard Partial Molal Properties to 1000 degrees C and 5 kbar
    ## J. Chem. Soc. Faraday Trans., 88(6), 803-826  doi:10.1039/FT9928800803
    # rhohat - density of water in g/cm3
    # Tc - temperature in degrees Celsius
    # P - pressure in bars

    # Vectorized version - handle both scalars and arrays
    rhohat = np.atleast_1d(rhohat)
    Tc = np.atleast_1d(Tc)
    P = np.atleast_1d(P)
    alpha = np.atleast_1d(alpha)
    daldT = np.atleast_1d(daldT)
    beta = np.atleast_1d(beta)

    # Broadcast to same shape
    shape = np.broadcast_shapes(rhohat.shape, Tc.shape, P.shape, alpha.shape, daldT.shape, beta.shape)
    rhohat = np.broadcast_to(rhohat, shape)
    Tc = np.broadcast_to(Tc, shape)
    P = np.broadcast_to(P, shape)
    alpha = np.broadcast_to(alpha, shape)
    daldT = np.broadcast_to(daldT, shape)
    beta = np.broadcast_to(beta, shape)

    # Initialize output arrays
    g = np.zeros(shape)
    dgdT = np.zeros(shape)
    d2gdT2 = np.zeros(shape)
    dgdP = np.zeros(shape)

    # only rhohat less than 1 will give results other than zero
    mask = rhohat < 1
    if not np.any(mask):
        return {"g": g, "dgdT": dgdT, "d2gdT2": d2gdT2, "dgdP": dgdP}

    # eta in Eq. 1
    eta = 1.66027E5
    # Table 3
    ag1 = -2.037662
    ag2 = 5.747000E-3
    ag3 = -6.557892E-6
    bg1 = 6.107361
    bg2 = -1.074377E-2
    bg3 = 1.268348E-5

    # Work only with masked values
    Tc_m = Tc[mask]
    P_m = P[mask]
    rhohat_m = rhohat[mask]
    alpha_m = alpha[mask]
    daldT_m = daldT[mask]
    beta_m = beta[mask]

    # Eq. 25
    ag = ag1 + ag2 * Tc_m + ag3 * Tc_m ** 2
    # Eq. 26
    bg = bg1 + bg2 * Tc_m + bg3 * Tc_m ** 2
    # Eq. 24
    g_m = ag * (1 - rhohat_m) ** bg

    # Table 4
    af1 = 0.3666666E2
    af2 = -0.1504956E-9
    af3 = 0.5017997E-13

    # Eq. 33
    f = ( ((Tc_m - 155) / 300) ** 4.8 + af1 * ((Tc_m - 155) / 300) ** 16 ) * \
        ( af2 * (1000 - P_m) ** 3 + af3 * (1000 - P_m) ** 4 )

    # limits of the f function (region II of Fig. 6)
    ifg = (Tc_m > 155) & (P_m < 1000) & (Tc_m < 355)

    # Eq. 32 - apply f correction where ifg is True
    # Check for complex values
    f_is_real = ~np.iscomplex(f)
    apply_f = ifg & f_is_real
    g_m = np.where(apply_f, g_m - f.real, g_m)

    # at P > 6000 bar (in DEW calculations), g is zero 20170926
    g_m = np.where(P_m > 6000, 0, g_m)

    ## now we have g at P, T
    # put the results in their right place (where rhohat < 1)
    g[mask] = g_m
    
    ## the rest is to get its partial derivatives with pressure and temperature
    ## after Johnson et al., 1992
    # alpha - coefficient of isobaric expansivity (K^-1)
    # daldT - temperature derivative of coefficient of isobaric expansivity (K^-2)
    # beta - coefficient of isothermal compressibility (bar^-1)

    # Eqn. 76
    d2fdT2 = (0.0608/300*((Tc_m-155)/300)**2.8 + af1/375*((Tc_m-155)/300)**14) * (af2*(1000-P_m)**3 + af3*(1000-P_m)**4)
    # Eqn. 75
    dfdT = (0.016*((Tc_m-155)/300)**3.8 + 16*af1/300*((Tc_m-155)/300)**15) * \
        (af2*(1000-P_m)**3 + af3*(1000-P_m)**4)
    # Eqn. 74
    dfdP = -(((Tc_m-155)/300)**4.8 + af1*((Tc_m-155)/300)**16) * \
        (3*af2*(1000-P_m)**2 + 4*af3*(1000-P_m)**3)
    d2bdT2 = 2 * bg3  # Eqn. 73
    d2adT2 = 2 * ag3  # Eqn. 72
    dbdT = bg2 + 2*bg3*Tc_m  # Eqn. 71
    dadT = ag2 + 2*ag3*Tc_m  # Eqn. 70

    # Convert complex to NaN
    d2fdT2 = np.where(np.iscomplex(d2fdT2), np.nan, np.real(d2fdT2))
    dfdT = np.where(np.iscomplex(dfdT), np.nan, np.real(dfdT))
    dfdP = np.where(np.iscomplex(dfdP), np.nan, np.real(dfdP))

    # Initialize derivative arrays for masked region
    dgdT_m = np.zeros_like(g_m)
    d2gdT2_m = np.zeros_like(g_m)
    dgdP_m = np.zeros_like(g_m)

    # Calculate derivatives where alpha and daldT are not NaN
    alpha_valid = ~np.isnan(alpha_m) & ~np.isnan(daldT_m)
    if np.any(alpha_valid):
        # Work with valid subset
        av_idx = alpha_valid
        bg_av = bg[av_idx]
        rhohat_av = rhohat_m[av_idx]
        alpha_av = alpha_m[av_idx]
        daldT_av = daldT_m[av_idx]
        g_av = g_m[av_idx]
        ag_av = ag[av_idx]
        Tc_av = Tc_m[av_idx]
        dbdT_av = dbdT[av_idx]
        dadT_av = dadT[av_idx]

        # Handle log of (1-rhohat) safely
        with np.errstate(divide='ignore', invalid='ignore'):
            log_term = np.log(1 - rhohat_av)
            log_term = np.where(np.isfinite(log_term), log_term, 0)

        # Eqn. 69
        dgadT = bg_av*rhohat_av*alpha_av*(1-rhohat_av)**(bg_av-1) + log_term*g_av/ag_av*dbdT_av
        D = rhohat_av

        # transcribed from SUPCRT92/reac92.f
        dDdT = -D * alpha_av
        dDdTT = -D * (daldT_av - alpha_av**2)
        Db = (1-D)**bg_av
        dDbdT = -bg_av*(1-D)**(bg_av-1)*dDdT + log_term*Db*dbdT_av
        dDbdTT = -(bg_av*(1-D)**(bg_av-1)*dDdTT + (1-D)**(bg_av-1)*dDdT*dbdT_av + \
            bg_av*dDdT*(-(bg_av-1)*(1-D)**(bg_av-2)*dDdT + log_term*(1-D)**(bg_av-1)*dbdT_av)) + \
            log_term*(1-D)**bg_av*d2bdT2 - (1-D)**bg_av*dbdT_av*dDdT/(1-D) + log_term*dbdT_av*dDbdT
        d2gdT2_calc = ag_av*dDbdTT + 2*dDbdT*dadT_av + Db*d2adT2

        # Apply f correction where ifg is True
        ifg_av = ifg[av_idx]
        d2fdT2_av = d2fdT2[av_idx]
        dfdT_av = dfdT[av_idx]
        d2gdT2_calc = np.where(ifg_av, d2gdT2_calc - d2fdT2_av, d2gdT2_calc)

        dgdT_calc = g_av/ag_av*dadT_av + ag_av*dgadT  # Eqn. 67
        dgdT_calc = np.where(ifg_av, dgdT_calc - dfdT_av, dgdT_calc)

        dgdT_m[av_idx] = dgdT_calc
        d2gdT2_m[av_idx] = d2gdT2_calc

    # Calculate dgdP where beta is not NaN
    beta_valid = ~np.isnan(beta_m)
    if np.any(beta_valid):
        bv_idx = beta_valid
        bg_bv = bg[bv_idx]
        rhohat_bv = rhohat_m[bv_idx]
        beta_bv = beta_m[bv_idx]
        g_bv = g_m[bv_idx]

        dgdP_calc = -bg_bv*rhohat_bv*beta_bv*g_bv*(1-rhohat_bv)**-1  # Eqn. 66
        ifg_bv = ifg[bv_idx]
        dfdP_bv = dfdP[bv_idx]
        dgdP_calc = np.where(ifg_bv, dgdP_calc - dfdP_bv, dgdP_calc)
        dgdP_m[bv_idx] = dgdP_calc

    # Put results back into full arrays
    dgdT[mask] = dgdT_m
    d2gdT2[mask] = d2gdT2_m
    dgdP[mask] = dgdP_m

    return {"g": g, "dgdT": dgdT, "d2gdT2": d2gdT2, "dgdP": dgdP}
def group_formulas() ‑> pandas.core.frame.DataFrame
Expand source code
def group_formulas() -> pd.DataFrame:
    """
    Return chemical formulas of amino acid residues.

    This function returns a DataFrame with the chemical formulas of
    H2O, the 20 amino acid sidechain groups, and the unfolded protein
    backbone group [UPBB].

    Returns
    -------
    DataFrame
        Chemical formulas with elements C, H, N, O, S as columns
        and residues as rows
    """
    # Chemical formulas as a numpy array
    # Rows: water, [Ala], [Cys], [Asp], [Glu], [Phe], [Gly], [His], [Ile], [Lys], [Leu],
    #       [Met], [Asn], [Pro], [Gln], [Arg], [Ser], [Thr], [Val], [Trp], [Tyr], [UPBB]
    # Columns: C, H, N, O, S
    A = np.array([
        [0, 2, 0, 1, 0],      # H2O
        [1, 3, 0, 0, 0],      # [Ala]
        [1, 3, 0, 0, 1],      # [Cys]
        [2, 3, 0, 2, 0],      # [Asp]
        [3, 5, 0, 2, 0],      # [Glu]
        [7, 7, 0, 0, 0],      # [Phe]
        [0, 1, 0, 0, 0],      # [Gly]
        [4, 5, 2, 0, 0],      # [His]
        [4, 9, 0, 0, 0],      # [Ile]
        [4, 10, 1, 0, 0],     # [Lys]
        [4, 9, 0, 0, 0],      # [Leu]
        [3, 7, 0, 0, 1],      # [Met]
        [2, 4, 1, 1, 0],      # [Asn]
        [3, 5, 0, 0, 0],      # [Pro]
        [3, 6, 1, 1, 0],      # [Gln]
        [4, 10, 3, 0, 0],     # [Arg]
        [1, 3, 0, 1, 0],      # [Ser]
        [2, 5, 0, 1, 0],      # [Thr]
        [3, 7, 0, 0, 0],      # [Val]
        [9, 8, 1, 0, 0],      # [Trp]
        [7, 7, 0, 1, 0],      # [Tyr]
        [2, 2, 1, 1, 0]       # [UPBB]
    ])

    rownames = ['H2O', '[Ala]', '[Cys]', '[Asp]', '[Glu]', '[Phe]', '[Gly]',
                '[His]', '[Ile]', '[Lys]', '[Leu]', '[Met]', '[Asn]', '[Pro]',
                '[Gln]', '[Arg]', '[Ser]', '[Thr]', '[Val]', '[Trp]', '[Tyr]',
                '[UPBB]']

    # Add [UPBB] to the sidechain groups to get residues
    out = A.copy()
    # Add [UPBB] (last row) to each sidechain group (rows 1-20)
    out[1:21, :] = out[1:21, :] + A[21, :]

    # Create DataFrame
    df = pd.DataFrame(out[0:21, :],
                     index=rownames[0:21],
                     columns=['C', 'H', 'N', 'O', 'S'])

    return df

Return chemical formulas of amino acid residues.

This function returns a DataFrame with the chemical formulas of H2O, the 20 amino acid sidechain groups, and the unfolded protein backbone group [UPBB].

Returns

DataFrame
Chemical formulas with elements C, H, N, O, S as columns and residues as rows
def hkf(property=None,
parameters=None,
T=298.15,
P=1,
contrib=['n', 's', 'o'],
H2O_props=['rho'],
water_model='SUPCRT92')
Expand source code
def hkf(property=None, parameters=None, T=298.15, P=1,
    contrib = ["n", "s", "o"], H2O_props=["rho"], water_model="SUPCRT92"):
    # calculate G, H, S, Cp, V, kT, and/or E using
    # the revised HKF equations of state
    # H2O_props - H2O properties needed for subcrt() output
    # constants
    Tr = 298.15 # K
    Pr = 1      # bar
    Theta = 228 # K
    Psi = 2600  # bar

    # Convert T and P to arrays for vectorized operations
    T = np.atleast_1d(T)
    P = np.atleast_1d(P)

    # DEBUG
    if False:
        print(f"\nDEBUG HKF input:")
        print(f"  T (K): {T}")
        print(f"  P (bar): {P}")

    # make T and P equal length
    if P.size < T.size:
        P = np.full_like(T, P[0] if P.size == 1 else P)
    if T.size < P.size:
        T = np.full_like(P, T[0] if T.size == 1 else T)

    n_conditions = T.size
    
    # GB conversion note: handle error messages later
#     # nonsolvation, solvation, and origination contribution
#     notcontrib <- ! contrib %in% c("n", "s", "o")
#     if(TRUE %in% notcontrib) stop(paste("contrib must be in c('n', 's', 'o); got", c2s(contrib[notcontrib])))
    
    # get water properties
    # rho - for subcrt() output and g function
    # Born functions and epsilon - for HKF calculations
    H2O_props += ["QBorn", "XBorn", "YBorn", "epsilon"]

    if water_model == "SUPCRT92":
      # using H2O92D.f from SUPCRT92: alpha, daldT, beta - for partial derivatives of omega (g function)
      H2O_props += ["alpha", "daldT", "beta"]
    
    elif water_model == "IAPWS95":
      # using IAPWS-95: NBorn, UBorn - for compressibility, expansibility
      H2O_props += ["alpha", "daldT", "beta", "NBorn", "UBorn"]
    
    elif water_model == "DEW":
      # using DEW model: get beta to calculate dgdP
      H2O_props += ["alpha", "daldT", "beta"]

    # DEBUG: Print T and P being passed to water
    if False:
        print(f"DEBUG HKF calling water():")
        print(f"  T type: {type(T)}, T: {T}")
        print(f"  P type: {type(P)}, P: {P}")
        print(f"  H2O_props: {H2O_props}")

    H2O_PrTr = water(H2O_props, T=Tr, P=Pr)
    H2O_PT = water(H2O_props, T=T, P=P)

    # DEBUG: Print what water returned
    if False:
        print(f"DEBUG HKF water() returned:")
        print(f"  H2O_PT type: {type(H2O_PT)}")
        if isinstance(H2O_PT, dict):
            print(f"  H2O_PT keys: {H2O_PT.keys()}")
            print(f"  epsilon: {H2O_PT.get('epsilon', 'NOT FOUND')}")

    # Handle dict output from water function
    def get_water_prop(water_dict, prop):
        """Helper function to get water property from dict or DataFrame"""
        if isinstance(water_dict, dict):
            return water_dict[prop]
        else:
            return water_dict.loc["1", prop]

    # Get epsilon values and handle potential zeros
    epsilon_PT = get_water_prop(H2O_PT, "epsilon")
    epsilon_PrTr = get_water_prop(H2O_PrTr, "epsilon")

    # Check for zero or very small epsilon values and warn
    if np.any(epsilon_PT == 0) or np.any(np.abs(epsilon_PT) < 1e-10):
        warnings.warn(f"HKF: epsilon at P,T is zero or very small: {epsilon_PT}. H2O_PT keys: {H2O_PT.keys() if isinstance(H2O_PT, dict) else 'not dict'}")

    with np.errstate(divide='ignore', invalid='ignore'):
        ZBorn = -1 / epsilon_PT
        ZBorn_PrTr = -1 / epsilon_PrTr
    
    # a class to store the result
    out_dict = {} # dictionary to store output
    
    for k in parameters.index:
        
        if parameters["state"][k] != "aq":
            out_dict[k] = {p:float('NaN') for p in property}
        else:
            sp = parameters["name"][k]

            # loop over each species
            PAR = copy.copy(parameters.loc[k, :])

            PAR["a1.a"] = copy.copy(PAR["a1.a"]*10**-1)
            PAR["a2.b"] = copy.copy(PAR["a2.b"]*10**2)
            PAR["a4.d"] = copy.copy(PAR["a4.d"]*10**4)
            PAR["c2.f"] = copy.copy(PAR["c2.f"]*10**4)
            PAR["omega.lambda"] = copy.copy(PAR["omega.lambda"]*10**5)

            # substitute Cp and V for missing EoS parameters
            # here we assume that the parameters are in the same position as in thermo()$OBIGT
            # we don't need this if we're just looking at solvation properties (Cp_s_var, V_s_var)

            # GB conversion note: this block checks various things about EOS parameters.
            # for now, just set hasEOS to True
            hasEOS = True # delete this once the following block is converted to python
    #         if "n" in contrib:
    #             # put the heat capacity in for c1 if both c1 and c2 are missing
    #             if all(is.na(PAR[, 18:19])):
    #                 PAR[, 18] = PAR["Cp"]
    #             # put the volume in for a1 if a1, a2, a3 and a4 are missing
    #             if all(is.na(PAR[, 14:17])):
    #                 PAR[, 14] = convert(PAR["V"], "calories")
    #             # test for availability of the EoS parameters
    #             hasEOS = any(!is.na(PAR[, 14:21]))
    #             # if at least one of the EoS parameters is available, zero out any NA's in the rest
    #             if hasEOS:
    #                 PAR[, 14:21][, is.na(PAR[, 14:21])] = 0

            # compute values of omega(P,T) from those of omega(Pr,Tr)
            # using g function etc. (Shock et al., 1992 and others)
            omega = PAR["omega.lambda"]  # omega_PrTr
            # its derivatives are zero unless the g function kicks in
            dwdP = np.zeros(n_conditions)
            dwdT = np.zeros(n_conditions)
            d2wdT2 = np.zeros(n_conditions)
            Z = PAR["z.T"]

            omega_PT = np.full(n_conditions, PAR["omega.lambda"])
            if Z != 0 and Z != "NA" and PAR["name"] != "H+":
                # compute derivatives of omega: g and f functions (Shock et al., 1992; Johnson et al., 1992)
                rhohat = get_water_prop(H2O_PT, "rho")/1000  # just converting kg/m3 to g/cm3

                # temporarily filter out Python's warnings about dividing by zero, which is possible
                # with the equations in the gfunction
                # Possible complex output is acounted for in gfun().
                with warnings.catch_warnings():
                    warnings.simplefilter('ignore')
                    g = gfun(rhohat, T-273.15, P, get_water_prop(H2O_PT, "alpha"), get_water_prop(H2O_PT, "daldT"), get_water_prop(H2O_PT, "beta"))

                # after SUPCRT92/reac92.f
                eta = 1.66027E5
                reref = (Z**2) / (omega/eta + Z/(3.082 + 0))
                re = reref + abs(Z) * g["g"]
                omega_PT = eta * (Z**2/re - Z/(3.082 + g["g"]))
                Z3 = abs(Z**3)/re**2 - Z/(3.082 + g["g"])**2
                Z4 = abs(Z**4)/re**3 - Z/(3.082 + g["g"])**3
                dwdP = (-eta * Z3 * g["dgdP"])
                dwdT = (-eta * Z3 * g["dgdT"])
                d2wdT2 = (2 * eta * Z4 * g["dgdT"]**2 - eta * Z3 * g["d2gdT2"])

            # loop over each property
            w = float('NaN')
            for i,PROP in enumerate(property) :

                # over nonsolvation, solvation, or origination contributions - vectorized
                hkf_p = np.zeros(n_conditions)

                for icontrib in contrib :
                    # various contributions to the properties
                    if icontrib == "n":
                        # nonsolvation ghs equations
                        if PROP == "H":
                            p_c = PAR["c1.e"]*(T-Tr) - PAR["c2.f"]*(1/(T-Theta)-1/(Tr-Theta))
                            p_a = PAR["a1.a"]*(P-Pr) + PAR["a2.b"]*np.log((Psi+P)/(Psi+Pr)) + \
                              ((2*T-Theta)/(T-Theta)**2)*(PAR["a3.c"]*(P-Pr)+PAR["a4.d"]*np.log((Psi+P)/(Psi+Pr)))
                            p = p_c + p_a
                        elif PROP == "S":
                            p_c = PAR["c1.e"]*np.log(T/Tr) - \
                              (PAR["c2.f"]/Theta)*( 1/(T-Theta)-1/(Tr-Theta) + \
                              np.log( (Tr*(T-Theta))/(T*(Tr-Theta)) )/Theta )
                            p_a = (T-Theta)**(-2)*(PAR["a3.c"]*(P-Pr)+PAR["a4.d"]*np.log((Psi+P)/(Psi+Pr)))
                            p = p_c + p_a
                        elif PROP == "G":
                            p_c = -PAR["c1.e"]*(T*np.log(T/Tr)-T+Tr) - \
                              PAR["c2.f"]*( (1/(T-Theta)-1/(Tr-Theta))*((Theta-T)/Theta) - \
                              (T/Theta**2)*np.log((Tr*(T-Theta))/(T*(Tr-Theta))) )
                            p_a = PAR["a1.a"]*(P-Pr) + PAR["a2.b"]*np.log((Psi+P)/(Psi+Pr)) + \
                              (PAR["a3.c"]*(P-Pr) + PAR["a4.d"]*np.log((Psi+P)/(Psi+Pr)))/(T-Theta)
                            p = p_c + p_a
                            # at Tr,Pr, if the origination contribution is not NA, ensure the solvation contribution is 0, not NA
                            if not np.isnan(PAR["G"]):
                                p = np.where((T==Tr) & (P==Pr), 0, p)
                        # nonsolvation cp v kt e equations
                        elif PROP == "Cp":
                            p = PAR["c1.e"] + PAR["c2.f"] * ( T - Theta ) ** (-2)
                        elif PROP == "V":
                            p = convert_cm3bar(PAR["a1.a"]) + \
                              convert_cm3bar(PAR["a2.b"]) / (Psi + P) + \
                              (convert_cm3bar(PAR["a3.c"]) + convert_cm3bar(PAR["a4.d"]) / (Psi + P)) / (T - Theta)
#                         elif PROP == "kT":
#                             p = (convert(PAR["a2.b"], "cm3bar") + \
#                               convert(PAR["a4.d"], "cm3bar") / (T - Theta)) * (Psi + P) ** (-2)
#                         elif PROP == "E":
#                             p = convert( - (PAR["a3.c"] + PAR["a4.d"] / convert((Psi + P), "calories")) * \
#                               (T - Theta) ** (-2), "cm3bar")
                        else:
                            print("BAD")

                    if icontrib == "s":
                        # solvation ghs equations
                        if PROP == "G":
                            p = -omega_PT*(ZBorn+1) + omega*(ZBorn_PrTr+1) + omega*get_water_prop(H2O_PrTr, "YBorn")*(T-Tr)
                            # at Tr,Pr, if the origination contribution is not NA, ensure the solvation contribution is 0, not NA
                            if(np.isnan(PAR["G"])):
                                p = np.where((T==Tr) & (P==Pr), 0, p)
                        if PROP == "H":
                            p = -omega_PT*(ZBorn+1) + omega_PT*T*get_water_prop(H2O_PT, "YBorn") + T*(ZBorn+1)*dwdT + \
                                   omega*(ZBorn_PrTr+1) - omega*Tr*get_water_prop(H2O_PrTr, "YBorn")
                        if PROP == "S":
                            p = omega_PT*get_water_prop(H2O_PT, "YBorn") + (ZBorn+1)*dwdT - omega*get_water_prop(H2O_PrTr, "YBorn")
                        # solvation cp v kt e equations
                        if PROP == "Cp":
                            p = omega_PT*T*get_water_prop(H2O_PT, "XBorn") + 2*T*get_water_prop(H2O_PT, "YBorn")*dwdT + T*(ZBorn+1)*d2wdT2
                        if PROP == "V":
                            term1 = -convert_cm3bar(omega_PT) * get_water_prop(H2O_PT, "QBorn")
                            term2 = convert_cm3bar(dwdP) * (-ZBorn - 1)
                            p = term1 + term2

                            # DEBUG
                            if False:
                                print(f"\nDEBUG solvation V terms:")
                                print(f"  omega_PT: {omega_PT}")
                                print(f"  QBorn: {get_water_prop(H2O_PT, 'QBorn')}")
                                print(f"  dwdP: {dwdP}")
                                print(f"  ZBorn: {ZBorn}")
                                print(f"  term1 (-ω*QBorn): {term1}")
                                print(f"  term2 (dwdP*(-Z-1)): {term2}")
                                print(f"  total p: {p}")
                        # TODO: the partial derivatives of omega are not included here here for kt and e
                        # (to do it, see p. 820 of SOJ+92 ... but kt requires d2wdP2 which we don"t have yet)
                        if PROP == "kT":
                            p = convert_cm3bar(omega) * get_water_prop(H2O_PT, "NBorn")
                        if PROP == "E":
                            p = -convert_cm3bar(omega) * get_water_prop(H2O_PT, "UBorn")

                    if icontrib == "o":
                        # origination ghs equations
                        if PROP == "G":
                            p = PAR["G"] - PAR["S"] * (T-Tr)
                            # don"t inherit NA from PAR$S at Tr
                            p = np.where(T == Tr, PAR["G"], p)
                        elif PROP == "H":
                            p = np.full(n_conditions, PAR["H"])
                        elif PROP == "S":
                            p = np.full(n_conditions, PAR["S"])
                        # origination eos equations (Cp, V, kT, E): senseless
                        else:
                            p = np.zeros(n_conditions)

                    # accumulate the contribution
                    hkf_p = hkf_p + p

                    # DEBUG
                    if False and PROP == "V":
                        print(f"\nDEBUG HKF V calculation (species {k}, contrib={icontrib}):")
                        print(f"  T: {T}")
                        print(f"  P: {P}")
                        print(f"  contribution p: {p}")
                        print(f"  accumulated hkf_p: {hkf_p}")

                # species have to be numbered (k) instead of named because of name repeats in db (e.g., cr polymorphs)
                if i > 0:
                    out_dict[k][PROP] = hkf_p
                else:
                    out_dict[k] = {PROP:hkf_p}

                # DEBUG
                if False and PROP == "V":
                    print(f"\nDEBUG HKF final V for species {k}: {hkf_p}")

    return(out_dict, H2O_PT)
def info(species: str | int | List[str | int] | pandas.core.series.Series | None = None,
state: str | List[str] | None = None,
check_it: bool = True,
messages: bool = True) ‑> pandas.core.frame.DataFrame | int | List[int] | None
Expand source code
def info(species: Optional[Union[str, int, List[Union[str, int]], pd.Series]] = None,
         state: Optional[Union[str, List[str]]] = None,
         check_it: bool = True,
         messages: bool = True) -> Union[pd.DataFrame, int, List[int], None]:
    """
    Search for species in the thermodynamic database.

    Parameters
    ----------
    species : str, int, list of str/int, pd.Series, or None
        Species name, formula, abbreviation, or OBIGT index.
        Can also be a pandas Series (e.g., from retrieve()).
        If None, returns summary information about the database.
    state : str, list of str, or None
        Physical state(s) to match ('aq', 'cr', 'gas', 'liq')
    check_it : bool, default True
        Whether to perform consistency checks on thermodynamic data
    messages : bool, default True
        Whether to print informational messages

    Returns
    -------
    pd.DataFrame, int, list of int, or None
        - If species is None: prints database summary, returns None
        - If species is numeric: returns DataFrame with species data
        - If species is string: returns species index(es) or NA if not found

    Examples
    --------
    >>> # Get database summary
    >>> info()

    >>> # Find species index
    >>> info("H2O")

    >>> # Get species data by index
    >>> info(1)

    >>> # Search with specific state
    >>> info("CO2", "aq")

    >>> # Use output from retrieve()
    >>> zn_species = retrieve("Zn", ["O", "H"], state="aq")
    >>> info(zn_species)
    """
    thermo_obj = thermo()

    # Initialize database if needed
    if not thermo_obj.is_initialized():
        thermo_obj.reset()

    # Return database summary if no species specified
    if species is None:
        return _print_database_summary(thermo_obj, messages)

    # Handle pandas Series (e.g., from retrieve())
    if isinstance(species, pd.Series):
        # Extract the integer indices from the Series values
        indices = species.values.tolist()
        return _info_numeric(indices, thermo_obj, check_it, messages)

    # Handle numeric species indices
    if isinstance(species, (int, list)) and all(isinstance(s, int) for s in (species if isinstance(species, list) else [species])):
        return _info_numeric(species, thermo_obj, check_it, messages)

    # Handle string species names/formulas
    if isinstance(species, (str, list)):
        return _info_character(species, state, thermo_obj, messages)

    raise ValueError(f"Invalid species type: {type(species)}")

Search for species in the thermodynamic database.

Parameters

species : str, int, list of str/int, pd.Series, or None
Species name, formula, abbreviation, or OBIGT index. Can also be a pandas Series (e.g., from retrieve()). If None, returns summary information about the database.
state : str, list of str, or None
Physical state(s) to match ('aq', 'cr', 'gas', 'liq')
check_it : bool, default True
Whether to perform consistency checks on thermodynamic data
messages : bool, default True
Whether to print informational messages

Returns

pd.DataFrame, int, list of int, or None
  • If species is None: prints database summary, returns None
  • If species is numeric: returns DataFrame with species data
  • If species is string: returns species index(es) or NA if not found

Examples

>>> # Get database summary
>>> info()
>>> # Find species index
>>> info("H2O")
>>> # Get species data by index
>>> info(1)
>>> # Search with specific state
>>> info("CO2", "aq")
>>> # Use output from retrieve()
>>> zn_species = retrieve("Zn", ["O", "H"], state="aq")
>>> info(zn_species)
def ionize_aa(aa: pandas.core.frame.DataFrame,
property: str = 'Z',
T: float | numpy.ndarray = 25.0,
P: float | str | numpy.ndarray = 'Psat',
pH: float | numpy.ndarray = 7.0,
ret_val: str | None = None,
suppress_Cys: bool = False) ‑> pandas.core.frame.DataFrame
Expand source code
def ionize_aa(aa: pd.DataFrame,
              property: str = "Z",
              T: Union[float, np.ndarray] = 25.0,
              P: Union[float, str, np.ndarray] = "Psat",
              pH: Union[float, np.ndarray] = 7.0,
              ret_val: Optional[str] = None,
              suppress_Cys: bool = False) -> pd.DataFrame:
    """
    Calculate additive ionization properties of proteins.

    This function calculates the net charge or other ionization properties
    of proteins based on amino acid composition at specified T, P, and pH.

    Parameters
    ----------
    aa : DataFrame
        Amino acid composition data
    property : str, default "Z"
        Property to calculate:
        - "Z": net charge
        - "A": chemical affinity
        - Other subcrt properties (G, H, S, Cp, V)
    T : float or array, default 25.0
        Temperature in degrees Celsius
    P : float, str, or array, default "Psat"
        Pressure in bar, or "Psat" for saturation
    pH : float or array, default 7.0
        pH value(s)
    ret_val : str, optional
        Return value type:
        - "pK": return pK values
        - "alpha": return degree of formation
        - "aavals": return amino acid values
        - None: return ionization property (default)
    suppress_Cys : bool, default False
        Suppress cysteine ionization

    Returns
    -------
    DataFrame
        Ionization properties

    Examples
    --------
    >>> from pychnosz import *
    >>> aa = pinfo(pinfo("LYSC_CHICK"))
    >>> Z = ionize_aa(aa, pH=7.0)
    """
    # Ensure inputs are arrays
    T = np.atleast_1d(T)
    if isinstance(P, str):
        P = np.array([P] * len(T))
    else:
        P = np.atleast_1d(P)
    pH_arr = np.atleast_1d(pH)

    # Get maximum length and replicate arrays
    lmax = max(len(T), len(P), len(pH_arr))
    T = np.resize(T, lmax)
    if isinstance(P[0], str):
        P = np.array([P[0]] * lmax)
    else:
        P = np.resize(P, lmax)
    pH_arr = np.resize(pH_arr, lmax)

    # Turn pH into a matrix with as many columns as ionizable groups (9)
    pH_matrix = np.tile(pH_arr[:, np.newaxis], (1, 9))

    # Charges for ionizable groups
    charges = np.array([-1, -1, -1, 1, 1, 1, -1, 1, -1])
    charges_matrix = np.tile(charges, (lmax, 1))

    # The ionizable groups
    neutral = ["[Cys]", "[Asp]", "[Glu]", "[His]", "[Lys]", "[Arg]", "[Tyr]", "[AABB]", "[AABB]"]
    charged = ["[Cys-]", "[Asp-]", "[Glu-]", "[His+]", "[Lys+]", "[Arg+]", "[Tyr-]", "[AABB+]", "[AABB-]"]

    # Get row numbers in OBIGT
    ineutral = [info(g, "aq") for g in neutral]
    icharged = [info(g, "aq") for g in charged]

    # Get unique T, P combinations
    pTP = [f"{t}_{p}" for t, p in zip(T, P)]
    unique_pTP = []
    seen = set()
    indices = []
    for i, tp in enumerate(pTP):
        if tp not in seen:
            unique_pTP.append(i)
            seen.add(tp)
        indices.append(list(seen).index(tp))

    # Determine which property to calculate
    sprop = ["G", property] if property not in ["A", "Z"] else ["G"]

    # Convert T to Kelvin for subcrt
    TK = convert(T, "K")

    # Call subcrt for unique T, P combinations
    unique_T = TK[unique_pTP]
    unique_P = P[unique_pTP]

    all_species = ineutral + icharged
    sout = subcrt(all_species, T=unique_T, P=unique_P, property=sprop, convert=False)

    # Extract G values
    Gs = np.zeros((len(unique_pTP), len(all_species)))
    for i, spec_idx in enumerate(all_species):
        if isinstance(sout['out'], dict):
            # Single species result
            Gs[:, i] = sout['out']['G']
        else:
            # Multiple species result
            Gs[:, i] = sout['out'][i]['G'].values

    # Gibbs energy difference for each group
    DG = Gs[:, 9:18] - Gs[:, 0:9]

    # Build matrix for all T, P values (including duplicates)
    DG_full = DG[indices, :]

    # Calculate pK values
    DG_full = DG_full * charges
    pK = np.zeros_like(DG_full)
    for i in range(pK.shape[1]):
        pK[:, i] = convert(DG_full[:, i], "logK", T=TK)

    # Return pK if requested
    if ret_val == "pK":
        return pd.DataFrame(pK, columns=charged)

    # Calculate alpha (degree of formation)
    alpha = 1 / (1 + 10 ** (charges_matrix * (pH_matrix - pK)))

    # Suppress cysteine ionization if requested
    if suppress_Cys:
        alpha[:, 0] = 0

    # Return alpha if requested
    if ret_val == "alpha":
        return pd.DataFrame(alpha, columns=charged)

    # Calculate amino acid values
    if property == "Z":
        aavals = charges_matrix.copy()
    elif property == "A":
        aavals = -charges_matrix * (pH_matrix - pK)
    else:
        # Extract property values from subcrt output
        prop_vals = np.zeros((len(unique_pTP), len(all_species)))
        for i, spec_idx in enumerate(all_species):
            if isinstance(sout['out'], dict):
                prop_vals[:, i] = sout['out'][property]
            else:
                prop_vals[:, i] = sout['out'][i][property].values

        # Build matrix for all T, P values
        prop_vals_full = prop_vals[indices, :]

        # Property difference for each group
        aavals = prop_vals_full[:, 9:18] - prop_vals_full[:, 0:9]

    # Return aavals if requested
    if ret_val == "aavals":
        return pd.DataFrame(aavals, columns=charged)

    # Contribution from each group
    aavals = aavals * alpha

    # Get counts of ionizable groups from aa
    # Columns: Cys, Asp, Glu, His, Lys, Arg, Tyr, chains, chains
    ionize_cols = ["Cys", "Asp", "Glu", "His", "Lys", "Arg", "Tyr", "chains", "chains"]
    aa_counts = aa[ionize_cols].values.astype(float)

    # Calculate total ionization property
    out = np.dot(aavals, aa_counts.T)

    # Create DataFrame
    result = pd.DataFrame(out)

    return result

Calculate additive ionization properties of proteins.

This function calculates the net charge or other ionization properties of proteins based on amino acid composition at specified T, P, and pH.

Parameters

aa : DataFrame
Amino acid composition data
property : str, default "Z"
Property to calculate: - "Z": net charge - "A": chemical affinity - Other subcrt properties (G, H, S, Cp, V)
T : float or array, default 25.0
Temperature in degrees Celsius
P : float, str, or array, default "Psat"
Pressure in bar, or "Psat" for saturation
pH : float or array, default 7.0
pH value(s)
ret_val : str, optional
Return value type: - "pK": return pK values - "alpha": return degree of formation - "aavals": return amino acid values - None: return ionization property (default)
suppress_Cys : bool, default False
Suppress cysteine ionization

Returns

DataFrame
Ionization properties

Examples

>>> from pychnosz import *
>>> aa = pinfo(pinfo("LYSC_CHICK"))
>>> Z = ionize_aa(aa, pH=7.0)
def list_OBIGT_files() ‑> List[str]
Expand source code
def list_OBIGT_files() -> List[str]:
    """
    List available OBIGT database files.
    
    Returns
    -------
    list of str
        List of available .csv files in the OBIGT directory
    """
    
    # Use package-relative path
    base_paths = [
        os.path.join(os.path.dirname(__file__), 'extdata', 'OBIGT'),
    ]
    
    files = []
    for base_path in base_paths:
        if os.path.exists(base_path):
            csv_files = [f[:-4] for f in os.listdir(base_path) if f.endswith('.csv')]
            files.extend(csv_files)
            break
    
    return sorted(list(set(files)))  # Remove duplicates and sort

List available OBIGT database files.

Returns

list of str
List of available .csv files in the OBIGT directory
def load_WORM(keep_default: bool = False, messages: bool = True) ‑> bool
Expand source code
def load_WORM(keep_default: bool = False, messages: bool = True) -> bool:
    """
    Load the WORM (Water-Organic-Rock-Microbe) thermodynamic database.

    This function downloads and loads the WORM database from the WORM-db GitHub
    repository. By default, it replaces the OBIGT database with WORM data,
    keeping only water, H+, and e- from the original database.

    Parameters
    ----------
    keep_default : bool, default False
        If False, replace OBIGT with minimal species (water, H+, e-) before
        loading WORM. If True, add WORM species to the existing OBIGT database.
    messages : bool, default True
        Whether to print informational messages

    Returns
    -------
    bool
        True if WORM database was loaded successfully, False otherwise

    Examples
    --------
    >>> import pychnosz
    >>> pychnosz.reset()
    >>> # Load WORM database (replaces default OBIGT)
    >>> pychnosz.load_WORM()
    >>>
    >>> # Load WORM database while keeping default OBIGT species
    >>> pychnosz.reset()
    >>> pychnosz.load_WORM(keep_default=True)

    Notes
    -----
    The WORM database is downloaded from:
    - Species data: https://github.com/worm-portal/WORM-db/master/wrm_data_latest.csv
    - References: https://github.com/worm-portal/WORM-db/master/references.csv

    This feature is exclusive to the Python version of CHNOSZ.
    """

    # WORM database URLs
    url_data = "https://raw.githubusercontent.com/worm-portal/WORM-db/master/wrm_data_latest.csv"
    url_refs = "https://raw.githubusercontent.com/worm-portal/WORM-db/master/references.csv"

    # Name for source_file column
    worm_source_name = "wrm_data_latest.csv"

    # Check if we can connect to the WORM database
    if not can_connect_to(url_data):
        if messages:
            print("load_WORM: could not reach WORM database repository")
        return False

    # Download WORM species data
    worm_data = download_worm_data(url_data)
    if worm_data is None:
        if messages:
            print("load_WORM: failed to download WORM species data")
        return False

    # Get the thermodynamic system
    thermo_sys = thermo()

    if not keep_default:
        # Keep only essential species (water, H+, e-)
        from ..core.info import info
        try:
            # Get indices for essential species
            essential_species = []
            for species in ["water", "H+", "e-"]:
                idx = info(species)
                if idx is not None:
                    if isinstance(idx, (list, tuple)):
                        essential_species.extend(idx)
                    else:
                        essential_species.append(idx)

            if essential_species:
                # Keep only essential species
                minimal_obigt = thermo_sys.obigt.loc[essential_species].copy()
                thermo_sys.obigt = minimal_obigt
        except Exception as e:
            if messages:
                print(f"load_WORM: warning - error keeping essential species: {e}")

    # Add WORM species data (suppress add_OBIGT messages)
    try:
        # Add source_file column to worm_data before adding
        worm_data['source_file'] = worm_source_name

        indices = add_OBIGT(worm_data, messages=False)
    except Exception as e:
        if messages:
            print(f"load_WORM: failed to add WORM species: {e}")
        return False

    # Try to download and load WORM references
    if can_connect_to(url_refs):
        worm_refs = download_worm_data(url_refs)
        if worm_refs is not None:
            # Replace refs with WORM refs
            thermo_sys.refs = worm_refs

    # Update formula_ox if it exists in WORM data
    # This is already handled by add_OBIGT, but we ensure it's set correctly
    if 'formula_ox' in thermo_sys.obigt.columns:
        formula_ox_df = pd.DataFrame({
            'name': thermo_sys.obigt['name'],
            'formula_ox': thermo_sys.obigt['formula_ox']
        })
        formula_ox_df.index = thermo_sys.obigt.index
        thermo_sys.formula_ox = formula_ox_df

    # Print single summary message
    if messages:
        final_obigt = thermo_sys.obigt
        total_species = len(final_obigt)
        aqueous_species = len(final_obigt[final_obigt['state'] == 'aq'])
        print(f"The WORM thermodynamic database has been loaded: {aqueous_species} aqueous, {total_species} total species")

    return True

Load the WORM (Water-Organic-Rock-Microbe) thermodynamic database.

This function downloads and loads the WORM database from the WORM-db GitHub repository. By default, it replaces the OBIGT database with WORM data, keeping only water, H+, and e- from the original database.

Parameters

keep_default : bool, default False
If False, replace OBIGT with minimal species (water, H+, e-) before loading WORM. If True, add WORM species to the existing OBIGT database.
messages : bool, default True
Whether to print informational messages

Returns

bool
True if WORM database was loaded successfully, False otherwise

Examples

>>> import pychnosz
>>> pychnosz.reset()
>>> # Load WORM database (replaces default OBIGT)
>>> pychnosz.load_WORM()
>>>
>>> # Load WORM database while keeping default OBIGT species
>>> pychnosz.reset()
>>> pychnosz.load_WORM(keep_default=True)

Notes

The WORM database is downloaded from: - Species data: https://github.com/worm-portal/WORM-db/master/wrm_data_latest.csv - References: https://github.com/worm-portal/WORM-db/master/references.csv

This feature is exclusive to the Python version of CHNOSZ.

def makeup(formula: str | int | List[str | int],
multiplier: float | List[float] = 1.0,
sum_formulas: bool = False,
count_zero: bool = False) ‑> Dict[str, float] | List[Dict[str, float]]
Expand source code
def makeup(formula: Union[str, int, List[Union[str, int]]], 
           multiplier: Union[float, List[float]] = 1.0,
           sum_formulas: bool = False,
           count_zero: bool = False) -> Union[Dict[str, float], List[Dict[str, float]]]:
    """
    Return elemental makeup (counts) of chemical formula(s).
    
    Handles formulas with parenthetical subformulas, suffixed formulas,
    charges, and fractional coefficients.
    
    Parameters
    ----------
    formula : str, int, or list
        Chemical formula(s) or species index(es)
    multiplier : float or list of float
        Multiplier(s) to apply to formula coefficients
    sum_formulas : bool
        If True, return sum of all formulas
    count_zero : bool
        If True, include zero counts for all elements appearing in any formula
        
    Returns
    -------
    dict or list of dict
        Elemental composition(s) as {element: count} dictionaries
        
    Examples
    --------
    >>> makeup("H2O")
    {'H': 2, 'O': 1}
    
    >>> makeup("Ca(OH)2")
    {'Ca': 1, 'O': 2, 'H': 2}
    
    >>> makeup(["H2O", "CO2"])
    [{'H': 2, 'O': 1}, {'C': 1, 'O': 2}]
    """
    # Handle matrix input
    if isinstance(formula, np.ndarray) and formula.ndim == 2:
        return [makeup(formula[i, :]) for i in range(formula.shape[0])]
    
    # Handle named numeric objects (return unchanged)
    if isinstance(formula, dict) and all(isinstance(k, str) for k in formula.keys()):
        return formula
    
    # Handle list of named objects
    if isinstance(formula, list) and len(formula) > 0:
        if isinstance(formula[0], dict) and all(isinstance(k, str) for k in formula[0].keys()):
            return formula
    
    # Prepare multiplier
    if not isinstance(multiplier, list):
        multiplier = [multiplier]
    
    # Handle multiple formulas
    if isinstance(formula, list):
        if len(multiplier) != 1 and len(multiplier) != len(formula):
            raise ValueError("multiplier does not have length = 1 or length = number of formulas")
        
        if len(multiplier) == 1:
            multiplier = multiplier * len(formula)
        
        # Get formulas for any species indices
        formula = get_formula(formula)
        
        results = []
        for i, f in enumerate(formula):
            result = makeup(f, multiplier[i])
            results.append(result)
        
        # Handle sum_formulas option
        if sum_formulas:
            all_elements = set()
            for result in results:
                if result is not None:
                    all_elements.update(result.keys())
            
            summed = {}
            for element in all_elements:
                summed[element] = sum(result.get(element, 0) for result in results if result is not None)
            return summed
        
        # Handle count_zero option
        elif count_zero:
            # Get all elements appearing in any formula
            all_elements = set()
            for result in results:
                if result is not None:
                    all_elements.update(result.keys())
            
            # Add zero counts for missing elements
            complete_results = []
            for result in results:
                if result is None:
                    complete_result = {element: np.nan for element in all_elements}
                else:
                    complete_result = {element: result.get(element, 0) for element in all_elements}
                complete_results.append(complete_result)
            
            return complete_results
        
        return results
    
    # Handle single formula
    if isinstance(formula, int):
        # Get formula from species index
        thermo_obj = thermo()
        if thermo_obj.obigt is not None:
            # Use .loc for label-based indexing (species indices are 1-based labels)
            if formula in thermo_obj.obigt.index:
                formula = thermo_obj.obigt.loc[formula, 'formula']
            else:
                raise FormulaError(f"Species index {formula} not found in OBIGT database")
        else:
            raise FormulaError("Thermodynamic database not initialized")
    
    if formula is None or pd.isna(formula):
        return None
    
    # Parse single formula
    try:
        result = _parse_formula(str(formula))
        
        # Apply multiplier
        if multiplier[0] != 1.0:
            result = {element: count * multiplier[0] for element, count in result.items()}
        
        # Validate elements
        _validate_elements(result)
        
        return result
    
    except Exception as e:
        raise FormulaError(f"Error parsing formula '{formula}': {e}")

Return elemental makeup (counts) of chemical formula(s).

Handles formulas with parenthetical subformulas, suffixed formulas, charges, and fractional coefficients.

Parameters

formula : str, int, or list
Chemical formula(s) or species index(es)
multiplier : float or list of float
Multiplier(s) to apply to formula coefficients
sum_formulas : bool
If True, return sum of all formulas
count_zero : bool
If True, include zero counts for all elements appearing in any formula

Returns

dict or list of dict
Elemental composition(s) as {element: count} dictionaries

Examples

>>> makeup("H2O")
{'H': 2, 'O': 1}
>>> makeup("Ca(OH)2")
{'Ca': 1, 'O': 2, 'H': 2}
>>> makeup(["H2O", "CO2"])
[{'H': 2, 'O': 1}, {'C': 1, 'O': 2}]
def mass(formula: str | int | List[str | int]) ‑> float | List[float]
Expand source code
def mass(formula: Union[str, int, List[Union[str, int]]]) -> Union[float, List[float]]:
    """
    Calculate molecular mass of chemical formula(s).
    
    Parameters
    ----------
    formula : str, int, or list
        Chemical formula(s) or species index(es)
        
    Returns
    -------
    float or list of float
        Molecular mass(es) in g/mol
    """
    thermo_obj = thermo()
    if thermo_obj.element is None:
        raise RuntimeError("Element data not available")
    
    # Convert to stoichiometric matrix
    compositions = makeup(formula, count_zero=False)
    if not isinstance(compositions, list):
        compositions = [compositions]
    
    masses = []
    
    for comp in compositions:
        if comp is None:
            masses.append(np.nan)
            continue
        
        total_mass = 0.0
        for element, count in comp.items():
            if element == 'Z':
                continue  # Charge has no mass
            
            # Look up element mass
            element_data = thermo_obj.element[thermo_obj.element['element'] == element]
            if len(element_data) == 0:
                raise FormulaError(f"Element {element} not found in element database")
            
            element_mass = element_data.iloc[0]['mass']
            total_mass += count * element_mass
        
        masses.append(total_mass)
    
    if len(masses) == 1:
        return masses[0]
    else:
        return masses

Calculate molecular mass of chemical formula(s).

Parameters

formula : str, int, or list
Chemical formula(s) or species index(es)

Returns

float or list of float
Molecular mass(es) in g/mol
def mod_OBIGT(*args, zap: bool = False, **kwargs) ‑> int | List[int]
Expand source code
def mod_OBIGT(*args, zap: bool = False, **kwargs) -> Union[int, List[int]]:
    """
    Add or modify species in the thermodynamic database.

    This function replicates the behavior of R CHNOSZ mod.OBIGT() by allowing
    modification of existing species or addition of new species to thermo().obigt.

    Parameters
    ----------
    *args : int, str, list, or dict
        If first argument is numeric: species index or indices to modify
        If first argument is str: species name(s) to modify or add
        If first argument is list/dict: contains all parameters
    zap : bool, default False
        If True, clear all properties except state and model before updating
    **kwargs : any
        Named properties to set (e.g., G=-100, S=50, formula="H2O")
        Special properties: name, state, formula, model, E_units

    Returns
    -------
    int or list of int
        Species index or indices that were modified/added

    Examples
    --------
    >>> import pychnosz
    >>> pychnosz.reset()
    >>> # Add new species
    >>> i = pychnosz.mod_OBIGT("myspecies", formula="C2H6", G=-100, S=50)

    >>> # Modify existing species
    >>> i = pychnosz.mod_OBIGT("water", state="liq", G=-56690)

    >>> # Modify by index
    >>> i_h2o = pychnosz.info("water", "liq")
    >>> i = pychnosz.mod_OBIGT(i_h2o, G=-56690)

    >>> # Add multiple species
    >>> i = pychnosz.mod_OBIGT(["X", "Y"], formula=["C12", "C13"], state=["aq", "cr"])

    Notes
    -----
    This function modifies the thermo() object in place.
    The behavior exactly matches R CHNOSZ mod.OBIGT().
    """

    # Get the thermo system
    thermo_sys = thermo()

    # Ensure the thermodynamic system is initialized
    if not thermo_sys.is_initialized() or thermo_sys.obigt is None:
        raise RuntimeError("Thermodynamic system not initialized. Run reset() first.")

    # Process arguments
    # If called with a dict as first arg (like R's list)
    if len(args) == 1 and isinstance(args[0], dict):
        params = args[0].copy()
    elif len(args) > 0:
        # First positional argument could be species index or name
        first_arg = args[0]
        params = kwargs.copy()

        # Check if first argument is numeric (species index/indices)
        if isinstance(first_arg, (int, np.integer)):
            params['_index'] = first_arg
        elif isinstance(first_arg, (list, tuple)) and len(first_arg) > 0:
            if isinstance(first_arg[0], (int, np.integer)):
                params['_index'] = list(first_arg)
            else:
                # First arg is list of names
                params['name'] = list(first_arg)
        else:
            # First arg is species name
            # If first arg name is not in kwargs, it's the species name
            if 'name' not in params:
                params['name'] = first_arg
    else:
        params = kwargs.copy()

    # Validate we have at least a name/index and one property
    if '_index' not in params and 'name' not in params:
        raise ValueError("Please supply at least a species name and a property to update")

    # Check that we have at least one property
    # When using index: exclude _index and state from property count
    # When using name: exclude name and state from property count (name is identifier, not property)
    if '_index' in params:
        property_keys = set(params.keys()) - {'_index', 'state'}
    else:
        property_keys = set(params.keys()) - {'name', 'state'}

    if len(property_keys) == 0:
        raise ValueError("Please supply at least a species name and a property to update")

    # Get species indices
    if '_index' in params:
        # Working with indices
        ispecies_input = params['_index']
        if not isinstance(ispecies_input, list):
            ispecies_input = [ispecies_input]
        del params['_index']

        # Get species names from indices
        speciesname = []
        for idx in ispecies_input:
            sp_info = info(idx)
            speciesname.append(sp_info['name'].iloc[0] if isinstance(sp_info, pd.DataFrame) else sp_info['name'])

        ispecies = ispecies_input
    else:
        # Working with names
        names = params.get('name')
        if not isinstance(names, list):
            names = [names]

        states = params.get('state')
        if states is not None and not isinstance(states, list):
            states = [states]

        speciesname = names

        # Find species indices
        ispecies = []
        for i, name in enumerate(names):
            state = states[i] if states and i < len(states) else None
            try:
                if state:
                    idx = info(name, state)
                else:
                    idx = info(name)

                # info() returns an int if found
                if isinstance(idx, (int, np.integer)):
                    ispecies.append(int(idx))
                else:
                    # Not found
                    ispecies.append(None)
            except:
                # Species doesn't exist - will be added
                ispecies.append(None)

    # Convert params to DataFrame format
    # Handle list values vs single values
    nspecies = len(ispecies)
    param_df = {}
    for key, value in params.items():
        if isinstance(value, list):
            if len(value) != nspecies:
                raise ValueError(f"Length of '{key}' ({len(value)}) doesn't match number of species ({nspecies})")
            param_df[key] = value
        else:
            param_df[key] = [value] * nspecies

    # Create DataFrame of arguments
    args_df = pd.DataFrame(param_df)

    # Get column names of OBIGT (handle split names with ".")
    obigt_cols = thermo_sys.obigt.columns.tolist()

    # Map parameter names to column names (handle dot notation)
    # e.g., "E.units" can be accessed as "E_units"
    col_mapping = {}
    for col in obigt_cols:
        col_mapping[col] = col
        col_mapping[col.replace('_', '.')] = col
        # Also map first part before dot
        if '_' in col:
            col_mapping[col.split('_')[0]] = col

    # Determine which columns we're updating
    icol = []
    icol_names = []
    for key in args_df.columns:
        if key in col_mapping:
            icol_names.append(col_mapping[key])
            icol.append(obigt_cols.index(col_mapping[key]))
        else:
            raise ValueError(f"Property '{key}' not in thermo$OBIGT")

    # Separate new species from existing ones
    inew = [i for i, idx in enumerate(ispecies) if idx is None]
    iold = [i for i, idx in enumerate(ispecies) if idx is not None]

    result_indices = []

    # Add new species
    if len(inew) > 0:
        # Create blank rows
        newrows = pd.DataFrame(index=range(len(inew)), columns=obigt_cols)
        newrows[:] = np.nan

        # Set defaults
        default_state = thermo_sys.opt.get('state', 'aq')
        default_units = thermo_sys.opt.get('E.units', 'J')

        newrows['state'] = default_state
        newrows['E_units'] = default_units

        # Set formula from name if not provided
        for i, idx in enumerate(inew):
            if 'formula' in args_df.columns:
                newrows.at[i, 'formula'] = args_df.iloc[idx]['formula']
            else:
                newrows.at[i, 'formula'] = args_df.iloc[idx]['name']

        # Fill in provided columns
        for i, idx in enumerate(inew):
            for col_name in icol_names:
                if col_name in args_df.columns:
                    newrows.at[i, col_name] = args_df.at[idx, col_name]

        # Guess model from state
        for i in range(len(newrows)):
            if pd.isna(newrows.iloc[i]['model']):
                if newrows.iloc[i]['state'] == 'aq':
                    newrows.at[i, 'model'] = 'HKF'
                else:
                    newrows.at[i, 'model'] = 'CGL'

        # Validate formulas
        for i in range(len(newrows)):
            formula = newrows.iloc[i]['formula']
            try:
                makeup(formula)
            except Exception as e:
                warnings.warn("Please supply a valid chemical formula as the species name or in the 'formula' argument")
                raise e

        # Add to OBIGT
        ntotal_before = len(thermo_sys.obigt)
        thermo_sys.obigt = pd.concat([thermo_sys.obigt, newrows], ignore_index=True)

        # Reset index to 1-based
        thermo_sys.obigt.index = range(1, len(thermo_sys.obigt) + 1)

        # Update ispecies for new entries
        for i, idx in enumerate(inew):
            new_idx = ntotal_before + i + 1
            if idx < len(ispecies):
                ispecies[idx] = new_idx
            result_indices.append(new_idx)

            # Print message
            name = newrows.iloc[i]['name']
            state = newrows.iloc[i]['state']
            model = newrows.iloc[i]['model']
            e_units = newrows.iloc[i]['E_units']
            print(f"mod_OBIGT: added {name}({state}) with {model} model and energy units of {e_units}")

    # Modify existing species
    if len(iold) > 0:
        for i in iold:
            idx = ispecies[i]

            # Get old values
            oldprop = thermo_sys.obigt.loc[idx, icol_names].copy()
            state = thermo_sys.obigt.loc[idx, 'state']
            model = thermo_sys.obigt.loc[idx, 'model']

            # If zap, clear all values except state and model
            if zap:
                thermo_sys.obigt.loc[idx, :] = np.nan
                thermo_sys.obigt.loc[idx, 'state'] = state
                thermo_sys.obigt.loc[idx, 'model'] = model

            # Get new properties
            newprop = args_df.iloc[i][icol_names].copy()

            # Check if there's any change
            # Compare values element-wise, treating NaN as equal to NaN
            has_change = False
            for col in icol_names:
                old_val = oldprop[col] if col in oldprop.index else np.nan
                new_val = newprop[col] if col in newprop.index else np.nan

                # Check if both are NaN
                if pd.isna(old_val) and pd.isna(new_val):
                    continue
                # Check if one is NaN and other is not
                elif pd.isna(old_val) or pd.isna(new_val):
                    has_change = True
                    break
                # Check if values are different
                elif old_val != new_val:
                    has_change = True
                    break

            if not has_change:
                # No change
                print(f"mod_OBIGT: no change for {speciesname[i]}({state})")
            else:
                # Update the data
                for col_name in icol_names:
                    if col_name in args_df.columns:
                        thermo_sys.obigt.loc[idx, col_name] = args_df.iloc[i][col_name]

                print(f"mod_OBIGT: updated {speciesname[i]}({state})")

            result_indices.append(idx)

    # Return indices
    if len(result_indices) == 1:
        return result_indices[0]
    return result_indices

Add or modify species in the thermodynamic database.

This function replicates the behavior of R CHNOSZ mod.OBIGT() by allowing modification of existing species or addition of new species to thermo().obigt.

Parameters

*args : int, str, list, or dict
If first argument is numeric: species index or indices to modify If first argument is str: species name(s) to modify or add If first argument is list/dict: contains all parameters
zap : bool, default False
If True, clear all properties except state and model before updating
**kwargs : any
Named properties to set (e.g., G=-100, S=50, formula="H2O") Special properties: name, state, formula, model, E_units

Returns

int or list of int
Species index or indices that were modified/added

Examples

>>> import pychnosz
>>> pychnosz.reset()
>>> # Add new species
>>> i = pychnosz.mod_OBIGT("myspecies", formula="C2H6", G=-100, S=50)
>>> # Modify existing species
>>> i = pychnosz.mod_OBIGT("water", state="liq", G=-56690)
>>> # Modify by index
>>> i_h2o = pychnosz.info("water", "liq")
>>> i = pychnosz.mod_OBIGT(i_h2o, G=-56690)
>>> # Add multiple species
>>> i = pychnosz.mod_OBIGT(["X", "Y"], formula=["C12", "C13"], state=["aq", "cr"])

Notes

This function modifies the thermo() object in place. The behavior exactly matches R CHNOSZ mod.OBIGT().

def pinfo(protein: str | int | pandas.core.frame.DataFrame | List,
organism: str | None = None,
residue: bool = False,
regexp: bool = False) ‑> pandas.core.frame.DataFrame | numpy.ndarray | int
Expand source code
def pinfo(protein: Union[str, int, pd.DataFrame, List],
          organism: Optional[str] = None,
          residue: bool = False,
          regexp: bool = False) -> Union[pd.DataFrame, np.ndarray, int]:
    """
    Get protein information from thermo().protein.

    This function retrieves protein data from the thermodynamic database.
    The behavior depends on the input type:
    - DataFrame: returns the DataFrame (possibly per residue)
    - int or list of ints: returns rows from thermo().protein
    - str: searches for protein by name, returns row number(s)

    Parameters
    ----------
    protein : str, int, DataFrame, or list
        Protein identifier(s) or data
    organism : str, optional
        Organism identifier (used with protein name)
    residue : bool, default False
        Return per-residue amino acid composition
    regexp : bool, default False
        Use regular expression matching for protein search

    Returns
    -------
    DataFrame, array, or int
        Protein information or row numbers

    Examples
    --------
    >>> # Get protein by name
    >>> iprotein = pinfo("LYSC_CHICK")
    >>> # Get protein data by row number
    >>> protein_data = pinfo(iprotein)
    """
    t_p = thermo().protein

    if t_p is None:
        raise RuntimeError("Protein database not loaded. Run reset() first.")

    # If input is a DataFrame, return it (possibly per residue)
    if isinstance(protein, pd.DataFrame):
        out = protein.copy()
        if residue:
            # Normalize by total amino acid count (columns 5:25)
            row_sums = out.iloc[:, 5:25].sum(axis=1)
            out.iloc[:, 4:24] = out.iloc[:, 4:24].div(row_sums, axis=0)
        return out

    # If input is numeric, get rows from thermo().protein
    if isinstance(protein, (int, np.integer)):
        protein = [protein]

    if isinstance(protein, (list, np.ndarray)) and all(isinstance(x, (int, np.integer)) for x in protein):
        # Get amino acid counts
        iproteins = list(range(len(t_p)))
        # Replace invalid indices with NaN
        protein_clean = [p if p in iproteins else np.nan for p in protein]
        # Filter out NaN values for indexing
        valid_indices = [p for p in protein_clean if not np.isnan(p)]

        if not valid_indices:
            return pd.DataFrame()

        out = t_p.iloc[valid_indices].copy()

        # Compute per-residue counts if requested
        if residue:
            row_sums = out.iloc[:, 5:25].sum(axis=1)
            out.iloc[:, 4:24] = out.iloc[:, 4:24].div(row_sums, axis=0)

        return out

    # If input is string or list of strings, search for protein
    if isinstance(protein, str):
        protein = [protein]

    if isinstance(protein, list) and all(isinstance(x, str) for x in protein):
        # Search for protein or protein_organism in thermo().protein
        t_p_names = t_p['protein'] + '_' + t_p['organism']

        if regexp:
            # Use regular expression matching
            matches = []
            for prot in protein:
                iprotein = t_p['protein'].str.contains(prot, regex=True, na=False)
                if organism is not None:
                    iorganism = t_p['organism'].str.contains(organism, regex=True, na=False)
                    iprotein = iprotein & iorganism
                indices = np.where(iprotein)[0]
                if len(indices) > 0:
                    matches.extend(indices.tolist())
                else:
                    matches.append(np.nan)

            if len(matches) == 1:
                if np.isnan(matches[0]):
                    return np.nan
                return int(matches[0])
            return np.array(matches)
        else:
            # Exact matching
            if organism is None:
                my_names = protein
            else:
                my_names = [f"{p}_{organism}" for p in protein]

            # Find matches
            matches = []
            for name in my_names:
                idx = np.where(t_p_names == name)[0]
                if len(idx) > 0:
                    matches.append(idx[0])
                else:
                    matches.append(np.nan)

            if len(matches) == 1:
                if np.isnan(matches[0]):
                    return np.nan
                return int(matches[0])
            return np.array(matches)

    raise TypeError(f"Unsupported protein type: {type(protein)}")

Get protein information from thermo().protein.

This function retrieves protein data from the thermodynamic database. The behavior depends on the input type: - DataFrame: returns the DataFrame (possibly per residue) - int or list of ints: returns rows from thermo().protein - str: searches for protein by name, returns row number(s)

Parameters

protein : str, int, DataFrame, or list
Protein identifier(s) or data
organism : str, optional
Organism identifier (used with protein name)
residue : bool, default False
Return per-residue amino acid composition
regexp : bool, default False
Use regular expression matching for protein search

Returns

DataFrame, array, or int
Protein information or row numbers

Examples

>>> # Get protein by name
>>> iprotein = pinfo("LYSC_CHICK")
>>> # Get protein data by row number
>>> protein_data = pinfo(iprotein)
def protein_OBIGT(protein: int | List[int] | pandas.core.frame.DataFrame,
organism: str | None = None,
state: str | None = None) ‑> pandas.core.frame.DataFrame
Expand source code
def protein_OBIGT(protein: Union[int, List[int], pd.DataFrame],
                 organism: Optional[str] = None,
                 state: Optional[str] = None) -> pd.DataFrame:
    """
    Calculate protein properties using group additivity.

    This function calculates thermodynamic properties of proteins
    from amino acid composition using the group additivity approach.

    Parameters
    ----------
    protein : int, list of int, or DataFrame
        Protein identifier(s) or amino acid composition data
    organism : str, optional
        Organism identifier
    state : str, optional
        Physical state ('aq' or 'cr'). If None, uses thermo().opt['state']

    Returns
    -------
    DataFrame
        Thermodynamic properties in OBIGT format

    Examples
    --------
    >>> iprotein = pinfo("LYSC_CHICK")
    >>> props = protein_OBIGT(iprotein)
    """
    # Get amino acid composition
    aa = pinfo(pinfo(protein, organism))

    if not isinstance(aa, pd.DataFrame):
        raise TypeError("Could not retrieve protein data")

    # Get state
    if state is None:
        state = thermo().opt.get('state', 'aq')

    # The names of the protein backbone groups depend on the state
    # [UPBB] for aq or [PBB] for cr
    if state == 'aq':
        bbgroup = 'UPBB'
    else:
        bbgroup = 'PBB'

    # Names of the AABB, sidechain and protein backbone groups
    aa_cols = aa.columns[5:25].tolist()  # Get amino acid column names
    groups = ['AABB'] + aa_cols + [bbgroup]

    # Put brackets around the group names
    groups = [f"[{g}]" for g in groups]

    # The row numbers of the groups in thermo().OBIGT
    from ..core.info import info

    groups_state = [f"{g}" for g in groups]
    obigt = thermo().obigt

    # Find groups in OBIGT
    igroup = []
    for group_name in groups_state:
        # Search for the group with the specified state
        matches = obigt[(obigt['name'] == group_name) & (obigt['state'] == state)]
        if len(matches) > 0:
            igroup.append(matches.index[0])
        else:
            # Try without brackets if not found
            group_alt = group_name.strip('[]')
            matches = obigt[(obigt['name'] == group_alt) & (obigt['state'] == state)]
            if len(matches) > 0:
                igroup.append(matches.index[0])
            else:
                raise ValueError(f"Group {group_name} not found in OBIGT for state {state}")

    # The properties are in columns 9:21 of thermo().OBIGT (G, H, S, Cp, V, etc.)
    # Column indices: G=9, H=10, S=11, Cp=12, V=13, a1.a=14, a2.b=15, a3.c=16, a4.d=17, c1.e=18, c2.f=19, omega.lambda=20, z.T=21
    groupprops = obigt.loc[igroup, obigt.columns[9:22]]

    # The elements in each of the groups
    groupelements = i2A(igroup)

    results = []

    # Process each protein
    for idx in range(len(aa)):
        aa_row = aa.iloc[idx]

        # Numbers of groups: chains [=AABB], sidechains, protein backbone
        nchains = float(aa_row.iloc[4])  # chains column
        length = float(aa_row.iloc[5:25].sum())  # sum of amino acids
        npbb = length - nchains

        # Create ngroups array
        ngroups = np.array([nchains] + aa_row.iloc[5:25].tolist() + [npbb], dtype=float)

        # Calculate thermodynamic properties by group additivity
        eos = (groupprops.values * ngroups[:, np.newaxis]).sum(axis=0)

        # Calculate formula
        f_in = (groupelements.values * ngroups[:, np.newaxis]).sum(axis=0).round(3)

        # Remove elements that don't appear
        element_names = groupelements.columns
        f_dict = {elem: f_in[i] for i, elem in enumerate(element_names) if f_in[i] != 0}

        # Turn it into a formula string
        f = as_chemical_formula(f_dict)

        # Species name
        name = f"{aa_row['protein']}_{aa_row['organism']}"

        # Print message
        print(f"protein_OBIGT: found {name} ({f}, {round(length, 3)} residues)")

        ref = aa_row['ref']

        # Include 'model' column
        model = 'HKF' if state == 'aq' else 'CGL'

        # Create header
        header = {
            'name': name,
            'abbrv': None,
            'formula': f,
            'state': state,
            'ref1': ref,
            'ref2': None,
            'date': None,
            'model': model,
            'E_units': 'cal'
        }

        # Combine header and eos
        eosout = {**header, **dict(zip(groupprops.columns, eos))}
        results.append(eosout)

    # Convert to DataFrame
    out = pd.DataFrame(results)
    out.reset_index(drop=True, inplace=True)

    return out

Calculate protein properties using group additivity.

This function calculates thermodynamic properties of proteins from amino acid composition using the group additivity approach.

Parameters

protein : int, list of int, or DataFrame
Protein identifier(s) or amino acid composition data
organism : str, optional
Organism identifier
state : str, optional
Physical state ('aq' or 'cr'). If None, uses thermo().opt['state']

Returns

DataFrame
Thermodynamic properties in OBIGT format

Examples

>>> iprotein = pinfo("LYSC_CHICK")
>>> props = protein_OBIGT(iprotein)
def protein_basis(protein: int | List[int] | pandas.core.frame.DataFrame,
T: float = 25.0,
normalize: bool = False) ‑> pandas.core.frame.DataFrame
Expand source code
def protein_basis(protein: Union[int, List[int], pd.DataFrame],
                 T: float = 25.0,
                 normalize: bool = False) -> pd.DataFrame:
    """
    Calculate coefficients of basis species in protein formation reactions.

    Parameters
    ----------
    protein : int, list of int, or DataFrame
        Protein identifier(s) or amino acid composition data
    T : float, default 25.0
        Temperature in degrees Celsius
    normalize : bool, default False
        Normalize by protein length

    Returns
    -------
    DataFrame
        Coefficients of basis species

    Examples
    --------
    >>> from pychnosz import *
    >>> basis("CHNOSe")
    >>> iprotein = pinfo("LYSC_CHICK")
    >>> coeffs = protein_basis(iprotein)
    """
    # Get amino acid composition
    aa = pinfo(pinfo(protein))

    if not isinstance(aa, pd.DataFrame):
        raise TypeError("Could not retrieve protein data")

    # Get protein formulas
    pf = protein_formula(aa)

    # Calculate coefficients of basis species in formation reactions
    sb = species_basis(pf)

    # Calculate ionization states if H+ is a basis species
    t = thermo()
    if t.basis is not None:
        basis_species = t.basis.index.tolist()
        if 'H+' in basis_species:
            iHplus = basis_species.index('H+')
            pH = -t.basis.loc['H+', 'logact']
            Z = ionize_aa(aa, T=T, pH=pH).iloc[0, :]
            sb.iloc[:, iHplus] = sb.iloc[:, iHplus] + Z.values

    # Normalize by length if requested
    if normalize:
        plen = protein_length(aa)
        sb = sb.div(plen, axis=0)

    return sb

Calculate coefficients of basis species in protein formation reactions.

Parameters

protein : int, list of int, or DataFrame
Protein identifier(s) or amino acid composition data
T : float, default 25.0
Temperature in degrees Celsius
normalize : bool, default False
Normalize by protein length

Returns

DataFrame
Coefficients of basis species

Examples

>>> from pychnosz import *
>>> basis("CHNOSe")
>>> iprotein = pinfo("LYSC_CHICK")
>>> coeffs = protein_basis(iprotein)
def protein_formula(protein: int | List[int] | pandas.core.frame.DataFrame,
organism: str | None = None,
residue: bool = False) ‑> pandas.core.frame.DataFrame
Expand source code
def protein_formula(protein: Union[int, List[int], pd.DataFrame],
                   organism: Optional[str] = None,
                   residue: bool = False) -> pd.DataFrame:
    """
    Calculate chemical formulas of proteins.

    Parameters
    ----------
    protein : int, list of int, or DataFrame
        Protein identifier(s) or amino acid composition data
    organism : str, optional
        Organism identifier (used with protein number)
    residue : bool, default False
        Return per-residue formula

    Returns
    -------
    DataFrame
        Chemical formulas with elements C, H, N, O, S as columns

    Examples
    --------
    >>> iprotein = pinfo("LYSC_CHICK")
    >>> formula = protein_formula(iprotein)
    """
    # Get amino acid composition
    aa = pinfo(pinfo(protein, organism))

    if not isinstance(aa, pd.DataFrame):
        raise TypeError("Could not retrieve protein data")

    # Get group formulas
    rf = group_formulas()

    # Matrix multiplication: amino acid counts * residue formulas
    # Columns 5:25 contain amino acid counts (excluding chains column at 4)
    # We need to add H2O (chains column) separately
    aa_counts = aa.iloc[:, 5:25].values.astype(float)
    chains = aa.iloc[:, 4].values.astype(float)
    rf_values = rf.iloc[1:, :].values.astype(float)  # Skip H2O row, use amino acid residues
    rf_H2O = rf.iloc[0, :].values.astype(float)  # H2O row

    # Calculate protein formula: amino acids + H2O for chains
    out = np.dot(aa_counts, rf_values) + np.outer(chains, rf_H2O)

    # Normalize by residue if requested
    if residue:
        row_sums = aa.iloc[:, 5:25].sum(axis=1).values
        out = out / row_sums[:, np.newaxis]

    # Create DataFrame with protein names as index
    protein_names = aa['protein'] + '_' + aa['organism']
    # Make names unique if there are duplicates
    if protein_names.duplicated().any():
        counts = {}
        unique_names = []
        for name in protein_names:
            if name in counts:
                counts[name] += 1
                unique_names.append(f"{name}.{counts[name]}")
            else:
                counts[name] = 0
                unique_names.append(name)
        protein_names = unique_names

    result = pd.DataFrame(out,
                         index=protein_names,
                         columns=['C', 'H', 'N', 'O', 'S'])

    return result

Calculate chemical formulas of proteins.

Parameters

protein : int, list of int, or DataFrame
Protein identifier(s) or amino acid composition data
organism : str, optional
Organism identifier (used with protein number)
residue : bool, default False
Return per-residue formula

Returns

DataFrame
Chemical formulas with elements C, H, N, O, S as columns

Examples

>>> iprotein = pinfo("LYSC_CHICK")
>>> formula = protein_formula(iprotein)
def protein_length(protein: int | List[int] | pandas.core.frame.DataFrame,
organism: str | None = None) ‑> int | numpy.ndarray
Expand source code
def protein_length(protein: Union[int, List[int], pd.DataFrame],
                   organism: Optional[str] = None) -> Union[int, np.ndarray]:
    """
    Calculate the length(s) of proteins.

    Parameters
    ----------
    protein : int, list of int, or DataFrame
        Protein identifier(s) or amino acid composition data
    organism : str, optional
        Organism identifier (used with protein number)

    Returns
    -------
    int or array
        Protein length(s) in amino acid residues

    Examples
    --------
    >>> iprotein = pinfo("LYSC_CHICK")
    >>> length = protein_length(iprotein)
    """
    # Get amino acid composition
    aa = pinfo(pinfo(protein, organism))

    if isinstance(aa, pd.DataFrame):
        # Use sum on the columns containing amino acid counts (columns 5:25)
        pl = aa.iloc[:, 5:25].sum(axis=1).values
        return pl
    else:
        return 0

Calculate the length(s) of proteins.

Parameters

protein : int, list of int, or DataFrame
Protein identifier(s) or amino acid composition data
organism : str, optional
Organism identifier (used with protein number)

Returns

int or array
Protein length(s) in amino acid residues

Examples

>>> iprotein = pinfo("LYSC_CHICK")
>>> length = protein_length(iprotein)
def quartz_coesite(PAR, T, P)
Expand source code
def quartz_coesite(PAR, T, P):
    # the corrections are 0 for anything other than quartz and coesite
    if not PAR["name"] in ["quartz", "coesite"]:
        n = T.size if isinstance(T, np.ndarray) else 1
        return(dict(G=np.zeros(n), H=np.zeros(n), S=np.zeros(n), V=np.zeros(n)))

    # Vectorized version
    T = np.atleast_1d(T)
    P = np.atleast_1d(P)

    # Tr, Pr and TtPr (transition temperature at Pr)
    Pr = 1      # bar
    Tr = 298.15 # K
    TtPr = 848  # K
    # constants from SUP92D.f
    aa = 549.824
    ba = 0.65995
    ca = -0.4973e-4
    VPtTta = 23.348
    VPrTtb = 23.72
    Stran = 0.342
    # constants from REAC92D.f
    VPrTra = 22.688 # VPrTr(a-quartz)
    Vdiff = 2.047   # VPrTr(a-quartz) - VPrTr(coesite)
    k = 38.5       # dPdTtr(a/b-quartz)
    #k <- 38.45834    # calculated in CHNOSZ: dPdTtr(info("quartz"))
    # code adapted from REAC92D.f
    qphase = PAR["state"].replace("cr", "")

    if qphase == "2":
        Pstar = P.copy()
        Sstar = np.zeros_like(T)
        V = np.full_like(T, VPrTtb)
    else:
        Pstar = Pr + k * (T - TtPr)
        Sstar = np.full_like(T, Stran)
        V = VPrTra + ca*(P-Pr) + (VPtTta - VPrTra - ca*(P-Pr))*(T-Tr) / (TtPr + (P-Pr)/k - Tr)

    # Apply condition: if T < TtPr
    below_transition = T < TtPr
    Pstar = np.where(below_transition, Pr, Pstar)
    Sstar = np.where(below_transition, 0, Sstar)

    if PAR["name"] == "coesite":
        VPrTra = VPrTra - Vdiff
        VPrTtb = VPrTtb - Vdiff
        V = V - Vdiff

    cm3bar_to_cal = 0.023901488

    # Vectorized log calculation
    with np.errstate(divide='ignore', invalid='ignore'):
        log_term = np.log((aa + P/k) / (aa + Pstar/k))
        log_term = np.where(np.isfinite(log_term), log_term, 0)

    GVterm = cm3bar_to_cal * (VPrTra * (P - Pstar) + VPrTtb * (Pstar - Pr) - \
        0.5 * ca * (2 * Pr * (P - Pstar) - (P**2 - Pstar**2)) - \
        ca * k * (T - Tr) * (P - Pstar) + \
        k * (ba + aa * ca * k) * (T - Tr) * log_term)
    SVterm = cm3bar_to_cal * (-k * (ba + aa * ca * k) * log_term + ca * k * (P - Pstar)) - Sstar

    # note the minus sign on "SVterm" in order that intdVdTdP has the correct sign
    return dict(intVdP=GVterm, intdVdTdP=-SVterm, V=V)
def ratlab(top: str = 'K+',
bottom: str = 'H+',
molality: bool = False,
reverse_charge: bool = False) ‑> str
Expand source code
def ratlab(top: str = "K+", bottom: str = "H+", molality: bool = False,
           reverse_charge: bool = False) -> str:
    """
    Create formatted text label for activity ratio.

    This function generates a LaTeX-formatted string suitable for use as
    axis labels in matplotlib plots, showing the ratio of activities of
    two ions raised to appropriate powers based on their charges.

    Parameters
    ----------
    top : str, default "K+"
        Chemical formula for the numerator ion
    bottom : str, default "H+"
        Chemical formula for the denominator ion
    molality : bool, default False
        If True, use 'm' (molality) instead of 'a' (activity)
    reverse_charge : bool, default False
        If True, reverse charge order in formatting (e.g., "Fe+3" becomes "Fe^{3+}")
        If False, keep original order (e.g., "Fe+3" becomes "Fe^{+3}")

    Returns
    -------
    str
        LaTeX-formatted string for the activity ratio label

    Examples
    --------
    >>> ratlab("K+", "H+")
    'log($a_{K^{+}}$ / $a_{H^{+}}$)'

    >>> ratlab("Ca+2", "H+")
    'log($a_{Ca^{+2}}$ / $a_{H^{+}}^{2}$)'

    >>> ratlab("Ca+2", "H+", reverse_charge=True)
    'log($a_{Ca^{2+}}$ / $a_{H^{+}}^{2}$)'

    >>> ratlab("Mg+2", "Ca+2")
    'log($a_{Mg^{+2}}$ / $a_{Ca^{+2}}$)'

    Notes
    -----
    The exponents are determined by the charges of the ions to maintain
    charge balance in the ratio. For example, for Ca+2/H+, the H+ term
    is squared because Ca has a +2 charge.

    The output format is compatible with matplotlib's LaTeX rendering.
    In R CHNOSZ, this uses plotmath expressions; here we use LaTeX strings
    that matplotlib can render.
    """
    # Get the charges of the ions
    makeup_top = makeup(top)
    makeup_bottom = makeup(bottom)

    Z_top = makeup_top.get('Z', 0)
    Z_bottom = makeup_bottom.get('Z', 0)

    # The exponents for charge balance
    # If top has charge +2 and bottom has +1, bottom gets exponent 2
    exp_bottom = abs(Z_top)
    exp_top = abs(Z_bottom)

    # Format exponents (don't show if = 1)
    exp_top_str = "" if exp_top == 1 else f"^{{{int(exp_top)}}}"
    exp_bottom_str = "" if exp_bottom == 1 else f"^{{{int(exp_bottom)}}}"

    # Format the ion formulas for display
    top_formatted = _format_species_latex(top, reverse_charge=reverse_charge)
    bottom_formatted = _format_species_latex(bottom, reverse_charge=reverse_charge)

    # Choose activity or molality symbol
    a = "m" if molality else "a"

    # Build the expression
    # Format: log(a_top^exp / a_bottom^exp)
    numerator = f"${a}_{{{top_formatted}}}{exp_top_str}$"
    denominator = f"${a}_{{{bottom_formatted}}}{exp_bottom_str}$"

    label = f"log({numerator} / {denominator})"

    return label

Create formatted text label for activity ratio.

This function generates a LaTeX-formatted string suitable for use as axis labels in matplotlib plots, showing the ratio of activities of two ions raised to appropriate powers based on their charges.

Parameters

top : str, default "K+"
Chemical formula for the numerator ion
bottom : str, default "H+"
Chemical formula for the denominator ion
molality : bool, default False
If True, use 'm' (molality) instead of 'a' (activity)
reverse_charge : bool, default False
If True, reverse charge order in formatting (e.g., "Fe+3" becomes "Fe^{3+}") If False, keep original order (e.g., "Fe+3" becomes "Fe^{+3}")

Returns

str
LaTeX-formatted string for the activity ratio label

Examples

>>> ratlab("K+", "H+")
'log($a_{K^{+}}$ / $a_{H^{+}}$)'
>>> ratlab("Ca+2", "H+")
'log($a_{Ca^{+2}}$ / $a_{H^{+}}^{2}$)'
>>> ratlab("Ca+2", "H+", reverse_charge=True)
'log($a_{Ca^{2+}}$ / $a_{H^{+}}^{2}$)'
>>> ratlab("Mg+2", "Ca+2")
'log($a_{Mg^{+2}}$ / $a_{Ca^{+2}}$)'

Notes

The exponents are determined by the charges of the ions to maintain charge balance in the ratio. For example, for Ca+2/H+, the H+ term is squared because Ca has a +2 charge.

The output format is compatible with matplotlib's LaTeX rendering. In R CHNOSZ, this uses plotmath expressions; here we use LaTeX strings that matplotlib can render.

def ratlab_html(top: str = 'K+', bottom: str = 'H+', molality: bool = False) ‑> str
Expand source code
def ratlab_html(top: str = "K+", bottom: str = "H+", molality: bool = False) -> str:
    """
    Create HTML-formatted text label for activity ratio (for Plotly/HTML rendering).

    This function generates an HTML-formatted string suitable for use with
    Plotly interactive plots, showing the ratio of activities of two ions
    raised to appropriate powers based on their charges.

    This is a companion function to ratlab() which produces LaTeX format for
    matplotlib. Use ratlab_html() when creating labels for diagram(..., interactive=True).

    Parameters
    ----------
    top : str, default "K+"
        Chemical formula for the numerator ion
    bottom : str, default "H+"
        Chemical formula for the denominator ion
    molality : bool, default False
        If True, use 'm' (molality) instead of 'a' (activity)

    Returns
    -------
    str
        HTML-formatted string for the activity ratio label

    Examples
    --------
    >>> ratlab_html("K+", "H+")
    'log(a<sub>K<sup>+</sup></sub>/a<sub>H<sup>+</sup></sub>)'

    >>> ratlab_html("Ca+2", "H+")
    'log(a<sub>Ca<sup>2+</sup></sub>/a<sup>2</sup><sub>H<sup>+</sup></sub>)'

    >>> ratlab_html("Mg+2", "Ca+2")
    'log(a<sub>Mg<sup>2+</sup></sub>/a<sub>Ca<sup>2+</sup></sub>)'

    Notes
    -----
    The exponents are determined by the charges of the ions to maintain
    charge balance in the ratio. For example, for Ca+2/H+, the H+ term
    is squared because Ca has a +2 charge.

    The output format uses HTML tags (<sub>, <sup>) compatible with Plotly.
    For matplotlib plots with LaTeX rendering, use ratlab() instead.

    Requires: WORMutils (for chemlabel) and chemparse (for parse_formula)

    See Also
    --------
    ratlab : LaTeX version for matplotlib
    """
    if not _HTML_DEPS_AVAILABLE:
        raise ImportError(
            "ratlab_html() requires 'WORMutils' and 'chemparse' packages.\n"
            "Install with: pip install WORMutils chemparse"
        )

    # Parse the formulas to get charges
    top_formula = parse_formula(top)
    if "+" in top_formula.keys():
        top_charge = top_formula["+"]
    elif "-" in top_formula.keys():
        top_charge = -top_formula["-"]
    else:
        raise ValueError("Cannot create an ion ratio involving one or more neutral species.")

    bottom_formula = parse_formula(bottom)
    if "+" in bottom_formula.keys():
        bottom_charge = bottom_formula["+"]
    elif "-" in bottom_formula.keys():
        bottom_charge = -bottom_formula["-"]
    else:
        raise ValueError("Cannot create an ion ratio involving one or more neutral species.")

    # Convert to integers if whole numbers
    if top_charge.is_integer():
        top_charge = int(top_charge)

    if bottom_charge.is_integer():
        bottom_charge = int(bottom_charge)

    # The exponents for charge balance
    # If top has charge +2 and bottom has +1, bottom gets exponent 2
    exp_bottom = abs(top_charge)
    exp_top = abs(bottom_charge)

    # Format exponents as superscripts (don't show if = 1)
    if exp_top != 1:
        top_exp_str = "<sup>" + str(exp_top) + "</sup>"
    else:
        top_exp_str = ""

    if exp_bottom != 1:
        bottom_exp_str = "<sup>" + str(exp_bottom) + "</sup>"
    else:
        bottom_exp_str = ""

    # Choose activity or molality symbol
    if molality:
        sym = "m"
    else:
        sym = "a"

    # Format the chemical formulas with chemlabel
    top_formatted = chemlabel(top)
    bottom_formatted = chemlabel(bottom)

    # Build the HTML expression
    # Format: log(a_top^exp / a_bottom^exp)
    return f"log({sym}{top_exp_str}<sub>{top_formatted}</sub>/{sym}{bottom_exp_str}<sub>{bottom_formatted}</sub>)"

Create HTML-formatted text label for activity ratio (for Plotly/HTML rendering).

This function generates an HTML-formatted string suitable for use with Plotly interactive plots, showing the ratio of activities of two ions raised to appropriate powers based on their charges.

This is a companion function to ratlab() which produces LaTeX format for matplotlib. Use ratlab_html() when creating labels for diagram(…, interactive=True).

Parameters

top : str, default "K+"
Chemical formula for the numerator ion
bottom : str, default "H+"
Chemical formula for the denominator ion
molality : bool, default False
If True, use 'm' (molality) instead of 'a' (activity)

Returns

str
HTML-formatted string for the activity ratio label

Examples

>>> ratlab_html("K+", "H+")
'log(a<sub>K<sup>+</sup></sub>/a<sub>H<sup>+</sup></sub>)'
>>> ratlab_html("Ca+2", "H+")
'log(a<sub>Ca<sup>2+</sup></sub>/a<sup>2</sup><sub>H<sup>+</sup></sub>)'
>>> ratlab_html("Mg+2", "Ca+2")
'log(a<sub>Mg<sup>2+</sup></sub>/a<sub>Ca<sup>2+</sup></sub>)'

Notes

The exponents are determined by the charges of the ions to maintain charge balance in the ratio. For example, for Ca+2/H+, the H+ term is squared because Ca has a +2 charge.

The output format uses HTML tags (, ) compatible with Plotly. For matplotlib plots with LaTeX rendering, use ratlab() instead.

Requires: WORMutils (for chemlabel) and chemparse (for parse_formula)

See Also

ratlab()
LaTeX version for matplotlib
def reset(messages: bool = True)
Expand source code
def reset(messages: bool = True):
    """
    Initialize or reset the CHNOSZ thermodynamic system.

    This function initializes the global thermodynamic system by loading
    all thermodynamic data files, setting up the OBIGT database, and
    preparing the system for calculations.

    This is equivalent to the reset() function in the R version of CHNOSZ.

    Parameters
    ----------
    messages : bool, default True
        Whether to print informational messages

    Examples
    --------
    >>> import pychnosz
    >>> pychnosz.reset()  # Initialize the system
    reset: thermodynamic system initialized
    """
    thermo_system = get_thermo_system()
    thermo_system.reset(messages=messages)

Initialize or reset the CHNOSZ thermodynamic system.

This function initializes the global thermodynamic system by loading all thermodynamic data files, setting up the OBIGT database, and preparing the system for calculations.

This is equivalent to the reset() function in the R version of CHNOSZ.

Parameters

messages : bool, default True
Whether to print informational messages

Examples

>>> import pychnosz
>>> pychnosz.reset()  # Initialize the system
reset: thermodynamic system initialized
def reset_OBIGT() ‑> None
Expand source code
def reset_OBIGT() -> None:
    """
    Reset OBIGT database to default state.
    
    This function reloads the default thermodynamic database,
    removing any modifications made by add_OBIGT().
    """
    from ..utils.reset import reset
    reset()
    print("OBIGT database reset to default state")

Reset OBIGT database to default state.

This function reloads the default thermodynamic database, removing any modifications made by add_OBIGT().

def reset_WORM(messages: bool = True) ‑> None
Expand source code
def reset_WORM(messages: bool = True) -> None:
    """
    Initialize the thermodynamic system with the WORM database.

    This is a convenience function that combines reset() and load_WORM().
    It initializes the system and loads the WORM database in one step.

    Parameters
    ----------
    messages : bool, default True
        Whether to print informational messages

    Examples
    --------
    >>> import pychnosz
    >>> # Initialize with WORM database
    >>> pychnosz.reset_WORM()

    Notes
    -----
    This is equivalent to:
        pychnosz.reset()
        pychnosz.load_WORM()
    """
    from ..utils.reset import reset

    # Reset the system first
    reset(messages=messages)

    # Load WORM database
    success = load_WORM(keep_default=False, messages=messages)

    if not success:
        if messages:
            print("reset_WORM: falling back to default OBIGT database")

Initialize the thermodynamic system with the WORM database.

This is a convenience function that combines reset() and load_WORM(). It initializes the system and loads the WORM database in one step.

Parameters

messages : bool, default True
Whether to print informational messages

Examples

>>> import pychnosz
>>> # Initialize with WORM database
>>> pychnosz.reset_WORM()

Notes

This is equivalent to: pychnosz.reset() pychnosz.load_WORM()

def retrieve(elements: str | List[str] | Tuple[str] | None = None,
ligands: str | List[str] | Tuple[str] | None = None,
state: str | List[str] | Tuple[str] | None = None,
T: float | List[float] | None = None,
P: str | float | List[float] = 'Psat',
add_charge: bool = True,
hide_groups: bool = True,
messages: bool = True) ‑> pandas.core.series.Series
Expand source code
def retrieve(elements: Optional[Union[str, List[str], Tuple[str]]] = None,
            ligands: Optional[Union[str, List[str], Tuple[str]]] = None,
            state: Optional[Union[str, List[str], Tuple[str]]] = None,
            T: Optional[Union[float, List[float]]] = None,
            P: Union[str, float, List[float]] = "Psat",
            add_charge: bool = True,
            hide_groups: bool = True,
            messages: bool = True) -> pd.Series:
    """
    Retrieve species containing specified elements.

    Parameters
    ----------
    elements : str, list of str, or tuple of str, optional
        Elements in a chemical system. If `elements` is a string, retrieve
        species containing that element.

        E.g., `retrieve("Au")` will return all species containing Au.

        If `elements` is a list, retrieve species that have all of the elements
        in the list.

        E.g., `retrieve(["Au", "Cl"])` will return all species that have both
        Au and Cl.

        If `elements` is a tuple, retrieve species relevant to the system,
        including charged species.

        E.g., `retrieve(("Au", "Cl"))` will return species that have Au
        and/or Cl, including charged species, but no other elements.

    ligands : str, list of str, or tuple of str, optional
        Elements present in any ligands. This affects the species search:
        - If ligands is a state ('cr', 'liq', 'gas', 'aq'), use that as the state filter
        - Otherwise, include elements in the system defined by ligands

    state : str, list of str, or tuple of str, optional
        Filter the result on these state(s) ('aq', 'cr', 'gas', 'liq').

    T : float or list of float, optional
        Temperature (K) for filtering species with non-NA Gibbs energy.

    P : str, float, or list of float, default "Psat"
        Pressure for Gibbs energy calculation. Default is "Psat" (saturation).

    add_charge : bool, default True
        For chemical systems (tuple input), automatically include charge (Z).

    hide_groups : bool, default True
        Exclude group species (names in brackets like [CH2]).

    messages : bool, default True
        Print informational messages. If False, suppress messages about
        updating the stoichiometric matrix and other information.

    Returns
    -------
    pd.Series
        Series of species indices (1-based) with chemical formulas as index.
        This behaves like R's named vector - you can access by name or position.
        Names are chemical formulas (or 'e-' for electrons).
        Values are species indices that match the criteria.

    Examples
    --------
    >>> # All species containing Au
    >>> retrieve("Au")

    >>> # All species that have both Au and Cl
    >>> retrieve(["Au", "Cl"])

    >>> # Au-Cl system: species with Au and/or Cl, including charged species
    >>> retrieve(("Au", "Cl"))

    >>> # All Au-bearing species in the Au-Cl system
    >>> retrieve("Au", ("Cl",))

    >>> # All uncharged Au-bearing species in the Au-Cl system
    >>> retrieve("Au", ("Cl",), add_charge=False)

    >>> # Minerals in the system SiO2-MgO-CaO-CO2
    >>> retrieve(("Si", "Mg", "Ca", "C", "O"), state="cr")

    Notes
    -----
    This function uses 1-based indexing to match R CHNOSZ conventions.
    The returned indices are labels that can be used with .loc[], not positions.
    """
    # Empty argument handling
    if elements is None:
        return pd.Series([], dtype=int)

    thermo_obj = thermo()

    # Initialize database if needed
    if not thermo_obj.is_initialized():
        thermo_obj.reset()

    ## Stoichiometric matrix
    # Get stoichiometric matrix from thermo object
    stoich = _get_or_update_stoich(thermo_obj, messages=messages)

    ## Generate error for missing element(s)
    allelements = []
    if elements is not None:
        if isinstance(elements, (list, tuple)):
            allelements.extend(elements)
        else:
            allelements.append(elements)
    if ligands is not None:
        if isinstance(ligands, (list, tuple)):
            allelements.extend(ligands)
        else:
            allelements.append(ligands)

    not_present = [elem for elem in allelements if elem not in stoich.columns and elem != "all"]
    if not_present:
        if len(not_present) == 1:
            raise ValueError(f'"{not_present[0]}" is not an element that is present in any species in the database')
        else:
            raise ValueError(f'"{", ".join(not_present)}" are not elements that are present in any species in the database')

    ## Handle 'ligands' argument
    if ligands is not None:
        # If 'ligands' is cr, liq, gas, or aq, use that as the state
        if ligands in ['cr', 'liq', 'gas', 'aq']:
            state = ligands
            ispecies = retrieve(elements, add_charge=add_charge, messages=messages)
        else:
            # Include the element in the system defined by the ligands list
            # Convert ligands to tuple if it's a string or list
            if isinstance(ligands, str):
                ligands_tuple = (ligands,)
            elif isinstance(ligands, list):
                ligands_tuple = tuple(ligands)
            else:
                ligands_tuple = ligands

            # Combine elements with ligands
            if isinstance(elements, str):
                combined = (elements,) + ligands_tuple
            elif isinstance(elements, list):
                combined = tuple(elements) + ligands_tuple
            else:
                combined = elements + ligands_tuple

            # Call retrieve() for each argument and take the intersection
            r1 = retrieve(elements, add_charge=add_charge, messages=messages)
            r2 = retrieve(combined, add_charge=add_charge, messages=messages)
            ispecies = np.intersect1d(r1, r2)
    else:
        ## Species identification
        ispecies_list = []

        # Determine if elements is a tuple (chemical system)
        is_system = isinstance(elements, tuple)

        # Convert single string to list for iteration
        if isinstance(elements, str):
            elements_iter = [elements]
        else:
            elements_iter = list(elements)

        # Automatically add charge to a system
        if add_charge and is_system and "Z" not in elements_iter:
            elements_iter.append("Z")

        # Proceed element-by-element
        for element in elements_iter:
            if element == "all":
                ispecies_list.append(np.array(thermo_obj.obigt.index.tolist()))
            else:
                # Identify the species that have the element
                has_element = (stoich[element] != 0)
                ispecies_list.append(np.array(stoich.index[has_element].tolist()))

        # Now we have a list of ispecies (one array for each element)
        # What we do next depends on whether the argument is a tuple or not
        if is_system:
            # For a chemical system, all species are included that do not contain any other elements
            ispecies = np.unique(np.concatenate(ispecies_list))

            # Get columns not in elements
            other_columns = [col for col in stoich.columns if col not in elements_iter]

            if other_columns:
                # Check which species have other elements
                otherstoich = stoich.loc[ispecies, other_columns]
                iother = (otherstoich != 0).any(axis=1)
                ispecies = ispecies[~iother.values]
        else:
            # Get species that have all the elements; the species must be present in each array
            # This is the intersection of all arrays
            ispecies = ispecies_list[0]
            for arr in ispecies_list[1:]:
                ispecies = np.intersect1d(ispecies, arr)

    # Exclude groups
    if hide_groups:
        obigt = thermo_obj.obigt
        names = obigt.loc[ispecies, 'name'].values
        is_group = np.array([bool(re.match(r'^\[.*\]$', str(name))) for name in names])
        ispecies = ispecies[~is_group]

    # Filter on state
    if state is not None:
        obigt = thermo_obj.obigt

        # Ensure state is a list
        if isinstance(state, str):
            state_list = [state]
        elif isinstance(state, tuple):
            state_list = list(state)
        else:
            state_list = state

        species_states = obigt.loc[ispecies, 'state'].values
        istate = np.array([s in state_list for s in species_states])
        ispecies = ispecies[istate]

    # Require non-NA Delta G0 at specific temperature
    if T is not None:
        from .subcrt import subcrt
        # Suppress warnings and (optionally) messages
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            try:
                result = subcrt(ispecies.tolist(), T=T, P=P, messages=False, show=False)
                if result is not None and 'out' in result:
                    G_values = []
                    for species_out in result['out']:
                        if isinstance(species_out, dict) and 'G' in species_out:
                            G = species_out['G']
                            if isinstance(G, (list, np.ndarray)):
                                G_values.append(G[0] if len(G) > 0 else np.nan)
                            else:
                                G_values.append(G)
                        else:
                            G_values.append(np.nan)

                    # Filter out species with NA G values
                    has_G = np.array([not pd.isna(g) for g in G_values])
                    ispecies = ispecies[has_G]
            except:
                # If subcrt fails, keep all species
                pass

    # Create a pandas Series with formula names (R-style named vector)
    obigt = thermo_obj.obigt
    formulas = obigt.loc[ispecies, 'formula'].values

    # Use e- instead of (Z-1) for electron
    formulas = np.array([f if f != '(Z-1)' else 'e-' for f in formulas])

    # Return empty Series if nothing found
    if len(ispecies) == 0:
        return pd.Series([], dtype=int)

    # Create a pandas Series with formulas as index (R-style named vector)
    # This allows both named access (result["Au"]) and positional access (result[0])
    result = pd.Series(ispecies, index=formulas)

    return result

Retrieve species containing specified elements.

Parameters

elements : str, list of str, or tuple of str, optional

Elements in a chemical system. If elements is a string, retrieve species containing that element.

E.g., retrieve("Au") will return all species containing Au.

If elements is a list, retrieve species that have all of the elements in the list.

E.g., retrieve(["Au", "Cl"]) will return all species that have both Au and Cl.

If elements is a tuple, retrieve species relevant to the system, including charged species.

E.g., retrieve(("Au", "Cl")) will return species that have Au and/or Cl, including charged species, but no other elements.

ligands : str, list of str, or tuple of str, optional
Elements present in any ligands. This affects the species search: - If ligands is a state ('cr', 'liq', 'gas', 'aq'), use that as the state filter - Otherwise, include elements in the system defined by ligands
state : str, list of str, or tuple of str, optional
Filter the result on these state(s) ('aq', 'cr', 'gas', 'liq').
T : float or list of float, optional
Temperature (K) for filtering species with non-NA Gibbs energy.
P : str, float, or list of float, default "Psat"
Pressure for Gibbs energy calculation. Default is "Psat" (saturation).
add_charge : bool, default True
For chemical systems (tuple input), automatically include charge (Z).
hide_groups : bool, default True
Exclude group species (names in brackets like [CH2]).
messages : bool, default True
Print informational messages. If False, suppress messages about updating the stoichiometric matrix and other information.

Returns

pd.Series
Series of species indices (1-based) with chemical formulas as index. This behaves like R's named vector - you can access by name or position. Names are chemical formulas (or 'e-' for electrons). Values are species indices that match the criteria.

Examples

>>> # All species containing Au
>>> retrieve("Au")
>>> # All species that have both Au and Cl
>>> retrieve(["Au", "Cl"])
>>> # Au-Cl system: species with Au and/or Cl, including charged species
>>> retrieve(("Au", "Cl"))
>>> # All Au-bearing species in the Au-Cl system
>>> retrieve("Au", ("Cl",))
>>> # All uncharged Au-bearing species in the Au-Cl system
>>> retrieve("Au", ("Cl",), add_charge=False)
>>> # Minerals in the system SiO2-MgO-CaO-CO2
>>> retrieve(("Si", "Mg", "Ca", "C", "O"), state="cr")

Notes

This function uses 1-based indexing to match R CHNOSZ conventions. The returned indices are labels that can be used with .loc[], not positions.

def set_title(ax_or_fig, title: str, fontsize: float = 12, **kwargs)
Expand source code
def set_title(ax_or_fig, title: str, fontsize: float = 12, **kwargs):
    """
    Set title on a matplotlib axes or Plotly figure.

    This function provides a unified interface for setting titles on both
    matplotlib and Plotly plots, allowing seamless switching between
    interactive=True and interactive=False.

    Parameters
    ----------
    ax_or_fig : matplotlib.axes.Axes or plotly.graph_objs.Figure
        Axes or Figure object to set title on
    title : str
        The title text
    fontsize : float, default 12
        Font size for the title
    **kwargs
        Additional arguments passed to matplotlib set_title() or Plotly update_layout()

    Returns
    -------
    matplotlib.text.Text or plotly.graph_objs.Figure
        The title object (matplotlib) or the figure (Plotly)

    Examples
    --------
    >>> from pychnosz.utils.expression import set_title, syslab
    >>> # Matplotlib diagram
    >>> d1 = diagram(a, interactive=False, plot_it=False)
    >>> title_text = syslab(["H2O", "CO2", "CaO", "MgO", "SiO2"])
    >>> set_title(d1['ax'], title_text, fontsize=12)
    >>> # Display the figure in Jupyter:
    >>> from IPython.display import display
    >>> display(d1['fig'])

    >>> # Plotly diagram
    >>> d1 = diagram(a, interactive=True, plot_it=False)
    >>> title_text = syslab_html(["H2O", "CO2", "CaO", "MgO", "SiO2"])
    >>> set_title(d1['ax'], title_text, fontsize=12)
    >>> d1['fig'].show()

    Notes
    -----
    When using plot_it=False, you need to explicitly display the figure after
    setting the title. In Jupyter notebooks, use display(d['fig']) or d['fig'].show()
    for Plotly diagrams. Outside Jupyter, use plt.show() or save with d['fig'].savefig().
    """
    is_plotly = _is_plotly_figure(ax_or_fig)

    if is_plotly:
        # Plotly figure
        title_dict = {'text': title, 'x': 0.5, 'xanchor': 'center'}
        if fontsize:
            title_dict['font'] = {'size': fontsize}
        ax_or_fig.update_layout(title=title_dict, **kwargs)
        return ax_or_fig
    else:
        # Matplotlib axes
        return ax_or_fig.set_title(title, fontsize=fontsize, **kwargs)

Set title on a matplotlib axes or Plotly figure.

This function provides a unified interface for setting titles on both matplotlib and Plotly plots, allowing seamless switching between interactive=True and interactive=False.

Parameters

ax_or_fig : matplotlib.axes.Axes or plotly.graph_objs.Figure
Axes or Figure object to set title on
title : str
The title text
fontsize : float, default 12
Font size for the title
**kwargs
Additional arguments passed to matplotlib set_title() or Plotly update_layout()

Returns

matplotlib.text.Text or plotly.graph_objs.Figure
The title object (matplotlib) or the figure (Plotly)

Examples

>>> from pychnosz.utils.expression import set_title, syslab
>>> # Matplotlib diagram
>>> d1 = diagram(a, interactive=False, plot_it=False)
>>> title_text = syslab(["H2O", "CO2", "CaO", "MgO", "SiO2"])
>>> set_title(d1['ax'], title_text, fontsize=12)
>>> # Display the figure in Jupyter:
>>> from IPython.display import display
>>> display(d1['fig'])
>>> # Plotly diagram
>>> d1 = diagram(a, interactive=True, plot_it=False)
>>> title_text = syslab_html(["H2O", "CO2", "CaO", "MgO", "SiO2"])
>>> set_title(d1['ax'], title_text, fontsize=12)
>>> d1['fig'].show()

Notes

When using plot_it=False, you need to explicitly display the figure after setting the title. In Jupyter notebooks, use display(d['fig']) or d['fig'].show() for Plotly diagrams. Outside Jupyter, use plt.show() or save with d['fig'].savefig().

def species(species: str | int | List[str | int] | pandas.core.series.Series | None = None,
state: str | List[str] | None = None,
delete: bool = False,
add: bool = False,
index_return: bool = False,
global_state: bool = True,
basis: pandas.core.frame.DataFrame | None = None,
messages: bool = True) ‑> pandas.core.frame.DataFrame | List[int] | None
Expand source code
def species(species: Optional[Union[str, int, List[Union[str, int]], pd.Series]] = None,
            state: Optional[Union[str, List[str]]] = None,
            delete: bool = False,
            add: bool = False,
            index_return: bool = False,
            global_state: bool = True,
            basis: Optional[pd.DataFrame] = None,
            messages: bool = True) -> Optional[Union[pd.DataFrame, List[int]]]:
    """
    Define species of interest for thermodynamic calculations.

    Parameters
    ----------
    species : str, int, list, pd.Series, or None
        Species name(s), formula(s), or index(es).
        Can also be a pandas Series (e.g., from retrieve()).
        If None, returns current species definition.
    state : str, list of str, or None
        Physical state(s) for the species
    delete : bool, default False
        If True, delete species (all if species is None)
    add : bool, default False
        If True, add to existing species instead of replacing
    index_return : bool, default False
        If True, return species indices instead of DataFrame
    global_state : bool, default True
        If True, store species in global thermo().species (default behavior)
        If False, return species definition without storing globally (local state)
    basis : pd.DataFrame, optional
        Basis species definition to use (if not using global basis)
        Required when global_state=False and basis is not defined globally
    messages : bool, default True
        If True, print informational messages

    Returns
    -------
    pd.DataFrame, list of int, or None
        Species definition DataFrame or indices, or None if deleted

    Examples
    --------
    >>> # Define species of interest
    >>> species(["CO2", "HCO3-", "CO3-2"])

    >>> # Add more species
    >>> species(["CH4", "C2H4"], add=True)

    >>> # Delete specific species
    >>> species(["CO2"], delete=True)

    >>> # Delete all species
    >>> species(delete=True)

    >>> # Use output from retrieve()
    >>> zn_species = retrieve("Zn", ["O", "H"], state="aq")
    >>> species(zn_species)
    """
    thermo_obj = thermo()

    # Handle pandas Series (e.g., from retrieve())
    if isinstance(species, pd.Series):
        # Extract the integer indices from the Series values
        species = species.values.tolist()

    # Handle NA species
    if species is pd.NA or species is np.nan:
        raise SpeciesError("'species' is NA")
    
    # Handle deletion
    if delete:
        return _delete_species(species, thermo_obj)
    
    # Return current species if no arguments
    if species is None and state is None:
        if index_return:
            if thermo_obj.species is not None:
                return list(range(1, len(thermo_obj.species) + 1))
            else:
                return []
        return thermo_obj.species
    
    # Use all species indices if species is None but state is given
    if species is None and thermo_obj.species is not None:
        species = list(range(1, len(thermo_obj.species) + 1))
    
    # Process state argument
    state = _process_state_argument(state)
    
    # Make species and state same length
    species, state = _match_argument_lengths(species, state)
    
    # Handle numeric state (treat as logact)
    logact = None
    if state is not None and len(state) > 0:
        if isinstance(state[0], (int, float)):
            logact = [float(s) for s in state]
            state = None
        elif _can_be_numeric(state[0]):
            logact = [float(s) for s in state]
            state = None
    
    # Handle species-state combinations for proteins
    if state is not None:
        species, state = _handle_protein_naming(species, state, thermo_obj)
    
    # Process species argument
    iOBIGT = None
    if isinstance(species[0], str):
        # Check if species are in current definition
        if thermo_obj.species is not None:
            existing_indices = _match_existing_species(species, thermo_obj.species)
            if all(idx is not None for idx in existing_indices) and logact is not None:
                # Update activities of existing species
                # Update activities of existing species directly
                species_indices = [i+1 for i in existing_indices]  # Convert to 1-based
                return _update_existing_species(species_indices, None, logact, index_return, thermo_obj)
        
        # Look up species in database
        iOBIGT = _lookup_species_indices(species, state, messages)
        
    else:
        # Handle numeric species
        if thermo_obj.species is not None:
            max_current = len(thermo_obj.species)
            if all(isinstance(s, int) and s <= max_current for s in species):
                # Referring to existing species
                return _update_existing_species(species, state, logact, index_return, thermo_obj)
        
        # Referring to OBIGT indices
        iOBIGT = species
    
    # Create or modify species definition
    if iOBIGT is not None:
        return _create_species_definition(iOBIGT, state, logact, add, index_return, thermo_obj, global_state, basis)
    else:
        return _update_existing_species(species, state, logact, index_return, thermo_obj)

Define species of interest for thermodynamic calculations.

Parameters

species : str, int, list, pd.Series, or None
Species name(s), formula(s), or index(es). Can also be a pandas Series (e.g., from retrieve()). If None, returns current species definition.
state : str, list of str, or None
Physical state(s) for the species
delete : bool, default False
If True, delete species (all if species is None)
add : bool, default False
If True, add to existing species instead of replacing
index_return : bool, default False
If True, return species indices instead of DataFrame
global_state : bool, default True
If True, store species in global thermo().species (default behavior) If False, return species definition without storing globally (local state)
basis : pd.DataFrame, optional
Basis species definition to use (if not using global basis) Required when global_state=False and basis is not defined globally
messages : bool, default True
If True, print informational messages

Returns

pd.DataFrame, list of int, or None
Species definition DataFrame or indices, or None if deleted

Examples

>>> # Define species of interest
>>> species(["CO2", "HCO3-", "CO3-2"])
>>> # Add more species
>>> species(["CH4", "C2H4"], add=True)
>>> # Delete specific species
>>> species(["CO2"], delete=True)
>>> # Delete all species
>>> species(delete=True)
>>> # Use output from retrieve()
>>> zn_species = retrieve("Zn", ["O", "H"], state="aq")
>>> species(zn_species)
def subcrt(species: str | List[str] | int | List[int],
coeff: int | float | List[int | float] | None = 1,
state: str | List[str] | None = None,
property: List[str] = ['logK', 'G', 'H', 'S', 'V', 'Cp'],
T: float | numpy.ndarray | List[float] = array([273.16, 298.15, 323.15, 348.15, 373.15, 398.15, 423.15, 448.15, 473.15, 498.15, 523.15, 548.15, 573.15, 598.15, 623.15]),
P: float | List[float] | numpy.ndarray | str = 'Psat',
grid: str | None = None,
convert: bool = True,
exceed_Ttr: bool = True,
exceed_rhomin: bool = False,
logact: List[float] | None = None,
autobalance: bool = True,
use_polymorphs: bool = True,
IS: float | List[float] = 0,
messages: bool = True,
show: bool = True,
basis: pandas.core.frame.DataFrame | None = None) ‑> pychnosz.core.subcrt.SubcrtResult
Expand source code
def subcrt(species: Union[str, List[str], int, List[int]],
           coeff: Union[int, float, List[Union[int, float]], None] = 1,
           state: Optional[Union[str, List[str]]] = None,
           property: List[str] = ["logK", "G", "H", "S", "V", "Cp"],
           T: Union[float, List[float], np.ndarray] = np.concatenate([[273.16], 273.15 + np.arange(25, 351, 25)]),
           P: Union[float, List[float], np.ndarray, str] = "Psat",
           grid: Optional[str] = None,
           convert: bool = True,
           exceed_Ttr: bool = True,
           exceed_rhomin: bool = False,
           logact: Optional[List[float]] = None,
           autobalance: bool = True,
           use_polymorphs: bool = True,
           IS: Union[float, List[float]] = 0,
           messages: bool = True,
           show: bool = True,
           basis: Optional[pd.DataFrame] = None,
           _recursion_count: int = 0) -> SubcrtResult:
    """
    Calculate standard molal thermodynamic properties of species and reactions.
    
    This function reproduces the behavior of R CHNOSZ subcrt() exactly, including
    all argument handling, validation, calculations, and output formatting.
    
    Parameters
    ----------
    species : str, list of str, int, or list of int
        Species names, formulas, or indices in thermodynamic database
    coeff : int, float, list, or None
        Stoichiometric coefficients for reaction calculation
        If 1 (default), calculate individual species properties
        If list, calculate reaction with given coefficients
    state : str, list of str, or None
        Physical states ("aq", "cr", "gas", "liq") for species
    property : list of str
        Properties to calculate: "logK", "G", "H", "S", "V", "Cp", "rho", "kT", "E"
    T : float, list, or ndarray
        Temperature(s) in K (default: 273.16, then 298.15 to 623.15 by 25 K)
    P : float, list, ndarray, or "Psat"
        Pressure(s) in bar or "Psat" for saturation pressure
    grid : str or None
        Grid calculation mode: "T", "P", "IS", or None
    convert : bool
        Convert temperature/pressure units (default: True)
    exceed_Ttr : bool
        Allow calculations beyond transition temperatures (default: False)
    exceed_rhomin : bool
        Allow calculations below minimum water density (default: False)
    logact : list of float or None
        Activity coefficients (log10 scale)
    autobalance : bool
        Automatically balance reactions using basis species (default: True)
    use_polymorphs : bool
        Include polymorphic phases for minerals (default: True)
    IS : float or list of float
        Ionic strength for activity corrections (default: 0)
    messages : bool, default True
        Whether to print informational messages
    show : bool, default True
        Whether to display result tables in Jupyter notebooks (default: True)
        Set to False when calling subcrt() from other functions
    basis : pd.DataFrame, optional
        Basis species definition to use for autobalancing (if not using global basis)

    Returns
    -------
    SubcrtResult
        Object containing:
        - species: DataFrame with species information
        - out: DataFrame with calculated thermodynamic properties
        - reaction: DataFrame with reaction stoichiometry (if reaction)
        - warnings: List of warning messages
        
    Examples
    --------
    >>> import pychnosz
    >>> pychnosz.reset()
    >>> 
    >>> # Single species properties
    >>> result = subcrt("H2O", T=25, P=1)
    >>> print(result.out[["G", "H", "S", "Cp"]])
    >>> 
    >>> # Reaction calculation
    >>> result = subcrt(["H2O", "H+", "OH-"], [-1, 1, 1], T=25, P=1)
    >>> print(f"Water dissociation ΔG° = {result.out.G[0]:.3f} kJ/mol")
    >>> 
    >>> # Temperature array
    >>> result = subcrt("quartz", T=[25, 100, 200], P=1)
    >>> print(result.out[["T", "G", "H", "S"]])
    
    Notes
    -----
    This implementation maintains complete fidelity to R CHNOSZ subcrt():
    - Identical argument processing and validation
    - Same species lookup and polymorphic handling
    - Exact HKF and CGL equation-of-state calculations
    - Same reaction balancing and autobalance logic
    - Identical output structure and formatting
    - Same warning and error messages
    """
    
    result = SubcrtResult()

    # Prevent infinite recursion in auto-balancing
    if _recursion_count > 5:
        result.warnings.append("Maximum recursion depth reached in auto-balancing")
        return result

    try:
        # === Phase 1: Argument Processing and Validation ===
        # (Exactly matching R subcrt.R lines 21-77)
        
        # Handle argument reordering if states are second argument
        if coeff != 1 and isinstance(coeff, (list, str)) and isinstance(coeff[0] if isinstance(coeff, list) else coeff, str):
            # States were passed as second argument - reorder
            if state is not None:
                if isinstance(state, (int, float)) or (isinstance(state, list) and all(isinstance(x, (int, float)) for x in state)):
                    # Third argument is coefficients
                    new_coeff = state
                    new_state = coeff
                    return subcrt(species, new_coeff, new_state, property, T, P, grid,
                                convert, exceed_Ttr, exceed_rhomin, logact, autobalance, use_polymorphs, IS,
                                messages, show, basis, _recursion_count)
                else:
                    raise ValueError("If both coeff and state are given, one should be numeric coefficients")
            else:
                # Only states provided, no coefficients
                new_state = coeff
                return subcrt(species, 1, new_state, property, T, P, grid,
                            convert, exceed_Ttr, exceed_rhomin, logact, autobalance, use_polymorphs, IS,
                            messages, show, basis, _recursion_count)
        
        # Determine if this is a reaction calculation
        do_reaction = (coeff != 1 and coeff is not None and 
                      (isinstance(coeff, list) or isinstance(coeff, (int, float)) and coeff != 1))
        
        # Convert inputs to consistent formats
        species = [species] if isinstance(species, (str, int)) else list(species)
        if state is not None:
            state = [state] if isinstance(state, str) else list(state)
            # Make species and state same length
            if len(state) > len(species):
                species = species * (len(state) // len(species) + 1)
                species = species[:len(state)]
            elif len(species) > len(state):
                state = state * (len(species) // len(state) + 1)
                state = state[:len(species)]
        
        if do_reaction:
            if isinstance(coeff, (int, float)):
                coeff = [coeff]
            coeff = list(coeff)
        
        # Validate properties
        allowed_properties = ["rho", "logK", "G", "H", "S", "Cp", "V", "kT", "E"]
        if isinstance(property, str):
            property = [property]
        
        invalid_props = [p for p in property if p not in allowed_properties]
        if invalid_props:
            if len(invalid_props) == 1:
                raise ValueError(f"invalid property name: {invalid_props[0]}")
            else:
                raise ValueError(f"invalid property names: {', '.join(invalid_props)}")
        
        # Length checking
        if do_reaction and len(species) != len(coeff):
            raise ValueError("the length of 'coeff' must equal the number of species")
        
        if logact is not None and len(logact) != len(species):
            raise ValueError("the length of 'logact' must equal the number of species")
        
        # Unit conversion
        T_array = np.atleast_1d(np.asarray(T, dtype=float))
        # Convert temperature to Kelvin if convert=True (matching R CHNOSZ behavior)
        # R: if(convert) T <- envert(T, "K") - converts Celsius input to Kelvin
        # Default parameter is [273.16, 298.15, 323.15, ..., 623.15] which is already in K, so only convert user input
        default_T = np.concatenate([[273.16], 273.15 + np.arange(25, 351, 25)])
        if convert and not np.array_equal(T_array, default_T[:len(T_array)]):
            # User provided temperature, assume Celsius and convert to Kelvin
            T_array = T_array + 273.15

        # Handle T=273.15K (0°C) exactly - R CHNOSZ uses 273.16K (0.01°C) instead
        # This avoids numerical issues at the freezing point
        T_array = np.where(np.abs(T_array - 273.15) < 1e-10, 273.16, T_array)
        
        if isinstance(P, str) and P == "Psat":
            P_array = "Psat"
        else:
            P_array = np.atleast_1d(np.asarray(P, dtype=float))
            # P is assumed to be in bar (R CHNOSZ standard)
        
        # Warning for high temperatures with Psat
        # Check if P is "Psat" (compare to the original P, not P_array which may be converted)
        if isinstance(P, str) and P == "Psat" and np.any(T_array > 647.067):
            n_over = np.sum(T_array > 647.067)
            vtext = "value" if n_over == 1 else "values"
            result.warnings.append(f"P = 'Psat' undefined for T > Tcritical ({n_over} T {vtext})")
        
        # === Phase 2: Grid Processing ===
        # Handle grid calculations (T-P arrays)
        if grid is not None:
            if grid == "T":
                # Grid over temperature
                new_T = []
                for temp in T_array:
                    if isinstance(P_array, str):
                        new_T.extend([temp] * 1)
                    else:
                        new_T.extend([temp] * len(P_array))
                if isinstance(P_array, str):
                    new_P = P_array
                else:
                    new_P = list(P_array) * len(T_array)
                T_array = np.array(new_T)
                P_array = new_P
            elif grid == "P":
                # Grid over pressure
                if not isinstance(P_array, str):
                    new_P = []
                    for press in P_array:
                        new_P.extend([press] * len(T_array))
                    new_T = list(T_array) * len(P_array)
                    T_array = np.array(new_T)
                    P_array = np.array(new_P)
            elif grid == "IS":
                # Grid over ionic strength
                IS_array = np.atleast_1d(np.asarray(IS))
                original_len = max(len(T_array), len(P_array) if not isinstance(P_array, str) else 1)
                new_IS = []
                for ionic_str in IS_array:
                    new_IS.extend([ionic_str] * original_len)
                T_array = np.tile(T_array, len(IS_array))
                if isinstance(P_array, str):
                    P_array = P_array
                else:
                    P_array = np.tile(P_array, len(IS_array))
                IS = new_IS
        else:
            # Ensure T and P are same length
            if isinstance(P_array, str):
                # P = "Psat", keep T as is
                pass
            else:
                max_len = max(len(T_array), len(P_array))
                if len(T_array) < max_len:
                    T_array = np.resize(T_array, max_len)
                if len(P_array) < max_len:
                    P_array = np.resize(P_array, max_len)
        
        # === Phase 3: Species Lookup and Validation ===
        result.species, result.reaction, iphases, isaq, isH2O, iscgl, polymorph_species, ispecies = _process_species(
            species, state, coeff, do_reaction, use_polymorphs, messages=messages)
        
        # === Phase 4: Generate Output Message ===
        if (len(species) > 1 or convert) and messages:
            _print_subcrt_message(species, T_array, P_array, isaq.any() or isH2O.any(), messages)
        
        # === Phase 5: Reaction Balance Check ===
        if do_reaction and autobalance:
            # Use original ispecies and coeff for balance check (before polymorph expansion)
            # This matches R CHNOSZ behavior where balance check happens before polymorph expansion
            rebalanced_result = _check_reaction_balance(result, species, coeff, state, property,
                                                      T_array, P_array, grid, convert, logact,
                                                      exceed_Ttr, exceed_rhomin, IS, ispecies, _recursion_count, basis, T, P, messages, show)
            if rebalanced_result is not None:  # If reaction was rebalanced, return the result
                return rebalanced_result
        
        # === Phase 6: Property Calculations ===
        result.out, calc_warnings = _calculate_properties(property, iphases, isaq, isH2O, iscgl,
                                         T_array, P_array, exceed_rhomin, exceed_Ttr, IS, logact, do_reaction)
        # Add calculation warnings to result
        result.warnings.extend(calc_warnings)
        
        # === Phase 6.5: Polymorph Selection ===
        if use_polymorphs:
            # Select stable polymorphs based on minimum Gibbs energy
            # Apply to both individual species AND reactions (matching R CHNOSZ behavior)
            thermo_sys = thermo()
            if do_reaction:
                # For reactions, also update coefficients and rebuild reaction DataFrame
                result.out, updated_coeff, updated_iphases = _select_stable_polymorphs(result.out, iphases, polymorph_species, ispecies, thermo_sys, result.reaction['coeff'].tolist(), messages)
                # Rebuild reaction DataFrame with updated species list
                reaction_data = []
                for i, iph in enumerate(updated_iphases):
                    row = thermo_sys.obigt.loc[iph]
                    model = row.get('model', 'unknown')
                    if model == "H2O":
                        water_model = thermo_sys.get_option('water', 'SUPCRT92')
                        model = f"water.{water_model}"
                    reaction_data.append({
                        'coeff': updated_coeff[i],
                        'name': row['name'],
                        'formula': row['formula'],
                        'state': row['state'],
                        'ispecies': iph,
                        'model': model
                    })
                result.reaction = pd.DataFrame(reaction_data)
            else:
                # For individual species, no coefficient update needed
                result.out, _ = _select_stable_polymorphs(result.out, iphases, polymorph_species, ispecies, thermo_sys, None, messages)
            
            # For single species (non-reaction), convert back to DataFrame format
            if not do_reaction and isinstance(result.out, dict) and 'species_data' in result.out and len(result.out['species_data']) == 1:
                result.out = result.out['species_data'][0]
        
        # === Phase 7: Reaction Property Summation ===
        if do_reaction:
            result.out = _sum_reaction_properties(result.out, result.reaction['coeff'])
        
        # === Phase 8: Unit Conversion (convert=True) ===
        if convert:
            # Apply R CHNOSZ compatible conversion
            # This matches the observed behavior where convert=TRUE gives different results
            # than just multiplying by 4.184
            result.out = _apply_r_chnosz_conversion(result.out, do_reaction)
            
            # Recalculate logK after unit conversion to ensure consistency
            if do_reaction and 'logK' in property and 'G' in result.out.columns:
                if not result.out['G'].isna().all():
                    R = 8.314462618  # J/(mol·K) - CODATA 2018 value
                    T_array = np.atleast_1d(T_array)
                    result.out['logK'] = -result.out['G'] / (np.log(10) * R * T_array)

        # Display tables in Jupyter notebooks if show=True
        if show:
            _display_subcrt_result(result)

        # Print warnings (matching R CHNOSZ behavior - lines 621-624)
        if result.warnings and messages:
            for warn in result.warnings:
                warnings.warn(warn)

        return result
        
    except Exception as e:
        result.warnings.append(f"subcrt error: {str(e)}")
        raise

Calculate standard molal thermodynamic properties of species and reactions.

This function reproduces the behavior of R CHNOSZ subcrt() exactly, including all argument handling, validation, calculations, and output formatting.

Parameters

species : str, list of str, int, or list of int
Species names, formulas, or indices in thermodynamic database
coeff : int, float, list, or None
Stoichiometric coefficients for reaction calculation If 1 (default), calculate individual species properties If list, calculate reaction with given coefficients
state : str, list of str, or None
Physical states ("aq", "cr", "gas", "liq") for species
property : list of str
Properties to calculate: "logK", "G", "H", "S", "V", "Cp", "rho", "kT", "E"
T : float, list, or ndarray
Temperature(s) in K (default: 273.16, then 298.15 to 623.15 by 25 K)
P : float, list, ndarray, or "Psat"
Pressure(s) in bar or "Psat" for saturation pressure
grid : str or None
Grid calculation mode: "T", "P", "IS", or None
convert : bool
Convert temperature/pressure units (default: True)
exceed_Ttr : bool
Allow calculations beyond transition temperatures (default: False)
exceed_rhomin : bool
Allow calculations below minimum water density (default: False)
logact : list of float or None
Activity coefficients (log10 scale)
autobalance : bool
Automatically balance reactions using basis species (default: True)
use_polymorphs : bool
Include polymorphic phases for minerals (default: True)
IS : float or list of float
Ionic strength for activity corrections (default: 0)
messages : bool, default True
Whether to print informational messages
show : bool, default True
Whether to display result tables in Jupyter notebooks (default: True) Set to False when calling subcrt() from other functions
basis : pd.DataFrame, optional
Basis species definition to use for autobalancing (if not using global basis)

Returns

SubcrtResult
Object containing: - species: DataFrame with species information - out: DataFrame with calculated thermodynamic properties - reaction: DataFrame with reaction stoichiometry (if reaction) - warnings: List of warning messages

Examples

>>> import pychnosz
>>> pychnosz.reset()
>>> 
>>> # Single species properties
>>> result = subcrt("H2O", T=25, P=1)
>>> print(result.out[["G", "H", "S", "Cp"]])
>>> 
>>> # Reaction calculation
>>> result = subcrt(["H2O", "H+", "OH-"], [-1, 1, 1], T=25, P=1)
>>> print(f"Water dissociation ΔG° = {result.out.G[0]:.3f} kJ/mol")
>>> 
>>> # Temperature array
>>> result = subcrt("quartz", T=[25, 100, 200], P=1)
>>> print(result.out[["T", "G", "H", "S"]])

Notes

This implementation maintains complete fidelity to R CHNOSZ subcrt(): - Identical argument processing and validation - Same species lookup and polymorphic handling - Exact HKF and CGL equation-of-state calculations - Same reaction balancing and autobalance logic - Identical output structure and formatting - Same warning and error messages

def syslab(system: list = None, dash: str = '-') ‑> str
Expand source code
def syslab(system: list = None, dash: str = "-") -> str:
    """
    Create formatted text for thermodynamic system.

    This generates a label showing the components of a thermodynamic system,
    separated by dashes (or other separator).

    Parameters
    ----------
    system : list of str, optional
        List of component formulas. Default: ["K2O", "Al2O3", "SiO2", "H2O"]
    dash : str, default "-"
        Separator between components

    Returns
    -------
    str
        LaTeX-formatted string for the system label

    Examples
    --------
    >>> syslab(["K2O", "Al2O3", "SiO2", "H2O"])
    '$K_{2}O-Al_{2}O_{3}-SiO_{2}-H_{2}O$'

    >>> syslab(["CaO", "MgO", "SiO2"], dash="–")
    '$CaO–MgO–SiO_{2}$'
    """
    if system is None:
        system = ["K2O", "Al2O3", "SiO2", "H2O"]

    # Format each component
    formatted_components = []
    for component in system:
        formatted = _add_subscripts(component)
        formatted_components.append(formatted)

    # Join with separator
    label = dash.join(formatted_components)

    # Wrap in LaTeX math mode
    return f"${label}$"

Create formatted text for thermodynamic system.

This generates a label showing the components of a thermodynamic system, separated by dashes (or other separator).

Parameters

system : list of str, optional
List of component formulas. Default: ["K2O", "Al2O3", "SiO2", "H2O"]
dash : str, default "-"
Separator between components

Returns

str
LaTeX-formatted string for the system label

Examples

>>> syslab(["K2O", "Al2O3", "SiO2", "H2O"])
'$K_{2}O-Al_{2}O_{3}-SiO_{2}-H_{2}O$'
>>> syslab(["CaO", "MgO", "SiO2"], dash="–")
'$CaO–MgO–SiO_{2}$'
def syslab_html(system: list = None, dash: str = '-') ‑> str
Expand source code
def syslab_html(system: list = None, dash: str = "-") -> str:
    """
    Create HTML-formatted text for thermodynamic system (for Plotly).

    This generates a label showing the components of a thermodynamic system,
    separated by dashes (or other separator), using HTML formatting compatible
    with Plotly instead of LaTeX.

    Parameters
    ----------
    system : list of str, optional
        List of component formulas. Default: ["K2O", "Al2O3", "SiO2", "H2O"]
    dash : str, default "-"
        Separator between components

    Returns
    -------
    str
        HTML-formatted string for the system label

    Examples
    --------
    >>> syslab_html(["K2O", "Al2O3", "SiO2", "H2O"])
    'K<sub>2</sub>O-Al<sub>2</sub>O<sub>3</sub>-SiO<sub>2</sub>-H<sub>2</sub>O'

    >>> syslab_html(["CaO", "MgO", "SiO2"], dash="–")
    'CaO–MgO–SiO<sub>2</sub>'

    Notes
    -----
    Use this function instead of syslab() when creating titles for interactive
    (Plotly) diagrams. The HTML formatting is compatible with Plotly's rendering.

    Requires: WORMutils (for chemlabel)
    """
    if not _HTML_DEPS_AVAILABLE:
        raise ImportError(
            "syslab_html() requires 'WORMutils' package.\n"
            "Install with: pip install WORMutils"
        )

    if system is None:
        system = ["K2O", "Al2O3", "SiO2", "H2O"]

    # Format each component using HTML via chemlabel
    formatted_components = []
    for component in system:
        formatted = chemlabel(component)
        formatted_components.append(formatted)

    # Join with separator (no HTML wrapper needed)
    label = dash.join(formatted_components)

    return label

Create HTML-formatted text for thermodynamic system (for Plotly).

This generates a label showing the components of a thermodynamic system, separated by dashes (or other separator), using HTML formatting compatible with Plotly instead of LaTeX.

Parameters

system : list of str, optional
List of component formulas. Default: ["K2O", "Al2O3", "SiO2", "H2O"]
dash : str, default "-"
Separator between components

Returns

str
HTML-formatted string for the system label

Examples

>>> syslab_html(["K2O", "Al2O3", "SiO2", "H2O"])
'K<sub>2</sub>O-Al<sub>2</sub>O<sub>3</sub>-SiO<sub>2</sub>-H<sub>2</sub>O'
>>> syslab_html(["CaO", "MgO", "SiO2"], dash="–")
'CaO–MgO–SiO<sub>2</sub>'

Notes

Use this function instead of syslab() when creating titles for interactive (Plotly) diagrams. The HTML formatting is compatible with Plotly's rendering.

Requires: WORMutils (for chemlabel)

def thermo(*args, messages=True, **kwargs)
Expand source code
def thermo(*args, messages=True, **kwargs):
    """
    Access or modify the thermodynamic system data object.

    This function provides a convenient interface to get or set parts of the
    thermodynamic system, similar to R's par() function for graphics parameters.

    Parameters
    ----------
    *args : str or list of str
        Names of attributes to retrieve (e.g., "element", "opt$ideal.H")
        For nested access, use "$" notation (e.g., "opt$E.units")
        Special values:
        - "WORM": Load the WORM thermodynamic database (Python-exclusive feature)
    messages : bool, default True
        Whether to print informational messages during operations
    **kwargs : any
        Named arguments to set attributes (e.g., element=new_df, opt={'E.units': 'cal'})
        For nested attributes, use "$" in the name (e.g., **{"opt$ideal.H": False})

    Returns
    -------
    various
        - If no arguments: returns the ThermoSystem object
        - If single unnamed argument: returns the requested value
        - If multiple unnamed arguments: returns list of requested values
        - If named arguments: returns original values before modification

    Examples
    --------
    >>> import pychnosz
    >>> # Get the entire thermo object
    >>> ts = pychnosz.thermo()

    >>> # Get a specific attribute
    >>> elem = pychnosz.thermo("element")

    >>> # Get nested attribute
    >>> e_units = pychnosz.thermo("opt$E.units")

    >>> # Get multiple attributes
    >>> elem, buf = pychnosz.thermo("element", "buffer")

    >>> # Set an attribute
    >>> old_elem = pychnosz.thermo(element=new_element_df)

    >>> # Set nested attribute
    >>> old_units = pychnosz.thermo(**{"opt$ideal.H": False})

    >>> # Load WORM database (Python-exclusive feature)
    >>> pychnosz.thermo("WORM")

    >>> # Suppress messages
    >>> pychnosz.thermo("WORM", messages=False)

    Notes
    -----
    This function mimics the behavior of R CHNOSZ thermo() function,
    providing flexible access to the thermodynamic data object.

    The "WORM" special argument is a Python-exclusive feature that loads
    the Water-Organic-Rock-Microbe thermodynamic database from the
    WORM-db GitHub repository.
    """
    # Get the global thermo system
    thermo_sys = get_thermo_system()

    # If no arguments, return the entire object
    if len(args) == 0 and len(kwargs) == 0:
        return thermo_sys

    # Handle character vectors passed as args (like R's c("basis", "species"))
    # If all args are strings or lists of strings, flatten them
    flat_args = []
    for arg in args:
        if isinstance(arg, (list, tuple)) and all(isinstance(x, str) for x in arg):
            flat_args.extend(arg)
        else:
            flat_args.append(arg)
    args = flat_args

    # Prepare return values list
    return_values = []

    # Ensure system is initialized if needed (before accessing any properties)
    # This prevents auto-initialization from using hardcoded messages=True
    if not thermo_sys.is_initialized() and len(args) > 0:
        thermo_sys.reset(messages=messages)

    # Process unnamed arguments (getters)
    for arg in args:
        if not isinstance(arg, str):
            raise TypeError(f"Unnamed arguments must be strings, got {type(arg)}")

        # Special handling for "WORM" - load WORM database
        if arg.upper() == "WORM":
            from ..data.worm import load_WORM
            success = load_WORM(keep_default=False, messages=messages)
            return_values.append(success)
            continue

        # Parse the argument to get slots (handle nested access with $)
        slots = arg.split('$')

        # Get the value from thermo_sys
        value = thermo_sys
        for slot in slots:
            # Handle OBIGT case-insensitively (R uses uppercase, Python uses lowercase)
            slot_lower = slot.lower()
            if hasattr(value, slot_lower):
                value = getattr(value, slot_lower)
            elif hasattr(value, slot):
                value = getattr(value, slot)
            elif isinstance(value, dict) and slot in value:
                value = value[slot]
            else:
                raise AttributeError(f"Attribute '{arg}' not found in thermo object")

        return_values.append(value)

    # Process named arguments (setters)
    setter_returns = {}

    # Ensure system is initialized if needed (before setting any properties)
    if not thermo_sys.is_initialized() and len(kwargs) > 0:
        thermo_sys.reset(messages=messages)

    for key, new_value in kwargs.items():
        # Parse the key to get slots
        slots = key.split('$')

        # Get the original value before modification
        orig_value = thermo_sys
        for slot in slots:
            # Handle case-insensitive attribute access (for OBIGT, etc.)
            slot_lower = slot.lower()
            if hasattr(orig_value, slot_lower):
                orig_value = getattr(orig_value, slot_lower)
            elif hasattr(orig_value, slot):
                orig_value = getattr(orig_value, slot)
            elif isinstance(orig_value, dict) and slot in orig_value:
                orig_value = orig_value[slot]
            else:
                raise AttributeError(f"Attribute '{key}' not found in thermo object")

        setter_returns[key] = orig_value

        # Set the new value
        if len(slots) == 1:
            # Direct attribute
            # Special handling for OBIGT - normalize index and handle refs
            if slots[0].upper() == 'OBIGT':
                # Handle OBIGT replacement with proper index normalization
                _set_obigt_data(thermo_sys, new_value)
            else:
                # Use lowercase version if it exists (Python convention)
                slot_lower = slots[0].lower()
                if hasattr(thermo_sys, slot_lower):
                    setattr(thermo_sys, slot_lower, new_value)
                else:
                    setattr(thermo_sys, slots[0], new_value)
        elif len(slots) == 2:
            # Nested attribute (e.g., opt$ideal.H)
            parent = getattr(thermo_sys, slots[0])
            if isinstance(parent, dict):
                parent[slots[1]] = new_value
            else:
                setattr(parent, slots[1], new_value)
        else:
            # Deeper nesting (if needed)
            current = thermo_sys
            for i, slot in enumerate(slots[:-1]):
                if hasattr(current, slot):
                    current = getattr(current, slot)
                elif isinstance(current, dict) and slot in current:
                    current = current[slot]

            # Set the final value
            final_slot = slots[-1]
            if isinstance(current, dict):
                current[final_slot] = new_value
            else:
                setattr(current, final_slot, new_value)

    # Determine return value based on R's behavior
    if len(kwargs) > 0:
        # If we had setters, return the original values as a named dict
        # In R, setters always return a named list
        if len(args) == 0:
            # Only setters - return dict (named list in R)
            return setter_returns
        else:
            # Mix of getters and setters - return all original values
            combined = {}
            for i, arg in enumerate(args):
                combined[arg] = return_values[i]
            combined.update(setter_returns)
            return combined
    else:
        # Only getters
        # Single unnamed argument returns the value directly
        if len(return_values) == 1:
            return return_values[0]
        return return_values

Access or modify the thermodynamic system data object.

This function provides a convenient interface to get or set parts of the thermodynamic system, similar to R's par() function for graphics parameters.

Parameters

*args : str or list of str
Names of attributes to retrieve (e.g., "element", "opt$ideal.H") For nested access, use "$" notation (e.g., "opt$E.units") Special values: - "WORM": Load the WORM thermodynamic database (Python-exclusive feature)
messages : bool, default True
Whether to print informational messages during operations
**kwargs : any
Named arguments to set attributes (e.g., element=new_df, opt={'E.units': 'cal'}) For nested attributes, use "$" in the name (e.g., **{"opt$ideal.H": False})

Returns

various
  • If no arguments: returns the ThermoSystem object
  • If single unnamed argument: returns the requested value
  • If multiple unnamed arguments: returns list of requested values
  • If named arguments: returns original values before modification

Examples

>>> import pychnosz
>>> # Get the entire thermo object
>>> ts = pychnosz.thermo()
>>> # Get a specific attribute
>>> elem = pychnosz.thermo("element")
>>> # Get nested attribute
>>> e_units = pychnosz.thermo("opt$E.units")
>>> # Get multiple attributes
>>> elem, buf = pychnosz.thermo("element", "buffer")
>>> # Set an attribute
>>> old_elem = pychnosz.thermo(element=new_element_df)
>>> # Set nested attribute
>>> old_units = pychnosz.thermo(**{"opt$ideal.H": False})
>>> # Load WORM database (Python-exclusive feature)
>>> pychnosz.thermo("WORM")
>>> # Suppress messages
>>> pychnosz.thermo("WORM", messages=False)

Notes

This function mimics the behavior of R CHNOSZ thermo() function, providing flexible access to the thermodynamic data object.

The "WORM" special argument is a Python-exclusive feature that loads the Water-Organic-Rock-Microbe thermodynamic database from the WORM-db GitHub repository.

def unicurve(logK: int | float | List[int | float],
species: str | List[str] | int | List[int],
coeff: int | float | List[int | float],
state: str | List[str],
pressures: float | List[float] = 1,
temperatures: float | List[float] = 25,
IS: float = 0,
minT: float = 0.1,
maxT: float = 100,
minP: float = 1,
maxP: float = 500,
tol: float | None = None,
solve: str = 'T',
messages: bool = True,
show: bool = True,
plot_it: bool = True,
width: int = 600,
height: int = 400,
res: int = 200) ‑> UnivariantResult | List[UnivariantResult]
Expand source code
def unicurve(logK: Union[float, int, List[Union[float, int]]],
             species: Union[str, List[str], int, List[int]],
             coeff: Union[int, float, List[Union[int, float]]],
             state: Union[str, List[str]],
             pressures: Union[float, List[float]] = 1,
             temperatures: Union[float, List[float]] = 25,
             IS: float = 0,
             minT: float = 0.1,
             maxT: float = 100,
             minP: float = 1,
             maxP: float = 500,
             tol: Optional[float] = None,
             solve: str = "T",
             messages: bool = True,
             show: bool = True,
             plot_it: bool = True,
             width: int = 600,
             height: int = 400,
             res: int = 200) -> Union[UnivariantResult, List[UnivariantResult]]:
    """
    Solve for temperatures or pressures of equilibration for a given logK value(s).

    This function calculates univariant curves useful for aqueous geothermometry
    and geobarometry. Given a measured equilibrium constant (logK) for a reaction,
    it solves for the temperatures (at specified pressures) or pressures (at
    specified temperatures) where the reaction would produce that logK value.

    The solver uses scipy.optimize.brentq (Brent's method), which combines
    bisection, secant, and inverse quadratic interpolation for efficient and
    robust convergence. This is ~100x faster than the original binary search
    algorithm while maintaining identical numerical accuracy.

    Parameters
    ----------
    logK : float, int, or list of float or int
        Logarithm (base 10) of the equilibrium constant(s). When a list is
        provided, each logK value is processed separately and a list of results
        is returned.
    species : str, int, or list of str or int
        Name, formula, or database index of species involved in the reaction
    coeff : int, float, or list
        Reaction stoichiometric coefficients (negative for reactants, positive for products)
    state : str or list of str
        Physical state(s) of species: "aq", "cr", "gas", "liq"
    pressures : float or list of float, default 1
        Pressure(s) in bars (used when solving for temperature)
    temperatures : float or list of float, default 25
        Temperature(s) in °C (used when solving for pressure)
    IS : float, default 0
        Ionic strength for activity corrections (mol/kg)
    minT : float, default 0.1
        Minimum temperature (°C) to search (ignored when solving for pressure)
    maxT : float, default 100
        Maximum temperature (°C) to search (ignored when solving for pressure)
    minP : float, default 1
        Minimum pressure (bar) to search (ignored when solving for temperature)
    maxP : float, default 500
        Maximum pressure (bar) to search (ignored when solving for temperature)
    tol : float, optional
        Tolerance for convergence. Default: 1/(10^(n+2)) where n is number of
        decimal places in logK, with maximum default of 1e-5
    solve : str, default "T"
        What to solve for: "T" for temperature or "P" for pressure
    messages : bool, default True
        Print informational messages
    show : bool, default True
        Display result table
    plot_it : bool, default True
        Display interactive plotly plot showing logK vs T (or P) with target logK
        as horizontal line and intersection points marked
    width : int, default 600
        Plot width in pixels (used if plot_it=True)
    height : int, default 400
        Plot height in pixels (used if plot_it=True)
    res : int, default 200
        Number of points to calculate for plotting the logK curve
        (used if plot_it=True)

    Returns
    -------
    UnivariantResult or list of UnivariantResult
        When logK is a single value: returns a UnivariantResult object.
        When logK is a list: returns a list of UnivariantResult objects.
        Each result contains:
        - reaction: DataFrame with reaction stoichiometry
        - out: DataFrame with solved T or P values and thermodynamic properties
        - warnings: List of warning messages

    Examples
    --------
    >>> from pychnosz import unicurve, reset
    >>> reset()
    >>>
    >>> # Solve for temperature: quartz dissolution
    >>> # SiO2(quartz) = SiO2(aq)
    >>> result = unicurve(logK=-2.71, species=["quartz", "SiO2"],
    ...                   state=["cr", "aq"], coeff=[-1, 1],
    ...                   pressures=200, minT=1, maxT=350)
    >>> print(result.out[["P", "T", "logK"]])
    >>>
    >>> # Solve for pressure: water dissociation
    >>> result = unicurve(logK=-14, species=["H2O", "H+", "OH-"],
    ...                   state=["liq", "aq", "aq"], coeff=[-1, 1, 1],
    ...                   temperatures=[25, 50, 75], solve="P",
    ...                   minP=1, maxP=1000)
    >>> print(result.out[["T", "P", "logK"]])

    Notes
    -----
    This function uses scipy.optimize.brentq for root-finding, which provides:
    - Guaranteed convergence if root is bracketed
    - Typical convergence in 5-15 function evaluations
    - ~100x speedup compared to custom binary search (1600 → 15 evaluations)
    - Identical numerical results to original implementation

    The algorithm also implements "warm start" optimization: when solving for
    multiple pressures/temperatures, previous solutions are used to intelligently
    bracket subsequent searches, further improving performance.

    References
    ----------
    Based on univariant.r from pyCHNOSZ by Grayson Boyer
    Optimized using Brent, R. P. (1973). Algorithms for Minimization without Derivatives.
    """
    # Track whether input was a single value or list
    single_logK_input = not isinstance(logK, list)

    # Ensure logK is a list for processing
    if single_logK_input:
        logK_list = [logK]
    else:
        logK_list = logK

    # Ensure species, state, and coeff are lists
    if not isinstance(species, list):
        species = [species]
    if not isinstance(state, list):
        state = [state]
    if not isinstance(coeff, list):
        coeff = [coeff]

    # Process each logK value
    results = []

    for this_logK in logK_list:
        result = UnivariantResult()

        # Set default tolerance based on logK precision
        if tol is None:
            # Count decimal places in logK
            logK_str = str(float(this_logK))
            if '.' in logK_str:
                n_decimals = len(logK_str.split('.')[1].rstrip('0'))
            else:
                n_decimals = 0
            this_tol = 10 ** (-(n_decimals + 2))
            if this_tol > 1e-5:
                this_tol = 1e-5
        else:
            this_tol = tol

        # Get reaction information from first subcrt call
        try:
            initial_calc = subcrt(species, coeff=coeff, state=state, T=25, P=1,
                                 exceed_Ttr=True, messages=False, show=False)
            result.reaction = initial_calc.reaction
        except Exception as e:
            if messages:
                warnings.warn(f"Error getting reaction information: {str(e)}")
            result.reaction = None

        if solve.upper() == "T":
            # Solve for temperature at given pressure(s)
            if not isinstance(pressures, list):
                pressures = [pressures]

            results_list = []
            prev_T = None  # For warm start optimization

            for i, pressure in enumerate(pressures):
                if messages:
                    print(f"Solving for T at P = {pressure} bar (logK = {this_logK})...")

                # Warm start: use previous solution to narrow search range if available
                current_minT = minT
                current_maxT = maxT
                if prev_T is not None and minT < prev_T < maxT:
                    # Center search around previous solution with a safety margin
                    # logK typically changes by ~0.006 per °C, so ±50°C should be safe
                    margin = 50
                    current_minT = max(minT, prev_T - margin)
                    current_maxT = min(maxT, prev_T + margin)
                    if messages:
                        print(f"  Using warm start: searching {current_minT:.1f} to {current_maxT:.1f}°C")

                result_dict = _solve_T_for_pressure(this_logK, species, state, coeff, pressure,
                                           IS, current_minT, current_maxT, this_tol,
                                           initial_guess=prev_T, messages=messages)

                # If warm start failed, try full range
                if result_dict['T'] is None and prev_T is not None:
                    if messages:
                        print(f"  Warm start failed, searching full range...")
                    result_dict = _solve_T_for_pressure(this_logK, species, state, coeff, pressure,
                                               IS, minT, maxT, this_tol, messages=messages)

                results_list.append(result_dict)

                # Update for next warm start
                if result_dict['T'] is not None:
                    prev_T = result_dict['T']

            result.out = pd.DataFrame(results_list)

        elif solve.upper() == "P":
            # Solve for pressure at given temperature(s)
            if not isinstance(temperatures, list):
                temperatures = [temperatures]

            results_list = []
            prev_P = None  # For warm start optimization

            for i, temperature in enumerate(temperatures):
                if messages:
                    print(f"Solving for P at T = {temperature} °C (logK = {this_logK})...")

                # Warm start: use previous solution to narrow search range if available
                current_minP = minP
                current_maxP = maxP
                if prev_P is not None and minP < prev_P < maxP:
                    # Center search around previous solution with a safety margin
                    # Pressure effects vary, use a generous ±500 bar margin
                    margin = 500
                    current_minP = max(minP, prev_P - margin)
                    current_maxP = min(maxP, prev_P + margin)
                    if messages:
                        print(f"  Using warm start: searching {current_minP:.0f} to {current_maxP:.0f} bar")

                result_dict = _solve_P_for_temperature(this_logK, species, state, coeff, temperature,
                                              IS, current_minP, current_maxP, this_tol,
                                              initial_guess=prev_P, messages=messages)

                # If warm start failed, try full range
                if result_dict['P'] is None and prev_P is not None:
                    if messages:
                        print(f"  Warm start failed, searching full range...")
                    result_dict = _solve_P_for_temperature(this_logK, species, state, coeff, temperature,
                                                  IS, minP, maxP, this_tol, messages=messages)

                results_list.append(result_dict)

                # Update for next warm start
                if result_dict['P'] is not None:
                    prev_P = result_dict['P']

            result.out = pd.DataFrame(results_list)

        else:
            raise ValueError(f"solve must be 'T' or 'P', got '{solve}'")

        # Create interactive plot if requested
        if plot_it:
            if not PLOTLY_AVAILABLE:
                warnings.warn("plotly is not installed. Set plot_it=False to suppress this warning, "
                             "or install plotly with: pip install plotly")
            else:
                result.fig = _create_unicurve_plot(this_logK, species, state, coeff, result, solve,
                                                   minT, maxT, minP, maxP, IS, width, height, res, messages)

        # Display result if requested
        if show and result.out is not None:
            try:
                from IPython.display import display
                if result.reaction is not None:
                    print("\nReaction:")
                    display(result.reaction)
                print(f"\nResults (logK = {this_logK}):")
                display(result.out)
            except ImportError:
                # Not in Jupyter, just print
                if result.reaction is not None:
                    print("\nReaction:")
                    print(result.reaction)
                print(f"\nResults (logK = {this_logK}):")
                print(result.out)

        # Add this result to the list
        results.append(result)

    # Return single result or list based on input
    if single_logK_input:
        return results[0]
    else:
        return results

Solve for temperatures or pressures of equilibration for a given logK value(s).

This function calculates univariant curves useful for aqueous geothermometry and geobarometry. Given a measured equilibrium constant (logK) for a reaction, it solves for the temperatures (at specified pressures) or pressures (at specified temperatures) where the reaction would produce that logK value.

The solver uses scipy.optimize.brentq (Brent's method), which combines bisection, secant, and inverse quadratic interpolation for efficient and robust convergence. This is ~100x faster than the original binary search algorithm while maintaining identical numerical accuracy.

Parameters

logK : float, int, or list of float or int
Logarithm (base 10) of the equilibrium constant(s). When a list is provided, each logK value is processed separately and a list of results is returned.
species : str, int, or list of str or int
Name, formula, or database index of species involved in the reaction
coeff : int, float, or list
Reaction stoichiometric coefficients (negative for reactants, positive for products)
state : str or list of str
Physical state(s) of species: "aq", "cr", "gas", "liq"
pressures : float or list of float, default 1
Pressure(s) in bars (used when solving for temperature)
temperatures : float or list of float, default 25
Temperature(s) in °C (used when solving for pressure)
IS : float, default 0
Ionic strength for activity corrections (mol/kg)
minT : float, default 0.1
Minimum temperature (°C) to search (ignored when solving for pressure)
maxT : float, default 100
Maximum temperature (°C) to search (ignored when solving for pressure)
minP : float, default 1
Minimum pressure (bar) to search (ignored when solving for temperature)
maxP : float, default 500
Maximum pressure (bar) to search (ignored when solving for temperature)
tol : float, optional
Tolerance for convergence. Default: 1/(10^(n+2)) where n is number of decimal places in logK, with maximum default of 1e-5
solve : str, default "T"
What to solve for: "T" for temperature or "P" for pressure
messages : bool, default True
Print informational messages
show : bool, default True
Display result table
plot_it : bool, default True
Display interactive plotly plot showing logK vs T (or P) with target logK as horizontal line and intersection points marked
width : int, default 600
Plot width in pixels (used if plot_it=True)
height : int, default 400
Plot height in pixels (used if plot_it=True)
res : int, default 200
Number of points to calculate for plotting the logK curve (used if plot_it=True)

Returns

UnivariantResult or list of UnivariantResult
When logK is a single value: returns a UnivariantResult object. When logK is a list: returns a list of UnivariantResult objects. Each result contains: - reaction: DataFrame with reaction stoichiometry - out: DataFrame with solved T or P values and thermodynamic properties - warnings: List of warning messages

Examples

>>> from pychnosz import unicurve, reset
>>> reset()
>>>
>>> # Solve for temperature: quartz dissolution
>>> # SiO2(quartz) = SiO2(aq)
>>> result = unicurve(logK=-2.71, species=["quartz", "SiO2"],
...                   state=["cr", "aq"], coeff=[-1, 1],
...                   pressures=200, minT=1, maxT=350)
>>> print(result.out[["P", "T", "logK"]])
>>>
>>> # Solve for pressure: water dissociation
>>> result = unicurve(logK=-14, species=["H2O", "H+", "OH-"],
...                   state=["liq", "aq", "aq"], coeff=[-1, 1, 1],
...                   temperatures=[25, 50, 75], solve="P",
...                   minP=1, maxP=1000)
>>> print(result.out[["T", "P", "logK"]])

Notes

This function uses scipy.optimize.brentq for root-finding, which provides: - Guaranteed convergence if root is bracketed - Typical convergence in 5-15 function evaluations - ~100x speedup compared to custom binary search (1600 → 15 evaluations) - Identical numerical results to original implementation

The algorithm also implements "warm start" optimization: when solving for multiple pressures/temperatures, previous solutions are used to intelligently bracket subsequent searches, further improving performance.

References

Based on univariant.r from pyCHNOSZ by Grayson Boyer Optimized using Brent, R. P. (1973). Algorithms for Minimization without Derivatives.

def univariant_TP(logK: int | float | List[int | float],
species: str | List[str] | int | List[int],
coeff: int | float | List[int | float],
state: str | List[str],
Trange: List[float],
Prange: List[float],
IS: float = 0,
xlim: List[float] | None = None,
ylim: List[float] | None = None,
line_type: str = 'markers+lines',
tol: float | None = None,
title: str | None = None,
res: int = 10,
width: int = 500,
height: int = 400,
save_as: str | None = None,
save_format: str = 'png',
save_scale: float = 1,
show: bool = False,
messages: bool = False,
parallel: bool = True,
plot_it: bool = True) ‑> List[UnivariantResult]
Expand source code
def univariant_TP(logK: Union[float, int, List[Union[float, int]]],
                  species: Union[str, List[str], int, List[int]],
                  coeff: Union[int, float, List[Union[int, float]]],
                  state: Union[str, List[str]],
                  Trange: List[float],
                  Prange: List[float],
                  IS: float = 0,
                  xlim: Optional[List[float]] = None,
                  ylim: Optional[List[float]] = None,
                  line_type: str = "markers+lines",
                  tol: Optional[float] = None,
                  title: Optional[str] = None,
                  res: int = 10,
                  width: int = 500,
                  height: int = 400,
                  save_as: Optional[str] = None,
                  save_format: str = "png",
                  save_scale: float = 1,
                  show: bool = False,
                  messages: bool = False,
                  parallel: bool = True,
                  plot_it: bool = True) -> List[UnivariantResult]:
    """
    Solve for temperatures and pressures of equilibration for given logK value(s)
    and produce an interactive T-P diagram.

    This function calculates univariant curves in temperature-pressure (T-P) space
    for one or more logK values. For each pressure in a range, it solves for the
    temperature where the reaction achieves the target logK. The resulting curves
    show phase boundaries or equilibrium conditions in T-P space.

    Parameters
    ----------
    logK : float, int, or list
        Logarithm (base 10) of equilibrium constant(s). Multiple values produce
        multiple curves on the same plot.
    species : str, int, or list of str or int
        Name, formula, or database index of species involved in the reaction
    coeff : int, float, or list
        Reaction stoichiometric coefficients (negative for reactants, positive for products)
    state : str or list of str
        Physical state(s) of species: "aq", "cr", "gas", "liq"
    Trange : list of two floats
        [min, max] temperature range (°C) to search for solutions
    Prange : list of two floats
        [min, max] pressure range (bar) to calculate along
    IS : float, default 0
        Ionic strength for activity corrections (mol/kg)
    xlim : list of two floats, optional
        [min, max] range for x-axis (temperature) in plot
    ylim : list of two floats, optional
        [min, max] range for y-axis (pressure) in plot
    line_type : str, default "markers+lines"
        Plotly line type: "markers+lines", "markers", or "lines"
    tol : float, optional
        Convergence tolerance. Default: 1/(10^(n+2)) where n is decimal places in logK
    title : str, optional
        Plot title. Default: auto-generated from reaction
    res : int, default 10
        Number of pressure points to calculate along the curve
    width : int, default 500
        Plot width in pixels
    height : int, default 400
        Plot height in pixels
    save_as : str, optional
        Filename to save plot (without extension)
    save_format : str, default "png"
        Save format: "png", "jpg", "jpeg", "webp", "svg", "pdf", "html"
    save_scale : float, default 1
        Scale factor for saved plot
    show : bool, default False
        Display subcrt result tables
    messages : bool, default False
        Print informational messages
    parallel : bool, default True
        Use parallel processing across multiple logK values for faster computation.
        Utilizes multiple CPU cores when processing multiple logK curves.
    plot_it : bool, default True
        Display the plot

    Returns
    -------
    list of UnivariantResult
        List of UnivariantResult objects, one for each logK value.
        Each contains reaction information and T-P curve data.

    Examples
    --------
    >>> from pychnosz import univariant_TP, reset
    >>> reset()
    >>>
    >>> # Calcite-aragonite phase boundary
    >>> result = univariant_TP(
    ...     logK=0,
    ...     species=["calcite", "aragonite"],
    ...     state=["cr", "cr"],
    ...     coeff=[-1, 1],
    ...     Trange=[0, 700],
    ...     Prange=[2000, 16000]
    ... )
    >>>
    >>> # Multiple curves for K-feldspar stability
    >>> result = univariant_TP(
    ...     logK=[-8, -6, -4, -2],
    ...     species=["K-feldspar", "kaolinite", "H2O", "SiO2", "muscovite"],
    ...     state=["cr", "cr", "liq", "aq", "cr"],
    ...     coeff=[-1, -1, 1, 2, 1],
    ...     Trange=[0, 350],
    ...     Prange=[1, 5000],
    ...     res=20
    ... )

    Notes
    -----
    This function creates T-P diagrams by:
    1. Generating a range of pressures from Prange[0] to Prange[1]
    2. For each pressure, solving for T where logK matches the target
    3. Plotting the resulting T-P points as a curve

    For multiple logK values, each curve represents a different equilibrium
    condition. This is useful for:
    - Phase diagrams (e.g., mineral stability fields)
    - Isopleths (lines of constant logK)
    - Reaction boundaries

    Requires plotly for interactive plotting. If plotly is not installed,
    set plot_it=False to just return the data without plotting.

    References
    ----------
    Based on univariant_TP from pyCHNOSZ by Grayson Boyer
    """

    # Check if plotly is available
    if plot_it and not PLOTLY_AVAILABLE:
        warnings.warn("plotly is not installed. Set plot_it=False to suppress this warning, "
                     "or install plotly with: pip install plotly")
        plot_it = False

    # Ensure logK is a list
    if not isinstance(logK, list):
        logK = [logK]

    # Create plotly figure
    if plot_it:
        fig = go.Figure()

    output = []

    # Generate pressure array
    pressures = np.linspace(Prange[0], Prange[1], res)

    # Process each logK value (in parallel if enabled)
    if parallel and len(logK) > 1:
        # Parallel processing
        max_workers = min(len(logK), multiprocessing.cpu_count())

        # Prepare arguments for each logK value
        args_list = [
            (this_logK, species, state, coeff, pressures, Trange, IS, tol, show, messages)
            for this_logK in logK
        ]

        # Process in parallel
        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            # Submit all tasks
            future_to_logK = {
                executor.submit(_process_single_logK, args): args[0]
                for args in args_list
            }

            # Collect results as they complete (maintains order via logK list)
            results_dict = {}
            for future in as_completed(future_to_logK):
                this_logK = future_to_logK[future]
                try:
                    out = future.result()
                    results_dict[this_logK] = out
                except Exception as e:
                    if messages:
                        print(f"Error processing logK={this_logK}: {str(e)}")
                    # Create empty result
                    results_dict[this_logK] = None

            # Reorder results to match input logK order
            for this_logK in logK:
                out = results_dict.get(this_logK)
                if out is not None:
                    output.append(out)

                    # Add to plot if we have valid data
                    if plot_it and not out.out['T'].isnull().all():
                        fig.add_trace(go.Scatter(
                            x=out.out['T'],
                            y=out.out['P'],
                            mode=line_type,
                            name=f"logK={this_logK}",
                            text=[f"logK={this_logK}" for _ in range(len(out.out['T']))],
                            hovertemplate='%{text}<br>T, °C=%{x:.2f}<br>P, bar=%{y:.2f}<extra></extra>',
                        ))
                    elif out.out['T'].isnull().all():
                        if messages:
                            print(f"Could not find any T or P values in this range that correspond to a logK value of {this_logK}")

    else:
        # Sequential processing (original code)
        for this_logK in logK:
            # Set tolerance if not specified
            if tol is None:
                logK_str = str(float(this_logK))
                if '.' in logK_str:
                    n_decimals = len(logK_str.split('.')[1].rstrip('0'))
                else:
                    n_decimals = 0
                this_tol = 10 ** (-(n_decimals + 2))
                if this_tol > 1e-5:
                    this_tol = 1e-5
            else:
                this_tol = tol

            # Solve for T at each pressure
            out = unicurve(
                solve="T",
                logK=this_logK,
                species=species,
                state=state,
                coeff=coeff,
                pressures=list(pressures),
                minT=Trange[0],
                maxT=Trange[1],
                IS=IS,
                tol=this_tol,
                show=show,
                messages=messages,
                plot_it=False  # Don't plot individual curves - univariant_TP makes its own plot
            )

            # Add to plot if we have valid data
            if plot_it and not out.out['T'].isnull().all():
                fig.add_trace(go.Scatter(
                    x=out.out['T'],
                    y=out.out['P'],
                    mode=line_type,
                    name=f"logK={this_logK}",
                    text=[f"logK={this_logK}" for _ in range(len(out.out['T']))],
                    hovertemplate='%{text}<br>T, °C=%{x:.2f}<br>P, bar=%{y:.2f}<extra></extra>',
                ))
            elif out.out['T'].isnull().all():
                if messages:
                    print(f"Could not find any T or P values in this range that correspond to a logK value of {this_logK}")

            output.append(out)

    # Generate plot title if not specified
    if plot_it:
        if title is None and len(output) > 0 and output[0].reaction is not None:
            react_grid = output[0].reaction

            # Build reaction string
            reactants = []
            products = []
            for i, row in react_grid.iterrows():
                coeff_val = row['coeff']
                name = row['name'] if row['name'] != 'water' else 'H2O'

                if coeff_val < 0:
                    coeff_str = str(int(-coeff_val)) if -coeff_val != 1 else ""
                    reactants.append(f"{coeff_str} {name}".strip())
                elif coeff_val > 0:
                    coeff_str = str(int(coeff_val)) if coeff_val != 1 else ""
                    products.append(f"{coeff_str} {name}".strip())

            title = " + ".join(reactants) + " = " + " + ".join(products)

        # Update layout
        fig.update_layout(
            template="simple_white",
            title=str(title) if title else "",
            xaxis_title="T, °C",
            yaxis_title="P, bar",
            width=width,
            height=height,
            hoverlabel=dict(bgcolor="white"),
        )

        # Set axis limits if specified
        if xlim is not None:
            fig.update_xaxes(range=xlim)
        if ylim is not None:
            fig.update_yaxes(range=ylim)

        # Configure plot controls
        config = {
            'displaylogo': False,
            'modeBarButtonsToRemove': ['resetScale2d', 'toggleSpikelines'],
            'toImageButtonOptions': {
                'format': save_format,
                'filename': save_as if save_as else 'univariant_TP',
                'height': height,
                'width': width,
                'scale': save_scale,
            },
        }

        # Save plot if requested
        if save_as is not None:
            full_filename = f"{save_as}.{save_format}"
            if save_format == 'html':
                fig.write_html(full_filename)
            else:
                fig.write_image(full_filename, format=save_format,
                              width=width, height=height, scale=save_scale)
            if messages:
                print(f"Plot saved to {full_filename}")

        # Display plot
        fig.show(config=config)

        # Store figure in all result objects
        for out in output:
            out.fig = fig

    return output

Solve for temperatures and pressures of equilibration for given logK value(s) and produce an interactive T-P diagram.

This function calculates univariant curves in temperature-pressure (T-P) space for one or more logK values. For each pressure in a range, it solves for the temperature where the reaction achieves the target logK. The resulting curves show phase boundaries or equilibrium conditions in T-P space.

Parameters

logK : float, int, or list
Logarithm (base 10) of equilibrium constant(s). Multiple values produce multiple curves on the same plot.
species : str, int, or list of str or int
Name, formula, or database index of species involved in the reaction
coeff : int, float, or list
Reaction stoichiometric coefficients (negative for reactants, positive for products)
state : str or list of str
Physical state(s) of species: "aq", "cr", "gas", "liq"
Trange : list of two floats
[min, max] temperature range (°C) to search for solutions
Prange : list of two floats
[min, max] pressure range (bar) to calculate along
IS : float, default 0
Ionic strength for activity corrections (mol/kg)
xlim : list of two floats, optional
[min, max] range for x-axis (temperature) in plot
ylim : list of two floats, optional
[min, max] range for y-axis (pressure) in plot
line_type : str, default "markers+lines"
Plotly line type: "markers+lines", "markers", or "lines"
tol : float, optional
Convergence tolerance. Default: 1/(10^(n+2)) where n is decimal places in logK
title : str, optional
Plot title. Default: auto-generated from reaction
res : int, default 10
Number of pressure points to calculate along the curve
width : int, default 500
Plot width in pixels
height : int, default 400
Plot height in pixels
save_as : str, optional
Filename to save plot (without extension)
save_format : str, default "png"
Save format: "png", "jpg", "jpeg", "webp", "svg", "pdf", "html"
save_scale : float, default 1
Scale factor for saved plot
show : bool, default False
Display subcrt result tables
messages : bool, default False
Print informational messages
parallel : bool, default True
Use parallel processing across multiple logK values for faster computation. Utilizes multiple CPU cores when processing multiple logK curves.
plot_it : bool, default True
Display the plot

Returns

list of UnivariantResult
List of UnivariantResult objects, one for each logK value. Each contains reaction information and T-P curve data.

Examples

>>> from pychnosz import univariant_TP, reset
>>> reset()
>>>
>>> # Calcite-aragonite phase boundary
>>> result = univariant_TP(
...     logK=0,
...     species=["calcite", "aragonite"],
...     state=["cr", "cr"],
...     coeff=[-1, 1],
...     Trange=[0, 700],
...     Prange=[2000, 16000]
... )
>>>
>>> # Multiple curves for K-feldspar stability
>>> result = univariant_TP(
...     logK=[-8, -6, -4, -2],
...     species=["K-feldspar", "kaolinite", "H2O", "SiO2", "muscovite"],
...     state=["cr", "cr", "liq", "aq", "cr"],
...     coeff=[-1, -1, 1, 2, 1],
...     Trange=[0, 350],
...     Prange=[1, 5000],
...     res=20
... )

Notes

This function creates T-P diagrams by: 1. Generating a range of pressures from Prange[0] to Prange[1] 2. For each pressure, solving for T where logK matches the target 3. Plotting the resulting T-P points as a curve

For multiple logK values, each curve represents a different equilibrium condition. This is useful for: - Phase diagrams (e.g., mineral stability fields) - Isopleths (lines of constant logK) - Reaction boundaries

Requires plotly for interactive plotting. If plotly is not installed, set plot_it=False to just return the data without plotting.

References

Based on univariant_TP from pyCHNOSZ by Grayson Boyer

def water(property: str | List[str] | None = None,
T: float | numpy.ndarray | List[float] = 298.15,
P: float | List[float] | numpy.ndarray | str = 1.0,
Psat_floor: float | None = 1.0,
model: str | None = None,
messages: bool = True) ‑> str | float | numpy.ndarray | Dict[str, Any]
Expand source code
def water(property: Optional[Union[str, List[str]]] = None,
          T: Union[float, np.ndarray, List[float]] = 298.15,
          P: Union[float, np.ndarray, List[float], str] = 1.0,
          Psat_floor: Union[float, None] = 1.0,
          model: Optional[str] = None,
          messages: bool = True) -> Union[str, float, np.ndarray, Dict[str, Any]]:
    """
    Calculate thermodynamic and electrostatic properties of liquid H2O.
    
    This is the main water function that provides the same interface as the
    R CHNOSZ water() function, with support for multiple water models.
    
    Parameters
    ----------
    property : str, list of str, or None
        Properties to calculate. If None, returns current water model.
        If water model name (SUPCRT92, IAPWS95, DEW), sets the water model.
        Available properties depend on the water model used.
    T : float or array-like
        Temperature in Kelvin
    P : float, array-like, or "Psat"
        Pressure in bar, or "Psat" for saturation pressure
    Psat_floor : float or None
        Minimum pressure floor for Psat calculations (SUPCRT92 only)
    model : str, optional
        Override the default water model for this calculation
    messages : bool, default True
        Whether to print informational messages

    Returns
    -------
    str, float, array, or dict
        Current water model name, single property value, array of values,
        or dictionary with calculated properties
        
    Examples
    --------
    >>> import pychnosz
    >>> pychnosz.reset()
    >>> 
    >>> # Get current water model
    >>> model = pychnosz.water()
    >>> print(model)  # 'SUPCRT92'
    >>> 
    >>> # Set water model
    >>> old_model = pychnosz.water('IAPWS95')
    >>> 
    >>> # Calculate single property
    >>> density = pychnosz.water('rho', T=298.15, P=1.0)
    >>> 
    >>> # Calculate multiple properties
    >>> props = pychnosz.water(['rho', 'epsilon'], T=298.15, P=1.0)
    >>> 
    >>> # Temperature array
    >>> temps = np.array([273.15, 298.15, 373.15])
    >>> densities = pychnosz.water('rho', T=temps, P=1.0)
    >>> 
    >>> # Saturation pressure
    >>> psat = pychnosz.water('Psat', T=373.15)
    """
    
    # Get thermo system
    thermo_system = thermo()

    # Ensure thermo is initialized before accessing/setting options
    # This prevents reset() from clearing options later
    if not thermo_system.is_initialized():
        thermo_system.reset(messages=False)

    # Case 1: Query current water model
    if property is None:
        return thermo_system.get_option('water', 'SUPCRT92')

    # Case 2: Set water model
    if isinstance(property, str) and property.upper() in ['SUPCRT92', 'SUPCRT', 'IAPWS95', 'IAPWS', 'DEW']:
        old_model = thermo_system.get_option('water', 'SUPCRT92')

        # Normalize model name
        if property.upper() in ['SUPCRT92', 'SUPCRT']:
            new_model = 'SUPCRT92'
        elif property.upper() in ['IAPWS95', 'IAPWS']:
            new_model = 'IAPWS95'
        elif property.upper() == 'DEW':
            new_model = 'DEW'

        thermo_system.set_option('water', new_model)
        if messages:
            print(f"water: setting water model to {new_model}")
        return  # Return None instead of the old model
    
    # Case 3: Calculate properties
    # Determine which model to use
    if model is not None:
        water_model = model.upper()
    else:
        water_model = thermo_system.get_option('water', 'SUPCRT92').upper()
    
    # Normalize model names
    if water_model in ['SUPCRT92', 'SUPCRT']:
        water_model = 'SUPCRT92'
    elif water_model in ['IAPWS95', 'IAPWS']:
        water_model = 'IAPWS95'
    elif water_model == 'DEW':
        water_model = 'DEW'
    else:
        warnings.warn(f"Unknown water model '{water_model}', using SUPCRT92")
        water_model = 'SUPCRT92'
    
    # Convert inputs
    T = np.atleast_1d(np.asarray(T, dtype=float))
    
    if isinstance(P, str):
        P_input = P
    else:
        P_input = np.atleast_1d(np.asarray(P, dtype=float))
        # Make T and P same length
        if len(P_input) < len(T):
            P_input = np.resize(P_input, len(T))
        elif len(T) < len(P_input):
            T = np.resize(T, len(P_input))
    
    # Call appropriate water model
    try:
        if water_model == 'SUPCRT92':
            result = _call_supcrt92(property, T, P_input, Psat_floor)
        elif water_model == 'IAPWS95':
            result = _call_iapws95(property, T, P_input, Psat_floor)  
        elif water_model == 'DEW':
            result = _call_dew(property, T, P_input)
        else:
            raise ValueError(f"Unsupported water model: {water_model}")
            
    except Exception as e:
        raise WaterModelError(f"Error calculating water properties with {water_model} model: {e}")
    
    # Apply Psat rounding to match R CHNOSZ behavior
    # Round Psat values to 4 decimal places (round up to ensure liquid phase)
    result = _apply_psat_rounding(result, property)
    
    return result

Calculate thermodynamic and electrostatic properties of liquid H2O.

This is the main water function that provides the same interface as the R CHNOSZ water() function, with support for multiple water models.

Parameters

property : str, list of str, or None
Properties to calculate. If None, returns current water model. If water model name (SUPCRT92, IAPWS95, DEW), sets the water model. Available properties depend on the water model used.
T : float or array-like
Temperature in Kelvin
P : float, array-like, or "Psat"
Pressure in bar, or "Psat" for saturation pressure
Psat_floor : float or None
Minimum pressure floor for Psat calculations (SUPCRT92 only)
model : str, optional
Override the default water model for this calculation
messages : bool, default True
Whether to print informational messages

Returns

str, float, array, or dict
Current water model name, single property value, array of values, or dictionary with calculated properties

Examples

>>> import pychnosz
>>> pychnosz.reset()
>>> 
>>> # Get current water model
>>> model = pychnosz.water()
>>> print(model)  # 'SUPCRT92'
>>> 
>>> # Set water model
>>> old_model = pychnosz.water('IAPWS95')
>>> 
>>> # Calculate single property
>>> density = pychnosz.water('rho', T=298.15, P=1.0)
>>> 
>>> # Calculate multiple properties
>>> props = pychnosz.water(['rho', 'epsilon'], T=298.15, P=1.0)
>>> 
>>> # Temperature array
>>> temps = np.array([273.15, 298.15, 373.15])
>>> densities = pychnosz.water('rho', T=temps, P=1.0)
>>> 
>>> # Saturation pressure
>>> psat = pychnosz.water('Psat', T=373.15)
def water_lines(eout: Dict[str, Any],
which: str | List[str] = ['oxidation', 'reduction'],
lty: str | int = 2,
lwd: float = 1,
col: str | None = None,
plot_it: bool = True,
messages: bool = True) ‑> Dict[str, Any]
Expand source code
def water_lines(eout: Dict[str, Any],
                which: Union[str, List[str]] = ['oxidation', 'reduction'],
                lty: Union[int, str] = 2,
                lwd: float = 1,
                col: Optional[str] = None,
                plot_it: bool = True,
                messages: bool = True) -> Dict[str, Any]:
    """
    Draw water stability limits for Eh-pH, logfO2-pH, logfO2-T or Eh-T diagrams.

    This function adds lines showing the oxidation and reduction stability limits
    of water to diagrams. Above the oxidation line, water breaks down to O2.
    Below the reduction line, water breaks down to H2.

    Parameters
    ----------
    eout : dict
        Output from affinity(), equilibrate(), or diagram()
    which : str or list of str, default ['oxidation', 'reduction']
        Which line(s) to draw: 'oxidation', 'reduction', or both
    lty : int or str, default 2
        Line style (matplotlib linestyle or numeric code)
    lwd : float, default 1
        Line width
    col : str, optional
        Line color (matplotlib color spec). If None, uses current foreground color
    plot_it : bool, default True
        Whether to plot the lines and display the figure. When True, the lines
        are added to the diagram and the figure is displayed (useful when the
        original diagram was created with plot_it=False). When False, only
        calculates and returns the water line coordinates without plotting.

    Returns
    -------
    dict
        Dictionary containing all keys from the input diagram (including 'fig', 'ax',
        'plotvar', 'plotvals', 'names', 'predominant', etc. if present) plus the
        following water line specific keys:
        - xpoints: x-axis values
        - y_oxidation: y values for oxidation line (or None)
        - y_reduction: y values for reduction line (or None)
        - swapped: whether axes were swapped

    Examples
    --------
    >>> # Add water lines to an existing displayed diagram
    >>> basis(["Fe+2", "SO4-2", "H2O", "H+", "e-"], [0, math.log10(3), math.log10(0.75), 999, 999])
    >>> species(["rhomboclase", "ferricopiapite", "hydronium jarosite", "goethite", "melanterite", "pyrite"])
    >>> a = affinity(pH=[-1, 4, 256], pe=[-5, 23, 256])
    >>> d = diagram(a, main="Fe-S-O-H, after Majzlan et al., 2006")
    >>> water_lines(d, lwd=2)

    >>> # Add water lines and display when diagram was created with plot_it=False
    >>> d = diagram(a, main="Fe-S-O-H", plot_it=False)
    >>> water_lines(d, lwd=2)  # This will display the figure with water lines

    Notes
    -----
    This function only works on diagrams with a redox variable (Eh, pe, O2, or H2)
    on one axis and pH, T, P, or another non-redox variable on the other axis.
    For 1-D diagrams, vertical lines are drawn.
    """

    # Import here to avoid circular imports
    from ..utils.units import convert, envert
    from ..core.subcrt import subcrt

    # Create a deep copy of the input to preserve all diagram information
    # This allows us to return all the original keys plus water line data
    result = copy_plot(eout)

    # Detect if this is a Plotly figure (interactive diagram)
    is_plotly = False
    if 'fig' in result and result['fig'] is not None:
        is_plotly = hasattr(result['fig'], 'add_trace') and hasattr(result['fig'], 'update_layout')

    # Ensure which is a list
    if isinstance(which, str):
        which = [which]

    # Get number of variables used in affinity()
    nvar1 = len(result['vars'])

    # Determine actual number of variables from array dimensions
    # Check both loga.equil (equilibrate output) and values (affinity output)
    if 'loga_equil' in result or 'loga.equil' in result:
        loga_key = 'loga_equil' if 'loga_equil' in result else 'loga.equil'
        first_val = result[loga_key][0] if isinstance(result[loga_key], list) else list(result[loga_key].values())[0]
    else:
        first_val = list(result['values'].values())[0] if isinstance(result['values'], dict) else result['values'][0]

    if hasattr(first_val, 'shape'):
        dim = first_val.shape
    elif hasattr(first_val, '__len__'):
        dim = (len(first_val),)
    else:
        dim = ()

    nvar2 = len(dim)

    # We only work on diagrams with 1 or 2 variables
    if nvar1 not in [1, 2] or nvar2 not in [1, 2]:
        result.update({'xpoints': None, 'y_oxidation': None, 'y_reduction': None, 'swapped': False})
        return result

    # Get variables from result
    vars_list = result['vars'].copy()

    # If needed, swap axes so redox variable is on y-axis
    # Also do this for 1-D diagrams
    if len(vars_list) == 1:
        vars_list.append('nothing')

    swapped = False
    if vars_list[1] in ['T', 'P', 'nothing']:
        vars_list = list(reversed(vars_list))
        vals_dict = {vars_list[0]: result['vals'][vars_list[0]]} if vars_list[0] != 'nothing' else {}
        if len(result['vars']) > 1:
            vals_dict[vars_list[1]] = result['vals'][vars_list[1]]
        swapped = True
    else:
        vals_dict = result['vals']

    xaxis = vars_list[0]
    yaxis = vars_list[1]
    xpoints = np.asarray(vals_dict[xaxis]) if xaxis in vals_dict else np.array([0])

    # Make xaxis "nothing" if it is not pH, T, or P
    # (so that horizontal water lines can be drawn for any non-redox variable on the x-axis)
    if xaxis not in ['pH', 'T', 'P']:
        xaxis = 'nothing'

    # T and P are constants unless they are plotted on one of the axes
    T = result['T']
    if vars_list[0] == 'T':
        T = envert(xpoints, 'K')
    P = result['P']
    if vars_list[0] == 'P':
        P = envert(xpoints, 'bar')

    # Handle the case where P is "Psat" - keep it as is for subcrt
    # (subcrt knows how to handle "Psat")

    # logaH2O is 0 unless given in result['basis']
    basis_df = result['basis']
    if 'H2O' in basis_df.index:
        logaH2O = float(basis_df.loc['H2O', 'logact'])
    else:
        logaH2O = 0

    # pH is 7 unless given in eout['basis'] or plotted on one of the axes
    if vars_list[0] == 'pH':
        pH = xpoints
    elif 'H+' in basis_df.index:
        minuspH = basis_df.loc['H+', 'logact']
        # Special treatment for non-numeric value (happens when a buffer is used)
        try:
            pH = -float(minuspH)
        except (ValueError, TypeError):
            pH = np.nan
    else:
        pH = 7

    # O2 state is gas unless given in eout['basis']
    O2state = 'gas'
    if 'O2' in basis_df.index:
        O2state = basis_df.loc['O2', 'state']

    # H2 state is gas unless given in eout['basis']
    H2state = 'gas'
    if 'H2' in basis_df.index:
        H2state = basis_df.loc['H2', 'state']

    # Where the calculated values will go
    y_oxidation = None
    y_reduction = None

    if xaxis in ['pH', 'T', 'P', 'nothing'] and yaxis in ['Eh', 'pe', 'O2', 'H2']:
        # Eh/pe/logfO2/logaO2/logfH2/logaH2 vs pH/T/P

        # Reduction line (H2O + e- = 1/2 H2 + OH-)
        if 'reduction' in which:
            logfH2 = logaH2O  # usually 0

            if yaxis == 'H2':
                # Calculate equilibrium constant for gas-aqueous conversion if needed
                logK = subcrt(['H2', 'H2'], [-1, 1], ['gas', H2state], T=T, P=P, convert=False, messages=messages, show=False).out['logK']
                # This is logfH2 if H2state == "gas", or logaH2 if H2state == "aq"
                logfH2 = logfH2 + logK
                # Broadcast to match xpoints length
                if isinstance(logfH2, (int, float)):
                    y_reduction = np.full_like(xpoints, logfH2)
                else:
                    logfH2_val = float(logfH2.iloc[0]) if hasattr(logfH2, 'iloc') else float(logfH2[0])
                    y_reduction = np.full_like(xpoints, logfH2_val)
            else:
                # Calculate logfO2 from H2O = 1/2 O2 + H2
                logK = subcrt(['H2O', 'O2', 'H2'], [-1, 0.5, 1], ['liq', O2state, 'gas'], T=T, P=P, convert=False, messages=messages, show=False).out['logK']
                # This is logfO2 if O2state == "gas", or logaO2 if O2state == "aq"
                logfO2 = 2 * (logK - logfH2 + logaH2O)

                if yaxis == 'O2':
                    # Broadcast to match xpoints length
                    if isinstance(logfO2, (int, float)):
                        y_reduction = np.full_like(xpoints, logfO2)
                    else:
                        logfO2_val = float(logfO2.iloc[0]) if hasattr(logfO2, 'iloc') else float(logfO2[0])
                        y_reduction = np.full_like(xpoints, logfO2_val)
                elif yaxis == 'Eh':
                    y_reduction = convert(logfO2, 'E0', T=T, P=P, pH=pH, logaH2O=logaH2O, messages=messages)
                elif yaxis == 'pe':
                    Eh_val = convert(logfO2, 'E0', T=T, P=P, pH=pH, logaH2O=logaH2O, messages=messages)
                    y_reduction = convert(Eh_val, 'pe', T=T, messages=messages)

        # Oxidation line (H2O = 1/2 O2 + 2H+ + 2e-)
        if 'oxidation' in which:
            logfO2 = logaH2O  # usually 0

            if yaxis == 'H2':
                # Calculate logfH2 from H2O = 1/2 O2 + H2
                logK = subcrt(['H2O', 'O2', 'H2'], [-1, 0.5, 1], ['liq', 'gas', H2state], T=T, P=P, convert=False, messages=messages, show=False).out['logK']
                # This is logfH2 if H2state == "gas", or logaH2 if H2state == "aq"
                logfH2 = logK - 0.5*logfO2 + logaH2O
                # Broadcast to match xpoints length
                if isinstance(logfH2, (int, float)):
                    y_oxidation = np.full_like(xpoints, logfH2)
                else:
                    logfH2_val = float(logfH2.iloc[0]) if hasattr(logfH2, 'iloc') else float(logfH2[0])
                    y_oxidation = np.full_like(xpoints, logfH2_val)
            else:
                # Calculate equilibrium constant for gas-aqueous conversion if needed
                logK = subcrt(['O2', 'O2'], [-1, 1], ['gas', O2state], T=T, P=P, convert=False, messages=messages, show=False).out['logK']
                # This is logfO2 if O2state == "gas", or logaO2 if O2state == "aq"
                logfO2 = logfO2 + logK

                if yaxis == 'O2':
                    # Broadcast to match xpoints length
                    if isinstance(logfO2, (int, float)):
                        y_oxidation = np.full_like(xpoints, logfO2)
                    else:
                        logfO2_val = float(logfO2.iloc[0]) if hasattr(logfO2, 'iloc') else float(logfO2[0])
                        y_oxidation = np.full_like(xpoints, logfO2_val)
                elif yaxis == 'Eh':
                    y_oxidation = convert(logfO2, 'E0', T=T, P=P, pH=pH, logaH2O=logaH2O, messages=messages)
                elif yaxis == 'pe':
                    Eh_val = convert(logfO2, 'E0', T=T, P=P, pH=pH, logaH2O=logaH2O, messages=messages)
                    y_oxidation = convert(Eh_val, 'pe', T=T, messages=messages)

    else:
        # Invalid axis combination
        result.update({'xpoints': xpoints, 'y_oxidation': None, 'y_reduction': None, 'swapped': swapped})
        return result

    # Route to Plotly or matplotlib implementation
    if is_plotly:
        return _water_lines_plotly(result, xpoints, y_oxidation, y_reduction, swapped,
                                  lty, lwd, col, plot_it)

    # Matplotlib implementation
    # Only draw water lines if eout already has an axes (meaning it's from a diagram)
    # If no axes, this is being called just for calculation (e.g., from within diagram())
    if 'ax' not in eout or eout['ax'] is None:
        # No axes to plot on - just return the calculated values
        result.update({'xpoints': xpoints, 'y_oxidation': y_oxidation, 'y_reduction': y_reduction, 'swapped': swapped})
        return result

    # Use the axes from result
    ax = result['ax']

    # First, shade the water-unstable regions with gray
    # This creates the same effect as R's fill.NA for H2O.predominant
    if y_oxidation is not None and y_reduction is not None:
        from matplotlib.colors import ListedColormap

        # Get current axis limits to create shading
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        # Create a high-resolution mesh for smooth shading
        n_points = 500
        if swapped:
            # When swapped, xpoints is on the y-axis
            y_mesh = np.linspace(ylim[0], ylim[1], n_points)
            x_mesh = np.linspace(xlim[0], xlim[1], n_points)
            X, Y = np.meshgrid(x_mesh, y_mesh)

            # For each y-value, determine if it's in water-unstable region
            # Interpolate oxidation and reduction values to the mesh
            y_ox_interp = np.interp(y_mesh, xpoints, y_oxidation)
            y_red_interp = np.interp(y_mesh, xpoints, y_reduction)

            # Create mask: unstable where x < min or x > max
            unstable = np.zeros_like(X, dtype=bool)
            for i in range(n_points):
                ymin = min(y_ox_interp[i], y_red_interp[i])
                ymax = max(y_ox_interp[i], y_red_interp[i])
                unstable[i, :] = (X[i, :] < ymin) | (X[i, :] > ymax)
        else:
            # Normal: xpoints on x-axis, y values on y-axis
            x_mesh = np.linspace(xlim[0], xlim[1], n_points)
            y_mesh = np.linspace(ylim[0], ylim[1], n_points)
            X, Y = np.meshgrid(x_mesh, y_mesh)

            # Interpolate oxidation and reduction values to the mesh
            y_ox_interp = np.interp(x_mesh, xpoints, y_oxidation)
            y_red_interp = np.interp(x_mesh, xpoints, y_reduction)

            # Create mask: unstable where y < min or y > max
            unstable = np.zeros_like(Y, dtype=bool)
            for i in range(n_points):
                ymin = min(y_ox_interp[i], y_red_interp[i])
                ymax = max(y_ox_interp[i], y_red_interp[i])
                unstable[:, i] = (Y[:, i] < ymin) | (Y[:, i] > ymax)

        # Create masked array for unstable regions
        import numpy.ma as ma
        unstable_mask = ma.masked_where(~unstable, np.ones_like(X))

        # Draw the shading with gray (matching R's gray80 = 0.8)
        fill_na_cmap = ListedColormap(['0.8'])
        extent = [xlim[0], xlim[1], ylim[0], ylim[1]]
        ax.imshow(unstable_mask, aspect='auto', origin='lower',
                 extent=extent, interpolation='nearest',
                 cmap=fill_na_cmap, vmin=0, vmax=1, zorder=1)

    # Set line color
    if col is None:
        col = 'black'

    # Convert numeric line style to matplotlib style
    lty_map = {1: '-', 2: '--', 3: '-.', 4: ':', 5: '-', 6: '--'}
    if isinstance(lty, int):
        lty = lty_map.get(lty, '--')

    if swapped:
        if nvar1 == 1 or nvar2 == 2:
            # Add vertical lines on 1-D diagram
            if y_oxidation is not None and len(y_oxidation) > 0:
                ax.axvline(x=y_oxidation[0], linestyle=lty, linewidth=lwd, color=col)
            if y_reduction is not None and len(y_reduction) > 0:
                ax.axvline(x=y_reduction[0], linestyle=lty, linewidth=lwd, color=col)
        else:
            # xpoints above is really the ypoints
            if y_oxidation is not None:
                ax.plot(y_oxidation, xpoints, linestyle=lty, linewidth=lwd, color=col)
            if y_reduction is not None:
                ax.plot(y_reduction, xpoints, linestyle=lty, linewidth=lwd, color=col)
    else:
        if y_oxidation is not None:
            ax.plot(xpoints, y_oxidation, linestyle=lty, linewidth=lwd, color=col)
        if y_reduction is not None:
            ax.plot(xpoints, y_reduction, linestyle=lty, linewidth=lwd, color=col)

    # Update the figure and axes references in result to reflect the water lines
    fig = ax.get_figure()
    result['fig'] = fig
    result['ax'] = ax

    # Display the figure if plot_it=True
    # This allows water_lines() to display a figure that was created with plot_it=False
    if plot_it and fig is not None:
        try:
            from IPython.display import display
            display(fig)
        except (ImportError, NameError):
            # Not in IPython/Jupyter, matplotlib will handle display
            pass

    # Update result with water line data and return
    result.update({'xpoints': xpoints, 'y_oxidation': y_oxidation, 'y_reduction': y_reduction, 'swapped': swapped})
    return result

Draw water stability limits for Eh-pH, logfO2-pH, logfO2-T or Eh-T diagrams.

This function adds lines showing the oxidation and reduction stability limits of water to diagrams. Above the oxidation line, water breaks down to O2. Below the reduction line, water breaks down to H2.

Parameters

eout : dict
Output from affinity(), equilibrate(), or diagram()
which : str or list of str, default ['oxidation', 'reduction']
Which line(s) to draw: 'oxidation', 'reduction', or both
lty : int or str, default 2
Line style (matplotlib linestyle or numeric code)
lwd : float, default 1
Line width
col : str, optional
Line color (matplotlib color spec). If None, uses current foreground color
plot_it : bool, default True
Whether to plot the lines and display the figure. When True, the lines are added to the diagram and the figure is displayed (useful when the original diagram was created with plot_it=False). When False, only calculates and returns the water line coordinates without plotting.

Returns

dict
Dictionary containing all keys from the input diagram (including 'fig', 'ax', 'plotvar', 'plotvals', 'names', 'predominant', etc. if present) plus the following water line specific keys: - xpoints: x-axis values - y_oxidation: y values for oxidation line (or None) - y_reduction: y values for reduction line (or None) - swapped: whether axes were swapped

Examples

>>> # Add water lines to an existing displayed diagram
>>> basis(["Fe+2", "SO4-2", "H2O", "H+", "e-"], [0, math.log10(3), math.log10(0.75), 999, 999])
>>> species(["rhomboclase", "ferricopiapite", "hydronium jarosite", "goethite", "melanterite", "pyrite"])
>>> a = affinity(pH=[-1, 4, 256], pe=[-5, 23, 256])
>>> d = diagram(a, main="Fe-S-O-H, after Majzlan et al., 2006")
>>> water_lines(d, lwd=2)
>>> # Add water lines and display when diagram was created with plot_it=False
>>> d = diagram(a, main="Fe-S-O-H", plot_it=False)
>>> water_lines(d, lwd=2)  # This will display the figure with water lines

Notes

This function only works on diagrams with a redox variable (Eh, pe, O2, or H2) on one axis and pH, T, P, or another non-redox variable on the other axis. For 1-D diagrams, vertical lines are drawn.

Classes

class ThermoSystem
Expand source code
class ThermoSystem:
    """
    Global thermodynamic system manager for CHNOSZ.
    
    This class manages the thermodynamic database, basis species, 
    formed species, and calculation options - essentially serving
    as the global state container for all CHNOSZ calculations.
    """
    
    def __init__(self):
        """Initialize the thermodynamic system."""
        self._data_loader = DataLoader()
        self._obigt_db = None
        self._initialized = False
        
        # Core data containers (similar to R thermo object)
        self.opt: Dict[str, Any] = {}
        self.element: Optional[pd.DataFrame] = None
        self.obigt: Optional[pd.DataFrame] = None
        self.refs: Optional[pd.DataFrame] = None
        self.Berman: Optional[pd.DataFrame] = None
        self.buffer: Optional[pd.DataFrame] = None
        self.protein: Optional[pd.DataFrame] = None
        self.groups: Optional[pd.DataFrame] = None
        self.stoich: Optional[np.ndarray] = None
        self.stoich_formulas: Optional[np.ndarray] = None
        self.bdot_acirc: Optional[Dict[str, float]] = None
        self.formula_ox: Optional[pd.DataFrame] = None
        
        # System state
        self.basis: Optional[pd.DataFrame] = None
        self.species: Optional[pd.DataFrame] = None
        
        # Options and parameters
        self.opar: Dict[str, Any] = {}
        
    def reset(self, messages: bool = True) -> None:
        """
        Initialize/reset the thermodynamic system.

        This is equivalent to reset() in the R version, loading all
        the thermodynamic data and initializing the system.

        Parameters
        ----------
        messages : bool, default True
            Whether to print informational messages
        """
        try:
            # Load core data files
            self._load_options(messages)
            self._load_element_data(messages)
            self._load_berman_data(messages)
            self._load_buffer_data(messages)
            self._load_protein_data(messages)
            self._load_stoich_data(messages)
            self._load_bdot_data(messages)
            self._load_refs_data(messages)

            # Initialize OBIGT database
            self._obigt_db = OBIGTDatabase()
            self.obigt = self._obigt_db.get_combined_data()

            # Reset system state
            self.basis = None
            self.species = None
            self.opar = {}

            self._initialized = True
            if messages:
                print('reset: thermodynamic system initialized')

        except Exception as e:
            raise RuntimeError(f"Failed to initialize thermodynamic system: {e}")
    
    def _load_options(self, messages: bool = True) -> None:
        """Load default thermodynamic options."""
        try:
            opt_file = self._data_loader.get_data_path() / "thermo" / "opt.csv"
            if opt_file.exists():
                df = pd.read_csv(opt_file)
                # Convert to dictionary format (first row contains values)
                self.opt = dict(zip(df.columns, df.iloc[0]))
            else:
                # Default options if file not found
                self.opt = {
                    'E.units': 'J',
                    'T.units': 'C',
                    'P.units': 'bar',
                    'state': 'aq',
                    'water': 'SUPCRT92',
                    'G.tol': 100,
                    'Cp.tol': 1,
                    'V.tol': 1,
                    'varP': False,
                    'IAPWS.sat': 'liquid',
                    'paramin': 1000,
                    'ideal.H': True,
                    'ideal.e': True,
                    'nonideal': 'Bdot',
                    'Setchenow': 'bgamma0',
                    'Berman': np.nan,
                    'maxcores': 2,
                    'ionize.aa': True
                }
        except Exception as e:
            if messages:
                print(f"Warning: Could not load options: {e}")
            # Fallback to hardcoded defaults with critical unit options
            self.opt = {
                'E.units': 'J',
                'T.units': 'C',
                'P.units': 'bar',
                'state': 'aq',
                'water': 'SUPCRT92',
                'G.tol': 100,
                'Cp.tol': 1,
                'V.tol': 1,
                'varP': False,
                'IAPWS.sat': 'liquid',
                'paramin': 1000,
                'ideal.H': True,
                'ideal.e': True,
                'nonideal': 'Bdot',
                'Setchenow': 'bgamma0',
                'Berman': np.nan,
                'maxcores': 2,
                'ionize.aa': True
            }
    
    def _load_element_data(self, messages: bool = True) -> None:
        """Load element properties data."""
        try:
            self.element = self._data_loader.load_elements()
        except Exception as e:
            if messages:
                print(f"Warning: Could not load element data: {e}")
            self.element = None
    
    def _load_berman_data(self, messages: bool = True) -> None:
        """Load Berman mineral parameters from CSV files."""
        try:
            # Get path to Berman directory
            berman_path = self._data_loader.data_path / "Berman"

            if not berman_path.exists():
                if messages:
                    print(f"Warning: Berman directory not found: {berman_path}")
                self.Berman = None
                return

            # Find all CSV files in the directory
            csv_files = list(berman_path.glob("*.csv"))

            if not csv_files:
                if messages:
                    print(f"Warning: No CSV files found in {berman_path}")
                self.Berman = None
                return
            
            # Extract year from filename and sort in reverse chronological order (youngest first)
            # Following R logic: files <- rev(files[order(sapply(strsplit(files, "_"), "[", 2))])
            def extract_year(filepath):
                filename = filepath.name
                parts = filename.split('_')
                if len(parts) >= 2:
                    year_part = parts[1].replace('.csv', '')
                    try:
                        return int(year_part)
                    except ValueError:
                        return 0
                return 0
            
            # Sort files by year (youngest first)
            sorted_files = sorted(csv_files, key=extract_year, reverse=True)
            
            # Read parameters from each file
            berman_dfs = []
            for file_path in sorted_files:
                try:
                    df = pd.read_csv(file_path)
                    berman_dfs.append(df)
                except Exception as e:
                    print(f"Warning: Could not read Berman file {file_path}: {e}")
            
            # Combine all data frames (equivalent to do.call(rbind, Berman))
            if berman_dfs:
                self.Berman = pd.concat(berman_dfs, ignore_index=True)
                # Ensure all numeric columns are properly typed
                numeric_cols = ['GfPrTr', 'HfPrTr', 'SPrTr', 'VPrTr', 'k0', 'k1', 'k2', 'k3', 'k4', 'k5', 'k6',
                               'v1', 'v2', 'v3', 'v4', 'Tlambda', 'Tref', 'dTdP', 'l1', 'l2', 'DtH', 'Tmax', 'Tmin',
                               'd0', 'd1', 'd2', 'd3', 'd4', 'Vad']
                for col in numeric_cols:
                    if col in self.Berman.columns:
                        self.Berman[col] = pd.to_numeric(self.Berman[col], errors='coerce')
            else:
                self.Berman = None
                
        except Exception as e:
            if messages:
                print(f"Warning: Could not load Berman data: {e}")
            self.Berman = None

    def _load_buffer_data(self, messages: bool = True) -> None:
        """Load buffer definitions."""
        try:
            self.buffer = self._data_loader.load_buffers()
        except Exception as e:
            if messages:
                print(f"Warning: Could not load buffer data: {e}")
            self.buffer = None

    def _load_protein_data(self, messages: bool = True) -> None:
        """Load protein composition data.""" 
        try:
            self.protein = self._data_loader.load_proteins()
        except Exception as e:
            if messages:
                print(f"Warning: Could not load protein data: {e}")
            self.protein = None

    def _load_stoich_data(self, messages: bool = True) -> None:
        """Load stoichiometric matrix data."""
        try:
            stoich_df = self._data_loader.load_stoich()
            if stoich_df is not None:
                # Extract formulas and convert to matrix
                self.stoich_formulas = stoich_df.iloc[:, 0].values
                self.stoich = stoich_df.iloc[:, 1:].values
            else:
                self.stoich_formulas = None
                self.stoich = None
        except Exception as e:
            if messages:
                print(f"Warning: Could not load stoichiometric data: {e}")
            self.stoich_formulas = None
            self.stoich = None

    def _load_bdot_data(self, messages: bool = True) -> None:
        """Load B-dot activity coefficient parameters."""
        try:
            bdot_file = self._data_loader.get_data_path() / "thermo" / "Bdot_acirc.csv"
            if bdot_file.exists():
                df = pd.read_csv(bdot_file)
                if len(df.columns) >= 2:
                    self.bdot_acirc = dict(zip(df.iloc[:, 0], df.iloc[:, 1]))
                else:
                    self.bdot_acirc = {}
            else:
                self.bdot_acirc = {}
        except Exception as e:
            if messages:
                print(f"Warning: Could not load B-dot data: {e}")
            self.bdot_acirc = {}

    def _load_refs_data(self, messages: bool = True) -> None:
        """Load references data."""
        try:
            self.refs = self._data_loader.load_refs()
        except Exception as e:
            if messages:
                print(f"Warning: Could not load refs data: {e}")
            self.refs = None
    
    def is_initialized(self) -> bool:
        """Check if the thermodynamic system is initialized."""
        return self._initialized
    
    def get_obigt_db(self) -> OBIGTDatabase:
        """Get the OBIGT database instance."""
        if not self._initialized:
            self.reset()
        return self._obigt_db
    
    def get_option(self, key: str, default: Any = None) -> Any:
        """Get a thermodynamic option value."""
        return self.opt.get(key, default)
    
    def set_option(self, key: str, value: Any) -> None:
        """Set a thermodynamic option value."""
        self.opt[key] = value
    
    def info(self) -> Dict[str, Any]:
        """Get information about the current thermodynamic system."""
        if not self._initialized:
            return {"status": "Not initialized"}
        
        info = {
            "status": "Initialized",
            "obigt_species": len(self.obigt) if self.obigt is not None else 0,
            "elements": len(self.element) if self.element is not None else 0,
            "berman_minerals": len(self.Berman) if self.Berman is not None else 0,
            "buffers": len(self.buffer) if self.buffer is not None else 0,
            "proteins": len(self.protein) if self.protein is not None else 0,
            "stoich_species": len(self.stoich_formulas) if self.stoich_formulas is not None else 0,
            "basis_species": len(self.basis) if self.basis is not None else 0,
            "formed_species": len(self.species) if self.species is not None else 0,
            "current_options": dict(self.opt)
        }
        return info
    
    def __repr__(self) -> str:
        """String representation of the thermodynamic system."""
        if not self._initialized:
            return "ThermoSystem(uninitialized)"

        info = self.info()
        return (f"ThermoSystem("
                f"obigt={info['obigt_species']} species, "
                f"basis={info['basis_species']}, "
                f"formed={info['formed_species']})")

    # R-style uppercase property aliases for compatibility
    @property
    def OBIGT(self):
        """Alias for obigt (R compatibility)."""
        # Auto-initialize if needed AND obigt is None (matches R behavior)
        if self.obigt is None and not self._initialized:
            self.reset(messages=True)
        return self.obigt

    @OBIGT.setter
    def OBIGT(self, value):
        """Setter for OBIGT (R compatibility)."""
        _set_obigt_data(self, value)

Global thermodynamic system manager for CHNOSZ.

This class manages the thermodynamic database, basis species, formed species, and calculation options - essentially serving as the global state container for all CHNOSZ calculations.

Initialize the thermodynamic system.

Instance variables

prop OBIGT
Expand source code
@property
def OBIGT(self):
    """Alias for obigt (R compatibility)."""
    # Auto-initialize if needed AND obigt is None (matches R behavior)
    if self.obigt is None and not self._initialized:
        self.reset(messages=True)
    return self.obigt

Alias for obigt (R compatibility).

Methods

def get_obigt_db(self) ‑> OBIGTDatabase
Expand source code
def get_obigt_db(self) -> OBIGTDatabase:
    """Get the OBIGT database instance."""
    if not self._initialized:
        self.reset()
    return self._obigt_db

Get the OBIGT database instance.

def get_option(self, key: str, default: Any = None) ‑> Any
Expand source code
def get_option(self, key: str, default: Any = None) -> Any:
    """Get a thermodynamic option value."""
    return self.opt.get(key, default)

Get a thermodynamic option value.

def info(self) ‑> Dict[str, Any]
Expand source code
def info(self) -> Dict[str, Any]:
    """Get information about the current thermodynamic system."""
    if not self._initialized:
        return {"status": "Not initialized"}
    
    info = {
        "status": "Initialized",
        "obigt_species": len(self.obigt) if self.obigt is not None else 0,
        "elements": len(self.element) if self.element is not None else 0,
        "berman_minerals": len(self.Berman) if self.Berman is not None else 0,
        "buffers": len(self.buffer) if self.buffer is not None else 0,
        "proteins": len(self.protein) if self.protein is not None else 0,
        "stoich_species": len(self.stoich_formulas) if self.stoich_formulas is not None else 0,
        "basis_species": len(self.basis) if self.basis is not None else 0,
        "formed_species": len(self.species) if self.species is not None else 0,
        "current_options": dict(self.opt)
    }
    return info

Get information about the current thermodynamic system.

def is_initialized(self) ‑> bool
Expand source code
def is_initialized(self) -> bool:
    """Check if the thermodynamic system is initialized."""
    return self._initialized

Check if the thermodynamic system is initialized.

def reset(self, messages: bool = True) ‑> None
Expand source code
def reset(self, messages: bool = True) -> None:
    """
    Initialize/reset the thermodynamic system.

    This is equivalent to reset() in the R version, loading all
    the thermodynamic data and initializing the system.

    Parameters
    ----------
    messages : bool, default True
        Whether to print informational messages
    """
    try:
        # Load core data files
        self._load_options(messages)
        self._load_element_data(messages)
        self._load_berman_data(messages)
        self._load_buffer_data(messages)
        self._load_protein_data(messages)
        self._load_stoich_data(messages)
        self._load_bdot_data(messages)
        self._load_refs_data(messages)

        # Initialize OBIGT database
        self._obigt_db = OBIGTDatabase()
        self.obigt = self._obigt_db.get_combined_data()

        # Reset system state
        self.basis = None
        self.species = None
        self.opar = {}

        self._initialized = True
        if messages:
            print('reset: thermodynamic system initialized')

    except Exception as e:
        raise RuntimeError(f"Failed to initialize thermodynamic system: {e}")

Initialize/reset the thermodynamic system.

This is equivalent to reset() in the R version, loading all the thermodynamic data and initializing the system.

Parameters

messages : bool, default True
Whether to print informational messages
def set_option(self, key: str, value: Any) ‑> None
Expand source code
def set_option(self, key: str, value: Any) -> None:
    """Set a thermodynamic option value."""
    self.opt[key] = value

Set a thermodynamic option value.