package regex

import (
	"fmt"
	"sort"
)

// a Match stores a slice of all the capturing groups in a match.
type Match []Group

// a Group represents a group. It contains the start index and end index of the match
type Group struct {
	StartIdx int
	EndIdx   int
}

func newMatch(size int) Match {
	toRet := make([]Group, size)
	for i := range toRet {
		toRet[i].StartIdx = -1
		toRet[i].EndIdx = -1
	}
	return toRet
}

// Returns the number of valid groups in the match
func (m Match) numValidGroups() int {
	numValid := 0
	for _, g := range m {
		if g.StartIdx >= 0 && g.EndIdx >= 0 {
			numValid++
		}
	}
	return numValid
}

// Returns a string containing the indices of all (valid) groups in the match
func (m Match) String() string {
	var toRet string
	for i, g := range m {
		if g.IsValid() {
			toRet += fmt.Sprintf("Group %d\n", i)
			toRet += g.toString()
			toRet += "\n"
		}
	}
	return toRet
}

// Converts the Group into a string representation:
func (idx Group) toString() string {
	return fmt.Sprintf("%d\t%d", idx.StartIdx, idx.EndIdx)
}

// Returns whether a group is valid (ie. whether it matched any text). It
// simply ensures that both indices of the group are >= 0.
func (g Group) IsValid() bool {
	return g.StartIdx >= 0 && g.EndIdx >= 0
}

// Simple function, makes it easier to map this over a list of matches
func getZeroGroup(m Match) Group {
	return m[0]
}

// takeZeroState takes the 0-state (if such a transition exists) for all states in the
// given slice. It returns the resulting states. If any of the resulting states is a 0-state,
// the second ret val is true.
// If a state begins or ends a capturing group, its 'thread' is updated to contain the correct index.
func takeZeroState(states []*nfaState, numGroups int, idx int) (rtv []*nfaState, isZero bool) {
	for _, state := range states {
		if len(state.transitions[epsilon]) > 0 {
			for _, s := range state.transitions[epsilon] {
				if s.threadGroups == nil {
					s.threadGroups = newMatch(numGroups + 1)
				}
				copy(s.threadGroups, state.threadGroups)
				if s.groupBegin {
					s.threadGroups[s.groupNum].StartIdx = idx
					//					openParenGroups = append(openParenGroups, s.groupNum)
				}
				if s.groupEnd {
					s.threadGroups[s.groupNum].EndIdx = idx
					//					closeParenGroups = append(closeParenGroups, s.groupNum)
				}
			}
			rtv = append(rtv, state.transitions[epsilon]...)
		}
	}
	for _, state := range rtv {
		if len(state.transitions[epsilon]) > 0 {
			return rtv, true
		}
	}
	return rtv, false
}

// zeroMatchPossible returns true if a zero-length match is possible
// from any of the given states, given the string and our position in it.
// It uses the same algorithm to find zero-states as the one inside the loop,
// so I should probably put it in a function.
func zeroMatchPossible(str []rune, idx int, numGroups int, states ...*nfaState) bool {
	zeroStates, isZero := takeZeroState(states, numGroups, idx)
	tempstates := make([]*nfaState, 0, len(zeroStates)+len(states))
	tempstates = append(tempstates, states...)
	tempstates = append(tempstates, zeroStates...)
	num_appended := 0 // number of unique states addded to tempstates
	for isZero == true {
		zeroStates, isZero = takeZeroState(tempstates, numGroups, idx)
		tempstates, num_appended = unique_append(tempstates, zeroStates...)
		if num_appended == 0 { // break if we haven't appended any more unique values
			break
		}
	}
	for _, state := range tempstates {
		if state.isEmpty && (state.assert == noneAssert || state.checkAssertion(str, idx)) && state.isLast {
			return true
		}
	}
	return false
}

// Prunes the slice by removing overlapping indices.
func pruneIndices(indices []Match) []Match {
	// First, sort the slice by the start indices
	sort.Slice(indices, func(i, j int) bool {
		return indices[i][0].StartIdx < indices[j][0].StartIdx
	})
	toRet := make([]Match, 0, len(indices))
	current := indices[0]
	for _, idx := range indices[1:] {
		// idx doesn't overlap with current (starts after current ends), so add current to result
		// and update the current.
		if idx[0].StartIdx >= current[0].EndIdx {
			toRet = append(toRet, current)
			current = idx
		} else if idx[0].EndIdx > current[0].EndIdx {
			// idx overlaps, but it is longer, so update current
			current = idx
		}
	}
	// Add last state
	toRet = append(toRet, current)
	return toRet
}

// Find returns the 0-group of the leftmost match of the regex in the given string.
// An error value != nil indicates that no match was found.
func (regex Reg) Find(str string) (Group, error) {
	match, err := regex.FindNthMatch(str, 1)
	if err != nil {
		return Group{}, fmt.Errorf("no matches found")
	}
	return getZeroGroup(match), nil
}

// FindAll returns a slice containing all the 0-groups of the regex in the given string.
// A 0-group represents the match without any submatches.
func (regex Reg) FindAll(str string) []Group {
	indices := regex.FindAllSubmatch(str)
	zeroGroups := funcMap(indices, getZeroGroup)
	return zeroGroups
}

// FindString returns the text of the leftmost match of the regex in the given string.
// The return value will be an empty string in two situations:
//  1. No match was found
//  2. The match was an empty string
func (regex Reg) FindString(str string) string {
	match, err := regex.FindNthMatch(str, 1)
	if err != nil {
		return ""
	}
	zeroGroup := getZeroGroup(match)
	return str[zeroGroup.StartIdx:zeroGroup.EndIdx]
}

// FindSubmatch returns the leftmost match of the regex in the given string, including
// the submatches matched by capturing groups. The returned [Match] will always contain the same
// number of groups. The validity of a group (whether or not it matched anything) can be determined with
// [Group.IsValid], or by checking that both indices of the group are >= 0.
// The second-return value is nil if no match was found.
func (regex Reg) FindSubmatch(str string) (Match, error) {
	match, err := regex.FindNthMatch(str, 1)
	if err != nil {
		return Match{}, fmt.Errorf("no match found")
	} else {
		return match, nil
	}
}

// FindAllString is the 'all' version of FindString.
// It returns a slice of strings containing the text of all matches of
// the regex in the given string.
func (regex Reg) FindAllString(str string) []string {
	zerogroups := regex.FindAll(str)
	matchStrs := funcMap(zerogroups, func(g Group) string {
		return str[g.StartIdx:g.EndIdx]
	})
	return matchStrs
}

// FindNthMatch return the 'n'th match of the regex in the given string.
// It returns an error (!= nil) if there are fewer than 'n' matches in the string.
func (regex Reg) FindNthMatch(str string, n int) (Match, error) {
	idx := 0
	matchNum := 0
	str_runes := []rune(str)
	var matchFound bool
	var matchIdx Match
	for idx <= len(str_runes) {
		matchFound, matchIdx, idx = findAllSubmatchHelper(regex.start, str_runes, idx, regex.numGroups)
		if matchFound {
			matchNum++
		}
		if matchNum == n {
			return matchIdx, nil
		}
	}
	// We haven't found the nth match after scanning the string - Return an error
	return nil, fmt.Errorf("invalid match index - too few matches found")
}

// FindAllSubmatch returns a slice of matches in the given string.
func (regex Reg) FindAllSubmatch(str string) []Match {
	idx := 0
	str_runes := []rune(str)
	var matchFound bool
	var matchIdx Match
	indices := make([]Match, 0)
	for idx <= len(str_runes) {
		matchFound, matchIdx, idx = findAllSubmatchHelper(regex.start, str_runes, idx, regex.numGroups)
		if matchFound {
			indices = append(indices, matchIdx)
		}
	}
	if len(indices) > 0 {
		return pruneIndices(indices)
	}

	return indices
}

// Helper for FindAllMatches. Returns whether it found a match, the
// first Match it finds, and how far it got into the string ie. where
// the next search should start from.
//
//	Might return duplicates or overlapping indices, so care must be taken to prune the resulting array.
func findAllSubmatchHelper(start *nfaState, str []rune, offset int, numGroups int) (bool, Match, int) {
	// Base case - exit if offset exceeds string's length
	if offset > len(str) {
		// The second value here shouldn't be used, because we should exit when the third return value is > than len(str)
		return false, []Group{}, offset
	}

	// Hold a list of match indices for the current run. When we
	// can no longer find a match, the match with the largest range is
	// chosen as the match for the entire string.
	// This allows us to pick the longest possible match (which is how greedy matching works).
	// COMMENT ABOVE IS CURRENTLY NOT UP-TO-DATE
	tempIndices := newMatch(numGroups + 1)

	foundPath := false
	startIdx := offset
	endIdx := offset
	currentStates := make([]*nfaState, 0)
	tempStates := make([]*nfaState, 0) // Used to store states that should be used in next loop iteration
	i := offset                        // Index in string
	startingFrom := i                  // Store starting index

	// If the first state is an assertion, makes sure the assertion
	// is true before we do _anything_ else.
	if start.assert != noneAssert {
		if start.checkAssertion(str, offset) == false {
			i++
			return false, []Group{}, i
		}
	}
	// Increment until we hit a character matching the start state (assuming not 0-state)
	if start.isEmpty == false {
		for i < len(str) && !start.contentContains(str, i) {
			i++
		}
		startIdx = i
		startingFrom = i
		i++ // Advance to next character (if we aren't at a 0-state, which doesn't match anything), so that we can check for transitions. If we advance at a 0-state, we will never get a chance to match the first character
	}

	start.threadGroups = newMatch(numGroups + 1)
	// Check if the start state begins a group - if so, add the start index to our list
	if start.groupBegin {
		start.threadGroups[start.groupNum].StartIdx = i
		//		tempIndices[start.groupNum].startIdx = i
	}

	currentStates = append(currentStates, start)

	// Main loop
	for i < len(str) {
		foundPath = false

		zeroStates := make([]*nfaState, 0)
		// Keep taking zero-states, until there are no more left to take
		// Objective: If any of our current states have transitions to 0-states, replace them with the 0-state. Do this until there are no more transitions to 0-states, or there are no more unique 0-states to take.
		zeroStates, isZero := takeZeroState(currentStates, numGroups, i)
		tempStates = append(tempStates, zeroStates...)
		num_appended := 0
		for isZero == true {
			zeroStates, isZero = takeZeroState(tempStates, numGroups, i)
			tempStates, num_appended = unique_append(tempStates, zeroStates...)
			if num_appended == 0 { // Break if we haven't appended any more unique values
				break
			}
		}

		currentStates, _ = unique_append(currentStates, tempStates...)
		tempStates = nil

		// Take any transitions corresponding to current character
		numStatesMatched := 0            // The number of states which had at least 1 match for this round
		assertionFailed := false         // Whether or not an assertion failed for this round
		lastStateInList := false         // Whether or not a last state was in our list of states
		var lastStatePtr *nfaState = nil // Pointer to the last-state, if it was found
		lastLookaroundInList := false    // Whether or not a last state (that is a lookaround) was in our list of states
		for _, state := range currentStates {
			matches, numMatches := state.matchesFor(str, i)
			if numMatches > 0 {
				numStatesMatched++
				tempStates = append(tempStates, matches...)
				foundPath = true
				for _, m := range matches {
					if m.threadGroups == nil {
						m.threadGroups = newMatch(numGroups + 1)
					}
					copy(m.threadGroups, state.threadGroups)
				}
			}
			if numMatches < 0 {
				assertionFailed = true
			}
			if state.isLast {
				if state.isLookaround() {
					lastLookaroundInList = true
				}
				lastStateInList = true
				lastStatePtr = state
			}
		}

		if assertionFailed && numStatesMatched == 0 { // Nothing has matched and an assertion has failed
			// If I'm being completely honest, I'm not sure why I have to check specifically for a _lookaround_
			// state. The explanation below is my attempt to explain this behavior.
			// If you replace 'lastLookaroundInList' with 'lastStateInList', one of the test cases fails.
			//
			// One of the states in our list was a last state and a lookaround. In this case, we
			// don't abort upon failure of the assertion, because we have found
			// another path to a final state.
			// Even if the last state _was_ an assertion, we can use the previously
			// saved indices to find a match.
			if lastLookaroundInList {
				break
			} else {
				if i == startingFrom {
					i++
				}
				return false, []Group{}, i
			}
		}
		// Check if we can find a state in our list that is:
		// 	a. A last-state
		// 	b. Empty
		// 	c. Doesn't assert anything
		for _, s := range currentStates {
			if s.isLast && s.isEmpty && s.assert == noneAssert {
				lastStatePtr = s
				lastStateInList = true
			}
		}
		if lastStateInList { // A last-state was in the list of states. add the matchIndex to our MatchIndex list
			for j := 1; j < numGroups+1; j++ {
				tempIndices[j] = lastStatePtr.threadGroups[j]
			}
			endIdx = i
			tempIndices[0] = Group{startIdx, endIdx}
		}

		// Check if we can find a zero-length match
		if foundPath == false {
			if ok := zeroMatchPossible(str, i, numGroups, currentStates...); ok {
				if tempIndices[0].IsValid() == false {
					tempIndices[0] = Group{startIdx, startIdx}
				}
			}
			// If we haven't moved in the string, increment the counter by 1
			// to ensure we don't keep trying the same string over and over.
			//			if i == startingFrom {
			startIdx++
			//	i++
			//			}
			if tempIndices.numValidGroups() > 0 && tempIndices[0].IsValid() {
				if tempIndices[0].StartIdx == tempIndices[0].EndIdx { // If we have a zero-length match, we have to shift the index at which we start. Otherwise we keep looking at the same paert of the string over and over.
					return true, tempIndices, tempIndices[0].EndIdx + 1
				} else {
					return true, tempIndices, tempIndices[0].EndIdx
				}
			}
			return false, []Group{}, startIdx
		}
		currentStates = make([]*nfaState, len(tempStates))
		copy(currentStates, tempStates)
		tempStates = nil

		i++
	}

	// End-of-string reached. Go to any 0-states, until there are no more 0-states to go to. Then check if any of our states are in the end position.
	// This is the exact same algorithm used inside the loop, so I should probably put it in a function.
	zeroStates, isZero := takeZeroState(currentStates, numGroups, i)
	tempStates = append(tempStates, zeroStates...)
	num_appended := 0 // Number of unique states addded to tempStates
	for isZero == true {
		zeroStates, isZero = takeZeroState(tempStates, numGroups, i)
		tempStates, num_appended = unique_append(tempStates, zeroStates...)
		if num_appended == 0 { // Break if we haven't appended any more unique values
			break
		}
	}

	currentStates = append(currentStates, tempStates...)
	tempStates = nil

	for _, state := range currentStates {
		// Only add the match if the start index is in bounds. If the state has an assertion,
		// make sure the assertion checks out.
		if state.isLast && i <= len(str) {
			if state.assert == noneAssert || state.checkAssertion(str, i) {
				for j := 1; j < numGroups+1; j++ {
					tempIndices[j] = state.threadGroups[j]
				}
				endIdx = i
				tempIndices[0] = Group{startIdx, endIdx}
			}
		}
	}

	if tempIndices.numValidGroups() > 0 {
		if tempIndices[0].StartIdx == tempIndices[0].EndIdx { // If we have a zero-length match, we have to shift the index at which we start. Otherwise we keep looking at the same paert of the string over and over.
			return true, tempIndices, tempIndices[0].EndIdx + 1
		} else {
			return true, tempIndices, tempIndices[0].EndIdx
		}
	}
	if startIdx == startingFrom { // Increment starting index if we haven't moved in the string. Prevents us from matching the same part of the string over and over.
		startIdx++
	}
	return false, []Group{}, startIdx
}