Skip to content

Commit

Permalink
Update caches
Browse files Browse the repository at this point in the history
- Use fnv1a hash and proper init value
- Change smart pointer type to shared in PKeyCache
- pkRSAUnseal keeps a vector of shared_ptrs to ensure they are around even if there is more than the SSLSERVICES_MAX_CACHE_SIZE
  • Loading branch information
jackdelv committed Jan 3, 2025
1 parent a1b1a44 commit 3e8e840
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 69 deletions.
139 changes: 72 additions & 67 deletions plugins/sslservices/sslservices.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ static const char* sslservicesCompatibleVersions[] = {
CURRENT_SSLSERVICES_VERSION,
NULL };

SSLSERVICES_API bool SSLSERVICES_CALL getECLPluginDefinition(ECLPluginDefinitionBlock* pb)
SSLSERVICES_API bool getECLPluginDefinition(ECLPluginDefinitionBlock* pb)
{
if (pb->size == sizeof(ECLPluginDefinitionBlockEx))
{
Expand Down Expand Up @@ -100,10 +100,30 @@ bool isPublicKey(size32_t keyLen, const char * key)

static constexpr size32_t SSLSERVICES_MAX_CACHE_SIZE = 10;
static constexpr bool PRINT_STATS = false;

/**
* Simple cache that is used for both ciphers and digests
*
* NOTE: Should only be used with the thread-local storage class
* specifier since the results cannot be relied on if called from
* multiple threads.
*/
template <typename T>
class AlgorithmCache
{
public:
AlgorithmCache()
{
setCacheName();
}

~AlgorithmCache()
{
if (PRINT_STATS)
LOG(MCmonitorMetric, "{ \"type\": \"metric\", \"name\": \"sslServiceCache%s\", \"hits\": \"%u\", \"misses\": \"%u\" }", cacheName.c_str(), hits, misses);
cache.clear();
}

const T * checkCache(const char * algorithm_name)
{
for (auto& c : cache)
Expand All @@ -128,20 +148,9 @@ class AlgorithmCache
return newObj;
}

void printStatistics() {DBGLOG("SSLSERVICES %s CACHE STATS: HITS = %d, MISSES = %d", cacheName.c_str(), hits, misses);}

void init()
{
setCacheName();
hits = 0;
misses = 0;
}

void clear() {cache.clear();}

private:
unsigned hits;
unsigned misses;
unsigned hits = 0;
unsigned misses = 0;
std::string cacheName;
std::list<std::tuple<std::string, const T *>> cache;

Expand All @@ -150,38 +159,43 @@ class AlgorithmCache
};

template <>
void AlgorithmCache<EVP_CIPHER>::setCacheName() {cacheName = "CIPHER";}
void AlgorithmCache<EVP_CIPHER>::setCacheName() {cacheName = "Cipher";}

template <>
void AlgorithmCache<EVP_MD>::setCacheName() {cacheName = "DIGEST";}
void AlgorithmCache<EVP_MD>::setCacheName() {cacheName = "Digest";}

template <>
const EVP_CIPHER * AlgorithmCache<EVP_CIPHER>::getObjectByName(const char * name) { return EVP_get_cipherbyname(name); }

template <>
const EVP_MD * AlgorithmCache<EVP_MD>::getObjectByName(const char * name) { return EVP_get_digestbyname(name); }

typedef std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)> UniquePKey;

// PEM Public/Private keys require parsing from a string
// Store the hash of the original string and parsed key
/**
* Simple cache that is used for both public and private keys.
*
* NOTE: Should only be used with the thread-local storage class
* specifier since the results cannot be relied on if called from
* multiple threads.
*/
class PKeyCache
{
public:
void init()
~PKeyCache()
{
hits = 0;
misses = 0;
if (PRINT_STATS)
LOG(MCmonitorMetric, "{ \"type\": \"metric\", \"name\": \"sslServiceCachePKey\", \"hits\": \"%u\", \"misses\": \"%u\" }", hits, misses);
cache.clear();
}

EVP_PKEY * checkCache(size32_t keyLen, const char * key, size32_t passphraseLen, const void * passphrase)
std::shared_ptr<EVP_PKEY> checkCache(size32_t keyLen, const char * key, size32_t passphraseLen, const void * passphrase)
{
hash64_t pkeyHash = hashc_fnv1a(static_cast<const byte *>(passphrase), passphraseLen, hashc_fnv1a(reinterpret_cast<const byte *>(key), keyLen, fnvInitialHash32));
for (auto& c : cache)
{
if (hashc(reinterpret_cast<const byte *>(passphrase), passphraseLen, hashc(reinterpret_cast<const byte *>(key), keyLen, 0)) == std::get<0>(c))
if (pkeyHash == std::get<0>(c))
{
hits++;
return std::get<1>(c).get();
return std::get<1>(c);
}
}

Expand All @@ -191,21 +205,20 @@ class PKeyCache
if (!bio)
failOpenSSLError("creating buffer for EVP_PKEY");

EVP_PKEY * pkey;
std::shared_ptr<EVP_PKEY> pkey;
if (isPublicKey(keyLen, key))
pkey = PEM_read_bio_PUBKEY(bio, nullptr, nullptr, nullptr);
pkey.reset(PEM_read_bio_PUBKEY(bio, nullptr, nullptr, nullptr), EVP_PKEY_free);
else
{
MemoryBuffer passphraseMB;
passphraseMB.setBuffer(passphraseLen, (void *)passphrase);
pkey = PEM_read_bio_PrivateKey(bio, nullptr, passphraseCB, static_cast<void *>(&passphraseMB));
pkey.reset(PEM_read_bio_PrivateKey(bio, nullptr, passphraseCB, static_cast<void *>(&passphraseMB)), EVP_PKEY_free);
}
BIO_free(bio);

if (pkey)
{
unsigned PkeyHash = hashc(reinterpret_cast<const byte *>(passphrase), passphraseLen, hashc(reinterpret_cast<const byte *>(key), keyLen, 0));
cache.emplace_front(PkeyHash, std::move(UniquePKey(pkey, EVP_PKEY_free)));
cache.emplace_front(pkeyHash, pkey);
if (cache.size() > SSLSERVICES_MAX_CACHE_SIZE)
cache.pop_back();
}
Expand All @@ -215,14 +228,10 @@ class PKeyCache
return pkey;
}

void clear() {cache.clear();}

void printStatistics() {DBGLOG("SSLSERVICES PKEY CACHE STATS: HITS = %d, MISSES = %d", hits, misses);}

private:
unsigned hits;
unsigned misses;
std::list<std::tuple<unsigned, UniquePKey>> cache;
unsigned hits = 0;
unsigned misses = 0;
std::list<std::tuple<hash64_t, std::shared_ptr<EVP_PKEY>>> cache;
};


Expand Down Expand Up @@ -507,12 +516,16 @@ SSLSERVICES_API void SSLSERVICES_CALL cipherDecrypt(ICodeContext *ctx, size32_t

SSLSERVICES_API void SSLSERVICES_CALL pkRSASeal(ICodeContext *ctx, size32_t & __lenResult, void * & __result, size32_t len_plaintext, const void * _plaintext, bool isAll_pem_public_keys, size32_t len_pem_public_keys, const void * _pem_public_keys, const char * _algorithm_name)
{
__result = nullptr;
__lenResult = 0;

// Initial sanity check of our arguments
if (len_pem_public_keys == 0)
rtlFail(-1, "No public keys provided");

if (!isAll_pem_public_keys && len_plaintext > 0)
{
std::vector<std::shared_ptr<EVP_PKEY>> publicKeysSP;
std::vector<EVP_PKEY *> publicKeys;
EVP_CIPHER_CTX * encryptCtx = nullptr;
byte ** encryptedKeys = nullptr;
Expand All @@ -528,7 +541,8 @@ SSLSERVICES_API void SSLSERVICES_CALL pkRSASeal(ICodeContext *ctx, size32_t & __
{
const size32_t keySize = *(reinterpret_cast<const size32_t *>(pubKeyPtr));
pubKeyPtr += sizeof(keySize);
publicKeys.push_back(pkeyCache.checkCache(keySize, pubKeyPtr, 0, nullptr));
publicKeysSP.push_back(pkeyCache.checkCache(keySize, pubKeyPtr, 0, nullptr));
publicKeys.push_back(publicKeysSP.back().get());
pubKeyPtr += keySize;
}

Expand Down Expand Up @@ -598,6 +612,8 @@ SSLSERVICES_API void SSLSERVICES_CALL pkRSASeal(ICodeContext *ctx, size32_t & __

// Cleanup
EVP_CIPHER_CTX_free(encryptCtx);
for (size_t i = 0; i < publicKeys.size(); i++)
delete [] encryptedKeys[i];
delete [] encryptedKeys;
}
catch (...)
Expand All @@ -617,6 +633,9 @@ SSLSERVICES_API void SSLSERVICES_CALL pkRSASeal(ICodeContext *ctx, size32_t & __

SSLSERVICES_API void SSLSERVICES_CALL pkRSAUnseal(ICodeContext *ctx, size32_t & __lenResult, void * & __result, size32_t len_ciphertext, const void * _ciphertext, size32_t len_passphrase, const void * _passphrase, size32_t len_pem_private_key, const char * _pem_private_key, const char * _algorithm_name)
{
__result = nullptr;
__lenResult = 0;

// Initial sanity check of our arguments
if (len_pem_private_key == 0)
rtlFail(-1, "No private key provided");
Expand All @@ -631,13 +650,13 @@ SSLSERVICES_API void SSLSERVICES_CALL pkRSAUnseal(ICodeContext *ctx, size32_t &
try
{
// Load the private key
EVP_PKEY * privateKey = pkeyCache.checkCache(len_pem_private_key, _pem_private_key, len_passphrase, _passphrase);
std::shared_ptr<EVP_PKEY> privateKey = pkeyCache.checkCache(len_pem_private_key, _pem_private_key, len_passphrase, _passphrase);

// Load the cipher
const EVP_CIPHER * cipher = cipherCache.checkCache(_algorithm_name);

// Allocate memory for the symmetric key and IV
size32_t keyLen = EVP_PKEY_size(privateKey);
size32_t keyLen = EVP_PKEY_size(privateKey.get());
symmetricKey.ensureCapacity(keyLen);
size32_t ivLen = EVP_CIPHER_iv_length(cipher);
iv.ensureCapacity(ivLen);
Expand Down Expand Up @@ -674,7 +693,7 @@ SSLSERVICES_API void SSLSERVICES_CALL pkRSAUnseal(ICodeContext *ctx, size32_t &
bool found = false;
for (auto& encryptedKey : encryptedKeys)
{
if (EVP_OpenInit(decryptCtx, cipher, reinterpret_cast<const unsigned char *>(encryptedKey.data()), encryptedKey.size(), static_cast<byte *>(iv.bufferBase()), privateKey) == 1)
if (EVP_OpenInit(decryptCtx, cipher, reinterpret_cast<const unsigned char *>(encryptedKey.data()), encryptedKey.size(), static_cast<byte *>(iv.bufferBase()), privateKey.get()) == 1)
{
found = true;
break;
Expand All @@ -698,9 +717,7 @@ SSLSERVICES_API void SSLSERVICES_CALL pkRSAUnseal(ICodeContext *ctx, size32_t &

// Copy to the ECL result buffer
__lenResult = plaintextLen;
MemoryBuffer resultBuffer(__lenResult);
resultBuffer.append(__lenResult, plaintext.bufferBase());
__result = resultBuffer.detachOwn();
__result = plaintext.detachOwn();

// Cleanup
EVP_CIPHER_CTX_free(decryptCtx);
Expand Down Expand Up @@ -733,10 +750,10 @@ SSLSERVICES_API void SSLSERVICES_CALL pkEncrypt(ICodeContext *ctx, size32_t & __
try
{
// Load key from buffer
EVP_PKEY * publicKey = pkeyCache.checkCache(len_pem_public_key, _pem_public_key, 0, nullptr);
std::shared_ptr<EVP_PKEY> publicKey = pkeyCache.checkCache(len_pem_public_key, _pem_public_key, 0, nullptr);

// Create encryption context
encryptCtx = EVP_PKEY_CTX_new(publicKey, nullptr);
encryptCtx = EVP_PKEY_CTX_new(publicKey.get(), nullptr);
if (!encryptCtx)
failOpenSSLError("publicKey");
if (EVP_PKEY_encrypt_init(encryptCtx) <= 0)
Expand Down Expand Up @@ -785,10 +802,10 @@ SSLSERVICES_API void SSLSERVICES_CALL pkDecrypt(ICodeContext *ctx, size32_t & __
try
{
// Load key from buffer
EVP_PKEY * privateKey = pkeyCache.checkCache(len_pem_private_key, _pem_private_key, len_passphrase, _passphrase);
std::shared_ptr<EVP_PKEY> privateKey = pkeyCache.checkCache(len_pem_private_key, _pem_private_key, len_passphrase, _passphrase);

// Create decryption context
decryptCtx = EVP_PKEY_CTX_new(privateKey, nullptr);
decryptCtx = EVP_PKEY_CTX_new(privateKey.get(), nullptr);
if (!decryptCtx)
failOpenSSLError("EVP_PKEY_CTX_new");
if (EVP_PKEY_decrypt_init(decryptCtx) <= 0)
Expand Down Expand Up @@ -824,11 +841,13 @@ SSLSERVICES_API void SSLSERVICES_CALL pkDecrypt(ICodeContext *ctx, size32_t & __
SSLSERVICES_API void SSLSERVICES_CALL pkSign(ICodeContext *ctx, size32_t & __lenResult, void * & __result, size32_t len_plaintext, const void * _plaintext, size32_t len_passphrase, const void * _passphrase, size32_t len_pem_private_key, const char * _pem_private_key, const char * _algorithm_name)
{
EVP_MD_CTX *mdCtx = nullptr;
__result = nullptr;
__lenResult = 0;

try
{
// Load the private key from the PEM string
EVP_PKEY * privateKey = pkeyCache.checkCache(len_pem_private_key, _pem_private_key, len_passphrase, _passphrase);
std::shared_ptr<EVP_PKEY> privateKey = pkeyCache.checkCache(len_pem_private_key, _pem_private_key, len_passphrase, _passphrase);

// Create and initialize the message digest context
mdCtx = EVP_MD_CTX_new();
Expand All @@ -837,7 +856,7 @@ SSLSERVICES_API void SSLSERVICES_CALL pkSign(ICodeContext *ctx, size32_t & __len

const EVP_MD *md = digestCache.checkCache(_algorithm_name);

if (EVP_DigestSignInit(mdCtx, nullptr, md, nullptr, privateKey) <= 0)
if (EVP_DigestSignInit(mdCtx, nullptr, md, nullptr, privateKey.get()) <= 0)
failOpenSSLError("EVP_DigestSignInit (pkSign)");

// Add plaintext to context
Expand Down Expand Up @@ -880,7 +899,7 @@ SSLSERVICES_API bool SSLSERVICES_CALL pkVerifySignature(ICodeContext *ctx, size3
try
{
// Load the public key from the PEM string
EVP_PKEY * publicKey = pkeyCache.checkCache(len_pem_public_key, _pem_public_key, 0, nullptr);
std::shared_ptr<EVP_PKEY> publicKey = pkeyCache.checkCache(len_pem_public_key, _pem_public_key, 0, nullptr);

// Create and initialize the message digest context
mdCtx = EVP_MD_CTX_new();
Expand All @@ -889,7 +908,7 @@ SSLSERVICES_API bool SSLSERVICES_CALL pkVerifySignature(ICodeContext *ctx, size3

const EVP_MD *md = digestCache.checkCache(_algorithm_name);

if (EVP_DigestVerifyInit(mdCtx, nullptr, md, nullptr, publicKey) <= 0)
if (EVP_DigestVerifyInit(mdCtx, nullptr, md, nullptr, publicKey.get()) <= 0)
failOpenSSLError("EVP_DigestVerifyInit (pkVerifySignature)");

if (EVP_DigestVerifyUpdate(mdCtx, _signedData, len_signedData) <= 0)
Expand All @@ -912,23 +931,9 @@ SSLSERVICES_API bool SSLSERVICES_CALL pkVerifySignature(ICodeContext *ctx, size3

MODULE_INIT(INIT_PRIORITY_STANDARD)
{
pkeyCache.init();
digestCache.init();
cipherCache.init();

return true;
}

MODULE_EXIT()
{
if(PRINT_STATS)
{
pkeyCache.printStatistics();
digestCache.printStatistics();
cipherCache.printStatistics();
}

pkeyCache.clear();
digestCache.clear();
cipherCache.clear();
}
7 changes: 5 additions & 2 deletions plugins/sslservices/sslservices.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@
#include "eclrtl_imp.hpp"
#include "eclhelper.hpp"

extern "C++"
extern "C"
{
SSLSERVICES_API bool SSLSERVICES_CALL getECLPluginDefinition(ECLPluginDefinitionBlock *pb);
SSLSERVICES_API bool getECLPluginDefinition(ECLPluginDefinitionBlock *pb);
}

extern "C++"
{
// Digest functions
SSLSERVICES_API void SSLSERVICES_CALL digestAvailableAlgorithms(ICodeContext *ctx, size32_t & __lenResult, void * & __result);
SSLSERVICES_API void SSLSERVICES_CALL digestHash(ICodeContext *ctx, size32_t & __lenResult, void * & __result, size32_t len_indata, const void * _indata, const char * _algorithm_name);
Expand Down

0 comments on commit 3e8e840

Please sign in to comment.