4
\$\begingroup\$

4sum problem

Given an array S of n integers, are there elements a, b, c, and d in S such that a + b + c + d = target? Find all unique quadruplets in the array which gives the sum of target.

Note: The solution set must not contain duplicate quadruplets.

Idea is to put all the pair sums a in hashmap along with corresponding indexes and once done check if -a is also present in the hashmap. If both a and -a is present and since the question is looking for unique quadruplets then we can just filter out with indexes.

class Solution(object): def fourSum(self, arr, target): seen = {} for i in range(len(arr)-1): for j in range(i+1, len(arr)): if arr[i]+arr[j] in seen: seen[arr[i]+arr[j]].add((i,j)) else: seen[arr[i]+arr[j]] = {(i,j)} result = [] for key in seen: if -key + target in seen: for (i,j) in seen[key]: for (p,q) in seen[-key + target]: sorted_index = sorted([arr[i], arr[j], arr[p], arr[q]]) if i not in (p, q) and j not in (p, q) and sorted_index not in result: result.append(sorted_index) return result 
\$\endgroup\$

    2 Answers 2

    3
    \$\begingroup\$
    • Use enumerate rather than range(len(...)) + __getitem__. It is both faster and more readable.
    • To limit items of the second iteration to be "after the current item" you can use itertools.combinations.
    • To avoid the need to check for the special case of "is the item already in the dictionary?", use a collections.defaultdict.
    • You could use a set rather than a list to store the final results and remove yourself the need to check for duplicates
    • -key + target is better written as target - key

    import itertools from collections import defaultdict def four_sum(array, target): seen = defaultdict(set) for (i, first), (j, second) in itertools.combinations(enumerate(array), 2): seen[first + second].add((i, j)) result = set() for key, first_indices in seen.items(): second_indices = seen.get(target - key, set()) for p, q in second_indices: for i, j in first_indices: # Not reusing the same number twice if not ({i, j} & {p, q}): indices = tuple(sorted(array[x] for x in (i, j, p, q))) result.add(indices) return result 
    \$\endgroup\$
    2
    • \$\begingroup\$Yours is actually slower compared to OP's on leetcode, I must agree it is more readable though. Yours: 335ms OP: 239ms. It must return a list, so I've changed it a bit, but still didn;t really expect that. :)\$\endgroup\$CommentedNov 27, 2017 at 9:51
    • \$\begingroup\$Note: The solution set must not contain duplicate quadruplets. Yeah, online judges and their requirements matching their specs…\$\endgroup\$CommentedNov 27, 2017 at 9:53
    2
    \$\begingroup\$

    Implementation

    • why not build result with condition i < j < p < q?

    Algorithm

    • code builds hash map as combination of all indexes from nums. Combination of all unique values from nums (or index or unique values) is better choice. Case: fourSum([0 for x in range(n)], 0)
    • code builds hash map with integers from nums which can't be added to result. Case: fourSum([x for x in range(1, n, 1)], 0)
    • code check if for key from hash map also target - key exists in final loop, can earlier. Case: fourSum([x for x in range(0, n*10, 10)], n*5+1)
    • You can split hash map for two parts: a,b and c,d pair. Don't change complexity of hash map, but final loop: 1/2 * 1/2 faster

    Speedup

    • best: algorithm (big O notation), e.g. reduce O(n^2) memory to O(n)
    • sometimes good: algorithm constants, e.g. split hash map for first and second pair
    • bad: dirty, low-level language speed-up constants, e.g. replace itertools.combinations with directly loops. This is anti-pattern. Reasons: less understandable, maintainable, changeable and paradoxically slower. Slower because bottlenecks are usually caused by cascade several algorithms, e.g. O(n^3) * O(n^3). With clean code easier reduce problem to O(n^5) or less. With dirty code usually at the end we get O(n^6) with small const

    Code (the same O(n^2) mem)

    from itertools import combinations from collections import defaultdict, Counter def fourSum(self, nums, target): if len(nums) < 4: return [] half_target = target // 2 counter = Counter(nums) uniques_wide = sorted(counter) x_min, x_max = target - 3 * uniques_wide[-1], target - 3 * uniques_wide[0] # bad uniques = [ x for x in uniques_wide if x_min <= x <= x_max ] duplicates = [x for x in uniques if counter[x] > 1] target_minus_xy_sums = set(target - x - y for x, y in combinations(uniques, 2)) target_minus_xy_sums |= set(target - x - x for x in duplicates) ab_sum_pairs, cd_sum_pairs = defaultdict(list), defaultdict(list) for (x, y) in combinations(uniques, 2): if x + y in target_minus_xy_sums: if x + y <= half_target: ab_sum_pairs[x + y].append((x, y)) if x + y >= half_target: cd_sum_pairs[x + y].append((x, y)) for x in duplicates: if x + x in target_minus_xy_sums: if x + x <= half_target: ab_sum_pairs[x + x].append((x, x)) if x + x >= half_target: cd_sum_pairs[x + x].append((x, x)) return [[a, b, c, d] for ab in ab_sum_pairs for (a, b) in ab_sum_pairs[ab] for (c, d) in cd_sum_pairs[target - ab] if b < c or b == c and [a, b, c, d].count(b) <= counter[b]] # if bi < ci 
    \$\endgroup\$

      Start asking to get answers

      Find the answer to your question by asking.

      Ask question

      Explore related questions

      See similar questions with these tags.