热图(相关系数矩阵)

推荐用seaborn这个库中的函数来快速实现热图的绘制。可以通过

1
2
pip install seaborn
conda install seaborn -c conda-forge

来安装这个库。

直接上绘图代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
import seaborn as sns
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
# 传入N个指数序列
index_a = np.random.rand(40)
index_b = np.random.rand(40)
index_c = np.random.rand(40)
index_d = np.random.rand(40)
index_e = np.random.rand(40)
# 堆叠
index_all = np.vstack([index_a,index_b,index_c,index_d,index_e])
#计算相关系数矩阵
correlation_matrix = np.zeros((index_all.shape[0],index_all.shape[0]))
p_value_matrix = np.zeros((index_all.shape[0],index_all.shape[0]))
for i in range(index_all.shape[0]):
for j in range(index_all.shape[0]):
correlation_matrix[i, j], p_value_matrix[i, j] = pearsonr(index_all[i], index_all[j])
#处理P值矩阵为对应的字符串矩阵(过检部分的*号显示)
p_value_str = p_value_matrix.astype(str)
p_value_str[p_value_matrix < 0.01] = '***'
p_value_str[(p_value_matrix >= 0.01) & (p_value_matrix < 0.05)] = '**'
p_value_str[(p_value_matrix >= 0.05) & (p_value_matrix < 0.1)] = '*'
p_value_str[(p_value_matrix >0.1)] = ''
strings = (np.asarray(["{0:.2f}\n{1}".format(cor,p_value) for cor, p_value in zip(correlation_matrix.flatten(), p_value_str.flatten())]))
strings = strings.reshape(correlation_matrix.shape)
#绘图
fig,ax = plt.subplots(figsize=(7,7),facecolor="w")
sns.heatmap(correlation_matrix, annot=strings,fmt='',cmap=plt.cm.coolwarm,vmin=-0.5, vmax=0.5,
annot_kws={"size":15,"fontweight":"bold"},linecolor="k",linewidths=.2,
ax=ax)
plt.xticks(ticks=np.arange(0.5, correlation_matrix.shape[0], 1), labels=['Index_a','Index_b','Index_c','Index_d','Index_e'],fontsize=12)
plt.yticks(ticks=np.arange(0.5, correlation_matrix.shape[1], 1), labels=['Index_a','Index_b','Index_c','Index_d','Index_e'],fontsize=12)
plt.title('Heatmap',fontsize=16)
plt.show()

图形输出:

image-20200523154901802

以下是sns.heatmap函数所支持的参数:

参数

  • data:二维数据集,可以是 ndarray 数组。如果提供 Pandas DataFrame,则会使用其索引/列信息来标记行和列。
  • vmin, vmax:浮点数,可选。用于锚定颜色映射的值,否则会根据数据和其他关键字参数进行推断。
  • cmap:Matplotlib 的颜色映射名称或对象,或颜色列表,可选。数据值到颜色空间的映射。
  • center:浮点数,可选。用于绘制差异性数据时设置颜色映射的中心值。如果使用此参数,且未指定 cmap,则会更改默认的颜色映射。
  • robust:布尔值,可选。如果为 True,且未指定 vmin 或 vmax,则将使用健壮的分位数计算颜色映射范围,而不是使用极端值。
  • annot:布尔值或二维数据集,可选。如果为 True,则在每个单元格中写入数据值。如果是与数据形状相同的类数组,则使用其来标注热图,而不是使用数据。注意,DataFrame 将根据位置匹配,而不是索引。
  • fmt:字符串,可选。添加注释时要使用的字符串格式化代码。
  • annot_kws:字典,可选。当 annot 为 True 时,用于 matplotlib.axes.Axes.text() 的关键字参数。
  • linewidths:浮点数,可选。单元格之间分隔线的宽度。
  • linecolor:颜色,可选。单元格之间分隔线的颜色。
  • cbar:布尔值,可选。是否绘制颜色条。
  • cbar_kws:字典,可选。用于 matplotlib.figure.Figure.colorbar() 的关键字参数。
  • cbar_ax:Matplotlib Axes 对象,可选。颜色条的绘制位置,否则将从主要 Axes 中获取空间。
  • square:布尔值,可选。如果为 True,则将 Axes 的纵横比设置为“equal”,使每个单元格都成为正方形。
  • xticklabels, yticklabels:字符串“auto”、布尔值、类列表或整数,可选。如果为 True,则绘制数据框的列名。如果为 False,则不绘制列名。如果是类似列表的对象,则将其作为 xticklabels 绘制替代标签。如果是整数,则使用列名,但仅每 n 个标签绘制一次。如果为“auto”,则尝试密集绘制非重叠标签。
  • mask:布尔数组或 DataFrame,可选。如果传递了 mask,则数据在 mask 为 True 的单元格中将不显示。具有缺失值的单元格会自动被屏蔽。
  • ax:Matplotlib Axes 对象,可选。要绘制图表的 Axes,否则使用当前正在使用的 Axes。