Source code for pyaerocom.plot.heatmaps

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import FuncFormatter
from seaborn import heatmap

from pyaerocom.mathutils import exponent


def _format_annot_heatmap(annot, annot_fmt_rows, annot_fmt_exceed):
    """
    Process annotation formatting info for :func:`df_to_heatmap`

    Parameters
    ----------
    annot : ndarray
        2D pre-existing annotation information.
    annot_fmt_rows : list
        annotation formatting strings for each row of the input table. This
        parameter is only considered if `annot_fmt_rowwise` is True. See also
        :func:`_format_annot_heatmap`.
    annot_fmt_exceed : list, optional
        how to format annotated values that exceed a certain threshold.
        The list contains 2 entries, 1. the threshold values, 2. how values
        exceeding this threshold should be formatted. .

    Returns
    -------
    ndarray
        updated annotation information (input `annot`)
    list
        row formatting info.


    """
    _annot = []
    if not isinstance(annot_fmt_rows, list):
        annot_fmt_rows = []
        for row in annot:
            mask = row[~np.isnan(row)]
            if len(mask) == 0:  # all NaN
                annot_fmt_rows.append("")
                continue
            mask = mask[mask != 0]
            exps = exponent(mask)
            minexp = exps.min()
            if minexp < -3:
                annot_fmt_rows.append(".1E")
            elif minexp < 0:
                annot_fmt_rows.append(f".{-minexp + 1}f")
            elif minexp in [0, 1]:
                annot_fmt_rows.append(".1f")
            else:
                annot_fmt_rows.append(".0f")
    if isinstance(annot_fmt_exceed, list):
        exceed_val, exceed_fmt = annot_fmt_exceed
    else:
        exceed_val, exceed_fmt = None, None

    for i, row in enumerate(annot):
        rowfmt = annot_fmt_rows[i]
        if rowfmt == "":
            row_fmt = [""] * len(row)
        else:
            row_fmt = []
            for val in row:
                if np.isnan(val):
                    valstr = ""
                else:
                    if exceed_val is not None and val > exceed_val:
                        valstr = format(val, exceed_fmt)
                    else:
                        valstr = format(val, rowfmt)

                row_fmt.append(valstr)
        _annot.append(row_fmt)
    annot = np.asarray(_annot)
    return annot, annot_fmt_rows


[docs] def df_to_heatmap( df, cmap=None, center=None, low=0.3, high=0.3, vmin=None, vmax=None, color_rowwise=False, normalise_rows=False, normalise_rows_how=None, normalise_rows_col=None, norm_ref=None, sub_norm_before_div=True, annot=True, num_digits=None, ax=None, figsize=(12, 12), cbar=False, cbar_label=None, cbar_labelsize=None, xticklabels=None, xtick_rot=45, yticklabels=None, ytick_rot=45, xlabel=None, ylabel=None, title=None, labelsize=12, annot_fontsize=None, annot_fmt_rowwise=False, annot_fmt_exceed=None, annot_fmt_rows=None, # explicit formatting strings for rows cbar_ax=None, cbar_kws=None, **kwargs, ): """Plot a pandas dataframe as heatmap Parameters ---------- df : DataFrame table data cmap : str, optional string specifying colormap to be used center : float, optional value that is mapped to center colour of colormap (e.g. 0) low : float, optional Extends lower range of the table values so that when mapped to the colormap, it’s entire range isn’t used. E.g. 0.3 roughly corresponds to colormap crop of 30% at the lower end. high : float, optional Extends upper range of the table values so that when mapped to the colormap, it’s entire range isn’t used. E.g. 0.3 roughly corresponds to colormap crop of 30% at the upper end. vmin : float, optional lower end of value range to be plotted. If specified, input arg `low` will be ignored. vmax : float, optional upper end of value range to be plotted. If specified, input arg `low` will be ignored. color_rowwise : bool, optional if True, the color mapping is applied row by row, else, for the whole table. Defaults to False. normalise_rows : bool, optional if True, the table is normalised in a rowwise manner either using the mean value in each row (if argument ``normalise_rows_col`` is unspecified) or using the value in a specified column. Defaults to False. normalise_rows_how : str, optional aggregation string for row normalisation. Choose from mean or median. Only relevant if input arg ``normalise_rows==True``. normalise_rows_col : int, optional if provided and if arg. ``normalise_rows==True``, then the corresponding table column is used for normalisation rather than the mean value of each row. norm_ref : float or ndarray, optional reference value(s) used for rowwise normalisation. Only relevant if normalise_rows is True. If specified, normalise_rows_how and normalise_rows_col will be ignored. sub_norm_before_div : bool, optional if True, the rowwise normilisation is applied by subtracting the normalisation value for each row before dividing by it. This can be useful to visualise positive or negative departures from the mean or median. annot : bool or list or ndarray, optional if True, the table values are printed into the heatmap. Defaults to True, in which case the values are computed based on the table content. If list or ndarray, the shape needs to be the same as input table shape (no of rows and cols), in which case the values of that 2D frame are used. num_digits : int, optional number of digits printed in heatmap annotation. ax : axes, optional matplotlib axes instance used for plotting, if None, an axes will be created figsize : tuple, optional size of figure for plot cbar : bool, optional if True, a colorbar is included cbar_label : str, optional label of colorbar (if colorbar is included, see cbar option) cbar_labelsize : int, optional size of colorbar label xticklabels : list, optional List of x axis labels. xtick_rot : int, optional rotation of x axis labels, defaults to 45 degrees. yticklabels : list, optional List of string labels. ytick_rot : int, optional rotation of y axis labels, defaults to 45 degrees. xlabel : str, optional x axis label ylabel : str, optional y axis label title : str, optional title of heatmap labelsize : int, optional fontsize of labels, default to 12 annot_fontsize : int, optional fontsize of annotated text. annot_fmt_rowwise : bool rowwise formatting of annotation values, based on row value ranges. Defaults to False. annot_fmt_exceed : list, optional how to format annotated values that exceed a certain threshold. The list contains 2 entries, 1. the threshold values, 2. how values exceeding this thrshold should be formatted. This parameter is only considered if `annot_fmt_rowwise` is True. See also :func:`_format_annot_heatmap`. annot_fmt_rows : list annotation formatting strings for each row of the input table. This parameter is only considered if `annot_fmt_rowwise` is True. See also :func:`_format_annot_heatmap`. cbar_ax : Axes, optional axes instance for colorbar, parsed to :func:`seaborn.heatmap`. cbar_kws : dict, optional keywords for colorbar formatting, , parsed to :func:`seaborn.heatmap`. **kwargs further keyword args parsed to :func:`seaborn.heatmap` Raises ------ ValueError if input `annot` is list or ndarray and has a different shape than the input `df`. Returns ------- Axes plot axes instance list or None annotation information for rows """ if cmap is None: cmap = "bwr" if cbar_label is None: cbar_label = "" if normalise_rows_how is None: normalise_rows_how = "median" if ax is None: _, ax = plt.subplots(1, 1, figsize=figsize) if cbar_kws is None: cbar_kws = {} if annot_fontsize is None: annot_fontsize = labelsize - 4 if not "annot_kws" in kwargs: kwargs["annot_kws"] = {} kwargs["annot_kws"]["size"] = annot_fontsize df_hm = df if normalise_rows: if norm_ref is None: if normalise_rows_col is not None: if isinstance(normalise_rows_col, str): try: normalise_rows_col = df.columns.to_list().index(normalise_rows_col) except ValueError: raise ValueError(f"Failed to localise column {normalise_rows_col}") norm_ref = df.values[:, normalise_rows_col] else: if normalise_rows_how == "mean": norm_ref = df.mean(axis=1) elif normalise_rows_how == "median": norm_ref = df.median(axis=1) else: raise ValueError( f"Invalid input for normalise_rows_how ({normalise_rows_how}). " f"Choose mean or median" ) if sub_norm_before_div: df_hm = df.subtract(norm_ref, axis=0).div(norm_ref, axis=0) else: df_hm = df.div(norm_ref, axis=0) cbar_kws["format"] = FuncFormatter(lambda x, pos: f"{x:.0%}") if color_rowwise: df_hm = df_hm.div(abs(df_hm).max(axis=1), axis=0) cbar_kws["label"] = cbar_label if "norm" in kwargs: norm = kwargs["norm"] vmin = norm.boundaries[0] vmax = norm.boundaries[-1] else: if vmin is None: vmin = df_hm.min().min() # * (1 - low) if vmax is None: vmax = df_hm.max().max() # * (1 + high) vmin -= abs(vmin) * low vmax += abs(vmax) * high if annot is True: annot = df.values elif isinstance(annot, list): annot = np.asarray(annot) else: fmt = "" if isinstance(annot, np.ndarray): if not annot.shape == df.values.shape: raise ValueError( "Invalid input for annot: needs to have same shape as input dataframe" ) elif np.any([isinstance(x, str) for x in annot.flatten()]): fmt = "" elif annot_fmt_rowwise: annot, annot_fmt_rows = _format_annot_heatmap(annot, annot_fmt_rows, annot_fmt_exceed) fmt = "" elif num_digits is None or num_digits > 5: fmt = ".4g" else: fmt = f".{num_digits}f" ax = heatmap( df_hm, cmap=cmap, center=center, annot=annot, ax=ax, # changes this from df_hm to df because the annotation and colorbar didn't work. cbar=cbar, cbar_ax=cbar_ax, cbar_kws=cbar_kws, fmt=fmt, vmin=vmin, vmax=vmax, xticklabels=True, yticklabels=True, **kwargs, ) ax.figure.axes[-1].yaxis.label.set_size(labelsize) if title is not None: ax.set_title(title, fontsize=labelsize + 2) if yticklabels is None: yticklabels = ax.get_yticklabels() ax.set_yticklabels(yticklabels, rotation=ytick_rot, fontsize=labelsize - 2) if xticklabels is None: xticklabels = ax.get_xticklabels() ax.set_xticklabels(xticklabels, rotation=xtick_rot, ha="right", fontsize=labelsize - 2) if xlabel is None: xlabel = "" ax.set_xlabel(xlabel, fontsize=labelsize) if ylabel is None: ylabel = "" ax.set_ylabel(ylabel, fontsize=labelsize) if cbar_labelsize is not None: ax.figure.axes[-1].yaxis.label.set_size(cbar_labelsize) return ax, annot_fmt_rows