Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions storage/storage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// Package storage handles keeping the files on disk.
package storage

import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"os"
"path/filepath"
"sync"

"github.com/letsencrypt/test-certs-site/config"
)

type version string

const (
next version = "next"
current version = "current"
)

const (
privateKeyFilename = "private.pem"
certificateFilename = "certificate.pem"
)

const (
// dirPerms rwxr-xr-x for created directories. Writable only by user, but global r & x for debugging.
dirPerms = 0o755

// keyPerms rw------- for private keys. No permissions outside of user.
keyPerms = 0o600

// certPerms rw-r--r-- for cert files. Globally readable certs for debugging.
certPerms = 0o644
)

// Storage of files for a domain.
type Storage struct {
// mu prevents simultaneous writing of files, or reading while writing.
mu sync.Mutex
dir string
}

// New storage handle.
func New(storageDir string) (*Storage, error) {
return &Storage{dir: storageDir}, nil
}

// StoreNextKey generates a new "next" key, writing it to disk.
func (s *Storage) StoreNextKey(domain string, keyType string) (crypto.Signer, error) {
var key crypto.Signer
switch keyType {
case config.KeyTypeP256:
p256Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
key = p256Key
case config.KeyTypeRSA2048:
bits := 2048
rsaKey, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, err
}
key = rsaKey
default:
// Should be unreachable due to config validation
return nil, fmt.Errorf("unknown key type: %s", keyType)
}

keyBytes, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
return nil, err
}

pemBytes := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: keyBytes,
})

path := s.pathFor(domain, next, privateKeyFilename)

s.mu.Lock()
defer s.mu.Unlock()

err = os.MkdirAll(filepath.Dir(path), dirPerms)
if err != nil {
return nil, err
}

err = os.WriteFile(path, pemBytes, keyPerms)
if err != nil {
return nil, err
}

return key, nil
}

// StoreNextCert stores the next certificate for the domain.
// Certificates should be a sequence of DER certificates.
func (s *Storage) StoreNextCert(domain string, certificates [][]byte) error {
s.mu.Lock()
defer s.mu.Unlock()

certPath := s.pathFor(domain, next, certificateFilename)
cert, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, certPerms) //nolint:gosec // Arbitrary file is not a risk here
if err != nil {
return fmt.Errorf("could not open certificate file: %w", err)
}
defer cert.Close()

for _, data := range certificates {
err := pem.Encode(cert, &pem.Block{
Type: "CERTIFICATE",
Bytes: data,
})
if err != nil {
return fmt.Errorf("could not write certificate: %w", err)
}
}

return nil
}

// TakeNext overwrites the current cert/key with the next cert/key, and returns the new current values.
func (s *Storage) TakeNext(domain string) (tls.Certificate, error) {
s.mu.Lock()
defer s.mu.Unlock()

err := os.MkdirAll(s.pathFor(domain, current, ""), dirPerms)
if err != nil {
return tls.Certificate{}, err
}

// Read the next values we're about to make current.
// Doing this before renaming ensures the key and certificate match.
cert, err := s.read(domain, next)
if err != nil {
return tls.Certificate{}, fmt.Errorf("reading next certificate: %w", err)
}

for _, path := range []string{privateKeyFilename, certificateFilename} {
nextPath := s.pathFor(domain, next, path)
currPath := s.pathFor(domain, current, path)
err := os.Rename(nextPath, currPath)
if err != nil {
return tls.Certificate{}, err
}
}

return cert, nil
}

// ReadCurrent reads the current cert and key for this domain.
// Returns an error if the stored value couldn't be read or parsed.
func (s *Storage) ReadCurrent(domain string) (tls.Certificate, error) {
s.mu.Lock()
defer s.mu.Unlock()

return s.read(domain, current)
}

// ReadNext reads the next cert and key for this domain.
// Returns an error if the stored value couldn't be read or parsed.
func (s *Storage) ReadNext(domain string) (tls.Certificate, error) {
s.mu.Lock()
defer s.mu.Unlock()

return s.read(domain, next)
}

// read a cert and key. Common logic for ReadCurrent and ReadNext. Caller should hold mu.
func (s *Storage) read(domain string, ver version) (tls.Certificate, error) {
return tls.LoadX509KeyPair(s.pathFor(domain, ver, certificateFilename), s.pathFor(domain, ver, privateKeyFilename))
}

func (s *Storage) pathFor(domain string, ver version, file string) string {
return filepath.Join(s.dir, domain, string(ver), file)
}
82 changes: 82 additions & 0 deletions storage/storage_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package storage

import (
"crypto"
"crypto/rand"
"crypto/x509"
"testing"

"github.com/letsencrypt/test-certs-site/config"
)

// TestStorage goes through the expected storage lifecycle.
func TestStorage(t *testing.T) {
t.Parallel()

storage, err := New(t.TempDir())
if err != nil {
t.Fatal(err)
}

const domain = "interesting.salad"

// Go through the lifecycle 3 times
for i := range 3 {
// Alternate key types to test both
keyType := config.KeyTypeP256
if i%2 == 1 {
keyType = config.KeyTypeRSA2048
}
key, err := storage.StoreNextKey(domain, keyType)
if err != nil {
t.Fatal(err)
}

// Outside of tests, this would come from a CA:
certs := testCert(t, domain, key)

err = storage.StoreNextCert(domain, certs)
if err != nil {
t.Fatal(err)
}

_, err = storage.ReadNext(domain)
if err != nil {
t.Fatal(err)
}

// A real user of the storage package would validate the certs here.
// Eg, checking if they're expired or revoked.

_, err = storage.TakeNext(domain)
if err != nil {
t.Fatal(err)
}

current, err := storage.ReadCurrent(domain)
if err != nil {
t.Fatal(err)
}

if current.Leaf.DNSNames[0] != domain {
t.Fatalf("Expected %s DNS SAN", domain)
}
}
}

// testCert returns a test self-signed cert for the given key.
func testCert(t *testing.T, domain string, key crypto.Signer) [][]byte {
t.Helper()

// Create a certificate template
template := x509.Certificate{
DNSNames: []string{domain},
}

certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, key.Public(), key)
if err != nil {
t.Fatalf("Failed to create certificate: %v", err)
}

return [][]byte{certDER}
}