拉格朗日插值
简介
对一个次数不超过 n 的多项式,拉格朗日插值法可以在 O(n2) 的时间内,利用多项式的 n+1 个取值,计算出多项式的各项系数。
多项式的系数表示法
众所周知一个 k 次多项式可以被它的 k+1 个系数唯一地确定,即:
f(x)=i=0∑kaixi
这种确定多项式的方法叫做多项式的**「系数表示法」**。
多项式的点值表示法
对于多项式 f(x),如果 f(x0)=y0,不引起歧义的话,我们就可以说多项式 f 在 x0 处的点值是 y0。
还是众所周知地,一个 k 次多项式同样可以被它的 k+1 个点值所唯一地确定。
其中的道理不难讲:设这个多项式为
f(x)=i=0∑kaixi
那么,如果我们有了这个多项式上的 k+1 个点 (x1,y1),(x2,y2),⋯,(xk+1,yk+1),就相当于有了一个 k+1 次方程组:
⎩⎪⎪⎪⎨⎪⎪⎪⎧a0+a1x1+a2x12+⋯+akx1ka0+a1x2+a2x22+⋯+akx2k⋯a0+a1xk+1+a2xk+12+⋯+akxk+1k==⋯=y1y2⋯yk+1
这是一个由 k+1 条 k+1 元一次方程组成的一个方程组,并且不难证明在这 k+1 条方程中,不存在两个等价的方程。
因此 a0,a1,a2,⋯,ak 这 k+1 个未知数可以被唯一地解出来。也就是说,这 k+1 个点同样唯一地确定了多项式
f(x)=i=0∑kaixi
这种确定多项式的方法叫做**「点值表示法」**。
系数表示法转点值表示法
直接 O(n2) 算。略。
对于一些特殊的点值,我们可以在更低的复杂度内算出这些点值。
点值表示法转系数表示法
首先由之前的推导过程,可以发先转换的过程实际上就是解一个 k+1 元一次方程组。
因此可以利用高斯消元在 O(k3) 的时间内求解方程,确定其系数。
我们再来看 O(k2) 的拉格朗日插值。
首先有一个结论:
f(x)≡f(a)(mod(x−a))
这是因为当 x=a 时,有 f(x)−f(a)=0。
那么由我们初中就学过的因式定理,就有 f(x)−f(a)≡0(mod(x−a)),即 f(x)≡f(a)(mod(x−a))。
分别将 a 取 x1,x2,⋯,xk+1 ,那么 f(a) 就取遍了 y1,y2,⋯,yk+1。
因此就得到一个关于 f(x) 的方程组:
⎩⎪⎪⎪⎨⎪⎪⎪⎧f(x)≡y1(mod (x−x1))f(x)≡y2(mod (x−x2))⋯f(x)≡yk+1(mod (x−xk+1))
发现这个形式很像 中国剩余定理 对不对qwq
回忆一下中国剩余定理的过程:
-
计算所有模数的积 Mul=i=1∏k+1(x−xi)。
-
对每个模数,计算 mi=(x−xi)Mul=i=j∏(x−xj)。
-
计算 mi 在模第 i 个模数 (x−xi) 意义下的逆元:mi−1=i=j∏xi−xj1
-
计算 ci=mimi−1=i=j∏xi−xjx−xj
-
线性同余方程组的解即为:i=1∑k+1yici,即:
f(x)=i=1∑k+1yii=j∏xi−xjx−xj
这就是 f(x) 的表达式了qwq。
事实上,我们取一个点值 (xc,yc),验证一下 f(xc),可以发现:
f(xc)=i=1∑k+1yii=j∏xi−xjxc−xj
而当 i=c 时,后面的那个 i=j∏xi−xjxc−xj 中,总存在 j 使得 j=c,那么 xc−xj=0,乘起来也是 0,那这一整项就都是 0。
这意味着那么大个求和符号中,实际上被计算的只有 i=c 时的值,也就是: ycc=j∏xc−xjxc−xj=yc,因此确实有 f(xc)=yc。
于是 f(x) 符合要求。我们只需要进行简单的二元多项式乘法,就可以计算出 f(x) 的表达式。
不难发现这个表达式的计算复杂度是 O(k2) 的,那么我们就在 O(k2) 的时间内算出了 f(x)。
Coding 注意事项&技巧&实现
这里放上计算 f(k) 的能过板子题的代码qwq。
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<iostream>
#define LL long long
const LL mod=998244353;
using namespace std;
LL n,k;
LL x[2005],y[2005];
LL ksm(LL a,LL b){
LL ans=1ll,y=a;
while(b){
if(b&1)ans*=y,ans%=mod;
y*=y;y%=mod;
b>>=1;
}
return ans%mod;
}//快速幂,用于求逆元
LL inv(LL x){
return ksm(x,mod-2);
}//逆元qwq
LL calc(LL t){
LL ans=0;
for(int i=1;i<=n;i++){
LL p=y[i]%mod,q=1ll;
for(int j=1;j<=n;j++){
if(i==j)continue;
p=p*(t-x[j])%mod;
q=q*(x[i]-x[j])%mod;//直接照着公式算qwq
}
ans+=p*inv(q)%mod;
ans%=mod;
}
return (ans%mod+mod)%mod;
}
int main(void){
cin>>n>>k;
for(int i=1;i<=n;i++)cin>>x[i]>>y[i];
cout<<calc(k)<<endl;
return 0;
}