Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions datahub-frontend/app/auth/AuthUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ public class AuthUtils {
public static final String PREFERRED_JWS_ALGORITHM = "preferredJwsAlgorithm";
public static final String PREFERRED_JWS_ALGORITHM_2 = "preferredJwsAlgorithm2";

// Private Key JWT (certificate-based) authentication constants
public static final String PRIVATE_KEY_FILE_PATH = "privateKeyFilePath";
public static final String PUBLIC_KEY_FILE_PATH = "publicKeyFilePath";
public static final String PRIVATE_KEY_PASSWORD = "privateKeyPassword";
public static final String PRIVATE_KEY_JWT_ALGORITHM = "privateKeyJwtAlgorithm";

/**
* Determines whether the inbound request should be forward to downstream Metadata Service. Today,
* this simply checks for the presence of an "Authorization" header or the presence of a valid
Expand Down
72 changes: 65 additions & 7 deletions datahub-frontend/app/auth/sso/oidc/OidcConfigs.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ public class OidcConfigs extends SsoConfigs {
public static final String OIDC_HTTP_RETRY_ATTEMPTS = "auth.oidc.httpRetryAttempts";
public static final String OIDC_HTTP_RETRY_DELAY = "auth.oidc.httpRetryDelay";

// Private Key JWT (certificate-based) authentication configs
public static final String OIDC_PRIVATE_KEY_FILE_PATH = "auth.oidc.privateKeyFilePath";
public static final String OIDC_PUBLIC_KEY_FILE_PATH = "auth.oidc.publicKeyFilePath";
public static final String OIDC_PRIVATE_KEY_PASSWORD = "auth.oidc.privateKeyPassword";
public static final String OIDC_PRIVATE_KEY_JWT_ALGORITHM = "auth.oidc.privateKeyJwtAlgorithm";

/** Default values */
private static final String DEFAULT_OIDC_USERNAME_CLAIM = "email";

Expand All @@ -64,6 +70,8 @@ public class OidcConfigs extends SsoConfigs {
private static final String DEFAULT_OIDC_CONNECT_TIMEOUT = "1000";
private static final String DEFAULT_OIDC_HTTP_RETRY_ATTEMPTS = "3";
private static final String DEFAULT_OIDC_HTTP_RETRY_DELAY = "1000";
private static final String DEFAULT_OIDC_PRIVATE_KEY_JWT_ALGORITHM = "RS256";
private static final String PRIVATE_KEY_JWT_METHOD = "private_key_jwt";

private final String clientId;
private final String clientSecret;
Expand All @@ -90,6 +98,12 @@ public class OidcConfigs extends SsoConfigs {
private final String httpRetryAttempts;
private final String httpRetryDelay;

// Private Key JWT authentication fields
private final Optional<String> privateKeyFilePath;
private final Optional<String> publicKeyFilePath;
private final Optional<String> privateKeyPassword;
private final String privateKeyJwtAlgorithm;

public OidcConfigs(Builder builder) {
super(builder);
this.clientId = builder.clientId;
Expand All @@ -116,6 +130,10 @@ public OidcConfigs(Builder builder) {
this.grantType = builder.grantType;
this.httpRetryAttempts = builder.httpRetryAttempts;
this.httpRetryDelay = builder.httpRetryDelay;
this.privateKeyFilePath = builder.privateKeyFilePath;
this.publicKeyFilePath = builder.publicKeyFilePath;
this.privateKeyPassword = builder.privateKeyPassword;
this.privateKeyJwtAlgorithm = builder.privateKeyJwtAlgorithm;
}

public String getHttpRetryAttempts() {
Expand Down Expand Up @@ -154,24 +172,30 @@ public static class Builder extends SsoConfigs.Builder<Builder> {
private Optional<String> acrValues = Optional.empty();
private String httpRetryAttempts = DEFAULT_OIDC_HTTP_RETRY_ATTEMPTS;
private String httpRetryDelay = DEFAULT_OIDC_HTTP_RETRY_DELAY;
private Optional<String> privateKeyFilePath = Optional.empty();
private Optional<String> publicKeyFilePath = Optional.empty();
private Optional<String> privateKeyPassword = Optional.empty();
private String privateKeyJwtAlgorithm = DEFAULT_OIDC_PRIVATE_KEY_JWT_ALGORITHM;

public Builder from(final com.typesafe.config.Config configs) {
super.from(configs);
clientId = getRequired(configs, OIDC_CLIENT_ID_CONFIG_PATH);
clientSecret = getRequired(configs, OIDC_CLIENT_SECRET_CONFIG_PATH);
discoveryUri = getRequired(configs, OIDC_DISCOVERY_URI_CONFIG_PATH);

clientAuthenticationMethod =
getOptional(
configs,
OIDC_CLIENT_AUTHENTICATION_METHOD_CONFIG_PATH,
DEFAULT_OIDC_CLIENT_AUTHENTICATION_METHOD);

clientSecret = getOptional(configs, OIDC_CLIENT_SECRET_CONFIG_PATH, null);
userNameClaim =
getOptional(configs, OIDC_USERNAME_CLAIM_CONFIG_PATH, DEFAULT_OIDC_USERNAME_CLAIM);
userNameClaimRegex =
getOptional(
configs, OIDC_USERNAME_CLAIM_REGEX_CONFIG_PATH, DEFAULT_OIDC_USERNAME_CLAIM_REGEX);
scope = getOptional(configs, OIDC_SCOPE_CONFIG_PATH, DEFAULT_OIDC_SCOPE);
clientName = getOptional(configs, OIDC_CLIENT_NAME_CONFIG_PATH, DEFAULT_OIDC_CLIENT_NAME);
clientAuthenticationMethod =
getOptional(
configs,
OIDC_CLIENT_AUTHENTICATION_METHOD_CONFIG_PATH,
DEFAULT_OIDC_CLIENT_AUTHENTICATION_METHOD);
jitProvisioningEnabled =
Boolean.parseBoolean(
getOptional(
Expand Down Expand Up @@ -206,6 +230,15 @@ public Builder from(final com.typesafe.config.Config configs) {
httpRetryAttempts =
getOptional(configs, OIDC_HTTP_RETRY_ATTEMPTS, DEFAULT_OIDC_HTTP_RETRY_ATTEMPTS);
httpRetryDelay = getOptional(configs, OIDC_HTTP_RETRY_DELAY, DEFAULT_OIDC_HTTP_RETRY_DELAY);

// Private Key JWT authentication configs
privateKeyFilePath = getOptional(configs, OIDC_PRIVATE_KEY_FILE_PATH);
publicKeyFilePath = getOptional(configs, OIDC_PUBLIC_KEY_FILE_PATH);
privateKeyPassword = getOptional(configs, OIDC_PRIVATE_KEY_PASSWORD);
privateKeyJwtAlgorithm =
getOptional(
configs, OIDC_PRIVATE_KEY_JWT_ALGORITHM, DEFAULT_OIDC_PRIVATE_KEY_JWT_ALGORITHM);

return this;
}

Expand Down Expand Up @@ -276,16 +309,41 @@ public Builder from(final com.typesafe.config.Config configs, final String ssoSe
grantType = Optional.ofNullable(getOptional(configs, OIDC_GRANT_TYPE, null));
acrValues = Optional.ofNullable(getOptional(configs, OIDC_ACR_VALUES, null));

if (jsonNode.has(PRIVATE_KEY_FILE_PATH)) {
privateKeyFilePath = Optional.of(jsonNode.get(PRIVATE_KEY_FILE_PATH).asText());
}
if (jsonNode.has(PUBLIC_KEY_FILE_PATH)) {
publicKeyFilePath = Optional.of(jsonNode.get(PUBLIC_KEY_FILE_PATH).asText());
}
if (jsonNode.has(PRIVATE_KEY_PASSWORD)) {
privateKeyPassword = Optional.of(jsonNode.get(PRIVATE_KEY_PASSWORD).asText());
}
if (jsonNode.has(PRIVATE_KEY_JWT_ALGORITHM)) {
privateKeyJwtAlgorithm = jsonNode.get(PRIVATE_KEY_JWT_ALGORITHM).asText();
}

return this;
}

public OidcConfigs build() {
Objects.requireNonNull(oidcEnabled, "oidcEnabled is required");
Objects.requireNonNull(clientId, "clientId is required");
Objects.requireNonNull(clientSecret, "clientSecret is required");
Objects.requireNonNull(discoveryUri, "discoveryUri is required");
Objects.requireNonNull(authBaseUrl, "authBaseUrl is required");

if (PRIVATE_KEY_JWT_METHOD.equals(clientAuthenticationMethod)) {
if (privateKeyFilePath.isEmpty()) {
throw new IllegalArgumentException(
"privateKeyFilePath is required when using private_key_jwt authentication");
}
if (publicKeyFilePath.isEmpty()) {
throw new IllegalArgumentException(
"publicKeyFilePath is required when using private_key_jwt authentication");
}
} else {
Objects.requireNonNull(clientSecret, "clientSecret is required");
}

return new OidcConfigs(this);
}
}
Expand Down
155 changes: 155 additions & 0 deletions datahub-frontend/app/auth/sso/oidc/PrivateKeyJwtUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package auth.sso.oidc;

import java.io.FileReader;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.Security;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPrivateKey;
import java.util.Base64;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.openssl.PEMDecryptorProvider;
import org.bouncycastle.openssl.PEMEncryptedKeyPair;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import org.bouncycastle.openssl.jcajce.JcePEMDecryptorProviderBuilder;

/**
* Utility class for loading private keys and certificates for private_key_jwt client
* authentication (RFC 7523).
*/
public final class PrivateKeyJwtUtils {

static {
// Register BouncyCastle as a JCA security provider for encrypted key support
if (Security.getProvider(BouncyCastleProvider.PROVIDER_NAME) == null) {
Security.addProvider(new BouncyCastleProvider());
}
}

private PrivateKeyJwtUtils() {}

/**
* Loads an RSA private key from a PEM file.
*
* <p>Supports both encrypted (password-protected) and unencrypted PEM files. The PEM file can
* contain:
*
* <ul>
* <li>PKCS#8 format (BEGIN PRIVATE KEY / BEGIN ENCRYPTED PRIVATE KEY)
* <li>Traditional RSA format (BEGIN RSA PRIVATE KEY)
* </ul>
*
* @param filePath Path to the PEM file containing the private key
* @param password Password for encrypted keys, or null for unencrypted keys
* @return The RSA private key
* @throws IOException If the file cannot be read
* @throws IllegalArgumentException If the key format is invalid or unsupported
*/
public static RSAPrivateKey loadPrivateKey(@Nonnull String filePath, @Nullable String password)
throws IOException {
try (PEMParser pemParser = new PEMParser(new FileReader(filePath))) {
Object pemObject = pemParser.readObject();

if (pemObject == null) {
throw new IllegalArgumentException(
"No PEM object found in file: " + filePath);
}

JcaPEMKeyConverter converter = new JcaPEMKeyConverter();
PrivateKeyInfo privateKeyInfo;

if (pemObject instanceof PEMEncryptedKeyPair) {
// Encrypted traditional format (e.g., BEGIN RSA PRIVATE KEY with encryption)
if (password == null || password.isEmpty()) {
throw new IllegalArgumentException(
"Private key is encrypted but no password was provided");
}
PEMDecryptorProvider decryptor =
new JcePEMDecryptorProviderBuilder().build(password.toCharArray());
PEMKeyPair keyPair = ((PEMEncryptedKeyPair) pemObject).decryptKeyPair(decryptor);
privateKeyInfo = keyPair.getPrivateKeyInfo();

} else if (pemObject instanceof PEMKeyPair) {
// Unencrypted traditional format (BEGIN RSA PRIVATE KEY)
privateKeyInfo = ((PEMKeyPair) pemObject).getPrivateKeyInfo();

} else if (pemObject instanceof PrivateKeyInfo) {
// PKCS#8 format (BEGIN PRIVATE KEY)
privateKeyInfo = (PrivateKeyInfo) pemObject;

} else if (pemObject instanceof org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo) {
// Encrypted PKCS#8 format (BEGIN ENCRYPTED PRIVATE KEY)
if (password == null || password.isEmpty()) {
throw new IllegalArgumentException(
"Private key is encrypted but no password was provided");
}
org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo encryptedInfo =
(org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo) pemObject;
org.bouncycastle.operator.InputDecryptorProvider decryptorProvider =
new org.bouncycastle.openssl.jcajce.JceOpenSSLPKCS8DecryptorProviderBuilder()
.build(password.toCharArray());
privateKeyInfo = encryptedInfo.decryptPrivateKeyInfo(decryptorProvider);

} else {
throw new IllegalArgumentException(
"Unsupported PEM object type: " + pemObject.getClass().getName());
}

return (RSAPrivateKey) converter.getPrivateKey(privateKeyInfo);

} catch (org.bouncycastle.operator.OperatorCreationException
| org.bouncycastle.pkcs.PKCSException e) {
throw new IllegalArgumentException("Failed to decrypt private key: " + e.getMessage(), e);
}
}

/**
* Loads an X.509 certificate from a PEM file.
*
* @param filePath Path to the PEM file containing the certificate
* @return The X.509 certificate
* @throws IOException If the file cannot be read
* @throws CertificateException If the certificate is invalid
*/
public static X509Certificate loadCertificate(@Nonnull String filePath)
throws IOException, CertificateException {
CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
try (var inputStream = Files.newInputStream(Path.of(filePath))) {
return (X509Certificate) certFactory.generateCertificate(inputStream);
}
}

/**
* Computes the SHA-256 thumbprint of an X.509 certificate, Base64URL encoded.
*
* <p>This is used as the key ID (kid) in JWT headers for private_key_jwt authentication. Azure AD
* and other IdPs use this format to identify which certificate was used to sign the JWT.
*
* @param certificate The X.509 certificate
* @return The SHA-256 thumbprint, Base64URL encoded (no padding)
* @throws CertificateEncodingException If the certificate cannot be encoded
*/
public static String computeThumbprint(@Nonnull X509Certificate certificate)
throws CertificateEncodingException {
try {
MessageDigest digest = MessageDigest.getInstance("SHA-256");
byte[] hash = digest.digest(certificate.getEncoded());
return Base64.getUrlEncoder().withoutPadding().encodeToString(hash);
} catch (NoSuchAlgorithmException e) {
// SHA-256 is always available in Java
throw new RuntimeException("SHA-256 algorithm not available", e);
}
}

}
Loading
Loading