/** $Id$
 * Author: Markus Cozowicz (eisber)
 * Date: 26/04/2005
 *
 * Purpose: Educational implementation of Strassen/Winograd's fast matrix multiplication
 *
 * Todo: Optimize temporary object creation (maybe const)
 * 	 Document the less obvious parts
 */

#include <iostream>
#include <iomanip>
#include <time.h>
#include <conio.h>

#include "Matrix.h"

using namespace std;

struct Matrix
{
	int m_row;
	int m_col;

	//int m_viewlen;
	//int m_len;
	// data is just behind this structure
	int* m_data;
};

// Memory Mgmt
void** memory; // temporary results
void** memoryBounds;
int* currentBounds;
int* maxBounds;

Matrix** memoryViews; 
Matrix** memoryViewBounds;
int* currentViewBounds;
int* maxViewBounds;

// maxSize = 2^n
void init(size_t maxSize)
{
	// waste mem to enable fast lookup
	memory =       (void**)malloc(maxSize * sizeof(void*));
	memoryBounds = (void**)malloc(maxSize * sizeof(void*));
	memoryViews =  (Matrix**)malloc(maxSize * sizeof(Matrix*));
	memoryViewBounds = (Matrix**)malloc(maxSize * sizeof(Matrix*));

	maxBounds = (int*)calloc(maxSize, sizeof(int));
	currentBounds = (int*)calloc(maxSize, sizeof(int));
	maxViewBounds = (int*)calloc(maxSize, sizeof(int));
	currentViewBounds = (int*)calloc(maxSize, sizeof(int));

	// for each layer we need 4 times more mem
	size_t overall = 0;
	int fac = 8;
	for(size_t i=maxSize,j=1;i>1;i>>=1,j*=fac) // i /= 2, j *= 4
	{
		if(j == 8)
			j = 18;

		size_t m = (i*i*sizeof(int)+sizeof(Matrix))*j;
		overall += m;

		memory[i] = malloc(m);
		memoryBounds[i] = ((char*)memory[i]) + m;

		m = sizeof(Matrix)*j;
		overall += m;
		memoryViews[i] = (Matrix*)malloc(m);
		memoryViewBounds[i] = memoryViews[i] + j;
#ifdef _DEBUG
		printf("ALLOC_VIEW[%6d]: %15zu base: %p bounds: %p\n", i, j, memoryViews[i], memoryViewBounds[i]);
#endif
		if(j >= 18 && fac > 1.5)
			fac -= 0.3;
	}
	cout << "allocated " << (overall/1024/1025) << "mb" << endl;
}

template<int len>
Matrix* allocate()
{
#ifdef _DEBUG
	if(memory[len] > memoryBounds[len])
	{
		printf("not enough memory in %d memory: base: %p bounds: %p\n", len, memory[len], memoryBounds[len]);
		getch();
		exit(1);
	}
//	currentBounds[len]++;
//	if(maxBounds[len] < currentBounds[len])
//		maxBounds[len]++;
#endif
	
	Matrix* m = (Matrix*)memory[len];
	memory[len] = ((char*)memory[len]) + len*len*sizeof(int) + sizeof(Matrix);

	m->m_row = m->m_col = 0;

	// data is just behind this structure
	m->m_data = (int*)(m + 1);
	// zero everything
	memset(m->m_data, 0, len*len * sizeof(int)); 

	//printf("allocate %d: %p %p\n", len, m, m->m_data);
	return m;
}

template<>
Matrix* allocate<2>()
{
	Matrix* m = (Matrix*)memory[2];
	memory[2] = ((char*)memory[2]) + 2*2*sizeof(int) + sizeof(Matrix);

	m->m_row = m->m_col = 0;
	m->m_data = (int*)(m + 1);
	
	long long* d = (long long*)(m->m_data);
	d[0] = d[1] = 0;

	return m;
}


template<int viewlen>
Matrix* allocate(Matrix* source, int row, int col)
{
	return allocate<viewlen>(source->m_data, row + source->m_row, col + source->m_col);
}
template<int viewlen>
Matrix* allocate(int* m_data, int row, int col)
{
#ifdef _DEBUG
	if(memoryViews[viewlen] > memoryViewBounds[viewlen])
	{
		printf("not enough memory in %d memoryViews: base: %p bounds: %p\n", viewlen, memoryViews[viewlen], memoryViewBounds[viewlen]);
		getch();
		exit(1);
	}
//	currentViewBounds[viewlen]++;
//	if(maxViewBounds[viewlen] < currentViewBounds[viewlen])
//		maxViewBounds[viewlen]++;
//
#endif

	Matrix* m = memoryViews[viewlen];
	memoryViews[viewlen]++;

	m->m_row     = row;
	m->m_col     = col;
	//m->m_viewlen = viewlen;
	//m->m_len     = len;

	m->m_data = m_data;

	//printf("allocate_view %d: %p\n", viewlen, m);

	return m;
}

// will free anything from that point down
template<int len>
void deallocate(void* topMost)
{
//#ifdef _DEBUG
//	currentBounds[len]--;
//#endif
	//printf("deallocate %d: %p\n", len, topMost);
	memory[len] = topMost;
}

template<int len>
void deallocate_view(Matrix* topMost)
{
//#ifdef _DEBUG
//	currentViewBounds[len]--;
//#endif
	//printf("deallocate_view %d: %p\n", len, topMost);
	memoryViews[len] = topMost;
}


//Matrix::Matrix(int* init_data, int len) : m_len(len), m_viewlen(len), m_row(0), m_col(0)
//{
//	//printf("assign  %p %d\n", init_data, len);
//	m_data = init_data;
//	ref_count = new int;
//	*ref_count = 1; // don't deallocate
//	m_doDelete = true;
//}
//
//Matrix::Matrix(const Matrix& d)
//{
//	m_data = d.m_data;
//	m_doDelete = d.m_doDelete;
//	if(m_doDelete)
//	{
//		ref_count = d.ref_count;
//		(*ref_count)++;
//	}
//
//	m_len =     d.m_len;
//	m_viewlen = d.m_viewlen;
//	m_row =     d.m_row;
//	m_col =     d.m_col;
//	
//	//printf("copy    %p %p %d\n", m_data, this, *ref_count); 
//}
//
//Matrix::Matrix(int len) : m_len(len), m_viewlen(len), m_row(0), m_col(0)
//{
//	m_data = new int[len*len];
//	for(int i=0;i<len*len;i++)
//		m_data[i] = 0;
//	if(m_data == NULL)
//	{
//		cerr << "Could not allocate memory of size " << (sizeof(int)*len*len) << endl;
//		exit(1);
//	}
//	//printf("created %p %p of size %d\n", m_data, this, len*len*sizeof(int)); 
//	ref_count  = new int;
//	*ref_count = 1;
//	m_doDelete = true;
//}
//
//Matrix::Matrix(const Matrix* source, int row, int col, int viewlen)
//{
//	m_data = source->m_data;
//	m_doDelete = false;
//
//	m_len = source->m_len;	
//	
//	m_viewlen = viewlen;
//	m_row     = row + source->m_row;
//	m_col     = col + source->m_col;
//
//	//printf("view    %p %p %d\n", m_data, this, *ref_count); 
//}
//
//Matrix::~Matrix()
//{
//	//printf("dealloc %p %p %d\n", m_data, this, *ref_count); 
//	if(m_doDelete)
//	{
//		if(*ref_count == 1)
//		{
//			//printf("freeing %p\n", m_data);
//
//			delete[] m_data;
//			delete ref_count;
//
//			
//			//m_data = NULL;
//			//printf("freeing.2 %p\n", m_data);
//
//			//ref_count = NULL;
//			//printf("freeing.3 %p\n", m_data);
//		}
//		else
//			(*ref_count)--;
//	}
//	//printf("freeing.4 %p\n", m_data);
//}

template<int len, int viewlen>
Matrix* add(Matrix* a, Matrix* b)
{
	Matrix* c = allocate<viewlen>();
	int ra = a->m_row * len + a->m_col;
	int rb = b->m_row * len + b->m_col;
	int rc = 0; 

	for(int row=0;row<viewlen;row++)
	{
		for(int col=0;col<viewlen;col++)
		{
			//cout << " (" << a->m_data[ra + col] << " + " << b->m_data[rb + col] << ")";
			c->m_data[rc + col] = a->m_data[ra + col] + b->m_data[rb + col];
		}
		//cout << endl;
		ra += len;
		rb += len;
		rc += viewlen;
	}
	return c;
}

template<int len, int viewlen>
Matrix* sub(Matrix* a, Matrix* b)
{
	Matrix* c = allocate<viewlen>();
	int ra = a->m_row * len + a->m_col;
	int rb = b->m_row * len + b->m_col;
	int rc = 0; 

	for(int row=0;row<viewlen;row++)
	{
		for(int col=0;col<viewlen;col++)
			c->m_data[rc + col] = a->m_data[ra + col] - b->m_data[rb + col];

		ra += len;
		rb += len;
		rc += viewlen;
	}
	return c;
}

// a is always data
template<int a_len, int b_len, int viewlen>
void addInPlace(Matrix* a, Matrix* b)
{
	int ra = 0; // a->m_row * a_len + a->m_col;
	int rb = b->m_row * b_len + b->m_col;

	for(int row=0;row<viewlen;row++)
	{
		for(int col=0;col<viewlen;col++)
		{
			//cout << " (" << a->m_data[ra + col] << " + " << b->m_data[rb + col] << ")";
			a->m_data[ra + col] += b->m_data[rb + col];
		}
		//cout << endl;
		ra += a_len;
		rb += b_len;
	}
}

// a is always data
template<int a_len, int b_len, int viewlen>
void subInPlace(Matrix* a, Matrix* b)
{
	int ra = 0; // a->m_row * a_len + a->m_col;
	int rb = b->m_row * b_len + b->m_col;

	for(int row=0;row<viewlen;row++)
	{
		for(int col=0;col<viewlen;col++)
			a->m_data[ra + col] -= b->m_data[rb + col];
		ra += a_len;
		rb += b_len;
	}
}

template<int len_a, int len_b, int viewlen>
Matrix* mult(Matrix* a, Matrix* b)
{
//#ifdef _DEBUG
//	if(m_viewlen != b.m_viewlen)
//	{
//		cout << "Trying to multiply " << m_viewlen << " by " << b.m_viewlen << endl;
//		getch();
//		exit(1);
//	}
//#endif
	if(viewlen == 2)
	{
		// conquer
		// strassen fast matrix mult
		Matrix* c = allocate<2>();

		int indices_00 = a->m_row*len_a + a->m_col;
		int indices_01 = indices_00 + 1;
		int indices_10 = indices_00 + len_a;
		int indices_11 = indices_10 + 1;


		int b_indices_00 = b->m_row*len_b + b->m_col;
		int b_indices_01 = b_indices_00 + 1;
		int b_indices_10 = b_indices_00 + len_b;
		int b_indices_11 = b_indices_10 + 1;

		int* a_data = a->m_data;
		int* b_data = b->m_data;

		//printf("m0 = (%d + %d) * (%d + %d)\n", a_data[indices_00], a_data[indices_11], b_data[b_indices_00], b_data[b_indices_11]);
		int m0 = (a_data[indices_00] +  a_data[indices_11])  * (b_data[b_indices_00] + b_data[b_indices_11]);
		int m1 = (a_data[indices_01] -  a_data[indices_11])  * (b_data[b_indices_10] + b_data[b_indices_11]);
		int m2 = (a_data[indices_00] -  a_data[indices_10])  * (b_data[b_indices_00] + b_data[b_indices_01]);
		int m3 = (a_data[indices_00] +  a_data[indices_01])  *  b_data[b_indices_11];
		int m4 =  a_data[indices_00] * (b_data[b_indices_01] -  b_data[b_indices_11]);
		int m5 =  a_data[indices_11] * (b_data[b_indices_10] -  b_data[b_indices_00]);
		int m6 = (a_data[indices_10] +  a_data[indices_11])  *  b_data[b_indices_00];

		int* c_data = c->m_data;

		//printf("\n %d + %d - %d + %d\n", m0, m1, m3, m5);
		c_data[0] = m0 + m1 - m3 + m5;

		c_data[1] = m3 + m4;
		c_data[2] = m5 + m6;
		c_data[3] = m0 - m2 + m4 - m6;

		//cout << "a" << endl;
		//print<len_a, 2>(a);
		//cout << "b" << endl;
		//print<len_b, 2>(b);
		//cout << "a*b" << endl;
		//print<2, 2>(c);
		//cout << "?" << endl << endl;

		return c;	
	}
	else
	{
		// divide
		Matrix* c = allocate<viewlen>();
		Matrix* a11 = allocate<viewlen/2>(a, 0, 0);
		Matrix* b11 = allocate<viewlen/2>(b, 0, 0);
		Matrix* c11 = allocate<viewlen/2>(c, 0, 0);

		Matrix* a12 = allocate<viewlen/2>(a, 0, viewlen/2);
		Matrix* b12 = allocate<viewlen/2>(b, 0, viewlen/2);
		Matrix* c12 = allocate<viewlen/2>(c, 0, viewlen/2);

		Matrix* a21 = allocate<viewlen/2>(a, viewlen/2, 0);
		Matrix* b21 = allocate<viewlen/2>(b, viewlen/2, 0);
		Matrix* c21 = allocate<viewlen/2>(c, viewlen/2, 0);

		Matrix* a22 = allocate<viewlen/2>(a, viewlen/2, viewlen/2);
		Matrix* b22 = allocate<viewlen/2>(b, viewlen/2, viewlen/2);
		Matrix* c22 = allocate<viewlen/2>(c, viewlen/2, viewlen/2);

		//http://www.brpreiss.com/books/opus5/html/page457.html

		// (a11 + a22) * (b11 + b22)
		Matrix* m0a = add<len_a, viewlen/2>(a11, a22);
		Matrix* m0  = mult<viewlen/2, viewlen/2, viewlen/2>(m0a, add<len_b, viewlen/2>(b11, b22));

		// (a12 - a22) * (b21 + b22)
		Matrix* m1 = mult<viewlen/2, viewlen/2, viewlen/2>(sub<len_a, viewlen/2>(a12, a22), add<len_b, viewlen/2>(b21, b22));
		// (a11 - a21) * (b11 + b12)
		Matrix* m2 = mult<viewlen/2, viewlen/2, viewlen/2>(sub<len_a, viewlen/2>(a11, a21), add<len_b, viewlen/2>(b11, b12));

		// (a11 + a12) * b22
		Matrix* m3 = mult<viewlen/2, len_b, viewlen/2>(add<len_a, viewlen/2>(a11, a12), b22);
		// (a21 + a22) * b22
		Matrix* m6 = mult<viewlen/2, len_b, viewlen/2>(add<len_a, viewlen/2>(a21, a22), b11);

		// a11 * (b12 - b22)
		Matrix* m4 = mult<len_a, viewlen/2, viewlen/2>(a11, sub<len_b, viewlen/2>(b12, b22));
		// a22 * (b21 - b11)
		Matrix* m5 = mult<len_a, viewlen/2, viewlen/2>(a22, sub<len_b, viewlen/2>(b21, b11));

		//cout << "mx" << endl;
		//print<viewlen/2, viewlen/2>(m5);
		//cout << "a/b/c" << endl;
		//print<len_a, viewlen/2>(a22);

		addInPlace<viewlen, viewlen/2, viewlen/2>(c11, m0);
		addInPlace<viewlen, viewlen/2, viewlen/2>(c11, m1);
		subInPlace<viewlen, viewlen/2, viewlen/2>(c11, m3);
		addInPlace<viewlen, viewlen/2, viewlen/2>(c11, m5);

		addInPlace<viewlen, viewlen/2, viewlen/2>(c12, m3);
		addInPlace<viewlen, viewlen/2, viewlen/2>(c12, m4);

		addInPlace<viewlen, viewlen/2, viewlen/2>(c21, m5);
		addInPlace<viewlen, viewlen/2, viewlen/2>(c21, m6);

		addInPlace<viewlen, viewlen/2, viewlen/2>(c22, m0);
		subInPlace<viewlen, viewlen/2, viewlen/2>(c22, m2);
		addInPlace<viewlen, viewlen/2, viewlen/2>(c22, m4);
		subInPlace<viewlen, viewlen/2, viewlen/2>(c22, m6);

		//c11.addInPlace(m0); c11.addInPlace(m1); c11.subInPlace(m3); c11.addInPlace(m5);
		//c12.addInPlace(m3); c12.addInPlace(m4);
		//c21.addInPlace(m5); c21.addInPlace(m6);
		//c22.addInPlace(m0); c22.subInPlace(m2); c22.addInPlace(m4); c22.subInPlace(m6);

		//cout << "c" << endl;
		//print<viewlen, viewlen>(c);

		deallocate<viewlen/2>(m0a);
		deallocate_view<viewlen/2>(a11);

		return c;
	}
}

template<int len, int viewlen>
void print(Matrix* a)
{
	for(int row=0;row<viewlen;row++)
	{
		for(int col=0;col<viewlen;col++)
			cout << setw(7) << a->m_data[((a->m_row+row)*len) + a->m_col + col] << " ";
		cout << endl;
	}
}

#define SAMPLE_LEN 4

template<int n> 
void test()
{
	test<n/2>();

	int *a = new int[n*n];
	int *b = new int[n*n];
	int *c = new int[n*n]; 

	//cout << "Size " << n;

	for(int i=0;i<n;i++)
	{
		int row = i*n;
		for(int j=0;j<n;j++)
		{
			a[row+j] = rand() % 100;
			b[row+j] = rand() % 100;
			c[row+j] = 0;
		}
	}
	//cout << "." << endl;
	
	// CLASSIC
	clock_t begin = clock();
	for(int row=0;row<n;row++)
	{
		int r = row*n;
		for(int col=0;col<n;col++)
			for(int i=0;i<n;i++)
				c[r + col] += a[r + i] * b[i*n + col];
	}
	//cout << "Cl " << (clock() - begin) << endl;
	clock_t classic = (clock() - begin);

	//// STRASSEN
	Matrix* ma = allocate<n>(a,0,0);
	Matrix* mb = allocate<n>(b,0,0);
	
	begin = clock();
					
	//cout << "a" << endl;
	//print<n, n>(ma);
	//cout << "b" << endl;
	//print<n, n>(mb);

	Matrix* mc = mult<n,n,n>(ma, mb);
	clock_t strassen = (clock() - begin);

	cout << setw(6) << n << " " << setw(8) << classic << " " << setw(8) << strassen << endl;
//#ifdef _DEBUG
//	for(int row=0;row<n;row++)
//		for(int col=0;col<n;col++)
//				if(c[row*n + col] != mc->m_data[row*n + col])
//				{
//					cout << "Strassen" << endl;
//					cout << "a" << endl;
//					print<n, n>(ma);
//					cout << "b" << endl;
//					print<n, n>(mb);
//					cout << "a*b" << endl;
//					print<n, n>(mc);
//
//					cout << "Classic" << endl;
//					for(int i=0;i<n;i++)
//					{
//						for(int j=0;j<n;j++)
//							cout << setw(7) << c[i*n+j] << " ";
//						cout << endl;
//					}
//					cerr << "differed at " << row << ", " << col << endl;
//					getch();
//					exit(1);
//				}
//
//	for(int i=2;i<=n;i<<=1)
//		printf("\tmax %d mem: %d view: %d\n", i, maxBounds[i], maxViewBounds[i]);
//
//	cout << endl;
//#endif	
	delete[] c;
}

template<>
void test<2>()
{ } 

void validate()
{
	Matrix* a = allocate<2>();
	Matrix* b = allocate<2>();
	int i = 0;
	for(int row=0;row<2;row++)
		for(int col=0;col<2;col++)
			a->m_data[row*2 + col] = i++;

	cout << "a" << endl;
	print<2, 2>(a);
	cout << "b" << endl;
	print<2, 2>(b);

	addInPlace<2, 2, 2>(b, a);

	i = 0;
	for(int row=0;row<2;row++)
		for(int col=0;col<2;col++)
			if(b->m_data[row*2 + col] != i++)
			{
				cerr << "add failed" << endl;
				getch();
				exit(1);
			}

	cout << "a" << endl;
	print<2, 2>(a);
	cout << "b" << endl;
	print<2, 2>(b);

	cout << "a+b" << endl;
	print<2, 2>(sub<2,2>(a,b));
}

int main(int argc, char* argv[])
{

	// cout << CLOCKS_PER_SEC << endl; // 1000 on my box

	srand(123);
	init(2048);

	//validate();

	cout << setw(6) << "size" << " " << setw(8) << "classic" << " " << setw(8) << "strassen" << endl;
	test<2048>();
	// cout << "done" << endl; getch();

	/*
	int ai[SAMPLE_LEN][SAMPLE_LEN];
	int bi[SAMPLE_LEN][SAMPLE_LEN];
	int ci[SAMPLE_LEN][SAMPLE_LEN];

	int count = 1;
	for(int row=0;row<SAMPLE_LEN;row++)
		for(int col=0;col<SAMPLE_LEN;col++)
		{
			ai[row][col] = count++;
			bi[row][col] = count++;
			ci[row][col] = 0;
		}

	for(int row=0;row<SAMPLE_LEN;row++)
		for(int col=0;col<SAMPLE_LEN;col++)
			for(int i=0;i<SAMPLE_LEN;i++)
				ci[row][col] += ai[row][i] * bi[i][col];

	cout << "Classic var" << endl;
	for(int row=0;row<SAMPLE_LEN;row++)
	{
		for(int col=0;col<SAMPLE_LEN;col++)
			cout << setw(5) << ci[row][col] << " ";
		cout << endl;
	}	
	cout << endl;

	Matrix<4> a;
	a.set(ai);
	Matrix<4> b;
	b.set(bi);
	
	Matrix<4> c = a*b;

	cout << "a" << endl;
	a.print();
	cout << "b" << endl;
	b.print();
	cout << endl;

	cout << "Fast Matrix Mult" << endl;
	c.print();

	getch();
	*/

	return 0;
}

