极大似然拟合概率分布
对数据集应用概率模型是解释数据集的一个好方法,但是,如何找到一个合适的模型本身就是一项工作。在选定模型之后,还要将其与数据进行比较或者检验。在这个例子当中,我们针对statsmodels自带的数据集“心脏移植后存活时间(1967-1974)”,采用极大似然估计的方法拟合概率分布。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 import numpy as npimport pandas as pdimport scipy.stats as stimport statsmodels.datasets as datasetsfrom plotnine.ggplot import *from plotnine.qplot import *from plotnine.geoms import *from plotnine.coords import *from plotnine.labels import *from plotnine.facets import *from plotnine.scales import *from plotnine.themes import *import plotly.graph_objs as goimport plotly.offline as py_offlineimport plotly.plotly as pyfrom plotly import toolspy_offline.init_notebook_mode() %matplotlib inline
1 data = datasets.heart.load_pandas().data
1 2 data = data[data.censors==1 ] qplot(x=data.survival.index, y=sorted (data.survival)[::-1 ], xlab='Patient' , ylab='Survival time' )
1 qplot(x=data.survival,bins=12 ,geom='histogram' , xlab='Patient' , ylab='Survival time' )
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 trace1 = go.Scatter( x=data.index, y=sorted (data.survival)[::-1 ], mode='markers' , name='散点图' ) trace2 = go.Histogram(x=data.survival, name='直方图' ) fig = tools.make_subplots(rows=1 , cols=2 ) fig.append_trace(trace1, 1 , 1 ) fig.append_trace(trace2, 1 , 2 ) fig['layout' ]['xaxis1' ].update(title='病人' ) fig['layout' ]['xaxis2' ].update(title='存活时间(天)' ) fig['layout' ]['yaxis1' ].update(title='存活时间(天)' ) fig['layout' ]['yaxis2' ].update(title='病人数目' ) fig['layout' ].update(height=600 , width=1000 , title='病人情况' ) py_offline.iplot(fig)
This is the format of your plot grid:
[ (1,1) x1,y1 ] [ (1,2) x2,y2 ]
通过原始数据的散点图和直方图,我们可以看到,绝大部分的存活时间在心脏移植后不超过3年,当然这是上个世纪六七十年代的数据,时至今日,今天的存活率和存活时间已经打大提高。
从直方图上可以看出,存活时间的频数随着天数的增加快速下降,因此考虑采用指数分布来拟合数据。
一个指数分布的概率密度函数是:
f ( x ; λ ) = { λ e − λ x , x ≥ 0 , 0 , x < 0. f(x; \lambda)=
\begin{cases}
\lambda e^{-\lambda x}, &x\geq0,\\
0, &x<0.
\end{cases}
f ( x ; λ ) = { λ e − λ x , 0 , x ≥ 0 , x < 0 .
其中λ > 0 \lambda>0 λ > 0 是分布的一个参数,常被称为率参数(rate parameter),λ \lambda λ 的倒数被称为scale参数。
假设存活天数的数目为s s s ,s s s 是服从参数为λ \lambda λ 的随机变量,那么根据极大似然法容易得到,λ \lambda λ 的极大似然估计为1 s ˉ \frac{1}{\bar{s}} s ˉ 1 ,即s s s 的样本均值。
1 2 3 4 5 6 survival_mean = data.survival.mean() rate = 1. / survival_mean smax = data.survival.max () days = np.linspace(0. , smax, 1000 )
这样,我们就得到了拟合分布。
1 dist_exp = st.expon.pdf(days, scale=survival_mean)
然后将拟合分布的概率密度函数与原始数据的直方图进行比较,需要注意的是,一个是概率,一个是频数,所以需要转换为同样的标准。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 nbins = 30 trace0 = go.Histogram( x=data.survival, nbinsx=nbins, name = 'Emrpical data' , ) trace1 = go.Scatter( x=days, y=dist_exp*len (data.index)*smax/nbins, name='Fitted PDF' , mode='lines' , ) data_plot=[trace0, trace1] layout = dict (xaxis = dict (title = 'Survival time (days)' ), yaxis = dict (title = 'Number of patients' ), ) fig = go.Figure(data=data_plot, layout=layout) py_offline.iplot(fig)
从上图可以看出,直接用指数分布拟合的结果并不是很理想,这可能是由于极大似然估计本身的缺陷(在假设情况下只用到了样本的均值),此外,对于指数分布MLE能够得到解析解,但是有些分布可能得不到解析解,就需要通过EM算法等近似求解数值解。
那么,我们通过scipy数值求解指数分布的参数:
1 2 3 dist = st.expon args = dist.fit(data.survival) args
(1.0, 222.2888888888889)
正如我们在极值分析里所做的一样,我们可以用K-S检验衡量分布对数据的拟合优度。
1 st.kstest(data.survival, dist.cdf, args)
KstestResult(statistic=0.36199693810792966, pvalue=8.647045785181717e-06)
如此小的p值,意味着拒绝原假设(样本的分布与拟合分布相同),即两者存在这显著的差异,说明不应该指数分布拟合。
改用另外一种分布Brinbaum-Sanders distribution,这种分布通常用来拟合疲劳寿命或者失效时间。
1 2 3 dist = st.fatiguelife args = dist.fit(data.survival) st.kstest(data.survival, dist.cdf, args)
KstestResult(statistic=0.1877344610194689, pvalue=0.07321149700086327)
由上可知,p值为0.073,在5%的置信水平上无法拒绝原假设,说明BS分布比指数分布更加合适。另一方面需要注意的是,在scipy中,并不是直接采用原始数据进行拟合的,而是利用loc和scale两个参数进行规范化后,与标准分布进行拟合,这一点在实际应用中需要注意。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 dist_bs = st.fatiguelife.pdf(days, *args) trace0 = go.Histogram( x=data.survival, nbinsx=nbins, name = 'Emrpical data' , ) trace1 = go.Scatter( x=days, y=dist_exp*len (data.index)*smax/nbins, name='Fitted by EXP' , mode='lines' , ) trace2 = go.Scatter( x=days, y=dist_bs*len (data.index)*smax/nbins, name='Fitted by BS' , mode='lines' , ) data_plot=[trace0, trace1, trace2] layout = dict (xaxis = dict (title = 'Survival time (days)' ), yaxis = dict (title = 'Number of patients' ), ) fig = go.Figure(data=data_plot, layout=layout) py_offline.iplot(fig)