Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,16 @@ func (d *Decimal) GobDecode(data []byte) error {
return d.UnmarshalBinary(data)
}

// DecodeSpanner decodes a Spanner value into a Decimal
func (d *Decimal) DecodeSpanner(val interface{}) error {
return d.Scan(val)
}

// EncodeSpanner encodes a Decimal into a Spanner value
func (d Decimal) EncodeSpanner() (interface{}, error) {
return d.String(), nil
}

// StringScaled first scales the decimal then calls .String() on it.
//
// Deprecated: buggy and unintuitive. Use StringFixed instead.
Expand Down Expand Up @@ -2117,6 +2127,32 @@ func (d NullDecimal) MarshalText() (text []byte, err error) {
return d.Decimal.MarshalText()
}

// DecodeSpanner decodes a Spanner value into a Decimal
func (d *NullDecimal) DecodeSpanner(value interface{}) error {
switch t := value.(type) {
case nil:
d.Valid = false
return nil
case *string:
if t == nil {
d.Valid = false
return nil
}
value = *t
}
d.Valid = true

return d.Decimal.Scan(value)
}

// EncodeSpanner encodes a Decimal into a Spanner value
func (d NullDecimal) EncodeSpanner() (interface{}, error) {
if !d.Valid {
return nil, nil
}
return d.Decimal.String(), nil
}

// Trig functions

// Atan returns the arctangent, in radians, of x.
Expand Down
163 changes: 163 additions & 0 deletions decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2491,6 +2491,77 @@ func TestDecimal_Value(t *testing.T) {
}
}

func decodeSpannerHelper(t *testing.T, dbval interface{}, expected Decimal) {
t.Helper()

a := Decimal{}
if err := a.DecodeSpanner(dbval); err != nil {
// DecodeSpanner failed... no need to test result value
t.Errorf("a.DecodeSpanner(%v) failed with message: %s", dbval, err)
} else if !a.Equal(expected) {
// DecodeSpanner succeeded... test resulting values
t.Errorf("%s does not equal to %s", a, expected)
}
}

type spannerDecoder interface {
DecodeSpanner(input interface{}) error
}

func TestDecimal_DecodeSpanner(t *testing.T) {
// test the DecodeSpanner method that implements spanner.Decoder interface
if _, ok := interface{}(new(Decimal)).(spannerDecoder); !ok {
t.Error("Decimal does not implement spanner.Decoder")
}

dbvalue := 54.33
expected := NewFromFloat(dbvalue)
decodeSpannerHelper(t, dbvalue, expected)

// also test uint64
dbvalueUint64 := uint64(2)
expected = New(2, 0)
decodeSpannerHelper(t, dbvalueUint64, expected)

// ensure we can handle the return of either []byte or string
valueStr := "535.666"
dbvalueStr := []byte(valueStr)
expected, err := NewFromString(valueStr)
if err != nil {
t.Fatal(err)
}
decodeSpannerHelper(t, dbvalueStr, expected)
decodeSpannerHelper(t, valueStr, expected)

type foo struct{}
a := Decimal{}
err = a.DecodeSpanner(foo{})
if err == nil {
t.Errorf("a.DecodeSpanner(Foo{}) should have thrown an error but did not")
}
}

type spannerEncoder interface {
EncodeSpanner() (interface{}, error)
}

func TestDecimal_EncodeSpanner(t *testing.T) {
// Make sure this does implement the spanner.Encoder interface
if _, ok := interface{}(Decimal{}).(spannerEncoder); !ok {
t.Error("Decimal does not implement spanner.Encoder")
}

// check that normal case is handled appropriately
a := New(1234, -2)
expected := "12.34"
value, err := a.Value()
if err != nil {
t.Errorf("Decimal(12.34).Value() failed with message: %s", err)
} else if got := value.(string); got != expected {
t.Errorf("%s does not equal to %s", a, expected)
}
}

// old tests after this line

func TestDecimal_Scale(t *testing.T) {
Expand Down Expand Up @@ -3287,6 +3358,98 @@ func TestNullDecimal_Value(t *testing.T) {
}
}

func TestNullDecimal_DecodeSpanner(t *testing.T) {
// test the DecodeSpanner method that implements the
// spanner.Decoder interface
if _, ok := interface{}(new(NullDecimal)).(spannerDecoder); !ok {
t.Error("NullDecimal does not implement spanner.Decoder")
}

// Make sure handles nil value
a := NullDecimal{}
var dbvaluePtr interface{}
err := a.DecodeSpanner(dbvaluePtr)
if err != nil {
// DecodeSpanner failed... no need to test result value
t.Errorf("a.DecodeSpanner(nil) failed with message: %s", err)
} else {
if a.Valid {
t.Errorf("%s is not null", a.Decimal)
}
}

// Make sure handles nil *string
dbvaluePtr = (*string)(nil)
if err := a.DecodeSpanner(dbvaluePtr); err != nil {
// DecodeSpanner failed... no need to test result value
t.Errorf("a.DecodeSpanner((*string)(nil)) failed with message: %s", err)
} else {
if a.Valid {
t.Errorf("%s is not null", a.Decimal)
}
}

valueStr := "535.666"
expected, err := NewFromString(valueStr)
if err != nil {
t.Fatal(err)
}

// Handle string
err = a.DecodeSpanner(valueStr)
if err != nil {
// DecodeSpanner failed... no need to test result value
t.Errorf("a.DecodeSpanner('535.666') failed with message: %s", err)
} else {
// DecodeSpanner succeeded... test resulting values
if !a.Valid {
t.Errorf("%s is null", a.Decimal)
} else if !a.Decimal.Equals(expected) {
t.Errorf("%v does not equal %v", a, expected)
}
}

// handle *string
err = a.DecodeSpanner(&valueStr)
if err != nil {
// DecodeSpanner failed... no need to test result value
t.Errorf("a.DecodeSpanner('535.666') failed with message: %s", err)
} else {
// DecodeSpanner succeeded... test resulting values
if !a.Valid {
t.Errorf("%s is null", a.Decimal)
} else if !a.Decimal.Equals(expected) {
t.Errorf("%v does not equal %v", a, expected)
}
}
}

func TestNullDecimal_EncodeSpanner(t *testing.T) {
// Make sure this does implement the spanner.Encoder interface
var nullDecimal NullDecimal
if _, ok := interface{}(nullDecimal).(spannerEncoder); !ok {
t.Error("NullDecimal does not implement spanner.Encoder")
}

// check that null is handled appropriately
value, err := nullDecimal.EncodeSpanner()
if err != nil {
t.Errorf("NullDecimal{}.Valid() failed with message: %s", err)
} else if value != nil {
t.Errorf("%v is not nil", value)
}

// check that normal case is handled appropriately
a := NullDecimal{Decimal: New(1234, -2), Valid: true}
expected := "12.34"
value, err = a.EncodeSpanner()
if err != nil {
t.Errorf("NullDecimal(12.34).EncodeSpanner() failed with message: %s", err)
} else if value.(string) != expected {
t.Errorf("%v does not equal %v", a, expected)
}
}

func TestBinary(t *testing.T) {
for _, y := range testTable {
x := y.float
Expand Down