diff --git a/age.go b/age.go index a4c2ad3d..bb1c1ccc 100644 --- a/age.go +++ b/age.go @@ -46,6 +46,7 @@ package age import ( + "bytes" "crypto/hmac" "crypto/rand" "errors" @@ -207,15 +208,29 @@ func (*NoIdentityMatchError) Error() string { // If no identity matches the encrypted file, the returned error will be of type // [NoIdentityMatchError]. func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) { - if len(identities) == 0 { - return nil, errors.New("no identities specified") - } - hdr, payload, err := format.Parse(src) if err != nil { return nil, fmt.Errorf("failed to read header: %w", err) } + fileKey, err := decryptHdr(hdr, identities...) + if err != nil { + return nil, err + } + + nonce := make([]byte, streamNonceSize) + if _, err := io.ReadFull(payload, nonce); err != nil { + return nil, fmt.Errorf("failed to read nonce: %w", err) + } + + return stream.NewReader(streamKey(fileKey, nonce), payload) +} + +func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) { + if len(identities) == 0 { + return nil, errors.New("no identities specified") + } + stanzas := make([]*Stanza, 0, len(hdr.Recipients)) for _, s := range hdr.Recipients { stanzas = append(stanzas, (*Stanza)(s)) @@ -223,6 +238,7 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) { errNoMatch := &NoIdentityMatchError{} var fileKey []byte for _, id := range identities { + var err error fileKey, err = id.Unwrap(stanzas) if errors.Is(err, ErrIncorrectIdentity) { errNoMatch.Errors = append(errNoMatch.Errors, err) @@ -244,12 +260,7 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) { return nil, errors.New("bad header MAC") } - nonce := make([]byte, streamNonceSize) - if _, err := io.ReadFull(payload, nonce); err != nil { - return nil, fmt.Errorf("failed to read nonce: %w", err) - } - - return stream.NewReader(streamKey(fileKey, nonce), payload) + return fileKey, nil } // multiUnwrap is a helper that implements Identity.Unwrap in terms of a @@ -270,3 +281,49 @@ func multiUnwrap(unwrap func(*Stanza) ([]byte, error), stanzas []*Stanza) ([]byt } return nil, ErrIncorrectIdentity } + +// ExtractHeader returns a detched header from the src file. +// +// The detached header can be decrypted with [DecryptHeader] and then the file +// key can be used with [NewInjectedFileKeyIdentity]. +func ExtractHeader(src io.Reader) ([]byte, error) { + hdr, _, err := format.Parse(src) + if err != nil { + return nil, fmt.Errorf("failed to read header: %w", err) + } + buf := &bytes.Buffer{} + if err := hdr.Marshal(buf); err != nil { + return nil, fmt.Errorf("failed to serialize header: %w", err) + } + return buf.Bytes(), nil +} + +// DecryptHeader decrypts a detached header and returns a file key. +// +// The detached header can be produced by [ExtractHeader], and the +// returned file key can be used with [NewInjectedFileKeyIdentity]. +// +// It is the caller's responsibility to keep track of what file the +// returned file key decrypts, and to ensure the file key is not used +// for any other purpose. +func DecryptHeader(header []byte, identities ...Identity) ([]byte, error) { + hdr, _, err := format.Parse(bytes.NewReader(header)) + if err != nil { + return nil, fmt.Errorf("failed to read header: %w", err) + } + return decryptHdr(hdr, identities...) +} + +type injectedFileKeyIdentity struct { + fileKey []byte +} + +// NewInjectedFileKeyIdentity returns an [Identity] that always produces +// a fixed file key, such as one returned by [DecryptHeader]. +func NewInjectedFileKeyIdentity(fileKey []byte) Identity { + return injectedFileKeyIdentity{fileKey} +} + +func (i injectedFileKeyIdentity) Unwrap(stanzas []*Stanza) (fileKey []byte, err error) { + return i.fileKey, nil +} diff --git a/age_test.go b/age_test.go index 8cf68670..ef870d47 100644 --- a/age_test.go +++ b/age_test.go @@ -284,3 +284,46 @@ func TestLabels(t *testing.T) { t.Errorf("expected pqc+foo mixed with foo+pqc to work, got %v", err) } } + +func TestDetachedHeader(t *testing.T) { + i, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + + buf := &bytes.Buffer{} + w, err := age.Encrypt(buf, i.Recipient()) + if err != nil { + t.Fatal(err) + } + if _, err := io.WriteString(w, helloWorld); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + encrypted := buf.Bytes() + + header, err := age.ExtractHeader(bytes.NewReader(encrypted)) + if err != nil { + t.Fatal(err) + } + + fileKey, err := age.DecryptHeader(header, i) + if err != nil { + t.Fatal(err) + } + + identity := age.NewInjectedFileKeyIdentity(fileKey) + out, err := age.Decrypt(bytes.NewReader(encrypted), identity) + if err != nil { + t.Fatal(err) + } + outBytes, err := io.ReadAll(out) + if err != nil { + t.Fatal(err) + } + if string(outBytes) != helloWorld { + t.Errorf("wrong data: %q, expected %q", outBytes, helloWorld) + } +}