diff --git a/main.go b/main.go index 3d765a8..76661f8 100644 --- a/main.go +++ b/main.go @@ -38,7 +38,7 @@ func shuntingYard(re string) string { re_postfix = append(re_postfix, re_runes[i]) if re_runes[i] != '(' && re_runes[i] != '|' { if i < len(re_runes)-1 { - if re_runes[i+1] != '|' && re_runes[i+1] != '*' && re_runes[i+1] != ')' { + if re_runes[i+1] != '|' && re_runes[i+1] != '*' && re_runes[i+1] != '+' && re_runes[i+1] != ')' { re_postfix = append(re_postfix, CONCAT) } } @@ -139,7 +139,14 @@ func thompson(re string) *State { stateToAdd.transitions[s1.content] = append(stateToAdd.transitions[s1.content], s1) nfa = append(nfa, stateToAdd) case '+': - + s1 := pop(&nfa) + for i := range s1.output { + s1.output[i].transitions[s1.content] = append(s1.output[i].transitions[s1.content], s1) + } + // Reset output to s1 (in case s1 was a union operator state, which has multiple outputs) + s1.output = nil + s1.output = append(s1.output, s1) + nfa = append(nfa, s1) case '|': s1 := pop(&nfa) s2 := pop(&nfa) diff --git a/re_test.go b/re_test.go index 9ee3ff1..114e66b 100644 --- a/re_test.go +++ b/re_test.go @@ -14,12 +14,19 @@ var reTests = []struct { {"a", "bca", []matchIndex{{2, 3}}}, {"l", "ggllgg", []matchIndex{{2, 3}, {3, 4}}}, {"(b|c)", "abdceb", []matchIndex{{1, 2}, {3, 4}, {5, 6}}}, - {"a*", "brerereraaaaabbbbb", []matchIndex{{8, 13}}}, + {"a+", "brerereraaaaabbbbb", []matchIndex{{8, 13}}}, + {"ab+", "qweqweqweaqweqweabbbbbr", []matchIndex{{16, 22}}}, + {"(b|c|A)", "ooaoobocA", []matchIndex{{5, 6}, {7, 8}, {8, 9}}}, + {"ab*", "a", []matchIndex{{0, 1}}}, + {"ab*", "abb", []matchIndex{{0, 3}}}, + {"(abc)*", "abcabcabc", []matchIndex{{0, 9}}}, + {"((abc)|(def))*", "abcdef", []matchIndex{{0, 6}}}, + {"(abc)*|(def)*", "abcdef", []matchIndex{{0, 3}, {3, 6}}}, } func TestFindAllMatches(t *testing.T) { for _, test := range reTests { - t.Run(test.re+" "+test.str, func(t *testing.T) { + t.Run(test.re+" "+test.str, func(t *testing.T) { re_postfix := shuntingYard(test.re) startState := thompson(re_postfix) matchIndices := findAllMatches(startState, test.str)