Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #612 from kperelygin/upstream
Browse files Browse the repository at this point in the history
Defined a new my_memory_system system class for memory.cu
  • Loading branch information
jaredhoberock committed Dec 16, 2014
2 parents bb903c0 + 86e24e8 commit b1098f0
Showing 1 changed file with 52 additions and 14 deletions.
66 changes: 52 additions & 14 deletions testing/memory.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,44 @@
#include <thrust/sequence.h>
#include <thrust/reverse.h>

// Define a new system class, as the my_system one is already used with a thrust::sort template definition
// that calls back into sort.cu
class my_memory_system : public thrust::device_execution_policy<my_memory_system>
{
public:
my_memory_system(int)
: correctly_dispatched(false),
num_copies(0)
{}

my_memory_system(const my_memory_system &other)
: correctly_dispatched(false),
num_copies(other.num_copies + 1)
{}

void validate_dispatch()
{
correctly_dispatched = (num_copies == 0);
}

bool is_valid()
{
return correctly_dispatched;
}

private:
bool correctly_dispatched;

// count the number of copies so that we can validate
// that dispatch does not introduce any
unsigned int num_copies;


// disallow default construction
my_memory_system();
};


template<typename T1, typename T2>
bool are_same(const T1 &, const T2 &)
{
Expand All @@ -27,7 +65,7 @@ void TestSelectSystemDifferentTypes()
{
using thrust::system::detail::generic::select_system;

my_system my_sys(0);
my_memory_system my_sys(0);
thrust::device_system_tag device_sys;

// select_system(my_system, device_system_tag) should return device_system_tag (the minimum tag)
Expand All @@ -45,7 +83,7 @@ void TestSelectSystemSameTypes()
{
using thrust::system::detail::generic::select_system;

my_system my_sys(0);
my_memory_system my_sys(0);
thrust::device_system_tag device_sys;
thrust::host_system_tag host_sys;

Expand Down Expand Up @@ -106,20 +144,20 @@ void TestMalloc()
DECLARE_UNITTEST(TestMalloc);


thrust::pointer<void,my_system>
malloc(my_system &system, std::size_t)
thrust::pointer<void,my_memory_system>
malloc(my_memory_system &system, std::size_t)
{
system.validate_dispatch();

return thrust::pointer<void,my_system>();
return thrust::pointer<void,my_memory_system>();
}


void TestMallocDispatchExplicit()
{
const size_t n = 0;

my_system sys(0);
my_memory_system sys(0);
thrust::malloc(sys, n);

ASSERT_EQUAL(true, sys.is_valid());
Expand All @@ -128,17 +166,17 @@ DECLARE_UNITTEST(TestMallocDispatchExplicit);


template<typename Pointer>
void free(my_system &system, Pointer)
void free(my_memory_system &system, Pointer)
{
system.validate_dispatch();
}


void TestFreeDispatchExplicit()
{
thrust::pointer<my_system,void> ptr;
thrust::pointer<my_memory_system,void> ptr;

my_system sys(0);
my_memory_system sys(0);
thrust::free(sys, ptr);

ASSERT_EQUAL(true, sys.is_valid());
Expand All @@ -147,14 +185,14 @@ DECLARE_UNITTEST(TestFreeDispatchExplicit);


template<typename T>
thrust::pair<thrust::pointer<T,my_system>, std::ptrdiff_t>
get_temporary_buffer(my_system &system, std::ptrdiff_t n)
thrust::pair<thrust::pointer<T,my_memory_system>, std::ptrdiff_t>
get_temporary_buffer(my_memory_system &system, std::ptrdiff_t n)
{
system.validate_dispatch();

thrust::device_system_tag device_sys;
thrust::pair<thrust::pointer<T, thrust::device_system_tag>, std::ptrdiff_t> result = thrust::get_temporary_buffer<T>(device_sys, n);
return thrust::make_pair(thrust::pointer<T,my_system>(result.first.get()), result.second);
return thrust::make_pair(thrust::pointer<T,my_memory_system>(result.first.get()), result.second);
}


Expand All @@ -166,7 +204,7 @@ void TestGetTemporaryBufferDispatchExplicit()
#else
const size_t n = 9001;

my_system sys(0);
my_memory_system sys(0);
typedef thrust::pointer<int, thrust::device_system_tag> pointer;
thrust::pair<pointer, std::ptrdiff_t> ptr_and_sz = thrust::get_temporary_buffer<int>(sys, n);

Expand Down Expand Up @@ -205,7 +243,7 @@ void TestGetTemporaryBufferDispatchImplicit()
thrust::reverse(vec.begin(), vec.end());

// call something we know will invoke get_temporary_buffer
my_system sys(0);
my_memory_system sys(0);
thrust::sort(sys, vec.begin(), vec.end());

ASSERT_EQUAL(true, thrust::is_sorted(vec.begin(), vec.end()));
Expand Down

0 comments on commit b1098f0

Please sign in to comment.