package main

import (
	"database/sql"
	"fmt"
	"log"
	"os"
	"time"

	"strings"

	"flag"

	"path/filepath"

	_ "github.com/lib/pq"
)

var (
	printDebugMsg = false

	mainQuery = `
		select 
			s.proc_name, 
			pg_get_functiondef(s.oid)||' '||proc_volatile||' COST '||s.procost||';' as proc_source_code, 
			proc_identity,
			'ALTER FUNCTION '||s.proc_identity||' OWNER TO '||s.owner||';' as proc_grant
		from (
			select p.oid, p.pronamespace::regnamespace::text as schemaname, p.pronamespace::regnamespace::text||'.'||p.proname as proc_name, 
			p.pronamespace::regnamespace::text||'.'||p.proname||'('||pg_get_function_identity_arguments(p.oid)||')' as proc_identity,
			(select rolname from pg_roles r where r.oid = p.proowner) as owner,
			case when p.provolatile::text='v'::text then 'volatile'::text when p.provolatile::text='i'::text then 'immutable'::text 
			when p.provolatile::text='s'::text then 'stable'::text else p.provolatile::text end as proc_volatile,
			p.procost
			from pg_proc p 
		) s
		where s.schemaname not in ('pg_catalog','information_schema') ${SCHEMACOND} ${FUNCNAMECOND}
		order by s.schemaname, s.proc_name`

	schemaCondition = ` and schemaname = '${SCHEMANAME}' `

	functionCondition = ` and proc_name = '${FUNCNAME}' `
)

type fncdata struct {
	fncName     string
	fncSource   string
	fncIdentity string
	fncGrant    string
}

func main() {

	_, prgName := filepath.Split(os.Args[0])
	usageHint := fmt.Sprintf("USAGE: %s [ --schema=schemaname ] [ --function=functionname ] [--debug ] [ --help ]", prgName)

	paramSchema := flag.String("schema", "", "name of pg schema in database")
	paramFunction := flag.String("function", "", "function name to be dump")
	paramDebug := flag.Bool("debug", false, "print debug messages - t / f")
	paramHelp := flag.Bool("help", false, "print help - t / f")
	flag.Parse()

	pgSchema := *paramSchema
	pgFunction := *paramFunction
	printDebugMsg = *paramDebug
	printHelp := *paramHelp
	debugMsg(fmt.Sprintf("pgSchema: %s", pgSchema))
	debugMsg(fmt.Sprintf("pgFunction: %s", pgFunction))
	debugMsg(fmt.Sprintf("printDebugMsg: %v", printDebugMsg))
	debugMsg(fmt.Sprintf("printHelp: %v", printHelp))

	if printHelp == true {
		log.Fatal(usageHint)
	}

	var err error
	var pgSchemaCondition string
	if pgSchema != "" {
		pgSchemaCondition = strings.Replace(schemaCondition, "${SCHEMANAME}", pgSchema, -1)
	} else {
		pgSchemaCondition = ""
	}

	var pgFunctionCondition string
	if pgFunction != "" {
		pgFunctionCondition = strings.Replace(functionCondition, "${FUNCNAME}", pgFunction, -1)
	} else {
		pgFunctionCondition = ""
	}

	pgMainQuery := strings.Replace(mainQuery, "${SCHEMACOND}", pgSchemaCondition, -1)
	pgMainQuery = strings.Replace(pgMainQuery, "${FUNCNAMECOND}", pgFunctionCondition, -1)

	pgURI := requireEnvVar("POSTGRES_URI")
	checkValue("postgresql URI", pgURI, true, false)

	pgDB, err := sql.Open("postgres", pgURI)
	if err != nil {
		log.Fatalln("Cannot connect into postgresql db: ", err)
	}
	defer func() {
		if errClose := pgDB.Close(); err != nil {
			log.Println(curTime(), "closing source database:", errClose.Error())
		}
	}()
	if err = pgDB.Ping(); err != nil {
		log.Fatalln("Cannot ping postgresql db: ", err)
	}

	rows, err := pgDB.Query(pgMainQuery)
	if err != nil {
		log.Fatalln(curTime(), "could not run query:", err)
	}
	var f fncdata
	var count int
	if rows != nil {
		for rows.Next() {
			if err = rows.Scan(
				&f.fncName,
				&f.fncSource,
				&f.fncIdentity,
				&f.fncGrant,
			); err != nil {
				log.Fatalln(curTime(), "cannot parse data:", err)
			}
			count++
			fmt.Println(count, ":", f.fncName)
			createSQLFile(f)
		}
	}

	fmt.Println(curTime(), count, "functions DONE")

}

func createSQLFile(f fncdata) {
	fname := f.fncIdentity + ".sql"
	fo, err := os.Create(fname)
	if err != nil {
		log.Fatal("Cannot create file ", fname)
	}
	defer func() {
		if err := fo.Close(); err != nil {
			panic(err)
		}
	}()
	if _, err = fo.Write([]byte(f.fncSource)); err != nil {
		panic(err)
	}
}

func curTime() string {
	return time.Now().UTC().Format(time.RFC3339) + ":"
}

func requireEnvVar(s string) string {
	env, ok := os.LookupEnv(s)
	if !ok {
		log.Fatalln(curTime(), s, "isn't defined")
	}
	return (env)
}

func checkValue(name string, value string, required bool, printit bool) {
	if value == "" && required == true {
		log.Fatal("ERROR: variable ", name, " cannot be empty!")
	}
	if printit == true {
		fmt.Println(curTime(), name, ": ", value)
	}
}

func debugMsg(t string) {
	if printDebugMsg == true {
		fmt.Println(curTime(), t)
	}
}