diff --git a/acme/service.go b/acme/service.go index 492d1a2..ccc2c91 100644 --- a/acme/service.go +++ b/acme/service.go @@ -1,6 +1,9 @@ package acme -import "github.com/ImageWare/TLSential/model" +import ( + "github.com/ImageWare/TLSential/model" + lregistration "github.com/go-acme/lego/v3/registration" +) // Service implements the ability to trigger a new certificate request, or Renew // a certificate. Renewal presumes a certificate has already been issued. @@ -11,4 +14,5 @@ type Service interface { RequestRenew(id string) bool GetAutoRenewChannel() chan string GetIssueChannel() chan string + Register(lregistration.User) (*lregistration.Resource, error) } diff --git a/api/certificate.go b/api/certificate.go index 6624c3e..d5999f2 100644 --- a/api/certificate.go +++ b/api/certificate.go @@ -245,6 +245,15 @@ func (h *certHandler) Post() http.HandlerFunc { return } + reg, err := h.acme.Register(c) + + if err != nil { + log.Printf("api CertHandler POST, acme.Register(), %s", err.Error()) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + c.ACMERegistration = reg + //TODO: Should probably decide valid range for client supplied RenewAt value //For instance we may not want them to be able to specify 0 or less, as that would //cause the cert to never auto renew. Although maybe thats a valid use case? diff --git a/model/certificate.go b/model/certificate.go index a018c24..0ed68f0 100644 --- a/model/certificate.go +++ b/model/certificate.go @@ -14,8 +14,6 @@ import ( "golang.org/x/net/idna" "github.com/ImageWare/TLSential/auth" - "github.com/go-acme/lego/v3/certcrypto" - "github.com/go-acme/lego/v3/lego" "github.com/go-acme/lego/v3/registration" "github.com/segmentio/ksuid" ) @@ -109,23 +107,6 @@ func NewCertificate(domains []string, email string) (*Certificate, error) { ACMEKey: privateKey, } - config := lego.NewConfig(c) - - config.CADirURL = CADirURL - config.Certificate.KeyType = certcrypto.RSA2048 - - client, err := lego.NewClient(config) - if err != nil { - return nil, err - } - - // TODO: Move this to acme Service so we can mock here - reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) - if err != nil { - return nil, err - } - c.ACMERegistration = reg - return c, nil } diff --git a/model/certificate_test.go b/model/certificate_test.go index 0a05a63..acca257 100644 --- a/model/certificate_test.go +++ b/model/certificate_test.go @@ -26,12 +26,6 @@ func TestNewCertificate(t *testing.T) { "test@notexample.com", ErrInvalidDomains.Error(), }, - { - "email at example.com", - []string{"example.com"}, - "test@example.com", - "acme: error: 400 :: POST :: https://acme-v02.api.letsencrypt.org/acme/new-acct :: urn:ietf:params:acme:error:invalidEmail :: Error creating new account :: invalid contact domain. Contact emails @example.com are forbidden, url: ", - }, { "wildcard domain", []string{"*.example.com"}, @@ -121,11 +115,6 @@ func TestNewCertificate(t *testing.T) { t.Error("email mismatch") } - // TODO: Test ACMERegistration values, like Status, ToS, etc. - if c.ACMERegistration == nil { - t.Error("acme registration shouldn't be nil") - } - if c.ACMEKey == nil { t.Error("acme key should not be nil") } diff --git a/service/acme.go b/service/acme.go index e21766b..3bb7d1e 100644 --- a/service/acme.go +++ b/service/acme.go @@ -11,6 +11,7 @@ import ( "github.com/go-acme/lego/v3/certcrypto" lcert "github.com/go-acme/lego/v3/certificate" "github.com/go-acme/lego/v3/lego" + lregistration "github.com/go-acme/lego/v3/registration" ) var certAutoRenewChan chan string @@ -19,8 +20,15 @@ var certIssueChan chan string type acmeService struct { certService cert.Service challService challenge_config.Service + registrar UserRegistrar } +type UserRegistrar interface { + Register(u lregistration.User) (*lregistration.Resource, error) +} + +type legoRegistrar struct{} + func CreateChannelsAndListeners(buffSize int, listeners int, cs cert.Service, as acme.Service) { certAutoRenewChan = make(chan string, buffSize) certIssueChan = make(chan string) @@ -56,8 +64,14 @@ func handleCertChannels(cs cert.Service, as acme.Service) { } } +//Create a new acme.Service with a default LEGO registrar func NewAcmeService(cts cert.Service, chs challenge_config.Service) acme.Service { - return &acmeService{certService: cts, challService: chs} + return NewAcmeServiceWithRegistrar(cts, chs, &legoRegistrar{}) +} + +//Create a new acme.Service that uses the supplied UserRegistrar. registrar must not be nil +func NewAcmeServiceWithRegistrar(cts cert.Service, chs challenge_config.Service, registrar UserRegistrar) acme.Service { + return &acmeService{certService: cts, challService: chs, registrar: registrar} } //RequestRenew will try to send to the CertAutoRenewChan channel, but won't block if the channel is full. @@ -221,6 +235,10 @@ func (s *acmeService) Renew(c *model.Certificate) { } +func (s *acmeService) Register(u lregistration.User) (*lregistration.Resource, error) { + return s.registrar.Register(u) +} + func getExpiry(c *model.Certificate) time.Time { x509Cert, err := certcrypto.ParsePEMCertificate(c.Certificate) if err != nil { @@ -229,3 +247,19 @@ func getExpiry(c *model.Certificate) time.Time { return x509Cert.NotAfter } + +func (l *legoRegistrar) Register(u lregistration.User) (*lregistration.Resource, error) { + config := lego.NewConfig(u) + + config.CADirURL = model.CADirURL + config.Certificate.KeyType = certcrypto.RSA2048 + + c, err := lego.NewClient(config) + + if err != nil { + return nil, err + } + + reg, err := c.Registration.Register(lregistration.RegisterOptions{TermsOfServiceAgreed: true}) + return reg, err +} diff --git a/service/acme_test.go b/service/acme_test.go new file mode 100644 index 0000000..076fbe2 --- /dev/null +++ b/service/acme_test.go @@ -0,0 +1,126 @@ +package service + +import ( + "errors" + "testing" + + "github.com/ImageWare/TLSential/acme" + "github.com/ImageWare/TLSential/model" + lregistration "github.com/go-acme/lego/v3/registration" +) + +type certTest struct { + testName string + domains []string + email string + registrar UserRegistrar + expectedError string +} + +type justReturnRegistrar struct { + resource *lregistration.Resource + err error +} + +func (r *justReturnRegistrar) Register(u lregistration.User) (*lregistration.Resource, error) { + return r.resource, r.err +} + +func TestRegister(t *testing.T) { + passThruError := errors.New("This is the expected error") + certTests := []certTest{ + { + "happy path", + []string{"example.com", "example2.com"}, + "test@notexample.com", + &justReturnRegistrar{nil, nil}, + "", + }, + { + //This test makes sure the registrar is actually being called + "return error", + []string{"somestuff.com"}, + "test@aurl.com", + &justReturnRegistrar{nil, passThruError}, + passThruError.Error(), + }, + } + + for _, ct := range certTests { + t.Run(ct.testName, func(t *testing.T) { + + c, err := model.NewCertificate(ct.domains, ct.email) + + if err != nil { + t.Error("Error creating certificate", err) + return + } + var a acme.Service + if ct.registrar == nil { + a = NewAcmeService(nil, nil) + } else { + a = NewAcmeServiceWithRegistrar(nil, nil, ct.registrar) + } + reg, err := a.Register(c) + + if err != nil { + c.ACMERegistration = reg + } + + if err == nil { + if ct.expectedError != "" { + t.Error("no error returned when expected") + return + } + } + + if err != nil { + if err.Error() != ct.expectedError { + t.Errorf("error mismatch: got %s, expected %s\n", err.Error(), ct.expectedError) + } + return + } + + if c.LastError != nil { + t.Error("last error shouldn't be set") + } + + if c.ACMEEmail != ct.email { + t.Error("email mismatch") + } + + if c.ACMEKey == nil { + t.Error("acme key should not be nil") + } + }) + } +} + +func TestChannels(t *testing.T) { + t.Run("request_issue", func(t *testing.T) { + CreateChannelsAndListeners(1, 0, nil, nil) + + a := NewAcmeServiceWithRegistrar(nil, nil, nil) + + if !a.RequestRenew("id") { + t.Error("Should not have blocked yet") + return + } + + if a.RequestIssue("id2") { + t.Error("Should have blocked") + return + } + + select { + case id := <-a.GetAutoRenewChannel(): + if id != "id" { + t.Errorf("expected 'id' but got '%s'", id) + } + break + default: + t.Error("Could not read from channel") + } + }) + +}