Skip to content
hotman edited this page Oct 25, 2020 · 1 revision

FFT

template<typename T>
class FFT{
    using u64 = std::uint_fast64_t;
    using real=long double;
	using C=complex<real>;
	int n=1;
    public:
	vector<T> mul(const vector<T>& a,const vector<T>& b){
		const int size=a.size()+b.size()-1;
        int h=0;
		while(n<size)n<<=1,++h;
		auto c=itoc(b,a,n);
		auto ic=fft(c,h,false);
		vector<C> ires(n);
		for(int i=0;i<n;++i){
            int j=i==0?0:n-i;
            ires[i]=(ic[i]+conj(ic[j]))*(ic[i]-conj(ic[j]))*C(0,-.25);
        }
		auto res=ctoi(fft(ires,h,true),size);
		return res;
	}
	private:
	vector<C> itoc(const vector<T>& s,const vector<T>& t,const int& n){
		vector<C> res(n);
		for(int i=0;i<n;++i)res[i]=C(i<(int)s.size()?cast(s[i]):0.,i<(int)t.size()?cast(t[i]):0);
		return res;
	}
	vector<T> ctoi(const vector<C>& v,const int& size){
		vector<T> res(size);
		for(int i=0;i<min<int>(size,v.size());i++)res[i]=recast(v[i].real());
		return res;
	}
	vector<C> fft(vector<C> v,const int& h,const bool& inv){
		int n=v.size(),mask=n-1;
		assert((n&(n-1))==0);
		vector<C> tmp(n);
        C table[n];
        real theta =2*M_PI*(inv?-1:1)/n;
        for(int i=0;i<n;++i)table[i]=C(cos(i*theta),sin(i*theta));
        for(int j=n>>1,t=h-1;j>=1;j>>=1,--t){
            for(int k=0;k<n;++k){
                int s=k&(j-1); // T は 下 t 桁
                int i=k>>t; // i は 上 h - t 桁
                tmp[k]=v[((i<<(t+1))|s)&mask]+table[i*j]*v[((i<<(t+1))|j|s)&mask];
                // ζ_(2^(h - t))^i
            }
            swap(v,tmp);
        }
        if(inv)for(int i=0;i<n;++i)v[i]/=n;
        return v;
	}
    inline real cast(const T& t){
        return t.a;
    }
    inline T recast(const real& t){
        return round(t-floor(t/MOD)*MOD);
    }
};

Clone this wiki locally