Problem: 1803. Count Pairs With XOR in a Range
Category: DS
ALgorithm: Trie
class Solution {
private:
static const int MAXB = 18;
struct Trie {
int counter;
Trie *child[2];
Trie() {
counter = 0;
child[0] = child[1] = nullptr;
}
~Trie() {
for (int i = 0; i < 2; i++) {
if (child[i] != nullptr and this != child[i]) {
delete child[i];
}
}
}
};
typedef Trie* pnode;
pnode root;
void add(int num) {
pnode curRoot = root;
for (int i = MAXB; i >= 0; i--) {
bool bit = num >> i & 1;
if (curRoot->child[bit] == nullptr) {
curRoot->child[bit] = new Trie();
}
curRoot = curRoot->child[bit];
curRoot->counter++;
}
}
int countLessEqual(int cur, int compareWith, bool lessEqual = true) {
pnode curRoot = root;
int cnt = 0;
for (int i = MAXB; i >= 0; i--) {
bool curBit = cur >> i & 1;
bool compareWithBit = compareWith >> i & 1;
if (curBit == 0) {
if (compareWithBit == 0) {
if (curRoot->child[0] == nullptr) {
break;
}
curRoot = curRoot->child[0];
} else {
if (curRoot->child[0] != nullptr) {
cnt += curRoot->child[0]->counter;
}
if (curRoot->child[1] == nullptr) {
break;
}
curRoot = curRoot->child[1];
}
} else {
if (compareWithBit == 0) {
if (curRoot->child[1] == nullptr) {
break;
}
curRoot = curRoot->child[1];
} else {
if (curRoot->child[1] != nullptr) {
cnt += curRoot->child[1]->counter;
}
if (curRoot->child[0] == nullptr) {
break;
}
curRoot = curRoot->child[0];
}
}
if (i == 0 and lessEqual == true) {
cnt += curRoot->counter;
}
}
return cnt;
}
int countEqual(int cur, int compareWith) {
pnode curRoot = root;
int cnt = 0;
for (int i = MAXB; i >= 0; i--) {
int curBit = cur >> i & 1;
int compareWithBit = compareWith >> i & 1;
if (curRoot->child[curBit ^ compareWithBit] == nullptr) {
break;
}
curRoot = curRoot->child[curBit ^ compareWithBit];
if (i == 0) {
cnt = curRoot->counter;
}
}
return cnt;
}
int countLess(int cur, int compareWith) {
return countLessEqual(cur, compareWith, false);
}
public:
int countPairs(vector<int>& nums, int low, int high) {
root = new Trie();
int ans = 0;
for (int i = 0; i < nums.size(); i++) {
int num = nums[i];
int lessEqual = countLessEqual(num, high);
int lesss = countLess(num, low);
ans += lessEqual;
ans -= lesss;
add(num);
}
return ans;
}
};
No comments:
Post a Comment
Note: Only a member of this blog may post a comment.