I'm trying to create a Dash application that displays a grid of subplots to visualize the pairwise comparison of the columns of a dataframe. To the top and left of each grid row and column will be the corresponding variables. The variable names can be quite long though, so it's easy to misalign them. I've tried, staggering the variable names, but eventually settled on line-wrapping them. See the picture below. I've also included my code at the end of this post
df = pd.DataFrame({
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA": ["1", "2", "3", "4"],
"BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB": ["2024-01-01", "2024-01-02", "2024-01-03", "2024-01-04"],
"CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC": ["cat", "dog", "cat", "mouse"],
"DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD": ["10.5", "20.3", "30.1", "40.2"],
'EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE': ['apple', 'apple', 'apple', 'banana']
})
For this dataframe, I'd like to get something like
As you can see, I'm having trouble aligning the row and column labels of the grid. Here is my code
import dash
from dash import dcc, html
import pandas as pd
import plotly.express as px
import plotly.subplots as sp
import numpy as np
import plotly.graph_objects as go
# Sample DataFrame
df = pd.DataFrame({
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA": ["1", "2", "3", "4"],
"BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB": ["2024-01-01", "2024-01-02", "2024-01-03", "2024-01-04"],
"CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC": ["cat", "dog", "cat", "mouse"],
"DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD": ["10.5", "20.3", "30.1", "40.2"],
'EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE': ['apple', 'apple', 'apple', 'banana']
})
# Convert data types
def convert_dtypes(df):
for col in df.columns:
try:
df[col] = pd.to_numeric(df[col]) # Convert to int/float
except ValueError:
try:
df[col] = pd.to_datetime(df[col]) # Convert to datetime
except ValueError:
df[col] = df[col].astype("string") # Keep as string
return df
df = convert_dtypes(df)
columns = df.columns
num_cols = len(columns)
# Dash App
app = dash.Dash(__name__)
app.layout = html.Div([
html.H1("Pairwise Column Plots"),
dcc.Graph(id='grid-plots')
])
@app.callback(
dash.Output('grid-plots', 'figure'),
dash.Input('grid-plots', 'id') # Dummy input to trigger callback
)
def create_plot_grid(_):
fig = sp.make_subplots(rows = num_cols, cols = num_cols,
#subplot_titles = [f"{x} vs {y}" for x in columns for y in columns],
shared_xaxes = False, shared_yaxes = False)
annotations = [] # Store subplot titles dynamically
# Add column labels (Top Labels)
for j, col_label in enumerate(columns):
annotations.append(
dict(
#text=f"<b>{col_label}</b>", # Bold for emphasis
text=f"<b>{'<br>'.join(col_label[x:x+10] for x in range(0, len(col_label), 10))}</b>",
xref = "paper", yref = "paper",
x = (j) / num_cols, # Center over the column
y = 1.02, # Slightly above the top row
showarrow = False,
font = dict(size = 14, color = "black")
)
)
# Add row labels (Side Labels)
for i, row_label in enumerate(columns):
annotations.append(
dict(
#text = f"<b>{row_label}</b>", # Bold for emphasis
text=f"<b>{'<br>'.join(row_label[x:x+10] for x in range(0, len(row_label), 10))}</b>",
xref = "paper", yref = "paper",
x = -0.02, # Slightly to the left of the row
y = (1 - (i + 0.5) / num_cols), # Center next to the row
showarrow = False,
font = dict(size = 14, color = "black"),
textangle = -90 # Rotate text for vertical orientation
)
)
print(annotations)
for i, x_col in enumerate(columns):
for j, y_col in enumerate(columns):
dtype_x, dtype_y = df[x_col].dtype, df[y_col].dtype
row, col = i + 1, j + 1 # Adjust for 1-based indexing
# I only want to print the upper triangle of the grid
if j <= i:
trace = None
# Numeric vs Numeric: Scatter Plot
elif pd.api.types.is_numeric_dtype(dtype_x) and pd.api.types.is_numeric_dtype(dtype_y):
trace = px.scatter(df, x = x_col, y = y_col).data[0]
# Numeric vs Categorical: Box Plot
elif pd.api.types.is_numeric_dtype(dtype_x) and pd.api.types.is_string_dtype(dtype_y):
trace = px.box(df, x = y_col, y = x_col).data[0]
elif pd.api.types.is_string_dtype(dtype_x) and pd.api.types.is_numeric_dtype(dtype_y):
trace = px.box(df, x = x_col, y = y_col).data[0]
# Categorical vs Categorical: Count Heatmap
elif pd.api.types.is_string_dtype(dtype_x) and pd.api.types.is_string_dtype(dtype_y):
#trace = px.histogram(df, x = x_col, color = y_col, barmode = "group").data[0]
counts_df = (
df
.groupby([x_col, y_col])
.size()
.reset_index(name = 'count')
.pivot_table(index = x_col, columns = y_col, values = "count", aggfunc="sum")
)
trace = go.Heatmap(z = counts_df.values, x = counts_df.columns, y = counts_df.index, showscale = False)
# Datetime vs Numeric: Line Plot
elif pd.api.types.is_datetime64_any_dtype(dtype_x) and pd.api.types.is_numeric_dtype(dtype_y):
trace = px.line(df, x = x_col, y = y_col).data[0]
elif pd.api.types.is_numeric_dtype(dtype_x) and pd.api.types.is_datetime64_any_dtype(dtype_y):
trace = px.line(df, x = y_col, y = x_col).data[0]
else:
trace = None # Unsupported combination
if trace:
fig.add_trace(trace, row = row, col = col)
fig.update_layout(height = 300 * num_cols,
width = 300 * num_cols,
showlegend = False,
annotations = annotations)
print(fig['layout'])
return fig
if __name__ == '__main__':
app.run_server(debug = True)