// #ifdef CPU_KVM
#include "cpu_kvm.hpp"

MmuTables* mmuTables = nullptr;

// ARMv6 MMU supports up to two levels of address lookup with 4KiB pages.
// The top level is called the level 1 table. It contains 4096 entries of 4 bytes each (16KiB total).
// The bottom level is called level 2, which contains 256 entries of 4 bytes each (1KiB total).
// The level 1 table supports 3 kind of entries: Pages, Sections and Supersections each corresponding to a page size.
// Pages are for 4KiB pages, Sections are for 1MiB pages and Supersections are for 16MiB pages.

// Sections and supersections don't use the level 2 table at all.
// This is because with a 32 bit vaddr and 4 KiB pages, the offset is 12 bits,
// the level 2 index is 8 bits and the level 1 index is 12 bits -> 12 + 8 + 12 = 32 for the vaddr
// However for sections, the offset is 20 bits, so you can only use
// the level 1 table (up to 4096 entries) because 20 for offset + 12 for level 1 -> 20 + 12 = 32 for the vaddr
// For supersections, you need a 24 bit offset, so the level 1 table actually has up to 256 entries because
// you're left with 8 bits -> 24 + 8 = 32 for the vaddr

// Level 2 entries
// Bits:  31-12  11  10  9      8-6       5-4   3  2  1  0
// Value: BADDR  nG  S   APX    TEX[2:0]  AP    C  B  1  XN

// Access permission table:
/*
    APX	AP	Privileged	Unprivileged	Description
    0	00	No access	No access	    Permission fault
    0	01	Read/Write	No access	    Privileged Access only
    0	10	Read/Write	Read	        No user-mode write
    0	11	Read/Write	Read/Write	    Full access
    1	00	-	        -	            Reserved
    1	01	Read	    No access	    Privileged Read only
    1	10	Read	    Read	        Read only
    1	11	-	        -	            Reserved
*/

constexpr u32 APX = 1 << 9;
constexpr u32 AP0 = 1 << 4;
constexpr u32 AP1 = 1 << 5;

enum Level2Flags : u32
{
    Level2Flags_ExecuteNever = 1 << 0,
    Level2Flags_Bufferable = 1 << 2,
    Level2Flags_Cacheable = 1 << 3,
    Level2Flags_Shared = 1 << 10,
    Level2Flags_AP_NoUserModeWrite = AP1,
    Level2Flags_AP_FullAccess = AP1 | AP0,
};

// Generated by passing the following code to godbolt:
// Thanks libn3ds
/*
    // FCSE PID Register (FCSE PID = 0)
	// Note: This must be 0 before disabling the MMU otherwise UB
    __asm__ volatile ("mcr p15, 0, %0, c13, c0, 0" : : "r"(0));
    
    // Context ID Register (ASID = 0, PROCID = 0)
    __asm__ volatile ("mcr p15, 0, %0, c13, c0, 1" : : "r"(0));

    // // TTBR0 address shared page table walk and outer cachable write-through, no allocate on write
    uint32_t ttbr0 = mmuTableAddress | 0x12;
    __asm__ volatile ("mcr p15, 0, %0, c2, c0, 0" : : "r" (ttbr0) : "memory");

    // Use the 16 KiB L1 table only
    __asm__ volatile ("mcr p15, 0, %0, c2, c0, 2" : : "r"(0));

    // Domain 0 = client, remaining domains all = no access
    __asm__ volatile("mcr p15, 0, %0, c3, c0, 0" : : "r"(1));

    uint32_t* d = (uint32_t*)hypervisorCodeAddress;
    *d = hypervisorCodeAddress;
*/
constexpr u8 mmuCodeBefore[] = {
    0x00, 0x30, 0xb0, 0xe3, // movs r3, #0
    0x10, 0x3f, 0x0d, 0xee, // mcr  p15, #0, r3, c13, c0, #0
    0x30, 0x3f, 0x0d, 0xee, // mcr  p15, #0, r3, c13, c0, #1
    0x14, 0x20, 0x9f, 0xe5, // ldr  r2, [pc, #0x14]
    0x10, 0x2f, 0x02, 0xee, // mcr  p15, #0, r2, c2, c0, #0
    0x50, 0x3f, 0x02, 0xee, // mcr  p15, #0, r3, c2, c0, #2
    0x01, 0x30, 0xb0, 0xe3, // movs r3, #1
    0x10, 0x3f, 0x03, 0xee, // mcr  p15, #0, r3, c3, c0, #0
    0x0d, 0x32, 0xa0, 0xe3, // mov  r3, #-0x30000000 TODO: instead jump to exit code
    0x00, 0x30, 0x83, 0xe5, // str  r3, [r3]
    (mmuTableAddress & 0xFF) | 0x12, (mmuTableAddress >> 8) & 0xFF, (mmuTableAddress >> 16) & 0xFF, (mmuTableAddress >> 24) & 0xFF,
};

// Generated by passing the following code to godbolt:
// Thanks libn3ds
/*
    // Invalidate TLB
    __asm__ volatile("mcr p15, 0, %0, c8, c7, 0" : : "r"(0));
    __asm__ volatile("dsb");

    // Get ACR
    uint32_t reg;
    __asm__ volatile("mrc p15, 0, %0, c1, c0, 1" : "=r"(reg));
    // Enable Return stack, Dynamic branch prediction, Static branch prediction,
	// Instruction folding and SMP mode: the CPU is taking part in coherency
    reg |= 0x2F;
    __asm__ volatile("mcr p15, 0, %0, c1, c0, 1" : : "r"(reg));

    // Get CR
    __asm__ volatile("mrc p15, 0, %0, c1, c0, 0" : "=r"(reg));
    // Enable MMU, D-Cache, Program flow prediction,
	// I-Cache, high exception vectors, Unaligned data access,
	// subpage AP bits disabled
    reg |= 0xC03805;
    __asm__ volatile("mcr p15, 0, %0, c1, c0, 0" : : "r"(reg));

    // Invalidate both caches
    __asm__ volatile("mcr p15, 0, %0, c7, c7, 0" : : "r" (0) : "memory");
    __asm__ volatile("dsb");
    __asm__ volatile("isb");

    uint32_t* d = (uint32_t*)hypervisorCodeAddress;
    *d = hypervisorCodeAddress;
*/
constexpr u8 mmuCodeAfter[] = {
    0x00, 0x00, 0xb0, 0xe3, // movs r0, #0
    0x17, 0x0f, 0x08, 0xee, // mcr  p15, #0, r0, c8, c7, #0
    0x4f, 0xf0, 0x7f, 0xf5, // dsb  sy
    0x30, 0x3f, 0x11, 0xee, // mrc  p15, #0, r3, c1, c0, #1
    0x2f, 0x30, 0x83, 0xe3, // orr  r3, r3, #0x2f
    0x30, 0x3f, 0x01, 0xee, // mcr  p15, #0, r3, c1, c0, #1
    0x10, 0x2f, 0x11, 0xee, // mrc  p15, #0, r2, c1, c0, #0
    0x05, 0x38, 0x03, 0xe3, // movw r3, #0x3805
    0xc0, 0x30, 0x40, 0xe3, // movt r3, #0xc0
    0x02, 0x30, 0x93, 0xe1, // orrs r3, r3, r2
    0x10, 0x3f, 0x01, 0xee, // mcr  p15, #0, r3, c1, c0, #0
    0x17, 0x0f, 0x07, 0xee, // mcr  p15, #0, r0, c7, c7, #0
    0x4f, 0xf0, 0x7f, 0xf5, // dsb  sy
    0x6f, 0xf0, 0x7f, 0xf5, // isb  sy
    0x0d, 0x32, 0xa0, 0xe3, // mov  r3, #-0x30000000 TODO: instead jump to exit code
    0x00, 0x30, 0x83, 0xe5, // str  r3, [r3]
};

// Store the CPU state and exit the VM, then return from SVC
// Generated from the following ARM32 assembly
/*
    push {r0}
    ldr r0, GuestStateAddr + 4
    stmfd r0, {r1-r12, sp, lr, pc}^
    pop {r0}

    push {r1}
    ldr r1, GuestStateAddr
    str r0, [r1]

    // Exit the VM
    ldr r1, CodeAddr
    str r1, [r1]

    pop {r1}

    CodeAddr:
        .word   0xD0000000
    GuestStateAddr:
        .word   0xE0200000
*/
constexpr u8 svcHandlerCode[] = {
};

/// Level 1, page table entry
/// Bits:  31-10  9    8-5     4    3   2    1  0
/// Value: BADDR  IMP  Domain  SBZ  NS  PXN  0  1
/// We don't use domains, so we can set it to 0
u32 pageTableEntry(u32 level2Address)
{
    // Level 2 tables have 256 entries of 4 bytes each, so they must be aligned to 1KiB
    if ((level2Address & 0x3FF) != 0) {
        Helpers::panic("level2Address is not aligned to 1KiB");
    }

    return level2Address | 0b1;
}

u32 level2Entry(u32 physicalAddress, Level2Flags flags)
{
    return (physicalAddress & 0xFFFFF000) | 0b10 | flags;
}

void mapPageTables(u32 virtualAddress, u32 physicalAddress, u8 pageCount, Level2Flags flags)
{
    if ((virtualAddress & 0xFFFFF000) != 0) {
        Helpers::panic("virtualAddress is not aligned to 4KiB");
    }

    if ((physicalAddress & 0xFFFFF000) != 0) {
        Helpers::panic("physicalAddress is not aligned to 4KiB");
    }

    for (u32 i = 0; i < pageCount * 4096; i += 4096)
    {
        u8 level2Index = ((virtualAddress + i) >> 12) & 0xFF;
        mmuTables->level2SectionTables[level2Index] = level2Entry(physicalAddress + i, flags);
    }

    u32 level2TableAddressVm = mmuTableAddress + offsetof(MmuTables, level2SectionTables);
    mmuTables->level1[virtualAddress >> 20] = pageTableEntry(level2TableAddressVm);
}

CPU::CPU(Memory& mem, Kernel& kernel)
: mem(mem), env(mem, kernel)
{
}

void CPU::romLoaded()
{
    NCCH* ncch = mem.getCXI();
    if (!ncch) {
        // TODO: what to do here?
        Helpers::panic("Alber has decided to panic!");
    }

    // Map the VM exit code which stores all registers to shared hypervisor memory
    // and exits the VM by writing to read-only memory.
    // We map it at the start of hypervisorCodeAddress.
    env.mapHypervisorCode(std::vector<u8>(vmExitCode, vmExitCode + sizeof(vmExitCode)), 0);

    printf("Debug: Running pre mmu table code\n");
    env.mapHypervisorCode(std::vector<u8>(mmuCodeBefore, mmuCodeBefore + sizeof(mmuCodeBefore)), customEntryOffset);
    env.setPC(hypervisorCodeAddress + customEntryOffset);
    env.run();

    const auto& text = ncch->text;
    const auto& rodata = ncch->rodata;
    const auto& data = ncch->data;

    mmuTables = (MmuTables*)((uintptr_t)env.hypervisorDataRegion + mmuTableOffset);
    printf("Debug: level2sectionTables is at %p in host, %08x in guest\n", mmuTables->level2SectionTables, mmuTableAddress + offsetof(MmuTables, level2SectionTables));
    mapPageTables(
        text.address,
        text.address,
        text.pageCount,
        (Level2Flags)(Level2Flags_Shared |
        Level2Flags_Bufferable |
        Level2Flags_Cacheable |
        Level2Flags_AP_NoUserModeWrite)
    );
    mapPageTables(
        rodata.address,
        rodata.address,
        rodata.pageCount,
        (Level2Flags)(Level2Flags_Shared |
        Level2Flags_Bufferable |
        Level2Flags_Cacheable |
        Level2Flags_AP_NoUserModeWrite |
        Level2Flags_ExecuteNever)
    );
    mapPageTables(
        data.address,
        data.address,
        data.pageCount,
        (Level2Flags)(Level2Flags_Shared |
        Level2Flags_Bufferable |
        Level2Flags_Cacheable |
        Level2Flags_AP_FullAccess)
    );

    printf("Debug: Running post mmu table code\n");
    env.mapHypervisorCode(std::vector<u8>(mmuCodeAfter, mmuCodeAfter + sizeof(mmuCodeAfter)), customEntryOffset);
    env.setPC(hypervisorCodeAddress + customEntryOffset);
    env.run();
    printf("Done\n");
}

// #endif