diff --git a/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProvider.h b/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProvider.h index 32936785bdf..314bc555d2e 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProvider.h +++ b/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProvider.h @@ -69,12 +69,19 @@ namespace Aws */ virtual AWSCredentials GetAWSCredentials() = 0; + /** + * Forces reloading the credentials returned by GetAWSCredentials(). + * Credentials are reloaded when they're expired or due to their reload frequency or after this function is called. + */ + virtual void SetNeedRefresh(); + protected: /** * The default implementation keeps up with the cache times and lets you know if it's time to refresh your internal caching * to aid your implementation of GetAWSCredentials. */ virtual bool IsTimeToRefresh(long reloadFrequency); + virtual bool IsSetNeedRefresh(); virtual void Reload(); mutable Aws::Utils::Threading::ReaderWriterLock m_reloadLock; private: diff --git a/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProviderChain.h b/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProviderChain.h index 58bc8eb0977..4d44390b92a 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProviderChain.h +++ b/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProviderChain.h @@ -29,6 +29,8 @@ namespace Aws */ virtual AWSCredentials GetAWSCredentials(); + virtual void SetNeedRefresh(); + /** * Gets all providers stored in this chain. */ diff --git a/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp b/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp index e2704ced88f..f7ac1db90fe 100644 --- a/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp +++ b/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp @@ -62,6 +62,22 @@ bool AWSCredentialsProvider::IsTimeToRefresh(long reloadFrequency) return false; } +void AWSCredentialsProvider::SetNeedRefresh() +{ + ReaderLockGuard guard(m_reloadLock); + if (m_lastLoadedMs != 0) + { + guard.UpgradeToWriterLock(); + m_lastLoadedMs = 0; + } +} + +bool AWSCredentialsProvider::IsSetNeedRefresh() +{ + /// This function is called from implementations of RefreshIfExpired() at the point when m_reloadLock is locked. + return m_lastLoadedMs == 0; +} + static const char* ENVIRONMENT_LOG_TAG = "EnvironmentAWSCredentialsProvider"; diff --git a/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp b/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp index fee416aa7af..71b4900747e 100644 --- a/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp +++ b/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp @@ -40,6 +40,19 @@ AWSCredentials AWSCredentialsProviderChain::GetAWSCredentials() return AWSCredentials(); } +void AWSCredentialsProviderChain::SetNeedRefresh() +{ + for (auto&& credentialsProvider : m_providerChain) + credentialsProvider->SetNeedRefresh(); + + ReaderLockGuard lock(m_cachedProviderLock); + if (m_cachedProvider) + { + lock.UpgradeToWriterLock(); + m_cachedProvider.reset(); + } +} + DefaultAWSCredentialsProviderChain::DefaultAWSCredentialsProviderChain() : AWSCredentialsProviderChain() { AddProvider(Aws::MakeShared(DefaultCredentialsProviderChainTag)); diff --git a/src/aws-cpp-sdk-core/source/auth/GeneralHTTPCredentialsProvider.cpp b/src/aws-cpp-sdk-core/source/auth/GeneralHTTPCredentialsProvider.cpp index 5f3ed7eceb6..9be9ee054b4 100644 --- a/src/aws-cpp-sdk-core/source/auth/GeneralHTTPCredentialsProvider.cpp +++ b/src/aws-cpp-sdk-core/source/auth/GeneralHTTPCredentialsProvider.cpp @@ -240,11 +240,15 @@ void GeneralHTTPCredentialsProvider::Reload() token = credentialsView.GetString("Token"); AWS_LOGSTREAM_DEBUG(GEN_HTTP_LOG_TAG, "Successfully pulled credentials from metadata service with access key " << accessKey); + auto old_credentials = m_credentials; + m_credentials.SetAWSAccessKeyId(accessKey); m_credentials.SetAWSSecretKey(secretKey); m_credentials.SetSessionToken(token); m_credentials.SetExpiration(Aws::Utils::DateTime(credentialsView.GetString("Expiration"), Aws::Utils::DateFormat::ISO_8601)); AWSCredentialsProvider::Reload(); + + AWS_LOGSTREAM_DEBUG(GEN_HTTP_LOG_TAG, "Got " << ((m_credentials == old_credentials) ? "same " : "") << "credentials from ECSCredentialService."); } void GeneralHTTPCredentialsProvider::RefreshIfExpired()