Implement PICA200 compliant arm64 MUL

Adds `emitSafeMUL` to implement a PICA200 compliant multiplication that
handles the special `0 * inf = 0` case.
This commit is contained in:
Wunkolo 2024-03-14 12:16:57 -07:00
parent e13fe49bbb
commit 2b34ef4a89
2 changed files with 46 additions and 11 deletions

View file

@ -42,6 +42,9 @@ class ShaderEmitter : private oaknut::CodeBlock, public oaknut::CodeGenerator {
oaknut::Label emitLog2Func();
oaknut::Label emitExp2Func();
// Emit a PICA200-compliant multiplication that handles "0 * inf = 0"
void emitSafeMUL(oaknut::QReg src1, oaknut::QReg src2, oaknut::QReg scratch0);
template <typename T>
T getLabelPointer(const oaknut::Label& label) {
auto pointer = reinterpret_cast<u8*>(oaknut::CodeBlock::ptr()) + label.offset();
@ -123,9 +126,7 @@ class ShaderEmitter : private oaknut::CodeBlock, public oaknut::CodeGenerator {
ShaderEmitter() : oaknut::CodeBlock(allocSize), oaknut::CodeGenerator(oaknut::CodeBlock::ptr()) {}
// PC must be a valid entrypoint here. It doesn't have that much overhead in this case, so we use std::array<>::at() to assert it does
InstructionCallback getInstructionCallback(u32 pc) {
return getLabelPointer<InstructionCallback>(instructionLabels.at(pc));
}
InstructionCallback getInstructionCallback(u32 pc) { return getLabelPointer<InstructionCallback>(instructionLabels.at(pc)); }
PrologueCallback getPrologueCallback() { return prologueCb; }
void compile(const PICAShader& shaderUnit);

View file

@ -7,6 +7,9 @@ using namespace Helpers;
using namespace oaknut;
using namespace oaknut::util;
// TODO: Expose safe/unsafe optimizations to the user
constexpr bool useSafeMUL = true;
// Similar to the x64 recompiler, we use an odd internal ABI, which abuses the fact that we'll very rarely be calling C++ functions
// So to avoid pushing and popping, we'll be making use of volatile registers as much as possible
static constexpr QReg scratch1 = Q0;
@ -474,14 +477,18 @@ void ShaderEmitter::recDP3(const PICAShader& shader, u32 instruction) {
const u32 dest = getBits<21, 5>(instruction);
const u32 writeMask = getBits<0, 4>(operandDescriptor);
// TODO: Safe multiplication equivalent (Multiplication is not IEEE compliant on the PICA)
loadRegister<1>(src1_vec, shader, src1, idx, operandDescriptor);
loadRegister<2>(src2_vec, shader, src2, 0, operandDescriptor);
// Set W component of src1 to 0.0, so that the w factor of the following dp4 will become 0, making it equivalent to a dp3
INS(src1_vec.Selem()[3], WZR);
// Now do a full DP4
FMUL(src1_vec.S4(), src1_vec.S4(), src2_vec.S4()); // Do a piecewise multiplication of the vectors first
// Do a piecewise multiplication of the vectors first
if constexpr (useSafeMUL) {
emitSafeMUL(src1_vec, src2_vec, scratch1);
} else {
FMUL(src1_vec.S4(), src1_vec.S4(), src2_vec.S4());
}
FADDP(src1_vec.S4(), src1_vec.S4(), src1_vec.S4()); // Now add the adjacent components together
FADDP(src1_vec.toS(), src1_vec.toD().S2()); // Again for the bottom 2 lanes. Now the bottom lane contains the dot product
@ -500,11 +507,15 @@ void ShaderEmitter::recDP4(const PICAShader& shader, u32 instruction) {
const u32 dest = getBits<21, 5>(instruction);
const u32 writeMask = getBits<0, 4>(operandDescriptor);
// TODO: Safe multiplication equivalent (Multiplication is not IEEE compliant on the PICA)
loadRegister<1>(src1_vec, shader, src1, idx, operandDescriptor);
loadRegister<2>(src2_vec, shader, src2, 0, operandDescriptor);
FMUL(src1_vec.S4(), src1_vec.S4(), src2_vec.S4()); // Do a piecewise multiplication of the vectors first
// Do a piecewise multiplication of the vectors first
if constexpr (useSafeMUL) {
emitSafeMUL(src1_vec, src2_vec, scratch1);
} else {
FMUL(src1_vec.S4(), src1_vec.S4(), src2_vec.S4());
}
FADDP(src1_vec.S4(), src1_vec.S4(), src1_vec.S4()); // Now add the adjacent components together
FADDP(src1_vec.toS(), src1_vec.toD().S2()); // Again for the bottom 2 lanes. Now the bottom lane contains the dot product
@ -515,6 +526,20 @@ void ShaderEmitter::recDP4(const PICAShader& shader, u32 instruction) {
storeRegister(src1_vec, shader, dest, operandDescriptor);
}
void ShaderEmitter::emitSafeMUL(oaknut::QReg src1, oaknut::QReg src2, oaknut::QReg scratch0) {
// 0 * inf and inf * 0 in the PICA should return 0 instead of NaN
// This can be done by checking for NaNs before and after a multiplication
// FMULX returns 2.0 in the case of 0.0 * inf or inf * 0.0
// Both a FMUL and FMULX are done and the results are compared to each other
// In the case that the results are diferent(a 0.0*inf happened), then
// a 0.0 is written
FMULX(scratch1.S4(), src1.S4(), src2.S4());
FMUL(src1.S4(), src1.S4(), src2.S4());
CMEQ(scratch1.S4(), scratch1.S4(), src1.S4());
AND(src1.B16(), src1.B16(), scratch1.B16());
}
void ShaderEmitter::recADD(const PICAShader& shader, u32 instruction) {
const u32 operandDescriptor = shader.operandDescriptors[instruction & 0x7f];
const u32 src1 = getBits<12, 7>(instruction);
@ -561,10 +586,15 @@ void ShaderEmitter::recMUL(const PICAShader& shader, u32 instruction) {
const u32 idx = getBits<19, 2>(instruction);
const u32 dest = getBits<21, 5>(instruction);
// TODO: Safe multiplication equivalent (Multiplication is not IEEE compliant on the PICA)
loadRegister<1>(src1_vec, shader, src1, idx, operandDescriptor);
loadRegister<2>(src2_vec, shader, src2, 0, operandDescriptor);
FMUL(src1_vec.S4(), src1_vec.S4(), src2_vec.S4());
if constexpr (useSafeMUL) {
emitSafeMUL(src1_vec, src2_vec, scratch1);
} else {
FMUL(src1_vec.S4(), src1_vec.S4(), src2_vec.S4());
}
storeRegister(src1_vec, shader, dest, operandDescriptor);
}
@ -632,8 +662,12 @@ void ShaderEmitter::recMAD(const PICAShader& shader, u32 instruction) {
loadRegister<2>(src2_vec, shader, src2, isMADI ? 0 : idx, operandDescriptor);
loadRegister<3>(src3_vec, shader, src3, isMADI ? idx : 0, operandDescriptor);
// TODO: Safe PICA multiplication
FMLA(src3_vec.S4(), src1_vec.S4(), src2_vec.S4());
if constexpr (useSafeMUL) {
emitSafeMUL(src1_vec, src2_vec, scratch1);
FADD(src3_vec.S4(), src3_vec.S4(), src1_vec.S4());
} else {
FMLA(src3_vec.S4(), src1_vec.S4(), src2_vec.S4());
}
storeRegister(src3_vec, shader, dest, operandDescriptor);
}