1111
1212using System ;
1313using System . Collections . Generic ;
14+ using System . Threading . Tasks ;
1415using static libsignalservice . messages . SignalServiceDataMessage ;
1516using static libsignalservice . push . DataMessage ;
1617
@@ -56,84 +57,135 @@ public OutgoingPushMessage Encrypt(SignalProtocolAddress destination, byte[] unp
5657 /// Decrypt a received <see cref="SignalServiceEnvelope"/>
5758 /// </summary>
5859 /// <param name="envelope">The received SignalServiceEnvelope</param>
60+ /// <param name="callback">Optional callback to call during the decrypt process before it is acked</param>
5961 /// <returns>a decrypted SignalServiceContent</returns>
60- public SignalServiceContent Decrypt ( SignalServiceEnvelope envelope )
62+ public async Task < SignalServiceContent > Decrypt ( SignalServiceEnvelope envelope , Func < SignalServiceContent , Task > callback = null )
6163 {
64+ Func < byte [ ] , Task > callback_func = null ;
65+ if ( callback != null )
66+ {
67+ callback_func = async ( data ) => await callback ( await DecryptComplete ( envelope , data ) ) ;
68+ }
6269 try
6370 {
64- SignalServiceContent content = new SignalServiceContent ( ) ;
65-
71+ byte [ ] decrypted_data = null ;
6672 if ( envelope . HasLegacyMessage ( ) )
6773 {
68- DataMessage message = DataMessage . Parser . ParseFrom ( Decrypt ( envelope , envelope . GetLegacyMessage ( ) ) ) ;
69- content = new SignalServiceContent ( )
70- {
71- Message = CreateSignalServiceMessage ( envelope , message )
72- } ;
74+ decrypted_data = await Decrypt ( envelope , envelope . GetLegacyMessage ( ) , callback_func ) ;
7375 }
7476 else if ( envelope . HasContent ( ) )
7577 {
76- Content message = Content . Parser . ParseFrom ( Decrypt ( envelope , envelope . GetContent ( ) ) ) ;
78+ decrypted_data = await Decrypt ( envelope , envelope . GetContent ( ) , callback_func ) ;
79+ }
80+ if ( callback_func != null )
81+ {
82+ return null ;
83+ }
84+ return await DecryptComplete ( envelope , decrypted_data ) ;
85+ }
86+ catch ( InvalidProtocolBufferException e )
87+ {
88+ throw new InvalidMessageException ( e ) ;
89+ }
90+ }
91+ private Task < SignalServiceContent > DecryptComplete ( SignalServiceEnvelope envelope , byte [ ] decrypted_data )
92+ {
93+ SignalServiceContent content = new SignalServiceContent ( ) ;
94+
95+ if ( envelope . HasLegacyMessage ( ) )
96+ {
97+ DataMessage message = DataMessage . Parser . ParseFrom ( decrypted_data ) ;
98+ content = new SignalServiceContent ( )
99+ {
100+ Message = CreateSignalServiceMessage ( envelope , message )
101+ } ;
102+ }
103+ else if ( envelope . HasContent ( ) )
104+ {
105+ Content message = Content . Parser . ParseFrom ( decrypted_data ) ;
77106
78- if ( message . DataMessageOneofCase == Content . DataMessageOneofOneofCase . DataMessage )
107+ if ( message . DataMessageOneofCase == Content . DataMessageOneofOneofCase . DataMessage )
108+ {
109+ content = new SignalServiceContent ( )
79110 {
80- content = new SignalServiceContent ( )
81- {
82- Message = CreateSignalServiceMessage ( envelope , message . DataMessage )
83- } ;
84- }
85- else if ( message . SyncMessageOneofCase == Content . SyncMessageOneofOneofCase . SyncMessage && LocalAddress . E164number == envelope . GetSource ( ) )
111+ Message = CreateSignalServiceMessage ( envelope , message . DataMessage )
112+ } ;
113+ }
114+ else if ( message . SyncMessageOneofCase == Content . SyncMessageOneofOneofCase . SyncMessage && LocalAddress . E164number == envelope . GetSource ( ) )
115+ {
116+ content = new SignalServiceContent ( )
86117 {
87- content = new SignalServiceContent ( )
88- {
89- SynchronizeMessage = CreateSynchronizeMessage ( envelope , message . SyncMessage )
90- } ;
91- }
92- else if ( message . CallMessageOneofCase == Content . CallMessageOneofOneofCase . CallMessage )
118+ SynchronizeMessage = CreateSynchronizeMessage ( envelope , message . SyncMessage )
119+ } ;
120+ }
121+ else if ( message . CallMessageOneofCase == Content . CallMessageOneofOneofCase . CallMessage )
122+ {
123+ content = new SignalServiceContent ( )
93124 {
94- content = new SignalServiceContent ( )
95- {
96- CallMessage = CreateCallMessage ( message . CallMessage )
97- } ;
98- }
99- else if ( message . ReceiptMessageOneofCase == Content . ReceiptMessageOneofOneofCase . ReceiptMessage )
125+ CallMessage = CreateCallMessage ( message . CallMessage )
126+ } ;
127+ }
128+ else if ( message . ReceiptMessageOneofCase == Content . ReceiptMessageOneofOneofCase . ReceiptMessage )
129+ {
130+ content = new SignalServiceContent ( )
100131 {
101- content = new SignalServiceContent ( )
102- {
103- ReadMessage = CreateReceiptMessage ( envelope , message . ReceiptMessage )
104- } ;
105- }
132+ ReadMessage = CreateReceiptMessage ( envelope , message . ReceiptMessage )
133+ } ;
106134 }
107-
108- return content ;
109135 }
110- catch ( InvalidProtocolBufferException e )
136+
137+ return Task . FromResult ( content ) ;
138+ }
139+ private class DecryptionCallbackHandler : DecryptionCallback
140+ {
141+ public Task handlePlaintext ( byte [ ] plaintext , SessionRecord sessionRecord )
111142 {
112- throw new InvalidMessageException ( e ) ;
143+ return callback ( GetStrippedMessage ( sessionRecord , plaintext ) ) ;
113144 }
145+ public SessionCipher sessionCipher ;
146+ public Func < byte [ ] , Task > callback ;
114147 }
115-
116- private byte [ ] Decrypt ( SignalServiceEnvelope envelope , byte [ ] ciphertext )
148+ private async Task < byte [ ] > Decrypt ( SignalServiceEnvelope envelope , byte [ ] ciphertext , Func < byte [ ] , Task > callback = null )
117149
118150 {
119151 SignalProtocolAddress sourceAddress = new SignalProtocolAddress ( envelope . GetSource ( ) , ( uint ) envelope . GetSourceDevice ( ) ) ;
120152 SessionCipher sessionCipher = new SessionCipher ( SignalProtocolStore , sourceAddress ) ;
121153
122154 byte [ ] paddedMessage ;
123-
155+ DecryptionCallbackHandler callback_handler = null ;
156+ if ( callback != null )
157+ callback_handler = new DecryptionCallbackHandler { callback = callback , sessionCipher = sessionCipher } ;
124158 if ( envelope . IsPreKeySignalMessage ( ) )
125159 {
160+ if ( callback_handler != null )
161+ {
162+ await sessionCipher . decrypt ( new PreKeySignalMessage ( ciphertext ) , callback_handler ) ;
163+ return null ;
164+ }
126165 paddedMessage = sessionCipher . decrypt ( new PreKeySignalMessage ( ciphertext ) ) ;
127166 }
128167 else if ( envelope . IsSignalMessage ( ) )
129168 {
169+ if ( callback_handler != null )
170+ {
171+ await sessionCipher . decrypt ( new SignalMessage ( ciphertext ) , callback_handler ) ;
172+ return null ;
173+ }
130174 paddedMessage = sessionCipher . decrypt ( new SignalMessage ( ciphertext ) ) ;
131175 }
132176 else
133177 {
134178 throw new InvalidMessageException ( "Unknown type: " + envelope . GetEnvelopeType ( ) + " from " + envelope . GetSource ( ) ) ;
135179 }
136-
180+ return GetStrippedMessage ( sessionCipher , paddedMessage ) ;
181+ }
182+ private static byte [ ] GetStrippedMessage ( SessionRecord sessionRecord , byte [ ] paddedMessage )
183+ {
184+ PushTransportDetails transportDetails = new PushTransportDetails ( sessionRecord . getSessionState ( ) . getSessionVersion ( ) ) ;
185+ return transportDetails . GetStrippedPaddingMessageBody ( paddedMessage ) ;
186+ }
187+ private static byte [ ] GetStrippedMessage ( SessionCipher sessionCipher , byte [ ] paddedMessage )
188+ {
137189 PushTransportDetails transportDetails = new PushTransportDetails ( sessionCipher . getSessionVersion ( ) ) ;
138190 return transportDetails . GetStrippedPaddingMessageBody ( paddedMessage ) ;
139191 }
@@ -152,7 +204,7 @@ private SignalServiceDataMessage CreateSignalServiceMessage(SignalServiceEnvelop
152204 attachments . Add ( CreateAttachmentPointer ( envelope . GetRelay ( ) , pointer ) ) ;
153205 }
154206
155- if ( content . TimestampOneofCase == DataMessage . TimestampOneofOneofCase . Timestamp && ( long ) content . Timestamp != envelope . GetTimestamp ( ) )
207+ if ( content . TimestampOneofCase == DataMessage . TimestampOneofOneofCase . Timestamp && ( long ) content . Timestamp != envelope . GetTimestamp ( ) )
156208 {
157209 throw new InvalidMessageException ( "Timestamps don't match: " + content . Timestamp + " vs " + envelope . GetTimestamp ( ) ) ;
158210 }
@@ -290,7 +342,7 @@ private SignalServiceCallMessage CreateCallMessage(CallMessage content)
290342 var l = new List < IceUpdateMessage > ( ) ;
291343 foreach ( var u in content . IceUpdate )
292344 {
293- l . Add ( new IceUpdateMessage ( )
345+ l . Add ( new IceUpdateMessage ( )
294346 {
295347 Id = u . Id ,
296348 SdpMid = u . SdpMid ,
@@ -374,7 +426,7 @@ private SignalServiceDataMessage.SignalServiceQuote CreateQuote(SignalServiceEnv
374426 pointer . ThumbnailOneofCase == Types . Quote . Types . QuotedAttachment . ThumbnailOneofOneofCase . Thumbnail ? CreateAttachmentPointer ( envelope . GetRelay ( ) , pointer . Thumbnail ) : null ) ) ;
375427 }
376428
377- return new SignalServiceDataMessage . SignalServiceQuote ( ( long ) content . Quote . Id ,
429+ return new SignalServiceDataMessage . SignalServiceQuote ( ( long ) content . Quote . Id ,
378430 new SignalServiceAddress ( content . Quote . Author ) ,
379431 content . Quote . Text ,
380432 attachments ) ;
0 commit comments