diff --git a/CMakeLists.txt b/CMakeLists.txt index bae21128..a0de7271 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -416,7 +416,7 @@ set(HEADER_FILES include/emulator.hpp include/helpers.hpp include/termcolor.hpp include/fs/archive_twl_sound.hpp include/fs/archive_card_spi.hpp include/services/ns.hpp include/audio/audio_device.hpp include/audio/audio_device_interface.hpp include/audio/libretro_audio_device.hpp include/services/ir/ir_types.hpp include/services/ir/ir_device.hpp include/services/ir/circlepad_pro.hpp include/services/service_intercept.hpp - include/screen_layout.hpp + include/screen_layout.hpp include/services/service_map.hpp ) if(IOS) diff --git a/include/lua_manager.hpp b/include/lua_manager.hpp index 7a79fa60..c16292d6 100644 --- a/include/lua_manager.hpp +++ b/include/lua_manager.hpp @@ -48,7 +48,8 @@ class LuaManager { } } - bool signalInterceptedService(const std::string& service, u32 function, u32 messagePointer); + bool signalInterceptedService(const std::string& service, u32 function, u32 messagePointer, int callbackRef); + void removeInterceptedService(const std::string& service, u32 function, int callbackRef); }; #else // Lua not enabled, Lua manager does nothing @@ -62,6 +63,7 @@ class LuaManager { void loadString(const std::string& code) {} void reset() {} void signalEvent(LuaEvent e) {} - bool signalInterceptedService(const std::string& service, u32 function, u32 messagePointer) { return false; } + bool signalInterceptedService(const std::string& service, u32 function, u32 messagePointer, int callbackRef) { return false; } + void removeInterceptedService(const std::string& service, u32 function, int callbackRef) {} }; #endif diff --git a/include/services/service_manager.hpp b/include/services/service_manager.hpp index 315e1163..f51ddf2d 100644 --- a/include/services/service_manager.hpp +++ b/include/services/service_manager.hpp @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include "kernel_types.hpp" #include "logger.hpp" @@ -94,8 +94,10 @@ class ServiceManager { // For example, if we want to intercept dsp::DSP ReadPipe (Header: 0x000E00C0), the "serviceName" field would be "dsp::DSP" // and the "function" field would be 0x000E00C0 LuaManager& lua; - std::unordered_set interceptedServices = {}; - // Calling std::unordered_set::size() compiles to a fairly non-trivial function call on Clang, so we store this + + // Map from service intercept entries to their corresponding Lua callbacks + std::unordered_map interceptedServices = {}; + // Calling std::unordered_map::size() compiles to a non-trivial function call on Clang, so we store this // separately and check it on service calls, for performance reasons bool haveServiceIntercepts = false; @@ -134,12 +136,23 @@ class ServiceManager { Y2RService& getY2R() { return y2r; } IRUserService& getIRUser() { return ir_user; } - void addServiceIntercept(const std::string& service, u32 function) { - interceptedServices.insert(InterceptedService(service, function)); + void addServiceIntercept(const std::string& service, u32 function, int callbackRef) { + auto success = interceptedServices.try_emplace(InterceptedService(service, function), callbackRef); + if (!success.second) { + // An intercept for this service function already exists + // Remove the old callback and set the new one + lua.removeInterceptedService(service, function, success.first->second); + success.first->second = callbackRef; + } + haveServiceIntercepts = true; } void clearServiceIntercepts() { + for (const auto& [interceptedService, callbackRef] : interceptedServices) { + lua.removeInterceptedService(interceptedService.serviceName, interceptedService.function, callbackRef); + } + interceptedServices.clear(); haveServiceIntercepts = false; } diff --git a/include/services/service_map.hpp b/include/services/service_map.hpp new file mode 100644 index 00000000..92f6a596 --- /dev/null +++ b/include/services/service_map.hpp @@ -0,0 +1,29 @@ +#pragma once +#include +#include +#include + +#include "handles.hpp" + +// Helpers for constructing std::maps to look up OS services. +// We want to be able to map both service names -> services (Used for OS emulation) +// And service handles -> services (For Lua service call intercepts) +using ServiceMapEntry = std::pair; + +// Comparator for constructing a name->handle service map +struct ServiceMapByNameComparator { + // The comparators must be transparent, as our search key is different from our set key + // Our set key is a ServiceMapEntry, while the comparator each time is either the name or the service handle + using is_transparent = std::true_type; + bool operator()(const ServiceMapEntry& lhs, std::string_view rhs) const { return lhs.first < rhs; } + bool operator()(std::string_view lhs, const ServiceMapEntry& rhs) const { return lhs < rhs.first; } + bool operator()(const ServiceMapEntry& lhs, const ServiceMapEntry& rhs) const { return lhs.first < rhs.first; } +}; + +// Comparator for constructing a handle->name service map +struct ServiceMapByHandleComparator { + using is_transparent = std::true_type; + bool operator()(const ServiceMapEntry& lhs, HorizonHandle rhs) const { return lhs.second < rhs; } + bool operator()(HorizonHandle lhs, const ServiceMapEntry& rhs) const { return lhs < rhs.second; } + bool operator()(const ServiceMapEntry& lhs, const ServiceMapEntry& rhs) const { return lhs.second < rhs.second; } +}; diff --git a/src/core/services/service_manager.cpp b/src/core/services/service_manager.cpp index 2c959493..bd4488e9 100644 --- a/src/core/services/service_manager.cpp +++ b/src/core/services/service_manager.cpp @@ -1,9 +1,10 @@ #include "services/service_manager.hpp" -#include +#include #include "ipc.hpp" #include "kernel.hpp" +#include "services/service_map.hpp" ServiceManager::ServiceManager( std::span regs, Memory& mem, GPU& gpu, u32& currentPID, Kernel& kernel, const EmulatorConfig& config, LuaManager& lua @@ -98,7 +99,7 @@ void ServiceManager::registerClient(u32 messagePointer) { } // clang-format off -static std::map serviceMap = { +static const ServiceMapEntry serviceMapArray[] = { { "ac:u", KernelHandles::AC }, { "ac:i", KernelHandles::AC }, { "act:a", KernelHandles::ACT }, @@ -148,6 +149,9 @@ static std::map serviceMap = { }; // clang-format on +static std::set serviceMapByName{std::begin(serviceMapArray), std::end(serviceMapArray)}; +static std::set serviceMapByHandle{std::begin(serviceMapArray), std::end(serviceMapArray)}; + // https://www.3dbrew.org/wiki/SRV:GetServiceHandle void ServiceManager::getServiceHandle(u32 messagePointer) { u32 nameLength = mem.read32(messagePointer + 12); @@ -158,7 +162,7 @@ void ServiceManager::getServiceHandle(u32 messagePointer) { log("srv::getServiceHandle (Service: %s, nameLength: %d, flags: %d)\n", service.c_str(), nameLength, flags); // Look up service handle in map, panic if it does not exist - if (auto search = serviceMap.find(service); search != serviceMap.end()) + if (auto search = serviceMapByName.find(service); search != serviceMapByName.end()) handle = search->second; else Helpers::panic("srv: GetServiceHandle with unknown service %s", service.c_str()); @@ -271,16 +275,13 @@ bool ServiceManager::checkForIntercept(u32 messagePointer, Handle handle) { // Check if there's a Lua handler for this function and call it const u32 function = mem.read32(messagePointer); - for (auto [serviceName, serviceHandle] : serviceMap) { - if (serviceHandle == handle) { - auto intercept = InterceptedService(std::string(serviceName), function); - if (interceptedServices.contains(intercept)) { - // If the Lua handler returns true, it means the service is handled entirely - // From Lua, and we shouldn't do anything else here. - return lua.signalInterceptedService(intercept.serviceName, function, messagePointer); - } + if (auto service_it = serviceMapByHandle.find(handle); service_it != serviceMapByHandle.end()) { + auto intercept = InterceptedService(service_it->first, function); - break; + if (auto intercept_it = interceptedServices.find(intercept); intercept_it != interceptedServices.end()) { + // If the Lua handler returns true, it means the service is handled entirely + // From Lua, and we shouldn't do anything else here. + return lua.signalInterceptedService(intercept.serviceName, function, messagePointer, intercept_it->second); } } diff --git a/src/lua.cpp b/src/lua.cpp index c23cfbc6..8d3f7980 100644 --- a/src/lua.cpp +++ b/src/lua.cpp @@ -101,14 +101,14 @@ void LuaManager::signalEventInternal(LuaEvent e) { lua_pcall(L, 1, 0, 0); } -// Calls the "interceptService" function, if it exists, when a service call is intercepted +// Calls the callback passed to the addServiceIntercept function when a service call is intercepted // It passes the service name, the function header, and a pointer to the call's TLS buffer as parameters -// interceptService is expected to return a bool, which indicates whether the C++ code should proceed to handle the service call +// The callback is expected to return a bool, indicating whether the C++ code should proceed to handle the service call // or if the Lua code handles it entirely. -// If the bool is true, the Lua code handles the service call entirely and the C++ code doesn't do anything extra -// Otherwise, then the C++ code calls its service call handling code as usual. -bool LuaManager::signalInterceptedService(const std::string& service, u32 function, u32 messagePointer) { - lua_getglobal(L, "interceptService"); +// If the bool is true, the Lua code handles the service call entirely and the C++ side doesn't do anything extra +// Otherwise, the C++ side calls its service call handling code as usual. +bool LuaManager::signalInterceptedService(const std::string& service, u32 function, u32 messagePointer, int callbackRef) { + lua_rawgeti(L, LUA_REGISTRYINDEX, callbackRef); lua_pushstring(L, service.c_str()); // Push service name lua_pushinteger(L, function); // Push function header lua_pushinteger(L, messagePointer); // Push pointer to TLS buffer @@ -129,6 +129,12 @@ bool LuaManager::signalInterceptedService(const std::string& service, u32 functi return ret; } +// Removes a reference from the callback value in the registry +// Prevents memory leaks, otherwise the function object would stay forever +void LuaManager::removeInterceptedService(const std::string& service, u32 function, int callbackRef) { + luaL_unref(L, LUA_REGISTRYINDEX, callbackRef); +} + void LuaManager::reset() { // Reset scripts haveScript = false; @@ -238,15 +244,20 @@ static int loadROMThunk(lua_State* L) { static int addServiceInterceptThunk(lua_State* L) { // Service name argument is invalid, report that loading failed and exit if (lua_type(L, 1) != LUA_TSTRING) { - lua_pushboolean(L, 0); - lua_error(L); - return 2; + return luaL_error(L, "Argument 1 (service name) is not a string"); } if (lua_type(L, 2) != LUA_TNUMBER) { - lua_pushboolean(L, 0); - lua_error(L); - return 2; + return luaL_error(L, "Argument 2 (function id) is not a number"); + } + + // Callback is not a function object directly, fail and exit + // Objects with a __call metamethod are not allowed (tables, userdata) + // Good: addServiceIntercept(serviceName, func, myLuaFunction) + // Good: addServiceIntercept(serviceName, func, function (service, func, buffer) ... end) + // Bad: addServiceIntercept(serviceName, func, obj:method) + if (lua_type(L, 3) != LUA_TFUNCTION) { + return luaL_error(L, "Argument 3 (callback) is not a function"); } // Get the name of the service we want to intercept, as well as the header of the function to intercept @@ -254,8 +265,13 @@ static int addServiceInterceptThunk(lua_State* L) { const char* const str = lua_tolstring(L, 1, &nameLength); const u32 function = (u32)lua_tointeger(L, 2); const auto serviceName = std::string(str, nameLength); - LuaManager::g_emulator->getServiceManager().addServiceIntercept(serviceName, function); - return 2; + + // Stores a reference to the callback function object in the registry for later use + // Must be freed with lua_unref later, in order to avoid memory leaks + lua_pushvalue(L, 3); + const int callbackRef = luaL_ref(L, LUA_REGISTRYINDEX); + LuaManager::g_emulator->getServiceManager().addServiceIntercept(serviceName, function, callbackRef); + return 0; } static int clearServiceInterceptsThunk(lua_State* L) { @@ -391,7 +407,7 @@ void LuaManager::initializeThunks() { disassembleARM = function(pc, instruction) return GLOBALS.__disassembleARM(pc, instruction) end, disassembleTeak = function(opcode, exp) return GLOBALS.__disassembleTeak(opcode, exp or 0) end, - addServiceIntercept = function(service, func) return GLOBALS.__addServiceIntercept(service, func) end, + addServiceIntercept = function(service, func, cb) return GLOBALS.__addServiceIntercept(service, func, cb) end, clearServiceIntercepts = function() return GLOBALS.__clearServiceIntercepts() end, Frame = __Frame,