#pragma once

#include "pica_to_mtl.hpp"

using namespace PICA;

namespace Metal {

struct DepthStencilHash {
    bool depthStencilWrite;
    u8 depthFunc;
    u32 stencilConfig;
    u16 stencilOpConfig;
};

class DepthStencilCache {
public:
    DepthStencilCache() = default;

    ~DepthStencilCache() {
        clear();
    }

    void set(MTL::Device* dev) {
        device = dev;
    }

    MTL::DepthStencilState* get(DepthStencilHash hash) {
        u64 intHash = ((u64)hash.depthStencilWrite << 56) | ((u64)hash.depthFunc << 48) | ((u64)hash.stencilConfig << 16) | (u64)hash.stencilOpConfig;
        auto& depthStencilState = depthStencilCache[intHash];
        if (!depthStencilState) {
            MTL::DepthStencilDescriptor* desc = MTL::DepthStencilDescriptor::alloc()->init();
            desc->setDepthWriteEnabled(hash.depthStencilWrite);
            desc->setDepthCompareFunction(toMTLCompareFunc(hash.depthFunc));

            const bool stencilEnable = Helpers::getBit<0>(hash.stencilConfig);
            MTL::StencilDescriptor* stencilDesc = nullptr;
            if (stencilEnable) {
               	const u8 stencilFunc = Helpers::getBits<4, 3>(hash.stencilConfig);
               	const s8 reference = s8(Helpers::getBits<16, 8>(hash.stencilConfig)); // Signed reference value
               	const u8 stencilRefMask = Helpers::getBits<24, 8>(hash.stencilConfig);

               	const u32 stencilBufferMask = hash.depthStencilWrite ? Helpers::getBits<8, 8>(hash.stencilConfig) : 0;

               	const u8 stencilFailOp = Helpers::getBits<0, 3>(hash.stencilOpConfig);
               	const u8 depthFailOp = Helpers::getBits<4, 3>(hash.stencilOpConfig);
               	const u8 passOp = Helpers::getBits<8, 3>(hash.stencilOpConfig);

                stencilDesc = MTL::StencilDescriptor::alloc()->init();
                stencilDesc->setStencilFailureOperation(toMTLStencilOperation(stencilFailOp));
                stencilDesc->setDepthFailureOperation(toMTLStencilOperation(depthFailOp));
                stencilDesc->setDepthStencilPassOperation(toMTLStencilOperation(passOp));
                stencilDesc->setStencilCompareFunction(toMTLCompareFunc(stencilFunc));
                stencilDesc->setReadMask(stencilRefMask);
                stencilDesc->setWriteMask(stencilBufferMask);
                // TODO: Set reference value

                desc->setFrontFaceStencil(stencilDesc);
                desc->setBackFaceStencil(stencilDesc);
            }

            depthStencilState = device->newDepthStencilState(desc);

            desc->release();
            if (stencilDesc) {
                stencilDesc->release();
            }
        }

        return depthStencilState;
    }

    void clear() {
        for (auto& pair : depthStencilCache) {
            pair.second->release();
        }
        depthStencilCache.clear();
    }

private:
    std::unordered_map<u64, MTL::DepthStencilState*> depthStencilCache;

    MTL::Device* device;
};

} // namespace Metal