Example #1: [Leetcode] Count of Smaller Numbers After Self (315)
Solution: https://docs.google.com/document/d/1kfn4MfnLLXVNFfayX6L6sjAsZSuUy6gJoJit3sngG-E/edit?usp=sharing
- Segment Tree
- Top-Down Recursion
- Top-Down Iteration
- Bottom-Up Iteration
Code (Top-Down Recursion):
class Solution { public List<Integer> countSmaller(int[] nums) { LinkedList<Integer> output = new LinkedList<>(); Map<Integer, Integer> rankMap = getRankMap(nums); SegmentTreeNode tree = buildTree(0, nums.length - 1); for (int i = nums.length - 1; i >= 0; i--) { int rank = rankMap.get(nums[i]); output.addFirst(sum(tree, 0, rank - 1)); add(tree, rank, 1); } return output; } private Map<Integer, Integer> getRankMap(int[] nums) { int[] sorted = Arrays.copyOf(nums, nums.length); Arrays.sort(sorted); Map<Integer, Integer> map = new HashMap<>(); for (int i = 0; i < sorted.length; i++) { map.put(sorted[i], i); } return map; } private SegmentTreeNode buildTree(int lower, int upper) { if (lower > upper) { return null; } else { SegmentTreeNode node = new SegmentTreeNode(lower, upper); if (lower < upper) { int mid = lower + (upper - lower) / 2; node.left = buildTree(lower, mid); node.right = buildTree(mid + 1, upper); } return node; } } private int sum(SegmentTreeNode node, int lower, int upper) { if (lower > upper) { return 0; } if (node.lower == lower && node.upper == upper) { return node.sum; } int mid = node.lower + (node.upper - node.lower) / 2; if (upper <= mid) { return sum(node.left, lower, upper); } else if (lower > mid) { return sum(node.right, lower, upper); } else { return sum(node.left, lower, mid) + sum(node.right, mid + 1, upper); } } private void add(SegmentTreeNode node, int x, int val) { if (node.lower == x && node.upper == x) { node.sum += val; return; } int mid = node.lower + (node.upper - node.lower) / 2; if (x <= mid) { add(node.left, x, val); } else { add(node.right, x, val); } node.sum = node.left.sum + node.right.sum; } private class SegmentTreeNode { public int lower; public int upper; public int sum; public SegmentTreeNode left; public SegmentTreeNode right; public SegmentTreeNode(int lower, int upper) { this.lower = lower; this.upper = upper; this.sum = 0; } } }
Code (Top-Down Iteration):
class Solution { public List<Integer> countSmaller(int[] nums) { LinkedList<Integer> output = new LinkedList<>(); Map<Integer, Integer> rankMap = getRankMap(nums); int[] tree = new int[4 * nums.length]; for (int i = nums.length - 1; i >= 0; i--) { int rank = rankMap.get(nums[i]); output.addFirst(sum(tree, 1, 0, nums.length - 1, 0, rank - 1)); add(tree, 1, 0, nums.length - 1, rank, 1); } return output; } private Map<Integer, Integer> getRankMap(int[] nums) { int[] sorted = Arrays.copyOf(nums, nums.length); Arrays.sort(sorted); Map<Integer, Integer> map = new HashMap<>(); for (int i = 0; i < sorted.length; i++) { map.put(sorted[i], i); } return map; } private void add(int[] tree, int node, int lower, int upper, int i, int val) { if (lower > upper || i < lower || i > upper) { return; } if (lower == upper) { tree[node] += val; return; } int mid = lower + (upper - lower) / 2; add(tree, 2 * node, lower, mid, i, val); add(tree, 2 * node + 1, mid + 1, upper, i, val); tree[node] = tree[2 * node] + tree[2 * node + 1]; } private int sum(int[] tree, int node, int lower, int upper, int i, int j) { if (lower > upper || j < lower || i > upper) { return 0; } if (lower >= i && upper <= j) { return tree[node]; } int mid = lower + (upper - lower) / 2; if (j <= mid) { return sum(tree, 2 * node, lower, mid, i, j); } else if (i > mid) { return sum(tree, 2 * node + 1, mid + 1, upper, i, j); } else { return sum(tree, 2 * node, lower, mid, i, j) + sum(tree, 2 * node + 1, mid + 1, upper, i, j); } } }
Code (Bottom-Up Iteration):
class Solution { public List<Integer> countSmaller(int[] nums) { LinkedList<Integer> output = new LinkedList<>(); Map<Integer, Integer> rankMap = getRankMap(nums); int[] tree = new int[2 * nums.length]; for (int i = nums.length - 1; i >= 0; i--) { int rank = rankMap.get(nums[i]); output.addFirst(sum(tree, nums.length, 0, rank - 1)); add(tree, nums.length, rank, 1); } return output; } private Map<Integer, Integer> getRankMap(int[] nums) { int[] sorted = Arrays.copyOf(nums, nums.length); Arrays.sort(sorted); Map<Integer, Integer> map = new HashMap<>(); for (int i = 0; i < sorted.length; i++) { map.put(sorted[i], i); } return map; } private int sum(int[] tree, int n, int i, int j) { i += n; j += n; int output = 0; while (i <= j) { if (i % 2 == 1) { output += tree[i++]; } if (j % 2 == 0) { output += tree[j--]; } i /= 2; j /= 2; } return output; } private void add(int[] tree, int n, int i, int val) { i += n; tree[i] += val; for (i /= 2; i > 0; i /= 2) { tree[i] = tree[2 * i] + tree[2 * i + 1]; } } }
References: