package regex

import (
	"fmt"
	"slices"
)

const epsilon int = 0xF0000

type assertType int

const (
	noneAssert assertType = iota
	sosAssert             // Start of string (^)
	soiAssert             // Start of input (\A)
	eosAssert             // End of string ($)
	eoiAssert             // End of input (\Z)
	wboundAssert
	nonwboundAssert
	plaAssert        // Positive lookahead
	nlaAssert        // Negative lookahead
	plbAssert        // Positive lookbehind
	nlbAssert        // Negative lookbehind
	alwaysTrueAssert // An assertion that is always true
)

type nfaState struct {
	content stateContents // Contents of current state
	isEmpty bool          // If it is empty - Union operator and Kleene star states will be empty
	isLast  bool          // If it is the last state (acept state)
	output  []*nfaState   // The outputs of the current state ie. the 'outward arrows'. A union operator state will have more than one of these.
	//	transitions                map[int][]*nfaState // Transitions to different states (maps a character (int representation) to a _list of states. This is useful if one character can lead multiple states eg. ab|aa)
	next                       *nfaState  // The next state (not for alternation or kleene states)
	isKleene                   bool       // Identifies whether current node is a 0-state representing Kleene star
	isQuestion                 bool       // Identifies whether current node is a 0-state representing the question operator
	isAlternation              bool       // Identifies whether current node is a 0-state representing an alternation
	splitState                 *nfaState  // Only for alternation states - the 'other' branch of the alternation ('next' is the first)
	assert                     assertType // Type of assertion of current node - NONE means that the node doesn't assert anything
	allChars                   bool       // Whether or not the state represents all characters (eg. a 'dot' metacharacter). A 'dot' node doesn't store any contents directly, as it would take up too much space
	except                     []rune     // Only valid if allChars is true - match all characters _except_ the ones in this block. Useful for inverting character classes.
	lookaroundRegex            string     // Only for lookaround states - Contents of the regex that the lookaround state holds
	lookaroundNFA              *nfaState  // Holds the NFA of the lookaroundRegex - if it exists
	lookaroundNumCaptureGroups int        // Number of capturing groups in lookaround regex if current node is a lookaround
	groupBegin                 bool       // Whether or not the node starts a capturing group
	groupEnd                   bool       // Whether or not the node ends a capturing group
	groupNum                   int        // Which capturing group the node starts / ends
	// The following properties depend on the current match - I should think about resetting them for every match.
	threadGroups    []Group // Assuming that a state is part of a 'thread' in the matching process, this array stores the indices of capturing groups in the current thread. As matches are found for this state, its groups will be copied over.
	isBackreference bool    // Whether or not current node is backreference
	referredGroup   int     // If current node is a backreference, the node that it points to
	threadBackref   int     // If current node is a backreference, how many characters to look forward into the referred group
}

// Clones the NFA starting from the given state.
func cloneState(start *nfaState) *nfaState {
	return cloneStateHelper(start, make(map[*nfaState]*nfaState))
}

// Helper function for clone. The map is used to keep track of which states have
// already been copied, and which ones haven't.
// This function was created using output from Llama3.1:405B.
func cloneStateHelper(stateToClone *nfaState, cloneMap map[*nfaState]*nfaState) *nfaState {
	// Base case - if the clone exists in our map, return it.
	if clone, exists := cloneMap[stateToClone]; exists {
		return clone
	}
	if stateToClone == nil {
		return nil
	}
	// Recursive case - if the clone doesn't exist, create it, add it to the map,
	// and recursively call for each of the transition states.
	clone := &nfaState{
		content:         append([]int{}, stateToClone.content...),
		isEmpty:         stateToClone.isEmpty,
		isLast:          stateToClone.isLast,
		output:          make([]*nfaState, len(stateToClone.output)),
		isKleene:        stateToClone.isKleene,
		isQuestion:      stateToClone.isQuestion,
		isAlternation:   stateToClone.isAlternation,
		assert:          stateToClone.assert,
		allChars:        stateToClone.allChars,
		except:          append([]rune{}, stateToClone.except...),
		lookaroundRegex: stateToClone.lookaroundRegex,
		groupEnd:        stateToClone.groupEnd,
		groupBegin:      stateToClone.groupBegin,
		groupNum:        stateToClone.groupNum,
	}
	cloneMap[stateToClone] = clone
	for i, s := range stateToClone.output {
		if s == stateToClone {
			clone.output[i] = clone
		} else {
			clone.output[i] = cloneStateHelper(s, cloneMap)
		}
	}
	if stateToClone.lookaroundNFA == stateToClone {
		clone.lookaroundNFA = clone
	}
	clone.lookaroundNFA = cloneStateHelper(stateToClone.lookaroundNFA, cloneMap)
	if stateToClone.splitState == stateToClone {
		clone.splitState = clone
	}
	clone.splitState = cloneStateHelper(stateToClone.splitState, cloneMap)
	if stateToClone.next == stateToClone {
		clone.next = clone
	}
	clone.next = cloneStateHelper(stateToClone.next, cloneMap)
	return clone
}

// Reset any thread-related fields of the NFA starting from the given state.
func resetThreads(start *nfaState) {
	visitedMap := make(map[*nfaState]bool) // The value type doesn't matter here
	resetThreadsHelper(start, visitedMap)
}

func resetThreadsHelper(state *nfaState, visitedMap map[*nfaState]bool) {
	if state == nil {
		return
	}
	if _, ok := visitedMap[state]; ok {
		return
	}
	// Assuming it hasn't been visited
	state.threadGroups = nil
	state.threadBackref = 0
	visitedMap[state] = true
	if state.isAlternation {
		resetThreadsHelper(state.next, visitedMap)
		resetThreadsHelper(state.splitState, visitedMap)
	} else {
		resetThreadsHelper(state.next, visitedMap)
	}
}

// Checks if the given state's assertion is true. Returns true if the given
// state doesn't have an assertion.
func (s nfaState) checkAssertion(str []rune, idx int, preferLongest bool) bool {
	if s.assert == alwaysTrueAssert {
		return true
	}
	if s.assert == sosAssert {
		// Single-line mode: Beginning of string
		// Multi-line mode: Previous character was newline
		return idx == 0 || (multilineMode && (idx > 0 && str[idx-1] == '\n'))
	}
	if s.assert == eosAssert {
		// Single-line mode: End of string
		// Multi-line mode: current character is newline
		// Index is at the end of the string, or it points to the last character which is a newline
		return idx == len(str) || (multilineMode && str[idx] == '\n')
	}
	if s.assert == soiAssert {
		// Only true at the start of the input, regardless of mode
		return idx == 0
	}
	if s.assert == eoiAssert {
		// Only true at the end of the input, regardless of mode
		return idx == len(str)
	}

	if s.assert == wboundAssert {
		return isWordBoundary(str, idx)
	}
	if s.assert == nonwboundAssert {
		return !isWordBoundary(str, idx)
	}
	if s.isLookaround() {
		// The process here is simple:
		// 		1. Compile the regex stored in the state's contents.
		// 		2. Run it on a subset of the test string, that ends after the current index in the string
		// 		3. Based on the kind of lookaround (and the indices we get), determine what action to take.
		startState := s.lookaroundNFA
		var runesToMatch []rune
		var strToMatch string
		if s.assert == plaAssert || s.assert == nlaAssert {
			runesToMatch = str[idx:]
		} else {
			runesToMatch = str[:idx]
		}

		if len(runesToMatch) == 0 {
			strToMatch = ""
		} else {
			strToMatch = string(runesToMatch)
		}

		regComp := Reg{startState, s.lookaroundNumCaptureGroups, s.lookaroundRegex, preferLongest}
		matchIndices := regComp.FindAll(strToMatch)

		numMatchesFound := 0
		for _, matchIdx := range matchIndices {
			if s.assert == plaAssert || s.assert == nlaAssert { // Lookahead - return true (or false) if at least one match starts at 0. Zero is used because the test-string _starts_ from idx.
				if matchIdx.StartIdx == 0 {
					numMatchesFound++
				}
			}
			if s.assert == plbAssert || s.assert == nlbAssert { // Lookbehind - return true (or false) if at least one match _ends_ at the current index.
				if matchIdx.EndIdx == idx {
					numMatchesFound++
				}
			}
		}
		if s.assert == plaAssert || s.assert == plbAssert { // Positive assertions want at least one match
			return numMatchesFound > 0
		}
		if s.assert == nlaAssert || s.assert == nlbAssert { // Negative assertions only want zero matches
			return numMatchesFound == 0
		}
	}
	return true
}

// Returns true if the contents of 's' contain the value at the given index of the given string
func (s nfaState) contentContains(str []rune, idx int, preferLongest bool) bool {
	if s.assert != noneAssert {
		return s.checkAssertion(str, idx, preferLongest)
	}
	if idx >= len(str) {
		return false
	}
	if s.allChars {
		return !slices.Contains(slices.Concat(notDotChars, s.except), str[idx]) // Return true only if the index isn't a 'notDotChar', or isn't one of the exception characters for the current node.
	}
	// Default - s.assert must be NONE
	return slices.Contains(s.content, int(str[idx]))
}

func (s nfaState) isLookaround() bool {
	return s.assert == plaAssert || s.assert == plbAssert || s.assert == nlaAssert || s.assert == nlbAssert
}

func (s nfaState) numTransitions() int {
	if s.next == nil && s.splitState == nil {
		return 0
	}
	if s.next == nil || s.splitState == nil {
		return 1
	}
	return 2
}

// Returns the matches for the character at the given index of the given string.
// Also returns the number of matches. Returns -1 if an assertion failed.
//func (s nfaState) matchesFor(str []rune, idx int) ([]*nfaState, int) {
//	// Assertions can be viewed as 'checks'. If the check fails, we return
//	// an empty array and 0.
//	// If it passes, we treat it like any other state, and return all the transitions.
//	if s.assert != noneAssert {
//		if s.checkAssertion(str, idx) == false {
//			return make([]*nfaState, 0), -1
//		}
//	}
//	listTransitions := s.transitions[int(str[idx])]
//	for _, dest := range s.transitions[int(anyCharRune)] {
//		if !slices.Contains(slices.Concat(notDotChars, dest.except), str[idx]) {
//			// Add an allChar state to the list of matches if:
//			// 		a. The current character isn't a 'notDotChars' character. In single line mode, this includes newline. In multiline mode, it doesn't.
//			// 		b. The current character isn't the state's exception list.
//			listTransitions = append(listTransitions, dest)
//		}
//	}
//	numTransitions := len(listTransitions)
//	return listTransitions, numTransitions
//}

// verifyLastStatesHelper performs the depth-first recursion needed for verifyLastStates
//func verifyLastStatesHelper(st *nfaState, visited map[*nfaState]bool) {
//	if st.numTransitions() == 0 {
//		st.isLast = true
//		return
//	}
//	//	if len(state.transitions) == 1 && len(state.transitions[state.content]) == 1 && state.transitions[state.content][0] == state { // Eg. a*
//	if st.numTransitions() == 1 { // Eg. a*
//		var moreThanOneTrans bool // Dummy variable, check if all the transitions for the current's state's contents have a length of one
//		for _, c := range st.content {
//			if len(st.transitions[c]) != 1 || st.transitions[c][0] != st {
//				moreThanOneTrans = true
//			}
//		}
//		st.isLast = !moreThanOneTrans
//	}
//
//	if st.isKleene { // A State representing a Kleene Star has transitions going out, which loop back to it. If all those transitions point to the same (single) state, then it must be a last state
//		transitionDests := make([]*nfaState, 0)
//		for _, v := range st.transitions {
//			transitionDests = append(transitionDests, v...)
//		}
//		if allEqual(transitionDests...) {
//			st.isLast = true
//			return
//		}
//	}
//	if visited[st] == true {
//		return
//	}
//	visited[st] = true
//	for _, states := range st.transitions {
//		for i := range states {
//			if states[i] != st {
//				verifyLastStatesHelper(states[i], visited)
//			}
//		}
//	}
//}

// verifyLastStates enables the 'isLast' flag for the leaf nodes (last states)
//func verifyLastStates(start []*nfaState) {
//	verifyLastStatesHelper(start[0], make(map[*nfaState]bool))
//}

// Concatenates s1 and s2, returns the start of the concatenation.
func concatenate(s1 *nfaState, s2 *nfaState) *nfaState {
	if s1 == nil {
		return s2
	}
	for i := range s1.output {
		s1.output[i].next = s2
	}
	s1.output = s2.output
	return s1
}

func kleene(s1 *nfaState) (*nfaState, error) {
	if s1.isEmpty && s1.assert != noneAssert {
		return nil, fmt.Errorf("previous token is not quantifiable")
	}

	toReturn := &nfaState{}
	toReturn.isEmpty = true
	toReturn.isAlternation = true
	toReturn.content = newContents(epsilon)
	toReturn.splitState = s1

	//	toReturn := &nfaState{}
	//	toReturn.transitions = make(map[int][]*nfaState)
	//	toReturn.content = newContents(epsilon)
	toReturn.isKleene = true
	toReturn.output = append([]*nfaState{}, toReturn)
	for i := range s1.output {
		s1.output[i].next = toReturn
	}
	//	for _, c := range s1.content {
	//		toReturn.transitions[c], _ = uniqueAppend(toReturn.transitions[c], &s1)
	//	}
	//toReturn.kleeneState = &s1
	return toReturn, nil
}

func alternate(s1 *nfaState, s2 *nfaState) *nfaState {
	toReturn := &nfaState{}
	//	toReturn.transitions = make(map[int][]*nfaState)
	toReturn.output = append(toReturn.output, s1.output...)
	toReturn.output = append(toReturn.output, s2.output...)
	//	// Unique append is used here (and elsewhere) to ensure that,
	//	// for any given transition, a state can only be mentioned once.
	//	// For example, given the transition 'a', the state 's1' can only be mentioned once.
	//	// This would lead to multiple instances of the same set of match indices, since both
	//	// 's1' states would be considered to match.
	//	for _, c := range s1.content {
	//		toReturn.transitions[c], _ = uniqueAppend(toReturn.transitions[c], s1)
	//	}
	//	for _, c := range s2.content {
	//		toReturn.transitions[c], _ = uniqueAppend(toReturn.transitions[c], s2)
	//	}
	toReturn.content = newContents(epsilon)
	toReturn.isEmpty = true
	toReturn.isAlternation = true
	toReturn.next = s1
	toReturn.splitState = s2

	return toReturn
}

func question(s1 *nfaState) (*nfaState, error) { // Use the fact that ab? == a(b|)
	if s1.isEmpty && s1.assert != noneAssert {
		return nil, fmt.Errorf("previous token is not quantifiable")
	}
	toReturn := &nfaState{}
	toReturn.isEmpty = true
	toReturn.isAlternation = true
	toReturn.isQuestion = true
	toReturn.content = newContents(epsilon)
	toReturn.splitState = s1
	toReturn.output = append([]*nfaState{}, toReturn)
	toReturn.output = append(toReturn.output, s1.output...)
	//	s2.transitions = make(map[int][]*nfaState)
	return toReturn, nil
}

// Creates and returns a new state with the 'default' values.
func newState() nfaState {
	ret := nfaState{
		output: make([]*nfaState, 0),
		//		transitions:     make(map[int][]*nfaState),
		assert:          noneAssert,
		except:          append([]rune{}, 0),
		lookaroundRegex: "",
		groupEnd:        false,
		groupBegin:      false,
	}
	ret.output = append(ret.output, &ret)
	return ret
}

// Creates and returns a state that _always_ has a zero-length match.
func zeroLengthMatchState() *nfaState {
	start := &nfaState{}
	start.content = newContents(epsilon)
	start.isEmpty = true
	start.assert = alwaysTrueAssert
	start.output = append([]*nfaState{}, start)
	return start
}

func (s nfaState) equals(other nfaState) bool {
	return s.isEmpty == other.isEmpty &&
		s.isLast == other.isLast &&
		slices.Equal(s.output, other.output) &&
		slices.Equal(s.content, other.content) &&
		s.next == other.next &&
		s.isKleene == other.isKleene &&
		s.isQuestion == other.isQuestion &&
		s.isAlternation == other.isAlternation &&
		s.splitState == other.splitState &&
		s.assert == other.assert &&
		s.allChars == other.allChars &&
		slices.Equal(s.except, other.except) &&
		s.lookaroundNFA == other.lookaroundNFA &&
		s.groupBegin == other.groupBegin &&
		s.groupEnd == other.groupEnd &&
		s.groupNum == other.groupNum &&
		slices.Equal(s.threadGroups, other.threadGroups) &&
		s.threadBackref == other.threadBackref
}

func stateExists(list []nfaState, s nfaState) bool {
	for i := range list {
		if list[i].equals(s) {
			return true
		}
	}
	return false
}