/*
 * Decompiled with CFR 0.152.
 */
package cryptix.jce.provider.rsa;

import cryptix.jce.provider.rsa.RSAAlgorithm;
import cryptix.jce.provider.util.Util;
import java.math.BigInteger;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.SignatureSpi;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;

public abstract class RSASignature_PSS
extends SignatureSpi {
    private static final byte[] MASK = new byte[]{-1, 127, 63, 31, 15, 7, 3, 1};
    private final MessageDigest md;
    private final int hLen;
    private final int sLen;
    private byte[] presetSalt;
    private int emLen;
    private int emBits;
    private BigInteger exp;
    private BigInteger n;
    private BigInteger p;
    private BigInteger q;
    private BigInteger u;
    private SecureRandom rng;

    protected Object engineGetParameter(String a) {
        throw new RuntimeException("NYI");
    }

    protected void engineInitSign(PrivateKey key, SecureRandom random) throws InvalidKeyException {
        if (!(key instanceof RSAPrivateKey)) {
            throw new InvalidKeyException("Not an RSA private key");
        }
        RSAPrivateKey rsa = (RSAPrivateKey)key;
        this.n = rsa.getModulus();
        this.exp = rsa.getPrivateExponent();
        if (key instanceof RSAPrivateCrtKey) {
            RSAPrivateCrtKey crt = (RSAPrivateCrtKey)key;
            this.p = crt.getPrimeP();
            this.q = crt.getPrimeQ();
            this.u = crt.getCrtCoefficient();
        } else {
            this.u = null;
            this.q = null;
            this.p = null;
        }
        this.rng = random;
        this.initCommon();
    }

    protected void engineInitSign(PrivateKey privateKey) throws InvalidKeyException {
        this.engineInitSign(privateKey, new SecureRandom());
    }

    protected void engineInitVerify(PublicKey key) throws InvalidKeyException {
        if (!(key instanceof RSAPublicKey)) {
            throw new InvalidKeyException("Not an RSA public key");
        }
        RSAPublicKey rsa = (RSAPublicKey)key;
        this.n = rsa.getModulus();
        this.exp = rsa.getPublicExponent();
        this.u = null;
        this.q = null;
        this.p = null;
        this.rng = null;
        this.initCommon();
    }

    private void initCommon() throws InvalidKeyException {
        this.emBits = this.getModulusBitLen() - 1;
        this.emLen = (this.emBits + 7) / 8;
        if (this.emBits < 8 * this.hLen + 8 * this.sLen + 9) {
            throw new InvalidKeyException("Signer's key modulus too short.");
        }
        this.md.reset();
    }

    protected void engineSetParameter(String name, Object param) {
        if (name.equalsIgnoreCase("CryptixDebugFixedSalt") && param instanceof byte[]) {
            this.presetSalt = (byte[])param;
        }
    }

    protected byte[] engineSign() {
        byte[] salt;
        byte[] padding1 = new byte[8];
        byte[] mHash = this.md.digest();
        if (this.presetSalt == null) {
            salt = new byte[this.sLen];
            this.rng.nextBytes(salt);
        } else {
            if (this.sLen != this.presetSalt.length) {
                throw new Error("Invalid presetSalt, size mismatch!");
            }
            salt = this.presetSalt;
            this.presetSalt = null;
            System.err.println("Using preset salt: " + cryptix.jce.util.Util.toString((byte[])salt) + "!");
        }
        this.md.update(padding1);
        this.md.update(mHash);
        byte[] H = this.md.digest(salt);
        byte[] dbMask = this.mgf1(H, this.emLen - this.hLen - 1);
        byte[] PS = new byte[this.emLen - this.sLen - this.hLen - 2];
        byte[] one = new byte[]{1};
        byte[] DB = this.concat(PS, one, salt);
        byte[] maskedDB = RSASignature_PSS.xor(DB, dbMask);
        int maskBits = 8 * this.emLen - this.emBits;
        maskedDB[0] = (byte)(maskedDB[0] & MASK[maskBits]);
        byte[] EM = this.concat(maskedDB, H, new byte[]{-68});
        BigInteger m = new BigInteger(1, EM);
        if (m.compareTo(this.n) != -1) {
            throw new InternalError("message > modulus!");
        }
        BigInteger s = RSAAlgorithm.rsa(m, this.n, this.exp, this.p, this.q, this.u);
        return Util.toFixedLenByteArray(s, this.getModulusLen());
    }

    private int getModulusLen() {
        return (this.n.bitLength() + 7) / 8;
    }

    private int getModulusBitLen() {
        return this.n.bitLength();
    }

    private static byte[] xor(byte[] a, byte[] b) {
        if (a.length != b.length) {
            throw new InternalError("a.len != b.len");
        }
        byte[] res = new byte[a.length];
        int i = 0;
        while (i < res.length) {
            res[i] = (byte)(a[i] ^ b[i]);
            ++i;
        }
        return res;
    }

    private byte[] mgf1(byte[] seed, int len) {
        int hashCount = (len + this.hLen - 1) / this.hLen;
        byte[] mask = new byte[]{};
        int i = 0;
        while (i < hashCount) {
            mask = this.concat(mask, this.mgf1Hash(seed, (byte)i));
            ++i;
        }
        byte[] res = new byte[len];
        System.arraycopy(mask, 0, res, 0, res.length);
        return res;
    }

    private byte[] mgf1Hash(byte[] seed, byte c) {
        this.md.update(seed);
        this.md.update(new byte[3]);
        this.md.update(c);
        return this.md.digest();
    }

    private byte[] concat(byte[] a, byte[] b) {
        byte[] res = new byte[a.length + b.length];
        System.arraycopy(a, 0, res, 0, a.length);
        System.arraycopy(b, 0, res, a.length, b.length);
        return res;
    }

    private byte[] concat(byte[] a, byte[] b, byte[] c) {
        return this.concat(a, this.concat(b, c));
    }

    protected void engineUpdate(byte b) {
        this.md.update(b);
    }

    protected void engineUpdate(byte[] buf, int off, int len) {
        this.md.update(buf, off, len);
    }

    protected boolean engineVerify(byte[] signature) {
        if (signature.length != this.getModulusLen()) {
            return false;
        }
        BigInteger s = new BigInteger(1, signature);
        if (s.compareTo(BigInteger.ZERO) < 0 || s.compareTo(this.n) >= 0) {
            return false;
        }
        BigInteger m = RSAAlgorithm.rsa(s, this.n, this.exp, this.p, this.q, this.u);
        if (m.bitLength() > this.emLen * 8) {
            return false;
        }
        byte[] em = Util.toFixedLenByteArray(m, this.emLen);
        return this.pssVerify(this.md.digest(), em, this.getModulusBitLen() - 1);
    }

    private boolean pssVerify(byte[] mHash, byte[] em, int emBits) {
        if (emBits < 8 * this.hLen + 8 * this.sLen + 9) {
            return false;
        }
        if (em[em.length - 1] != -68) {
            return false;
        }
        int maskedDbLen = this.emLen - this.hLen - 1;
        byte[] maskedDb = new byte[maskedDbLen];
        System.arraycopy(em, 0, maskedDb, 0, maskedDbLen);
        byte[] H = new byte[this.hLen];
        System.arraycopy(em, maskedDbLen, H, 0, this.hLen);
        int lmbs = 8 * this.emLen - emBits;
        if ((maskedDb[0] & ~MASK[lmbs]) != 0) {
            return false;
        }
        byte[] dbMask = this.mgf1(H, this.emLen - this.hLen - 1);
        byte[] DB = RSASignature_PSS.xor(maskedDb, dbMask);
        int zc = 8 * this.emLen - emBits;
        DB[0] = (byte)(DB[0] & MASK[zc]);
        int leftMost = this.emLen - this.hLen - this.sLen - 2;
        int i = 0;
        while (i < leftMost) {
            if (DB[i] != 0) {
                return false;
            }
            ++i;
        }
        if (DB[leftMost] != 1) {
            return false;
        }
        byte[] salt = new byte[this.sLen];
        System.arraycopy(DB, DB.length - this.sLen, salt, 0, this.sLen);
        this.md.reset();
        this.md.update(new byte[8]);
        this.md.update(mHash);
        byte[] H1 = this.md.digest(salt);
        return cryptix.jce.util.Util.equals((byte[])H1, (byte[])H);
    }

    public RSASignature_PSS(String hashName) {
        try {
            this.md = MessageDigest.getInstance(hashName);
            this.hLen = this.sLen = this.md.getDigestLength();
        }
        catch (NoSuchAlgorithmException ex) {
            throw new InternalError("MessageDigest not found! (" + hashName + "): " + ex.toString());
        }
    }
}

