Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix potential race condition in GetResultOrRunClassInitialize #4555

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 74 additions & 62 deletions src/Adapter/MSTest.TestAdapter/Execution/TestClassInfo.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.Extensions;
Expand Down Expand Up @@ -254,6 +254,7 @@ public void RunClassInitialize(TestContext testContext)
// If no class initialize and no base class initialize, return
if (ClassInitializeMethod is null && BaseClassInitMethods.Count == 0)
{
DebugEx.Assert(false, "Caller shouldn't call us if nothing to execute");
IsClassInitializeExecuted = true;
return;
}
Expand All @@ -270,45 +271,37 @@ public void RunClassInitialize(TestContext testContext)
string? failedClassInitializeMethodName = string.Empty;

// If class initialization is not done, then do it.
DebugEx.Assert(!IsClassInitializeExecuted, "Caller shouldn't call us if it was executed.");
if (!IsClassInitializeExecuted)
{
// Acquiring a lock is usually a costly operation which does not need to be
// performed every time if the class initialization is already executed.
lock (_testClassExecuteSyncObject)
try
{
// Perform a check again.
if (!IsClassInitializeExecuted)
// We have discovered the methods from bottom (most derived) to top (less derived) but we want to execute
// from top to bottom.
for (int i = BaseClassInitMethods.Count - 1; i >= 0; i--)
{
try
initializeMethod = BaseClassInitMethods[i];
ClassInitializationException = InvokeInitializeMethod(initializeMethod, testContext);
if (ClassInitializationException is not null)
{
// We have discovered the methods from bottom (most derived) to top (less derived) but we want to execute
// from top to bottom.
for (int i = BaseClassInitMethods.Count - 1; i >= 0; i--)
{
initializeMethod = BaseClassInitMethods[i];
ClassInitializationException = InvokeInitializeMethod(initializeMethod, testContext);
if (ClassInitializationException is not null)
{
break;
}
}

if (ClassInitializationException is null)
{
initializeMethod = ClassInitializeMethod;
ClassInitializationException = InvokeInitializeMethod(ClassInitializeMethod, testContext);
}
}
catch (Exception ex)
{
ClassInitializationException = ex;
failedClassInitializeMethodName = initializeMethod?.Name ?? ClassInitializeMethod?.Name;
}
finally
{
IsClassInitializeExecuted = true;
break;
}
}

if (ClassInitializationException is null)
{
initializeMethod = ClassInitializeMethod;
ClassInitializationException = InvokeInitializeMethod(ClassInitializeMethod, testContext);
}
}
catch (Exception ex)
{
ClassInitializationException = ex;
failedClassInitializeMethodName = initializeMethod?.Name ?? ClassInitializeMethod?.Name;
}
finally
{
IsClassInitializeExecuted = true;
}
}

Expand Down Expand Up @@ -385,8 +378,6 @@ internal UnitTestResult GetResultOrRunClassInitialize(ITestContext testContext,
return clonedInitializeResult;
}

DebugEx.Assert(!IsClassInitializeExecuted, "If class initialize was executed, we should have been in the previous if were we have a result available.");

// For optimization purposes, return right away if there is nothing to execute.
// For STA, this avoids starting a thread when we know it will do nothing.
// But we still return early even not STA.
Expand All @@ -396,41 +387,62 @@ internal UnitTestResult GetResultOrRunClassInitialize(ITestContext testContext,
return _classInitializeResult = new(ObjectModelUnitTestOutcome.Passed, null);
}

bool isSTATestClass = AttributeComparer.IsDerived<STATestClassAttribute>(ClassAttribute);
bool isWindowsOS = RuntimeInformation.IsOSPlatform(OSPlatform.Windows);
if (isSTATestClass
&& isWindowsOS
&& Thread.CurrentThread.GetApartmentState() != ApartmentState.STA)
// At this point, maybe class initialize was executed by another thread such
// that TryGetClonedCachedClassInitializeResult would return non-null.
// Now, we need to check again, but under a lock.
// Note that we are duplicating the logic above.
// We could keep the logic in lock only and not duplicate, but we don't want to pay
// the lock cost unnecessarily for a common case.
// We also need to lock to avoid concurrency issues and guarantee that class init is called only once.
lock (_testClassExecuteSyncObject)
{
UnitTestResult result = new(ObjectModelUnitTestOutcome.Error, "MSTest STATestClass ClassInitialize didn't complete");
Thread entryPointThread = new(() => result = DoRun())
{
Name = "MSTest STATestClass ClassInitialize",
};

entryPointThread.SetApartmentState(ApartmentState.STA);
entryPointThread.Start();
clonedInitializeResult = TryGetClonedCachedClassInitializeResult();

try
// Optimization: If we already ran before and know the result, return it.
if (clonedInitializeResult is not null)
{
entryPointThread.Join();
return result;
DebugEx.Assert(IsClassInitializeExecuted, "Class initialize result should be available if and only if class initialize was executed");
return clonedInitializeResult;
}
catch (Exception ex)

DebugEx.Assert(!IsClassInitializeExecuted, "If class initialize was executed, we should have been in the previous if were we have a result available.");

bool isSTATestClass = AttributeComparer.IsDerived<STATestClassAttribute>(ClassAttribute);
bool isWindowsOS = RuntimeInformation.IsOSPlatform(OSPlatform.Windows);
if (isSTATestClass
&& isWindowsOS
&& Thread.CurrentThread.GetApartmentState() != ApartmentState.STA)
{
PlatformServiceProvider.Instance.AdapterTraceLogger.LogError(ex.ToString());
return new UnitTestResult(new TestFailedException(ObjectModelUnitTestOutcome.Error, ex.TryGetMessage(), ex.TryGetStackTraceInformation()));
UnitTestResult result = new(ObjectModelUnitTestOutcome.Error, "MSTest STATestClass ClassInitialize didn't complete");
Thread entryPointThread = new(() => result = DoRun())
{
Name = "MSTest STATestClass ClassInitialize",
};

entryPointThread.SetApartmentState(ApartmentState.STA);
entryPointThread.Start();

try
{
entryPointThread.Join();
return result;
}
catch (Exception ex)
{
PlatformServiceProvider.Instance.AdapterTraceLogger.LogError(ex.ToString());
return new UnitTestResult(new TestFailedException(ObjectModelUnitTestOutcome.Error, ex.TryGetMessage(), ex.TryGetStackTraceInformation()));
}
}
}
else
{
// If the requested apartment state is STA and the OS is not Windows, then warn the user.
if (!isWindowsOS && isSTATestClass)
else
{
PlatformServiceProvider.Instance.AdapterTraceLogger.LogWarning(Resource.STAIsOnlySupportedOnWindowsWarning);
}
// If the requested apartment state is STA and the OS is not Windows, then warn the user.
if (!isWindowsOS && isSTATestClass)
{
PlatformServiceProvider.Instance.AdapterTraceLogger.LogWarning(Resource.STAIsOnlySupportedOnWindowsWarning);
}

return DoRun();
return DoRun();
}
}

// Local functions
Expand Down
Loading