Radix sort

01.05.2009 @ 01:50:08 by Rafał Kozik | programowanie C++

Dawno, dawno temu na slashdot pojawił się wpis Sort Linked Lists 10X Faster Than MergeSort. Okazało się, że twórca algorytmu wykazał się brakiem wiedzy i zaproponował algorytm Radix Sort.

Dzisiaj sam zaimplementowałem ten algorytm (okazało się to prostsze niż się spodziewałem) chcąc sprawdzić jak wypada w porównaniu do qsort z cstdlib i sort z STL.

Zrobiłem testy losując sporą ilość intów, a następnie sortując je każdym z algorytmów. Na koniec sprawdzam czy wynik działania wszystkich algorytmów jest taki sam. Wyniki okazały się zaskakujące -- nie spodziewałem się aż tak dużych różnic czasu. Wyniki poniżej.



Obrazek

Pozioma oś to rozmiar danych testowych, a pionowa to czas wykonania w milisekundach. Jak widać Radix sort sprawdził się wyśmienicie. Kod kompilowany w VC++ 2005 EE z /O2.

Oczywiście nie jest to algorytm idealny. Po pierwsze robi duże założenia co do budowy danych. Po drugie nie działa w miejscu i wymaga drugie tyle pamięci co dane wejściowe. Drugi problem da się pewnie rozwiązać, ale domyślam się, że byłoby z tym trochę zabawy.

Program na którym zostały wykonane testy (tak, wiem, nie ustawiam seeda dla randa, więc zawsze działa tak samo ;)):

#include <iostream>
#include <algorithm>
#include <windows.h>
using namespace std;

// komparator dla qsort
int comparator(const void* a, const void* b)
{
	return *(int*)a - *(int*)b;
}

// radix napisany na szybko
void radix(int* data, const int count)
{
	int countsInternal[257];
	int* counts = &countsInternal[1];
	
	int* buff = new int[count];

	int* in = data;
	int* out = buff;

	for (int b=0; b < 4; b++)
	{
		const int shift = b * 8;

		memset(countsInternal, 0, sizeof(countsInternal));

		for (int i = count - 1; i >= 0; i--)
		{
			counts[(in[i] >> shift) & 255]++;
		}

		for (int i = 0; i < 256; i++) counts[i] += counts[i-1];

		for (int i = 0; i < count; i++)
		{
			int v = (in[i] >> shift) & 255;
			out[countsInternal[v]++] = in[i];
		}

		int* tmp = in;
		in = out;
		out = tmp;
	}

	delete[] buff;
}

class SimpleProfiler
{
	const char* text;
	int start;
public:
	SimpleProfiler(const char* text) : text(text)
	{
		start = timeGetTime();
	}

	~SimpleProfiler()
	{
		int diff = timeGetTime() - start;
		cout << text << diff << endl;
	}
};

bool compareResults(int** data, const int nTests, const int testSize)
{
	for (int i = 0; i < testSize; i++)
	{
		for (int t = 1; t < nTests; t++)
			if (data[t][i] != data[t-1][i]) return false;
	}

	return true;
}

bool doTest(const int testSize = 1000000)
{
	const int nTests = 3;

	cout << "Test size: " << testSize << endl;

	int** data = new int*[nTests];
	
	for (int i = 0; i < nTests; i++) data[i] = new int[testSize];

	// RAND_MAX w VC++ był 32768 
	for (int i = 0; i< testSize; i++)
		data[0][i] = (rand() << 16) | rand();

	for (int i = 1; i < nTests; i++)
		memcpy(data[i], data[0], testSize * sizeof(int));

	int test = 0;

	{
		SimpleProfiler profiler("qsort - stdlib:\t ");
		qsort(data[test], testSize, sizeof(int), comparator);
		test++;
	}

	{
		SimpleProfiler profiler("sort - stl:\t ");
		sort(data[test], data[test] + testSize);
		test++;
	}

	{
		SimpleProfiler profiler("radix sort:\t ");
		radix(data[test], testSize);
		test++;
	}

	bool result = compareResults(data, nTests, testSize);

	for (int i = 0; i < nTests; i++) delete[] data[i];
	delete[] data;

	return result;
}

int main()
{
	for (int n = 500000; n <= 10000000; n += 500000)
	{
		cout << (doTest(n) ? "ok" : "fail") << endl << endl;
	}
}

Komentowanie zostało tymczasowo wyłączone.