From 0225de5acb474a3c269cd48c3d6bb9c0a6e64019 Mon Sep 17 00:00:00 2001 From: Joseph Watson Date: Thu, 23 Jan 2025 16:48:11 -0500 Subject: [PATCH 1/2] feature: Add spanner support for Encode and Decode --- decimal.go | 38 +++++++++++ decimal_test.go | 163 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 201 insertions(+) diff --git a/decimal.go b/decimal.go index a37a230..c921a02 100644 --- a/decimal.go +++ b/decimal.go @@ -1898,6 +1898,16 @@ func (d *Decimal) GobDecode(data []byte) error { return d.UnmarshalBinary(data) } +// DecodeSpanner decodes a Spanner value into a Decimal +func (d *Decimal) DecodeSpanner(val interface{}) error { + return d.Scan(val) +} + +// EncodeSpanner encodes a Decimal into a Spanner value +func (d Decimal) EncodeSpanner() (interface{}, error) { + return d.String(), nil +} + // StringScaled first scales the decimal then calls .String() on it. // // Deprecated: buggy and unintuitive. Use StringFixed instead. @@ -2117,6 +2127,34 @@ func (d NullDecimal) MarshalText() (text []byte, err error) { return d.Decimal.MarshalText() } +// DecodeSpanner decodes a Spanner value into a Decimal +func (d *NullDecimal) DecodeSpanner(value interface{}) error { + switch t := value.(type) { + case nil: + d.Valid = false + + return nil + case *string: + if t == nil { + d.Valid = false + + return nil + } + value = *t + } + d.Valid = true + + return d.Decimal.Scan(value) +} + +// EncodeSpanner encodes a Decimal into a Spanner value +func (d NullDecimal) EncodeSpanner() (interface{}, error) { + if !d.Valid { + return nil, nil + } + return d.Decimal.String(), nil +} + // Trig functions // Atan returns the arctangent, in radians, of x. diff --git a/decimal_test.go b/decimal_test.go index d398f2d..d1a8a55 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -2491,6 +2491,77 @@ func TestDecimal_Value(t *testing.T) { } } +func decodeSpannerHelper(t *testing.T, dbval interface{}, expected Decimal) { + t.Helper() + + a := Decimal{} + if err := a.DecodeSpanner(dbval); err != nil { + // DecodeSpanner failed... no need to test result value + t.Errorf("a.DecodeSpanner(%v) failed with message: %s", dbval, err) + } else if !a.Equal(expected) { + // DecodeSpanner succeeded... test resulting values + t.Errorf("%s does not equal to %s", a, expected) + } +} + +type spannerDecoder interface { + DecodeSpanner(input interface{}) error +} + +func TestDecimal_DecodeSpanner(t *testing.T) { + // test the DecodeSpanner method that implements spanner.Decoder interface + if _, ok := interface{}(new(Decimal)).(spannerDecoder); !ok { + t.Error("Decimal does not implement spanner.Decoder") + } + + dbvalue := 54.33 + expected := NewFromFloat(dbvalue) + decodeSpannerHelper(t, dbvalue, expected) + + // also test uint64 + dbvalueUint64 := uint64(2) + expected = New(2, 0) + decodeSpannerHelper(t, dbvalueUint64, expected) + + // ensure we can handle the return of either []byte or string + valueStr := "535.666" + dbvalueStr := []byte(valueStr) + expected, err := NewFromString(valueStr) + if err != nil { + t.Fatal(err) + } + decodeSpannerHelper(t, dbvalueStr, expected) + decodeSpannerHelper(t, valueStr, expected) + + type foo struct{} + a := Decimal{} + err = a.DecodeSpanner(foo{}) + if err == nil { + t.Errorf("a.DecodeSpanner(Foo{}) should have thrown an error but did not") + } +} + +type spannerEncoder interface { + EncodeSpanner() (interface{}, error) +} + +func TestDecimal_EncodeSpanner(t *testing.T) { + // Make sure this does implement the spanner.Encoder interface + if _, ok := interface{}(Decimal{}).(spannerEncoder); !ok { + t.Error("Decimal does not implement spanner.Encoder") + } + + // check that normal case is handled appropriately + a := New(1234, -2) + expected := "12.34" + value, err := a.Value() + if err != nil { + t.Errorf("Decimal(12.34).Value() failed with message: %s", err) + } else if got := value.(string); got != expected { + t.Errorf("%s does not equal to %s", a, expected) + } +} + // old tests after this line func TestDecimal_Scale(t *testing.T) { @@ -3287,6 +3358,98 @@ func TestNullDecimal_Value(t *testing.T) { } } +func TestNullDecimal_DecodeSpanner(t *testing.T) { + // test the DecodeSpanner method that implements the + // spanner.Decoder interface + if _, ok := interface{}(new(NullDecimal)).(spannerDecoder); !ok { + t.Error("NullDecimal does not implement spanner.Decoder") + } + + // Make sure handles nil value + a := NullDecimal{} + var dbvaluePtr interface{} + err := a.DecodeSpanner(dbvaluePtr) + if err != nil { + // DecodeSpanner failed... no need to test result value + t.Errorf("a.DecodeSpanner(nil) failed with message: %s", err) + } else { + if a.Valid { + t.Errorf("%s is not null", a.Decimal) + } + } + + // Make sure handles nil *string + dbvaluePtr = (*string)(nil) + if err := a.DecodeSpanner(dbvaluePtr); err != nil { + // DecodeSpanner failed... no need to test result value + t.Errorf("a.DecodeSpanner((*string)(nil)) failed with message: %s", err) + } else { + if a.Valid { + t.Errorf("%s is not null", a.Decimal) + } + } + + valueStr := "535.666" + expected, err := NewFromString(valueStr) + if err != nil { + t.Fatal(err) + } + + // Handle string + err = a.DecodeSpanner(valueStr) + if err != nil { + // DecodeSpanner failed... no need to test result value + t.Errorf("a.DecodeSpanner('535.666') failed with message: %s", err) + } else { + // DecodeSpanner succeeded... test resulting values + if !a.Valid { + t.Errorf("%s is null", a.Decimal) + } else if !a.Decimal.Equals(expected) { + t.Errorf("%v does not equal %v", a, expected) + } + } + + // handle *string + err = a.DecodeSpanner(&valueStr) + if err != nil { + // DecodeSpanner failed... no need to test result value + t.Errorf("a.DecodeSpanner('535.666') failed with message: %s", err) + } else { + // DecodeSpanner succeeded... test resulting values + if !a.Valid { + t.Errorf("%s is null", a.Decimal) + } else if !a.Decimal.Equals(expected) { + t.Errorf("%v does not equal %v", a, expected) + } + } +} + +func TestNullDecimal_EncodeSpanner(t *testing.T) { + // Make sure this does implement the spanner.Encoder interface + var nullDecimal NullDecimal + if _, ok := interface{}(nullDecimal).(spannerEncoder); !ok { + t.Error("NullDecimal does not implement spanner.Encoder") + } + + // check that null is handled appropriately + value, err := nullDecimal.EncodeSpanner() + if err != nil { + t.Errorf("NullDecimal{}.Valid() failed with message: %s", err) + } else if value != nil { + t.Errorf("%v is not nil", value) + } + + // check that normal case is handled appropriately + a := NullDecimal{Decimal: New(1234, -2), Valid: true} + expected := "12.34" + value, err = a.EncodeSpanner() + if err != nil { + t.Errorf("NullDecimal(12.34).EncodeSpanner() failed with message: %s", err) + } else if value.(string) != expected { + t.Errorf("%v does not equal %v", a, expected) + } +} + func TestBinary(t *testing.T) { for _, y := range testTable { x := y.float From 7d968187da83ca9aec9fe3aec8a0b3098a4c4c8c Mon Sep 17 00:00:00 2001 From: Joseph Watson Date: Thu, 23 Jan 2025 17:04:30 -0500 Subject: [PATCH 2/2] cleanup --- decimal.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/decimal.go b/decimal.go index c921a02..6cc8042 100644 --- a/decimal.go +++ b/decimal.go @@ -2132,12 +2132,10 @@ func (d *NullDecimal) DecodeSpanner(value interface{}) error { switch t := value.(type) { case nil: d.Valid = false - return nil case *string: if t == nil { d.Valid = false - return nil } value = *t