package main

import (
	"database/sql"
	"fmt"
	"log"
	"os"
	"time"

	"strings"

	"flag"

	"path/filepath"

	_ "github.com/lib/pq"
)

var (
	printDebugMsg = false

	// schemas
	allSchemasQueryTemp = `select schema_name, 'create schema '||schema_name||' authorization '||schema_owner||';' as schema_ddl
		from information_schema.schemata 
		where schema_name not like 'pg_%' and schema_name not in ('information_schema') order by 1`

	oneSchemaQueryTemp = `select schema_name, 'create schema '||schema_name||' authorization '||schema_owner||';' as schema_ddl
		from information_schema.schemata where schema_name = '${SCHEMANAME}'`

	// functions
	dumpFuncsQueryTemp = `
		select 
			s.proc_name, 
			pg_get_functiondef(s.oid)||' '||proc_volatile||' COST '||s.procost||';' as proc_source_code, 
			proc_identity,
			'ALTER FUNCTION '||s.proc_identity||' OWNER TO '||s.owner||';' as proc_grant
		from (
			select p.oid, p.pronamespace::regnamespace::text as schemaname, p.pronamespace::regnamespace::text||'.'||p.proname as proc_name, 
			p.pronamespace::regnamespace::text||'.'||p.proname||'('||pg_get_function_identity_arguments(p.oid)||')' as proc_identity,
			(select rolname from pg_roles r where r.oid = p.proowner) as owner,
			case when p.provolatile::text='v'::text then 'volatile'::text when p.provolatile::text='i'::text then 'immutable'::text 
			when p.provolatile::text='s'::text then 'stable'::text else p.provolatile::text end as proc_volatile,
			p.procost
			from pg_proc p 
		) s
		where s.schemaname not in ('pg_catalog','information_schema') ${SCHEMACOND} ${FUNCNAMECOND}
		order by s.schemaname, s.proc_name`

	funcSchemaCondition = ` and schemaname = '${SCHEMANAME}' `

	funcNameCondition = ` and proc_name = '${OBJNAME}' `

	// views
	dumpViewsQueryTemp = `select schemaname||'.'||viewname as viewname, 
		'create or replace view '||schemaname||'.'||viewname||' as '||definition as def,
		schemaname||'.'||viewname as viewidentity,
		'' as viewgrant
		from pg_views where 1=1 ${SCHEMACOND} ${VIEWNAMECOND}`

	viewNameCondition = ` and viewname = '${OBJNAME}' `

	viewSchemaCondition = ` and schemaname = '${SCHEMANAME}' `

	// tables
	dumpTableQueryTemp = `
		select 
			tl.table_schema, 
			tl.table_name, 
			coalesce(ddl.table_exists, 0) as table_exists, 
			coalesce(ddl.table_is_child, 0) as table_is_child, 
			coalesce(ddl.table_inherits, '') as table_inherits, 
			coalesce(ddl.table_structure, '') as table_structure, 
			coalesce(ddl.table_constraints, '') as table_constraints, 
			coalesce(ddl.table_tablespace, '') as table_tablespace, 
			coalesce(ddl.table_indexes, '') as table_indexes
		from (
			select table_schema, table_name, table_type,
			(select count(*) from pg_inherits where inhrelid=(x.table_schema||'.'||x.table_name)::regclass) as child
			from information_schema.tables x
		) tl 
		join lateral (select
		(select count(*) from information_schema.tables t where table_schema||'.'||table_name=tl.table_schema||'.'||tl.table_name) 		as table_exists,
		(select count(*) from pg_inherits where inhrelid=(tl.table_schema||'.'||tl.table_name)::regclass) 		as table_is_child,
		(select 'INHERITS ('|| inhparent::regclass||')' from pg_inherits where inhrelid=(tl.table_schema||'.'||tl.table_name)::regclass) 		as table_inherits,
		(with srcdata as (SELECT a.attname ||' '|| format_type(a.atttypid, a.atttypmod) ||case when a.attnotnull then ' NOT NULL' else '' end as _def FROM pg_attribute a
		WHERE attrelid=(tl.table_schema||'.'||tl.table_name)::regclass and a.attstattarget = -1 order by attnum) select string_agg(_def,',
		') as _ddl from srcdata) 		as table_structure,
		(with srcdata as (select 'CONSTRAINT '||conname||' '||pg_catalog.pg_get_constraintdef(oid) as _ddl from pg_constraint where conrelid=(tl.table_schema||'.'||tl.table_name)::regclass
		order by case contype when 'p' then '1' when 'f' then '2'||conname else '3' end )
		select string_agg(_ddl,',
		') as _ddl from srcdata)		as table_constraints,
		(with srcdata as (select case when tablespace is not null then 'TABLESPACE '||t.tablespace else ' ' end as _ddl from pg_tables t
		where schemaname||'.'||tablename = (tl.table_schema||'.'||tl.table_name))
		select string_agg(_ddl,',
		') as _ddl from srcdata )		as table_tablespace,
		(with srcdata as (select indexdef||case when tablespace is not null then ' tablespace '||tablespace else '' end||';' as _ddl from pg_indexes  
		where schemaname||'.'||tablename = (tl.table_schema||'.'||tl.table_name))
		select string_agg(_ddl,'
		') as _ddl from srcdata)		as table_indexes) ddl
		on true
		where upper(table_type) like '%TABLE%' ${SCHEMACOND} ${TABLENAMECOND} ${WITHCHILDCOND}
		order by 1, 2`

	tableNameCondition      = ` and table_name='${OBJNAME}' `
	tableWithChildCondition = ` and child=1 `
	tableNoChildCondition   = ` and child=0 `
	tableSchemaCondition    = ` and table_schema='${SCHEMANAME}' `
)

type pgdata struct {
	objName     string
	objSource   string
	objIdentity string
	objGrant    string
}

type schemadata struct {
	schName   string
	schSource string
}

type pgtable struct {
	tableSchema      string
	tableName        string
	tableExists      int
	tableIsChild     int
	tableInherits    string
	tableStructure   string
	tableConstraints string
	tableTablespace  string
	tableIndexes     string
}

func main() {

	_, prgName := filepath.Split(os.Args[0])
	usageHint := fmt.Sprintf(`
USAGE: %s --schema=schemaname | --all_schemas [ --functions | --views | --tables [ --withchildren ] | --all_objects ] [ --name=object_name ] [--debug ] [ --help ]
	
Program dumps source codes of postgresql functions in separate files with extension ".sql".
These files can be directly implemented back into database to install functions.
`, prgName)

	paramSchema := flag.String("schema", "", "name of pg schema in database to dump functions from")
	paramObjName := flag.String("name", "", "function name to be dump")
	paramAllSchemas := flag.Bool("all_schemas", false, "dumps functions from all schemas in database")
	paramDebug := flag.Bool("debug", false, "print debug messages")
	paramHelp := flag.Bool("help", false, "print help")
	paramAllObjects := flag.Bool("all_objects", false, "dump all objects source code")
	paramFuncs := flag.Bool("functions", false, "dump functions source code")
	paramTables := flag.Bool("tables", false, "dump tables source code (only parents or single tables)")
	paramWithChildTables := flag.Bool("withchildren", false, "dump tables source code - with children (inhereted) tables")
	paramViews := flag.Bool("views", false, "dump views source code")
	flag.Parse()

	pgSchema := *paramSchema
	pgDumpAllSchemas := *paramAllSchemas
	pgObjName := *paramObjName
	printDebugMsg = *paramDebug
	printHelp := *paramHelp
	dumpViews := *paramViews
	dumpFuncs := *paramFuncs
	dumpTabs := *paramTables
	dumpWithChildTables := *paramWithChildTables
	dumpAllObjects := *paramAllObjects
	debugMsg(fmt.Sprintf("pgSchema: %s", pgSchema))
	debugMsg(fmt.Sprintf("pgDumpAllSchemas: %v", pgDumpAllSchemas))
	debugMsg(fmt.Sprintf("pgObjName: %s", pgObjName))
	debugMsg(fmt.Sprintf("printDebugMsg: %v", printDebugMsg))
	debugMsg(fmt.Sprintf("printHelp: %v", printHelp))
	debugMsg(fmt.Sprintf("dumpViews: %v", dumpViews))
	debugMsg(fmt.Sprintf("dumpFuncs: %v", dumpFuncs))
	debugMsg(fmt.Sprintf("dumpTabs: %v", dumpTabs))
	debugMsg(fmt.Sprintf("dumpAllObjects: %v", dumpAllObjects))

	if printHelp == true || (pgSchema == "" && pgDumpAllSchemas == false) {
		log.Fatal(usageHint)
	}

	var err error
	var createSchemaQuery string
	pgTableSchemaCondition := ""
	pgViewSchemaCondition := ""
	pgFuncSchemaCondition := ""
	if pgSchema != "" {
		pgTableSchemaCondition = strings.Replace(tableSchemaCondition, "${SCHEMANAME}", pgSchema, -1)
		pgViewSchemaCondition = strings.Replace(viewSchemaCondition, "${SCHEMANAME}", pgSchema, -1)
		pgFuncSchemaCondition = strings.Replace(funcSchemaCondition, "${SCHEMANAME}", pgSchema, -1)
		createSchemaQuery = strings.Replace(oneSchemaQueryTemp, "${SCHEMANAME}", pgSchema, -1)
	} else {
		createSchemaQuery = allSchemasQueryTemp
	}

	var pgFuncNameCondition, pgViewNameCondition, pgTableNameCondition string
	if pgObjName != "" {
		pgFuncNameCondition = strings.Replace(funcNameCondition, "${OBJNAME}", pgObjName, -1)
		pgViewNameCondition = strings.Replace(viewNameCondition, "${OBJNAME}", pgObjName, -1)
		pgTableNameCondition = strings.Replace(tableNameCondition, "${OBJNAME}", pgObjName, -1)
	} else {
		pgFuncNameCondition = ""
		pgViewNameCondition = ""
		pgTableNameCondition = ""
	}

	pgURI := requireEnvVar("POSTGRES_URI")
	checkValue("postgresql URI", pgURI, true, false)

	pgDB, err := sql.Open("postgres", pgURI)
	if err != nil {
		log.Fatalln("Cannot connect into postgresql db: ", err)
	}
	defer func() {
		if errClose := pgDB.Close(); err != nil {
			log.Println(curTime(), "closing source database:", errClose.Error())
		}
	}()
	if err = pgDB.Ping(); err != nil {
		log.Fatalln("Cannot ping postgresql db: ", err)
	}

	var sd schemadata
	sch, err := pgDB.Query(createSchemaQuery)
	if err != nil {
		log.Fatalln("cannot not run query:", err)
	}
	if sch != nil {
		for sch.Next() {
			if err = sch.Scan(&sd.schName, &sd.schSource); err != nil {
				log.Fatalln("cannot get schema ", &sd.schName, " DDL:", err)
			}
		}
		fmt.Println(sd.schName)
		createSchemaSQLFile(sd)
	}

	if dumpTabs == true || dumpAllObjects == true {
		pgTabsQuery := strings.Replace(dumpTableQueryTemp, "${SCHEMACOND}", pgTableSchemaCondition, -1)
		pgTabsQuery = strings.Replace(pgTabsQuery, "${TABLENAMECOND}", pgTableNameCondition, -1)
		if dumpWithChildTables == true {
			pgTabsQuery = strings.Replace(pgTabsQuery, "${WITHCHILDCOND}", tableWithChildCondition, -1)
		} else {
			pgTabsQuery = strings.Replace(pgTabsQuery, "${WITHCHILDCOND}", tableNoChildCondition, -1)
		}
		debugMsg(fmt.Sprint("pgTabsQuery: ", pgTabsQuery))

		dumpTables(pgDB, pgTabsQuery, dumpWithChildTables)
	}

	if dumpFuncs == true || dumpAllObjects == true {
		pgFuncQuery := strings.Replace(dumpFuncsQueryTemp, "${SCHEMACOND}", pgFuncSchemaCondition, -1)
		pgFuncQuery = strings.Replace(pgFuncQuery, "${FUNCNAMECOND}", pgFuncNameCondition, -1)
		debugMsg(fmt.Sprint("pgFuncQuery: ", pgFuncQuery))

		dumpObjects(pgDB, pgFuncQuery, "functions")
	}

	if dumpViews == true || dumpAllObjects == true {
		pgViewsQuery := strings.Replace(dumpViewsQueryTemp, "${SCHEMACOND}", pgViewSchemaCondition, -1)
		pgViewsQuery = strings.Replace(pgViewsQuery, "${VIEWNAMECOND}", pgViewNameCondition, -1)
		debugMsg(fmt.Sprint("pgViewsQuery: ", pgViewsQuery))

		dumpObjects(pgDB, pgViewsQuery, "views")
	}

	fmt.Println(curTime(), "ALL DONE")
}

func dumpTables(pgDB *sql.DB, pgQuery string, dumpWithChildTables bool) {
	rows, err := pgDB.Query(pgQuery)
	if err != nil {
		log.Fatalln("cannot not run query:", pgQuery, " ERROR: ", err)
	}
	var count int
	var t pgtable
	var f pgdata
	if rows != nil {
		for rows.Next() {
			if err = rows.Scan(
				&t.tableSchema,
				&t.tableName,
				&t.tableExists,
				&t.tableIsChild,
				&t.tableInherits,
				&t.tableStructure,
				&t.tableConstraints,
				&t.tableTablespace,
				&t.tableIndexes,
			); err != nil {
				log.Fatalln("cannot parse data:", err)
			}
			count++
			if (dumpWithChildTables == false) || ((dumpWithChildTables == true) && (t.tableIsChild == 0)) {
				f.objIdentity = t.tableSchema + "." + t.tableName
				f.objName = f.objIdentity
				f.objSource = "CREATE TABLE " + f.objName + `(
		`
				if t.tableIsChild == 0 {
					f.objSource = f.objSource + t.tableStructure
					if t.tableConstraints != "" {
						f.objSource = f.objSource + `, 
	` + t.tableConstraints
					}
					f.objSource = f.objSource + ")"
				} else {
					if t.tableConstraints != "" {
						f.objSource = f.objSource + t.tableConstraints
					}
					f.objSource = f.objSource + ")"
					if t.tableInherits == "" {
						log.Fatal("table ", f.objName, " is child table but source code for inheritance is empty")
					} else {
						f.objSource = f.objSource + `
	` + t.tableInherits
					}
				}

				if t.tableTablespace != "" {
					f.objSource = f.objSource + `
` + t.tableTablespace
				}
				f.objSource = f.objSource + ";"
				if t.tableIndexes != "" {
					f.objSource = f.objSource + `

` + t.tableIndexes
				}
				fmt.Println(count, ":", f.objName)
				debugMsg(fmt.Sprint("DDL:", f.objSource))
				createObjSQLFile(f)
			}
		}
	}
	fmt.Println(curTime(), count, " tables DONE")
}

func dumpObjects(pgDB *sql.DB, pgQuery string, description string) {
	rows, err := pgDB.Query(pgQuery)
	if err != nil {
		log.Fatalln("cannot not run query:", pgQuery, " ERROR: ", err)
	}
	var f pgdata
	var count int
	if rows != nil {
		for rows.Next() {
			if err = rows.Scan(
				&f.objName,
				&f.objSource,
				&f.objIdentity,
				&f.objGrant,
			); err != nil {
				log.Fatalln("cannot parse data:", err)
			}
			count++
			fmt.Println(count, ":", f.objName)
			createObjSQLFile(f)
		}
	}
	fmt.Println(curTime(), count, " ", description, " DONE")
}

func createSchemaSQLFile(s schemadata) {
	sname := s.schName + ".sql"
	so, err := os.Create(sname)
	if err != nil {
		log.Fatal("Cannot create file ", sname)
	}
	defer func() {
		if err := so.Close(); err != nil {
			panic(err)
		}
	}()
	if _, err = so.Write([]byte(s.schSource)); err != nil {
		panic(err)
	}
}

func createObjSQLFile(f pgdata) {
	fname := f.objIdentity + ".sql"
	debugMsg(fmt.Sprint("fname:", fname))
	fo, err := os.Create(fname)
	if err != nil {
		log.Fatal("Cannot create file ", fname)
	}
	defer func() {
		if err := fo.Close(); err != nil {
			panic(err)
		}
	}()
	if _, err = fo.Write([]byte(f.objSource)); err != nil {
		panic(err)
	}
}

func curTime() string {
	return time.Now().UTC().Format(time.RFC3339) + ":"
}

func requireEnvVar(s string) string {
	env, ok := os.LookupEnv(s)
	if !ok {
		log.Fatalln(curTime(), s, "isn't defined")
	}
	return (env)
}

func checkValue(name string, value string, required bool, printit bool) {
	if value == "" && required == true {
		log.Fatal("ERROR: variable ", name, " cannot be empty!")
	}
	if printit == true {
		fmt.Println(curTime(), name, ": ", value)
	}
}

func debugMsg(t string) {
	if printDebugMsg == true {
		fmt.Println(curTime(), t)
	}
}