This repository has been archived by the owner on Mar 19, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrie_matching.py
95 lines (74 loc) · 2.23 KB
/
trie_matching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# python3
##############################
#
# @author Daniel Avdar
#
# @description:Extend TrieMatching algorithm so that it
# handles correctly cases when one of the
# patterns is a prefix of another one.
#
# @input: The first line of the input contains a string Text,
# the second line contains the number of patterns,
# and the following lines contain the patterns
#
# @output:All starting positions in Text where a string
# from Patterns appears as a substring in
# increasing order (assuming that Text is a
# 0-based array of symbols).
# If more than one pattern appears starting at
# position 𝑖, output 𝑖 once.
#
##############################
import sys
NA = -1
HIT = "hit"
def build_trie(patterns):
tree = dict()
tree[0] = dict()
new_k = 1
prev = None
for i in patterns:
node = tree.get(0)
j = 0
for j in i:
letter = node.get(j)
if letter is None:
node[j] = new_k
tree[new_k] = dict()
new_k += 1
go_to = node[j]
prev = node
node = tree[go_to]
if prev.get(HIT) is None:
prev[HIT] = set()
prev[HIT].add(j[len(j) - 1])
return tree
def solve(text, patterns):
result = []
tree = build_trie(patterns)
def find_patt(ind, node_ind=0, c=0):
node = tree[node_ind]
if ind + c == len(text):
return
node_ind = node.get(text[ind + c])
if node_ind is None:
return
hit = node.get(HIT)
if hit is not None and text[ind + c] in hit:
result.append(ind)
return
find_patt(ind, node_ind, c + 1)
for ind in range(len(text)):
find_patt(ind)
return list(result)
def main():
text = sys.stdin.readline().strip()
n = int(sys.stdin.readline().strip())
patterns = []
for i in range(n):
patterns += [sys.stdin.readline().strip()]
ans = solve(text, patterns)
sys.stdout.write(' '.join(map(str, ans)) + '\n')
if __name__ == '__main__':
main()
# todo git