diff --git a/src/coreclr/jit/hwintrinsiccodegenarm64.cpp b/src/coreclr/jit/hwintrinsiccodegenarm64.cpp index 61ed96590534f6..e7c5a098870bb9 100644 --- a/src/coreclr/jit/hwintrinsiccodegenarm64.cpp +++ b/src/coreclr/jit/hwintrinsiccodegenarm64.cpp @@ -2429,37 +2429,36 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node) } case NI_Sve2_Scatter16BitNarrowingNonTemporal: - case NI_Sve2_Scatter16BitWithByteOffsetsNarrowingNonTemporal: case NI_Sve2_Scatter32BitNarrowingNonTemporal: - case NI_Sve2_Scatter32BitWithByteOffsetsNarrowingNonTemporal: - case NI_Sve2_Scatter8BitNarrowingNonTemporal: - case NI_Sve2_Scatter8BitWithByteOffsetsNarrowingNonTemporal: case NI_Sve2_ScatterNonTemporal: - case NI_Sve2_ScatterWithByteOffsetsNonTemporal: { if (!varTypeIsSIMD(intrin.op2->gtType)) { - // Scatter...(Vector mask, T* address, Vector offsets, Vector data) + // Scatter...(Vector mask, T* address, Vector indices, Vector data) assert(intrin.numOperands == 4); - // Calculate the byte offsets if using indices. + ssize_t shift = 0; + regNumber tempReg = internalRegisters.GetSingle(node, RBM_ALLFLOAT); + if (intrin.id == NI_Sve2_Scatter16BitNarrowingNonTemporal) { - GetEmitter()->emitIns_R_R_I(INS_sve_lsl, emitSize, op3Reg, op3Reg, 1, opt); + shift = 1; } else if (intrin.id == NI_Sve2_Scatter32BitNarrowingNonTemporal) { - GetEmitter()->emitIns_R_R_I(INS_sve_lsl, emitSize, op3Reg, op3Reg, 2, opt); + shift = 2; } - else if (intrin.id == NI_Sve2_ScatterNonTemporal) + else { - assert(emitActualTypeSize(intrin.baseType) == 8); - GetEmitter()->emitIns_R_R_I(INS_sve_lsl, emitSize, op3Reg, op3Reg, 3, opt); + assert(intrin.id == NI_Sve2_ScatterNonTemporal); + shift = 3; } - // op2Reg and op3Reg are swapped - GetEmitter()->emitIns_R_R_R_R(ins, emitSize, op4Reg, op1Reg, op3Reg, op2Reg, opt); + // The SVE2 instructions only support byte offsets. Convert indices to bytes. + GetEmitter()->emitIns_R_R_I(INS_sve_lsl, emitSize, tempReg, op3Reg, shift, opt); + + GetEmitter()->emitIns_R_R_R_R(ins, emitSize, op4Reg, op1Reg, tempReg, op2Reg, opt); } else { @@ -2471,6 +2470,33 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node) break; } + case NI_Sve2_Scatter8BitNarrowingNonTemporal: + if (!varTypeIsSIMD(intrin.op2->gtType)) + { + // Scatter...(Vector mask, T* address, Vector indices, Vector data) + assert(intrin.numOperands == 4); + GetEmitter()->emitIns_R_R_R_R(ins, emitSize, op4Reg, op1Reg, op3Reg, op2Reg, opt); + } + else + { + // Scatter...(Vector mask, Vector addresses, Vector data) + assert(intrin.numOperands == 3); + GetEmitter()->emitIns_R_R_R_R(ins, emitSize, op3Reg, op1Reg, op2Reg, REG_ZR, opt); + } + break; + + case NI_Sve2_Scatter16BitWithByteOffsetsNarrowingNonTemporal: + case NI_Sve2_Scatter32BitWithByteOffsetsNarrowingNonTemporal: + case NI_Sve2_Scatter8BitWithByteOffsetsNarrowingNonTemporal: + case NI_Sve2_ScatterWithByteOffsetsNonTemporal: + // Scatter...(Vector mask, T* address, Vector offsets, Vector data) + assert(!varTypeIsSIMD(intrin.op2->gtType)); + assert(intrin.numOperands == 4); + + // op2Reg and op3Reg are swapped + GetEmitter()->emitIns_R_R_R_R(ins, emitSize, op4Reg, op1Reg, op3Reg, op2Reg, opt); + break; + case NI_Sve_StoreNarrowing: opt = emitter::optGetSveInsOpt(emitTypeSize(intrin.baseType)); GetEmitter()->emitIns_R_R_R_I(ins, emitSize, op3Reg, op1Reg, op2Reg, 0, opt); diff --git a/src/coreclr/jit/lsraarm64.cpp b/src/coreclr/jit/lsraarm64.cpp index 89c0f3c69e982c..9e3979ac161201 100644 --- a/src/coreclr/jit/lsraarm64.cpp +++ b/src/coreclr/jit/lsraarm64.cpp @@ -1365,6 +1365,22 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou // Build any immediates BuildHWIntrinsicImmediate(intrinsicTree, intrin); + // Build any additional special cases + switch (intrin.id) + { + case NI_Sve2_Scatter16BitNarrowingNonTemporal: + case NI_Sve2_Scatter32BitNarrowingNonTemporal: + case NI_Sve2_ScatterNonTemporal: + if (!varTypeIsSIMD(intrin.op2->gtType)) + { + buildInternalFloatRegisterDefForNode(intrinsicTree, internalFloatRegCandidates()); + } + break; + + default: + break; + } + // Build all Operands for (size_t opNum = 1; opNum <= intrin.numOperands; opNum++) {