Monday, December 10, 2018

[toph.co] Counting Triplets

Author            : Dipu Kumar Mohanto 
                    CSE, Batch - 6
                    BRUR.
Problem Statement : Counting Triplets
Source            : toph.co
Category          : Data Structure
Algorithm         : FFT with Block Decomposition
Verdict           : Accepted
#include <bits/stdc++.h>
 
using namespace std;
 
typedef double ld;
typedef long long   ll;
 
const ld PI    = 2*acos(0.0);
const int MAXN = 1e5 + 5;
const int N    = 1<<17;
 
struct base
{
    ld a, b;
    base()             : a(0.0), b(0.0) {}
    base(ld aa)        : a(aa), b(0.0) {}
    base(ld aa, ld bb) : a(aa), b(bb) {}
 
    inline base operator + (const base &c)
    {
        return base(a + c.a, b + c.b);
    }
    inline base operator - (const base &c)
    {
        return base(a - c.a, b - c.b);
    }
    inline base operator * (const base &c)
    {
        return base(a * c.a - b * c.b, a * c.b + b * c.a);
    }
};
 
base w_pre[N|1], w[N|1];
ll rev[N];
 
void calcw()
{
    for (int i = 0; i <= N; i++)
    {
        w_pre[i] = base(cos(2*PI/N*i), sin(2*PI/N*i));
    }
}
 
void calcrev(int n)
{
    int sz = 31 - __builtin_clz(n);
    sz = abs(sz);
 
    for (int i = 1; i < n-1; i++)
    {
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << sz - 1);
    }
}
 
void fft(base *p, int n, int dir)
{
    for (int i = 1; i < n-1; i++)
    {
        if (i < rev[i])
        {
            swap(p[i], p[rev[i]]);
        }
    }
    for (int h = 1; h < n; h <<= 1)
    {
        int l = h << 1;
 
        if (!dir)
        {
            for (int j = 0; j < h; j++)
            {
                w[j] = w_pre[N/l*j];
            }
        }
        else
        {
            for (int j = 0; j < h; j++)
            {
                w[j] = w_pre[N - N/l*j];
            }
        }
        for (int j = 0; j < n; j += l)
        {
            base t, *wn = w;
            base *u = p + j, *v = u + h, *e = v;
            while (u != e)
            {
                t = *v * *wn;
                *v = *u - t;
                *u = *u + t;
                ++u, ++v, ++wn;
            }
        }
    }
    if (dir)
    {
        for (int i = 0; i < n; i++)
        {
            p[i].a /= n, p[i].b /= n;
        }
    }
}
#define BLOCK_SIZE    700
#define fil(a, b)     fill(begin(a), end(a), b)
 
base p[N], q[N], r[N];
ll res[N], arr[MAXN];
ll bef[N], inside[N];
 
int main()
{
    //freopen("in.txt", "r", stdin);
 
    calcw();
    calcrev(N);
 
    int tc;
    scanf("%d", &tc);
    for (int tcase = 1; tcase <= tc; tcase++)
    {
        int n;
        scanf("%d", &n);
        for (int i = 1; i <= n; i++)
            scanf("%lld", arr+i);
        fil(bef, 0);
        fil(inside, 0);
        ll ans = 0;
        for (int block = 1; block <= n; block += BLOCK_SIZE)
        {
            int st = block;
            int ed = block + BLOCK_SIZE - 1;
            if (ed > n) ed = n;
            for (int ii = st; ii <= ed; ii++)
            {
                for (int jj = ii+1; jj <= ed; jj++)
                {
                    // j, k
                    ll Aj = arr[ii];
                    ll Ak = arr[jj];
                    ll Ai = Ak - Aj;
                    if (Ai >= 0 && Ai < N)
                        ans += bef[Ai] + inside[Ai];
                }
                inside[ arr[ii] ]++;
            }
            if (st > 1)  // middle and last block
            {
                for (int i = 0; i < N; i++)
                {
                    p[i] = base();
                    //q[i] = base();
                    r[i] = base();
                    res[i] = 0;
                }
                for (int i = 0; i < N; i++)
                {
                    p[i] = base(bef[i]);
                    //q[i] = base(bef[i]);
                }
                fft(p, N, 0);
                //fft(q, N, 0);
                for (int i = 0; i < N; i++)
                    r[i] = p[i] * p[i];
 
                fft(r, N, 1);
 
                for (int i = 0; i < N; i++)
                    res[i] = floor(r[i].a + 0.5);
 
                for (int i = st; i <= ed; i++)
                {
                    ll Ak = arr[i];
                    if (Ak < N)
                    {
                        ll add = res[Ak];
                        add -= ((Ak%2 == 0) ? bef[Ak >> 1] : 0);
                        ans += (add >> 1);
                    }
                }
            }
            for (int i = st; i <= ed; i++)
            {
                bef[ arr[i] ]++;
                inside[ arr[i] ]--;
            }
        }
        printf("%lld\n", ans);
    }
}

No comments:

Post a Comment

Note: Only a member of this blog may post a comment.