From b6c72e72e47f59b8552f9295449c7328a99df53e Mon Sep 17 00:00:00 2001
From: Samuliak <samuliak77@gmail.com>
Date: Wed, 3 Jul 2024 17:45:55 +0200
Subject: [PATCH] use sampler states

---
 include/renderer_mtl/mtl_texture.hpp   |  1 +
 include/renderer_mtl/pica_to_mtl.hpp   | 16 ++++++++++++++++
 include/renderer_mtl/renderer_mtl.hpp  |  2 +-
 src/core/renderer_mtl/mtl_texture.cpp  | 21 ++++++++++++++++++++-
 src/core/renderer_mtl/renderer_mtl.cpp | 11 ++++++-----
 src/host_shaders/metal_shaders.metal   | 13 ++++++-------
 6 files changed, 50 insertions(+), 14 deletions(-)

diff --git a/include/renderer_mtl/mtl_texture.hpp b/include/renderer_mtl/mtl_texture.hpp
index 44ec61fa..bedf97a4 100644
--- a/include/renderer_mtl/mtl_texture.hpp
+++ b/include/renderer_mtl/mtl_texture.hpp
@@ -26,6 +26,7 @@ struct Texture {
     Interval<u32> range;
 
     MTL::Texture* texture = nullptr;
+    MTL::SamplerState* sampler = nullptr;
 
     Texture() : valid(false) {}
 
diff --git a/include/renderer_mtl/pica_to_mtl.hpp b/include/renderer_mtl/pica_to_mtl.hpp
index f4437da2..81cd831a 100644
--- a/include/renderer_mtl/pica_to_mtl.hpp
+++ b/include/renderer_mtl/pica_to_mtl.hpp
@@ -110,4 +110,20 @@ inline MTL::PrimitiveType toMTLPrimitiveType(PrimType primType) {
     }
 }
 
+inline MTL::SamplerAddressMode toMTLSamplerAddressMode(u8 addrMode) {
+    switch (addrMode) {
+    case 0: return MTL::SamplerAddressModeClampToEdge;
+    case 1: return MTL::SamplerAddressModeClampToBorderColor;
+    case 2: return MTL::SamplerAddressModeRepeat;
+    case 3: return MTL::SamplerAddressModeMirrorRepeat;
+    case 4: return MTL::SamplerAddressModeClampToEdge;
+    case 5: return MTL::SamplerAddressModeClampToBorderColor;
+    case 6: return MTL::SamplerAddressModeRepeat;
+    case 7: return MTL::SamplerAddressModeRepeat;
+    default: panic("Unknown sampler address mode %u", addrMode);
+    }
+
+    return MTL::SamplerAddressModeClampToEdge;
+}
+
 } // namespace PICA
diff --git a/include/renderer_mtl/renderer_mtl.hpp b/include/renderer_mtl/renderer_mtl.hpp
index d332f539..48b978e8 100644
--- a/include/renderer_mtl/renderer_mtl.hpp
+++ b/include/renderer_mtl/renderer_mtl.hpp
@@ -97,7 +97,7 @@ class RendererMTL final : public Renderer {
 
 	std::optional<Metal::ColorRenderTarget> getColorRenderTarget(u32 addr, PICA::ColorFmt format, u32 width, u32 height, bool createIfnotFound = true);
 	Metal::DepthStencilRenderTarget& getDepthRenderTarget();
-	MTL::Texture* getTexture(Metal::Texture& tex);
+	Metal::Texture& getTexture(Metal::Texture& tex);
 	void setupTextureEnvState(MTL::RenderCommandEncoder* encoder);
 	void bindTexturesToSlots(MTL::RenderCommandEncoder* encoder);
 };
diff --git a/src/core/renderer_mtl/mtl_texture.cpp b/src/core/renderer_mtl/mtl_texture.cpp
index 8fb255dd..53e6806c 100644
--- a/src/core/renderer_mtl/mtl_texture.cpp
+++ b/src/core/renderer_mtl/mtl_texture.cpp
@@ -1,4 +1,5 @@
 #include "renderer_mtl/mtl_texture.hpp"
+#include "renderer_mtl/pica_to_mtl.hpp"
 #include "colour.hpp"
 #include <array>
 
@@ -23,7 +24,22 @@ void Texture::allocate() {
 void Texture::setNewConfig(u32 cfg) {
     config = cfg;
 
-    // TODO: implement this
+    if (sampler) {
+        sampler->release();
+    }
+
+    const auto magFilter = (cfg & 0x2) != 0 ? MTL::SamplerMinMagFilterLinear : MTL::SamplerMinMagFilterNearest;
+    const auto minFilter = (cfg & 0x4) != 0 ? MTL::SamplerMinMagFilterLinear : MTL::SamplerMinMagFilterNearest;
+    const auto wrapT = PICA::toMTLSamplerAddressMode(getBits<8, 3>(cfg));
+    const auto wrapS = PICA::toMTLSamplerAddressMode(getBits<12, 3>(cfg));
+
+    MTL::SamplerDescriptor* samplerDescriptor = MTL::SamplerDescriptor::alloc()->init();
+    samplerDescriptor->setMinFilter(minFilter);
+    samplerDescriptor->setMagFilter(magFilter);
+    samplerDescriptor->setSAddressMode(wrapS);
+    samplerDescriptor->setTAddressMode(wrapT);
+
+    sampler = device->newSamplerState(samplerDescriptor);
 }
 
 void Texture::free() {
@@ -32,6 +48,9 @@ void Texture::free() {
 	if (texture) {
 		texture->release();
 	}
+	if (sampler) {
+        sampler->release();
+    }
 }
 
 u64 Texture::sizeInBytes() {
diff --git a/src/core/renderer_mtl/renderer_mtl.cpp b/src/core/renderer_mtl/renderer_mtl.cpp
index b5efe9e2..c63efe2f 100644
--- a/src/core/renderer_mtl/renderer_mtl.cpp
+++ b/src/core/renderer_mtl/renderer_mtl.cpp
@@ -453,17 +453,17 @@ Metal::DepthStencilRenderTarget& RendererMTL::getDepthRenderTarget() {
 	}
 }
 
-MTL::Texture* RendererMTL::getTexture(Metal::Texture& tex) {
+Metal::Texture& RendererMTL::getTexture(Metal::Texture& tex) {
 	auto buffer = textureCache.find(tex);
 
 	if (buffer.has_value()) {
-		return buffer.value().get().texture;
+		return buffer.value().get();
 	} else {
 		const auto textureData = std::span{gpu.getPointerPhys<u8>(tex.location), tex.sizeInBytes()};  // Get pointer to the texture data in 3DS memory
 		Metal::Texture& newTex = textureCache.add(tex);
 		newTex.decodeTexture(textureData);
 
-		return newTex.texture;
+		return newTex;
 	}
 }
 
@@ -518,8 +518,9 @@ void RendererMTL::bindTexturesToSlots(MTL::RenderCommandEncoder* encoder) {
 
 		if (addr != 0) [[likely]] {
 			Metal::Texture targetTex(device, addr, static_cast<PICA::TextureFmt>(format), width, height, config);
-			MTL::Texture* tex = getTexture(targetTex);
-			encoder->setFragmentTexture(tex, i);
+			auto tex = getTexture(targetTex);
+			encoder->setFragmentTexture(tex.texture, i);
+			encoder->setFragmentSamplerState(tex.sampler ? tex.sampler : basicSampler, i);
 		} else {
 			// TODO: bind a dummy texture?
 		}
diff --git a/src/host_shaders/metal_shaders.metal b/src/host_shaders/metal_shaders.metal
index 6d667329..2947e827 100644
--- a/src/host_shaders/metal_shaders.metal
+++ b/src/host_shaders/metal_shaders.metal
@@ -247,10 +247,9 @@ struct FragTEV {
     }
 };
 
-fragment float4 fragmentDraw(DrawVertexOut in [[stage_in]], constant PicaRegs& picaRegs [[buffer(0)]], constant FragTEV& tev [[buffer(1)]], texture2d<float> tex0 [[texture(0)]], texture2d<float> tex1 [[texture(1)]], texture2d<float> tex2 [[texture(2)]]) {
-    // TODO: upload this as argument
-    sampler samplr;
-
+fragment float4 fragmentDraw(DrawVertexOut in [[stage_in]], constant PicaRegs& picaRegs [[buffer(0)]], constant FragTEV& tev [[buffer(1)]],
+                             texture2d<float> tex0 [[texture(0)]], texture2d<float> tex1 [[texture(1)]], texture2d<float> tex2 [[texture(2)]],
+                             sampler samplr0 [[sampler(0)]], sampler samplr1 [[sampler(1)]], sampler samplr2 [[sampler(2)]]) {
     Globals globals;
     globals.tevSources[0] = in.color;
     // TODO: uncomment
@@ -259,9 +258,9 @@ fragment float4 fragmentDraw(DrawVertexOut in [[stage_in]], constant PicaRegs& p
 	uint textureConfig = picaRegs.read(0x80u);
 	float2 texCoord2 = (textureConfig & (1u << 13)) != 0u ? in.texCoord1 : in.texCoord2;
 
-	if ((textureConfig & 1u) != 0u) globals.tevSources[3] = tex0.sample(samplr, in.texCoord0.xy);
-	if ((textureConfig & 2u) != 0u) globals.tevSources[4] = tex1.sample(samplr, in.texCoord1);
-	if ((textureConfig & 4u) != 0u) globals.tevSources[5] = tex2.sample(samplr, texCoord2);
+	if ((textureConfig & 1u) != 0u) globals.tevSources[3] = tex0.sample(samplr0, in.texCoord0.xy);
+	if ((textureConfig & 2u) != 0u) globals.tevSources[4] = tex1.sample(samplr1, in.texCoord1);
+	if ((textureConfig & 4u) != 0u) globals.tevSources[5] = tex2.sample(samplr2, texCoord2);
 	globals.tevSources[13] = float4(0.0);  // Previous buffer
 	globals.tevSources[15] = in.color;     // Previous combiner