package main

import (
	"bufio"
	"flag"
	"fmt"
	"io"
	"os"

	"github.com/fatih/color"
)

func main() {
	// Flags for the regex Compile function
	flagsToCompile := make([]ReFlag, 0)

	invertFlag := flag.Bool("v", false, "Invert match.")
	// This flag has two 'modes':
	// 1. Without '-v': Prints only matches. Prints a newline after every match.
	// 2. With '-v': Substitutes all matches with empty string.
	onlyFlag := flag.Bool("o", false, "Print only colored content. Overrides -l.")
	lineFlag := flag.Bool("l", false, "Only print lines with a match (or with no matches, if -v is enabled). Similar to grep's default.")
	multiLineFlag := flag.Bool("t", false, "Multi-line mode. Treats newline just like any character.")
	printMatchesFlag := flag.Bool("p", false, "Prints start and end index of each match. Can only be used with '-t' for multi-line mode.")
	caseInsensitiveFlag := flag.Bool("i", false, "Case-insensitive. Disregard the case of all characters.")
	matchNum := flag.Int("m", 0, "Print the match with the given index. Eg. -m 3 prints the third match.")
	substituteText := flag.String("s", "", "Substitute the contents of each match with the given string. Overrides -o and -v")
	flag.Parse()

	// These flags have to be passed to the Compile function
	if *multiLineFlag {
		flagsToCompile = append(flagsToCompile, RE_MULTILINE)
	}
	if *caseInsensitiveFlag {
		flagsToCompile = append(flagsToCompile, RE_CASE_INSENSITIVE)
	}

	// -l and -o are mutually exclusive: -o overrides -l
	if *onlyFlag {
		*lineFlag = false
	}
	// Check if substitute and matchNum flags have been enabled
	substituteFlagEnabled := false
	matchNumFlagEnabled := false
	flag.Visit(func(f *flag.Flag) {
		if f.Name == "s" {
			substituteFlagEnabled = true
		}
		if f.Name == "m" {
			matchNumFlagEnabled = true
		}
	})

	// Validate matchNumFlag - must be positive integer
	if matchNumFlagEnabled && *matchNum < 1 {
		panic("Invalid match number to print.")
	}

	// Process:
	// 1. Convert regex into postfix notation (Shunting-Yard algorithm)
	// 		a. Add explicit concatenation operators to facilitate this
	// 2. Build NFA from postfix representation (Thompson's algorithm)
	// 3. Run the string against the NFA

	if len(flag.Args()) != 1 { // flag.Args() also strips out program name
		fmt.Println("ERROR: Missing cmdline args")
		os.Exit(22)
	}
	var re string
	re = flag.Args()[0]
	var test_str string
	var err error
	var linesRead bool // Whether or not we have read the lines in the file
	lineNum := 0       // Current line number
	// Create reader for stdin and writer for stdout
	reader := bufio.NewReader(os.Stdin)
	out := bufio.NewWriter(os.Stdout)

	regComp, err := Compile(re, flagsToCompile...)
	if err != nil {
		fmt.Println(err)
		return
	}
	for true {
		if linesRead {
			break
		}
		if !(*multiLineFlag) {
			// Read every string from stdin until we encounter an error. If the error isn't EOF, panic.
			test_str, err = reader.ReadString('\n')
			lineNum++
			if err != nil {
				if err == io.EOF {
					linesRead = true
				} else {
					panic(err)
				}
			}
			if len(test_str) > 0 && test_str[len(test_str)-1] == '\n' {
				test_str = test_str[:len(test_str)-1]
			}
		} else {
			// Multi-line mode - read every line of input into a temp. string.
			// test_str will contain all lines of input (including newline characters)
			// as one string.
			var temp string
			for temp, err = reader.ReadString('\n'); err == nil; temp, err = reader.ReadString('\n') {
				test_str += temp
			}
			// Assuming err != nil
			if err == io.EOF {
				if len(temp) > 0 {
					test_str += temp // Add the last line (if it is non-empty)
				}
				linesRead = true
			} else {
				panic(err)
			}
		}
		matchIndices := make([]Match, 0)
		if matchNumFlagEnabled {
			tmp, err := FindNthMatch(regComp, test_str, *matchNum)
			if err == nil {
				matchIndices = append(matchIndices, tmp)
			}
		} else {
			matchIndices = FindAllMatches(regComp, test_str)
		}

		if *printMatchesFlag {
			// if we are in single line mode, print the line on which
			// the matches occur
			if len(matchIndices) > 0 {
				if !(*multiLineFlag) {
					fmt.Fprintf(out, "Line %d:\n", lineNum)
				}
				for _, m := range matchIndices {
					fmt.Fprintf(out, "%s\n", m.toString())
				}
				err := out.Flush()
				if err != nil {
					panic(err)
				}
			}
			continue
		}
		// Decompose the array of matchIndex structs into a flat unique array of ints - if matchIndex is {4,7}, flat array will contain 4,5,6
		// This should make checking O(1) instead of O(n)
		indicesToPrint := new_uniq_arr[int]()
		for _, idx := range matchIndices {
			indicesToPrint.add(genRange(idx[0].startIdx, idx[0].endIdx)...)
		}
		// If we are inverting, then we should print the indices which _didn't_ match
		// in color.
		if *invertFlag {
			oldIndices := indicesToPrint.values()
			indicesToPrint = new_uniq_arr[int]()
			// Explanation:
			// Find all numbers from 0 to len(test_str) that are NOT in oldIndices.
			// These are the values we want to print, now that we have inverted the match.
			// Re-initialize indicesToPrint and add all of these values to it.
			indicesToPrint.add(setDifference(genRange(0, len(test_str)), oldIndices)...)

		}
		// If lineFlag is enabled, we should only print something if:
		// 		a. We are not inverting, and have at least one match on the current line
		// 		OR
		// 		b. We are inverting, and have no matches at all on the current line.
		// This checks for the inverse, and continues if it is true.
		if *lineFlag {
			if !(*invertFlag) && len(matchIndices) == 0 || *invertFlag && len(matchIndices) > 0 {
				continue
			}
		}

		// If we are substituting, we need a different behavior, as follows:
		// For every character in the test string:
		// 		1. Check if the index is the start of any matchIndex
		// 		2. If so, print the substitute text, and set our index to
		//			the corresponding end index.
		// 		3. If not, just print the character.
		if substituteFlagEnabled {
			for i := range test_str {
				inMatchIndex := false
				for _, m := range matchIndices {
					if i == m[0].startIdx {
						fmt.Fprintf(out, "%s", *substituteText)
						i = m[0].endIdx
						inMatchIndex = true
						break
					}
				}
				if !inMatchIndex {
					fmt.Fprintf(out, "%c", test_str[i])
				}
			}
		} else {
			for i, c := range test_str {
				if indicesToPrint.contains(i) {
					color.New(color.FgRed).Fprintf(out, "%c", c)
					// Newline after every match - only if -o is enabled and -v is disabled.
					if *onlyFlag && !(*invertFlag) {
						for _, idx := range matchIndices {
							if i+1 == idx[0].endIdx { // End index is one more than last index of match
								fmt.Fprintf(out, "\n")
								break
							}
						}
					}
				} else {
					if !(*onlyFlag) {
						fmt.Fprintf(out, "%c", c)
					}
				}
			}
		}
		err = out.Flush()
		if err != nil {
			panic(err)
		}
		fmt.Println()
	}
}