]> granicus.if.org Git - llvm/commitdiff
fix fmls fp16
authorSebastian Pop <spop@amazon.com>
Tue, 8 Oct 2019 13:23:57 +0000 (13:23 +0000)
committerSebastian Pop <spop@amazon.com>
Tue, 8 Oct 2019 13:23:57 +0000 (13:23 +0000)
Tim Northover remarked that the added patterns for fmls fp16
produce wrong code in case the fsub instruction has a
multiplication as its first operand, i.e., all the patterns FMLSv*_OP1:

> define <8 x half> @test_FMLSv8f16_OP1(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
> ; CHECK-LABEL: test_FMLSv8f16_OP1:
> ; CHECK: fmls {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
> entry:
>
>   %mul = fmul fast <8 x half> %c, %b
>   %sub = fsub fast <8 x half> %mul, %a
>   ret <8 x half> %sub
> }
>
> This doesn't look right to me. The exact instruction produced is "fmls
> v0.8h, v2.8h, v1.8h", which I think calculates "v0 - v2*v1", but the
> IR is calculating "v2*v1-v0". The equivalent <4 x float> code also
> doesn't emit an fmls.

This patch generates an fmla and negates the value of the operand2 of the fsub.

Inspecting the pattern match, I found that there was another mistake in the
opcode to be selected: matching FMULv4*16 should generate FMLSv4*16
and not FMLSv2*32.

Tested on aarch64-linux with make check-all.

Differential Revision: https://reviews.llvm.org/D67990

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@374044 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/CodeGen/MachineCombinerPattern.h
lib/Target/AArch64/AArch64InstrInfo.cpp
test/CodeGen/AArch64/fp16-fmla.ll

index 31056c8fdf0f19adf5ae0c1e6b4052b46c6bbf7e..503227222207f7406632e2b07ce1fed9cc0859e6 100644 (file)
@@ -80,6 +80,7 @@ enum class MachineCombinerPattern {
   FMLAv4i32_indexed_OP2,
   FMLSv1i32_indexed_OP2,
   FMLSv1i64_indexed_OP2,
+  FMLSv4f16_OP1,
   FMLSv4f16_OP2,
   FMLSv8f16_OP1,
   FMLSv8f16_OP2,
@@ -87,6 +88,7 @@ enum class MachineCombinerPattern {
   FMLSv2f32_OP2,
   FMLSv2f64_OP1,
   FMLSv2f64_OP2,
+  FMLSv4i16_indexed_OP1,
   FMLSv4i16_indexed_OP2,
   FMLSv8i16_indexed_OP1,
   FMLSv8i16_indexed_OP2,
index 1cc3177b26a7f400b82354ecb855ca559cd45d0c..57782862967d7c48e5abd6f6ec4b9392edf57292 100644 (file)
@@ -3806,8 +3806,8 @@ static bool getFMAPatterns(MachineInstr &Root,
     Found |= Match(AArch64::FMULv4i16_indexed, 2, MCP::FMLSv4i16_indexed_OP2) ||
              Match(AArch64::FMULv4f16, 2, MCP::FMLSv4f16_OP2);
 
-    Found |= Match(AArch64::FMULv4i16_indexed, 1, MCP::FMLSv2i32_indexed_OP1) ||
-             Match(AArch64::FMULv4f16, 1, MCP::FMLSv2f32_OP1);
+    Found |= Match(AArch64::FMULv4i16_indexed, 1, MCP::FMLSv4i16_indexed_OP1) ||
+             Match(AArch64::FMULv4f16, 1, MCP::FMLSv4f16_OP1);
     break;
   case AArch64::FSUBv8f16:
     Found |= Match(AArch64::FMULv8i16_indexed, 2, MCP::FMLSv8i16_indexed_OP2) ||
@@ -3888,6 +3888,7 @@ bool AArch64InstrInfo::isThroughputPattern(
   case MachineCombinerPattern::FMLAv4f32_OP2:
   case MachineCombinerPattern::FMLAv4i32_indexed_OP1:
   case MachineCombinerPattern::FMLAv4i32_indexed_OP2:
+  case MachineCombinerPattern::FMLSv4i16_indexed_OP1:
   case MachineCombinerPattern::FMLSv4i16_indexed_OP2:
   case MachineCombinerPattern::FMLSv8i16_indexed_OP1:
   case MachineCombinerPattern::FMLSv8i16_indexed_OP2:
@@ -3895,6 +3896,7 @@ bool AArch64InstrInfo::isThroughputPattern(
   case MachineCombinerPattern::FMLSv1i64_indexed_OP2:
   case MachineCombinerPattern::FMLSv2i32_indexed_OP2:
   case MachineCombinerPattern::FMLSv2i64_indexed_OP2:
+  case MachineCombinerPattern::FMLSv4f16_OP1:
   case MachineCombinerPattern::FMLSv4f16_OP2:
   case MachineCombinerPattern::FMLSv8f16_OP1:
   case MachineCombinerPattern::FMLSv8f16_OP2:
@@ -4497,6 +4499,26 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
                            FMAInstKind::Indexed);
     break;
 
+  case MachineCombinerPattern::FMLSv4f16_OP1:
+  case MachineCombinerPattern::FMLSv4i16_indexed_OP1: {
+    RC = &AArch64::FPR64RegClass;
+    Register NewVR = MRI.createVirtualRegister(RC);
+    MachineInstrBuilder MIB1 =
+        BuildMI(MF, Root.getDebugLoc(), TII->get(AArch64::FNEGv4f16), NewVR)
+            .add(Root.getOperand(2));
+    InsInstrs.push_back(MIB1);
+    InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
+    if (Pattern == MachineCombinerPattern::FMLSv4f16_OP1) {
+      Opc = AArch64::FMLAv4f16;
+      MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC,
+                             FMAInstKind::Accumulator, &NewVR);
+    } else {
+      Opc = AArch64::FMLAv4i16_indexed;
+      MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC,
+                             FMAInstKind::Indexed, &NewVR);
+    }
+    break;
+  }
   case MachineCombinerPattern::FMLSv4f16_OP2:
     RC = &AArch64::FPR64RegClass;
     Opc = AArch64::FMLSv4f16;
@@ -4525,18 +4547,25 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     break;
 
   case MachineCombinerPattern::FMLSv8f16_OP1:
+  case MachineCombinerPattern::FMLSv8i16_indexed_OP1: {
     RC = &AArch64::FPR128RegClass;
-    Opc = AArch64::FMLSv8f16;
-    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC,
-                           FMAInstKind::Accumulator);
-    break;
-  case MachineCombinerPattern::FMLSv8i16_indexed_OP1:
-    RC = &AArch64::FPR128RegClass;
-    Opc = AArch64::FMLSv8i16_indexed;
-    MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC,
-                           FMAInstKind::Indexed);
+    Register NewVR = MRI.createVirtualRegister(RC);
+    MachineInstrBuilder MIB1 =
+        BuildMI(MF, Root.getDebugLoc(), TII->get(AArch64::FNEGv8f16), NewVR)
+            .add(Root.getOperand(2));
+    InsInstrs.push_back(MIB1);
+    InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
+    if (Pattern == MachineCombinerPattern::FMLSv8f16_OP1) {
+      Opc = AArch64::FMLAv8f16;
+      MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC,
+                             FMAInstKind::Accumulator, &NewVR);
+    } else {
+      Opc = AArch64::FMLAv8i16_indexed;
+      MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC,
+                             FMAInstKind::Indexed, &NewVR);
+    }
     break;
-
+  }
   case MachineCombinerPattern::FMLSv8f16_OP2:
     RC = &AArch64::FPR128RegClass;
     Opc = AArch64::FMLSv8f16;
index 08228e25d4aad7be706e132706eed9a391906318..a81721afb8453cad7afaa2c5a465b5a097dc9d3e 100644 (file)
@@ -138,6 +138,16 @@ entry:
   ret <8 x half> %add
 }
 
+define <4 x half> @test_FMLSv4f16_OP1(<4 x half> %a, <4 x half> %b, <4 x half> %c) {
+; CHECK-LABEL: test_FMLSv4f16_OP1:
+; CHECK: fneg    {{v[0-9]+}}.4h, {{v[0-9]+}}.4h
+; CHECK: fmla    {{v[0-9]+}}.4h, {{v[0-9]+}}.4h, {{v[0-9]+}}.4h
+entry:
+  %mul = fmul fast <4 x half> %c, %b
+  %sub = fsub fast <4 x half> %mul, %a
+  ret <4 x half> %sub
+}
+
 define <4 x half> @test_FMLSv4f16_OP2(<4 x half> %a, <4 x half> %b, <4 x half> %c) {
 ; CHECK-LABEL: test_FMLSv4f16_OP2:
 ; CHECK: fmls    {{v[0-9]+}}.4h, {{v[0-9]+}}.4h, {{v[0-9]+}}.4h
@@ -149,7 +159,8 @@ entry:
 
 define <8 x half> @test_FMLSv8f16_OP1(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
 ; CHECK-LABEL: test_FMLSv8f16_OP1:
-; CHECK: fmls    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
+; CHECK: fneg    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
+; CHECK: fmla    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
 entry:
   %mul = fmul fast <8 x half> %c, %b
   %sub = fsub fast <8 x half> %mul, %a
@@ -185,7 +196,8 @@ define <8 x half> @test_FMLSv8i16_indexed_OP1(<8 x half> %a, <8 x i16> %b, <8 x
 ; CHECK: mul
 ; CHECK: fsub
 ; CHECK-FIXME: It should instead produce the following instruction:
-; CHECK-FIXME: fmls    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
+; CHECK-FIXME: fneg    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
+; CHECK-FIXME: fmla    {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h
 entry:
   %mul = mul <8 x i16> %c, %b
   %m = bitcast <8 x i16> %mul to <8 x half>