如何对SciPy曲线拟合施加约束?

问题描述:

我正在尝试使用自定义概率密度函数拟合一些实验值的分布.显然,结果函数的积分应始终等于1,但是简单scipy.optimize.curve_fit(function,dataBincenters,dataCounts)的结果永远不会满足此条件. 解决此问题的最佳方法是什么?

I'm trying to fit the distribution of some experimental values with a custom probability density function. Obviously, the integral of the resulting function should always be equal to 1, but the results of simple scipy.optimize.curve_fit(function, dataBincenters, dataCounts) never satisfy this condition. What is the best way to solve this problem?

您可以定义自己的残差函数,包括一个罚分参数,如下面的代码所示,该函数事先知道沿间隔的积分必须为2..如果您在没有处罚的情况下进行测试,您将看到您得到的是常规的curve_fit:

You can define your own residuals function, including a penalization parameter, like detailed in the code below, where it is known beforehand that the integral along the interval must be 2.. If you test without the penalization you will see that what your are getting is the conventional curve_fit:

import matplotlib.pyplot as plt
import scipy
from scipy.optimize import curve_fit, minimize, leastsq
from scipy.integrate import quad
from scipy import pi, sin
x = scipy.linspace(0, pi, 100)
y = scipy.sin(x) + (0. + scipy.rand(len(x))*0.4)
def func1(x, a0, a1, a2, a3):
    return a0 + a1*x + a2*x**2 + a3*x**3

# here you include the penalization factor
def residuals(p,x,y):
    integral = quad( func1, 0, pi, args=(p[0],p[1],p[2],p[3]))[0]
    penalization = abs(2.-integral)*10000
    return y - func1(x, p[0],p[1],p[2],p[3]) - penalization

popt1, pcov1 = curve_fit( func1, x, y )
popt2, pcov2 = leastsq(func=residuals, x0=(1.,1.,1.,1.), args=(x,y))
y_fit1 = func1(x, *popt1)
y_fit2 = func1(x, *popt2)
plt.scatter(x,y, marker='.')
plt.plot(x,y_fit1, color='g', label='curve_fit')
plt.plot(x,y_fit2, color='y', label='constrained')
plt.legend(); plt.xlim(-0.1,3.5); plt.ylim(0,1.4)
print 'Exact   integral:',quad(sin ,0,pi)[0]
print 'Approx integral1:',quad(func1,0,pi,args=(popt1[0],popt1[1],
                                                popt1[2],popt1[3]))[0]
print 'Approx integral2:',quad(func1,0,pi,args=(popt2[0],popt2[1],
                                                popt2[2],popt2[3]))[0]
plt.show()

#Exact   integral: 2.0
#Approx integral1: 2.60068579748
#Approx integral2: 2.00001911981

其他相关问题: