多项式板子 | miaom's Blog
miaom's Blog

lych真的0

miaom's avatar miaom

多项式板子

最近突然想复习多项式了,然后就打了个板子。

多项式乘法

对于两个$n-1$次的多项式,考虑暴力乘法,复杂度为$O(n^2)$ 。

如果使用多项式的点值表示,那么直接乘的复杂度为$O(n)$。问题在于如何把一个多项式转化成为点值的形式然后转化回来。

具体的讲,令$w^n=1$且$w^0,w^1,\dots,w^{n-1}$互不相同,我们可以在$O(n\log(n))$的时间内,求出多项式在$w^0,w^1,\dots,w^{n-1}$的值或根据这些值求出多项式。

考虑计算多项式$f(x)$的点值形式,将它的每项按照指数的奇偶分成$f_0(x^2)+xf_1(x^2)$。递归计算这两个多项式。对于某个$w^k$,已经求出了$f_0$与$f_1$在$w^{2k}$的值,可以在$O(n)$时间内得到$f(x)$的点值形式。由主定理可知,总复杂度为$O(n\log(n))$。

从点值求多项式时类似,但是需要除以$n$。

也可以使用非递归的方式自底向上实现这个过程。

多项式乘法本质上相当于求$f(x)=\sum_{i=0}^xg(i)h(x-i)$,可以用来优化一些dp。

在模$P=998244353$意义下,由于有单位元$\omega$且$P-1$是一个较大的$2$的幂的倍数,容易进行上面的操作(NTT)。

多项式求逆

对于给定的$n-1$次多项式$f(x)$,求一个$n-1$次多项式$g(x)$满足$f(x)\times g(x)=1 (mod x^n)$。

考虑递归构造。当$n=1$时容易得出$g(x)_0={f(x)_0}^{-1}$。

假设$f(x)\times g(x)=1 (mod x^n)$,那么$\left(f(x)\times g(x)-1\right)^2=0 (mod x^{2n})$。

化简得$f(x)\cdot g(x)\cdot(2-f(x)\cdot g(x))=1 (mod x^{2n})$,所以$f^{-1}(x)=g(x)\cdot(2-f(x)\cdot g(x))$。

复杂度$O(n\log(n))$。

多项式除法

留坑

多项式ln

对于给定的$n-1$次多项式$f(x)$,求一个$n-1$次多项式$g(x)$满足$g(x)=ln(f(x)) (mod x^n)$。

由于有$ln(f(x))’=\frac{f’(x)}{f(x)}$,所以可以在$O(n\log(n))$时间内求出$g(x)$。

多项式exp

留坑

多项式多点求值

留坑

多项式多点插值

留坑

奇怪的板子(未完成)

#include<bits/stdc++.h>
using namespace std;
namespace poly
{
    typedef long long ll;
    const int P=998244353;
    const int W=3;
    const int N=400005;
    int pw(int x,int y)
    {
        int ans=1;
        for (;y;x=(ll)x*x%P,y>>=1)
            if (y&1) ans=(ll)ans*x%P;
        return ans;
    }
    void dft(int a[],int n,int p)
    {
        static int w[N],r[N];
        int W0=pw(W,(P-1)/n);
        if (p) W0=pw(W0,P-2);
        w[0]=1;
        for (int i=1;i<n;i++)
            w[i]=(ll)w[i-1]*W0%P;
        //r[0]=0;
        for (int i=1;i<n;i++)
            r[i]=r[i>>1]>>1|((n>>(i&1))&(n-1));
        for (int i=1;i<n;i++)
            if (r[i]<i) swap(a[r[i]],a[i]);

        for (int i=1;i<n;i<<=1)
        for (int j=0,t=n/(i<<1);j<n;j+=i<<1)
        for (int k=0,l=0;k<i;k++,l+=t)
        {
            int x=(ll)a[j+k+i]*w[l]%P,y=a[j+k];
            a[j+k]=(y+x)%P;
            a[j+k+i]=(y-x+P)%P;
        }

        if (p)
            for (int i=0,t=pw(n,P-2);i<n;i++)
                a[i]=(ll)a[i]*t%P;
    }
    void mul(int a[],int b[],int c[],int n)
    {
        int tn=1,tl=0;
        while(tn<n) tn<<=1,tl++;
        tn<<=1;tl++;
        static int A[N],B[N],C[N];
        for (int i=0;i<n;i++) A[i]=a[i],B[i]=b[i];
        for (int i=n;i<tn;i++) A[i]=B[i]=0;
        dft(A,tn,0);dft(B,tn,0);
        for (int i=0;i<tn;i++) C[i]=(ll)A[i]*B[i]%P;
        dft(C,tn,1);
        for (int i=0;i<n*2-1;i++) c[i]=C[i];
    }
    void inv(int f[],int g[],int n)
    {
        if (n==1)
            g[0]=pw(f[0],P-2);
        else
        {
            static int h[N];
            inv(f,g,(n+1)/2);
            for (int i=(n+1)/2;i<n;i++)
                g[i]=0;
            mul(f,g,h,n);
            for (int i=0;i<n;i++)
                h[i]=(P-h[i])%P;
            h[0]=(h[0]+2)%P;
            mul(h,g,g,n);
        }
    }
    void ln(int f[],int g[],int n)
    {
        //f[0]==1
        static int h[N];
        for (int i=1;i<n;i++)
            h[i-1]=(ll)f[i]*i%P;
        h[n-1]=0;
        inv(f,g,n);
        mul(h,g,g,n);
        for (int i=n-1;i>=1;i--)
            g[i]=(ll)g[i-1]*pw(i,P-2)%P;
        g[0]=0;
    }
    void exp(int f[],int g[],int n)
    {
        //f[0]==0
        if (n==1)
            g[0]=1;
        else
        {
            static int h[N];
            exp(f,g,(n+1)/2);
            for (int i=(n+1)/2;i<n;i++)
                g[i]=0;
            ln(g,h,n);
            for (int i=0;i<n;i++)
                h[i]=(f[i]+P-h[i])%P;
            h[0]=(h[0]+1)%P;
            mul(g,h,g,n);
        }
    }
}
//using namespace poly;
bool diff(int a[],int b[],int n)
{
    //for (int i=0;i<=n;i++) printf("%d ",a[i]);puts("");
    //for (int i=0;i<=n;i++) printf("%d ",b[i]);puts("");
    for (int i=0;i<=n;i++)
        if (a[i]!=b[i])
        {
            printf("WA at %d !\n",i);
            //system("pause");
            return 1;
        }
    printf("AC\n");
    return 0;
}
#define N 400005
#define ll long long
#define P 998244353
int n,a[N],b[N],c[N],d[N];
int Rd(){return rand()%P;}
int main()
{
    srand(time(0)+(unsigned long long)(new char));
    for (int hz=0;hz<10;hz++)
    {
        //n=rand()+1
        n=100000;
        /*
        for (int i=0;i<n;i++)
            a[i]=b[i]=Rd();
        int tn=1;
        while(tn<n)tn<<=1;
        for (int i=n;i<tn;i++)a[i]=0;
        poly::dft(a,tn,0);
        poly::dft(a,tn,1);
        diff(a,b,n-1);

        for (int i=0;i<=n;i++)
            a[i]=Rd(),b[i]=Rd();
        poly::mul(a,b,c,n+1);
        for (int i=0;i<=n*2;i++) d[i]=0;
        for (int i=0;i<=n;i++)
        for (int j=0;j<=n;j++)
            d[i+j]=(d[i+j]+(ll)a[i]*b[j])%P;
        diff(c,d,n*2);


        for (int i=0;i<n;i++)
            a[i]=Rd();
        poly::inv(a,b,n);
        poly::mul(a,b,c,n);
        d[0]=1;
        for (int i=1;i<n;i++) d[i]=0;
        diff(c,d,n-1);

        for (int i=0;i<n;i++)
            a[i]=b[i]=Rd();
        a[0]=b[0]=1;
        poly::ln(a,c,n);
        poly::exp(c,a,n);
        diff(a,b,n-1);
        */
    }
}