diff --git a/constraint.go b/constraint.go index c526fd3..97c5067 100644 --- a/constraint.go +++ b/constraint.go @@ -3,23 +3,8 @@ package hm import "fmt" // A Constraint is well.. a constraint that says a must equal to b. It's used mainly in the constraint generation process. -type Constraint struct { - a, b Type -} +type Constraint Pair -func (c Constraint) Apply(sub Subs) Substitutable { - c.a = c.a.Apply(sub).(Type) - c.b = c.b.Apply(sub).(Type) - return c -} - -func (c Constraint) FreeTypeVar() TypeVarSet { - var retVal TypeVarSet - retVal = c.a.FreeTypeVar().Union(retVal) - retVal = c.b.FreeTypeVar().Union(retVal) - return retVal -} - -func (c Constraint) Format(state fmt.State, r rune) { - fmt.Fprintf(state, "{%v = %v}", c.a, c.b) -} +func (c Constraint) Apply(sub Subs) Substitutable { return Constraint(*(*Pair)(&c).Apply(sub)) } +func (c Constraint) FreeTypeVar() TypeVarSet { return Pair(c).FreeTypeVar() } +func (c Constraint) Format(state fmt.State, r rune) { fmt.Fprintf(state, "{%v = %v}", c.A, c.B) } diff --git a/constraint_test.go b/constraint_test.go index 5894a3e..10a9506 100644 --- a/constraint_test.go +++ b/constraint_test.go @@ -4,8 +4,8 @@ import "testing" func TestConstraint(t *testing.T) { c := Constraint{ - a: TypeVariable('a'), - b: NewFnType(TypeVariable('b'), TypeVariable('c')), + A: TypeVariable('a'), + B: NewFnType(TypeVariable('b'), TypeVariable('c')), } ftv := c.FreeTypeVar() @@ -20,11 +20,11 @@ func TestConstraint(t *testing.T) { } c = c.Apply(subs).(Constraint) - if !c.a.Eq(NewFnType(proton, proton)) { + if !c.A.Eq(NewFnType(proton, proton)) { t.Errorf("c.a: %v", c) } - if !c.b.Eq(NewFnType(proton, neutron)) { + if !c.B.Eq(NewFnType(proton, neutron)) { t.Errorf("c.b: %v", c) } } diff --git a/debug.go b/debug.go index 76fe63c..cbd4fe4 100644 --- a/debug.go +++ b/debug.go @@ -11,7 +11,7 @@ import ( ) // DEBUG returns true when it's in debug mode -const DEBUG = false +const DEBUG = true var tabcount uint32 diff --git a/hm.go b/hm.go index 5d8cc8b..ef03f7a 100644 --- a/hm.go +++ b/hm.go @@ -99,7 +99,7 @@ func (infer *inferer) consGen(expr Expression) (err error) { tv := infer.Fresh() cs := append(fnCs, bodyCs...) - cs = append(cs, Constraint{fnType, NewFnType(bodyType, tv)}) + cs = append(cs, Constraint{NewFnType(bodyType, tv), fnType}) infer.t = tv infer.cs = cs @@ -334,14 +334,14 @@ func Unify(a, b Type) (sub Subs, err error) { switch at := a.(type) { case TypeVariable: - return bind(at, b) + return Bind(at, b) default: if a.Eq(b) { return nil, nil } if btv, ok := b.(TypeVariable); ok { - return bind(btv, a) + return Bind(btv, a) } atypes := a.Types() btypes := b.Types() @@ -385,7 +385,7 @@ func unifyMany(a, b Types) (sub Subs, err error) { if sub == nil { sub = s2 } else { - sub2 := compose(sub, s2) + sub2 := Compose(sub, s2) defer ReturnSubs(s2) if sub2 != sub { defer ReturnSubs(sub) @@ -396,11 +396,12 @@ func unifyMany(a, b Types) (sub Subs, err error) { return } -func bind(tv TypeVariable, t Type) (sub Subs, err error) { +// Bind binds a TypeVariable to a Type. It returns a substitution list. +func Bind(tv TypeVariable, t Type) (sub Subs, err error) { logf("Binding %v to %v", tv, t) switch { // case tv == t: - case occurs(tv, t): + case Occurs(tv, t): err = errors.Errorf("recursive unification") default: ssub := BorrowSSubs(1) @@ -411,7 +412,8 @@ func bind(tv TypeVariable, t Type) (sub Subs, err error) { return } -func occurs(tv TypeVariable, s Substitutable) bool { +// Occurs checks if a TypeVariable exists in any Substitutable (type, scheme, map etc). +func Occurs(tv TypeVariable, s Substitutable) bool { ftv := s.FreeTypeVar() defer ReturnTypeVarSet(ftv) diff --git a/perf.go b/perf.go index c324932..e5e50da 100644 --- a/perf.go +++ b/perf.go @@ -1,6 +1,8 @@ package hm -import "sync" +import ( + "sync" +) const ( poolSize = 4 @@ -160,3 +162,19 @@ func ReturnFnType(fnt *FunctionType) { fnt.b = nil fnTypePool.Put(fnt) } + +var pairPool = &sync.Pool{ + New: func() interface{} { return new(Pair) }, +} + +// BorrowPair allows access to this package's pair pool +func BorrowPair() *Pair { + return pairPool.Get().(*Pair) +} + +// ReturnPair allows accesso this package's pair pool +func ReturnPair(p *Pair) { + p.A = nil + p.B = nil + pairPool.Put(p) +} diff --git a/scheme.go b/scheme.go index 23aface..615ae4c 100644 --- a/scheme.go +++ b/scheme.go @@ -52,6 +52,11 @@ func (s *Scheme) Clone() *Scheme { } func (s *Scheme) Format(state fmt.State, c rune) { + if s == nil { + state.Write([]byte("∀[∅].∅")) + return + } + state.Write([]byte("∀[")) for i, tv := range s.tvs { if i < len(s.tvs)-1 { @@ -82,7 +87,7 @@ func (s *Scheme) Normalize() (err error) { defer ReturnTypeVarSet(tfv) ord := BorrowTypeVarSet(len(tfv)) for i := range tfv { - ord[i] = TypeVariable(letters[i]) + ord[i] = TypeVariable('a' + i) } s.t, err = s.t.Normalize(tfv, ord) diff --git a/solver.go b/solver.go index 80a1142..8bea50e 100644 --- a/solver.go +++ b/solver.go @@ -24,11 +24,12 @@ func (s *solver) solve(cs Constraints) { default: var sub Subs c := cs[0] - sub, s.err = Unify(c.a, c.b) + sub, s.err = Unify(c.A, c.B) defer ReturnSubs(s.sub) - s.sub = compose(sub, s.sub) + s.sub = Compose(sub, s.sub) cs = cs[1:].Apply(s.sub).(Constraints) + s.solve(cs) } diff --git a/solver_test.go b/solver_test.go index d206fce..12a1a17 100644 --- a/solver_test.go +++ b/solver_test.go @@ -38,6 +38,26 @@ var solverTest = []struct { }, mSubs{'a': neutron, 'b': proton}, false, }, + + // (a -> a) and (b -> c) + { + Constraints{ + { + NewFnType(TypeVariable('b'), TypeVariable('c')), + NewFnType(TypeVariable('a'), TypeVariable('a')), + }, + }, + mSubs{'b': TypeVariable('a'), 'c': TypeVariable('a')}, false, + }, + { + Constraints{ + { + NewFnType(TypeVariable('a'), TypeVariable('a')), + NewFnType(TypeVariable('b'), TypeVariable('c')), + }, + }, + mSubs{'b': TypeVariable('c'), 'a': TypeVariable('b')}, false, + }, } func TestSolver(t *testing.T) { diff --git a/struct.go b/struct.go new file mode 100644 index 0000000..0b93860 --- /dev/null +++ b/struct.go @@ -0,0 +1,97 @@ +package hm + +// this file provides a common structural abstraction + +// Pair is a convenient structural abstraction for types that are composed of two types. +// Depending on use cases, it may be useful to embed Pair, or define a new type base on *Pair. +// +// Pair partially implements Type, as the intention is merely for syntactic abstraction +// +// It has very specific semantics - +// it's useful for a small subset of types like function types, or supertypes. +// See the documentation for Apply and FreeTypeVar. +type Pair struct { + A, B Type +} + +// Apply applies a substitution on both the first and second types of the Pair. +func (t *Pair) Apply(sub Subs) *Pair { + retVal := t.Clone() + retVal.UnsafeApply(sub) + return retVal +} + +// UnsafeApply is an unsafe application of the substitution. +func (t *Pair) UnsafeApply(sub Subs) { + t.A = t.A.Apply(sub).(Type) + t.B = t.B.Apply(sub).(Type) +} + +// Types returns all the types of the Pair's constituents +func (t Pair) Types() Types { + retVal := BorrowTypes(2) + retVal[0] = t.A + retVal[1] = t.B + return retVal +} + +// FreeTypeVar returns a set of free (unbound) type variables. +func (t Pair) FreeTypeVar() TypeVarSet { return t.A.FreeTypeVar().Union(t.B.FreeTypeVar()) } + +// Clone implements Cloner +func (t *Pair) Clone() *Pair { + retVal := BorrowPair() + + if ac, ok := t.A.(Cloner); ok { + retVal.A = ac.Clone().(Type) + } else { + retVal.A = t.A + } + + if bc, ok := t.B.(Cloner); ok { + retVal.B = bc.Clone().(Type) + } else { + retVal.B = t.B + } + return retVal +} + +// Monuple is a convenient structural abstraction for types that are composed of one type. +// +// Monuple implements Substitutable, but with very specific semantics - +// It's useful for singly polymorphic types like arrays, linear types, reference types, etc +type Monuple struct { + T Type +} + +// Apply applies a substitution to the monuple type. +func (t Monuple) Apply(subs Subs) Monuple { + t.T = t.T.Apply(subs).(Type) + return t +} + +// FreeTypeVar returns the set of free type variables in the monuple. +func (t Monuple) FreeTypeVar() TypeVarSet { return t.T.FreeTypeVar() } + +// Normalize is the method to normalize all type variables +func (t Monuple) Normalize(k, v TypeVarSet) (Monuple, error) { + var t2 Type + var err error + if t2, err = t.T.Normalize(k, v); err != nil { + return Monuple{}, err + } + t.T = t2 + return t, nil +} + +// Pairer is any type that can be represented by a Pair +type Pairer interface { + Type + AsPair() *Pair +} + +// Monupler is any type that can be represented by a Monuple +type Monupler interface { + Type + AsMonuple() Monuple +} diff --git a/substitutables_test.go b/substitutables_test.go index 484cdec..aebdbfa 100644 --- a/substitutables_test.go +++ b/substitutables_test.go @@ -25,17 +25,17 @@ func TestConstraints(t *testing.T) { } cs = cs.Apply(sub).(Constraints) - if cs[0].a != neutron { + if cs[0].A != neutron { t.Error("Expected neutron") } - if cs[0].b != proton { + if cs[0].B != proton { t.Error("Expected proton") } - if cs[1].a != TypeVariable('b') { + if cs[1].A != TypeVariable('b') { t.Error("There was nothing to substitute b with") } - if cs[1].b != proton { + if cs[1].B != proton { t.Error("Expected proton") } diff --git a/substitutions.go b/substitutions.go index a21dd11..3a85e25 100644 --- a/substitutions.go +++ b/substitutions.go @@ -14,6 +14,15 @@ type Subs interface { Clone() Subs } +// MakeSubs is a utility function to help make substitution lists. +// This is useful for cases where there isn't a real need to implement Subs +func MakeSubs(n int) Subs { + if n >= 30 { + return make(mSubs) + } + return newSliceSubs(n) +} + // A Substitution is a tuple representing the TypeVariable and the replacement Type type Substitution struct { Tv TypeVariable @@ -116,7 +125,8 @@ func (s mSubs) Clone() Subs { return retVal } -func compose(a, b Subs) (retVal Subs) { +// Compose composes two substitution lists together. +func Compose(a, b Subs) (retVal Subs) { if b == nil { return a } diff --git a/substitutions_test.go b/substitutions_test.go index aabbd14..9e467a0 100644 --- a/substitutions_test.go +++ b/substitutions_test.go @@ -131,7 +131,7 @@ var composeTests = []struct { func TestCompose(t *testing.T) { for i, cts := range composeTests { - subs := compose(cts.a, cts.b) + subs := Compose(cts.a, cts.b) for _, v := range cts.expected.Iter() { if T, ok := subs.Get(v.Tv); !ok { diff --git a/type.go b/type.go index 6d2a1bc..384687e 100644 --- a/type.go +++ b/type.go @@ -7,10 +7,10 @@ import ( // Type represents all the possible type constructors. type Type interface { Substitutable - Name() string // Name is the name of the constructor - Normalize(TypeVarSet, TypeVarSet) (Type, error) // Normalize normalizes all the type variable names in the type - Types() Types // If the type is made up of smaller types, then it will return them - Eq(Type) bool // equality operation + Name() string // Name is the name of the constructor + Normalize(k TypeVarSet, v TypeVarSet) (Type, error) // Normalize normalizes all the type variable names in the type + Types() Types // If the type is made up of smaller types, then it will return them + Eq(Type) bool // equality operation fmt.Formatter fmt.Stringer diff --git a/typeVariable.go b/typeVariable.go index 2d33945..1fc383e 100644 --- a/typeVariable.go +++ b/typeVariable.go @@ -9,7 +9,7 @@ import ( // TypeVariable is a variable that ranges over the types - that is to say it can take any type. type TypeVariable rune -func (t TypeVariable) Name() string { return string(t) } +func (t TypeVariable) Name() string { return fmt.Sprintf("%v", t) } func (t TypeVariable) Apply(sub Subs) Substitutable { if sub == nil { return t @@ -30,7 +30,13 @@ func (t TypeVariable) Normalize(k, v TypeVarSet) (Type, error) { return nil, errors.Errorf("Type Variable %v not in signature", t) } -func (t TypeVariable) Types() Types { return nil } -func (t TypeVariable) String() string { return string(t) } -func (t TypeVariable) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%c", rune(t)) } -func (t TypeVariable) Eq(other Type) bool { return other == t } +func (t TypeVariable) Types() Types { return nil } +func (t TypeVariable) String() string { return fmt.Sprintf("%v", t) } +func (t TypeVariable) Format(s fmt.State, c rune) { + if t >= 'a' && t <= 'z' { + fmt.Fprintf(s, "%c", rune(t)) + return + } + fmt.Fprintf(s, "<%d>", rune(t)) +} +func (t TypeVariable) Eq(other Type) bool { return other == t } diff --git a/types/function.go b/types/function.go new file mode 100644 index 0000000..27d3942 --- /dev/null +++ b/types/function.go @@ -0,0 +1,105 @@ +package hmtypes + +import ( + "fmt" + + "github.com/chewxy/hm" +) + +// Function is a type constructor that builds function types. +type Function hm.Pair + +// NewFunction creates a new FunctionType. Functions are by default right associative. This: +// NewFunction(a, a, a) +// is short hand for this: +// NewFunction(a, NewFunction(a, a)) +func NewFunction(ts ...hm.Type) *Function { + if len(ts) < 2 { + panic("Expected at least 2 input types") + } + + retVal := borrowFn() + retVal.A = ts[0] + + if len(ts) > 2 { + retVal.B = NewFunction(ts[1:]...) + } else { + retVal.B = ts[1] + } + return retVal +} + +func (t *Function) Name() string { return "→" } +func (t *Function) Apply(sub hm.Subs) hm.Substitutable { return (*Function)((*hm.Pair)(t).Apply(sub)) } +func (t *Function) FreeTypeVar() hm.TypeVarSet { return ((*hm.Pair)(t)).FreeTypeVar() } +func (t *Function) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%v → %v", t.A, t.B) } +func (t *Function) String() string { return fmt.Sprintf("%v", t) } +func (t *Function) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { + var a, b hm.Type + var err error + if a, err = t.A.Normalize(k, v); err != nil { + return nil, err + } + + if b, err = t.B.Normalize(k, v); err != nil { + return nil, err + } + + return NewFunction(a, b), nil +} +func (t *Function) Types() hm.Types { return ((*hm.Pair)(t)).Types() } + +func (t *Function) Eq(other hm.Type) bool { + if ot, ok := other.(*Function); ok { + return ot.A.Eq(t.A) && ot.B.Eq(t.B) + } + return false +} + +// Other methods (accessors mainly) + +// Arg returns the type of the function argument +func (t *Function) Arg() hm.Type { return t.A } + +// Ret returns the return type of a function. If recursive is true, it will get the final return type +func (t *Function) Ret(recursive bool) hm.Type { + if !recursive { + return t.B + } + + if fnt, ok := t.B.(*Function); ok { + return fnt.Ret(recursive) + } + + return t.B +} + +// FlatTypes returns the types in FunctionTypes as a flat slice of types. This allows for easier iteration in some applications +func (t *Function) FlatTypes() hm.Types { + retVal := hm.BorrowTypes(8) // start with 8. Can always grow + retVal = retVal[:0] + + if a, ok := t.A.(*Function); ok { + ft := a.FlatTypes() + retVal = append(retVal, ft...) + hm.ReturnTypes(ft) + } else { + retVal = append(retVal, t.A) + } + + if b, ok := t.B.(*Function); ok { + ft := b.FlatTypes() + retVal = append(retVal, ft...) + hm.ReturnTypes(ft) + } else { + retVal = append(retVal, t.B) + } + return retVal +} + +// Clone implenents cloner +func (t *Function) Clone() interface{} { + p := (*hm.Pair)(t) + cloned := p.Clone() + return (*Function)(cloned) +} diff --git a/types/function_test.go b/types/function_test.go new file mode 100644 index 0000000..b8a973b --- /dev/null +++ b/types/function_test.go @@ -0,0 +1,99 @@ +package hmtypes + +import ( + "testing" + + "github.com/chewxy/hm" + "github.com/stretchr/testify/assert" +) + +func TestFunctionTypeBasics(t *testing.T) { + fnType := NewFunction(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a')) + if fnType.Name() != "→" { + t.Errorf("FunctionType should have \"→\" as a name. Got %q instead", fnType.Name()) + } + + if fnType.String() != "a → a → a" { + t.Errorf("Expected \"a → a → a\". Got %q instead", fnType.String()) + } + + if !fnType.Arg().Eq(hm.TypeVariable('a')) { + t.Error("Expected arg of function to be 'a'") + } + + if !fnType.Ret(false).Eq(NewFunction(hm.TypeVariable('a'), hm.TypeVariable('a'))) { + t.Error("Expected ret(false) to be a → a") + } + + if !fnType.Ret(true).Eq(hm.TypeVariable('a')) { + t.Error("Expected final return type to be 'a'") + } + + // a very simple fn + fnType = NewFunction(hm.TypeVariable('a'), hm.TypeVariable('a')) + if !fnType.Ret(true).Eq(hm.TypeVariable('a')) { + t.Error("Expected final return type to be 'a'") + } + + ftv := fnType.FreeTypeVar() + if len(ftv) != 1 { + t.Errorf("Expected only one free type var") + } + + for _, fas := range fnApplyTests { + fn := fas.fn.Apply(fas.sub).(*Function) + if !fn.Eq(fas.expected) { + t.Errorf("Expected %v. Got %v instead", fas.expected, fn) + } + } + + // bad shit + f := func() { + NewFunction(hm.TypeVariable('a')) + } + assert.Panics(t, f) +} + +var fnApplyTests = []struct { + fn *Function + sub hm.Subs + + expected *Function +}{ + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('a')), mSubs{'a': proton, 'b': neutron}, NewFunction(proton, proton)}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b')), mSubs{'a': proton, 'b': neutron}, NewFunction(proton, neutron)}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b')), mSubs{'c': proton, 'd': neutron}, NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b'))}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b')), mSubs{'a': proton, 'c': neutron}, NewFunction(proton, hm.TypeVariable('b'))}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b')), mSubs{'c': proton, 'b': neutron}, NewFunction(hm.TypeVariable('a'), neutron)}, + {NewFunction(electron, proton), mSubs{'a': proton, 'b': neutron}, NewFunction(electron, proton)}, + + // a -> (b -> c) + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b'), hm.TypeVariable('a')), mSubs{'a': proton, 'b': neutron}, NewFunction(proton, neutron, proton)}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('b')), mSubs{'a': proton, 'b': neutron}, NewFunction(proton, proton, neutron)}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b'), hm.TypeVariable('c')), mSubs{'a': proton, 'b': neutron}, NewFunction(proton, neutron, hm.TypeVariable('c'))}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('c'), hm.TypeVariable('b')), mSubs{'a': proton, 'b': neutron}, NewFunction(proton, hm.TypeVariable('c'), neutron)}, + + // (a -> b) -> c + {NewFunction(NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b')), hm.TypeVariable('a')), mSubs{'a': proton, 'b': neutron}, NewFunction(NewFunction(proton, neutron), proton)}, +} + +func TestFunctionType_FlatTypes(t *testing.T) { + fnType := NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b'), hm.TypeVariable('c')) + ts := fnType.FlatTypes() + correct := hm.Types{hm.TypeVariable('a'), hm.TypeVariable('b'), hm.TypeVariable('c')} + assert.Equal(t, ts, correct) + + fnType2 := NewFunction(fnType, hm.TypeVariable('d')) + correct = append(correct, hm.TypeVariable('d')) + ts = fnType2.FlatTypes() + assert.Equal(t, ts, correct) +} + +func TestFunctionType_Clone(t *testing.T) { + fnType := NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b'), hm.TypeVariable('c')) + assert.Equal(t, fnType.Clone(), fnType) + + rec := NewTupleType("", hm.TypeVariable('a'), NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b')), hm.TypeVariable('c')) + fnType = NewFunction(rec, rec) + assert.Equal(t, fnType.Clone(), fnType) +} diff --git a/types/interfaces.go b/types/interfaces.go new file mode 100644 index 0000000..b0afe8a --- /dev/null +++ b/types/interfaces.go @@ -0,0 +1,5 @@ +package hmtypes + +type Cloner interface { + Clone() interface{} +} diff --git a/types/monuples.go b/types/monuples.go new file mode 100644 index 0000000..e85d196 --- /dev/null +++ b/types/monuples.go @@ -0,0 +1,88 @@ +package hmtypes + +import ( + "fmt" + + "github.com/chewxy/hm" +) + +// Slice is the type of a Slice/List +type Slice hm.Monuple + +func (t Slice) Name() string { return "List" } +func (t Slice) Apply(subs hm.Subs) hm.Substitutable { return Slice(hm.Monuple(t).Apply(subs)) } +func (t Slice) FreeTypeVar() hm.TypeVarSet { return hm.Monuple(t).FreeTypeVar() } +func (t Slice) Format(s fmt.State, c rune) { fmt.Fprintf(s, "[]%v", t.T) } +func (t Slice) String() string { return fmt.Sprintf("%v", t) } +func (t Slice) Types() hm.Types { return hm.Types{t.T} } + +func (t Slice) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { + t2, err := hm.Monuple(t).Normalize(k, v) + if err != nil { + return nil, err + } + return Slice(t2), nil +} + +func (t Slice) Eq(other hm.Type) bool { + if ot, ok := other.(Slice); ok { + return ot.T.Eq(t.T) + } + return false +} + +func (t Slice) Monuple() hm.Monuple { return hm.Monuple(t) } + +// Linear is a linear type (i.e types that can only appear once) +type Linear hm.Monuple + +func (t Linear) Name() string { return "Linear" } +func (t Linear) Apply(subs hm.Subs) hm.Substitutable { return Linear(hm.Monuple(t).Apply(subs)) } +func (t Linear) FreeTypeVar() hm.TypeVarSet { return hm.Monuple(t).FreeTypeVar() } +func (t Linear) Format(s fmt.State, c rune) { fmt.Fprintf(s, "Linear[%v]", t.T) } +func (t Linear) String() string { return fmt.Sprintf("%v", t) } +func (t Linear) Types() hm.Types { return hm.Types{t.T} } + +func (t Linear) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { + t2, err := hm.Monuple(t).Normalize(k, v) + if err != nil { + return nil, err + } + return Linear(t2), nil +} + +func (t Linear) Eq(other hm.Type) bool { + if ot, ok := other.(Linear); ok { + return ot.T.Eq(t.T) + } + return false +} + +func (t Linear) Monuple() hm.Monuple { return hm.Monuple(t) } + +// Ref is a reference type (think pointers) +type Ref hm.Monuple + +func (t Ref) Name() string { return "Ref" } +func (t Ref) Apply(subs hm.Subs) hm.Substitutable { return Ref(hm.Monuple(t).Apply(subs)) } +func (t Ref) FreeTypeVar() hm.TypeVarSet { return hm.Monuple(t).FreeTypeVar() } +func (t Ref) Format(s fmt.State, c rune) { fmt.Fprintf(s, "*%v", t.T) } +func (t Ref) String() string { return fmt.Sprintf("%v", t) } +func (t Ref) Types() hm.Types { return hm.Types{t.T} } + +func (t Ref) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { + t2, err := hm.Monuple(t).Normalize(k, v) + if err != nil { + return nil, err + } + return Ref(t2), nil +} + +func (t Ref) Eq(other hm.Type) bool { + if ot, ok := other.(Ref); ok { + return ot.T.Eq(t.T) + } + return false +} + +func (t Ref) Monuple() hm.Monuple { return hm.Monuple(t) } diff --git a/types/pairs.go b/types/pairs.go new file mode 100644 index 0000000..eb46bcc --- /dev/null +++ b/types/pairs.go @@ -0,0 +1,119 @@ +package hmtypes + +import ( + "fmt" + + "github.com/chewxy/hm" +) + +var ( + _ hm.Type = &Choice{} + _ hm.Type = &Super{} + _ hm.Type = &Application{} +) + +// pair types + +// Choice is the type of choice of algorithm to use within a class method. +// +// Imagine how one would implement a class in an OOP language. +// Then imagine how one would implement method overloading for the class. +// The typical approach is name mangling followed by having a jump table. +// +// Now consider OOP classes and the ability to override methods, based on subclassing ability. +// The typical approach to this is to use a Vtable. +// +// Both overloading and overriding have a general notion: a jump table of sorts. +// How does one type such a table? +// +// By using Choice. +// +// The first type is the key of either the vtable or the name mangled table. +// The second type is the value of the table. +type Choice hm.Pair + +func (t *Choice) Name() string { return ":" } +func (t *Choice) Apply(sub hm.Subs) hm.Substitutable { ((*hm.Pair)(t)).Apply(sub); return t } +func (t *Choice) FreeTypeVar() hm.TypeVarSet { return ((*hm.Pair)(t)).FreeTypeVar() } +func (t *Choice) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%v : %v", t.A, t.B) } +func (t *Choice) String() string { return fmt.Sprintf("%v", t) } + +func (t *Choice) Normalize(k hm.TypeVarSet, v hm.TypeVarSet) (hm.Type, error) { + panic("not implemented") +} + +func (t *Choice) Types() hm.Types { return ((*hm.Pair)(t)).Types() } + +func (t *Choice) Eq(other hm.Type) bool { + if ot, ok := other.(*Choice); ok { + return ot.A.Eq(t.A) && ot.B.Eq(t.B) + } + return false +} + +func (t *Choice) Clone() interface{} { return (*Choice)((*hm.Pair)(t).Clone()) } + +func (t *Choice) Pair() *hm.Pair { return (*hm.Pair)(t) } + +// Super is the inverse of Choice. It allows for supertyping functions. +// +// Supertyping is typically implemented as a adding an entry to the vtable/mangled table. +// But there needs to be a separate accounting structure to keep account of the types. +// +// This is where Super comes in. +type Super hm.Pair + +func (t *Super) Name() string { return "§" } +func (t *Super) Apply(sub hm.Subs) hm.Substitutable { ((*hm.Pair)(t)).Apply(sub); return t } +func (t *Super) FreeTypeVar() hm.TypeVarSet { return ((*hm.Pair)(t)).FreeTypeVar() } +func (t *Super) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%v §: %v", t.A, t.B) } +func (t *Super) String() string { return fmt.Sprintf("%v", t) } + +func (t *Super) Normalize(k hm.TypeVarSet, v hm.TypeVarSet) (hm.Type, error) { + panic("not implemented") +} + +func (t *Super) Types() hm.Types { return ((*hm.Pair)(t)).Types() } + +func (t *Super) Eq(other hm.Type) bool { + if ot, ok := other.(*Super); ok { + return ot.A.Eq(t.A) && ot.B.Eq(t.B) + } + return false +} + +func (t *Super) Clone() interface{} { return (*Super)((*hm.Pair)(t).Clone()) } + +func (t *Super) Pair() *hm.Pair { return (*hm.Pair)(t) } + +// Application is the pre-unified type for a function application. +// In a simple HM system this would not be needed as the type of an +// application expression would be found during the unification phase of +// the expression. +// +// In advanced systems where unification may be done concurrently, this would +// be required, as a "thunk" of sorts for the type system. +type Application hm.Pair + +func (t *Application) Name() string { return "•" } +func (t *Application) Apply(sub hm.Subs) hm.Substitutable { ((*hm.Pair)(t)).Apply(sub); return t } +func (t *Application) FreeTypeVar() hm.TypeVarSet { return ((*hm.Pair)(t)).FreeTypeVar() } +func (t *Application) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%v • %v", t.A, t.B) } +func (t *Application) String() string { return fmt.Sprintf("%v", t) } + +func (t *Application) Normalize(k hm.TypeVarSet, v hm.TypeVarSet) (hm.Type, error) { + panic("not implemented") +} + +func (t *Application) Types() hm.Types { return ((*hm.Pair)(t)).Types() } + +func (t *Application) Eq(other hm.Type) bool { + if ot, ok := other.(*Application); ok { + return ot.A.Eq(t.A) && ot.B.Eq(t.B) + } + return false +} + +func (t *Application) Clone() interface{} { return (*Application)((*hm.Pair)(t).Clone()) } + +func (t *Application) Pair() *hm.Pair { return (*hm.Pair)(t) } diff --git a/types/perf.go b/types/perf.go new file mode 100644 index 0000000..5902304 --- /dev/null +++ b/types/perf.go @@ -0,0 +1,28 @@ +package hmtypes + +import ( + "unsafe" + + "github.com/chewxy/hm" +) + +func borrowFn() *Function { + got := hm.BorrowPair() + return (*Function)(unsafe.Pointer(got)) +} + +// ReturnFn returns a *FunctionType to the pool. NewFnType automatically borrows from the pool. USE WITH CAUTION +func ReturnFn(fnt *Function) { + if a, ok := fnt.A.(*Function); ok { + ReturnFn(a) + } + + if b, ok := fnt.B.(*Function); ok { + ReturnFn(b) + } + + fnt.A = nil + fnt.B = nil + p := (*hm.Pair)(unsafe.Pointer(fnt)) + hm.ReturnPair(p) +} diff --git a/types/perf_test.go b/types/perf_test.go new file mode 100644 index 0000000..7d82844 --- /dev/null +++ b/types/perf_test.go @@ -0,0 +1,19 @@ +package hmtypes + +import "testing" + +func TestFnTypePool(t *testing.T) { + f := borrowFn() + f.A = NewFunction(proton, electron) + f.B = NewFunction(proton, neutron) + + ReturnFn(f) + f = borrowFn() + if f.A != nil { + t.Error("FunctionType not cleaned up: a is not nil") + } + if f.B != nil { + t.Error("FunctionType not cleaned up: b is not nil") + } + +} diff --git a/types/quantified.go b/types/quantified.go new file mode 100644 index 0000000..a54ae75 --- /dev/null +++ b/types/quantified.go @@ -0,0 +1,9 @@ +package hmtypes + +import "github.com/chewxy/hm" + +// Quantified is essentially a replacement scheme that is made into a Type +// TODO: implement hm.Type +type Quantified struct { + hm.Scheme +} diff --git a/types/record.go b/types/record.go new file mode 100644 index 0000000..5ebfb45 --- /dev/null +++ b/types/record.go @@ -0,0 +1,207 @@ +package hmtypes + +import ( + "fmt" + + "github.com/chewxy/hm" +) + +// Tuple is a basic tuple type. It takes an optional name +type Tuple struct { + ts []hm.Type + name string +} + +// NewTupleType creates a new Tuple +func NewTupleType(name string, ts ...hm.Type) *Tuple { + return &Tuple{ + ts: ts, + name: name, + } +} + +func (t *Tuple) Apply(subs hm.Subs) hm.Substitutable { + ts := t.apply(subs) + return NewTupleType(t.name, ts...) +} + +func (t *Tuple) FreeTypeVar() hm.TypeVarSet { + var tvs hm.TypeVarSet + for _, v := range t.ts { + tvs = v.FreeTypeVar().Union(tvs) + } + return tvs +} + +func (t *Tuple) Name() string { + if t.name != "" { + return t.name + } + return t.String() +} + +func (t *Tuple) Normalize(k, v hm.TypeVarSet) (T hm.Type, err error) { + var ts []hm.Type + if ts, err = t.normalize(k, v); err != nil { + return nil, err + } + return NewTupleType(t.name, ts...), nil +} + +func (t *Tuple) Types() hm.Types { + ts := hm.BorrowTypes(len(t.ts)) + copy(ts, t.ts) + return ts +} + +func (t *Tuple) Format(f fmt.State, c rune) { + f.Write([]byte("(")) + for i, v := range t.ts { + if i < len(t.ts)-1 { + fmt.Fprintf(f, "%v, ", v) + } else { + fmt.Fprintf(f, "%v)", v) + } + } +} + +func (t *Tuple) String() string { return fmt.Sprintf("%v", t) } + +func (t *Tuple) Eq(other hm.Type) bool { + if ot, ok := other.(*Tuple); ok { + if len(ot.ts) != len(t.ts) { + return false + } + for i, v := range t.ts { + if !v.Eq(ot.ts[i]) { + return false + } + } + return true + } + return false +} + +// Clone implements Cloner +func (t *Tuple) Clone() interface{} { + retVal := new(Tuple) + ts := hm.BorrowTypes(len(t.ts)) + for i, tt := range t.ts { + if c, ok := tt.(Cloner); ok { + ts[i] = c.Clone().(hm.Type) + } else { + ts[i] = tt + } + } + retVal.ts = ts + retVal.name = t.name + + return retVal +} + +// internal function to be used by Tuple.Apply and Record.Apply +func (t *Tuple) apply(subs hm.Subs) []hm.Type { + ts := make([]hm.Type, len(t.ts)) + for i, v := range t.ts { + ts[i] = v.Apply(subs).(hm.Type) + } + return ts +} + +// internal function to be used by Tuple.Normalize and Record.Normalize +func (t *Tuple) normalize(k, v hm.TypeVarSet) ([]hm.Type, error) { + ts := make([]hm.Type, len(t.ts)) + var err error + for i, tt := range t.ts { + if ts[i], err = tt.Normalize(k, v); err != nil { + return nil, err + } + } + return ts, nil +} + +// Field is a name-type pair. +type Field struct { + Name string + Type hm.Type +} + +// Record is a basic record type. It's like Tuple except there are named fields. It takes an optional name. +type Record struct { + Tuple + ns []string // field names +} + +// NewRecordType creates a new Record hm.Type +func NewRecordType(name string, fields ...Field) *Record { + ts := make([]hm.Type, len(fields)) + ns := make([]string, len(fields)) + for i := range fields { + ns[i] = fields[i].Name + ts[i] = fields[i].Type + } + return &Record{ + Tuple: Tuple{ + ts: ts, + name: name, + }, + ns: ns, + } +} + +func (t *Record) Apply(subs hm.Subs) hm.Substitutable { + ts := t.apply(subs) + return &Record{ + Tuple: Tuple{ + ts: ts, + name: t.name, + }, + ns: t.ns, + } +} + +func (t *Record) Normalize(k, v hm.TypeVarSet) (T hm.Type, err error) { + var ts []hm.Type + if ts, err = t.normalize(k, v); err != nil { + return nil, err + } + return &Record{ + Tuple: Tuple{ + ts: ts, + name: t.name, + }, + ns: t.ns, + }, nil +} + +func (t *Record) Format(f fmt.State, c rune) { + if t.name != "" { + f.Write([]byte(t.name)) + } + f.Write([]byte("{")) + for i, v := range t.ts { + if i < len(t.ts)-1 { + fmt.Fprintf(f, "%v: %v, ", t.ns[i], v) + } else { + fmt.Fprintf(f, "%v: %v}", t.ns[i], v) + } + } +} + +func (t *Record) Eq(other hm.Type) bool { + if ot, ok := other.(*Record); ok { + if len(ot.ts) != len(t.ts) { + return false + } + for i, v := range t.ts { + if t.ns[i] != ot.ns[i] { + return false + } + if !v.Eq(ot.ts[i]) { + return false + } + } + return true + } + return false +} diff --git a/types/test_test.go b/types/test_test.go new file mode 100644 index 0000000..c3a67ba --- /dev/null +++ b/types/test_test.go @@ -0,0 +1,42 @@ +package hmtypes + +import "github.com/chewxy/hm" + +const ( + proton hm.TypeConst = "proton" + neutron hm.TypeConst = "neutron" + quark hm.TypeConst = "quark" + + electron hm.TypeConst = "electron" + positron hm.TypeConst = "positron" + muon hm.TypeConst = "muon" + + photon hm.TypeConst = "photon" + higgs hm.TypeConst = "higgs" +) + +// useful copy pasta from the hm package +type mSubs map[hm.TypeVariable]hm.Type + +func (s mSubs) Get(tv hm.TypeVariable) (hm.Type, bool) { retVal, ok := s[tv]; return retVal, ok } +func (s mSubs) Add(tv hm.TypeVariable, t hm.Type) hm.Subs { s[tv] = t; return s } +func (s mSubs) Remove(tv hm.TypeVariable) hm.Subs { delete(s, tv); return s } + +func (s mSubs) Iter() []hm.Substitution { + retVal := make([]hm.Substitution, len(s)) + var i int + for k, v := range s { + retVal[i] = hm.Substitution{k, v} + i++ + } + return retVal +} + +func (s mSubs) Size() int { return len(s) } +func (s mSubs) Clone() hm.Subs { + retVal := make(mSubs) + for k, v := range s { + retVal[k] = v + } + return retVal +}