Author : Dipu Kumar Mohanto
CSE, Batch - 6
BRUR.
Problem Statement : Counting Triplets
Source : toph.co
Category : Data Structure
Algorithm : FFT with Block Decomposition
Verdict : Accepted
#include <bits/stdc++.h>
using namespace std;
typedef double ld;
typedef long long ll;
const ld PI = 2*acos(0.0);
const int MAXN = 1e5 + 5;
const int N = 1<<17;
struct base
{
ld a, b;
base() : a(0.0), b(0.0) {}
base(ld aa) : a(aa), b(0.0) {}
base(ld aa, ld bb) : a(aa), b(bb) {}
inline base operator + (const base &c)
{
return base(a + c.a, b + c.b);
}
inline base operator - (const base &c)
{
return base(a - c.a, b - c.b);
}
inline base operator * (const base &c)
{
return base(a * c.a - b * c.b, a * c.b + b * c.a);
}
};
base w_pre[N|1], w[N|1];
ll rev[N];
void calcw()
{
for (int i = 0; i <= N; i++)
{
w_pre[i] = base(cos(2*PI/N*i), sin(2*PI/N*i));
}
}
void calcrev(int n)
{
int sz = 31 - __builtin_clz(n);
sz = abs(sz);
for (int i = 1; i < n-1; i++)
{
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << sz - 1);
}
}
void fft(base *p, int n, int dir)
{
for (int i = 1; i < n-1; i++)
{
if (i < rev[i])
{
swap(p[i], p[rev[i]]);
}
}
for (int h = 1; h < n; h <<= 1)
{
int l = h << 1;
if (!dir)
{
for (int j = 0; j < h; j++)
{
w[j] = w_pre[N/l*j];
}
}
else
{
for (int j = 0; j < h; j++)
{
w[j] = w_pre[N - N/l*j];
}
}
for (int j = 0; j < n; j += l)
{
base t, *wn = w;
base *u = p + j, *v = u + h, *e = v;
while (u != e)
{
t = *v * *wn;
*v = *u - t;
*u = *u + t;
++u, ++v, ++wn;
}
}
}
if (dir)
{
for (int i = 0; i < n; i++)
{
p[i].a /= n, p[i].b /= n;
}
}
}
#define BLOCK_SIZE 700
#define fil(a, b) fill(begin(a), end(a), b)
base p[N], q[N], r[N];
ll res[N], arr[MAXN];
ll bef[N], inside[N];
int main()
{
//freopen("in.txt", "r", stdin);
calcw();
calcrev(N);
int tc;
scanf("%d", &tc);
for (int tcase = 1; tcase <= tc; tcase++)
{
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%lld", arr+i);
fil(bef, 0);
fil(inside, 0);
ll ans = 0;
for (int block = 1; block <= n; block += BLOCK_SIZE)
{
int st = block;
int ed = block + BLOCK_SIZE - 1;
if (ed > n) ed = n;
for (int ii = st; ii <= ed; ii++)
{
for (int jj = ii+1; jj <= ed; jj++)
{
// j, k
ll Aj = arr[ii];
ll Ak = arr[jj];
ll Ai = Ak - Aj;
if (Ai >= 0 && Ai < N)
ans += bef[Ai] + inside[Ai];
}
inside[ arr[ii] ]++;
}
if (st > 1) // middle and last block
{
for (int i = 0; i < N; i++)
{
p[i] = base();
//q[i] = base();
r[i] = base();
res[i] = 0;
}
for (int i = 0; i < N; i++)
{
p[i] = base(bef[i]);
//q[i] = base(bef[i]);
}
fft(p, N, 0);
//fft(q, N, 0);
for (int i = 0; i < N; i++)
r[i] = p[i] * p[i];
fft(r, N, 1);
for (int i = 0; i < N; i++)
res[i] = floor(r[i].a + 0.5);
for (int i = st; i <= ed; i++)
{
ll Ak = arr[i];
if (Ak < N)
{
ll add = res[Ak];
add -= ((Ak%2 == 0) ? bef[Ak >> 1] : 0);
ans += (add >> 1);
}
}
}
for (int i = st; i <= ed; i++)
{
bef[ arr[i] ]++;
inside[ arr[i] ]--;
}
}
printf("%lld\n", ans);
}
}
No comments:
Post a Comment
Note: Only a member of this blog may post a comment.