#include "stdafx.h"
#include "AsyncIO.h"
#include "SQLite.h"
#include "OS/IORequest.h"
#include "SQL/Exception.h"

#if defined(LINUX)
#include <sys/sysmacros.h>
#endif

namespace sql {

	static int stormSleep(sqlite3_vfs *vfs, int microseconds) {
		int slept = (microseconds + 999) / 1000;
		os::UThread::sleep(slept);
		return slept * 1000;
	}

#define SET_SYSCALL(vfs, name, fn)										\
	if ((vfs)->xSetSystemCall(vfs, name, reinterpret_cast<sqlite3_syscall_ptr>(fn))) { \
		storm::Engine &e = runtime::someEngine();						\
		throw new (e) SQLError(TO_S(e, S("Failed to update SQLite syscall: ") << S(name))); \
	}


#if defined(WINDOWS)

	// Just copy things if we have them.
	static void init(os::IORequest &to, OVERLAPPED *overlapped) {
		if (overlapped) {
			to.Internal = overlapped->Internal;
			to.InternalHigh = overlapped->InternalHigh;
			to.Offset = overlapped->Offset;
			to.OffsetHigh = overlapped->OffsetHigh;
			to.hEvent = overlapped->hEvent;
		}
	}

	// Initialize an IORequest from an (optional) OVERLAPPED structure.
	static void init(os::IORequest &to, OVERLAPPED *overlapped, HANDLE handle) {
		if (!overlapped) {
			LARGE_INTEGER pos;
			pos.QuadPart = 0;
			if (SetFilePointerEx(handle, pos, &pos, FILE_CURRENT)) {
				to.Offset = pos.LowPart;
				to.OffsetHigh = pos.HighPart;
			} else {
				to.Offset = 0;
				to.OffsetHigh = 0;
			}
		} else {
			init(to, overlapped);
		}
	}

	// Advance the file pointer after we are done.
	static void advance(OVERLAPPED *overlapped, HANDLE handle, nat bytes) {
		if (!overlapped) {
			LARGE_INTEGER advance;
			advance.QuadPart = bytes;
			SetFilePointerEx(handle, advance, NULL, FILE_CURRENT);
		}
	}

	static void WINAPI wrapSleep(DWORD ms) {
		// I don't know why this is here alongside the xSleep...
		os::UThread::sleep(ms);
	}

	// Note: If we are on WinRT, we might need to support CreateFile2 as well..

	static HANDLE WINAPI wrapCreateFileA(LPCSTR name, DWORD access, DWORD share, LPSECURITY_ATTRIBUTES attrs,
								DWORD disposition, DWORD flags, HANDLE hTemplate) {
		// Just add the OVERLAPPED flag and attach the handle.
		HANDLE r = CreateFileA(name, access, share, attrs, disposition, flags | FILE_FLAG_OVERLAPPED, hTemplate);
		if (r != INVALID_HANDLE_VALUE) {
			os::Thread::current().attach(os::Handle(r));
		}
		return r;
	}

	static HANDLE WINAPI wrapCreateFileW(LPWSTR name, DWORD access, DWORD share, LPSECURITY_ATTRIBUTES attrs,
								DWORD disposition, DWORD flags, HANDLE hTemplate) {
		// Just add the OVERLAPPED flag and attach the handle.
		HANDLE r = CreateFileW(name, access, share, attrs, disposition, flags | FILE_FLAG_OVERLAPPED, hTemplate);
		if (r != INVALID_HANDLE_VALUE) {
			os::Thread::current().attach(os::Handle(r));
		}
		return r;
	}

	// Note: FlushFileBuffers would be very nice to have, but there does not seem to be an async
	// version of it. Perhaps we should simply have a thread that handles that and close?

	static BOOL WINAPI wrapLockFileEx(HANDLE file, DWORD flags, DWORD reserved,
						DWORD countLow, DWORD countHigh, LPOVERLAPPED overlapped) {

		os::IORequest request(os::Handle(file), os::Thread::current());
		init(request, overlapped);

		int error = 0;
		if (!LockFileEx(file, flags, reserved, countLow, countHigh, &request))
			error = GetLastError();

		if (error == ERROR_IO_PENDING || error == 0) {
			request.wake.down();
			if (request.error) {
				SetLastError(request.error);
				return FALSE;
			} else {
				return TRUE;
			}
		} else {
			return TRUE;
		}
	}

	static BOOL WINAPI wrapUnlockFileEx(HANDLE file, DWORD reserved, DWORD countLow, DWORD countHigh,
								LPOVERLAPPED overlapped) {
		os::IORequest request(os::Handle(file), os::Thread::current());
		init(request, overlapped);

		int error = 0;
		if (!UnlockFileEx(file, reserved, countLow, countHigh, &request))
			error = GetLastError();

		if (error == ERROR_IO_PENDING) {
			request.wake.down();
			if (request.error) {
				SetLastError(request.error);
				return FALSE;
			} else {
				return TRUE;
			}
		} else if (error == 0) {
			// Seems we don't get a response on result of zero.
			return TRUE;
		} else {
			return FALSE;
		}
	}

	// Note: The plain LockFile and UnlockFile are API functions, but they won't be called on
	// Windows NT, where LockFileEx is available.

	static BOOL WINAPI wrapReadFile(HANDLE file, LPVOID buffer,
									DWORD toRead, LPDWORD numRead,
									LPOVERLAPPED overlapped) {
		os::IORequest request(os::Handle(file), os::Thread::current());
		init(request, overlapped, file);

		int error = 0;
		if (!ReadFile(file, buffer, toRead, numRead, &request))
			error = GetLastError();

		if (error == ERROR_IO_PENDING || error == 0) {
			// Completing async...
			request.wake.down();

			// Update file pointer if necessary.
			advance(overlapped, file, request.bytes);

			if (numRead)
				*numRead = request.bytes;

			if (request.error) {
				SetLastError(request.error);
				return FALSE;
			} else {
				return TRUE;
			}
		} else {
			// Failed.
			return FALSE;
		}
	}

	static BOOL WINAPI wrapWriteFile(HANDLE file, LPCVOID buffer,
									DWORD toWrite, LPDWORD numWritten,
									LPOVERLAPPED overlapped) {
		os::IORequest request(os::Handle(file), os::Thread::current());
		init(request, overlapped, file);

		int error = 0;
		if (!WriteFile(file, buffer, toWrite, numWritten, &request))
			error = GetLastError();

		if (error == ERROR_IO_PENDING || error == 0) {
			// Completing async...
			request.wake.down();

			// Update file pointer if necessary.
			advance(overlapped, file, request.bytes);

			if (numWritten)
				*numWritten = request.bytes;

			if (request.error) {
				SetLastError(request.error);
				return FALSE;
			} else {
				return TRUE;
			}
		} else {
			// Failed.
			return FALSE;
		}
	}

	static sqlite3_vfs *initializeAsync() {
		// On Windows, we want to enable long file names. So we re-register the win32-longpath
		// VFS and make it the default. Note: re-registering like this is sanctioned by the
		// documentation.
		sqlite3_vfs *vfs = sqlite3_vfs_find("win32-longpath");
		sqlite3_vfs_register(vfs, 1);

		// Note: Unclear why this is here as well. SQLite actually has an implementation of Sleep
		// using WaitForSingleObject for some reason. Perhaps it is to be able to cancel long sleeps
		// when SQLite is terminated?
		SET_SYSCALL(vfs, "Sleep", wrapSleep);

		// Note: LockFile and UnlockFile are API functions, but they are not used on Windows NT,
		// where LockEX and UnlockEX are available.
		SET_SYSCALL(vfs, "LockFileEx", wrapLockFileEx);
		SET_SYSCALL(vfs, "UnlockFileEx", wrapUnlockFileEx);

		SET_SYSCALL(vfs, "CreateFileA", wrapCreateFileA);
		SET_SYSCALL(vfs, "CreateFileW", wrapCreateFileW);
		SET_SYSCALL(vfs, "ReadFile", wrapReadFile);
		SET_SYSCALL(vfs, "WriteFile", wrapWriteFile);

		// Note: WaitForSingleObject[,Ex] are API functions, but they only seem to be used to
		// implement Sleep.

		return vfs;
	}

#elif defined(LINUX_IO_URING)

	static int submitWithErrno(os::IORequest &r) {
		int result = r.submit();
		if (result < 0) {
			errno = -result;
			result = -1;
		}
		return result;
	}

	static int wrapOpen(const char *file, int flags, int mode) {
		os::IORequest r(os::Handle(AT_FDCWD), os::Thread::current());
		r.request.opcode = IORING_OP_OPENAT;
		r.request.addr = reinterpret_cast<size_t>(file);
		r.request.open_flags = flags;
		r.request.len = mode;
		return submitWithErrno(r);
	}

	static int wrapClose(int fd) {
		// Note: We could cancel outstanding requests. However, we assume that sqlite is fairly well-behaved.
		os::IORequest r(os::Handle(fd), os::Thread::current());
		r.request.opcode = IORING_OP_CLOSE;
		return submitWithErrno(r);
	}

	static void toStat(struct statx *in, struct stat *out) {
		out->st_dev = makedev(in->stx_rdev_major, in->stx_rdev_minor);
		out->st_ino = in->stx_ino;
		out->st_mode = in->stx_mode;
		out->st_nlink = in->stx_nlink;
		out->st_uid = in->stx_uid;
		out->st_gid = in->stx_gid;
		out->st_rdev = makedev(in->stx_rdev_major, in->stx_rdev_minor);
		out->st_size = in->stx_size;
		out->st_blksize = in->stx_blksize;
		out->st_blocks = in->stx_blocks;

		out->st_atime = in->stx_atime.tv_sec;
		out->st_mtime = in->stx_mtime.tv_sec;
		out->st_ctime = in->stx_ctime.tv_sec;

		out->st_atim.tv_sec = in->stx_atime.tv_sec;
		out->st_atim.tv_nsec = in->stx_atime.tv_nsec;
		out->st_mtim.tv_sec = in->stx_mtime.tv_sec;
		out->st_mtim.tv_nsec = in->stx_mtime.tv_nsec;
		out->st_ctim.tv_sec = in->stx_ctime.tv_sec;
		out->st_ctim.tv_nsec = in->stx_ctime.tv_nsec;
	}

	static int xStat(int fd, const char *file, int flags, struct stat *out) {
		struct statx xbuffer;

		os::IORequest r(os::Handle(fd), os::Thread::current());
		r.request.opcode = IORING_OP_STATX;
		r.request.addr = reinterpret_cast<size_t>(file);
		r.request.statx_flags = AT_STATX_SYNC_AS_STAT | flags;
		r.request.len = STATX_BASIC_STATS;
		r.request.off = reinterpret_cast<size_t>(&xbuffer);

		int result = submitWithErrno(r);
		if (result < 0)
			return result;

		toStat(&xbuffer, out);
		return result;
	}

	static int wrapStat(const char *file, struct stat *out) {
		return xStat(AT_FDCWD, file, 0, out);
	}

	static int wrapFstat(int fd, struct stat *out) {
		return xStat(fd, "", AT_EMPTY_PATH, out);
	}

	static int wrapLstat(const char *file, struct stat *out) {
		return xStat(AT_FDCWD, file, AT_SYMLINK_NOFOLLOW, out);
	}

#ifdef IORING_OP_FTRUNCATE
	static int wrapFtruncate(int fd, off_t length) {
		os::IORequest r(os::Handle(fd), os::Thread::current());
		r.request.opcode = IORING_OP_FTRUNCATE;
		r.request.off = length;
		return submitWithErrno(r);
	}
#endif

	static int wrapRead(int fd, void *buffer, size_t size) {
		os::IORequest r(os::Handle(fd), os::Thread::current());
		r.request.opcode = IORING_OP_READ;
		r.request.addr = reinterpret_cast<size_t>(buffer);
		r.request.len = size;
		r.request.off = -1;
		return submitWithErrno(r);
	}

	static int wrapPread(int fd, void *buffer, size_t size, off_t offset) {
		os::IORequest r(os::Handle(fd), os::Thread::current());
		r.request.opcode = IORING_OP_READ;
		r.request.addr = reinterpret_cast<size_t>(buffer);
		r.request.len = size;
		r.request.off = offset;
		return submitWithErrno(r);
	}

	static int wrapWrite(int fd, const void *buffer, size_t size) {
		os::IORequest r(os::Handle(fd), os::Thread::current());
		r.request.opcode = IORING_OP_WRITE;
		r.request.addr = reinterpret_cast<size_t>(buffer);
		r.request.len = size;
		r.request.off = -1;
		return submitWithErrno(r);
	}

	static int wrapPwrite(int fd, const void *buffer, size_t size, off_t offset) {
		os::IORequest r(os::Handle(fd), os::Thread::current());
		r.request.opcode = IORING_OP_WRITE;
		r.request.addr = reinterpret_cast<size_t>(buffer);
		r.request.len = size;
		r.request.off = offset;
		return submitWithErrno(r);
	}

	static int wrapUnlink(const char *file) {
		os::IORequest r(os::Handle(AT_FDCWD), os::Thread::current());
		r.request.opcode = IORING_OP_UNLINKAT;
		r.request.addr = reinterpret_cast<size_t>(file);
		r.request.unlink_flags = 0;
		return submitWithErrno(r);
	}

	static int wrapMkdir(const char *file, mode_t mode) {
		os::IORequest r(os::Handle(AT_FDCWD), os::Thread::current());
		r.request.opcode = IORING_OP_MKDIRAT;
		r.request.addr = reinterpret_cast<size_t>(file);
		r.request.len = mode;
		return submitWithErrno(r);
	}

	static int wrapRmdir(const char *file) {
		os::IORequest r(os::Handle(AT_FDCWD), os::Thread::current());
		r.request.opcode = IORING_OP_UNLINKAT;
		r.request.addr = reinterpret_cast<size_t>(file);
		r.request.unlink_flags = AT_REMOVEDIR;
		return submitWithErrno(r);
	}

	typedef int (*FsyncPtr)(int);

	extern "C" int fsync(int fd) {
		os::Thread thread = os::Thread::current();
		if (thread != os::Thread::invalid) {
			os::IORequest r(os::Handle(fd), thread);
			r.request.opcode = IORING_OP_FSYNC;
			r.request.fsync_flags = 0;
			return submitWithErrno(r);
		} else {
			static FsyncPtr fsync = NULL;
			if (!fsync)
				fsync = (FsyncPtr)dlsym(RTLD_DEFAULT, "fsync");
			return (*fsync)(fd);
		}
	}

	extern "C" int fdatasync(int fd) {
		os::Thread thread = os::Thread::current();
		if (thread != os::Thread::invalid) {
			os::IORequest r(os::Handle(fd), thread);
			r.request.opcode = IORING_OP_FSYNC;
			r.request.fsync_flags = IORING_FSYNC_DATASYNC;
			return submitWithErrno(r);
		} else {
			static FsyncPtr fdatasync = NULL;
			if (!fdatasync)
				fdatasync = (FsyncPtr)dlsym(RTLD_DEFAULT, "fdatasync");
			return (*fdatasync)(fd);
		}
	}

	static sqlite3_vfs *initializeAsync() {
		// Default VFS:
		sqlite3_vfs *vfs = sqlite3_vfs_find(NULL);

		SET_SYSCALL(vfs, "open", &wrapOpen);
		SET_SYSCALL(vfs, "close", &wrapClose);
		SET_SYSCALL(vfs, "read", &wrapRead);
		SET_SYSCALL(vfs, "pread", &wrapPread);
		SET_SYSCALL(vfs, "write", &wrapWrite);
		SET_SYSCALL(vfs, "pwrite", &wrapPwrite);
		SET_SYSCALL(vfs, "unlink", &wrapUnlink);
		SET_SYSCALL(vfs, "mkdir", &wrapMkdir);
		SET_SYSCALL(vfs, "rmdir", &wrapRmdir);
		SET_SYSCALL(vfs, "stat", &wrapStat);
		SET_SYSCALL(vfs, "fstat", &wrapFstat);
		SET_SYSCALL(vfs, "lstat", &wrapLstat);

#ifdef IORING_OP_FTRUNCATE
		SET_SYSCALL(vfs, "ftruncate", &wrapFtruncate);
#endif

		// We would like to intercept fsync through here as well, we use the functions above for that.

		// Note: We could intercept fcntl to handle advisory file locking. They don't wait, so no
		// problems there.

		return vfs;
	}

#else

	static sqlite3_vfs *initializeAsync() {
		// We don't support the plain poll-based IO yet.
		return sqlite3_vfs_find(NULL);
	}

#endif

	static size_t initialized = 0;

	void sqlite3EnableAsync() {
		size_t state = atomicCAS(initialized, 0, 1);
		if (state == 0) {
			// Make sure SQLite uses locks internally. Otherwise, it might be confused due to async IO.
			sqlite3_config(SQLITE_CONFIG_SERIALIZED);

			sqlite3_vfs *vfs = initializeAsync();

			// Replace 'sleep':
			vfs->xSleep = stormSleep;

			atomicWrite(initialized, 2);
			state = 2;
		}

		// We need to wait for some other thread to initialize. A simple spinlock is fine, initialization is quite cheap.
		while (state != 2)
			state = atomicRead(initialized);
	}

}
