From: Sebastian Pop Date: Tue, 8 Oct 2019 13:23:57 +0000 (+0000) Subject: fix fmls fp16 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=930fdcd02bb963cea10c9dc26ad0f61e2ba64820;p=llvm fix fmls fp16 Tim Northover remarked that the added patterns for fmls fp16 produce wrong code in case the fsub instruction has a multiplication as its first operand, i.e., all the patterns FMLSv*_OP1: > define <8 x half> @test_FMLSv8f16_OP1(<8 x half> %a, <8 x half> %b, <8 x half> %c) { > ; CHECK-LABEL: test_FMLSv8f16_OP1: > ; CHECK: fmls {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h > entry: > > %mul = fmul fast <8 x half> %c, %b > %sub = fsub fast <8 x half> %mul, %a > ret <8 x half> %sub > } > > This doesn't look right to me. The exact instruction produced is "fmls > v0.8h, v2.8h, v1.8h", which I think calculates "v0 - v2*v1", but the > IR is calculating "v2*v1-v0". The equivalent <4 x float> code also > doesn't emit an fmls. This patch generates an fmla and negates the value of the operand2 of the fsub. Inspecting the pattern match, I found that there was another mistake in the opcode to be selected: matching FMULv4*16 should generate FMLSv4*16 and not FMLSv2*32. Tested on aarch64-linux with make check-all. Differential Revision: https://reviews.llvm.org/D67990 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@374044 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/include/llvm/CodeGen/MachineCombinerPattern.h b/include/llvm/CodeGen/MachineCombinerPattern.h index 31056c8fdf0..50322722220 100644 --- a/include/llvm/CodeGen/MachineCombinerPattern.h +++ b/include/llvm/CodeGen/MachineCombinerPattern.h @@ -80,6 +80,7 @@ enum class MachineCombinerPattern { FMLAv4i32_indexed_OP2, FMLSv1i32_indexed_OP2, FMLSv1i64_indexed_OP2, + FMLSv4f16_OP1, FMLSv4f16_OP2, FMLSv8f16_OP1, FMLSv8f16_OP2, @@ -87,6 +88,7 @@ enum class MachineCombinerPattern { FMLSv2f32_OP2, FMLSv2f64_OP1, FMLSv2f64_OP2, + FMLSv4i16_indexed_OP1, FMLSv4i16_indexed_OP2, FMLSv8i16_indexed_OP1, FMLSv8i16_indexed_OP2, diff --git a/lib/Target/AArch64/AArch64InstrInfo.cpp b/lib/Target/AArch64/AArch64InstrInfo.cpp index 1cc3177b26a..57782862967 100644 --- a/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -3806,8 +3806,8 @@ static bool getFMAPatterns(MachineInstr &Root, Found |= Match(AArch64::FMULv4i16_indexed, 2, MCP::FMLSv4i16_indexed_OP2) || Match(AArch64::FMULv4f16, 2, MCP::FMLSv4f16_OP2); - Found |= Match(AArch64::FMULv4i16_indexed, 1, MCP::FMLSv2i32_indexed_OP1) || - Match(AArch64::FMULv4f16, 1, MCP::FMLSv2f32_OP1); + Found |= Match(AArch64::FMULv4i16_indexed, 1, MCP::FMLSv4i16_indexed_OP1) || + Match(AArch64::FMULv4f16, 1, MCP::FMLSv4f16_OP1); break; case AArch64::FSUBv8f16: Found |= Match(AArch64::FMULv8i16_indexed, 2, MCP::FMLSv8i16_indexed_OP2) || @@ -3888,6 +3888,7 @@ bool AArch64InstrInfo::isThroughputPattern( case MachineCombinerPattern::FMLAv4f32_OP2: case MachineCombinerPattern::FMLAv4i32_indexed_OP1: case MachineCombinerPattern::FMLAv4i32_indexed_OP2: + case MachineCombinerPattern::FMLSv4i16_indexed_OP1: case MachineCombinerPattern::FMLSv4i16_indexed_OP2: case MachineCombinerPattern::FMLSv8i16_indexed_OP1: case MachineCombinerPattern::FMLSv8i16_indexed_OP2: @@ -3895,6 +3896,7 @@ bool AArch64InstrInfo::isThroughputPattern( case MachineCombinerPattern::FMLSv1i64_indexed_OP2: case MachineCombinerPattern::FMLSv2i32_indexed_OP2: case MachineCombinerPattern::FMLSv2i64_indexed_OP2: + case MachineCombinerPattern::FMLSv4f16_OP1: case MachineCombinerPattern::FMLSv4f16_OP2: case MachineCombinerPattern::FMLSv8f16_OP1: case MachineCombinerPattern::FMLSv8f16_OP2: @@ -4497,6 +4499,26 @@ void AArch64InstrInfo::genAlternativeCodeSequence( FMAInstKind::Indexed); break; + case MachineCombinerPattern::FMLSv4f16_OP1: + case MachineCombinerPattern::FMLSv4i16_indexed_OP1: { + RC = &AArch64::FPR64RegClass; + Register NewVR = MRI.createVirtualRegister(RC); + MachineInstrBuilder MIB1 = + BuildMI(MF, Root.getDebugLoc(), TII->get(AArch64::FNEGv4f16), NewVR) + .add(Root.getOperand(2)); + InsInstrs.push_back(MIB1); + InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); + if (Pattern == MachineCombinerPattern::FMLSv4f16_OP1) { + Opc = AArch64::FMLAv4f16; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Accumulator, &NewVR); + } else { + Opc = AArch64::FMLAv4i16_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Indexed, &NewVR); + } + break; + } case MachineCombinerPattern::FMLSv4f16_OP2: RC = &AArch64::FPR64RegClass; Opc = AArch64::FMLSv4f16; @@ -4525,18 +4547,25 @@ void AArch64InstrInfo::genAlternativeCodeSequence( break; case MachineCombinerPattern::FMLSv8f16_OP1: + case MachineCombinerPattern::FMLSv8i16_indexed_OP1: { RC = &AArch64::FPR128RegClass; - Opc = AArch64::FMLSv8f16; - MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, - FMAInstKind::Accumulator); - break; - case MachineCombinerPattern::FMLSv8i16_indexed_OP1: - RC = &AArch64::FPR128RegClass; - Opc = AArch64::FMLSv8i16_indexed; - MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, - FMAInstKind::Indexed); + Register NewVR = MRI.createVirtualRegister(RC); + MachineInstrBuilder MIB1 = + BuildMI(MF, Root.getDebugLoc(), TII->get(AArch64::FNEGv8f16), NewVR) + .add(Root.getOperand(2)); + InsInstrs.push_back(MIB1); + InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); + if (Pattern == MachineCombinerPattern::FMLSv8f16_OP1) { + Opc = AArch64::FMLAv8f16; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Accumulator, &NewVR); + } else { + Opc = AArch64::FMLAv8i16_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Indexed, &NewVR); + } break; - + } case MachineCombinerPattern::FMLSv8f16_OP2: RC = &AArch64::FPR128RegClass; Opc = AArch64::FMLSv8f16; diff --git a/test/CodeGen/AArch64/fp16-fmla.ll b/test/CodeGen/AArch64/fp16-fmla.ll index 08228e25d4a..a81721afb84 100644 --- a/test/CodeGen/AArch64/fp16-fmla.ll +++ b/test/CodeGen/AArch64/fp16-fmla.ll @@ -138,6 +138,16 @@ entry: ret <8 x half> %add } +define <4 x half> @test_FMLSv4f16_OP1(<4 x half> %a, <4 x half> %b, <4 x half> %c) { +; CHECK-LABEL: test_FMLSv4f16_OP1: +; CHECK: fneg {{v[0-9]+}}.4h, {{v[0-9]+}}.4h +; CHECK: fmla {{v[0-9]+}}.4h, {{v[0-9]+}}.4h, {{v[0-9]+}}.4h +entry: + %mul = fmul fast <4 x half> %c, %b + %sub = fsub fast <4 x half> %mul, %a + ret <4 x half> %sub +} + define <4 x half> @test_FMLSv4f16_OP2(<4 x half> %a, <4 x half> %b, <4 x half> %c) { ; CHECK-LABEL: test_FMLSv4f16_OP2: ; CHECK: fmls {{v[0-9]+}}.4h, {{v[0-9]+}}.4h, {{v[0-9]+}}.4h @@ -149,7 +159,8 @@ entry: define <8 x half> @test_FMLSv8f16_OP1(<8 x half> %a, <8 x half> %b, <8 x half> %c) { ; CHECK-LABEL: test_FMLSv8f16_OP1: -; CHECK: fmls {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h +; CHECK: fneg {{v[0-9]+}}.8h, {{v[0-9]+}}.8h +; CHECK: fmla {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h entry: %mul = fmul fast <8 x half> %c, %b %sub = fsub fast <8 x half> %mul, %a @@ -185,7 +196,8 @@ define <8 x half> @test_FMLSv8i16_indexed_OP1(<8 x half> %a, <8 x i16> %b, <8 x ; CHECK: mul ; CHECK: fsub ; CHECK-FIXME: It should instead produce the following instruction: -; CHECK-FIXME: fmls {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h +; CHECK-FIXME: fneg {{v[0-9]+}}.8h, {{v[0-9]+}}.8h +; CHECK-FIXME: fmla {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h entry: %mul = mul <8 x i16> %c, %b %m = bitcast <8 x i16> %mul to <8 x half>