// Copyright (C) 2014 Jakob Borg. All rights reserved. Use of this source code
// is governed by an MIT-style license that can be found in the LICENSE file.

package main

import (
	"bytes"
	"flag"
	"fmt"
	"go/ast"
	"go/format"
	"go/parser"
	"go/token"
	"io"
	"log"
	"os"
	"regexp"
	"strconv"
	"strings"
	"text/template"
)

type fieldInfo struct {
	Name      string
	IsBasic   bool   // handled by one the native Read/WriteUint64 etc functions
	IsSlice   bool   // field is a slice of FieldType
	FieldType string // original type of field, i.e. "int"
	Encoder   string // the encoder name, i.e. "Uint64" for Read/WriteUint64
	Convert   string // what to convert to when encoding, i.e. "uint64"
	Max       int    // max size for slices and strings
}

type structInfo struct {
	Name   string
	Fields []fieldInfo
}

var headerTpl = template.Must(template.New("header").Parse(`// ************************************************************
// This file is automatically generated by genxdr. Do not edit.
// ************************************************************

package {{.Package}}

import (
	"bytes"
	"io"

	"github.com/calmh/xdr"
)
`))

var encodeTpl = template.Must(template.New("encoder").Parse(`
func (o {{.TypeName}}) EncodeXDR(w io.Writer) (int, error) {
	var xw = xdr.NewWriter(w)
	return o.encodeXDR(xw)
}//+n

func (o {{.TypeName}}) MarshalXDR() ([]byte, error) {
	return o.AppendXDR(make([]byte, 0, 128))
}//+n

func (o {{.TypeName}}) MustMarshalXDR() []byte {
	bs, err := o.MarshalXDR()
	if err != nil {
		panic(err)
	}
	return bs
}//+n

func (o {{.TypeName}}) AppendXDR(bs []byte) ([]byte, error) {
	var aw = xdr.AppendWriter(bs)
	var xw = xdr.NewWriter(&aw)
	_, err := o.encodeXDR(xw)
	return []byte(aw), err
}//+n

func (o {{.TypeName}}) encodeXDR(xw *xdr.Writer) (int, error) {
	{{range $fieldInfo := .Fields}}
		{{if not $fieldInfo.IsSlice}}
			{{if ne $fieldInfo.Convert ""}}
				xw.Write{{$fieldInfo.Encoder}}({{$fieldInfo.Convert}}(o.{{$fieldInfo.Name}}))
			{{else if $fieldInfo.IsBasic}}
				{{if ge $fieldInfo.Max 1}}
					if l := len(o.{{$fieldInfo.Name}}); l > {{$fieldInfo.Max}} {
						return xw.Tot(), xdr.ElementSizeExceeded("{{$fieldInfo.Name}}", l, {{$fieldInfo.Max}})
					}
				{{end}}
				xw.Write{{$fieldInfo.Encoder}}(o.{{$fieldInfo.Name}})
			{{else}}
				_, err := o.{{$fieldInfo.Name}}.encodeXDR(xw)
				if err != nil {
					return xw.Tot(), err
				}
			{{end}}
		{{else}}
			{{if ge $fieldInfo.Max 1}}
				if l := len(o.{{$fieldInfo.Name}}); l > {{$fieldInfo.Max}} {
					return xw.Tot(), xdr.ElementSizeExceeded("{{$fieldInfo.Name}}", l, {{$fieldInfo.Max}})
				}
			{{end}}
			xw.WriteUint32(uint32(len(o.{{$fieldInfo.Name}})))
			for i := range o.{{$fieldInfo.Name}} {
			{{if ne $fieldInfo.Convert ""}}
				xw.Write{{$fieldInfo.Encoder}}({{$fieldInfo.Convert}}(o.{{$fieldInfo.Name}}[i]))
			{{else if $fieldInfo.IsBasic}}
				xw.Write{{$fieldInfo.Encoder}}(o.{{$fieldInfo.Name}}[i])
			{{else}}
				_, err := o.{{$fieldInfo.Name}}[i].encodeXDR(xw)
				if err != nil {
					return xw.Tot(), err
				}
			{{end}}
			}
		{{end}}
	{{end}}
	return xw.Tot(), xw.Error()
}//+n

func (o *{{.TypeName}}) DecodeXDR(r io.Reader) error {
	xr := xdr.NewReader(r)
	return o.decodeXDR(xr)
}//+n

func (o *{{.TypeName}}) UnmarshalXDR(bs []byte) error {
	var br = bytes.NewReader(bs)
	var xr = xdr.NewReader(br)
	return o.decodeXDR(xr)
}//+n

func (o *{{.TypeName}}) decodeXDR(xr *xdr.Reader) error {
	{{range $fieldInfo := .Fields}}
		{{if not $fieldInfo.IsSlice}}
			{{if ne $fieldInfo.Convert ""}}
				o.{{$fieldInfo.Name}} = {{$fieldInfo.FieldType}}(xr.Read{{$fieldInfo.Encoder}}())
			{{else if $fieldInfo.IsBasic}}
				{{if ge $fieldInfo.Max 1}}
					o.{{$fieldInfo.Name}} = xr.Read{{$fieldInfo.Encoder}}Max({{$fieldInfo.Max}})
				{{else}}
					o.{{$fieldInfo.Name}} = xr.Read{{$fieldInfo.Encoder}}()
				{{end}}
			{{else}}
				(&o.{{$fieldInfo.Name}}).decodeXDR(xr)
			{{end}}
		{{else}}
			_{{$fieldInfo.Name}}Size := int(xr.ReadUint32())
			{{if ge $fieldInfo.Max 1}}
				if _{{$fieldInfo.Name}}Size > {{$fieldInfo.Max}} {
					return xdr.ElementSizeExceeded("{{$fieldInfo.Name}}", _{{$fieldInfo.Name}}Size, {{$fieldInfo.Max}})
				}
			{{end}}
			o.{{$fieldInfo.Name}} = make([]{{$fieldInfo.FieldType}}, _{{$fieldInfo.Name}}Size)
			for i := range o.{{$fieldInfo.Name}} {
				{{if ne $fieldInfo.Convert ""}}
					o.{{$fieldInfo.Name}}[i] = {{$fieldInfo.FieldType}}(xr.Read{{$fieldInfo.Encoder}}())
				{{else if $fieldInfo.IsBasic}}
					o.{{$fieldInfo.Name}}[i] = xr.Read{{$fieldInfo.Encoder}}()
				{{else}}
					(&o.{{$fieldInfo.Name}}[i]).decodeXDR(xr)
				{{end}}
			}
		{{end}}
	{{end}}
	return xr.Error()
}`))

var maxRe = regexp.MustCompile(`\Wmax:(\d+)`)

type typeSet struct {
	Type    string
	Encoder string
}

var xdrEncoders = map[string]typeSet{
	"int8":   typeSet{"uint8", "Uint8"},
	"uint8":  typeSet{"", "Uint8"},
	"int16":  typeSet{"uint16", "Uint16"},
	"uint16": typeSet{"", "Uint16"},
	"int32":  typeSet{"uint32", "Uint32"},
	"uint32": typeSet{"", "Uint32"},
	"int64":  typeSet{"uint64", "Uint64"},
	"uint64": typeSet{"", "Uint64"},
	"int":    typeSet{"uint64", "Uint64"},
	"string": typeSet{"", "String"},
	"[]byte": typeSet{"", "Bytes"},
	"bool":   typeSet{"", "Bool"},
}

func handleStruct(t *ast.StructType) []fieldInfo {
	var fs []fieldInfo

	for _, sf := range t.Fields.List {
		if len(sf.Names) == 0 {
			// We don't handle anonymous fields
			continue
		}

		fn := sf.Names[0].Name
		var max = 0
		if sf.Comment != nil {
			c := sf.Comment.List[0].Text
			if m := maxRe.FindStringSubmatch(c); m != nil {
				max, _ = strconv.Atoi(m[1])
			}
			if strings.Contains(c, "noencode") {
				continue
			}
		}

		var f fieldInfo
		switch ft := sf.Type.(type) {
		case *ast.Ident:
			tn := ft.Name
			if enc, ok := xdrEncoders[tn]; ok {
				f = fieldInfo{
					Name:      fn,
					IsBasic:   true,
					FieldType: tn,
					Encoder:   enc.Encoder,
					Convert:   enc.Type,
					Max:       max,
				}
			} else {
				f = fieldInfo{
					Name:      fn,
					IsBasic:   false,
					FieldType: tn,
					Max:       max,
				}
			}

		case *ast.ArrayType:
			if ft.Len != nil {
				// We don't handle arrays
				continue
			}

			tn := ft.Elt.(*ast.Ident).Name
			if enc, ok := xdrEncoders["[]"+tn]; ok {
				f = fieldInfo{
					Name:      fn,
					IsBasic:   true,
					FieldType: tn,
					Encoder:   enc.Encoder,
					Convert:   enc.Type,
					Max:       max,
				}
			} else if enc, ok := xdrEncoders[tn]; ok {
				f = fieldInfo{
					Name:      fn,
					IsBasic:   true,
					IsSlice:   true,
					FieldType: tn,
					Encoder:   enc.Encoder,
					Convert:   enc.Type,
					Max:       max,
				}
			} else {
				f = fieldInfo{
					Name:      fn,
					IsBasic:   false,
					IsSlice:   true,
					FieldType: tn,
					Max:       max,
				}
			}
		}

		fs = append(fs, f)
	}

	return fs
}

func generateCode(output io.Writer, s structInfo) {
	name := s.Name
	fs := s.Fields

	var buf bytes.Buffer
	err := encodeTpl.Execute(&buf, map[string]interface{}{"TypeName": name, "Fields": fs})
	if err != nil {
		panic(err)
	}

	bs := regexp.MustCompile(`(\s*\n)+`).ReplaceAll(buf.Bytes(), []byte("\n"))
	bs = bytes.Replace(bs, []byte("//+n"), []byte("\n"), -1)

	bs, err = format.Source(bs)
	if err != nil {
		panic(err)
	}
	fmt.Fprintln(output, string(bs))
}

func uncamelize(s string) string {
	return regexp.MustCompile("[a-z][A-Z]").ReplaceAllStringFunc(s, func(camel string) string {
		return camel[:1] + " " + camel[1:]
	})
}

func generateDiagram(output io.Writer, s structInfo) {
	sn := s.Name
	fs := s.Fields

	fmt.Fprintln(output, sn+" Structure:")
	fmt.Fprintln(output)
	fmt.Fprintln(output, " 0                   1                   2                   3")
	fmt.Fprintln(output, " 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1")
	line := "+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+"
	fmt.Fprintln(output, line)

	for _, f := range fs {
		tn := f.FieldType
		sl := f.IsSlice
		name := uncamelize(f.Name)

		if sl {
			fmt.Fprintf(output, "| %s |\n", center("Number of "+name, 61))
			fmt.Fprintln(output, line)
		}
		switch tn {
		case "bool":
			fmt.Fprintf(output, "| %s |V|\n", center(name+" (V=0 or 1)", 59))
			fmt.Fprintln(output, line)
		case "uint16":
			fmt.Fprintf(output, "| %s | %s |\n", center("0x0000", 29), center(name, 29))
			fmt.Fprintln(output, line)
		case "uint32":
			fmt.Fprintf(output, "| %s |\n", center(name, 61))
			fmt.Fprintln(output, line)
		case "int64", "uint64":
			fmt.Fprintf(output, "| %-61s |\n", "")
			fmt.Fprintf(output, "+ %s +\n", center(name+" (64 bits)", 61))
			fmt.Fprintf(output, "| %-61s |\n", "")
			fmt.Fprintln(output, line)
		case "string", "byte": // XXX We assume slice of byte!
			fmt.Fprintf(output, "| %s |\n", center("Length of "+name, 61))
			fmt.Fprintln(output, line)
			fmt.Fprintf(output, "/ %61s /\n", "")
			fmt.Fprintf(output, "\\ %s \\\n", center(name+" (variable length)", 61))
			fmt.Fprintf(output, "/ %61s /\n", "")
			fmt.Fprintln(output, line)
		default:
			if sl {
				tn = "Zero or more " + tn + " Structures"
				fmt.Fprintf(output, "/ %s /\n", center("", 61))
				fmt.Fprintf(output, "\\ %s \\\n", center(tn, 61))
				fmt.Fprintf(output, "/ %s /\n", center("", 61))
			} else {
				fmt.Fprintf(output, "| %s |\n", center(tn, 61))
			}
			fmt.Fprintln(output, line)
		}
	}
	fmt.Fprintln(output)
	fmt.Fprintln(output)
}

func generateXdr(output io.Writer, s structInfo) {
	sn := s.Name
	fs := s.Fields

	fmt.Fprintf(output, "struct %s {\n", sn)

	for _, f := range fs {
		tn := f.FieldType
		fn := f.Name
		suf := ""
		l := ""
		if f.Max > 0 {
			l = strconv.Itoa(f.Max)
		}
		if f.IsSlice {
			suf = "<" + l + ">"
		}

		switch tn {
		case "uint16", "uint32":
			fmt.Fprintf(output, "\tunsigned int %s%s;\n", fn, suf)
		case "int64":
			fmt.Fprintf(output, "\thyper %s%s;\n", fn, suf)
		case "uint64":
			fmt.Fprintf(output, "\tunsigned hyper %s%s;\n", fn, suf)
		case "string":
			fmt.Fprintf(output, "\tstring %s<%s>;\n", fn, l)
		case "byte":
			fmt.Fprintf(output, "\topaque %s<%s>;\n", fn, l)
		default:
			fmt.Fprintf(output, "\t%s %s%s;\n", tn, fn, suf)
		}
	}
	fmt.Fprintln(output, "}")
	fmt.Fprintln(output)
}

func center(s string, w int) string {
	w -= len(s)
	l := w / 2
	r := l
	if l+r < w {
		r++
	}
	return strings.Repeat(" ", l) + s + strings.Repeat(" ", r)
}

func inspector(structs *[]structInfo) func(ast.Node) bool {
	return func(n ast.Node) bool {
		switch n := n.(type) {
		case *ast.TypeSpec:
			switch t := n.Type.(type) {
			case *ast.StructType:
				name := n.Name.Name
				fs := handleStruct(t)
				*structs = append(*structs, structInfo{name, fs})
			}
			return false
		default:
			return true
		}
	}
}

func main() {
	outputFile := flag.String("o", "", "Output file, blank for stdout")
	flag.Parse()
	fname := flag.Arg(0)

	fset := token.NewFileSet()
	f, err := parser.ParseFile(fset, fname, nil, parser.ParseComments)
	if err != nil {
		log.Fatal(err)
	}

	var structs []structInfo
	i := inspector(&structs)
	ast.Inspect(f, i)

	var output io.Writer = os.Stdout
	if *outputFile != "" {
		fd, err := os.Create(*outputFile)
		if err != nil {
			log.Fatal(err)
		}
		output = fd
	}

	headerTpl.Execute(output, map[string]string{"Package": f.Name.Name})
	for _, s := range structs {
		fmt.Fprintf(output, "\n/*\n\n")
		generateDiagram(output, s)
		generateXdr(output, s)
		fmt.Fprintf(output, "*/\n")
		generateCode(output, s)
	}
}
