超前滞后相关用于分析两个序列之间的时滞相关性,可以在一定程度上判断两者的因果关系。例如一个序列与另一个序列存在显著的超前相关,那么说明前一个序列对应的要素可能会在显著相关的周期后影响一个序列对应的要素。
这里我将计算讲个序列的超前滞后相关封装成了现成函数方便大家调用:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| import numpy as np from scipy.stats import pearsonr
def lead_lag_correlation(s1, s2, max_shift): correlations = np.empty(2 * max_shift + 1) p_values = np.empty(2 * max_shift + 1)
for shift in range(-max_shift, max_shift + 1): if shift < 0: corr, p_value = pearsonr(s1[:shift], s2[-shift:]) elif shift > 0: corr, p_value = pearsonr(s1[shift:], s2[:-shift]) else: corr, p_value = pearsonr(s1, s2) correlations[shift + max_shift] = corr p_values[shift + max_shift] = p_value return correlations, p_values
|
只需要输入S1,S2两个等长序列,以及需要计算的最大超前步长即可。下边是一个示例:
1 2 3 4 5 6 7 8 9 10
| s1 = np.random.rand(40) s2 = np.random.rand(40)
max_shift = 10 correlations, p_values = lead_lag_correlation(s1, s2, max_shift) print("超前/滞后相关系数数组:") print(correlations) print("\n对应的 P 值数组:") print(p_values)
|
输出结果为:
1 2 3 4 5 6 7 8 9 10 11
| 超前/滞后相关系数数组: [-0.0667701 -0.20196212 -0.19440555 -0.03344729 0.09401154 0.03682201 -0.08435947 -0.14058007 0.29389157 0.35276157 -0.08771688 -0.43661858 -0.29126164 0.00451985 -0.04504127 -0.02366273 0.10640454 0.0751925 0.1293342 0.26239277 -0.1882403 ]
对应的 P 值数组: [0.72591345 0.27591415 0.28634049 0.8533993 0.59691207 0.83366645 0.62472335 0.40659622 0.07330382 0.02761788 0.59042669 0.00545349 0.07603491 0.97881895 0.79421402 0.89267097 0.5492149 0.67749602 0.48050658 0.15385903 0.31916923]
|
可以对其进行一些简单的可视化:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| fig, axs = plt.subplots(2, 1, figsize=(8, 8))
axs[0].plot(sequence1, label='S1') axs[0].plot(sequence2, label='S2') axs[0].set_title('Original Sequences') axs[0].legend()
shifts = np.arange(-max_shift, max_shift + 1) axs[1].plot(shifts, correlations, label='Correlation') axs[1].set_title('Lagged Correlation')
threshold = 0.5 mask = p_values < threshold
axs[1].scatter(shifts[mask], correlations[mask], color='black', marker='o')
axs[1].set_xlabel('Shift') axs[1].set_ylabel('Correlation') axs[1].set_xticks(shifts) axs[1].set_xticklabels(shifts) plt.tight_layout() plt.show()
|
图形输出为: