package main

import (
	"slices"
)

const EPSILON int = 0

type assertType int

const (
	NONE assertType = iota
	SOS
	EOS
	WBOUND
	NONWBOUND
	PLA // Positive lookahead
	NLA // Negative lookahead
	PLB // Positive lookbehind
	NLB // Negative lookbehind
)

type State struct {
	content                    stateContents    // Contents of current state
	isEmpty                    bool             // If it is empty - Union operator and Kleene star states will be empty
	isLast                     bool             // If it is the last state (acept state)
	output                     []*State         // The outputs of the current state ie. the 'outward arrows'. A union operator state will have more than one of these.
	transitions                map[int][]*State // Transitions to different states (maps a character (int representation) to a _list of states. This is useful if one character can lead multiple states eg. ab|aa)
	isKleene                   bool             // Identifies whether current node is a 0-state representing Kleene star
	assert                     assertType       // Type of assertion of current node - NONE means that the node doesn't assert anything
	allChars                   bool             // Whether or not the state represents all characters (eg. a 'dot' metacharacter). A 'dot' node doesn't store any contents directly, as it would take up too much space
	except                     []rune           // Only valid if allChars is true - match all characters _except_ the ones in this block. Useful for inverting character classes.
	lookaroundRegex            string           // Only for lookaround states - Contents of the regex that the lookaround state holds
	lookaroundNFA              *State           // Holds the NFA of the lookaroundRegex - if it exists
	lookaroundNumCaptureGroups int              // Number of capturing groups in lookaround regex if current node is a lookaround
	groupBegin                 bool             // Whether or not the node starts a capturing group
	groupEnd                   bool             // Whether or not the node ends a capturing group
	groupNum                   int              // Which capturing group the node starts / ends
	// The following properties depend on the current match - I should think about resetting them for every match.
	zeroMatchFound bool    // Whether or not the state has been used for a zero-length match - only relevant for zero states
	threadGroups   []Group // Assuming that a state is part of a 'thread' in the matching process, this array stores the indices of capturing groups in the current thread. As matches are found for this state, its groups will be copied over.
}

// Clones the NFA starting from the given state.
func cloneState(start *State) *State {
	return cloneStateHelper(start, make(map[*State]*State))
}

// Helper function for clone. The map is used to keep track of which states have
// already been copied, and which ones haven't.
// This function was created using output from Llama3.1:405B.
func cloneStateHelper(state *State, cloneMap map[*State]*State) *State {
	// Base case - if the clone exists in our map, return it.
	if clone, exists := cloneMap[state]; exists {
		return clone
	}
	if state == nil {
		return nil
	}
	// Recursive case - if the clone doesn't exist, create it, add it to the map,
	// and recursively call for each of the transition states.
	clone := &State{
		content:         append([]int{}, state.content...),
		isEmpty:         state.isEmpty,
		isLast:          state.isLast,
		output:          make([]*State, len(state.output)),
		transitions:     make(map[int][]*State),
		isKleene:        state.isKleene,
		assert:          state.assert,
		zeroMatchFound:  state.zeroMatchFound,
		allChars:        state.allChars,
		except:          append([]rune{}, state.except...),
		lookaroundRegex: state.lookaroundRegex,
		groupEnd:        state.groupEnd,
		groupBegin:      state.groupBegin,
		groupNum:        state.groupNum,
	}
	cloneMap[state] = clone
	for i, s := range state.output {
		if s == state {
			clone.output[i] = clone
		} else {
			clone.output[i] = cloneStateHelper(s, cloneMap)
		}
	}
	for k, v := range state.transitions {
		clone.transitions[k] = make([]*State, len(v))
		for i, s := range v {
			if s == state {
				clone.transitions[k][i] = clone
			} else {
				clone.transitions[k][i] = cloneStateHelper(s, cloneMap)
			}
		}
	}
	if state.lookaroundNFA == state {
		clone.lookaroundNFA = clone
	}
	clone.lookaroundNFA = cloneStateHelper(state.lookaroundNFA, cloneMap)
	return clone
}

// Checks if the given state's assertion is true. Returns true if the given
// state doesn't have an assertion.
func (s State) checkAssertion(str []rune, idx int) bool {
	if s.assert == SOS {
		return idx == 0
	}
	if s.assert == EOS {
		// Index is at the end of the string, or it points to the last character which is a newline
		return idx == len(str) || (idx == len(str)-1 && str[len(str)-1] == '\n')
	}
	if s.assert == WBOUND {
		return isWordBoundary(str, idx)
	}
	if s.assert == NONWBOUND {
		return !isWordBoundary(str, idx)
	}
	if s.isLookaround() {
		// The process here is simple:
		// 		1. Compile the regex stored in the state's contents.
		// 		2. Run it on a subset of the test string, that ends after the current index in the string
		// 		3. Based on the kind of lookaround (and the indices we get), determine what action to take.
		startState := s.lookaroundNFA
		var runesToMatch []rune
		var strToMatch string
		if s.assert == PLA || s.assert == NLA {
			runesToMatch = str[idx:]
		} else {
			runesToMatch = str[:idx]
		}

		if len(runesToMatch) == 0 {
			strToMatch = ""
		} else {
			strToMatch = string(runesToMatch)
		}

		matchIndices := findAllMatches(Reg{startState, s.lookaroundNumCaptureGroups}, strToMatch)

		numMatchesFound := 0
		for _, matchIdx := range matchIndices {
			if s.assert == PLA || s.assert == NLA { // Lookahead - return true (or false) if at least one match starts at 0. Zero is used because the test-string _starts_ from idx.
				if matchIdx[0].startIdx == 0 {
					numMatchesFound++
				}
			}
			if s.assert == PLB || s.assert == NLB { // Lookbehind - return true (or false) if at least one match _ends_ at the current index.
				if matchIdx[0].endIdx == idx {
					numMatchesFound++
				}
			}
		}
		if s.assert == PLA || s.assert == PLB { // Positive assertions want at least one match
			return numMatchesFound > 0
		}
		if s.assert == NLA || s.assert == NLB { // Negative assertions only want zero matches
			return numMatchesFound == 0
		}
	}
	return true
}

// Returns true if the contents of 's' contain the value at the given index of the given string
func (s State) contentContains(str []rune, idx int) bool {
	if s.assert != NONE {
		return s.checkAssertion(str, idx)
	}
	if s.allChars {
		return !slices.Contains(slices.Concat(notDotChars, s.except), str[idx]) // Return true only if the index isn't a 'notDotChar', or isn't one of the exception characters for the current node.
	}
	// Default - s.assert must be NONE
	return slices.Contains(s.content, int(str[idx]))
}

func (s State) isLookaround() bool {
	return s.assert == PLA || s.assert == PLB || s.assert == NLA || s.assert == NLB
}

// Returns the matches for the character at the given index of the given string.
// Also returns the number of matches. Returns -1 if an assertion failed.
func (s State) matchesFor(str []rune, idx int) ([]*State, int) {
	// Assertions can be viewed as 'checks'. If the check fails, we return
	// an empty array and 0.
	// If it passes, we treat it like any other state, and return all the transitions.
	if s.assert != NONE {
		if s.checkAssertion(str, idx) == false {
			return make([]*State, 0), -1
		}
	}
	listTransitions := s.transitions[int(str[idx])]
	for _, dest := range s.transitions[int(ANY_CHAR)] {
		if !slices.Contains(slices.Concat(notDotChars, dest.except), str[idx]) {
			// Add an allChar state to the list of matches if:
			// 		a. The current character isn't a 'notDotChars' character. In single line mode, this includes newline. In multiline mode, it doesn't.
			// 		b. The current character isn't the state's exception list.
			listTransitions = append(listTransitions, dest)
		}
	}
	numTransitions := len(listTransitions)
	return listTransitions, numTransitions
}

type NFA struct {
	start   State
	outputs []State
}

// verifyLastStatesHelper performs the depth-first recursion needed for verifyLastStates
func verifyLastStatesHelper(state *State, visited map[*State]bool) {
	if len(state.transitions) == 0 {
		state.isLast = true
		return
	}
	//	if len(state.transitions) == 1 && len(state.transitions[state.content]) == 1 && state.transitions[state.content][0] == state { // Eg. a*
	if len(state.transitions) == 1 { // Eg. a*
		var moreThanOneTrans bool // Dummy variable, check if all the transitions for the current's state's contents have a length of one
		for _, c := range state.content {
			if len(state.transitions[c]) != 1 || state.transitions[c][0] != state {
				moreThanOneTrans = true
			}
		}
		state.isLast = !moreThanOneTrans
	}

	if state.isKleene { // A State representing a Kleene Star has transitions going out, which loop back to it. If all those transitions point to the same (single) state, then it must be a last state
		transitionDests := make([]*State, 0)
		for _, v := range state.transitions {
			transitionDests = append(transitionDests, v...)
		}
		if allEqual(transitionDests...) {
			state.isLast = true
			return
		}
	}
	if visited[state] == true {
		return
	}
	visited[state] = true
	for _, states := range state.transitions {
		for i := range states {
			if states[i] != state {
				verifyLastStatesHelper(states[i], visited)
			}
		}
	}
}

// verifyLastStates enables the 'isLast' flag for the leaf nodes (last states)
func verifyLastStates(start []*State) {
	verifyLastStatesHelper(start[0], make(map[*State]bool))
}

// Concatenates s1 and s2, returns the start of the concatenation.
func concatenate(s1 *State, s2 *State) *State {
	if s1 == nil {
		return s2
	}
	for i := range s1.output {
		for _, c := range s2.content { // Create transitions for every element in s1's content to s2'
			s1.output[i].transitions[c], _ = unique_append(s1.output[i].transitions[c], s2)
		}
	}
	s1.output = s2.output
	return s1
}

func kleene(s1 State) *State {
	toReturn := &State{}
	toReturn.transitions = make(map[int][]*State)
	toReturn.content = newContents(EPSILON)
	toReturn.isEmpty = true
	toReturn.isKleene = true
	toReturn.output = append(toReturn.output, toReturn)
	for i := range s1.output {
		for _, c := range toReturn.content {
			s1.output[i].transitions[c], _ = unique_append(s1.output[i].transitions[c], toReturn)
		}
	}
	for _, c := range s1.content {
		toReturn.transitions[c], _ = unique_append(toReturn.transitions[c], &s1)
	}
	return toReturn
}

func alternate(s1 *State, s2 *State) *State {
	toReturn := &State{}
	toReturn.transitions = make(map[int][]*State)
	toReturn.output = append(toReturn.output, s1.output...)
	toReturn.output = append(toReturn.output, s2.output...)
	// Unique append is used here (and elsewhere) to ensure that,
	// for any given transition, a state can only be mentioned once.
	// For example, given the transition 'a', the state 's1' can only be mentioned once.
	// This would lead to multiple instances of the same set of match indices, since both
	// 's1' states would be considered to match.
	for _, c := range s1.content {
		toReturn.transitions[c], _ = unique_append(toReturn.transitions[c], s1)
	}
	for _, c := range s2.content {
		toReturn.transitions[c], _ = unique_append(toReturn.transitions[c], s2)
	}
	toReturn.content = newContents(EPSILON)
	toReturn.isEmpty = true

	return toReturn
}

func question(s1 *State) *State { // Use the fact that ab? == a(b|)
	s2 := &State{}
	s2.transitions = make(map[int][]*State)
	s2.content = newContents(EPSILON)
	s2.output = append(s2.output, s2)
	s2.isEmpty = true
	s3 := alternate(s1, s2)
	return s3
}