From 8327450dd2ed52763a0b2b62abddb71d58299f1a Mon Sep 17 00:00:00 2001 From: Aadhavan Srinivasan Date: Tue, 11 Feb 2025 16:14:48 -0500 Subject: [PATCH] Started implementing backreferences (octal values should now be prefaced with \0) --- regex/compile.go | 59 +++++++++++++++++++++++++++++++++++----- regex/matching.go | 46 +++++++++++++++++++++++-------- regex/nfa.go | 11 +++++--- regex/postfixNode.go | 65 +++++++++++++++++++++++++++++++++++++------- regex/re_test.go | 14 +++++----- 5 files changed, 156 insertions(+), 39 deletions(-) diff --git a/regex/compile.go b/regex/compile.go index 0ae3d6b..0414ac8 100644 --- a/regex/compile.go +++ b/regex/compile.go @@ -313,13 +313,20 @@ func shuntingYard(re string, flags ...ReFlag) ([]postfixNode, error) { } else { return nil, fmt.Errorf("invalid hex value in expression") } - } else if isOctal(re_runes[i]) { + } else if re_runes[i] == '0' { // Start of octal value numDigits := 1 - for i+numDigits < len(re_runes) && numDigits < 3 && isOctal(re_runes[i+numDigits]) { // Skip while we see an octal character (max of 3) + for i+numDigits < len(re_runes) && numDigits < 4 && isOctal(re_runes[i+numDigits]) { // Skip while we see an octal character (max of 4, starting with 0) numDigits++ } re_postfix = append(re_postfix, re_runes[i:i+numDigits]...) i += (numDigits - 1) // I have to move back a step, so that I can add a concatenation operator if necessary, and so that the increment at the bottom of the loop works as intended + } else if unicode.IsDigit(re_runes[i]) { // Any other number - backreference + numDigits := 1 + for i+numDigits < len(re_runes) && unicode.IsDigit(re_runes[i+numDigits]) { // Skip while we see a digit + numDigits++ + } + re_postfix = append(re_postfix, re_runes[i:i+numDigits]...) + i += (numDigits - 1) // Move back a step to add concatenation operator } else { re_postfix = append(re_postfix, re_runes[i]) } @@ -364,7 +371,9 @@ func shuntingYard(re string, flags ...ReFlag) ([]postfixNode, error) { outQueue := make([]postfixNode, 0) // Output queue // Actual algorithm - numOpenParens := 0 // Number of open parentheses + numOpenParens := 0 // Number of open parentheses + parenIndices := make([]Group, 0) // I really shouldn't be using Group here, because that's strictly for matching purposes, but its a convenient way to store the indices of the opening and closing parens. + parenIndices = append(parenIndices, Group{0, 0}) // I append a weird value here, because the 0-th group doesn't have any parens. This way, the 1st group will be at index 1, 2nd at 2 ... for i := 0; i < len(re_postfix); i++ { /* Two cases: 1. Current character is alphanumeric - send to output queue @@ -420,11 +429,11 @@ func shuntingYard(re string, flags ...ReFlag) ([]postfixNode, error) { } else { return nil, fmt.Errorf("not enough hex characters found in expression") } - } else if isOctal(re_postfix[i]) { // Octal value + } else if re_postfix[i] == '0' { // Octal value var octVal int64 var octValStr string numDigitsParsed := 0 - for (i+numDigitsParsed) < len(re_postfix) && isOctal(re_postfix[i+numDigitsParsed]) && numDigitsParsed <= 3 { + for (i+numDigitsParsed) < len(re_postfix) && isOctal(re_postfix[i+numDigitsParsed]) && numDigitsParsed <= 4 { octValStr += string(re_postfix[i+numDigitsParsed]) numDigitsParsed++ } @@ -437,6 +446,20 @@ func shuntingYard(re string, flags ...ReFlag) ([]postfixNode, error) { } i += numDigitsParsed - 1 // Shift forward by the number of digits that were parsed. Move back one character, because the loop increment will move us back to the next character automatically outQueue = append(outQueue, newPostfixCharNode(rune(octVal))) + } else if unicode.IsDigit(re_postfix[i]) { // Backreference + var num int64 + var numStr string + numDigitsParsed := 0 + for (i+numDigitsParsed) < len(re_postfix) && unicode.IsDigit(re_postfix[i+numDigitsParsed]) { + numStr += string(re_postfix[i+numDigitsParsed]) + numDigitsParsed++ + } + num, err := strconv.ParseInt(numStr, 10, 32) + if err != nil { + return nil, fmt.Errorf("error parsing backreference in expresion") + } + i += numDigitsParsed - 1 + outQueue = append(outQueue, newPostfixBackreferenceNode(int(num))) } else { escapedNode, err := newEscapedNode(re_postfix[i], false) if err != nil { @@ -588,11 +611,11 @@ func shuntingYard(re string, flags ...ReFlag) ([]postfixNode, error) { } else { return nil, fmt.Errorf("not enough hex characters found in character class") } - } else if isOctal(re_postfix[i]) { // Octal value + } else if re_postfix[i] == '0' { // Octal value var octVal int64 var octValStr string numDigitsParsed := 0 - for (i+numDigitsParsed) < len(re_postfix)-1 && isOctal(re_postfix[i+numDigitsParsed]) && numDigitsParsed <= 3 { // The '-1' exists, because even in the worst case (the character class extends till the end), the last character must be a closing bracket (and nothing else) + for (i+numDigitsParsed) < len(re_postfix)-1 && isOctal(re_postfix[i+numDigitsParsed]) && numDigitsParsed <= 4 { // The '-1' exists, because even in the worst case (the character class extends till the end), the last character must be a closing bracket (and nothing else) octValStr += string(re_postfix[i+numDigitsParsed]) numDigitsParsed++ } @@ -796,6 +819,7 @@ func shuntingYard(re string, flags ...ReFlag) ([]postfixNode, error) { outQueue = append(outQueue, newPostfixNode(c)) } numOpenParens++ + parenIndices = append(parenIndices, Group{StartIdx: len(outQueue) - 1}) // Push the index of the lparen into parenIndices } if c == ')' { // Keep popping from opStack until we encounter an opening parantheses or a NONCAPLPAREN_CHAR. Throw error if we reach the end of the stack. @@ -812,6 +836,7 @@ func shuntingYard(re string, flags ...ReFlag) ([]postfixNode, error) { if val == '(' { // Whatever was inside the parentheses was a _capturing_ group, so we append the closing parentheses as well outQueue = append(outQueue, newPostfixNode(')')) // Add closing parentheses } + parenIndices[numOpenParens].EndIdx = len(outQueue) - 1 numOpenParens-- } } @@ -826,6 +851,11 @@ func shuntingYard(re string, flags ...ReFlag) ([]postfixNode, error) { return nil, fmt.Errorf("imbalanced parantheses") } + // outQueue, _, err := rewriteBackreferences(outQueue, parenIndices) + // if err != nil { + // return nil, err + // } + return outQueue, nil } @@ -1037,6 +1067,21 @@ func thompson(re []postfixNode) (Reg, error) { }) nfa = append(nfa, toAdd) } + if c.nodetype == backreferenceNode { + if c.referencedGroup > numGroups { + return Reg{}, fmt.Errorf("invalid backreference") + } + stateToAdd := &nfaState{} + stateToAdd.assert = noneAssert + stateToAdd.content = newContents(epsilon) + stateToAdd.isEmpty = true + stateToAdd.isBackreference = true + stateToAdd.output = make([]*nfaState, 0) + stateToAdd.output = append(stateToAdd.output, stateToAdd) + stateToAdd.referredGroup = c.referencedGroup + stateToAdd.threadBackref = 0 + nfa = append(nfa, stateToAdd) + } // Must be an operator if it isn't a character switch c.nodetype { case concatenateNode: diff --git a/regex/matching.go b/regex/matching.go index a344a40..230a658 100644 --- a/regex/matching.go +++ b/regex/matching.go @@ -228,25 +228,45 @@ func (re Reg) FindAllStringSubmatch(str string) [][]string { return rtv } -func addStateToList(str []rune, idx int, list []nfaState, state nfaState, threadGroups []Group, visited []nfaState, preferLongest bool) []nfaState { +// Second parameter is the 'new index' +func addStateToList(str []rune, idx int, list []nfaState, state nfaState, threadGroups []Group, visited []nfaState, preferLongest bool) ([]nfaState, int) { if stateExists(list, state) || stateExists(visited, state) { - return list + return list, idx } visited = append(visited, state) + if state.isBackreference { + if threadGroups[state.referredGroup].IsValid() { + groupLength := threadGroups[state.referredGroup].EndIdx - threadGroups[state.referredGroup].StartIdx + if state.threadBackref == groupLength { + state.threadBackref = 0 + copyThread(state.next, state) + return addStateToList(str, idx+groupLength, list, *state.next, threadGroups, visited, preferLongest) + } + idxInReferredGroup := threadGroups[state.referredGroup].StartIdx + state.threadBackref + if idxInReferredGroup < len(str) && idx+state.threadBackref < len(str) && str[idxInReferredGroup] == str[idx+state.threadBackref] { + state.threadBackref += 1 + return addStateToList(str, idx, list, state, threadGroups, visited, preferLongest) + } else { + return list, idx + } + } else { + return list, idx + } + } if state.isKleene || state.isQuestion { copyThread(state.splitState, state) - list = addStateToList(str, idx, list, *state.splitState, threadGroups, visited, preferLongest) + list, newIdx := addStateToList(str, idx, list, *state.splitState, threadGroups, visited, preferLongest) copyThread(state.next, state) - list = addStateToList(str, idx, list, *state.next, threadGroups, visited, preferLongest) - return list + list, newIdx = addStateToList(str, newIdx, list, *state.next, threadGroups, visited, preferLongest) + return list, newIdx } if state.isAlternation { copyThread(state.next, state) - list = addStateToList(str, idx, list, *state.next, threadGroups, visited, preferLongest) + list, newIdx := addStateToList(str, idx, list, *state.next, threadGroups, visited, preferLongest) copyThread(state.splitState, state) - list = addStateToList(str, idx, list, *state.splitState, threadGroups, visited, preferLongest) - return list + list, newIdx = addStateToList(str, newIdx, list, *state.splitState, threadGroups, visited, preferLongest) + return list, newIdx } state.threadGroups = append([]Group{}, threadGroups...) if state.assert != noneAssert { @@ -257,13 +277,15 @@ func addStateToList(str []rune, idx int, list []nfaState, state nfaState, thread } if state.groupBegin { state.threadGroups[state.groupNum].StartIdx = idx + copyThread(state.next, state) return addStateToList(str, idx, list, *state.next, state.threadGroups, visited, preferLongest) } if state.groupEnd { state.threadGroups[state.groupNum].EndIdx = idx + copyThread(state.next, state) return addStateToList(str, idx, list, *state.next, state.threadGroups, visited, preferLongest) } - return append(list, state) + return append(list, state), idx } @@ -293,7 +315,7 @@ func findAllSubmatchHelper(start *nfaState, str []rune, offset int, numGroups in start.threadGroups = newMatch(numGroups + 1) start.threadGroups[0].StartIdx = i - currentStates = addStateToList(str, i, currentStates, *start, start.threadGroups, nil, preferLongest) + currentStates, _ = addStateToList(str, i, currentStates, *start, start.threadGroups, nil, preferLongest) // We can't go forward at the beginning, so I discard the second retval var match Match = nil for idx := i; idx <= len(str); idx++ { if len(currentStates) == 0 { @@ -315,7 +337,9 @@ func findAllSubmatchHelper(start *nfaState, str []rune, offset int, numGroups in } } else if !currentState.isAlternation && !currentState.isKleene && !currentState.isQuestion && !currentState.groupBegin && !currentState.groupEnd && currentState.assert == noneAssert { // Normal character if currentState.contentContains(str, idx, preferLongest) { - nextStates = addStateToList(str, idx+1, nextStates, *currentState.next, currentState.threadGroups, nil, preferLongest) + var newIdx int + nextStates, newIdx = addStateToList(str, idx+1, nextStates, *currentState.next, currentState.threadGroups, nil, preferLongest) + idx = newIdx - 1 } } } diff --git a/regex/nfa.go b/regex/nfa.go index c649712..8f454cf 100644 --- a/regex/nfa.go +++ b/regex/nfa.go @@ -45,8 +45,10 @@ type nfaState struct { groupEnd bool // Whether or not the node ends a capturing group groupNum int // Which capturing group the node starts / ends // The following properties depend on the current match - I should think about resetting them for every match. - zeroMatchFound bool // Whether or not the state has been used for a zero-length match - only relevant for zero states - threadGroups []Group // Assuming that a state is part of a 'thread' in the matching process, this array stores the indices of capturing groups in the current thread. As matches are found for this state, its groups will be copied over. + threadGroups []Group // Assuming that a state is part of a 'thread' in the matching process, this array stores the indices of capturing groups in the current thread. As matches are found for this state, its groups will be copied over. + isBackreference bool // Whether or not current node is backreference + referredGroup int // If current node is a backreference, the node that it points to + threadBackref int // If current node is a backreference, how many characters to look forward into the referred group } // Clones the NFA starting from the given state. @@ -76,7 +78,6 @@ func cloneStateHelper(stateToClone *nfaState, cloneMap map[*nfaState]*nfaState) isQuestion: stateToClone.isQuestion, isAlternation: stateToClone.isAlternation, assert: stateToClone.assert, - zeroMatchFound: stateToClone.zeroMatchFound, allChars: stateToClone.allChars, except: append([]rune{}, stateToClone.except...), lookaroundRegex: stateToClone.lookaroundRegex, @@ -122,6 +123,7 @@ func resetThreadsHelper(state *nfaState, visitedMap map[*nfaState]bool) { } // Assuming it hasn't been visited state.threadGroups = nil + state.threadBackref = 0 visitedMap[state] = true if state.isAlternation { resetThreadsHelper(state.next, visitedMap) @@ -428,7 +430,8 @@ func (s nfaState) equals(other nfaState) bool { s.groupBegin == other.groupBegin && s.groupEnd == other.groupEnd && s.groupNum == other.groupNum && - slices.Equal(s.threadGroups, other.threadGroups) + slices.Equal(s.threadGroups, other.threadGroups) && + s.threadBackref == other.threadBackref } func stateExists(list []nfaState, s nfaState) bool { diff --git a/regex/postfixNode.go b/regex/postfixNode.go index c60de47..88e4c62 100644 --- a/regex/postfixNode.go +++ b/regex/postfixNode.go @@ -1,6 +1,8 @@ package regex -import "fmt" +import ( + "fmt" +) type nodeType int @@ -20,6 +22,7 @@ const ( assertionNode lparenNode rparenNode + backreferenceNode ) // Helper constants for lookarounds @@ -31,15 +34,16 @@ const lookbehind = -1 var infinite_reps int = -1 // Represents infinite reps eg. the end range in {5,} // This represents a node in the postfix representation of the expression type postfixNode struct { - nodetype nodeType - contents []rune // Contents of the node - startReps int // Minimum number of times the node should be repeated - used with numeric specifiers - endReps int // Maximum number of times the node should be repeated - used with numeric specifiers - allChars bool // Whether or not the current node represents all characters (eg. dot metacharacter) - except []postfixNode // For inverted character classes, we match every unicode character _except_ a few. In this case, allChars is true and the exceptions are placed here. - lookaroundSign int // ONLY USED WHEN nodetype == ASSERTION. Whether we have a positive or negative lookaround. - lookaroundDir int // Lookbehind or lookahead - nodeContents []postfixNode // ONLY USED WHEN nodetype == CHARCLASS. Holds all the nodes inside the given CHARCLASS node. + nodetype nodeType + contents []rune // Contents of the node + startReps int // Minimum number of times the node should be repeated - used with numeric specifiers + endReps int // Maximum number of times the node should be repeated - used with numeric specifiers + allChars bool // Whether or not the current node represents all characters (eg. dot metacharacter) + except []postfixNode // For inverted character classes, we match every unicode character _except_ a few. In this case, allChars is true and the exceptions are placed here. + lookaroundSign int // ONLY USED WHEN nodetype == ASSERTION. Whether we have a positive or negative lookaround. + lookaroundDir int // Lookbehind or lookahead + nodeContents []postfixNode // ONLY USED WHEN nodetype == CHARCLASS. Holds all the nodes inside the given CHARCLASS node. + referencedGroup int // ONLY USED WHEN nodetype == backreferenceNode. Holds the group which this one refers to. After parsing is done, the expression will be rewritten eg. (a)\1 will become (a)(a). So the return value of ShuntingYard() shouldn't contain a backreferenceNode. } // Converts the given list of postfixNodes to one node of type CHARCLASS. @@ -208,3 +212,44 @@ func newPostfixCharNode(contents ...rune) postfixNode { toReturn.contents = append(toReturn.contents, contents...) return toReturn } + +// newPostfixBackreferenceNode creates and returns a backreference node, referring to the given group +func newPostfixBackreferenceNode(referred int) postfixNode { + toReturn := postfixNode{} + toReturn.startReps = 1 + toReturn.endReps = 1 + toReturn.nodetype = backreferenceNode + toReturn.referencedGroup = referred + return toReturn +} + +// rewriteBackreferences rewrites any backreferences in the given postfixNode slice, into their respective groups. +// It stores the relation in a map, and returns it as the second return value. +// It uses parenIndices to determine where a group starts and ends in nodes. +// For example, \1(a) will be rewritten into (a)(a), and 1 -> 2 will be the hashmap value. +// It returns an error if a backreference points to an invalid group. +// func rewriteBackreferences(nodes []postfixNode, parenIndices []Group) ([]postfixNode, map[int]int, error) { +// rtv := make([]postfixNode, 0) +// referMap := make(map[int]int) +// numGroups := 0 +// groupIncrement := 0 // If we have a backreference before the group its referring to, then the group its referring to will have its group number incremented. +// for i, node := range nodes { +// if node.nodetype == backreferenceNode { +// if node.referencedGroup >= len(parenIndices) { +// return nil, nil, fmt.Errorf("invalid backreference") +// } +// rtv = slices.Concat(rtv, nodes[parenIndices[node.referencedGroup].StartIdx:parenIndices[node.referencedGroup].EndIdx+1]) // Add all the nodes in the group to rtv +// numGroups += 1 +// if i < parenIndices[node.referencedGroup].StartIdx { +// groupIncrement += 1 +// } +// referMap[numGroups] = node.referencedGroup + groupIncrement +// } else { +// rtv = append(rtv, node) +// if node.nodetype == lparenNode { +// numGroups += 1 +// } +// } +// } +// return rtv, referMap, nil +// } diff --git a/regex/re_test.go b/regex/re_test.go index f697e81..05230a1 100644 --- a/regex/re_test.go +++ b/regex/re_test.go @@ -179,7 +179,7 @@ var reTests = []struct { {"[[:graph:]]+", nil, "abcdefghijklmnopqrstuvwyxzABCDEFGHIJKLMNOPRQSTUVWXYZ0123456789!@#$%^&*", []Group{{0, 70}}}, // Test cases from Python's RE test suite - {`[\1]`, nil, "\x01", []Group{{0, 1}}}, + {`[\01]`, nil, "\x01", []Group{{0, 1}}}, {`\0`, nil, "\x00", []Group{{0, 1}}}, {`[\0a]`, nil, "\x00", []Group{{0, 1}}}, @@ -194,7 +194,7 @@ var reTests = []struct { {`\x00ffffffffffffff`, nil, "\xff", []Group{}}, {`\x00f`, nil, "\x0f", []Group{}}, {`\x00fe`, nil, "\xfe", []Group{}}, - {`^\w+=(\\[\000-\277]|[^\n\\])*`, nil, "SRC=eval.c g.c blah blah blah \\\\\n\tapes.c", []Group{{0, 32}}}, + {`^\w+=(\\[\000-\0277]|[^\n\\])*`, nil, "SRC=eval.c g.c blah blah blah \\\\\n\tapes.c", []Group{{0, 32}}}, {`a.b`, nil, `acb`, []Group{{0, 3}}}, {`a.b`, nil, "a\nb", []Group{}}, @@ -312,7 +312,7 @@ var reTests = []struct { {`a[-]?c`, nil, `ac`, []Group{{0, 2}}}, {`^(.+)?B`, nil, `AB`, []Group{{0, 2}}}, {`\0009`, nil, "\x009", []Group{{0, 2}}}, - {`\141`, nil, "a", []Group{{0, 1}}}, + {`\0141`, nil, "a", []Group{{0, 1}}}, // At this point, the python test suite has a bunch // of backreference tests. Since my engine doesn't @@ -433,7 +433,7 @@ var reTests = []struct { {`a[-]?c`, []ReFlag{RE_CASE_INSENSITIVE}, `AC`, []Group{{0, 2}}}, {`^(.+)?B`, []ReFlag{RE_CASE_INSENSITIVE}, `ab`, []Group{{0, 2}}}, {`\0009`, []ReFlag{RE_CASE_INSENSITIVE}, "\x009", []Group{{0, 2}}}, - {`\141`, []ReFlag{RE_CASE_INSENSITIVE}, "A", []Group{{0, 1}}}, + {`\0141`, []ReFlag{RE_CASE_INSENSITIVE}, "A", []Group{{0, 1}}}, {`a[-]?c`, []ReFlag{RE_CASE_INSENSITIVE}, `AC`, []Group{{0, 2}}}, @@ -473,7 +473,7 @@ var reTests = []struct { {`[\t][\n][\v][\r][\f][\b]`, nil, "\t\n\v\r\f\b", []Group{{0, 6}}}, {`.*d`, nil, "abc\nabd", []Group{{4, 7}}}, {`(`, nil, "-", nil}, - {`[\41]`, nil, `!`, []Group{{0, 1}}}, + {`[\041]`, nil, `!`, []Group{{0, 1}}}, {`(?