Shader decompiler: More control flow handling

This commit is contained in:
wheremyfoodat 2024-07-26 22:02:03 +03:00
parent ff3afd436a
commit 5eb15de431
2 changed files with 206 additions and 16 deletions

View file

@ -4,6 +4,7 @@
#include <tuple>
#include <map>
#include <vector>
#include <utility>
#include "PICA/shader.hpp"
#include "PICA/shader_gen_types.hpp"
@ -95,7 +96,8 @@ namespace PICA::ShaderGen {
Language language;
void compileInstruction(u32& pc, bool& finished);
void compileRange(const AddressRange& range);
// Compile range "range" and returns the end PC or if we're "finished" with the program (called an END instruction)
std::pair<u32, bool> compileRange(const AddressRange& range);
void callFunction(const Function& function);
const Function* findFunction(const AddressRange& range);
@ -105,6 +107,7 @@ namespace PICA::ShaderGen {
std::string getDest(u32 dest) const;
std::string getSwizzlePattern(u32 swizzle) const;
std::string getDestSwizzle(u32 destinationMask) const;
const char* getCondition(u32 cond, u32 refX, u32 refY);
void setDest(u32 operandDescriptor, const std::string& dest, const std::string& value);
// Returns if the instruction uses the typical register encodings most instructions use

View file

@ -2,6 +2,9 @@
#include <fmt/format.h>
#include <array>
#include <cassert>
#include "config.hpp"
using namespace PICA;
@ -20,6 +23,40 @@ void ControlFlow::analyze(const PICAShader& shader, u32 entrypoint) {
}
}
// Helpers for merging parallel/series exit methods from Citra
// Merges exit method of two parallel branches.
static ExitMode exitParallel(ExitMode a, ExitMode b) {
if (a == ExitMode::Unknown) {
return b;
}
else if (b == ExitMode::Unknown) {
return a;
}
else if (a == b) {
return a;
}
return ExitMode::Conditional;
}
// Cascades exit method of two blocks of code.
static ExitMode exitSeries(ExitMode a, ExitMode b) {
assert(a != ExitMode::AlwaysEnd);
if (a == ExitMode::Unknown) {
return ExitMode::Unknown;
}
if (a == ExitMode::AlwaysReturn) {
return b;
}
if (b == ExitMode::Unknown || b == ExitMode::AlwaysEnd) {
return ExitMode::AlwaysEnd;
}
return ExitMode::Conditional;
}
ExitMode ControlFlow::analyzeFunction(const PICAShader& shader, u32 start, u32 end, Function::Labels& labels) {
// Initialize exit mode to unknown by default, in order to detect things like unending loops
auto [it, inserted] = exitMap.emplace(AddressRange(start, end), ExitMode::Unknown);
@ -32,17 +69,63 @@ ExitMode ControlFlow::analyzeFunction(const PICAShader& shader, u32 start, u32 e
for (u32 pc = start; pc < PICAShader::maxInstructionCount && pc != end; pc++) {
const u32 instruction = shader.loadedShader[pc];
const u32 opcode = instruction >> 26;
auto setExitMode = [&it](ExitMode mode) {
it->second = mode;
return it->second;
};
switch (opcode) {
case ShaderOpcodes::JMPC: Helpers::panic("Unimplemented control flow operation (JMPC)"); break;
case ShaderOpcodes::JMPU: Helpers::panic("Unimplemented control flow operation (JMPU)"); break;
case ShaderOpcodes::IFU: Helpers::panic("Unimplemented control flow operation (IFU)"); break;
case ShaderOpcodes::IFC: Helpers::panic("Unimplemented control flow operation (IFC)"); break;
case ShaderOpcodes::JMPC:
case ShaderOpcodes::JMPU: {
const u32 dest = getBits<10, 12>(instruction);
// Register this jump address to our outLabels set
labels.insert(dest);
// This opens up 2 parallel paths of execution
auto branchTakenExit = analyzeFunction(shader, dest, end, labels);
auto branchNotTakenExit = analyzeFunction(shader, pc + 1, dest, labels);
return setExitMode(exitParallel(branchTakenExit, branchNotTakenExit));
}
case ShaderOpcodes::IFU:
case ShaderOpcodes::IFC: {
Helpers::panic("IFC/IFU");
const u32 num = instruction & 0xff;
const u32 dest = getBits<10, 12>(instruction);
const Function* branchTakenFunc = addFunction(shader, pc + 1, dest);
// Check if analysis of the branch taken func failed and return unknown if it did
if (analysisFailed) {
return setExitMode(ExitMode::Unknown);
}
// Next analyze the not taken func
ExitMode branchNotTakenExitMode = ExitMode::AlwaysReturn;
if (num != 0) {
const Function* branchNotTakenFunc = addFunction(shader, dest, dest + num);
// Check if analysis failed and return unknown if it did
if (analysisFailed) {
return setExitMode(ExitMode::Unknown);
}
branchNotTakenExitMode = branchNotTakenFunc->exitMode;
}
auto parallel = exitParallel(branchTakenFunc->exitMode, branchNotTakenExitMode);
// Both branches of the if/else end, so there's nothing after the call
if (parallel == ExitMode::AlwaysEnd) {
return setExitMode(parallel);
} else {
ExitMode afterConditional = analyzeFunction(shader, pc + 1, end, labels);
ExitMode conditionalExitMode = exitSeries(parallel, afterConditional);
return setExitMode(conditionalExitMode);
}
break;
}
case ShaderOpcodes::CALL: Helpers::panic("Unimplemented control flow operation (CALL)"); break;
case ShaderOpcodes::CALLC: Helpers::panic("Unimplemented control flow operation (CALLC)"); break;
case ShaderOpcodes::CALLU: Helpers::panic("Unimplemented control flow operation (CALLU)"); break;
case ShaderOpcodes::LOOP: Helpers::panic("Unimplemented control flow operation (LOOP)"); break;
case ShaderOpcodes::END: it->second = ExitMode::AlwaysEnd; return it->second;
case ShaderOpcodes::END: return setExitMode(ExitMode::AlwaysEnd);
default: break;
}
@ -52,7 +135,7 @@ ExitMode ControlFlow::analyzeFunction(const PICAShader& shader, u32 start, u32 e
return ExitMode::AlwaysReturn;
}
void ShaderDecompiler::compileRange(const AddressRange& range) {
std::pair<u32, bool> ShaderDecompiler::compileRange(const AddressRange& range) {
u32 pc = range.start;
const u32 end = range.end >= range.start ? range.end : PICAShader::maxInstructionCount;
bool finished = false;
@ -60,6 +143,8 @@ void ShaderDecompiler::compileRange(const AddressRange& range) {
while (pc < end && !finished) {
compileInstruction(pc, finished);
}
return std::make_pair(pc, finished);
}
const Function* ShaderDecompiler::findFunction(const AddressRange& range) {
@ -84,6 +169,7 @@ void ShaderDecompiler::writeAttributes() {
vec4 tmp_regs[16];
vec4 out_regs[8];
vec4 dummy_vec = vec4(0.0);
bvec2 cmp_reg = bvec2(false);
)";
}
@ -124,14 +210,45 @@ std::string ShaderDecompiler::decompile() {
callFunction(*findFunction(mainFunctionRange));
decompiledShader += "}\n";
for (auto& func : controlFlow.functions) {
if (func.outLabels.size() > 0) {
Helpers::panic("Function with out labels");
}
for (const Function& func : controlFlow.functions) {
if (func.outLabels.empty()) {
decompiledShader += fmt::format("void {}() {{\n", func.getIdentifier());
compileRange(AddressRange(func.start, func.end));
decompiledShader += "}\n";
} else {
auto labels = func.outLabels;
labels.insert(func.start);
decompiledShader += "void " + func.getIdentifier() + "() {\n";
compileRange(AddressRange(func.start, func.end));
decompiledShader += "}\n";
// If a function has jumps and "labels", this needs to be emulated using a switch-case, with the variable being switched on being the
// current PC
decompiledShader += fmt::format("void {}() {{\n", func.getIdentifier());
decompiledShader += fmt::format("uint pc = {}u;\n", func.start);
decompiledShader += "while(true){\nswitch(pc){\n";
for (u32 label : labels) {
decompiledShader += fmt::format("case {}u: {{", label);
// Fetch the next label whose address > label
auto it = labels.lower_bound(label + 1);
u32 next = (it == labels.end()) ? func.end : *it;
auto [endPC, finished] = compileRange(AddressRange(label, next));
if (endPC > next && !finished) {
labels.insert(endPC);
decompiledShader += fmt::format("pc = {}u; break;", endPC);
}
// Fallthrough to next label
decompiledShader += "}\n";
}
decompiledShader += "default: return;\n";
// Exit the switch and loop
decompiledShader += "} }\n";
// Exit the function
decompiledShader += "return;\n";
decompiledShader += "}\n";
}
}
return decompiledShader;
@ -272,6 +389,33 @@ void ShaderDecompiler::compileInstruction(u32& pc, bool& finished) {
case ShaderOpcodes::DP3: setDest(operandDescriptor, dest, fmt::format("vec4(dot({}.xyz, {}.xyz))", src1, src2)); break;
case ShaderOpcodes::DP4: setDest(operandDescriptor, dest, fmt::format("vec4(dot({}, {}))", src1, src2)); break;
case ShaderOpcodes::RSQ: setDest(operandDescriptor, dest, fmt::format("vec4(inversesqrt({}.x))", src1)); break;
case ShaderOpcodes::RCP: setDest(operandDescriptor, dest, fmt::format("vec4(1.0 / {}.x)", src1)); break;
case ShaderOpcodes::CMP1:
case ShaderOpcodes::CMP2: {
static constexpr std::array<const char*, 8> operators = {
// The last 2 operators always return true and are handled specially
"==", "!=", "<", "<=", ">", ">=", "", "",
};
const u32 cmpY = getBits<21, 3>(instruction);
const u32 cmpX = getBits<24, 3>(instruction);
// Compare x first
if (cmpX >= 6) {
decompiledShader += "cmp_reg.x = true;\n";
} else {
decompiledShader += fmt::format("cmp_reg.x = {}.x {} {}.x;\n", src1, operators[cmpX], src2);
}
// Then compare Y
if (cmpY >= 6) {
decompiledShader += "cmp_reg.y = true;\n";
} else {
decompiledShader += fmt::format("cmp_reg.y = {}.y {} {}.y;\n", src1, operators[cmpY], src2);
}
break;
}
default: Helpers::panic("GLSL recompiler: Unknown common opcode: %X", opcode); break;
}
@ -315,7 +459,20 @@ void ShaderDecompiler::compileInstruction(u32& pc, bool& finished) {
setDest(operandDescriptor, dest, src1 + " * " + src2 + " + " + src3);
} else {
switch (opcode) {
case ShaderOpcodes::END: finished = true; return;
case ShaderOpcodes::JMPC: {
const u32 dest = getBits<10, 12>(instruction);
const u32 condOp = getBits<22, 2>(instruction);
const uint refY = getBit<24>(instruction);
const uint refX = getBit<25>(instruction);
const char* condition = getCondition(condOp, refX, refY);
decompiledShader += fmt::format("if ({}) {{ pc = {}u; break; }}", condition, dest);
break;
}
case ShaderOpcodes::END:
decompiledShader += "return;\n";
finished = true;
return;
default: Helpers::panic("GLSL recompiler: Unknown opcode: %X", opcode); break;
}
}
@ -323,7 +480,6 @@ void ShaderDecompiler::compileInstruction(u32& pc, bool& finished) {
pc++;
}
bool ShaderDecompiler::usesCommonEncoding(u32 instruction) const {
const u32 opcode = instruction >> 26;
switch (opcode) {
@ -360,3 +516,34 @@ std::string ShaderGen::decompileShader(PICAShader& shader, EmulatorConfig& confi
return decompiler.decompile();
}
const char* ShaderDecompiler::getCondition(u32 cond, u32 refX, u32 refY) {
static constexpr std::array<const char*, 16> conditions = {
// ref(Y, X) = (0, 0)
"!all(cmp_reg)",
"all(not(cmp_reg))",
"!cmp_reg.x",
"!cmp_reg.y",
// ref(Y, X) = (0, 1)
"cmp_reg.x || !cmp_reg.y",
"cmp_reg.x && !cmp_reg.y",
"cmp_reg.x",
"!cmp_reg.y",
// ref(Y, X) = (1, 0)
"!cmp_reg.x || cmp_reg.y",
"!cmp_reg.x && cmp_reg.y",
"!cmp_reg.x",
"cmp_reg.y",
// ref(Y, X) = (1, 1)
"any(cmp_reg)",
"all(cmp_reg)",
"cmp_reg.x",
"cmp_reg.y",
};
u32 key = (cond & 0b11) | (refX << 2) | (refY << 3);
return conditions[key];
}