Skip to content

Commit b671626

Browse files
committed
age: add ExtractHeader, DecryptHeader, and NewInjectedFileKeyIdentity
1 parent 0447d8d commit b671626

File tree

2 files changed

+110
-10
lines changed

2 files changed

+110
-10
lines changed

age.go

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
package age
4747

4848
import (
49+
"bytes"
4950
"crypto/hmac"
5051
"crypto/rand"
5152
"errors"
@@ -207,22 +208,37 @@ func (*NoIdentityMatchError) Error() string {
207208
// If no identity matches the encrypted file, the returned error will be of type
208209
// [NoIdentityMatchError].
209210
func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
210-
if len(identities) == 0 {
211-
return nil, errors.New("no identities specified")
212-
}
213-
214211
hdr, payload, err := format.Parse(src)
215212
if err != nil {
216213
return nil, fmt.Errorf("failed to read header: %w", err)
217214
}
218215

216+
fileKey, err := decryptHdr(hdr, identities...)
217+
if err != nil {
218+
return nil, err
219+
}
220+
221+
nonce := make([]byte, streamNonceSize)
222+
if _, err := io.ReadFull(payload, nonce); err != nil {
223+
return nil, fmt.Errorf("failed to read nonce: %w", err)
224+
}
225+
226+
return stream.NewReader(streamKey(fileKey, nonce), payload)
227+
}
228+
229+
func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) {
230+
if len(identities) == 0 {
231+
return nil, errors.New("no identities specified")
232+
}
233+
219234
stanzas := make([]*Stanza, 0, len(hdr.Recipients))
220235
for _, s := range hdr.Recipients {
221236
stanzas = append(stanzas, (*Stanza)(s))
222237
}
223238
errNoMatch := &NoIdentityMatchError{}
224239
var fileKey []byte
225240
for _, id := range identities {
241+
var err error
226242
fileKey, err = id.Unwrap(stanzas)
227243
if errors.Is(err, ErrIncorrectIdentity) {
228244
errNoMatch.Errors = append(errNoMatch.Errors, err)
@@ -244,12 +260,7 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
244260
return nil, errors.New("bad header MAC")
245261
}
246262

247-
nonce := make([]byte, streamNonceSize)
248-
if _, err := io.ReadFull(payload, nonce); err != nil {
249-
return nil, fmt.Errorf("failed to read nonce: %w", err)
250-
}
251-
252-
return stream.NewReader(streamKey(fileKey, nonce), payload)
263+
return fileKey, nil
253264
}
254265

255266
// multiUnwrap is a helper that implements Identity.Unwrap in terms of a
@@ -270,3 +281,49 @@ func multiUnwrap(unwrap func(*Stanza) ([]byte, error), stanzas []*Stanza) ([]byt
270281
}
271282
return nil, ErrIncorrectIdentity
272283
}
284+
285+
// ExtractHeader returns a detched header from the src file.
286+
//
287+
// The detached header can be decrypted with [DecryptHeader] and then the file
288+
// key can be used with [NewInjectedFileKeyIdentity].
289+
func ExtractHeader(src io.Reader) ([]byte, error) {
290+
hdr, _, err := format.Parse(src)
291+
if err != nil {
292+
return nil, fmt.Errorf("failed to read header: %w", err)
293+
}
294+
buf := &bytes.Buffer{}
295+
if err := hdr.Marshal(buf); err != nil {
296+
return nil, fmt.Errorf("failed to serialize header: %w", err)
297+
}
298+
return buf.Bytes(), nil
299+
}
300+
301+
// DecryptHeader decrypts a detached header and returns a file key.
302+
//
303+
// The detached header can be produced by [ExtractHeader], and the
304+
// returned file key can be used with [NewInjectedFileKeyIdentity].
305+
//
306+
// It is the caller's responsibility to keep track of what file the
307+
// returned file key decrypts, and to ensure the file key is not used
308+
// for any other purpose.
309+
func DecryptHeader(header []byte, identities ...Identity) ([]byte, error) {
310+
hdr, _, err := format.Parse(bytes.NewReader(header))
311+
if err != nil {
312+
return nil, fmt.Errorf("failed to read header: %w", err)
313+
}
314+
return decryptHdr(hdr, identities...)
315+
}
316+
317+
type injectedFileKeyIdentity struct {
318+
fileKey []byte
319+
}
320+
321+
// NewInjectedFileKeyIdentity returns an [Identity] that always produces
322+
// a fixed file key, such as one returned by [DecryptHeader].
323+
func NewInjectedFileKeyIdentity(fileKey []byte) Identity {
324+
return injectedFileKeyIdentity{fileKey}
325+
}
326+
327+
func (i injectedFileKeyIdentity) Unwrap(stanzas []*Stanza) (fileKey []byte, err error) {
328+
return i.fileKey, nil
329+
}

age_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,46 @@ func TestLabels(t *testing.T) {
284284
t.Errorf("expected pqc+foo mixed with foo+pqc to work, got %v", err)
285285
}
286286
}
287+
288+
func TestDetachedHeader(t *testing.T) {
289+
i, err := age.GenerateX25519Identity()
290+
if err != nil {
291+
t.Fatal(err)
292+
}
293+
294+
buf := &bytes.Buffer{}
295+
w, err := age.Encrypt(buf, i.Recipient())
296+
if err != nil {
297+
t.Fatal(err)
298+
}
299+
if _, err := io.WriteString(w, helloWorld); err != nil {
300+
t.Fatal(err)
301+
}
302+
if err := w.Close(); err != nil {
303+
t.Fatal(err)
304+
}
305+
encrypted := buf.Bytes()
306+
307+
header, err := age.ExtractHeader(bytes.NewReader(encrypted))
308+
if err != nil {
309+
t.Fatal(err)
310+
}
311+
312+
fileKey, err := age.DecryptHeader(header, i)
313+
if err != nil {
314+
t.Fatal(err)
315+
}
316+
317+
identity := age.NewInjectedFileKeyIdentity(fileKey)
318+
out, err := age.Decrypt(bytes.NewReader(encrypted), identity)
319+
if err != nil {
320+
t.Fatal(err)
321+
}
322+
outBytes, err := io.ReadAll(out)
323+
if err != nil {
324+
t.Fatal(err)
325+
}
326+
if string(outBytes) != helloWorld {
327+
t.Errorf("wrong data: %q, expected %q", outBytes, helloWorld)
328+
}
329+
}

0 commit comments

Comments
 (0)