Shader decompiler: Implement safe multiplication

This commit is contained in:
wheremyfoodat 2024-08-08 00:38:52 +03:00
parent 370aa8ee5c
commit c7371e3bf4

View file

@ -241,7 +241,7 @@ std::string ShaderDecompiler::decompile() {
decompiledShader += R"(
vec4 safe_mul(vec4 a, vec4 b) {
vec4 res = a * b;
return mix(res, mix(mix(vec4(0.0), res, isnan(rhs)), product, isnan(lhs)), isnan(res));
return mix(res, mix(mix(vec4(0.0), res, isnan(b)), res, isnan(a)), isnan(res));
}
)";
}
@ -423,12 +423,32 @@ void ShaderDecompiler::compileInstruction(u32& pc, bool& finished) {
switch (opcode) {
case ShaderOpcodes::MOV: setDest(operandDescriptor, dest, src1); break;
case ShaderOpcodes::ADD: setDest(operandDescriptor, dest, fmt::format("{} + {}", src1, src2)); break;
case ShaderOpcodes::MUL: setDest(operandDescriptor, dest, fmt::format("{} * {}", src1, src2)); break;
case ShaderOpcodes::MUL:
if (!config.accurateShaderMul) {
setDest(operandDescriptor, dest, fmt::format("{} * {}", src1, src2));
} else {
setDest(operandDescriptor, dest, fmt::format("safe_mul({}, {})", src1, src2));
}
break;
case ShaderOpcodes::MAX: setDest(operandDescriptor, dest, fmt::format("max({}, {})", src1, src2)); break;
case ShaderOpcodes::MIN: setDest(operandDescriptor, dest, fmt::format("min({}, {})", src1, src2)); break;
case ShaderOpcodes::DP3: setDest(operandDescriptor, dest, fmt::format("vec4(dot({}.xyz, {}.xyz))", src1, src2)); break;
case ShaderOpcodes::DP4: setDest(operandDescriptor, dest, fmt::format("vec4(dot({}, {}))", src1, src2)); break;
case ShaderOpcodes::DP3:
if (!config.accurateShaderMul) {
setDest(operandDescriptor, dest, fmt::format("vec4(dot({}.xyz, {}.xyz))", src1, src2));
} else {
// A dot product between a and b is equivalent to the per-lane multiplication of a and b followed by a dot product with vec3(1.0)
setDest(operandDescriptor, dest, fmt::format("vec4(dot(safe_mul({}, {}).xyz, vec3(1.0)))", src1, src2));
}
break;
case ShaderOpcodes::DP4:
if (!config.accurateShaderMul) {
setDest(operandDescriptor, dest, fmt::format("vec4(dot({}, {}))", src1, src2));
} else {
// A dot product between a and b is equivalent to the per-lane multiplication of a and b followed by a dot product with vec4(1.0)
setDest(operandDescriptor, dest, fmt::format("vec4(dot(safe_mul({}, {}), vec4(1.0)))", src1, src2));
}
break;
case ShaderOpcodes::FLR: setDest(operandDescriptor, dest, fmt::format("floor({})", src1)); break;
case ShaderOpcodes::RSQ: setDest(operandDescriptor, dest, fmt::format("vec4(inversesqrt({}.x))", src1)); break;
case ShaderOpcodes::RCP: setDest(operandDescriptor, dest, fmt::format("vec4(1.0 / {}.x)", src1)); break;
@ -441,7 +461,13 @@ void ShaderDecompiler::compileInstruction(u32& pc, bool& finished) {
case ShaderOpcodes::DPH:
case ShaderOpcodes::DPHI:
setDest(operandDescriptor, dest, fmt::format("vec4(dot(vec4({}.xyz, 1.0), {}))", src1, src2)); break;
if (!config.accurateShaderMul) {
setDest(operandDescriptor, dest, fmt::format("vec4(dot(vec4({}.xyz, 1.0), {}))", src1, src2));
} else {
// A dot product between a and b is equivalent to the per-lane multiplication of a and b followed by a dot product with vec4(1.0)
setDest(operandDescriptor, dest, fmt::format("vec4(dot(safe_mul(vec4({}.xyz, 1.0), {}), vec4(1.0)))", src1, src2));
}
break;
case ShaderOpcodes::CMP1:
case ShaderOpcodes::CMP2: {
@ -517,7 +543,11 @@ void ShaderDecompiler::compileInstruction(u32& pc, bool& finished) {
src3 += getSwizzlePattern(swizzle3);
std::string dest = getDest(destIndex);
setDest(operandDescriptor, dest, src1 + " * " + src2 + " + " + src3);
if (!config.accurateShaderMul) {
setDest(operandDescriptor, dest, fmt::format("{} * {} + {}", src1, src2, src3));
} else {
setDest(operandDescriptor, dest, fmt::format("safe_mul({}, {}) + {}", src1, src2, src3));
}
} else {
switch (opcode) {
case ShaderOpcodes::JMPC: {