Sunday, March 5, 2023

[LeetCode] 1803. Count Pairs With XOR in a Range


Problem: 1803. Count Pairs With XOR in a Range
Category: DS
ALgorithm: Trie
class Solution {
private:
    static const int MAXB = 18;
 
    struct Trie {
        int counter;
        Trie *child[2];
 
        Trie() {
            counter = 0;
            child[0] = child[1] = nullptr;
        }
        ~Trie() {
            for (int i = 0; i < 2; i++) {
                if (child[i] != nullptr and this != child[i]) {
                    delete child[i];
                }
            }
        }
    };
 
    typedef Trie* pnode;
    pnode root;
 
    void add(int num) {
        pnode curRoot = root;
        for (int i = MAXB; i >= 0; i--) {
            bool bit = num >> i & 1;
            if (curRoot->child[bit] == nullptr) {
                curRoot->child[bit] = new Trie();
            }
            curRoot = curRoot->child[bit];
            curRoot->counter++;
        }
    }
 
    int countLessEqual(int cur, int compareWith, bool lessEqual = true) {
        pnode curRoot = root;
        int cnt = 0;
        for (int i = MAXB; i >= 0; i--) {
            bool curBit = cur >> i & 1;
            bool compareWithBit = compareWith >> i & 1;
            if (curBit == 0) {
                if (compareWithBit == 0) {
                    if (curRoot->child[0] == nullptr) {
                        break;
                    }
                    curRoot = curRoot->child[0];
                } else {
                    if (curRoot->child[0] != nullptr) {
                        cnt += curRoot->child[0]->counter;
                    }
                    if (curRoot->child[1] == nullptr) {
                        break;
                    }
                    curRoot = curRoot->child[1];
                }
            } else {
                if (compareWithBit == 0) {
                    if (curRoot->child[1] == nullptr) {
                        break;
                    }
                    curRoot = curRoot->child[1];
                } else {
                    if (curRoot->child[1] != nullptr) {
                        cnt += curRoot->child[1]->counter;
                    }
                    if (curRoot->child[0] == nullptr) {
                        break;
                    }
                    curRoot = curRoot->child[0];
                }
            }
            if (i == 0 and lessEqual == true) {
                cnt += curRoot->counter;
            }
        }
        return cnt;
    }
 
    int countEqual(int cur, int compareWith) {
        pnode curRoot = root;
        int cnt = 0;
        for (int i = MAXB; i >= 0; i--) {
            int curBit = cur >> i & 1;
            int compareWithBit = compareWith >> i & 1;
            if (curRoot->child[curBit ^ compareWithBit] == nullptr) {
                break;
            }
            curRoot = curRoot->child[curBit ^ compareWithBit];
 
            if (i == 0) {
                cnt = curRoot->counter;
            }
        }
        return cnt;
    }
 
    int countLess(int cur, int compareWith) {
        return countLessEqual(cur, compareWith, false);
    }
 
public:
    int countPairs(vector<int>& nums, int low, int high) {
        root = new Trie();
        int ans = 0;
        for (int i = 0; i < nums.size(); i++) {
            int num = nums[i];
            int lessEqual = countLessEqual(num, high);
            int lesss = countLess(num, low);
            ans += lessEqual;
            ans -= lesss;
            add(num);
        }
        return ans;
    }
};

No comments:

Post a Comment

Note: Only a member of this blog may post a comment.