From 038671ce8699dd67be952ac936ea375a8fa6b069 Mon Sep 17 00:00:00 2001 From: chris Date: Fri, 9 Aug 2019 10:58:51 -0400 Subject: [PATCH] [WIP] Schema versioning --- pkg/e2db/db_test.go | 6 +- pkg/e2db/model.go | 216 +++++++++++++++++++++++++++++++++++------ pkg/e2db/model_test.go | 48 +++++++++ pkg/e2db/query.go | 8 +- pkg/e2db/tx.go | 38 ++++---- 5 files changed, 260 insertions(+), 56 deletions(-) create mode 100644 pkg/e2db/model_test.go diff --git a/pkg/e2db/db_test.go b/pkg/e2db/db_test.go index a053111..a1c85f0 100644 --- a/pkg/e2db/db_test.go +++ b/pkg/e2db/db_test.go @@ -15,7 +15,10 @@ import ( var db *DB -func init() { +func initDB() { + if db != nil { + return + } log.SetLevel(zapcore.DebugLevel) if err := os.RemoveAll("testdata"); err != nil { @@ -68,6 +71,7 @@ var newRoles = []*Role{ } func resetTable(t *testing.T) { + initDB() roles := db.Table(&Role{}) if err := roles.Drop(); err != nil && errors.Cause(err) != ErrTableNotFound { t.Fatal(err) diff --git a/pkg/e2db/model.go b/pkg/e2db/model.go index b34cb35..487f3c3 100644 --- a/pkg/e2db/model.go +++ b/pkg/e2db/model.go @@ -1,8 +1,14 @@ package e2db import ( + "bytes" + "crypto/sha1" + "encoding/hex" + "fmt" "reflect" + "sort" "strings" + "unicode" "github.com/criticalstack/e2d/pkg/e2db/key" "github.com/pkg/errors" @@ -16,15 +22,50 @@ type Tag struct { Name, Value string } +func (t *Tag) String() string { + if t.Value == "" { + return t.Name + } + return fmt.Sprintf("%s=%s", t.Name, t.Value) +} + type FieldDef struct { - Name string - Tags []*Tag + Name string + Kind reflect.Kind + Type string + Tags []*Tag + Fields []*FieldDef +} + +func (f *FieldDef) String() string { + var tags string + if len(f.Tags) > 0 { + tt := make([]string, 0) + for _, t := range f.Tags { + tt = append(tt, t.String()) + } + tags = fmt.Sprintf(" `%s`", strings.Join(tt, ",")) + } + t := f.Kind.String() + if f.Kind == reflect.Struct { + t = f.Type + } + return fmt.Sprintf("%s %s%s", f.Name, t, tags) } func (f *FieldDef) isIndex() bool { return f.isPrimaryKey() || f.hasTag("index", "unique") } +func (f *FieldDef) getTag(name string) *Tag { + for _, t := range f.Tags { + if t.Name == name { + return t + } + } + return nil +} + func (f *FieldDef) hasTag(tags ...string) bool { for _, t := range f.Tags { for _, tag := range tags { @@ -55,7 +96,7 @@ const ( UniqueIndex ) -func (f *FieldDef) Type() IndexType { +func (f *FieldDef) indexType() IndexType { switch { case f.hasTag("increment", "id"): return PrimaryKey @@ -69,7 +110,7 @@ func (f *FieldDef) Type() IndexType { } func (f *FieldDef) indexKey(tableName string, value string) (string, error) { - switch f.Type() { + switch f.indexType() { case PrimaryKey: return key.ID(tableName, value), nil case SecondaryIndex, UniqueIndex: @@ -79,27 +120,13 @@ func (f *FieldDef) indexKey(tableName string, value string) (string, error) { } } -type ModelDef struct { - Name string - Fields map[string]*FieldDef - - t reflect.Type -} - -func NewModelDef(t reflect.Type) *ModelDef { - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - if t.NumField() == 0 { - panic("must have at least 1 struct field") - } - m := &ModelDef{ - Name: t.Name(), - Fields: make(map[string]*FieldDef), - t: t, - } +func newFieldDefs(t reflect.Type) []*FieldDef { + fields := make([]*FieldDef, 0) for i := 0; i < t.NumField(); i++ { ft := t.Field(i) + if unicode.IsLower([]rune(ft.Name)[0]) { + continue + } tags := make([]*Tag, 0) if tagValue, ok := ft.Tag.Lookup("e2db"); ok { for _, t := range strings.Split(tagValue, ",") { @@ -111,11 +138,58 @@ func NewModelDef(t reflect.Type) *ModelDef { } } } - m.Fields[ft.Name] = &FieldDef{ + sort.Slice(tags, func(i, j int) bool { + return tags[i].Name < tags[j].Name + }) + fd := &FieldDef{ Name: ft.Name, + Kind: ft.Type.Kind(), + Type: ft.Type.String(), Tags: tags, } + if ft.Type.Kind() == reflect.Struct { + fd.Fields = newFieldDefs(ft.Type) + } + fields = append(fields, fd) + } + sort.Slice(fields, func(i, j int) bool { + return fields[i].Name < fields[j].Name + }) + return fields +} + +type ModelDef struct { + Name string + Fields []*FieldDef + CheckSum string + Version string + + t reflect.Type +} + +func NewModelDef(t reflect.Type) *ModelDef { + if t.Kind() == reflect.Ptr { + t = t.Elem() } + if t.NumField() == 0 { + panic("must have at least 1 struct field") + } + m := &ModelDef{ + Name: t.Name(), + Fields: newFieldDefs(t), + t: t, + } + if !m.hasPrimaryKey() { + panic("must specify a primary key") + } + pk := m.getPrimaryKey() + vt := pk.getTag("v") + if vt == nil { + vt = &Tag{Name: "v", Value: "0"} + pk.Tags = append(pk.Tags, vt) + } + m.Version = vt.Value + m.CheckSum = SchemaCheckSum(m) return m } @@ -127,18 +201,45 @@ func (m *ModelDef) New() *reflect.Value { return &v } +func (m *ModelDef) getPrimaryKey() *FieldDef { + for _, f := range m.Fields { + if f.isPrimaryKey() { + return f + } + } + return nil +} + +func (m *ModelDef) hasPrimaryKey() bool { + return m.getPrimaryKey() != nil +} + +func (m *ModelDef) getFieldByName(name string) (*FieldDef, bool) { + for _, f := range m.Fields { + if f.Name == name { + return f, true + } + } + return nil, false +} + +func (m *ModelDef) String() string { + return m.t.String() +} + type Field struct { *FieldDef - value reflect.Value + + v reflect.Value } func (f *Field) isZero() bool { - return f.value.IsValid() && reflect.DeepEqual(f.value.Interface(), reflect.Zero(f.value.Type()).Interface()) + return f.v.IsValid() && reflect.DeepEqual(f.v.Interface(), reflect.Zero(f.v.Type()).Interface()) } type ModelItem struct { *ModelDef - Fields map[string]*Field + Fields []*Field } func NewModelItem(v reflect.Value) *ModelItem { @@ -148,13 +249,13 @@ func NewModelItem(v reflect.Value) *ModelItem { } m := &ModelItem{ ModelDef: NewModelDef(v.Type()), - Fields: make(map[string]*Field), + Fields: make([]*Field, 0), } - for name, f := range m.ModelDef.Fields { - m.Fields[name] = &Field{ + for _, f := range m.ModelDef.Fields { + m.Fields = append(m.Fields, &Field{ FieldDef: f, - value: v.FieldByName(name), - } + v: v.FieldByName(f.Name), + }) } return m } @@ -167,3 +268,54 @@ func (m *ModelItem) getPrimaryKey() (*Field, error) { } return nil, ErrNoPrimaryKey } + +func schemaCheckSumFieldDef(f *FieldDef) string { + var sb strings.Builder + sb.WriteString(f.String()) + for _, f := range f.Fields { + switch f.Kind { + case reflect.Struct: + sb.WriteString(schemaCheckSumFieldDef(f)) + default: + sb.WriteString(f.String()) + } + } + return sb.String() +} + +func SchemaCheckSum(m *ModelDef) string { + var b bytes.Buffer + for _, f := range m.Fields { + b.WriteString(schemaCheckSumFieldDef(f)) + } + h := sha1.Sum(b.Bytes()) + name := hex.EncodeToString(h[:]) + if len(name) > 5 { + name = name[:5] + } + return strings.ToLower(name) +} + +func printFieldDef(f *FieldDef) { + fmt.Println(f) + for _, f := range f.Fields { + switch f.Kind { + case reflect.Struct: + printFieldDef(f) + default: + fmt.Println(f) + } + } +} + +func PrintModelDef(m *ModelDef) { + fmt.Println(m) + for _, f := range m.Fields { + switch f.Kind { + case reflect.Struct: + printFieldDef(f) + default: + fmt.Println(f) + } + } +} diff --git a/pkg/e2db/model_test.go b/pkg/e2db/model_test.go new file mode 100644 index 0000000..3e4f6fa --- /dev/null +++ b/pkg/e2db/model_test.go @@ -0,0 +1,48 @@ +package e2db_test + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/criticalstack/e2d/pkg/e2db" +) + +type ModelEnum int + +const ( + Invalid ModelEnum = iota + EnumVal1 + EnumVal2 +) + +type NestedStruct struct { + Name string + Count int +} + +type Model1 struct { + Name string `e2db:"unique,required"` + ID int `e2db:"id"` + CreatedAt time.Time + Stats NestedStruct + Enum ModelEnum +} + +type Model2 struct { + ID int `e2db:"id"` + Name string `e2db:"unique,required"` + CreatedAt time.Time + Stats NestedStruct + Enum ModelEnum +} + +func TestSchemaCheckSumArbitraryOrder(t *testing.T) { + m := e2db.NewModelDef(reflect.TypeOf(&Model1{})) + fmt.Println(m.String()) + fmt.Println(m.CheckSum) + m = e2db.NewModelDef(reflect.TypeOf(&Model2{})) + fmt.Println(m.String()) + fmt.Println(m.CheckSum) +} diff --git a/pkg/e2db/query.go b/pkg/e2db/query.go index 8956706..83ca6c5 100644 --- a/pkg/e2db/query.go +++ b/pkg/e2db/query.go @@ -219,7 +219,7 @@ func (q *query) Count(fieldName string, data interface{}) (int64, error) { if err := q.t.tableMustExist(); err != nil { return 0, err } - f, ok := q.t.meta.Fields[fieldName] + f, ok := q.t.meta.getFieldByName(fieldName) if !ok { return 0, errors.Wrap(ErrInvalidField, fieldName) } @@ -246,7 +246,7 @@ func (q *query) Find(fieldName string, data interface{}, to interface{}) error { if err := q.t.validateSchema(v.Type()); err != nil { return err } - f, ok := q.t.meta.Fields[fieldName] + f, ok := q.t.meta.getFieldByName(fieldName) if !ok { return errors.Wrap(ErrInvalidField, fieldName) } @@ -265,7 +265,7 @@ func (q *query) Find(fieldName string, data interface{}, to interface{}) error { } return q.findManyByIndex(key.Indexes(q.t.meta.Name, f.Name, k), v) } - switch f.Type() { + switch f.indexType() { case PrimaryKey: return q.findOneByPrimaryKey(key.ID(q.t.meta.Name, k), v) case UniqueIndex: @@ -273,7 +273,7 @@ func (q *query) Find(fieldName string, data interface{}, to interface{}) error { case SecondaryIndex: return q.findOneBySecondaryIndex(key.Indexes(q.t.meta.Name, f.Name, k), v) default: - return errors.Errorf("field %#v has invalid index type: %#v", fieldName, f.Type()) + return errors.Errorf("field %#v has invalid index type: %#v", fieldName, f.indexType()) } } diff --git a/pkg/e2db/tx.go b/pkg/e2db/tx.go index 91dd95c..358d40c 100644 --- a/pkg/e2db/tx.go +++ b/pkg/e2db/tx.go @@ -80,15 +80,15 @@ func (tx *Tx) Insert(iface interface{}) error { if err != nil { return err } - switch pk.value.Kind() { + switch pk.v.Kind() { case reflect.Int: - pk.value.Set(reflect.ValueOf(int(id))) + pk.v.Set(reflect.ValueOf(int(id))) case reflect.Int64: - pk.value.Set(reflect.ValueOf(int64(id))) + pk.v.Set(reflect.ValueOf(int64(id))) } } } - id := toString(pk.value.Interface()) + id := toString(pk.v.Interface()) if id == "" { return errors.Wrapf(ErrInvalidPrimaryKey, "cannot be empty: %#v", pk.Name) } @@ -97,19 +97,19 @@ func (tx *Tx) Insert(iface interface{}) error { for _, tag := range f.Tags { switch tag.Name { case "index": - indexes = append(indexes, key.Index(m.Name, f.Name, toString(f.value.Interface()), id)) + indexes = append(indexes, key.Index(m.Name, f.Name, toString(f.v.Interface()), id)) case "required": if f.isZero() { return errors.Wrap(ErrFieldRequired, f.Name) } case "unique": - k := key.Unique(m.Name, f.Name, toString(f.value.Interface())) + k := key.Unique(m.Name, f.Name, toString(f.v.Interface())) ok, err := tx.db.client.Exists(k) if err != nil { return err } if ok { - return errors.Wrapf(ErrUniqueConstraint, "%#v: %#v", f.Name, f.value.String()) + return errors.Wrapf(ErrUniqueConstraint, "%#v: %#v", f.Name, f.v.String()) } indexes = append(indexes, k) } @@ -138,7 +138,7 @@ func (tx *Tx) Update(iface interface{}) error { if err != nil { return err } - id := toString(pk.value.Interface()) + id := toString(pk.v.Interface()) if id == "" { return errors.Wrapf(ErrInvalidPrimaryKey, "cannot be empty: %#v", pk.Name) } @@ -155,7 +155,7 @@ func (tx *Tx) Update(iface interface{}) error { continue } dbFieldValue := dbValue.FieldByName(f.Name) - if reflect.DeepEqual(f.value.Interface(), dbFieldValue.Interface()) { + if reflect.DeepEqual(f.v.Interface(), dbFieldValue.Interface()) { continue } @@ -168,22 +168,22 @@ func (tx *Tx) Update(iface interface{}) error { switch tag.Name { case "index": oldIdx := key.Index(m.Name, f.Name, toString(dbFieldValue.Interface()), id) - newIdx := key.Index(m.Name, f.Name, toString(f.value.Interface()), id) + newIdx := key.Index(m.Name, f.Name, toString(f.v.Interface()), id) indexes[oldIdx] = newIdx case "unique": oldIdx := key.Unique(m.Name, f.Name, toString(dbFieldValue.Interface())) - newIdx := key.Unique(m.Name, f.Name, toString(f.value.Interface())) + newIdx := key.Unique(m.Name, f.Name, toString(f.v.Interface())) ok, err := tx.db.client.Exists(newIdx) if err != nil { return err } if ok { - return errors.Wrapf(ErrUniqueConstraint, "%#v: %#v", f.Name, f.value.String()) + return errors.Wrapf(ErrUniqueConstraint, "%#v: %#v", f.Name, f.v.String()) } indexes[oldIdx] = newIdx } } - dbFieldValue.Set(f.value) + dbFieldValue.Set(f.v) } data, err := tx.c.Encode(dbValue.Interface()) if err != nil { @@ -215,12 +215,12 @@ func (tx *Tx) getIndexesByPrimaryKey(pk string) ([]string, error) { } keys := []string{pk} _, id := filepath.Split(pk) - for n, f := range tx.meta.Fields { - switch f.Type() { + for _, f := range tx.meta.Fields { + switch f.indexType() { case UniqueIndex: - keys = append(keys, key.Unique(tx.meta.Name, n, toString(v.FieldByName(n).Interface()))) + keys = append(keys, key.Unique(tx.meta.Name, f.Name, toString(v.FieldByName(f.Name).Interface()))) case SecondaryIndex: - keys = append(keys, key.Index(tx.meta.Name, n, toString(v.FieldByName(n).Interface()), id)) + keys = append(keys, key.Index(tx.meta.Name, f.Name, toString(v.FieldByName(f.Name).Interface()), id)) } } return keys, nil @@ -244,7 +244,7 @@ func (tx *Tx) Delete(fieldName string, data interface{}) (int64, error) { zap.Duration("elapsed", time.Now().Sub(st)), ) }() - f, ok := tx.meta.Fields[fieldName] + f, ok := tx.meta.getFieldByName(fieldName) if !ok { return 0, errors.Errorf("invalid field name: %#v", fieldName) } @@ -252,7 +252,7 @@ func (tx *Tx) Delete(fieldName string, data interface{}) (int64, error) { pks := make([]string, 0) // get the primary key of the item(s) being deleted - switch f.Type() { + switch f.indexType() { case PrimaryKey: pks = append(pks, key.ID(tx.meta.Name, k)) case UniqueIndex: