From 49e90eaede894427c56483d94b83792a538c8362 Mon Sep 17 00:00:00 2001 From: Scott Hamrick <2623452+cshamrick@users.noreply.github.com> Date: Fri, 5 Dec 2025 09:31:01 -0600 Subject: [PATCH] feat(sdk): Implement pluggable assertion binding and validation framework Signed-off-by: Scott Hamrick <2623452+cshamrick@users.noreply.github.com> --- .../opentdf/platform/sdk/AssertionBinder.java | 15 ++ .../platform/sdk/AssertionRegistry.java | 44 +++++ .../platform/sdk/AssertionValidator.java | 43 +++++ .../sdk/AssertionVerificationMode.java | 7 + .../java/io/opentdf/platform/sdk/Config.java | 41 +++- .../sdk/ConfigBasedAssertionBinder.java | 42 ++++ .../platform/sdk/DEKAssertionValidator.java | 50 +++++ .../platform/sdk/KeyAssertionBinder.java | 66 +++++++ .../platform/sdk/KeyAssertionValidator.java | 73 +++++++ .../io/opentdf/platform/sdk/Manifest.java | 149 +++++++++++++-- .../java/io/opentdf/platform/sdk/SDK.java | 2 + .../sdk/SystemMetadataAssertionBinder.java | 24 +++ .../sdk/SystemMetadataAssertionValidator.java | 61 ++++++ .../java/io/opentdf/platform/sdk/TDF.java | 180 +++++++++--------- .../java/io/opentdf/platform/sdk/TDFTest.java | 5 - 15 files changed, 697 insertions(+), 105 deletions(-) create mode 100644 sdk/src/main/java/io/opentdf/platform/sdk/AssertionBinder.java create mode 100644 sdk/src/main/java/io/opentdf/platform/sdk/AssertionRegistry.java create mode 100644 sdk/src/main/java/io/opentdf/platform/sdk/AssertionValidator.java create mode 100644 sdk/src/main/java/io/opentdf/platform/sdk/AssertionVerificationMode.java create mode 100644 sdk/src/main/java/io/opentdf/platform/sdk/ConfigBasedAssertionBinder.java create mode 100644 sdk/src/main/java/io/opentdf/platform/sdk/DEKAssertionValidator.java create mode 100644 sdk/src/main/java/io/opentdf/platform/sdk/KeyAssertionBinder.java create mode 100644 sdk/src/main/java/io/opentdf/platform/sdk/KeyAssertionValidator.java create mode 100644 sdk/src/main/java/io/opentdf/platform/sdk/SystemMetadataAssertionBinder.java create mode 100644 sdk/src/main/java/io/opentdf/platform/sdk/SystemMetadataAssertionValidator.java diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/AssertionBinder.java b/sdk/src/main/java/io/opentdf/platform/sdk/AssertionBinder.java new file mode 100644 index 00000000..fe5de922 --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/AssertionBinder.java @@ -0,0 +1,15 @@ +package io.opentdf.platform.sdk; + +import io.opentdf.platform.sdk.Manifest.Assertion; + +public interface AssertionBinder { + /** + * Bind creates and signs an assertion, binding it to the given manifest. + * // The implementation is responsible for both configuring the assertion and binding it. + * + * @param manifest The manifest. + * @return The assertion. + * @throws SDK.AssertionException If an error occurs during binding. + */ + Assertion bind(Manifest manifest) throws SDK.AssertionException; +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/AssertionRegistry.java b/sdk/src/main/java/io/opentdf/platform/sdk/AssertionRegistry.java new file mode 100644 index 00000000..e09bc349 --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/AssertionRegistry.java @@ -0,0 +1,44 @@ +package io.opentdf.platform.sdk; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; + +public class AssertionRegistry { + private final List binders; + private final Map validators; + + public AssertionRegistry() { + this.binders = new CopyOnWriteArrayList<>(); + this.validators = new ConcurrentHashMap<>(); + } + + public void registerBinder(AssertionBinder binder) { + binders.add(binder); + } + + public void registerValidator(AssertionValidator validator) { + String schema = validator.getSchema(); + validators.put(schema, validator); + } + + public List getBinders() { + return Collections.unmodifiableList(binders); + } + + public void setBinders(List binders) { + this.binders.clear(); + this.binders.addAll(binders); + } + + public Map getValidators() { + return Collections.unmodifiableMap(validators); + } + + public void setValidators(Map validators) { + this.validators.clear(); + this.validators.putAll(validators); + } +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/AssertionValidator.java b/sdk/src/main/java/io/opentdf/platform/sdk/AssertionValidator.java new file mode 100644 index 00000000..0ddf4fe7 --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/AssertionValidator.java @@ -0,0 +1,43 @@ +package io.opentdf.platform.sdk; + +import com.nimbusds.jose.JOSEException; +import io.opentdf.platform.sdk.Manifest.Assertion; + +import java.io.IOException; +import java.text.ParseException; + +public interface AssertionValidator { + /** + * // Schema returns the schema URI this validator handles. + * // The schema identifies the assertion format and version. + * // Examples: "urn:opentdf:system:metadata:v1", "urn:opentdf:key:assertion:v1" + * + * @return The schema URI. + */ + String getSchema(); + + void setVerificationMode(AssertionVerificationMode verificationMode); + + /** + * // Verify checks the assertion's cryptographic binding. + * // + * // Example: + * // assertionHash, _ := a.GetHash() + * // manifest := r.Manifest() + * // expectedSig, _ := manifest.ComputeAssertionSignature(assertionHash) + * + * @param assertion The assertion to verify. + * @param manifest The manifest. + * @throws SDK.AssertionException If the verification fails. + */ + void verify(Assertion assertion, Manifest manifest) throws SDK.AssertionException; + + /** + * // Validate checks the assertion's policy and trust requirements + * + * @param assertion The assertion to validate. + * @param reader The TDF reader. + * @throws SDK.AssertionException If the validation fails. + */ + void validate(Assertion assertion, TDFReader reader) throws SDK.AssertionException; +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/AssertionVerificationMode.java b/sdk/src/main/java/io/opentdf/platform/sdk/AssertionVerificationMode.java new file mode 100644 index 00000000..e451b325 --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/AssertionVerificationMode.java @@ -0,0 +1,7 @@ +package io.opentdf.platform.sdk; + +public enum AssertionVerificationMode { + PERMISSIVE, + FAIL_FAST, + STRICT +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/Config.java b/sdk/src/main/java/io/opentdf/platform/sdk/Config.java index ea49d074..1f4f70a2 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/Config.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/Config.java @@ -132,6 +132,22 @@ public static class TDFReaderConfig { KeyType sessionKeyType; Set kasAllowlist; boolean ignoreKasAllowlist; + private AssertionVerificationMode assertionVerificationMode = AssertionVerificationMode.FAIL_FAST; + private final AssertionRegistry assertionRegistry = new AssertionRegistry(); + + public AssertionVerificationMode getAssertionVerificationMode() { + return assertionVerificationMode; + } + + public void setAssertionVerificationMode(AssertionVerificationMode assertionVerificationMode) { + this.assertionVerificationMode = assertionVerificationMode; + } + + public AssertionRegistry getAssertionRegistry() { + return assertionRegistry; + } + + } @SafeVarargs @@ -148,7 +164,18 @@ public static TDFReaderConfig newTDFReaderConfig(Consumer... op public static Consumer withAssertionVerificationKeys( AssertionVerificationKeys assertionVerificationKeys) { - return (TDFReaderConfig config) -> config.assertionVerificationKeys = assertionVerificationKeys; + return (TDFReaderConfig config) -> { + config.assertionVerificationKeys = assertionVerificationKeys; + + // ONLY register wildcard validator if assertion verification is enabled + // This maintains backward compatibility with the disableAssertionVerification flag + if (!config.disableAssertionVerification) { + // Register a wildcard KeyAssertionValidator that handles any schema + // when verification keys are provided + KeyAssertionValidator keyAssertionValidator = new KeyAssertionValidator(assertionVerificationKeys); + config.getAssertionRegistry().registerValidator(keyAssertionValidator); + } + }; } public static Consumer withDisableAssertionVerification(boolean disable) { @@ -195,6 +222,7 @@ public static class TDFConfig { public boolean hexEncodeRootAndSegmentHashes; public boolean renderVersionInfoInManifest; public boolean systemMetadataAssertion; + private AssertionRegistry assertionRegistry; public TDFConfig() { this.autoconfigure = true; @@ -212,6 +240,11 @@ public TDFConfig() { this.hexEncodeRootAndSegmentHashes = false; this.renderVersionInfoInManifest = true; this.systemMetadataAssertion = false; + this.assertionRegistry = new AssertionRegistry(); + } + + public AssertionRegistry getAssertionRegistry() { + return assertionRegistry; } } @@ -289,7 +322,13 @@ public static Consumer withSplitPlan(Autoconfigure.KeySplitStep... p) public static Consumer withAssertionConfig(io.opentdf.platform.sdk.AssertionConfig... assertionList) { return (TDFConfig config) -> { + // add to assertionConfigList for backward compatibility Collections.addAll(config.assertionConfigList, assertionList); + // register a binder for each assertionConfig + for (AssertionConfig assertionConfig : assertionList) { + ConfigBasedAssertionBinder binder = new ConfigBasedAssertionBinder(assertionConfig); + config.getAssertionRegistry().registerBinder(binder); + } }; } diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/ConfigBasedAssertionBinder.java b/sdk/src/main/java/io/opentdf/platform/sdk/ConfigBasedAssertionBinder.java new file mode 100644 index 00000000..fad1dfda --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/ConfigBasedAssertionBinder.java @@ -0,0 +1,42 @@ +package io.opentdf.platform.sdk; + +import com.nimbusds.jose.KeyLengthException; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +public class ConfigBasedAssertionBinder implements AssertionBinder { + private final AssertionConfig assertionConfig; + + public ConfigBasedAssertionBinder(AssertionConfig assertionConfig) { + this.assertionConfig = assertionConfig; + } + + @Override + public Manifest.Assertion bind(Manifest manifest) throws SDK.AssertionException { + Manifest.Assertion assertion = new Manifest.Assertion(); + assertion.id = assertionConfig.id; + assertion.type = assertionConfig.type.toString(); + assertion.scope = assertionConfig.scope.toString(); + assertion.statement = assertionConfig.statement; + assertion.appliesToState = assertionConfig.appliesToState.toString(); + + try { + ByteArrayOutputStream aggregateHash = Manifest.computeAggregateHash(manifest.encryptionInformation.integrityInformation.segments, manifest.payload.isEncrypted); + boolean hexEncodeRootAndSegmentHashes = manifest.tdfVersion == null || manifest.tdfVersion.isEmpty(); + Manifest.Assertion.HashValues hashValues = Manifest.Assertion.calculateAssertionHashValues(aggregateHash, assertion, hexEncodeRootAndSegmentHashes); + if (assertionConfig.signingKey != null && assertionConfig.signingKey.isDefined()) { + assertion.sign(hashValues, assertionConfig.signingKey); + } + // otherwise no explicit signing key provided - use the payload key (DEK) + // this is handled by passing the payload key from the TDF creation context + // for now, return the unsigned assertion - it will be signed by a DEK-based binder + } catch (IOException e) { + throw new SDK.AssertionException("error reading assertion hash", assertionConfig.id); + } catch (KeyLengthException e) { + throw new SDK.AssertionException("error signing assertion", assertionConfig.id); + } + return assertion; + } + +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/DEKAssertionValidator.java b/sdk/src/main/java/io/opentdf/platform/sdk/DEKAssertionValidator.java new file mode 100644 index 00000000..a2b78cd6 --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/DEKAssertionValidator.java @@ -0,0 +1,50 @@ +package io.opentdf.platform.sdk; + +import com.nimbusds.jose.JOSEException; + +import javax.annotation.Nonnull; +import java.io.IOException; +import java.text.ParseException; +import java.util.Objects; + +public class DEKAssertionValidator implements AssertionValidator { + + private AssertionVerificationMode verificationMode = AssertionVerificationMode.FAIL_FAST; + + private AssertionConfig.AssertionKey dekKey; + + public DEKAssertionValidator(AssertionConfig.AssertionKey dekKey) { + this.dekKey = dekKey; + } + + @Override + public String getSchema() { + return ""; + } + + @Override + public void setVerificationMode(@Nonnull AssertionVerificationMode verificationMode) { + this.verificationMode = verificationMode; + } + + @Override + public void verify(Manifest.Assertion assertion, Manifest manifest) throws SDK.AssertionException { + try { + Manifest.Assertion.HashValues hashValues = assertion.verify(dekKey); + var hashOfAssertionAsHex = assertion.hash(); + if (!Objects.equals(hashOfAssertionAsHex, hashValues.getAssertionHash())) { + throw new SDK.AssertionException("assertion hash mismatch", assertion.id); + } + } catch (JOSEException e) { + throw new SDKException("error validating assertion hash", e); + } catch (ParseException e) { + throw new SDK.AssertionException("error parsing assertion hash", assertion.id); + } catch (IOException e) { + throw new SDK.AssertionException("error reading assertion hash", assertion.id); + } + } + + // Validate does nothing - DEK-based validation doesn't check trust/policy. + @Override + public void validate(Manifest.Assertion assertion, TDFReader reader) throws SDK.AssertionException {} +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KeyAssertionBinder.java b/sdk/src/main/java/io/opentdf/platform/sdk/KeyAssertionBinder.java new file mode 100644 index 00000000..e2ad3bac --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/KeyAssertionBinder.java @@ -0,0 +1,66 @@ +package io.opentdf.platform.sdk; + +import com.nimbusds.jose.Algorithm; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jose.KeyLengthException; +import com.nimbusds.jose.jwk.RSAKey; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.security.interfaces.RSAPublicKey; +import java.util.Optional; + + +public class KeyAssertionBinder implements AssertionBinder { + + public static final String KEY_ASSERTION_ID = "assertion-key"; + public static final String KEY_ASSERTION_SCHEMA = "urn:opentdf:key:assertion:v1"; + + private final AssertionConfig.AssertionKey privateKey; + private final AssertionConfig.AssertionKey publicKey; + private final String statementValue; + + public KeyAssertionBinder(AssertionConfig.AssertionKey privateKey, AssertionConfig.AssertionKey publicKey, String statementValue) { + this.privateKey = privateKey; + this.publicKey = publicKey; + this.statementValue = statementValue; + } + + @Override + public Manifest.Assertion bind(Manifest manifest) throws SDK.AssertionException { + Manifest.Assertion assertion = new Manifest.Assertion(); + assertion.id = KEY_ASSERTION_ID; + assertion.type = "other"; + assertion.scope = "payload"; + assertion.statement = new AssertionConfig.Statement(); + assertion.statement.format = "json"; + assertion.statement.schema = KEY_ASSERTION_SCHEMA; + assertion.statement.value = statementValue; + assertion.appliesToState = "unencrypted"; + + RSAKey publicKeyJwk = new RSAKey.Builder((RSAPublicKey) publicKey.key) + .algorithm(Algorithm.parse(publicKey.alg.toString())) + .build(); + + var protectedHeaders = new java.util.HashMap(); + // set key id to public key algorithm in protected headers + protectedHeaders.put("kid", publicKey.alg.toString()); + // set jwk as a protected header + protectedHeaders.put("jwk", publicKeyJwk.toJSONObject()); + + try { + ByteArrayOutputStream aggregateHash = Manifest.computeAggregateHash(manifest.encryptionInformation.integrityInformation.segments, manifest.payload.isEncrypted); + boolean hexEncodeRootAndSegmentHashes = manifest.tdfVersion == null || manifest.tdfVersion.isEmpty(); + Manifest.Assertion.HashValues hashValues = Manifest.Assertion.calculateAssertionHashValues(aggregateHash, assertion, hexEncodeRootAndSegmentHashes); + try { + assertion.sign(hashValues, privateKey, Optional.of(protectedHeaders)); + } catch (KeyLengthException e) { + throw new SDK.AssertionException("error signing assertion hash", assertion.id); + } + } catch (IOException e) { + throw new SDK.AssertionException("error calculating assertion hash", assertion.id); + } + + return assertion; + } +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KeyAssertionValidator.java b/sdk/src/main/java/io/opentdf/platform/sdk/KeyAssertionValidator.java new file mode 100644 index 00000000..0857cc29 --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/KeyAssertionValidator.java @@ -0,0 +1,73 @@ +package io.opentdf.platform.sdk; + +import com.nimbusds.jose.JOSEException; + +import javax.annotation.Nonnull; +import java.io.IOException; +import java.text.ParseException; +import java.util.Objects; + +public class KeyAssertionValidator implements AssertionValidator { + + private final Config.AssertionVerificationKeys assertionVerificationKeys; + + private AssertionVerificationMode assertionVerificationMode = AssertionVerificationMode.FAIL_FAST; + + public KeyAssertionValidator(Config.AssertionVerificationKeys assertionVerificationKeys) { + this.assertionVerificationKeys = assertionVerificationKeys; + } + + @Override + public String getSchema() { + return "*"; + } + + @Override + public void verify(Manifest.Assertion assertion, Manifest manifest) throws SDK.AssertionException { + + if (Objects.equals(assertion.binding.signature, "")) { + throw new SDK.AssertionException("assertion has no cryptographic binding", assertion.id); + } + + if (assertionVerificationKeys.isEmpty()) { + if (Objects.requireNonNull(assertionVerificationMode) == AssertionVerificationMode.PERMISSIVE) { + return; + } + throw new SDK.AssertionException("no verification keys configured for assertion validation", assertion.id); + } + AssertionConfig.AssertionKey assertionKey = assertionVerificationKeys.getKey(assertion.id); + + try { + Manifest.Assertion.HashValues hashValues = assertion.verify(assertionKey); + + if(hashValues.getSchema() != null && !Objects.equals(hashValues.getSchema(), assertion.statement.schema)) { + throw new SDK.AssertionException("Assertion schema mismatch", assertion.id); + } + + if (!Objects.equals(assertion.hash(), hashValues.getAssertionHash())) { + throw new SDK.AssertionException("Assertion hash mismatch", assertion.id); + } + + Manifest.Assertion.verifyAssertionSignatureFormat(hashValues.getSignature(), assertion, manifest); + } catch (ParseException | JOSEException | IOException e) { + throw new SDK.AssertionException("failed to verify assertion signature", assertion.id); + } + } + + @Override + public void validate(Manifest.Assertion assertion, TDFReader reader) throws SDK.AssertionException { + if (assertionVerificationKeys.isEmpty()) { + throw new SDK.AssertionException("no verification keys are trusted", assertion.id); + } + + AssertionConfig.AssertionKey assertionKey = assertionVerificationKeys.getKey(assertion.id); + + if (assertionKey == null) { + throw new SDK.AssertionException("no verification keys are trusted", assertion.id); + } + } + + public void setVerificationMode(@Nonnull AssertionVerificationMode assertionVerificationMode) { + this.assertionVerificationMode = assertionVerificationMode; + } +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/Manifest.java b/sdk/src/main/java/io/opentdf/platform/sdk/Manifest.java index 9cd94aa1..d033e27e 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/Manifest.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/Manifest.java @@ -23,9 +23,11 @@ import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; import io.opentdf.platform.sdk.SDK.AssertionException; +import org.apache.commons.codec.DecoderException; import org.apache.commons.codec.binary.Hex; import org.erdtman.jcs.JsonCanonicalizer; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.lang.reflect.Type; import java.nio.charset.StandardCharsets; @@ -34,10 +36,7 @@ import java.security.PrivateKey; import java.security.interfaces.RSAPublicKey; import java.text.ParseException; -import java.util.ArrayList; -import java.util.Base64; -import java.util.List; -import java.util.Objects; +import java.util.*; /** * The Manifest class represents a detailed structure encapsulating various @@ -49,6 +48,7 @@ public class Manifest { private static final String kAssertionHash = "assertionHash"; private static final String kAssertionSignature = "assertionSig"; + private static final String kAssertionSchema = "assertionSchema"; private static final Gson gson = new GsonBuilder() .registerTypeAdapter(AssertionConfig.Statement.class, new AssertionValueAdapter()) @@ -324,10 +324,12 @@ static public class Assertion { static public class HashValues { private final String assertionHash; private final String signature; + private final String schema; - public HashValues(String assertionHash, String signature) { + public HashValues(String assertionHash, String signature, String schema) { this.assertionHash = assertionHash; this.signature = signature; + this.schema = schema; } public String getAssertionHash() { @@ -337,6 +339,10 @@ public String getAssertionHash() { public String getSignature() { return signature; } + + public String getSchema() { + return schema; + } } @Override @@ -357,6 +363,14 @@ public int hashCode() { } public String hash() throws IOException { + return Hex.encodeHexString(this.hashAsBytes()); + } + + public String hashAsHexEncodedString() throws IOException { + return Hex.encodeHexString(this.hashAsBytes()); + } + + public byte[] hashAsBytes() throws IOException { MessageDigest digest; try { digest = MessageDigest.getInstance("SHA-256"); @@ -366,7 +380,7 @@ public String hash() throws IOException { var assertionAsJson = gson.toJson(this); JsonCanonicalizer jc = new JsonCanonicalizer(assertionAsJson); - return Hex.encodeHexString(digest.digest(jc.getEncodedUTF8())); + return digest.digest(jc.getEncodedUTF8()); } // Sign the assertion with the given hash and signature using the key. @@ -374,6 +388,11 @@ public String hash() throws IOException { // The assertion binding is updated with the method and the signature. public void sign(final HashValues hashValues, final AssertionConfig.AssertionKey assertionKey) throws KeyLengthException { + sign(hashValues, assertionKey, Optional.empty()); + } + + public void sign(final HashValues hashValues, final AssertionConfig.AssertionKey assertionKey, final Optional> protectedHeaders) + throws KeyLengthException { // Build JWT claims final JWTClaimsSet claims = new JWTClaimsSet.Builder() .claim(kAssertionHash, hashValues.assertionHash) @@ -381,7 +400,7 @@ public void sign(final HashValues hashValues, final AssertionConfig.AssertionKey .build(); // Prepare for signing - SignedJWT signedJWT = createSignedJWT(claims, assertionKey); + SignedJWT signedJWT = createSignedJWT(claims, assertionKey, protectedHeaders); try { // Sign the JWT @@ -418,25 +437,49 @@ public Assertion.HashValues verify(AssertionConfig.AssertionKey assertionKey) JWTClaimsSet claimsSet = signedJWT.getJWTClaimsSet(); String assertionHash = claimsSet.getStringClaim(kAssertionHash); String signature = claimsSet.getStringClaim(kAssertionSignature); + String schema = claimsSet.getStringClaim(kAssertionSchema); - return new Assertion.HashValues(assertionHash, signature); + return new Assertion.HashValues(assertionHash, signature, schema); } - private SignedJWT createSignedJWT(final JWTClaimsSet claims, final AssertionConfig.AssertionKey assertionKey) + public static Manifest.Assertion.HashValues calculateAssertionHashValues(ByteArrayOutputStream aggregateHash, Manifest.Assertion assertion, boolean hexEncodeRootAndSegmentHashes) throws IOException { + var hashOfAssertionAsHex = assertion.hash(); + byte[] assertionHash; + if (hexEncodeRootAndSegmentHashes) { + assertionHash = hashOfAssertionAsHex.getBytes(StandardCharsets.UTF_8); + } else { + try { + assertionHash = Hex.decodeHex(hashOfAssertionAsHex); + } catch (DecoderException e) { + throw new SDKException("error decoding assertion hash", e); + } + } + byte[] completeHash = new byte[aggregateHash.size() + assertionHash.length]; + System.arraycopy(aggregateHash.toByteArray(), 0, completeHash, 0, aggregateHash.size()); + System.arraycopy(assertionHash, 0, completeHash, aggregateHash.size(), assertionHash.length); + + var encodedHash = Base64.getEncoder().encodeToString(completeHash); + + return new Manifest.Assertion.HashValues(hashOfAssertionAsHex, encodedHash, null); + } + + private SignedJWT createSignedJWT(final JWTClaimsSet claims, final AssertionConfig.AssertionKey assertionKey, final Optional> protectedHeaders) throws SDKException { - final JWSHeader jwsHeader; + final JWSHeader.Builder jwsHeaderBuilder; switch (assertionKey.alg) { case RS256: - jwsHeader = new JWSHeader.Builder(JWSAlgorithm.RS256).build(); + jwsHeaderBuilder = new JWSHeader.Builder(JWSAlgorithm.RS256); break; case HS256: - jwsHeader = new JWSHeader.Builder(JWSAlgorithm.HS256).build(); + jwsHeaderBuilder = new JWSHeader.Builder(JWSAlgorithm.HS256); break; default: throw new SDKException("Unknown assertion key algorithm, error signing assertion"); } - return new SignedJWT(jwsHeader, claims); + protectedHeaders.ifPresent(headers -> headers.forEach(jwsHeaderBuilder::customParam)); + + return new SignedJWT(jwsHeaderBuilder.build(), claims); } private JWSSigner createSigner(final AssertionConfig.AssertionKey assertionKey) @@ -467,6 +510,36 @@ private JWSVerifier createVerifier(AssertionConfig.AssertionKey assertionKey) th throw new SDKException("Unknown verify key, unable to verify assertion signature"); } } + + // VerifyAssertionSignatureFormat validates that the assertion signature matches the expected format. + // This is the standard format used across all SDKs: base64(aggregateHash + assertionHash). + // + // This function is a convenience helper that: + // 1. Computes the aggregate hash from manifest segments + // 2. Determines the encoding format (hex vs raw bytes) from the TDF version + // 3. Computes the expected signature using the standard format + // 4. Compares it against the verified signature from the JWT + // + // Parameters: + // - verifiedSignature: The signature claim extracted from the verified JWT + // - assertion: The assertion + // - manifest: The TDF manifest containing segments and version info + // + // Throws an exception if the signature format is invalid (tampering detected) + // + // This function is used by custom AssertionValidator implementations to verify + // assertion signatures after JWT verification. + public static void verifyAssertionSignatureFormat(String verifiedSignature, Assertion assertion, Manifest manifest) throws IOException { + ByteArrayOutputStream aggregateHash = Manifest.computeAggregateHash(manifest.encryptionInformation.integrityInformation.segments, manifest.payload.isEncrypted); + + boolean useHex = manifest.tdfVersion == null || manifest.tdfVersion.isEmpty(); + + HashValues hashValues = Assertion.calculateAssertionHashValues(aggregateHash, assertion, useHex); + + if (!hashValues.signature.equals(verifiedSignature)) { + throw new AssertionException("failed integrity check on assertion signature", assertion.id); + } + } } public static class AssertionValueAdapter implements JsonDeserializer { @@ -545,4 +618,54 @@ static PolicyObject decodePolicyObject(Manifest manifest) { return gson.fromJson(policyJson, PolicyObject.class); } + + public static ByteArrayOutputStream computeAggregateHash(List segments, boolean isEncrypted) throws IOException { + ByteArrayOutputStream aggregateHash = new ByteArrayOutputStream(); + for (Manifest.Segment segment : segments) { + if (isEncrypted) { + byte[] decodedHash = Base64.getDecoder().decode(segment.hash); + aggregateHash.write(decodedHash); + } else { + aggregateHash.write(segment.hash.getBytes()); + } + } + return aggregateHash; + } + + public ByteArrayOutputStream computeAggregateHash() { + ByteArrayOutputStream aggregateHash = new ByteArrayOutputStream(); + for (Manifest.Segment segment : this.encryptionInformation.integrityInformation.segments) { + byte[] decodedHash = Base64.getDecoder().decode(segment.hash); + try { + aggregateHash.write(decodedHash); + } catch (IOException e) { + throw new SDKException("failed to decode segment hash"); + } + } + return aggregateHash; + } + + public Assertion.HashValues computeAssertionSignature(String assertionHash) { + ByteArrayOutputStream aggregateHash = this.computeAggregateHash(); + + // use hex if this.tdfVersion is null or empty + boolean useHex = this.tdfVersion == null || this.tdfVersion.isEmpty(); + + byte[] hashToUse; + if (useHex) { + hashToUse = assertionHash.getBytes(StandardCharsets.UTF_8); + } else { + try { + hashToUse = Hex.decodeHex(assertionHash); + } catch (DecoderException e) { + throw new SDKException("error decoding assertion hash", e); + } + } + byte[] completeHash = new byte[aggregateHash.size() + hashToUse.length]; + System.arraycopy(aggregateHash.toByteArray(), 0, completeHash, 0, aggregateHash.size()); + System.arraycopy(hashToUse, 0, completeHash, aggregateHash.size(), hashToUse.length); + + String signature = Base64.getEncoder().encodeToString(completeHash); + return new Manifest.Assertion.HashValues(assertionHash, signature, null); + } } diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java index ba1c8082..69f44810 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDK.java @@ -3,6 +3,7 @@ import com.connectrpc.Interceptor; import com.connectrpc.impl.ProtocolClient; +import com.nimbusds.jose.JOSEException; import io.opentdf.platform.authorization.AuthorizationServiceClientInterface; import io.opentdf.platform.policy.attributes.AttributesServiceClientInterface; import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClientInterface; @@ -17,6 +18,7 @@ import java.io.OutputStream; import java.nio.ByteBuffer; import java.nio.channels.SeekableByteChannel; +import java.text.ParseException; import java.util.Optional; /** diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SystemMetadataAssertionBinder.java b/sdk/src/main/java/io/opentdf/platform/sdk/SystemMetadataAssertionBinder.java new file mode 100644 index 00000000..e6c8a2ae --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SystemMetadataAssertionBinder.java @@ -0,0 +1,24 @@ +package io.opentdf.platform.sdk; + +import static io.opentdf.platform.sdk.TDF.TDF_SPEC_VERSION; + +public class SystemMetadataAssertionBinder implements AssertionBinder { + public static final String SYSTEM_METADATA_SCHEMA_V1 = "system-metadata-v1"; + + @Override + public Manifest.Assertion bind(Manifest manifest) { + AssertionConfig assertionConfig = AssertionConfig.getSystemMetadataAssertionConfig(TDF_SPEC_VERSION); + + assertionConfig.statement.schema = SYSTEM_METADATA_SCHEMA_V1; + + Manifest.Assertion assertion = new Manifest.Assertion(); + assertion.id = assertionConfig.id; + assertion.type = assertionConfig.type.toString(); + assertion.scope = assertionConfig.scope.toString(); + assertion.statement = assertionConfig.statement; + assertion.appliesToState = assertionConfig.appliesToState.toString(); + + return assertion; + } + +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SystemMetadataAssertionValidator.java b/sdk/src/main/java/io/opentdf/platform/sdk/SystemMetadataAssertionValidator.java new file mode 100644 index 00000000..092b766c --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SystemMetadataAssertionValidator.java @@ -0,0 +1,61 @@ +package io.opentdf.platform.sdk; + +import com.nimbusds.jose.JOSEException; + +import javax.annotation.Nonnull; +import java.io.IOException; +import java.text.ParseException; +import java.util.Objects; + +public class SystemMetadataAssertionValidator implements AssertionValidator { + + private final byte[] payloadKey; + + private AssertionVerificationMode verificationMode = AssertionVerificationMode.FAIL_FAST; + + public SystemMetadataAssertionValidator(byte[] payloadKey) { + this.payloadKey = payloadKey; + } + + @Override + public String getSchema() { + return SystemMetadataAssertionBinder.SYSTEM_METADATA_SCHEMA_V1; + } + + @Override + public void setVerificationMode(@Nonnull AssertionVerificationMode verificationMode) { + this.verificationMode = verificationMode; + } + + @Override + public void verify(Manifest.Assertion assertion, Manifest manifest) throws SDK.AssertionException { + boolean isValidSchema = Objects.equals(assertion.statement.schema, SystemMetadataAssertionBinder.SYSTEM_METADATA_SCHEMA_V1) || + Objects.equals(assertion.statement.schema, ""); + + if (!isValidSchema) { + throw new SDK.AssertionException("System Metadata assertion schema is invalid", assertion.id); + } + + AssertionConfig.AssertionKey assertionKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.HS256, + payloadKey); + + try { + Manifest.Assertion.HashValues hashValues = assertion.verify(assertionKey); + var hashOfAssertionAsHex = assertion.hash(); + if (!Objects.equals(hashOfAssertionAsHex, hashValues.getAssertionHash())) { + throw new SDK.AssertionException("assertion hash mismatch", assertion.id); + } + } catch (JOSEException e) { + throw new SDK.AssertionException("error validating assertion hash", assertion.id); + } catch (ParseException e) { + throw new SDK.AssertionException("error parsing assertion hash", assertion.id); + } catch (IOException e) { + throw new SDK.AssertionException("error reading assertion hash", assertion.id); + } + } + + // Validate does nothing. + @Override + public void validate(Manifest.Assertion assertion, TDFReader reader) throws SDK.AssertionException {} + +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java index 2ae08f5b..89a65062 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java @@ -371,8 +371,8 @@ TDFObject createTDF(InputStream payload, OutputStream outputStream, Config.TDFCo // Add System Metadata Assertion if configured if (tdfConfig.systemMetadataAssertion) { - AssertionConfig systemAssertion = AssertionConfig.getSystemMetadataAssertionConfig(TDF_SPEC_VERSION); - tdfConfig.assertionConfigList.add(systemAssertion); + SystemMetadataAssertionBinder systemAssertionBinder = new SystemMetadataAssertionBinder(); + tdfConfig.getAssertionRegistry().registerBinder(systemAssertionBinder); } TDFObject tdfObject = new TDFObject(); @@ -458,50 +458,41 @@ TDFObject createTDF(InputStream payload, OutputStream outputStream, Config.TDFCo tdfObject.manifest.payload.url = TDFWriter.TDF_PAYLOAD_FILE_NAME; tdfObject.manifest.payload.isEncrypted = true; - List signedAssertions = new ArrayList<>(tdfConfig.assertionConfigList.size()); + AssertionConfig.AssertionKey dekKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.HS256, + tdfObject.payloadKey); - for (var assertionConfig : tdfConfig.assertionConfigList) { - var assertion = new Manifest.Assertion(); - assertion.id = assertionConfig.id; - assertion.type = assertionConfig.type.toString(); - assertion.scope = assertionConfig.scope.toString(); - assertion.statement = assertionConfig.statement; - assertion.appliesToState = assertionConfig.appliesToState.toString(); + for (AssertionBinder binder : tdfConfig.getAssertionRegistry().getBinders()) { + Manifest.Assertion assertion = binder.bind(tdfObject.manifest); + if (assertion.binding == null) { - var assertionHashAsHex = assertion.hash(); - byte[] assertionHash; - if (tdfConfig.hexEncodeRootAndSegmentHashes) { - assertionHash = assertionHashAsHex.getBytes(StandardCharsets.UTF_8); - } else { - try { - assertionHash = Hex.decodeHex(assertionHashAsHex); - } catch (DecoderException e) { - throw new SDKException("error decoding assertion hash", e); - } - } - byte[] completeHash = new byte[aggregateHash.size() + assertionHash.length]; - System.arraycopy(aggregateHash.toByteArray(), 0, completeHash, 0, aggregateHash.size()); - System.arraycopy(assertionHash, 0, completeHash, aggregateHash.size(), assertionHash.length); + boolean useHex = tdfObject.manifest.tdfVersion == null || tdfObject.manifest.tdfVersion.isEmpty(); - var encodedHash = Base64.getEncoder().encodeToString(completeHash); + var assertionHashAsHex = assertion.hashAsHexEncodedString(); + byte[] assertionHashBytes; + if (useHex) { + assertionHashBytes = assertionHashAsHex.getBytes(StandardCharsets.UTF_8); + } else { + try { + assertionHashBytes = Hex.decodeHex(assertionHashAsHex); + } catch (DecoderException e) { + throw new SDKException("error decoding assertion hash", e); + } + } + byte[] completeHash = new byte[aggregateHash.size() + assertionHashBytes.length]; + System.arraycopy(aggregateHash.toByteArray(), 0, completeHash, 0, aggregateHash.size()); + System.arraycopy(assertionHashBytes, 0, completeHash, aggregateHash.size(), assertionHashBytes.length); - var assertionSigningKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.HS256, - tdfObject.aesGcm.getKey()); - if (assertionConfig.signingKey != null && assertionConfig.signingKey.isDefined()) { - assertionSigningKey = assertionConfig.signingKey; - } - var hashValues = new Manifest.Assertion.HashValues( - assertionHashAsHex, - encodedHash); - try { - assertion.sign(hashValues, assertionSigningKey); - } catch (KeyLengthException e) { - throw new SDKException("error signing assertion hash", e); + var encodedHash = Base64.getEncoder().encodeToString(completeHash); + var hashValues = new Manifest.Assertion.HashValues(assertionHashAsHex, encodedHash, null); + try { + assertion.sign(hashValues, dekKey); + } catch (KeyLengthException e) { + throw new SDKException("error signing assertion hash", e); + } } - signedAssertions.add(assertion); + tdfObject.manifest.assertions.add(assertion); } - tdfObject.manifest.assertions = signedAssertions; String manifestAsStr = gson.toJson(tdfObject.manifest); tdfWriter.appendManifest(manifestAsStr); @@ -510,7 +501,6 @@ TDFObject createTDF(InputStream payload, OutputStream outputStream, Config.TDFCo return tdfObject; } - Reader loadTDF(SeekableByteChannel tdf, String platformUrl) throws SDKException, IOException { return loadTDF(tdf, Config.newTDFReaderConfig(), platformUrl); } @@ -628,15 +618,7 @@ Reader loadTDF(SeekableByteChannel tdf, Config.TDFReaderConfig tdfReaderConfig) String rootAlgorithm = manifest.encryptionInformation.integrityInformation.rootSignature.algorithm; String rootSignature = manifest.encryptionInformation.integrityInformation.rootSignature.signature; - ByteArrayOutputStream aggregateHash = new ByteArrayOutputStream(); - for (Manifest.Segment segment : manifest.encryptionInformation.integrityInformation.segments) { - if (manifest.payload.isEncrypted) { - byte[] decodedHash = Base64.getDecoder().decode(segment.hash); - aggregateHash.write(decodedHash); - } else { - aggregateHash.write(segment.hash.getBytes()); - } - } + ByteArrayOutputStream aggregateHash = Manifest.computeAggregateHash(manifest.encryptionInformation.integrityInformation.segments, manifest.payload.isEncrypted); String rootSigValue; boolean isLegacyTdf = manifest.tdfVersion == null || manifest.tdfVersion.isEmpty(); @@ -675,52 +657,78 @@ Reader loadTDF(SeekableByteChannel tdf, Config.TDFReaderConfig tdfReaderConfig) } var aggregateHashByteArrayBytes = aggregateHash.toByteArray(); - // Validate assertions - for (var assertion : manifest.assertions) { - // Skip assertion verification if disabled - if (tdfReaderConfig.disableAssertionVerification) { - break; - } - // Set default to HS256 - var assertionKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.HS256, payloadKey); - Config.AssertionVerificationKeys assertionVerificationKeys = tdfReaderConfig.assertionVerificationKeys; - if (!assertionVerificationKeys.isEmpty()) { - var keyForAssertion = assertionVerificationKeys.getKey(assertion.id); - if (keyForAssertion != null) { - assertionKey = keyForAssertion; - } - } + if(tdfReaderConfig.disableAssertionVerification) { + // skip all assertion verification + return new Reader(tdfReader, manifest, payloadKey, unencryptedMetadata); + } - Manifest.Assertion.HashValues hashValues; - try { - hashValues = assertion.verify(assertionKey); - } catch (ParseException | JOSEException e) { - throw new SDKException("error validating assertion hash", e); - } - var hashOfAssertionAsHex = assertion.hash(); + // Propagate verification mode to all registered validators + // This ensures validators respect the configured verification mode + for (AssertionValidator validator : tdfReaderConfig.getAssertionRegistry().getValidators().values()){ + validator.setVerificationMode(tdfReaderConfig.getAssertionVerificationMode()); + } + + // Register system metadata assertion validator + SystemMetadataAssertionValidator systemMetadataAssertionValidator = new SystemMetadataAssertionValidator(payloadKey); + systemMetadataAssertionValidator.setVerificationMode(tdfReaderConfig.getAssertionVerificationMode()); + tdfReaderConfig.getAssertionRegistry().registerValidator(systemMetadataAssertionValidator); + + // Create DEK-based validator for fallback verification (not registered with wildcard) + // This will be used as a last resort for unknown assertions that might be DEK-signed + AssertionConfig.AssertionKey dekKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.HS256, payloadKey); + DEKAssertionValidator dekValidator = new DEKAssertionValidator(dekKey); + dekValidator.setVerificationMode(tdfReaderConfig.getAssertionVerificationMode()); - if (!Objects.equals(hashOfAssertionAsHex, hashValues.getAssertionHash())) { - throw new SDK.AssertionException("assertion hash mismatch", assertion.id); + // Validate assertions based on configured verification mode + for (var assertion : manifest.assertions) { + // SECURITY: Assertions without cryptographic bindings cannot be verified and must fail + // This prevents unsigned assertions from being tampered with + // Unsigned assertions represent a security risk and should not be accepted + if (assertion.binding.signature == null || assertion.binding.signature.isEmpty()) { + throw new SDK.AssertionException("assertion has no cryptographic binding - unsigned assertions are not allowed", assertion.id); } - byte[] hashOfAssertion; - if (isLegacyTdf) { - hashOfAssertion = hashOfAssertionAsHex.getBytes(StandardCharsets.UTF_8); - } else { + AssertionValidator validator = null; + boolean dekVerified = false; // Flag to track if DEK validator successfully verified + + if (assertion.statement.schema != null && tdfReaderConfig.getAssertionRegistry().getValidators().containsKey(assertion.statement.schema)) { + validator = tdfReaderConfig.getAssertionRegistry().getValidators().get(assertion.statement.schema); + } else if (tdfReaderConfig.assertionVerificationKeys.isEmpty()) { + // No schema-specific validator found, and no explicit verification keys provided + // Try DEK-based verification as a fallback (for assertions signed with DEK during encryption) try { - hashOfAssertion = Hex.decodeHex(hashOfAssertionAsHex); - } catch (DecoderException e) { - throw new SDKException("error decoding assertion hash", e); + dekValidator.verify(assertion, manifest); + dekVerified = true; // DEK verification succeeded + validator = dekValidator; // Assign dekValidator as the effective validator + } catch (SDKException e) { + if (e.getMessage().equals("Unable to verify assertion signature")) { + // JWT signature verification failed with DEK - assertion not signed with DEK + // Treat as unknown assertion (forward compatibility) + validator = null; + } else { + // DEK verification failed for other reason (hash mismatch, binding mismatch, etc.) + // This indicates tampering of a DEK-signed assertion - FAIL immediately + throw e; + } } } - var signature = new byte[aggregateHashByteArrayBytes.length + hashOfAssertion.length]; - System.arraycopy(aggregateHashByteArrayBytes, 0, signature, 0, aggregateHashByteArrayBytes.length); - System.arraycopy(hashOfAssertion, 0, signature, aggregateHashByteArrayBytes.length, hashOfAssertion.length); - var encodeSignature = Base64.getEncoder().encodeToString(signature); - if (!Objects.equals(encodeSignature, hashValues.getSignature())) { - throw new SDK.AssertionException("failed integrity check on assertion signature", assertion.id); + if (validator == null){ + switch (tdfReaderConfig.getAssertionVerificationMode()) { + case STRICT: + throw new SDK.AssertionException("unknown assertion type in strict mode", assertion.id); + case PERMISSIVE: + case FAIL_FAST: + continue; + } + } else { + // If it was already DEK verified, no need to call verify again. + // Otherwise, call verify for the schema-specific validator. + if (!dekVerified) { + validator.verify(assertion, manifest); + } + validator.validate(assertion, tdfReader); } } diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/TDFTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/TDFTest.java index c28bd2bd..f5bf5a2a 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/TDFTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/TDFTest.java @@ -350,11 +350,6 @@ void testWithAssertionVerificationDisabled() throws Exception { .setKeyAccessServerRegistryService(kasRegistryService).build()); tdf.createTDF(plainTextInputStream, tdfOutputStream, config); - var assertionVerificationKeys = new Config.AssertionVerificationKeys(); - assertionVerificationKeys.keys.put(assertion1Id, - new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.RS256, - keypair.getPublic())); - var unwrappedData = new ByteArrayOutputStream(); var dataToUnwrap = new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()); var emptyConfig = Config.newTDFReaderConfig();