Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion acme/service.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package acme

import "github.com/ImageWare/TLSential/model"
import (
"github.com/ImageWare/TLSential/model"
lregistration "github.com/go-acme/lego/v3/registration"
)

// Service implements the ability to trigger a new certificate request, or Renew
// a certificate. Renewal presumes a certificate has already been issued.
Expand All @@ -11,4 +14,5 @@ type Service interface {
RequestRenew(id string) bool
GetAutoRenewChannel() chan string
GetIssueChannel() chan string
Register(lregistration.User) (*lregistration.Resource, error)
}
9 changes: 9 additions & 0 deletions api/certificate.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,15 @@ func (h *certHandler) Post() http.HandlerFunc {
return
}

reg, err := h.acme.Register(c)

if err != nil {
log.Printf("api CertHandler POST, acme.Register(), %s", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
c.ACMERegistration = reg

//TODO: Should probably decide valid range for client supplied RenewAt value
//For instance we may not want them to be able to specify 0 or less, as that would
//cause the cert to never auto renew. Although maybe thats a valid use case?
Expand Down
19 changes: 0 additions & 19 deletions model/certificate.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ import (
"golang.org/x/net/idna"

"github.com/ImageWare/TLSential/auth"
"github.com/go-acme/lego/v3/certcrypto"
"github.com/go-acme/lego/v3/lego"
"github.com/go-acme/lego/v3/registration"
"github.com/segmentio/ksuid"
)
Expand Down Expand Up @@ -109,23 +107,6 @@ func NewCertificate(domains []string, email string) (*Certificate, error) {
ACMEKey: privateKey,
}

config := lego.NewConfig(c)

config.CADirURL = CADirURL
config.Certificate.KeyType = certcrypto.RSA2048

client, err := lego.NewClient(config)
if err != nil {
return nil, err
}

// TODO: Move this to acme Service so we can mock here
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err != nil {
return nil, err
}
c.ACMERegistration = reg

return c, nil
}

Expand Down
11 changes: 0 additions & 11 deletions model/certificate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ func TestNewCertificate(t *testing.T) {
"test@notexample.com",
ErrInvalidDomains.Error(),
},
{
"email at example.com",
[]string{"example.com"},
"test@example.com",
"acme: error: 400 :: POST :: https://acme-v02.api.letsencrypt.org/acme/new-acct :: urn:ietf:params:acme:error:invalidEmail :: Error creating new account :: invalid contact domain. Contact emails @example.com are forbidden, url: ",
},
{
"wildcard domain",
[]string{"*.example.com"},
Expand Down Expand Up @@ -121,11 +115,6 @@ func TestNewCertificate(t *testing.T) {
t.Error("email mismatch")
}

// TODO: Test ACMERegistration values, like Status, ToS, etc.
if c.ACMERegistration == nil {
t.Error("acme registration shouldn't be nil")
}

if c.ACMEKey == nil {
t.Error("acme key should not be nil")
}
Expand Down
36 changes: 35 additions & 1 deletion service/acme.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/go-acme/lego/v3/certcrypto"
lcert "github.com/go-acme/lego/v3/certificate"
"github.com/go-acme/lego/v3/lego"
lregistration "github.com/go-acme/lego/v3/registration"
)

var certAutoRenewChan chan string
Expand All @@ -19,8 +20,15 @@ var certIssueChan chan string
type acmeService struct {
certService cert.Service
challService challenge_config.Service
registrar UserRegistrar
}

type UserRegistrar interface {
Register(u lregistration.User) (*lregistration.Resource, error)
}

type legoRegistrar struct{}

func CreateChannelsAndListeners(buffSize int, listeners int, cs cert.Service, as acme.Service) {
certAutoRenewChan = make(chan string, buffSize)
certIssueChan = make(chan string)
Expand Down Expand Up @@ -56,8 +64,14 @@ func handleCertChannels(cs cert.Service, as acme.Service) {
}
}

//Create a new acme.Service with a default LEGO registrar
func NewAcmeService(cts cert.Service, chs challenge_config.Service) acme.Service {
return &acmeService{certService: cts, challService: chs}
return NewAcmeServiceWithRegistrar(cts, chs, &legoRegistrar{})
}

//Create a new acme.Service that uses the supplied UserRegistrar. registrar must not be nil
func NewAcmeServiceWithRegistrar(cts cert.Service, chs challenge_config.Service, registrar UserRegistrar) acme.Service {
return &acmeService{certService: cts, challService: chs, registrar: registrar}
}

//RequestRenew will try to send to the CertAutoRenewChan channel, but won't block if the channel is full.
Expand Down Expand Up @@ -221,6 +235,10 @@ func (s *acmeService) Renew(c *model.Certificate) {

}

func (s *acmeService) Register(u lregistration.User) (*lregistration.Resource, error) {
return s.registrar.Register(u)
}

func getExpiry(c *model.Certificate) time.Time {
x509Cert, err := certcrypto.ParsePEMCertificate(c.Certificate)
if err != nil {
Expand All @@ -229,3 +247,19 @@ func getExpiry(c *model.Certificate) time.Time {

return x509Cert.NotAfter
}

func (l *legoRegistrar) Register(u lregistration.User) (*lregistration.Resource, error) {
config := lego.NewConfig(u)

config.CADirURL = model.CADirURL
config.Certificate.KeyType = certcrypto.RSA2048

c, err := lego.NewClient(config)

if err != nil {
return nil, err
}

reg, err := c.Registration.Register(lregistration.RegisterOptions{TermsOfServiceAgreed: true})
return reg, err
}
126 changes: 126 additions & 0 deletions service/acme_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package service

import (
"errors"
"testing"

"github.com/ImageWare/TLSential/acme"
"github.com/ImageWare/TLSential/model"
lregistration "github.com/go-acme/lego/v3/registration"
)

type certTest struct {
testName string
domains []string
email string
registrar UserRegistrar
expectedError string
}

type justReturnRegistrar struct {
resource *lregistration.Resource
err error
}

func (r *justReturnRegistrar) Register(u lregistration.User) (*lregistration.Resource, error) {
return r.resource, r.err
}

func TestRegister(t *testing.T) {
passThruError := errors.New("This is the expected error")
certTests := []certTest{
{
"happy path",
[]string{"example.com", "example2.com"},
"test@notexample.com",
&justReturnRegistrar{nil, nil},
"",
},
{
//This test makes sure the registrar is actually being called
"return error",
[]string{"somestuff.com"},
"test@aurl.com",
&justReturnRegistrar{nil, passThruError},
passThruError.Error(),
},
}

for _, ct := range certTests {
t.Run(ct.testName, func(t *testing.T) {

c, err := model.NewCertificate(ct.domains, ct.email)

if err != nil {
t.Error("Error creating certificate", err)
return
}
var a acme.Service
if ct.registrar == nil {
a = NewAcmeService(nil, nil)
} else {
a = NewAcmeServiceWithRegistrar(nil, nil, ct.registrar)
}
reg, err := a.Register(c)

if err != nil {
c.ACMERegistration = reg
}

if err == nil {
if ct.expectedError != "" {
t.Error("no error returned when expected")
return
}
}

if err != nil {
if err.Error() != ct.expectedError {
t.Errorf("error mismatch: got %s, expected %s\n", err.Error(), ct.expectedError)
}
return
}

if c.LastError != nil {
t.Error("last error shouldn't be set")
}

if c.ACMEEmail != ct.email {
t.Error("email mismatch")
}

if c.ACMEKey == nil {
t.Error("acme key should not be nil")
}
})
}
}

func TestChannels(t *testing.T) {
t.Run("request_issue", func(t *testing.T) {
CreateChannelsAndListeners(1, 0, nil, nil)

a := NewAcmeServiceWithRegistrar(nil, nil, nil)

if !a.RequestRenew("id") {
t.Error("Should not have blocked yet")
return
}

if a.RequestIssue("id2") {
t.Error("Should have blocked")
return
}

select {
case id := <-a.GetAutoRenewChannel():
if id != "id" {
t.Errorf("expected 'id' but got '%s'", id)
}
break
default:
t.Error("Could not read from channel")
}
})

}