RUPC 2016 Day2 K Sum of Sequences
こういう計算コストの減らし方は面白い.定数倍が少しきつめ?
- 問題概要
n個の要素からなる数列aとm個の要素からなる数列bが与えられる.
また,以下のようなクエリがq個与えられる.
クエリ
整数cが与えられる,
数列aの部分列]の和と数列bの部分列]の和の差の絶対値がcとなる
の組み合わせを求めよ.
- 制約
- 解法
まず,数列の部分列の和を普通に全て計算すると,計算量はでTLEする.
このため,なんらかの方法でより高速に求める必要がある.
ここで,数列の各要素が高々5であることに注目する.
数列の長さの制約から部分列の和の最大値はとなる.
また,累積和を事前に計算しておくと2つの累積和の差で部分列の和を求めることができる.
これらと畳み込みを利用して差をまとめて計算することを考える.
以下ではd+(-e)という計算を行うと考える.
畳み込み
今回の場合は,dの負の部分とeの正の部分は必要ないので,次の計算を考えると良い.
畳み込み
こうすると,最初の例と比べて数列の長さを半分にできる.
#include <bits/stdc++.h> #define _overload(_1,_2,_3,name,...) name #define _rep(i,n) _range(i,0,n) #define _range(i,a,b) for(int i=int(a);i<int(b);++i) #define rep(...) _overload(__VA_ARGS__,_range,_rep,)(__VA_ARGS__) #define _rrep(i,n) _rrange(i,n,0) #define _rrange(i,a,b) for(int i=int(a)-1;i>=int(b);--i) #define rrep(...) _overload(__VA_ARGS__,_rrange,_rrep,)(__VA_ARGS__) #define _all(arg) begin(arg),end(arg) #define uniq(arg) sort(_all(arg)),(arg).erase(unique(_all(arg)),end(arg)) #define getidx(ary,key) lower_bound(_all(ary),key)-begin(ary) #define clr(a,b) memset((a),(b),sizeof(a)) #define bit(n) (1LL<<(n)) template<class T>bool chmax(T &a, const T &b) { return (a<b)?(a=b,1):0;} template<class T>bool chmin(T &a, const T &b) { return (b<a)?(a=b,1):0;} using namespace std; using ll=long long; inline ll extgcd(ll a,ll b,ll& x,ll& y){x=1,y=0;ll g=a;if(b!=0) g=extgcd(b,a%b,y,x),y-=a/b*x;return g;} inline ll ADD(const ll &a, const ll &b,const ll &mod) { return a+b<mod?a+b:a+b-mod;} inline ll SUB(const ll &a, const ll &b,const ll &mod) { return a-b>=0?a-b:a-b+mod;} inline ll MUL(const ll &a, const ll &b,const ll &mod) { return (1LL*a*b)%mod;} inline ll INV(ll a,ll mod){ll x,y;extgcd(a,mod,x,y);return (x%mod+mod)%mod;} inline ll DIV(const ll &a, const ll &b,const ll &mod) {return MUL(a,INV(b,mod),mod);} inline ll POW(ll a,ll n,ll mod){ll b=1LL;for(a%=mod;n;a=MUL(a,a,mod),n>>=1)if(n&1) b=MUL(b,a,mod); return b;} const ll mod[3]={998244353,897581057,645922817}; void ntt(vector<ll> &a,bool inv,ll mod){ const int n=a.size(); ll base=POW(3LL,(mod-1)/n,mod); if(inv) base=INV(base,mod); int rj=0; rep(j,1,n-1){ for(int k=n>>1;k>(rj^=k);k>>=1); if(j<rj) swap(a[j],a[rj]); } for(int m=1;m<n;m<<=1){ const ll wn=POW(base,n/2/m,mod); ll w=1LL; rep(p,m){ for(int s=p;s<n;s+=2*m){ ll u=a[s],v=MUL(a[s+m],w,mod); a[s]=ADD(u,v,mod),a[s+m]=SUB(u,v,mod); } w=MUL(w,wn,mod); } } const ll n_inv=INV(n,mod); if(inv) rep(i,n) a[i]=MUL(a[i],n_inv,mod); } vector<ll> convolution(vector<ll> a,vector<ll> b,ll mod){ int ntt_size=1; while (ntt_size < a.size()+b.size()) ntt_size <<= 1; a.resize(ntt_size),ntt(a,0,mod); b.resize(ntt_size),ntt(b,0,mod); vector<ll> c(ntt_size,0LL); rep(i,ntt_size) c[i]=MUL(a[i],b[i],mod); ntt(c,1,mod); return c; } int a[40010],b[40010]; const int limit=200000; int main(void){ int n,m,q; scanf("%d%d%d",&n,&m,&q); rep(i,n) scanf("%d",a+i); rep(i,m) scanf("%d",b+i); vector<ll> ares,bres,ans[3]; vector<ll> plus(limit+1),minus(limit+1); { fill(_all(plus),0),fill(_all(minus),0); int sum=0; plus[sum]++,minus[limit-sum]++; rep(i,n) sum+=a[i],plus[sum]++,minus[limit-sum]++; ares=convolution(plus,minus,mod[0]); } { fill(_all(plus),0),fill(_all(minus),0); int sum=0; plus[sum]++,minus[limit-sum]++; rep(i,m) sum+=b[i],plus[sum]++,minus[limit-sum]++; bres=convolution(plus,minus,mod[0]); } { fill(_all(plus),0),fill(_all(minus),0); rep(i,1,limit+1) plus[i]=ares[i+limit],minus[limit-i]=bres[i+limit]; rep(i,3) ans[i]=convolution(plus,minus,mod[i]); } auto chinese_remainder=[&](int idx){ const int n=3; const ll b[n]={ans[0][idx],ans[1][idx],ans[2][idx]}; const ll m[n]={mod[0],mod[1],mod[2]}; vector<ll> constant(n,0LL),coef(n,1LL),v(n,0LL); rep(i,n){ v[i]=SUB(b[i],constant[i],m[i]); v[i]=DIV(v[i],coef[i],m[i]); rep(j,i+1,n){ constant[j]=ADD(constant[j],MUL(v[i],coef[j],m[j]),m[j]); coef[j]=MUL(coef[j],m[i],m[j]); } } ll ret=0LL; rrep(i,n){ ret*=m[i]; ret+=v[i]; } return ret; }; rep(loop,q){ int c; scanf("%d",&c); ll res=0; if(c==0) res=chinese_remainder(limit); else res=chinese_remainder(c+limit)+chinese_remainder(-c+limit); cout << res << endl; } return 0; }