import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colors as mcolors
from . import util
"""
Visualization module for missing data patterns.
This module provides plotting functions to help understand the structure
and distribution of missingness in datasets. It includes matrix plots,
heatmaps, and column-wise summaries.
Functions:
----------
- plot_missing_matrix : Visualize missingness patterns with row/column alignment.
- plot_missing_heatmap : Show missing value correlation heatmap.
These tools are useful during EDA or when diagnosing missing data mechanisms.
Example usage::
from missmecha import visual
visual.plot_missing_matrix(X)
"""
def _get_auto_figsize(n_rows, n_cols, base_width=1.2, base_height=0.3, max_size=(20, 12)):
"""
Compute dynamic figsize based on DataFrame shape.
base_width: how wide each column should be (in inches)
base_height: how tall each row should be (in inches)
max_size: cap the maximum figsize to avoid excessive size
"""
width = min(max_size[0], max(6, n_cols * base_width))
height = min(max_size[1], max(4, n_rows * base_height))
return (width, height)
#def matrix(df, figsize=(20, 12), cmap="RdBu", color=True, fontsize=14, label_rotation=45, show_colorbar=False,ts = False):
[docs]
def plot_missing_matrix(df, figsize=None, cmap="Blues", sort_by=None, color=True, fontsize=14, label_rotation=45, show_colorbar=False, ts=False):
"""
Visualize missing data patterns in a matrix-style heatmap.
This function renders a binary mask of missingness in the input DataFrame as a heatmap.
It optionally colors the observed (non-missing) values using a colormap, and supports both
standard tabular and time series formats.
Parameters
----------
df : pandas.DataFrame
Input DataFrame to visualize. Missing values (NaN) will be shown as empty.
figsize : tuple of int, optional
Custom figure size (width, height). Defaults to auto-scaling based on shape.
cmap : str, optional
Colormap to apply to observed values when `color=True`. Default is "Blues".
sort_by : str or None, optional
If set, sorts rows by the specified column before plotting. Useful for detecting missing pattern.
color : bool, optional
If True, applies a colormap to observed values. If False, uses a binary (gray-scale) mask.
fontsize : int, optional
Font size for column labels and axis ticks. Default is 14.
label_rotation : int, optional
Rotation angle for x-axis labels (column names and missing rates). Default is 45°.
show_colorbar : bool, optional
Whether to display the colorbar (only works if `color=True`).
ts : bool, optional
If True, displays the y-axis using the actual DataFrame index (e.g., for time series).
If False, uses sequential row numbers.
Returns
-------
ax : matplotlib.axes.Axes
Axes object of the generated plot.
Notes
-----
- Top axis: column names; bottom axis: column-wise missing rates.
- Works with both numerical and categorical columns.
- Fully observed or fully missing columns are retained (not filtered).
- For large datasets, consider subsampling before plotting for performance.
Examples
--------
>>> from missmecha.visual import plot_missing_matrix
>>> import pandas as pd
>>> df = pd.read_csv("data.csv")
>>> plot_missing_matrix(df, color=False)
"""
if sort_by:
df = df.sort_values(by=sort_by, ascending=False).reset_index(drop=True)
height, width = df.shape
missing_rates = df.isnull().sum() / height * 100
if figsize is None:
figsize = _get_auto_figsize(height, width)
# Build RGB matrix
if not color:
fixed_color = (0.25, 0.25, 0.25)
g = np.full((height, width, 3), 1.0)
g[df.notnull().values] = fixed_color
else:
data_array = util.type_convert(df)
for col in range(width):
col_data = data_array[:, col]
valid_mask = ~np.isnan(col_data)
if valid_mask.any():
min_val, max_val = np.nanmin(col_data), np.nanmax(col_data)
if min_val != max_val:
data_array[valid_mask, col] = (col_data[valid_mask] - min_val) / (max_val - min_val) + 1
else:
data_array[valid_mask, col] = 1
norm = mcolors.Normalize(vmin=0, vmax=1.5)
cmap = plt.get_cmap(cmap)
g = np.full((height, width, 3), 1.0)
for col in range(width):
col_data = data_array[:, col]
valid_mask = ~np.isnan(col_data)
g[valid_mask, col] = cmap(norm(col_data[valid_mask]))[:, :3]
# === Plot ===
fig, ax = plt.subplots(figsize=figsize)
ax.imshow(g, interpolation="none", aspect="auto")
ax.grid(False)
# Remove all default x-axis ticks/labels from base ax
ax.set_xticks([])
ax.set_xticklabels([])
# --- Top: Column Names ---
ax_top = ax.twiny()
ax_top.set_xlim(ax.get_xlim())
ax_top.set_xticks(range(width))
ax_top.set_xticklabels(df.columns, rotation=label_rotation, ha="left", fontsize=fontsize)
ax_top.xaxis.set_ticks_position("top")
ax_top.xaxis.set_label_position("top")
# --- Bottom: Missing Rates ---
ax_bottom = ax.twiny()
ax_bottom.set_xlim(ax.get_xlim())
ax_bottom.set_xticks(range(width))
ax_bottom.set_xticklabels([f"{rate:.1f}%" for rate in missing_rates],
rotation=label_rotation, ha="right", fontsize=fontsize - 2)
ax_bottom.xaxis.set_ticks_position("bottom")
ax_bottom.xaxis.set_label_position("bottom")
# Y-axis row labels
if not ts:
ax.set_yticks([0, height - 1])
ax.set_yticklabels([1, height], fontsize=fontsize)
else:
# Show a fixed maximum number of y-axis labels (e.g., 50)
max_labels = 50
step = max(1, height // max_labels)
ticks = list(range(0, height, step))
if height - 1 not in ticks:
ticks.append(height - 1) # Ensure last row is labeled
ax.set_yticks(ticks)
ax.set_yticklabels([df.index[i] for i in ticks], fontsize=fontsize)
# Optional: colorbar
if color and show_colorbar:
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, orientation="vertical", fraction=0.02, pad=0.02)
cbar.set_label("Normalized Values", fontsize=fontsize)
plt.tight_layout()
plt.show()
#return ax
[docs]
def plot_missing_heatmap(df, figsize=(20, 12), fontsize=14, label_rotation=45, cmap='RdBu', method = "pearson"):
"""
Plot a heatmap of pairwise nullity correlations.
This function visualizes the pairwise correlation between missing value patterns
across columns in the input DataFrame. The heatmap helps identify dependencies
between missingness in different variables and can guide further missing data analysis.
Parameters
----------
df : pandas.DataFrame
Input dataset to visualize. Each column should represent a feature.
figsize : tuple of int, optional
Figure size in inches (width, height). Default is (20, 12).
fontsize : int, optional
Font size for axis labels and annotations. Default is 14.
label_rotation : int, optional
Rotation angle (in degrees) for x-axis tick labels. Default is 45.
cmap : str, optional
Colormap for the heatmap (e.g., 'RdBu', 'viridis'). Default is 'RdBu'.
method : {'pearson', 'kendall', 'spearman'}, optional
Correlation method to compute pairwise nullity relationships. Default is 'pearson'.
Returns
-------
ax : matplotlib.axes.Axes
Axes object containing the plotted heatmap.
Raises
------
ValueError
If the input DataFrame does not contain any missing values.
Notes
-----
- Fully observed or fully missing columns are excluded from the plot.
- If the dataset has more than 1000 rows, a random sample of 1000 rows is used.
- The heatmap represents correlation between binary indicators of missingness (True/False).
Examples
--------
>>> from missmecha.visual import plot_missing_heatmap
>>> import pandas as pd
>>> df = pd.read_csv("my_data.csv")
>>> plot_missing_heatmap(df)
"""
# Step 1: Sample if too large
if df.shape[0] > 1000:
df = df.sample(n=1000, random_state=42)
# Convert types but preserve columns/index
converted_array = util.type_convert(df)
df_converted = pd.DataFrame(converted_array, columns=df.columns, index=df.index)
# Remove fully observed or fully missing columns
missing_vars = df_converted.isnull().var(axis=0) > 0
df_used = df_converted.loc[:, missing_vars]
if df_used.shape[1] == 0:
raise ValueError("No missing values found in the dataset.")
# Compute nullity correlation
corr_mat = df_used.isnull().corr(method=method)
mask = np.ones_like(corr_mat, dtype=bool)
# Plot heatmap
plt.figure(figsize=figsize)
ax = sns.heatmap(corr_mat, cmap=cmap, vmin=-1, vmax=1,
cbar=True, annot=True, fmt=".2f", annot_kws={"size": fontsize - 2})
ax.set_xticklabels(ax.get_xticklabels(), rotation=label_rotation, ha='right', fontsize=fontsize)
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=fontsize)
plt.show()