Skip to content

Commit c85f632

Browse files
author
tjroach
committed
fix token serialization
1 parent 7253c7e commit c85f632

File tree

2 files changed

+173
-8
lines changed
  • aws-auth-cognito/src
    • main/java/com/amplifyframework/statemachine/codegen/data
    • test/java/com/amplifyframework/auth/cognito/data

2 files changed

+173
-8
lines changed

aws-auth-cognito/src/main/java/com/amplifyframework/statemachine/codegen/data/Tokens.kt

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ import java.time.Instant
2222
import kotlin.text.Charsets.UTF_8
2323
import kotlinx.serialization.Serializable
2424
import org.json.JSONObject
25+
import kotlinx.serialization.KSerializer
26+
import kotlinx.serialization.SerializationException
27+
import kotlinx.serialization.builtins.serializer
28+
import kotlinx.serialization.encoding.Decoder
29+
import kotlinx.serialization.encoding.Encoder
30+
import kotlinx.serialization.json.JsonDecoder
31+
import kotlinx.serialization.json.JsonObject
32+
import kotlinx.serialization.json.JsonPrimitive
33+
import kotlinx.serialization.json.jsonPrimitive
2534

2635
internal abstract class Jwt {
2736
abstract val tokenValue: String
@@ -72,7 +81,7 @@ internal abstract class Jwt {
7281
}
7382

7483
// See https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-the-id-token.html
75-
@Serializable
84+
@Serializable(with = IdTokenAsStringSerializer::class)
7685
internal class IdToken(override val tokenValue: String) : Jwt() {
7786
val userSub: String?
7887
get() = getClaim(Claim.UserSub)
@@ -84,7 +93,7 @@ internal class IdToken(override val tokenValue: String) : Jwt() {
8493
}
8594

8695
// See https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-the-access-token.html
87-
@Serializable
96+
@Serializable(with = AccessTokenAsStringSerializer::class)
8897
internal class AccessToken(override val tokenValue: String) : Jwt() {
8998
val tokenRevocationId: String?
9099
get() = getClaim(Claim.TokenRevocationId)
@@ -98,7 +107,7 @@ internal class AccessToken(override val tokenValue: String) : Jwt() {
98107
}
99108

100109
// Refresh token is just an opaque base64 string
101-
@Serializable
110+
@Serializable(with = RefreshTokenAsStringSerializer::class)
102111
@JvmInline
103112
internal value class RefreshToken(val tokenValue: String) {
104113
override fun toString() = tokenValue.mask()
@@ -142,3 +151,46 @@ internal data class CognitoUserPoolTokens(
142151
idToken == other.idToken && accessToken == other.accessToken && refreshToken == other.refreshToken
143152
}
144153
}
154+
155+
/**
156+
* Helper function to extract token value from either flat or nested format
157+
*/
158+
private fun extractTokenValue(decoder: Decoder, tokenType: String): String {
159+
return if (decoder is JsonDecoder) {
160+
when (val element = decoder.decodeJsonElement()) {
161+
is JsonPrimitive -> element.content // Flat format: "token": "value"
162+
is JsonObject -> element["tokenValue"]?.jsonPrimitive?.content
163+
?: throw SerializationException("Missing tokenValue in nested $tokenType")
164+
else -> throw SerializationException("Expected string or object for $tokenType")
165+
}
166+
} else {
167+
decoder.decodeString() // Fallback for non-JSON decoders
168+
}
169+
}
170+
171+
/**
172+
* Serializer for IdToken that maintains string serialization format
173+
*/
174+
internal object IdTokenAsStringSerializer : KSerializer<IdToken> {
175+
override val descriptor = String.serializer().descriptor
176+
override fun serialize(encoder: Encoder, value: IdToken) = encoder.encodeString(value.tokenValue)
177+
override fun deserialize(decoder: Decoder) = IdToken(extractTokenValue(decoder, "IdToken"))
178+
}
179+
180+
/**
181+
* Serializer for AccessToken that maintains string serialization format
182+
*/
183+
internal object AccessTokenAsStringSerializer : KSerializer<AccessToken> {
184+
override val descriptor = String.serializer().descriptor
185+
override fun serialize(encoder: Encoder, value: AccessToken) = encoder.encodeString(value.tokenValue)
186+
override fun deserialize(decoder: Decoder) = AccessToken(extractTokenValue(decoder, "AccessToken"))
187+
}
188+
189+
/**
190+
* Serializer for RefreshToken that maintains string serialization format
191+
*/
192+
internal object RefreshTokenAsStringSerializer : KSerializer<RefreshToken> {
193+
override val descriptor = String.serializer().descriptor
194+
override fun serialize(encoder: Encoder, value: RefreshToken) = encoder.encodeString(value.tokenValue)
195+
override fun deserialize(decoder: Decoder) = RefreshToken(extractTokenValue(decoder, "RefreshToken"))
196+
}

aws-auth-cognito/src/test/java/com/amplifyframework/auth/cognito/data/TokensTest.kt

Lines changed: 118 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ import com.amplifyframework.statemachine.codegen.data.asIdToken
2121
import com.amplifyframework.statemachine.codegen.data.asRefreshToken
2222
import io.kotest.matchers.shouldBe
2323
import java.time.Instant
24+
import kotlinx.serialization.encodeToString
25+
import kotlinx.serialization.json.Json
2426
import org.junit.Test
2527

2628
class TokensTest {
2729
private val tokenString =
2830
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwidXNlcm5hbWUiOiJqZG" +
29-
"9lIiwiaWF0IjoxNzU2OTk4Mjc4LCJleHAiOjE3NTY5OTg1NzgsIm9yaWdpbl9qdGkiOiJhYWFhYWFhYS1iYmJiLWNjY2MtZGRkZC1lZ" +
30-
"WVlZWVlZWVlZWUifQ.3Mvd5WVi1z1GpQ37hEoev6DzYNv9lWNL-fGfQTxUYx4"
31+
"9lIiwiaWF0IjoxNzU2OTk4Mjc4LCJleHAiOjE3NTY5OTg1NzgsIm9yaWdpbl9qdGkiOiJhYWFhYWFhYS1iYmJiLWNjY2MtZGRkZC1lZ" +
32+
"WVlZWVlZWVlZWUifQ.3Mvd5WVi1z1GpQ37hEoev6DzYNv9lWNL-fGfQTxUYx4"
3133

3234
@Test
3335
fun `identity token returns expiry`() {
@@ -86,7 +88,118 @@ class TokensTest {
8688
expiration = null
8789
)
8890
cognitoTokens.toString() shouldBe
89-
"CognitoUserPoolTokens(idToken=eyJh***, accessToken=eyJh***, " +
90-
"refreshToken=eyJh***, expiration=null)"
91+
"CognitoUserPoolTokens(idToken=eyJh***, accessToken=eyJh***, " +
92+
"refreshToken=eyJh***, expiration=null)"
9193
}
92-
}
94+
95+
@Test
96+
fun `non nested tokens are parsed correctly`() {
97+
val flatFormatJson = """
98+
{
99+
"idToken": "$tokenString",
100+
"accessToken": "$tokenString",
101+
"refreshToken": "refresh_token_value",
102+
"expiration": 1756998578
103+
}
104+
""".trimIndent()
105+
106+
val tokens = Json.decodeFromString<CognitoUserPoolTokens>(flatFormatJson)
107+
108+
tokens.idToken?.tokenValue shouldBe tokenString
109+
tokens.accessToken?.tokenValue shouldBe tokenString
110+
tokens.refreshToken?.tokenValue shouldBe "refresh_token_value"
111+
tokens.expiration shouldBe 1756998578L
112+
113+
// Verify JWT parsing still works
114+
tokens.accessToken?.userSub shouldBe "1234567890"
115+
tokens.accessToken?.username shouldBe "jdoe"
116+
}
117+
118+
@Test
119+
fun `nested tokens are read correctly`() {
120+
val nestedFormatJson = """
121+
{
122+
"idToken": {"tokenValue": "$tokenString"},
123+
"accessToken": {"tokenValue": "$tokenString"},
124+
"refreshToken": {"tokenValue": "refresh_token_value"},
125+
"expiration": 1756998578
126+
}
127+
""".trimIndent()
128+
129+
val tokens = Json.decodeFromString<CognitoUserPoolTokens>(nestedFormatJson)
130+
131+
tokens.idToken?.tokenValue shouldBe tokenString
132+
tokens.accessToken?.tokenValue shouldBe tokenString
133+
tokens.refreshToken?.tokenValue shouldBe "refresh_token_value"
134+
tokens.expiration shouldBe 1756998578L
135+
136+
// Verify JWT parsing still works after extracting from nested format
137+
tokens.accessToken?.userSub shouldBe "1234567890"
138+
tokens.accessToken?.username shouldBe "jdoe"
139+
}
140+
141+
@Test
142+
fun `nested tokens are saved as non nested`() {
143+
// Start with nested format
144+
val nestedFormatJson = """
145+
{
146+
"idToken": {"tokenValue": "$tokenString"},
147+
"accessToken": {"tokenValue": "$tokenString"},
148+
"refreshToken": {"tokenValue": "refresh_token_value"},
149+
"expiration": 1756998578
150+
}
151+
""".trimIndent()
152+
153+
// Deserialize nested format
154+
val tokens = Json.decodeFromString<CognitoUserPoolTokens>(nestedFormatJson)
155+
156+
// Serialize back to JSON
157+
val serializedJson = Json.encodeToString(tokens)
158+
159+
// Should now be in flat format
160+
val expectedFlatJson =
161+
"""{"idToken":"$tokenString","accessToken":"$tokenString","refreshToken":"refresh_token_value","expiration":1756998578}"""
162+
serializedJson shouldBe expectedFlatJson
163+
164+
// Verify we can deserialize the flat format again
165+
val tokensFromFlat = Json.decodeFromString<CognitoUserPoolTokens>(serializedJson)
166+
tokensFromFlat.idToken?.tokenValue shouldBe tokenString
167+
tokensFromFlat.accessToken?.tokenValue shouldBe tokenString
168+
tokensFromFlat.refreshToken?.tokenValue shouldBe "refresh_token_value"
169+
}
170+
171+
@Test
172+
fun `flat and nested formats produce identical results`() {
173+
val flatFormatJson = """
174+
{
175+
"idToken": "$tokenString",
176+
"accessToken": "$tokenString",
177+
"refreshToken": "refresh_token_value",
178+
"expiration": 1756998578
179+
}
180+
""".trimIndent()
181+
182+
val nestedFormatJson = """
183+
{
184+
"idToken": {"tokenValue": "$tokenString"},
185+
"accessToken": {"tokenValue": "$tokenString"},
186+
"refreshToken": {"tokenValue": "refresh_token_value"},
187+
"expiration": 1756998578
188+
}
189+
""".trimIndent()
190+
191+
val tokensFromFlat = Json.decodeFromString<CognitoUserPoolTokens>(flatFormatJson)
192+
val tokensFromNested = Json.decodeFromString<CognitoUserPoolTokens>(nestedFormatJson)
193+
194+
// Both should have identical token values
195+
tokensFromFlat.idToken?.tokenValue shouldBe tokensFromNested.idToken?.tokenValue
196+
tokensFromFlat.accessToken?.tokenValue shouldBe tokensFromNested.accessToken?.tokenValue
197+
tokensFromFlat.refreshToken?.tokenValue shouldBe tokensFromNested.refreshToken?.tokenValue
198+
tokensFromFlat.expiration shouldBe tokensFromNested.expiration
199+
200+
// Both should serialize to the same flat format
201+
val serializedFlat = Json.encodeToString(tokensFromFlat)
202+
val serializedNested = Json.encodeToString(tokensFromNested)
203+
serializedFlat shouldBe serializedNested
204+
}
205+
}

0 commit comments

Comments
 (0)