#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <fcntl.h>
#include <signal.h>
#include <locale.h>
#include <sys/poll.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <asm/types.h>

static char relay_path[] = "/relay/";

static int devfd, tracefd;
static void *tracebuf;

static int buf_size = 128*1024;		/* must be a power-of-2 */
static int buf_nr = 4;

static int use_mmap = 0;

#define min(a, b)	((a) < (b) ? (a) : (b))
#define max(a, b)	((a) > (b) ? (a) : (b))

enum {
        BLK_TC_READ     = 1 << 0,       /* reads */
        BLK_TC_WRITE    = 1 << 1,       /* writes */
        BLK_TC_BARRIER  = 1 << 2,       /* barrier */
        BLK_TC_SYNC     = 1 << 3,       /* barrier */
        BLK_TC_QUEUE    = 1 << 4,       /* queueing/merging */
        BLK_TC_REQUEUE  = 1 << 5,       /* requeueing */
        BLK_TC_ISSUE    = 1 << 6,       /* issue */
        BLK_TC_COMPLETE = 1 << 7,       /* completions */
        BLK_TC_FS       = 1 << 8,       /* fs requests */
        BLK_TC_PC       = 1 << 9,       /* pc requests */

        BLK_TC_END      = 1 << 15,      /* only 16-bits, reminder */
};

#define BLK_TC_SHIFT		(16)
#define BLK_TC_ACT(act)		((act) << BLK_TC_SHIFT)

/*
 * Basic trace actions
 */
enum {
	__BLK_TA_QUEUE = 1,		/* queued */
	__BLK_TA_BACKMERGE,		/* back merged to existing rq */
	__BLK_TA_FRONTMERGE,		/* front merge to existing rq */
	__BLK_TA_GETRQ,			/* allocated new request */
	__BLK_TA_SLEEPRQ,		/* sleeping on rq allocation */
	__BLK_TA_REQUEUE,		/* request requeued */
	__BLK_TA_ISSUE,			/* sent to driver */
	__BLK_TA_COMPLETE,		/* completed by driver */
};

/*
 * Trace actions in full. Additionally, read or write is masked
 */
#define BLK_TA_QUEUE		(__BLK_TA_QUEUE | BLK_TC_ACT(BLK_TC_QUEUE))
#define BLK_TA_BACKMERGE	(__BLK_TA_BACKMERGE | BLK_TC_ACT(BLK_TC_QUEUE))
#define BLK_TA_FRONTMERGE	(__BLK_TA_FRONTMERGE | BLK_TC_ACT(BLK_TC_QUEUE))
#define	BLK_TA_GETRQ		(__BLK_TA_GETRQ | BLK_TC_ACT(BLK_TC_QUEUE))
#define	BLK_TA_SLEEPRQ		(__BLK_TA_SLEEPRQ | BLK_TC_ACT(BLK_TC_QUEUE))
#define	BLK_TA_REQUEUE		(__BLK_TA_REQUEUE | BLK_TC_ACT(BLK_TC_QUEUE))
#define BLK_TA_ISSUE		(__BLK_TA_ISSUE | BLK_TC_ACT(BLK_TC_ISSUE))
#define BLK_TA_COMPLETE		(__BLK_TA_COMPLETE| BLK_TC_ACT(BLK_TC_COMPLETE))

#define BLK_IO_TRACE_MAGIC	(0x65617400)
#define CHECK_MAGIC(t)		(((t)->magic & 0xffffff00) == BLK_IO_TRACE_MAGIC)
#define SUPPORTED_VERSION	(0x02)

struct blk_io_trace {
	__u32 magic;
	__u32 sequence;
	__u64 time;
	__u64 sector;
	__u32 bytes;
	__u32 action;
	__u32 pid;
	__u16 error;
	__u16 pdu_len;
};

struct blk_user_trace_setup {
	char name[32];
	__u16 act_mask;
	__u32 buf_size;
	__u32 buf_nr;
};


#ifndef BLKSTARTTRACE
#define BLKSTARTTRACE	_IOWR(0x12,115,struct blk_user_trace_setup)
#define BLKSTOPTRACE	_IO(0x12,116)
#endif

static unsigned long qreads, qwrites, creads, cwrites, mreads, mwrites;
static unsigned long long qread_kb, qwrite_kb, cread_kb, cwrite_kb;
static unsigned long long events, missed_events;
static unsigned long last_sequence;

static unsigned int cur_sub_buf;
static unsigned long cur_sub_buf_offset;

static unsigned long long start_time = -1;

inline void update_buf(int len)
{
	if (use_mmap)
		cur_sub_buf_offset += len;
}

int fill_buf_mmap(void *buf, int len)
{
	void *src;

	if (len + cur_sub_buf_offset > buf_size) {
		cur_sub_buf++;
		cur_sub_buf_offset = 0;
		if (cur_sub_buf == buf_nr)
			cur_sub_buf = 0;
	}

	src = tracebuf + buf_size * cur_sub_buf + cur_sub_buf_offset;
	memcpy(buf, src, len);
	return 0;
}

int fill_buf_read(void *buf, int len)
{
	int ret = read(tracefd, buf, len);

	if (ret == len)
		return 0;
	else if (ret < 0)
		perror("read");

	return 1;
}

int fill_buf(void *buf, int len)
{
	if (use_mmap)
		return fill_buf_mmap(buf, len);

	return fill_buf_read(buf, len);
}

static inline void account_m(int rw, unsigned int bytes)
{
	if (rw) {
		mwrites++;
		qwrite_kb += bytes >> 10;
	} else {
		mreads++;
		qread_kb += bytes >> 10;
	}
}

static inline void account_q(int rw, unsigned int bytes)
{
	if (rw) {
		qwrites++;
		qwrite_kb += bytes >> 10;
	} else {
		qreads++;
		qread_kb += bytes >> 10;
	}
}

static inline void account_c(int rw, unsigned int bytes)
{
	if (rw) {
		cwrites++;
		cwrite_kb += bytes >> 10;
	} else {
		creads++;
		cread_kb += bytes >> 10;
	}
}

inline int verify_trace(struct blk_io_trace *t)
{
	if (!CHECK_MAGIC(t)) {
		fprintf(stderr, "bad trace magic %x\n", t->magic);
		return 1;
	}
	if ((t->magic & 0xff) != SUPPORTED_VERSION) {
		fprintf(stderr, "unsupported trace version %x\n", t->magic & 0xff);
		return 1;
	}

	return 0;
}

void log_complete(struct blk_io_trace *t, char *rwbs, char act)
{
	printf("%12Lu %5u %c %3s %Lu-%Lu [%d]\n", (unsigned long long) t->time, t->pid, act, rwbs, (unsigned long long) t->sector, (unsigned long long) t->sector + (t->bytes >> 9), t->error);
}

void log_queue(struct blk_io_trace *t, char *rwbs, char act)
{
	printf("%12Lu %5u %c %3s %Lu-%Lu\n", (unsigned long long) t->time, t->pid, act, rwbs, (unsigned long long) t->sector, (unsigned long long) t->sector + (t->bytes >> 9));
}

void log_issue(struct blk_io_trace *t, char *rwbs, char act)
{
	printf("%12Lu %5u %c %3s %Lu-%Lu\n", (unsigned long long) t->time, t->pid, act, rwbs, (unsigned long long) t->sector, (unsigned long long) t->sector + (t->bytes >> 9));
}

void log_merge(struct blk_io_trace *t, char *rwbs, char act)
{
	printf("%12Lu %5u %c %3s [%Lu-%Lu]\n", (unsigned long long) t->time, t->pid, act, rwbs, (unsigned long long) t->sector, (unsigned long long) t->sector + (t->bytes >> 9));
}

void log_generic(struct blk_io_trace *t, char *rwbs, char act)
{
	printf("%12Lu %5u %c %3s\n", (unsigned long long) t->time, t->pid, act,rwbs);
}

void log_pc(struct blk_io_trace *t, char *rwbs, char act)
{
	unsigned char buf[64];
	int i;

	printf("%12Lu %5u %c %3s ", (unsigned long long) t->time, t->pid, act, rwbs);

	if (t->pdu_len > sizeof(buf)) {
		fprintf(stderr, "Payload too large %d\n", t->pdu_len);
		return;
	}

	if (fill_buf(buf, t->pdu_len))
		return;

	update_buf(t->pdu_len);

	for (i = 0; i < t->pdu_len; i++)
		printf("%02x ", buf[i]);

	if (act == 'C')
		printf("[%d]", t->error);

	printf("\n");
}

void dump_trace_pc(struct blk_io_trace *t)
{
	int w = t->action & BLK_TC_ACT(BLK_TC_WRITE);
	int b = t->action & BLK_TC_ACT(BLK_TC_BARRIER);
	int s = t->action & BLK_TC_ACT(BLK_TC_SYNC);
	char rwbs[4];
	int i = 0;

	if (w)
		rwbs[i++] = 'W';
	else
		rwbs[i++] = 'R';
	if (b)
		rwbs[i++] = 'B';
	if (s)
		rwbs[i++] = 'S';

	rwbs[i] = '\0';

	switch (t->action & 0xffff) {
		case __BLK_TA_QUEUE:
			log_generic(t, rwbs, 'Q');
			break;
		case __BLK_TA_GETRQ:
			log_generic(t, rwbs, 'G');
			break;
		case __BLK_TA_SLEEPRQ:
			log_generic(t, rwbs, 'S');
			break;
		case __BLK_TA_REQUEUE:
			log_generic(t, rwbs, 'R');
			break;
		case __BLK_TA_ISSUE:
			log_pc(t, rwbs, 'D');
			break;
		case __BLK_TA_COMPLETE:
			log_pc(t, rwbs, 'C');
			break;
		default:
			fprintf(stderr, "Bad pc action %x\n", t->action);
			return;
	}
	
	events++;
}

void dump_trace_fs(struct blk_io_trace *t)
{
	int w = t->action & BLK_TC_ACT(BLK_TC_WRITE);
	int b = t->action & BLK_TC_ACT(BLK_TC_BARRIER);
	int s = t->action & BLK_TC_ACT(BLK_TC_SYNC);
	char rwbs[4];
	int i = 0;

	if (w)
		rwbs[i++] = 'W';
	else
		rwbs[i++] = 'R';
	if (b)
		rwbs[i++] = 'B';
	if (s)
		rwbs[i++] = 'S';

	rwbs[i] = '\0';

	switch (t->action & 0xffff) {
		case __BLK_TA_QUEUE:
			account_q(w, t->bytes);
			log_queue(t, rwbs, 'Q');
			break;
		case __BLK_TA_BACKMERGE:
			account_m(w, t->bytes);
			log_merge(t, rwbs, 'M');
			break;
		case __BLK_TA_FRONTMERGE:
			account_m(w, t->bytes);
			log_merge(t, rwbs, 'F');
			break;
		case __BLK_TA_GETRQ:
			log_generic(t, rwbs, 'G');
			break;
		case __BLK_TA_SLEEPRQ:
			log_generic(t, rwbs, 'S');
			break;
		case __BLK_TA_REQUEUE:
			log_queue(t, rwbs, 'R');
			break;
		case __BLK_TA_ISSUE:
			log_issue(t, rwbs, 'D');
			break;
		case __BLK_TA_COMPLETE:
			account_c(w, t->bytes);
			log_complete(t, rwbs, 'C');
			break;
		default:
			fprintf(stderr, "Bad fs action %x\n", t->action);
			return;
	}
	
	events++;
}

void dump_trace(struct blk_io_trace *t)
{
	if (t->action & BLK_TC_ACT(BLK_TC_PC))
		dump_trace_pc(t);
	else
		dump_trace_fs(t);
}

int wait_for_events(void)
{
	usleep(10000);
	return 1;
}

int get_next_event(struct blk_io_trace *t)
{
	if (fill_buf(t, sizeof(*t)))
		return 0;

	if (!CHECK_MAGIC(t))
		return 0;

	if ((t->magic & 0xff) != SUPPORTED_VERSION) {
		fprintf(stderr, "Unsupported trace version %d\n", t->magic & 0xff);
		return 0;
	}

	/*
	 * offset time by first trace event. keep the internal resolution
	 * in usec
	 */
	if (start_time == -1)
		start_time = t->time;

	t->time -= start_time;
	t->time /= 1000;

	update_buf(sizeof(*t));
	return 1;
}

void read_events(void)
{
	struct blk_io_trace t;
	unsigned long r_events = 0;

	do {
		if (!get_next_event(&t)) {
			wait_for_events();
			continue;
		}

		if (verify_trace(&t))
			break;

		dump_trace(&t);

		if (t.sequence != last_sequence + 1) {
			fprintf(stderr, "seq %u, last %lu\n", t.sequence, last_sequence);
			missed_events += max(t.sequence, last_sequence) - min(t.sequence, last_sequence);
		}
		last_sequence = t.sequence;

		r_events++;
		if ((r_events % 32768) == 0)
			fprintf(stderr, "%lu events\n", r_events);

	} while (1);
}

void stop_trace(void)
{
	if (ioctl(devfd, BLKSTOPTRACE) < 0)
		perror("BLKSTOPTRACE");

	if (tracebuf)
		munmap(tracebuf, buf_size * buf_nr);

	close(tracefd);
}

int start_trace(char *name)
{
	struct blk_user_trace_setup buts;
	char p[64];

	memset(&buts, sizeof(buts), 0);
	//buts.act_mask = BLK_TC_REQUEUE;
	buts.buf_size = buf_size;
	buts.buf_nr = buf_nr;

	if (ioctl(devfd, BLKSTARTTRACE, &buts) < 0)
		return 1;

	sprintf(p, "%s%s0", relay_path, buts.name);
	printf("relay name: %s\n", p);
	tracefd = open(p, O_RDONLY);
	if (tracefd == -1) {
		perror("open relay file");
		return 1;
	}

	if (!use_mmap)
		return 0;

	tracebuf = mmap(NULL, buf_size * buf_nr, PROT_READ, MAP_PRIVATE | MAP_POPULATE, tracefd, 0);
	if (tracebuf == MAP_FAILED) {
		perror("mmap");
		return 1;
	}
	
	return 0;
}

void show_stats(void)
{
	printf("Reads:");
	printf("\tQueued:    %'8lu, %'8LuKiB\n", qreads, qread_kb);
	printf("\tCompleted: %'8lu, %'8LuKiB\n", creads, cread_kb);
	printf("\tMerges:    %'8lu\n", mreads);

	printf("Writes:");
	printf("\tQueued:    %'8lu, %'8LuKiB\n", qwrites, qwrite_kb);
	printf("\tCompleted: %'8lu, %'8LuKiB\n", cwrites, cwrite_kb);
	printf("\tMerges:    %'8lu\n", mwrites);

	printf("Events: %'Lu\n", events);
	printf("Missed events: %'Lu\n", missed_events);
}

void handle_sigint(int sig)
{
	fflush(stdout);
	stop_trace();
	show_stats();
	exit(0);
}

int main(int argc, char *argv[])
{
	struct stat st;
	char name[64];

	if (argc < 2) {
		fprintf(stderr, "Usage: %s <dev>\n", argv[0]);
		return 1;
	}

	if (stat(relay_path, &st) < 0) {
		perror("stat");
		fprintf(stderr, "Is relayfs mounted on %s?\n", relay_path);
		return 1;
	}
	if (!S_ISDIR(st.st_mode)) {
		fprintf(stderr, "%s doesn't appear to be a dir\n", relay_path);
		return 1;
	}

	devfd = open(argv[1], O_RDONLY);
	if (devfd == -1) {
		perror("open");
		return 2;
	}

	signal(SIGINT, handle_sigint);

	if (start_trace(name)) {
		fprintf(stderr, "Failed to start trace\n");
		stop_trace();
		return 3;
	}

	setlocale(LC_NUMERIC, "en_US");

	read_events();
	stop_trace();
	close(devfd);
	show_stats();
	return 0;
}