Coverage for blog/dsa/leetcode/sum_of_prefix_scores/__init__.py: 56%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-20 16:23 +0000

1import pathlib 

2 

3import pytest 

4import yaml 

5 

6from dsa.leetcode.longest_pfx import TrieNode 

7 

8 

9# start snippet solution_initial 

10class SolutionInitial: 

11 

12 def sumPrefixScores(self, words: list[str]) -> list[int]: 

13 

14 def countOne(pfx: str, *, node: TrieNode | None = None) -> int: 

15 """Count for a specific prefix.""" 

16 

17 if pfx in memo: 

18 return memo[pfx] 

19 

20 # NOTE: Traverse until prefix is done. If the prefix cannot be 

21 # traversed through, then no words start with the prefix. 

22 if node is None: 

23 node = root 

24 for char in pfx: 

25 if char not in node.children: 

26 memo[pfx] = 0 

27 return 0 

28 

29 node = node.children[char] 

30 

31 # NOTE: If the current node is a terminating node, then one word 

32 # is matched. 

33 count = 0 

34 if node.terminates: 

35 count += node.terminates 

36 

37 # NOTE: Count the number of terminating nodes bellow path. This 

38 # should be the sum for each subtree. 

39 for char in node.children: 

40 count += countOne(pfx + char, node=node.children[char]) 

41 

42 memo[pfx] = count 

43 return count 

44 

45 memo: dict[str, int] = dict() 

46 root = TrieNode(dict()) 

47 for word in words: 

48 root.insert(word) 

49 

50 out = [] 

51 for word in words: 

52 count = 0 

53 for end in range(len(word), 0, -1): 

54 count += countOne(word[:end]) 

55 

56 out.append(count) 

57 

58 return out 

59 # end snippet solution_initial 

60 

61 

62# start snippet solutiond 

63class Solution: 

64 

65 def sumPrefixScores(self, words: list[str]) -> list[int]: 

66 root = TrieNode(dict()) 

67 for word in words: 

68 

69 node = root 

70 for char in word: 

71 

72 if char not in node.children: 

73 new_node = TrieNode(dict(), terminates=1) 

74 node.children[char] = new_node 

75 node = new_node 

76 else: 

77 node = node.children[char] 

78 node.terminates += 1 

79 

80 def count_path(word: str): 

81 node = root 

82 count = 0 

83 for char in word: 

84 node = node.children[char] 

85 count += node.terminates 

86 

87 return count 

88 

89 return [count_path(word) for word in words] 

90 # end snippet solutiond 

91 

92 

93with open(pathlib.Path(__file__).resolve().parent / "cases.yaml") as file: 

94 cases = list(item.values() for item in yaml.safe_load(file)) 

95 

96 

97@pytest.fixture 

98def solution(): 

99 return Solution() 

100 

101 

102@pytest.mark.parametrize("words, answer", cases) 

103def test_solution(solution: Solution, words: list[str], answer: list[int]): 

104 

105 answer_computed = solution.sumPrefixScores(words) 

106 print(answer_computed) 

107 assert answer == answer_computed