#include "MemChecker.h"
#include "RStd.h"

#include <vector>

extern "C" {
// For SIZE_MAX.
// FIXME: Should use numeric_limits when supported.
#ifdef SunOS
#include <climits>
#else
#define __STDC_LIMIT_MACROS
#include <stdint.h>
#undef __STDC_LIMIT_MACROS
#endif
}
/*###########################################################################*/
namespace {

// Assuming sizeof (double) >= sizeof (size_t).
const size_t ALIGN_PAD = sizeof (double);
const size_t UGLY_CONST = 0x1badfeed;

/*===========================================================================*/
	size_t actual_bytes(size_t bytes) {
		return (bytes + ALIGN_PAD * 3 - 1) / ALIGN_PAD * ALIGN_PAD;
	}

/*===========================================================================*/
	void * wall_alloc(size_t bytes) {
		size_t total_bytes = actual_bytes(bytes);

		size_t * mem = static_cast<size_t *>(malloc(total_bytes));

		if (! mem) return 0;

		size_t size_t_size = total_bytes / sizeof (size_t);
		RASSERT(size_t_size * sizeof (size_t) == total_bytes);

		for (size_t x = 0; x < size_t_size; ++x) {
			mem[x] = UGLY_CONST;
		}

		return mem + ALIGN_PAD / sizeof (size_t);
	}

/*===========================================================================*/
	void wall_free(void * data) {
		if (! data) return;
		free(static_cast<size_t *>(data) - ALIGN_PAD / sizeof (size_t));
	}
}

stack<MemChecker *> MemChecker::mc_stack_;

/*===========================================================================*/
MemChecker * MemChecker::mc() {
	return (mc_stack_.empty()) ? 0 : mc_stack_.top();
}

/*===========================================================================*/
MemChecker::MemChecker(const char * name) :
	mem_map_(),
	min_alloc_(0),
	max_alloc_(0),
	min_byte_(0),
	max_byte_(0),
	name_(name)
{
	mc_stack_.push(this);
}

/*===========================================================================*/
MemChecker::MemChecker(const char * name, size_t min_al, size_t max_al) :
	mem_map_(),
	min_alloc_(min_al),
	max_alloc_(max_al),
	min_byte_(0),
	max_byte_(SIZE_MAX),
	name_(name)
{
	mc_stack_.push(this);
}

/*===========================================================================*/
MemChecker::MemChecker(const char * name, size_t min_al, size_t max_al,
size_t min_b, size_t max_b) :
	mem_map_(),
	min_alloc_(min_al),
	max_alloc_(max_al),
	min_byte_(min_b),
	max_byte_(max_b),
	name_(name)
{
	mc_stack_.push(this);
}

/*===========================================================================*/
MemChecker::~MemChecker() throw() {
	try {
		check();

		RASSERT(mc_stack_.size());
		mc_stack_.pop();
	} catch (...) {
// Get rid of all exceptions in here.
	}
}

/*===========================================================================*/
void MemChecker::set_expected(size_t min_al, size_t max_al) {
	min_alloc_ = min_al;
	max_alloc_ = max_al;
	min_byte_ = 0;
	max_byte_ = SIZE_MAX;
}

/*===========================================================================*/
void MemChecker::set_expected(size_t min_al, size_t max_al,
size_t min_b, size_t max_b) {
	min_alloc_ = min_al;
	max_alloc_ = max_al;
	min_byte_ = min_b;
	max_byte_ = max_b;
}

/*===========================================================================*/
void *
MemChecker::checked_alloc(size_t bytes) {
	void * mem = wall_alloc(bytes);

	if (! mem) {
		throw std::bad_alloc();
	}

	if (mem_map_.find(mem) != mem_map_.end()) {
		cerr << "MemChecker " << name_ <<
			": Allocation of memory that was already allocated.\n";

		Rabort();
	}

	mem_map_[mem] = bytes;

	return mem;
}

/*===========================================================================*/
void MemChecker::checked_dealloc(void * data) {
	if (! data) {
		return;
	}

	if (mem_map_.empty()) {
		cerr << "MemChecker " << name_ << ": Deletion although there were "
			"no outstanding allocations.\n";

		Rabort();
	}

	MemMap::iterator mem_iter = mem_map_.find(data);

	if (mem_iter == mem_map_.end()) {
		cerr << "MemChecker " << name_ <<
			": Deletion of memory that was not allocated.\n";

		Rabort();
	}

	size_t request_size = mem_iter->second;
	size_t total_size = actual_bytes(request_size);
	size_t size_t_size = total_size / sizeof (size_t);

	if (! size_t_size) {
		cerr << "MemChecker " << name_ <<
			": Deletion of memory with a size of 0.\n";

		Rabort();
	}

	size_t * mem = static_cast<size_t *>(data) - ALIGN_PAD / sizeof (size_t);

	for (size_t x = 0; x < ALIGN_PAD / sizeof (size_t); x++) {
		if (mem[x] != UGLY_CONST) {
			cerr << "MemChecker " << name_ << ": Start pad damaged.\n";

			Rabort();
		}
	}

	const char * segment = reinterpret_cast<const char *>(& UGLY_CONST);
	const char * char_mem = reinterpret_cast<const char *>(mem);

	for (size_t x = ALIGN_PAD + request_size; x < total_size - ALIGN_PAD; x++) {
		if (char_mem[x] != segment[x % sizeof (size_t)]) {
			cerr << "MemChecker " << name_ << ": End pad damaged.\n";

			Rabort();
		}
	}

	for (size_t x = 0; x < size_t_size; ++x) {
		mem[x] = UGLY_CONST;
	}

	mem_map_.erase(mem_iter);
	wall_free(data);
}

/*===========================================================================*/
size_t
MemChecker::out_allocs() const {
	return mem_map_.size();
}

/*===========================================================================*/
size_t
MemChecker::out_bytes() const {
	size_t size = 0;

	for (MemMap::const_iterator i = mem_map_.begin(); i != mem_map_.end();
		i++)
	{
		size += i->second;
	}

	return size;
}

/*===========================================================================*/
void MemChecker::check(int max_print) {
	size_t oa = out_allocs();
	size_t ob = out_bytes();

	if (oa < min_alloc_ || oa > max_alloc_ ||
		oa < min_byte_ || ob > max_byte_)
	{
		cerr << "MemChecker " << name_ << ": Expected min allocs: " <<
			min_alloc_ << ", max allocs: " << max_alloc_ << "\n";

		cerr << "MemChecker " << name_ << ": Expected min bytes: " <<
			min_byte_ << ", max bytes: " << max_byte_ << "\n";

		cerr << "MemChecker " << name_ << ": Outstanding number of allocs: " <<
			oa << "\n";

		cerr << "MemChecker " << name_ << ": Outstanding number of bytes: " <<
			ob << "\n";

		if (max_print < 0 || oa < static_cast<unsigned>(max_print)) {
			cerr << "MemChecker " << name_ <<
				": Allocation information follows: " <<
				"(position, bytes, array).\n";

			for (MemMap::const_iterator i = mem_map_.begin();
				i != mem_map_.end(); i++)
			{
				cerr << "MemChecker " << name_ << ": (" << i->first << ", " <<
					i->second << ")\n";
			}

			cerr << "\n";
		}
	}
}

/*###########################################################################*/
#ifdef ROB_DEBUG

void *
operator new(size_t bytes) throw (std::bad_alloc) {
	if (MemChecker::mc()) {
		return MemChecker::mc()->checked_alloc(bytes);
	} else {
		return wall_alloc(bytes);
	}
}

/*===========================================================================*/
void operator delete(void * data) throw () {
	if (MemChecker::mc()) {
		MemChecker::mc()->checked_dealloc(data);
	} else {
		wall_free(data);
	}
}

/*===========================================================================*/
void *
operator new[](size_t bytes) throw (std::bad_alloc) {
	if (MemChecker::mc()) {
		return MemChecker::mc()->checked_alloc(bytes);
	} else {
		return wall_alloc(bytes);
	}
}

/*===========================================================================*/
void operator delete[](void * data) throw () {
	if (MemChecker::mc()) {
		MemChecker::mc()->checked_dealloc(data);
	} else {
		wall_free(data);
	}
}

#endif

/*###########################################################################*/
void MemChecker_test() {
	{
		cout << "Making new MemChecker (local)\n";
		MemChecker mc("local");

		cout << "Leaking 3 intentionally.\n";
		new int;
		new char[12];
		new double;
	}

#ifdef ROB_DEBUG
// These are too dangerous to try without the MemChecker on.

	long * x = new long;
	couthex << "Should be 1badfeed: " << *x << "\n";
	*x = 0;
	delete x;
	couthex << "Should be 1badfeed: " << *x << "\n";

	x = new long;
	cout << "Damaging end pad intentionally.  "
		"This will also cause an accidental leak and abort the program.\n";

	*(x + 1) = 0;
	delete x;
#endif
}


