超前滞后相关系数

超前滞后相关用于分析两个序列之间的时滞相关性,可以在一定程度上判断两者的因果关系。例如一个序列与另一个序列存在显著的超前相关,那么说明前一个序列对应的要素可能会在显著相关的周期后影响一个序列对应的要素。

这里我将计算讲个序列的超前滞后相关封装成了现成函数方便大家调用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import numpy as np
from scipy.stats import pearsonr

def lead_lag_correlation(s1, s2, max_shift):
correlations = np.empty(2 * max_shift + 1)
p_values = np.empty(2 * max_shift + 1)

for shift in range(-max_shift, max_shift + 1):
if shift < 0:
corr, p_value = pearsonr(s1[:shift], s2[-shift:])
elif shift > 0:
corr, p_value = pearsonr(s1[shift:], s2[:-shift])
else:
corr, p_value = pearsonr(s1, s2)
correlations[shift + max_shift] = corr
p_values[shift + max_shift] = p_value
return correlations, p_values

只需要输入S1,S2两个等长序列,以及需要计算的最大超前步长即可。下边是一个示例:

1
2
3
4
5
6
7
8
9
10
# 示例调用
s1 = np.random.rand(40)
s2 = np.random.rand(40)

max_shift = 10
correlations, p_values = lead_lag_correlation(s1, s2, max_shift)
print("超前/滞后相关系数数组:")
print(correlations)
print("\n对应的 P 值数组:")
print(p_values)

输出结果为:

1
2
3
4
5
6
7
8
9
10
11
超前/滞后相关系数数组:
[-0.0667701 -0.20196212 -0.19440555 -0.03344729 0.09401154 0.03682201
-0.08435947 -0.14058007 0.29389157 0.35276157 -0.08771688 -0.43661858
-0.29126164 0.00451985 -0.04504127 -0.02366273 0.10640454 0.0751925
0.1293342 0.26239277 -0.1882403 ]

对应的 P 值数组:
[0.72591345 0.27591415 0.28634049 0.8533993 0.59691207 0.83366645
0.62472335 0.40659622 0.07330382 0.02761788 0.59042669 0.00545349
0.07603491 0.97881895 0.79421402 0.89267097 0.5492149 0.67749602
0.48050658 0.15385903 0.31916923]

可以对其进行一些简单的可视化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
fig, axs = plt.subplots(2, 1, figsize=(8, 8))
# 第一个子图:绘制两条序列的原始值
axs[0].plot(sequence1, label='S1')
axs[0].plot(sequence2, label='S2')
axs[0].set_title('Original Sequences')
axs[0].legend()
# 第二个子图:绘制超前/滞后相关系数,并加粗 P 值小于 0.1 的部分
shifts = np.arange(-max_shift, max_shift + 1)
axs[1].plot(shifts, correlations, label='Correlation')
axs[1].set_title('Lagged Correlation')
# 根据 P 值小于 0.1 的部分加粗或是打点
threshold = 0.5
mask = p_values < threshold
#打点表示 P 值小于阈值的部分
axs[1].scatter(shifts[mask], correlations[mask], color='black', marker='o') # 使用 scatter 绘制黑色实心圆
#修改轴刻度
axs[1].set_xlabel('Shift')
axs[1].set_ylabel('Correlation')
axs[1].set_xticks(shifts)
axs[1].set_xticklabels(shifts)
plt.tight_layout()
plt.show()

图形输出为:

image-20200528144651404