# -*- coding: utf-8 -*-
""" Tools for studying correlations
Author:
- Simon Wehle (swehle@desy.de)
"""
import numpy as np
from scipy.stats import binned_statistic_2d
from scipy import stats
import matplotlib.pylab as plt
import pandas as pd
import seaborn as sns
[docs]def corrmatrix(corr, separate_first=0, x_label_rot=45, invert_y=True, label_font_size=None, ax=None, *args, **kwargs):
"""
Recommendation:
with plt.style.context(['default','seaborn-bright']):
corrmatrix(corrm_s,separate_first=2)
"""
ax = plt.gca() if ax is None else ax
sns.heatmap(corr, annot=False, cmap='PiYG',square=True, vmax=1,vmin=-1, *args, **kwargs)
plt.ylim(*plt.xlim())
plt.gca().set_xticklabels(plt.gca().get_xticklabels(), rotation=x_label_rot, horizontalalignment='right', fontsize=label_font_size)
plt.gca().set_yticklabels(plt.gca().get_yticklabels(), fontsize=label_font_size)
if invert_y:
plt.gca().invert_yaxis()
if separate_first > 0:
plt.axhline(separate_first, color='gray',lw=1)
plt.axvline(separate_first, color='gray',lw=1)
[docs]def flat_correlation(x,y, nbins='auto', zoom=1, nlabels=5, ax=None, ax_fmt='%.2e', x_label_rot=45, invert_y=True, draw_labels=True, get_im=False, cmap='jet'):
""" Calculate and plot a 2D correlation in flat binning.
This function calculates an equal frequency binning for x and y and fills a 2D histogram with this binning.
Thus each slice in x and y contains the same number of entries for continuus distributions.
For uncorrelated distributions the expected amount of each bin is N_expected = N_total / N_bins**2
This plot shows the statistical significance of the deviation from N_expected.
Args:
x: array of values to be binned in x direction
y: array of values to be binned in y direction
nbins: int or 'auto', number of bins in x and y
zoom: factor f of the significance [-f*5,f*5]
nlabels: number of x,y labels
ax: axes, if None, takes current
ax_fmt: formatter for tick labls
x_label_rot: rotation for x labels
Returns:
chi2 probability for flat distribution
"""
not_on_axes = True if ax is None else False
ax = plt.gca() if ax is None else ax
# calculate equal fequrency binning
nbins = int(2*(3*len(x)**(1/3))**(1/2)) if nbins=='auto' else nbins
binsx = pd.unique(np.percentile(x, np.linspace(0,100, nbins)))
binsy = pd.unique(np.percentile(y, np.linspace(0,100, nbins)))
# Bin count
bs = binned_statistic_2d(x, y, values=x, statistic='count', bins=[binsx,binsy])
# Calculate actual count - expected significance
nexp_total = len(x)/((nbins-1)**2)
a0 = bs.statistic.T
m1 = a0.sum(axis=1)/(a0.shape[1])
m1 /= np.min(m1)
m0 = a0.sum(axis=0)/(a0.shape[0])
m0 /= np.min(m0)
beta = np.full( a0.shape, nexp_total)
m_exp = (beta.T*(m1).astype(float)).T*(m0).astype(float)
m_stat = m_exp**0.5
a = (a0-m_exp)/m_stat
a[a0==0] = None
# Plotting
cmap=plt.get_cmap(cmap) if isinstance(cmap, str) else cmap
im = ax.imshow(a, cmap=plt.get_cmap(cmap), interpolation='nearest', origin='lower',vmin=-5*zoom, vmax=5*zoom)
# set labels
if draw_labels:
cbar = plt.colorbar(im,fraction=0.046, pad=0.04, ax=ax)
ax.set_xticks(np.linspace(*ax.get_xlim(), nlabels))
ax.set_xticklabels([ax_fmt%f for f in np.percentile(x, np.linspace(0,100, nlabels))], rotation=x_label_rot, ha='right')
ax.set_yticks(np.linspace(*ax.get_ylim(), nlabels))
ax.set_yticklabels([ax_fmt%f for f in np.percentile(y, np.linspace(0,100, nlabels))])
if isinstance(x,pd.Series):
ax.set_xlabel(x.name)
if isinstance(y,pd.Series):
ax.set_ylabel(y.name)
else:
ax.set_xticklabels([])
ax.set_yticklabels([])
# Calculate chi2 probability
flat_probability = stats.distributions.chi2.sf(np.nansum(a*a),(nbins)**2-(nbins-1)-(nbins-1)-1)
if invert_y:
ax.invert_yaxis()
if get_im:
return im
return flat_probability
[docs]def flat_corr_matrix(df, pdf=None, tight=False, labels=None, label_size=None, size=12, n_labels=3,
fontsize='auto', draw_cbar=False, tick_label_rotation=45, formatter='%.2e', label_rotation=45, cmap='PiYG'):
""" Draws a flat correlation matrix of df
Args:
df:
pdf:
tight:
col_numbers:
labels:
label_size:
size:
n_labels:
fontsize:
draw_cbar:
rotation:
formatter:
Returns:
"""
assert isinstance(df, pd.DataFrame), 'Argument of wrong type! Needs pd.DataFrame'
n_vars = np.shape(df)[1]
fontsize = np.interp(n_vars, (0,10), (22, 10)) if fontsize is 'auto' else fontsize
if labels is None:
labels = df.columns
else:
assert len(labels) == len(df.columns), "Numbers of labels not matching the numbers of coulums in the df"
im = None
fig, axes = plt.subplots(nrows=n_vars, ncols=n_vars, figsize=(size, size))
# Plotting the matrix, iterate over the columns in 2D
for i, row in zip(range(n_vars), axes):
for j, ax in zip(range(n_vars), row):
if i is j - 1000:
plt.sca(ax)
ax.hist(df.iloc[:, i].values, label='data', color='gray')
ax.set_yticklabels([])
else:
im = flat_correlation(df.iloc[:, j], df.iloc[:, i], ax=ax, draw_labels=False, get_im=True,cmap=cmap)
ax.xaxis.set_major_locator(plt.NullLocator())
ax.yaxis.set_major_locator(plt.NullLocator())
if tight:
plt.tight_layout()
# Common outer label
for i, row in zip(range(n_vars), axes):
for j, ax in zip(range(n_vars), row):
if i == n_vars - 1:
if label_size is not None:
set_flat_labels(ax, df.iloc[:, j], axis=1, n_labels=n_labels, labelsize=label_size, rotation=90 if tick_label_rotation is 0 else tick_label_rotation, formatter=formatter)
ax.set_xlabel(labels[j], fontsize=fontsize, rotation=label_rotation, ha='right', va='top')
if j == 0:
if label_size is not None:
set_flat_labels(ax, df.iloc[:, i], axis=0, n_labels=n_labels, labelsize=label_size, rotation=tick_label_rotation, formatter=formatter)
ax.set_ylabel(labels[i], fontsize=fontsize, rotation=label_rotation, ha='right', va='bottom')
if pdf is None:
# plt.show()
pass
else:
pdf.savefig()
plt.close()
if draw_cbar:
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
cbar = plt.colorbar(im, cax=cbar_ax, )
cbar.ax.set_ylabel('$\sigma$', rotation=0, fontsize=fontsize*1.2, va='center')
cbar.ax.tick_params(labelsize=fontsize)
[docs]def set_flat_labels(ax, x, n_labels=5, axis=1, labelsize=12, rotation=45,
formatter='%.3e'):
""" Helper function to draw the correct x-labels to a flat plot
Args:
ax:
x:
n_labels:
axis:
labelsize:
rotation:
formatter:
Returns:
"""
start, end = ax.get_xlim() if axis == 1 else ax.get_ylim()
label_position = np.linspace(start, end, n_labels)
# print label_position
new_labels = np.percentile(x, np.linspace(0, 100, n_labels))
# print new_labels
if axis is 1:
ha = 'center' if rotation != 0 else 'right'
ax.set_xticks(label_position)
ax.set_xticklabels([formatter % i for i in new_labels], fontsize=labelsize, rotation=rotation, ha=ha)
else:
ha = 'center' if rotation == 0 else 'top'
ax.set_yticks(label_position)
ax.set_yticklabels([formatter % i for i in new_labels], fontsize=labelsize, rotation=rotation, va=ha)
[docs]def heatmap(x, y, tfs=12, bkg_color='#F1F1F1', separate_first=0, **kwargs):
""" Calculate a heatmap
Based on: https://towardsdatascience.com/better-heatmaps-and-correlation-matrix-plots-in-python-41445d0f2bec
"""
if 'color' in kwargs:
color = kwargs['color']
else:
color = [1]*len(x)
if 'palette' in kwargs:
palette = kwargs['palette']
n_colors = len(palette)
else:
n_colors = 256 # Use 256 colors for the diverging color palette
palette = sns.diverging_palette(359,122, s=90, n=500) #sns.color_palette("BrBG", n_colors)
if 'color_range' in kwargs:
color_min, color_max = kwargs['color_range']
else:
color_min, color_max = min(color), max(color) # Range of values that will be mapped to the palette, i.e. min and max possible correlation
def value_to_color(val):
if color_min == color_max:
return palette[-1]
else:
val_position = float((val - color_min)) / (color_max - color_min) # position of value in the input range, relative to the length of the input range
val_position = min(max(val_position, 0), 1) # bound the position betwen 0 and 1
ind = int(val_position * (n_colors - 1)) # target index in the color palette
return palette[ind]
if 'size' in kwargs:
size = kwargs['size']
else:
size = [1]*len(x)
if 'size_range' in kwargs:
size_min, size_max = kwargs['size_range'][0], kwargs['size_range'][1]
else:
size_min, size_max = min(size), max(size)
size_scale = kwargs.get('size_scale', 500)
def value_to_size(val):
if size_min == size_max:
return 1 * size_scale
else:
val_position = (val - size_min) * 0.99 / (size_max - size_min) + 0.01 # position of value in the input range, relative to the length of the input range
val_position = min(max(val_position, 0), 1) # bound the position betwen 0 and 1
return val_position * size_scale
if 'x_order' in kwargs:
x_names = [t for t in kwargs['x_order']]
else:
x_names = [t for t in sorted(set([v for v in x]))]
x_to_num = {p[1]:p[0] for p in enumerate(x_names)}
if 'y_order' in kwargs:
y_names = [t for t in kwargs['y_order']]
else:
y_names = [t for t in sorted(set([v for v in y]))]
y_to_num = {p[1]:p[0] for p in enumerate(y_names)}
plot_grid = plt.GridSpec(1, 30, hspace=0.2, wspace=0.1) # Setup a 1x10 grid
ax = plt.subplot(plot_grid[:,:-1]) # Use the left 14/15ths of the grid for the main plot
marker = kwargs.get('marker', 's')
kwargs_pass_on = {k:v for k,v in kwargs.items() if k not in [
'color', 'palette', 'color_range', 'size', 'size_range', 'size_scale', 'marker', 'x_order', 'y_order'
]}
ax.scatter(
x=[x_to_num[v] for v in x],
y=[y_to_num[v] for v in y],
marker=marker,
s=[value_to_size(v) for v in size],
c=[value_to_color(v) for v in color],
**kwargs_pass_on
)
ax.set_xticks([v for k,v in x_to_num.items()])
ax.set_xticklabels([k for k in x_to_num], rotation=45, horizontalalignment='right', fontsize=tfs)
ax.set_yticks([v for k,v in y_to_num.items()])
ax.set_yticklabels([k for k in y_to_num], fontsize=tfs)
ax.grid(False, 'major')
ax.grid(True, 'minor')
ax.set_xticks([t + 0.5 for t in ax.get_xticks()], minor=True)
ax.set_yticks([t + 0.5 for t in ax.get_yticks()], minor=True)
ax.set_xlim([-0.5, max([v for v in x_to_num.values()]) + 0.5])
ax.set_ylim([-0.5, max([v for v in y_to_num.values()]) + 0.5])
ax.set_facecolor(bkg_color)
if separate_first:
l = np.sqrt(len(x))
plt.axvline(separate_first - .5, color='gray')
plt.axhline(l - .5 - separate_first , color='gray')
# Add color legend on the right side of the plot
if color_min < color_max:
ax = plt.subplot(plot_grid[:,-1]) # Use the rightmost column of the plot
#ax.axis('off')
plt.box(on=None)
col_x = [0]*len(palette) # Fixed x coordinate for the bars
bar_y=np.linspace(color_min, color_max, n_colors) # y coordinates for each of the n_colors bars
bar_height = bar_y[1] - bar_y[0]
print(bar_height)
ax.barh(
y=bar_y,
width=[15]*len(palette), # Make bars 5 units wide
left=col_x, # Make bars start at 0
height=bar_height,
color=palette,
linewidth=0
)
ax.set_ylim(-2,2)
ax.set_xlim(0, 5)# Bars are going from 0 to 5, so lets crop the plot somewhere in the middle
ax.grid(False) # Hide grid
ax.set_facecolor('white') # Make background white
ax.set_xticks([]) # Remove horizontal ticks
ax.set_yticks(np.linspace(min(bar_y), max(bar_y), 3)) # Show vertical ticks for min, middle and max
ax.yaxis.tick_right() # Show vertical ticks on the right
plt.sca(plt.subplot(plot_grid[:,:-1]))
[docs]def corrplot(data, size_scale=500, marker='s',tfs=12,
separate_first=0,
*args,**kwargs):
""" Correlation plot
Based on: https://towardsdatascience.com/better-heatmaps-and-correlation-matrix-plots-in-python-41445d0f2bec
"""
corr = pd.melt(data.reset_index(), id_vars='index')
corr.columns = ['x', 'y', 'value']
heatmap(
corr['x'], corr['y'],
color=corr['value'], color_range=[-1, 1],
size=corr['value'].abs(), size_range=[0,1],
marker=marker,
x_order=data.columns,
y_order=data.columns[::-1],
size_scale=size_scale,
tfs=tfs,
separate_first=separate_first,
*args,**kwargs
)