@@ -17,6 +17,7 @@ import (
1717 "github.com/ProtonMail/go-crypto/openpgp/ecdh"
1818 "github.com/ProtonMail/go-crypto/openpgp/elgamal"
1919 "github.com/ProtonMail/go-crypto/openpgp/errors"
20+ "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm"
2021 "github.com/ProtonMail/go-crypto/openpgp/internal/encoding"
2122 "github.com/ProtonMail/go-crypto/openpgp/mlkem_ecdh"
2223 "github.com/ProtonMail/go-crypto/openpgp/symmetric"
@@ -35,12 +36,15 @@ type EncryptedKey struct {
3536 CipherFunc CipherFunction // only valid after a successful Decrypt for a v3 packet
3637 Key []byte // only valid after a successful Decrypt
3738
38- encryptedMPI1 encoding.Field // Only valid in RSA, Elgamal, ECDH, AEAD and PQC keys
39+ encryptedMPI1 encoding.Field // Only valid in RSA, Elgamal, ECDH, and PQC keys
3940 encryptedMPI2 encoding.Field // Only valid in Elgamal, ECDH and PQC keys
4041 encryptedMPI3 encoding.Field // Only valid in PQC keys
4142 ephemeralPublicX25519 * x25519.PublicKey // used for x25519
4243 ephemeralPublicX448 * x448.PublicKey // used for x448
4344 encryptedSession []byte // used for x25519 and x448
45+
46+ nonce []byte
47+ aeadMode algorithm.AEADMode
4448}
4549
4650func (e * EncryptedKey ) parse (r io.Reader ) (err error ) {
@@ -138,11 +142,20 @@ func (e *EncryptedKey) parse(r io.Reader) (err error) {
138142 return
139143 }
140144 case ExperimentalPubKeyAlgoAEAD :
141- ivAndCiphertext , err := io .ReadAll (r )
142- if err != nil {
143- return err
145+ var aeadMode [1 ]byte
146+ if _ , err = readFull (r , aeadMode [:]); err != nil {
147+ return
148+ }
149+ e .aeadMode = algorithm .AEADMode (aeadMode [0 ])
150+ nonceLength := e .aeadMode .NonceLength ()
151+ e .nonce = make ([]byte , nonceLength )
152+ if _ , err = readFull (r , e .nonce ); err != nil {
153+ return
154+ }
155+ e .encryptedMPI1 = new (encoding.ShortByteString )
156+ if _ , err = e .encryptedMPI1 .ReadFrom (r ); err != nil {
157+ return
144158 }
145- e .encryptedMPI1 = encoding .NewOctetArray (ivAndCiphertext )
146159 case PubKeyAlgoMlkem768X25519 :
147160 if e .encryptedMPI1 , e .encryptedMPI2 , e .encryptedMPI3 , cipherFunction , err = mlkem_ecdh .DecodeFields (r , 32 , 1088 , e .Version == 6 ); err != nil {
148161 return err
@@ -211,7 +224,7 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey, config *Config) error {
211224 b , err = x448 .Decrypt (priv .PrivateKey .(* x448.PrivateKey ), e .ephemeralPublicX448 , e .encryptedSession )
212225 case ExperimentalPubKeyAlgoAEAD :
213226 priv := priv .PrivateKey .(* symmetric.AEADPrivateKey )
214- b , err = priv .Decrypt (e .encryptedMPI1 .Bytes (), priv . PublicKey . AEADMode )
227+ b , err = priv .Decrypt (e .nonce , e . encryptedMPI1 .Bytes (), e . aeadMode )
215228 case PubKeyAlgoMlkem768X25519 , PubKeyAlgoMlkem1024X448 :
216229 ecE := e .encryptedMPI1 .Bytes ()
217230 kE := e .encryptedMPI2 .Bytes ()
@@ -453,7 +466,7 @@ func SerializeEncryptedKeyAEADwithHiddenOption(w io.Writer, pub *PublicKey, ciph
453466 case PubKeyAlgoX448 :
454467 return serializeEncryptedKeyX448 (w , config .Random (), buf [:lenHeaderWritten ], pub .PublicKey .(* x448.PublicKey ), keyBlock , byte (cipherFunc ), version )
455468 case ExperimentalPubKeyAlgoAEAD :
456- return serializeEncryptedKeyAEAD (w , config .Random (), buf [:lenHeaderWritten ], pub .PublicKey .(* symmetric.AEADPublicKey ), keyBlock )
469+ return serializeEncryptedKeyAEAD (w , config .Random (), buf [:lenHeaderWritten ], pub .PublicKey .(* symmetric.AEADPublicKey ), keyBlock , config . AEAD () )
457470 case PubKeyAlgoMlkem768X25519 , PubKeyAlgoMlkem1024X448 :
458471 return serializeEncryptedKeyMlkem (w , config .Random (), buf [:lenHeaderWritten ], pub .PublicKey .(* mlkem_ecdh.PublicKey ), keyBlock , byte (cipherFunc ), version )
459472 case PubKeyAlgoDSA , PubKeyAlgoRSASignOnly , ExperimentalPubKeyAlgoHMAC :
@@ -627,16 +640,20 @@ func serializeEncryptedKeyX448(w io.Writer, rand io.Reader, header []byte, pub *
627640 return x448 .EncodeFields (w , ephemeralPublicX448 , ciphertext , cipherFunc , version == 6 )
628641}
629642
630- func serializeEncryptedKeyAEAD (w io.Writer , rand io.Reader , header []byte , pub * symmetric.AEADPublicKey , keyBlock []byte ) error {
631- mode := pub .AEADMode
632- iv , ciphertext , err := pub .Encrypt (rand , keyBlock , mode )
643+ func serializeEncryptedKeyAEAD (w io.Writer , rand io.Reader , header []byte , pub * symmetric.AEADPublicKey , keyBlock []byte , config * AEADConfig ) error {
644+ mode := algorithm .AEADMode ( config . Mode ())
645+ iv , ciphertextRaw , err := pub .Encrypt (rand , keyBlock , mode )
633646 if err != nil {
634647 return errors .InvalidArgumentError ("AEAD encryption failed: " + err .Error ())
635648 }
636649
650+ ciphertextShortByteString := encoding .NewShortByteString (ciphertextRaw )
651+
652+ buffer := append ([]byte {byte (mode )}, iv ... )
653+ buffer = append (buffer , ciphertextShortByteString .EncodedBytes ()... )
654+
637655 packetLen := len (header ) /* header length */
638- packetLen += int (len (iv ))
639- packetLen += int (len (ciphertext ))
656+ packetLen += int (len (buffer ))
640657
641658 err = serializeHeader (w , packetTypeEncryptedKey , packetLen )
642659 if err != nil {
@@ -648,12 +665,7 @@ func serializeEncryptedKeyAEAD(w io.Writer, rand io.Reader, header []byte, pub *
648665 return err
649666 }
650667
651- _ , err = w .Write (iv [:])
652- if err != nil {
653- return err
654- }
655-
656- _ , err = w .Write (ciphertext )
668+ _ , err = w .Write (buffer )
657669 return err
658670}
659671
0 commit comments