diff options
author | Mark Kettenis <kettenis@cvs.openbsd.org> | 2023-11-19 16:42:06 +0000 |
---|---|---|
committer | Mark Kettenis <kettenis@cvs.openbsd.org> | 2023-11-19 16:42:06 +0000 |
commit | fbbaeee97ea3eeaef4c93581904230ddca260c76 (patch) | |
tree | f6747b84b5387e4353ba0fb03d11474c85b115e6 /gnu/llvm | |
parent | 146e625cc798fc52a10e9527e9f065896c7d93f6 (diff) |
Disable LOAD_STACK_GUARD on OpenBSD/armv7. It seems the implementation
is incomplete resulting in SIGSEGV with the OpenBSD default options.
ok deraadt@, jsg@
Diffstat (limited to 'gnu/llvm')
-rw-r--r-- | gnu/llvm/llvm/lib/Target/ARM/ARMISelLowering.cpp | 7520 |
1 files changed, 6014 insertions, 1506 deletions
diff --git a/gnu/llvm/llvm/lib/Target/ARM/ARMISelLowering.cpp b/gnu/llvm/llvm/lib/Target/ARM/ARMISelLowering.cpp index 66f3f418d06..200d450537b 100644 --- a/gnu/llvm/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/gnu/llvm/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -21,6 +21,7 @@ #include "ARMRegisterInfo.h" #include "ARMSelectionDAGInfo.h" #include "ARMSubtarget.h" +#include "ARMTargetTransformInfo.h" #include "MCTargetDesc/ARMAddressingModes.h" #include "MCTargetDesc/ARMBaseInfo.h" #include "Utils/ARMBaseInfo.h" @@ -54,6 +55,7 @@ #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/RuntimeLibcalls.h" #include "llvm/CodeGen/SelectionDAG.h" +#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h" #include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/CodeGen/TargetLowering.h" @@ -108,6 +110,7 @@ #include <cstdlib> #include <iterator> #include <limits> +#include <optional> #include <string> #include <tuple> #include <utility> @@ -143,7 +146,7 @@ static cl::opt<unsigned> ConstpoolPromotionMaxTotal( cl::desc("Maximum size of ALL constants to promote into a constant pool"), cl::init(128)); -static cl::opt<unsigned> +cl::opt<unsigned> MVEMaxSupportedInterleaveFactor("mve-max-interleave-factor", cl::Hidden, cl::desc("Maximum interleave factor for MVE VLDn to generate."), cl::init(2)); @@ -153,8 +156,7 @@ static const MCPhysReg GPRArgRegs[] = { ARM::R0, ARM::R1, ARM::R2, ARM::R3 }; -void ARMTargetLowering::addTypeForNEON(MVT VT, MVT PromotedLdStVT, - MVT PromotedBitwiseVT) { +void ARMTargetLowering::addTypeForNEON(MVT VT, MVT PromotedLdStVT) { if (VT != PromotedLdStVT) { setOperationAction(ISD::LOAD, VT, Promote); AddPromotedToType (ISD::LOAD, VT, PromotedLdStVT); @@ -193,16 +195,6 @@ void ARMTargetLowering::addTypeForNEON(MVT VT, MVT PromotedLdStVT, setOperationAction(ISD::SRL, VT, Custom); } - // Promote all bit-wise operations. - if (VT.isInteger() && VT != PromotedBitwiseVT) { - setOperationAction(ISD::AND, VT, Promote); - AddPromotedToType (ISD::AND, VT, PromotedBitwiseVT); - setOperationAction(ISD::OR, VT, Promote); - AddPromotedToType (ISD::OR, VT, PromotedBitwiseVT); - setOperationAction(ISD::XOR, VT, Promote); - AddPromotedToType (ISD::XOR, VT, PromotedBitwiseVT); - } - // Neon does not support vector divide/remainder operations. setOperationAction(ISD::SDIV, VT, Expand); setOperationAction(ISD::UDIV, VT, Expand); @@ -210,6 +202,8 @@ void ARMTargetLowering::addTypeForNEON(MVT VT, MVT PromotedLdStVT, setOperationAction(ISD::SREM, VT, Expand); setOperationAction(ISD::UREM, VT, Expand); setOperationAction(ISD::FREM, VT, Expand); + setOperationAction(ISD::SDIVREM, VT, Expand); + setOperationAction(ISD::UDIVREM, VT, Expand); if (!VT.isFloatingPoint() && VT != MVT::v2i64 && VT != MVT::v1i64) @@ -222,12 +216,12 @@ void ARMTargetLowering::addTypeForNEON(MVT VT, MVT PromotedLdStVT, void ARMTargetLowering::addDRTypeForNEON(MVT VT) { addRegisterClass(VT, &ARM::DPRRegClass); - addTypeForNEON(VT, MVT::f64, MVT::v2i32); + addTypeForNEON(VT, MVT::f64); } void ARMTargetLowering::addQRTypeForNEON(MVT VT) { addRegisterClass(VT, &ARM::DPairRegClass); - addTypeForNEON(VT, MVT::v2f64, MVT::v4i32); + addTypeForNEON(VT, MVT::v2f64); } void ARMTargetLowering::setAllExpand(MVT VT) { @@ -278,13 +272,23 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) { setOperationAction(ISD::UADDSAT, VT, Legal); setOperationAction(ISD::SSUBSAT, VT, Legal); setOperationAction(ISD::USUBSAT, VT, Legal); + setOperationAction(ISD::ABDS, VT, Legal); + setOperationAction(ISD::ABDU, VT, Legal); + setOperationAction(ISD::AVGFLOORS, VT, Legal); + setOperationAction(ISD::AVGFLOORU, VT, Legal); + setOperationAction(ISD::AVGCEILS, VT, Legal); + setOperationAction(ISD::AVGCEILU, VT, Legal); // No native support for these. setOperationAction(ISD::UDIV, VT, Expand); setOperationAction(ISD::SDIV, VT, Expand); setOperationAction(ISD::UREM, VT, Expand); setOperationAction(ISD::SREM, VT, Expand); + setOperationAction(ISD::UDIVREM, VT, Expand); + setOperationAction(ISD::SDIVREM, VT, Expand); setOperationAction(ISD::CTPOP, VT, Expand); + setOperationAction(ISD::SELECT, VT, Expand); + setOperationAction(ISD::SELECT_CC, VT, Expand); // Vector reductions setOperationAction(ISD::VECREDUCE_ADD, VT, Legal); @@ -292,12 +296,19 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) { setOperationAction(ISD::VECREDUCE_UMAX, VT, Legal); setOperationAction(ISD::VECREDUCE_SMIN, VT, Legal); setOperationAction(ISD::VECREDUCE_UMIN, VT, Legal); + setOperationAction(ISD::VECREDUCE_MUL, VT, Custom); + setOperationAction(ISD::VECREDUCE_AND, VT, Custom); + setOperationAction(ISD::VECREDUCE_OR, VT, Custom); + setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); if (!HasMVEFP) { setOperationAction(ISD::SINT_TO_FP, VT, Expand); setOperationAction(ISD::UINT_TO_FP, VT, Expand); setOperationAction(ISD::FP_TO_SINT, VT, Expand); setOperationAction(ISD::FP_TO_UINT, VT, Expand); + } else { + setOperationAction(ISD::FP_TO_SINT_SAT, VT, Custom); + setOperationAction(ISD::FP_TO_UINT_SAT, VT, Custom); } // Pre and Post inc are supported on loads and stores @@ -327,6 +338,8 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) { setOperationAction(ISD::SETCC, VT, Custom); setOperationAction(ISD::MLOAD, VT, Custom); setOperationAction(ISD::MSTORE, VT, Legal); + setOperationAction(ISD::SELECT, VT, Expand); + setOperationAction(ISD::SELECT_CC, VT, Expand); // Pre and Post inc are supported on loads and stores for (unsigned im = (unsigned)ISD::PRE_INC; @@ -341,6 +354,10 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) { setOperationAction(ISD::FMINNUM, VT, Legal); setOperationAction(ISD::FMAXNUM, VT, Legal); setOperationAction(ISD::FROUND, VT, Legal); + setOperationAction(ISD::VECREDUCE_FADD, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMUL, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); // No native support for these. setOperationAction(ISD::FDIV, VT, Expand); @@ -358,6 +375,17 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) { } } + // Custom Expand smaller than legal vector reductions to prevent false zero + // items being added. + setOperationAction(ISD::VECREDUCE_FADD, MVT::v4f16, Custom); + setOperationAction(ISD::VECREDUCE_FMUL, MVT::v4f16, Custom); + setOperationAction(ISD::VECREDUCE_FMIN, MVT::v4f16, Custom); + setOperationAction(ISD::VECREDUCE_FMAX, MVT::v4f16, Custom); + setOperationAction(ISD::VECREDUCE_FADD, MVT::v2f16, Custom); + setOperationAction(ISD::VECREDUCE_FMUL, MVT::v2f16, Custom); + setOperationAction(ISD::VECREDUCE_FMIN, MVT::v2f16, Custom); + setOperationAction(ISD::VECREDUCE_FMAX, MVT::v2f16, Custom); + // We 'support' these types up to bitcast/load/store level, regardless of // MVE integer-only / float support. Only doing FP data processing on the FP // vector types is inhibited at integer-only level. @@ -368,7 +396,11 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) { setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); setOperationAction(ISD::BUILD_VECTOR, VT, Custom); + setOperationAction(ISD::VSELECT, VT, Legal); + setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom); } + setOperationAction(ISD::SCALAR_TO_VECTOR, MVT::v2f64, Legal); + // We can do bitwise operations on v2i64 vectors setOperationAction(ISD::AND, MVT::v2i64, Legal); setOperationAction(ISD::OR, MVT::v2i64, Legal); @@ -403,7 +435,7 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) { } // Predicate types - const MVT pTypes[] = {MVT::v16i1, MVT::v8i1, MVT::v4i1}; + const MVT pTypes[] = {MVT::v16i1, MVT::v8i1, MVT::v4i1, MVT::v2i1}; for (auto VT : pTypes) { addRegisterClass(VT, &ARM::VCCRRegClass); setOperationAction(ISD::BUILD_VECTOR, VT, Custom); @@ -416,7 +448,36 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) { setOperationAction(ISD::SCALAR_TO_VECTOR, VT, Expand); setOperationAction(ISD::LOAD, VT, Custom); setOperationAction(ISD::STORE, VT, Custom); + setOperationAction(ISD::TRUNCATE, VT, Custom); + setOperationAction(ISD::VSELECT, VT, Expand); + setOperationAction(ISD::SELECT, VT, Expand); + setOperationAction(ISD::SELECT_CC, VT, Expand); + + if (!HasMVEFP) { + setOperationAction(ISD::SINT_TO_FP, VT, Expand); + setOperationAction(ISD::UINT_TO_FP, VT, Expand); + setOperationAction(ISD::FP_TO_SINT, VT, Expand); + setOperationAction(ISD::FP_TO_UINT, VT, Expand); + } } + setOperationAction(ISD::SETCC, MVT::v2i1, Expand); + setOperationAction(ISD::TRUNCATE, MVT::v2i1, Expand); + setOperationAction(ISD::AND, MVT::v2i1, Expand); + setOperationAction(ISD::OR, MVT::v2i1, Expand); + setOperationAction(ISD::XOR, MVT::v2i1, Expand); + setOperationAction(ISD::SINT_TO_FP, MVT::v2i1, Expand); + setOperationAction(ISD::UINT_TO_FP, MVT::v2i1, Expand); + setOperationAction(ISD::FP_TO_SINT, MVT::v2i1, Expand); + setOperationAction(ISD::FP_TO_UINT, MVT::v2i1, Expand); + + setOperationAction(ISD::SIGN_EXTEND, MVT::v8i32, Custom); + setOperationAction(ISD::SIGN_EXTEND, MVT::v16i16, Custom); + setOperationAction(ISD::SIGN_EXTEND, MVT::v16i32, Custom); + setOperationAction(ISD::ZERO_EXTEND, MVT::v8i32, Custom); + setOperationAction(ISD::ZERO_EXTEND, MVT::v16i16, Custom); + setOperationAction(ISD::ZERO_EXTEND, MVT::v16i32, Custom); + setOperationAction(ISD::TRUNCATE, MVT::v8i32, Custom); + setOperationAction(ISD::TRUNCATE, MVT::v16i16, Custom); } ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, @@ -429,7 +490,7 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, setBooleanVectorContents(ZeroOrNegativeOneBooleanContent); if (!Subtarget->isTargetDarwin() && !Subtarget->isTargetIOS() && - !Subtarget->isTargetWatchOS()) { + !Subtarget->isTargetWatchOS() && !Subtarget->isTargetDriverKit()) { bool IsHFTarget = TM.Options.FloatABIType == FloatABI::Hard; for (int LCID = 0; LCID < RTLIB::UNKNOWN_LIBCALL; ++LCID) setLibcallCallingConv(static_cast<RTLIB::Libcall>(LCID), @@ -511,6 +572,9 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, setLibcallName(RTLIB::SHL_I128, nullptr); setLibcallName(RTLIB::SRL_I128, nullptr); setLibcallName(RTLIB::SRA_I128, nullptr); + setLibcallName(RTLIB::MUL_I128, nullptr); + setLibcallName(RTLIB::MULO_I64, nullptr); + setLibcallName(RTLIB::MULO_I128, nullptr); // RTLIB if (Subtarget->isAAPCS_ABI() && @@ -708,6 +772,12 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, Subtarget->hasFPRegs()) { addRegisterClass(MVT::f32, &ARM::SPRRegClass); addRegisterClass(MVT::f64, &ARM::DPRRegClass); + + setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i32, Custom); + setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i32, Custom); + setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i64, Custom); + setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i64, Custom); + if (!Subtarget->hasVFP2Base()) setAllExpand(MVT::f32); if (!Subtarget->hasFP64()) @@ -717,22 +787,26 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, if (Subtarget->hasFullFP16()) { addRegisterClass(MVT::f16, &ARM::HPRRegClass); setOperationAction(ISD::BITCAST, MVT::i16, Custom); - setOperationAction(ISD::BITCAST, MVT::i32, Custom); setOperationAction(ISD::BITCAST, MVT::f16, Custom); setOperationAction(ISD::FMINNUM, MVT::f16, Legal); setOperationAction(ISD::FMAXNUM, MVT::f16, Legal); } + if (Subtarget->hasBF16()) { + addRegisterClass(MVT::bf16, &ARM::HPRRegClass); + setAllExpand(MVT::bf16); + if (!Subtarget->hasFullFP16()) + setOperationAction(ISD::BITCAST, MVT::bf16, Custom); + } + for (MVT VT : MVT::fixedlen_vector_valuetypes()) { for (MVT InnerVT : MVT::fixedlen_vector_valuetypes()) { setTruncStoreAction(VT, InnerVT, Expand); addAllExtLoads(VT, InnerVT, Expand); } - setOperationAction(ISD::MULHS, VT, Expand); setOperationAction(ISD::SMUL_LOHI, VT, Expand); - setOperationAction(ISD::MULHU, VT, Expand); setOperationAction(ISD::UMUL_LOHI, VT, Expand); setOperationAction(ISD::BSWAP, VT, Expand); @@ -749,8 +823,7 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, // Combine low-overhead loop intrinsics so that we can lower i1 types. if (Subtarget->hasLOB()) { - setTargetDAGCombine(ISD::BRCOND); - setTargetDAGCombine(ISD::BR_CC); + setTargetDAGCombine({ISD::BRCOND, ISD::BR_CC}); } if (Subtarget->hasNEON()) { @@ -771,6 +844,11 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, addQRTypeForNEON(MVT::v8f16); addDRTypeForNEON(MVT::v4f16); } + + if (Subtarget->hasBF16()) { + addQRTypeForNEON(MVT::v8bf16); + addDRTypeForNEON(MVT::v4bf16); + } } if (Subtarget->hasMVEIntegerOps() || Subtarget->hasNEON()) { @@ -906,22 +984,19 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::v4i32, Custom); setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::v2i64, Custom); + for (MVT VT : MVT::fixedlen_vector_valuetypes()) { + setOperationAction(ISD::MULHS, VT, Expand); + setOperationAction(ISD::MULHU, VT, Expand); + } + // NEON only has FMA instructions as of VFP4. if (!Subtarget->hasVFP4Base()) { setOperationAction(ISD::FMA, MVT::v2f32, Expand); setOperationAction(ISD::FMA, MVT::v4f32, Expand); } - setTargetDAGCombine(ISD::INTRINSIC_VOID); - setTargetDAGCombine(ISD::INTRINSIC_W_CHAIN); - setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN); - setTargetDAGCombine(ISD::SHL); - setTargetDAGCombine(ISD::SRL); - setTargetDAGCombine(ISD::SRA); - setTargetDAGCombine(ISD::FP_TO_SINT); - setTargetDAGCombine(ISD::FP_TO_UINT); - setTargetDAGCombine(ISD::FDIV); - setTargetDAGCombine(ISD::LOAD); + setTargetDAGCombine({ISD::SHL, ISD::SRL, ISD::SRA, ISD::FP_TO_SINT, + ISD::FP_TO_UINT, ISD::FDIV, ISD::LOAD}); // It is legal to extload from v4i8 to v4i16 or v4i32. for (MVT Ty : {MVT::v8i8, MVT::v4i8, MVT::v2i8, MVT::v4i16, MVT::v2i16, @@ -935,13 +1010,20 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, } if (Subtarget->hasNEON() || Subtarget->hasMVEIntegerOps()) { - setTargetDAGCombine(ISD::BUILD_VECTOR); - setTargetDAGCombine(ISD::VECTOR_SHUFFLE); - setTargetDAGCombine(ISD::INSERT_VECTOR_ELT); - setTargetDAGCombine(ISD::STORE); - setTargetDAGCombine(ISD::SIGN_EXTEND); - setTargetDAGCombine(ISD::ZERO_EXTEND); - setTargetDAGCombine(ISD::ANY_EXTEND); + setTargetDAGCombine( + {ISD::BUILD_VECTOR, ISD::VECTOR_SHUFFLE, ISD::INSERT_SUBVECTOR, + ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT, + ISD::SIGN_EXTEND_INREG, ISD::STORE, ISD::SIGN_EXTEND, ISD::ZERO_EXTEND, + ISD::ANY_EXTEND, ISD::INTRINSIC_WO_CHAIN, ISD::INTRINSIC_W_CHAIN, + ISD::INTRINSIC_VOID, ISD::VECREDUCE_ADD, ISD::ADD, ISD::BITCAST}); + } + if (Subtarget->hasMVEIntegerOps()) { + setTargetDAGCombine({ISD::SMIN, ISD::UMIN, ISD::SMAX, ISD::UMAX, + ISD::FP_EXTEND, ISD::SELECT, ISD::SELECT_CC, + ISD::SETCC}); + } + if (Subtarget->hasMVEFloatOps()) { + setTargetDAGCombine(ISD::FADD); } if (!Subtarget->hasFP64()) { @@ -1049,6 +1131,10 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, setOperationAction(ISD::SSUBSAT, MVT::i8, Custom); setOperationAction(ISD::SADDSAT, MVT::i16, Custom); setOperationAction(ISD::SSUBSAT, MVT::i16, Custom); + setOperationAction(ISD::UADDSAT, MVT::i8, Custom); + setOperationAction(ISD::USUBSAT, MVT::i8, Custom); + setOperationAction(ISD::UADDSAT, MVT::i16, Custom); + setOperationAction(ISD::USUBSAT, MVT::i16, Custom); } if (Subtarget->hasBaseDSP()) { setOperationAction(ISD::SADDSAT, MVT::i32, Legal); @@ -1073,6 +1159,8 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, setOperationAction(ISD::SRA, MVT::i64, Custom); setOperationAction(ISD::INTRINSIC_VOID, MVT::Other, Custom); setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i64, Custom); + setOperationAction(ISD::LOAD, MVT::i64, Custom); + setOperationAction(ISD::STORE, MVT::i64, Custom); // MVE lowers 64 bit shifts to lsll and lsrl // assuming that ISD::SRL and SRA of i64 are already marked custom @@ -1269,6 +1357,32 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, } } + // Compute supported atomic widths. + if (Subtarget->isTargetLinux() || + (!Subtarget->isMClass() && Subtarget->hasV6Ops())) { + // For targets where __sync_* routines are reliably available, we use them + // if necessary. + // + // ARM Linux always supports 64-bit atomics through kernel-assisted atomic + // routines (kernel 3.1 or later). FIXME: Not with compiler-rt? + // + // ARMv6 targets have native instructions in ARM mode. For Thumb mode, + // such targets should provide __sync_* routines, which use the ARM mode + // instructions. (ARMv6 doesn't have dmb, but it has an equivalent + // encoding; see ARMISD::MEMBARRIER_MCR.) + setMaxAtomicSizeInBitsSupported(64); + } else if ((Subtarget->isMClass() && Subtarget->hasV8MBaselineOps()) || + Subtarget->hasForced32BitAtomics()) { + // Cortex-M (besides Cortex-M0) have 32-bit atomics. + setMaxAtomicSizeInBitsSupported(32); + } else { + // We can't assume anything about other targets; just use libatomic + // routines. + setMaxAtomicSizeInBitsSupported(0); + } + + setMaxDivRemBitWidthSupported(64); + setOperationAction(ISD::PREFETCH, MVT::Other, Custom); // Requires SXTB/SXTH, available on v6 and up in both ARM and Thumb modes. @@ -1283,7 +1397,8 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, // Turn f64->i64 into VMOVRRD, i64 -> f64 to VMOVDRR // iff target supports vfp2. setOperationAction(ISD::BITCAST, MVT::i64, Custom); - setOperationAction(ISD::FLT_ROUNDS_, MVT::i32, Custom); + setOperationAction(ISD::GET_ROUNDING, MVT::i32, Custom); + setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom); } // We want to custom lower some of our intrinsics. @@ -1419,12 +1534,16 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, } if (Subtarget->hasNEON()) { - // vmin and vmax aren't available in a scalar form, so we use - // a NEON instruction with an undef lane instead. - setOperationAction(ISD::FMINIMUM, MVT::f16, Legal); - setOperationAction(ISD::FMAXIMUM, MVT::f16, Legal); - setOperationAction(ISD::FMINIMUM, MVT::f32, Legal); - setOperationAction(ISD::FMAXIMUM, MVT::f32, Legal); + // vmin and vmax aren't available in a scalar form, so we can use + // a NEON instruction with an undef lane instead. This has a performance + // penalty on some cores, so we don't do this unless we have been + // asked to by the core tuning model. + if (Subtarget->useNEONForSinglePrecisionFP()) { + setOperationAction(ISD::FMINIMUM, MVT::f32, Legal); + setOperationAction(ISD::FMAXIMUM, MVT::f32, Legal); + setOperationAction(ISD::FMINIMUM, MVT::f16, Legal); + setOperationAction(ISD::FMAXIMUM, MVT::f16, Legal); + } setOperationAction(ISD::FMINIMUM, MVT::v2f32, Legal); setOperationAction(ISD::FMAXIMUM, MVT::v2f32, Legal); setOperationAction(ISD::FMINIMUM, MVT::v4f32, Legal); @@ -1445,17 +1564,21 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, // We have target-specific dag combine patterns for the following nodes: // ARMISD::VMOVRRD - No need to call setTargetDAGCombine - setTargetDAGCombine(ISD::ADD); - setTargetDAGCombine(ISD::SUB); - setTargetDAGCombine(ISD::MUL); - setTargetDAGCombine(ISD::AND); - setTargetDAGCombine(ISD::OR); - setTargetDAGCombine(ISD::XOR); + setTargetDAGCombine( + {ISD::ADD, ISD::SUB, ISD::MUL, ISD::AND, ISD::OR, ISD::XOR}); + + if (Subtarget->hasMVEIntegerOps()) + setTargetDAGCombine(ISD::VSELECT); if (Subtarget->hasV6Ops()) setTargetDAGCombine(ISD::SRL); if (Subtarget->isThumb1Only()) setTargetDAGCombine(ISD::SHL); + // Attempt to lower smin/smax to ssat/usat + if ((!Subtarget->isThumb() && Subtarget->hasV6Ops()) || + Subtarget->isThumb2()) { + setTargetDAGCombine({ISD::SMIN, ISD::SMAX}); + } setStackPointerRegisterToSaveRestore(ARM::SP); @@ -1541,170 +1664,216 @@ ARMTargetLowering::findRepresentativeClass(const TargetRegisterInfo *TRI, } const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const { +#define MAKE_CASE(V) \ + case V: \ + return #V; switch ((ARMISD::NodeType)Opcode) { - case ARMISD::FIRST_NUMBER: break; - case ARMISD::Wrapper: return "ARMISD::Wrapper"; - case ARMISD::WrapperPIC: return "ARMISD::WrapperPIC"; - case ARMISD::WrapperJT: return "ARMISD::WrapperJT"; - case ARMISD::COPY_STRUCT_BYVAL: return "ARMISD::COPY_STRUCT_BYVAL"; - case ARMISD::CALL: return "ARMISD::CALL"; - case ARMISD::CALL_PRED: return "ARMISD::CALL_PRED"; - case ARMISD::CALL_NOLINK: return "ARMISD::CALL_NOLINK"; - case ARMISD::BRCOND: return "ARMISD::BRCOND"; - case ARMISD::BR_JT: return "ARMISD::BR_JT"; - case ARMISD::BR2_JT: return "ARMISD::BR2_JT"; - case ARMISD::RET_FLAG: return "ARMISD::RET_FLAG"; - case ARMISD::INTRET_FLAG: return "ARMISD::INTRET_FLAG"; - case ARMISD::PIC_ADD: return "ARMISD::PIC_ADD"; - case ARMISD::CMP: return "ARMISD::CMP"; - case ARMISD::CMN: return "ARMISD::CMN"; - case ARMISD::CMPZ: return "ARMISD::CMPZ"; - case ARMISD::CMPFP: return "ARMISD::CMPFP"; - case ARMISD::CMPFPE: return "ARMISD::CMPFPE"; - case ARMISD::CMPFPw0: return "ARMISD::CMPFPw0"; - case ARMISD::CMPFPEw0: return "ARMISD::CMPFPEw0"; - case ARMISD::BCC_i64: return "ARMISD::BCC_i64"; - case ARMISD::FMSTAT: return "ARMISD::FMSTAT"; - - case ARMISD::CMOV: return "ARMISD::CMOV"; - case ARMISD::SUBS: return "ARMISD::SUBS"; - - case ARMISD::SSAT: return "ARMISD::SSAT"; - case ARMISD::USAT: return "ARMISD::USAT"; - - case ARMISD::ASRL: return "ARMISD::ASRL"; - case ARMISD::LSRL: return "ARMISD::LSRL"; - case ARMISD::LSLL: return "ARMISD::LSLL"; - - case ARMISD::SRL_FLAG: return "ARMISD::SRL_FLAG"; - case ARMISD::SRA_FLAG: return "ARMISD::SRA_FLAG"; - case ARMISD::RRX: return "ARMISD::RRX"; - - case ARMISD::ADDC: return "ARMISD::ADDC"; - case ARMISD::ADDE: return "ARMISD::ADDE"; - case ARMISD::SUBC: return "ARMISD::SUBC"; - case ARMISD::SUBE: return "ARMISD::SUBE"; - case ARMISD::LSLS: return "ARMISD::LSLS"; - - case ARMISD::VMOVRRD: return "ARMISD::VMOVRRD"; - case ARMISD::VMOVDRR: return "ARMISD::VMOVDRR"; - case ARMISD::VMOVhr: return "ARMISD::VMOVhr"; - case ARMISD::VMOVrh: return "ARMISD::VMOVrh"; - case ARMISD::VMOVSR: return "ARMISD::VMOVSR"; - - case ARMISD::EH_SJLJ_SETJMP: return "ARMISD::EH_SJLJ_SETJMP"; - case ARMISD::EH_SJLJ_LONGJMP: return "ARMISD::EH_SJLJ_LONGJMP"; - case ARMISD::EH_SJLJ_SETUP_DISPATCH: return "ARMISD::EH_SJLJ_SETUP_DISPATCH"; - - case ARMISD::TC_RETURN: return "ARMISD::TC_RETURN"; - - case ARMISD::THREAD_POINTER:return "ARMISD::THREAD_POINTER"; - - case ARMISD::DYN_ALLOC: return "ARMISD::DYN_ALLOC"; - - case ARMISD::MEMBARRIER_MCR: return "ARMISD::MEMBARRIER_MCR"; - - case ARMISD::PRELOAD: return "ARMISD::PRELOAD"; - - case ARMISD::WIN__CHKSTK: return "ARMISD::WIN__CHKSTK"; - case ARMISD::WIN__DBZCHK: return "ARMISD::WIN__DBZCHK"; - - case ARMISD::PREDICATE_CAST: return "ARMISD::PREDICATE_CAST"; - case ARMISD::VCMP: return "ARMISD::VCMP"; - case ARMISD::VCMPZ: return "ARMISD::VCMPZ"; - case ARMISD::VTST: return "ARMISD::VTST"; - - case ARMISD::VSHLs: return "ARMISD::VSHLs"; - case ARMISD::VSHLu: return "ARMISD::VSHLu"; - case ARMISD::VSHLIMM: return "ARMISD::VSHLIMM"; - case ARMISD::VSHRsIMM: return "ARMISD::VSHRsIMM"; - case ARMISD::VSHRuIMM: return "ARMISD::VSHRuIMM"; - case ARMISD::VRSHRsIMM: return "ARMISD::VRSHRsIMM"; - case ARMISD::VRSHRuIMM: return "ARMISD::VRSHRuIMM"; - case ARMISD::VRSHRNIMM: return "ARMISD::VRSHRNIMM"; - case ARMISD::VQSHLsIMM: return "ARMISD::VQSHLsIMM"; - case ARMISD::VQSHLuIMM: return "ARMISD::VQSHLuIMM"; - case ARMISD::VQSHLsuIMM: return "ARMISD::VQSHLsuIMM"; - case ARMISD::VQSHRNsIMM: return "ARMISD::VQSHRNsIMM"; - case ARMISD::VQSHRNuIMM: return "ARMISD::VQSHRNuIMM"; - case ARMISD::VQSHRNsuIMM: return "ARMISD::VQSHRNsuIMM"; - case ARMISD::VQRSHRNsIMM: return "ARMISD::VQRSHRNsIMM"; - case ARMISD::VQRSHRNuIMM: return "ARMISD::VQRSHRNuIMM"; - case ARMISD::VQRSHRNsuIMM: return "ARMISD::VQRSHRNsuIMM"; - case ARMISD::VSLIIMM: return "ARMISD::VSLIIMM"; - case ARMISD::VSRIIMM: return "ARMISD::VSRIIMM"; - case ARMISD::VGETLANEu: return "ARMISD::VGETLANEu"; - case ARMISD::VGETLANEs: return "ARMISD::VGETLANEs"; - case ARMISD::VMOVIMM: return "ARMISD::VMOVIMM"; - case ARMISD::VMVNIMM: return "ARMISD::VMVNIMM"; - case ARMISD::VMOVFPIMM: return "ARMISD::VMOVFPIMM"; - case ARMISD::VDUP: return "ARMISD::VDUP"; - case ARMISD::VDUPLANE: return "ARMISD::VDUPLANE"; - case ARMISD::VEXT: return "ARMISD::VEXT"; - case ARMISD::VREV64: return "ARMISD::VREV64"; - case ARMISD::VREV32: return "ARMISD::VREV32"; - case ARMISD::VREV16: return "ARMISD::VREV16"; - case ARMISD::VZIP: return "ARMISD::VZIP"; - case ARMISD::VUZP: return "ARMISD::VUZP"; - case ARMISD::VTRN: return "ARMISD::VTRN"; - case ARMISD::VTBL1: return "ARMISD::VTBL1"; - case ARMISD::VTBL2: return "ARMISD::VTBL2"; - case ARMISD::VMOVN: return "ARMISD::VMOVN"; - case ARMISD::VMULLs: return "ARMISD::VMULLs"; - case ARMISD::VMULLu: return "ARMISD::VMULLu"; - case ARMISD::UMAAL: return "ARMISD::UMAAL"; - case ARMISD::UMLAL: return "ARMISD::UMLAL"; - case ARMISD::SMLAL: return "ARMISD::SMLAL"; - case ARMISD::SMLALBB: return "ARMISD::SMLALBB"; - case ARMISD::SMLALBT: return "ARMISD::SMLALBT"; - case ARMISD::SMLALTB: return "ARMISD::SMLALTB"; - case ARMISD::SMLALTT: return "ARMISD::SMLALTT"; - case ARMISD::SMULWB: return "ARMISD::SMULWB"; - case ARMISD::SMULWT: return "ARMISD::SMULWT"; - case ARMISD::SMLALD: return "ARMISD::SMLALD"; - case ARMISD::SMLALDX: return "ARMISD::SMLALDX"; - case ARMISD::SMLSLD: return "ARMISD::SMLSLD"; - case ARMISD::SMLSLDX: return "ARMISD::SMLSLDX"; - case ARMISD::SMMLAR: return "ARMISD::SMMLAR"; - case ARMISD::SMMLSR: return "ARMISD::SMMLSR"; - case ARMISD::QADD16b: return "ARMISD::QADD16b"; - case ARMISD::QSUB16b: return "ARMISD::QSUB16b"; - case ARMISD::QADD8b: return "ARMISD::QADD8b"; - case ARMISD::QSUB8b: return "ARMISD::QSUB8b"; - case ARMISD::BUILD_VECTOR: return "ARMISD::BUILD_VECTOR"; - case ARMISD::BFI: return "ARMISD::BFI"; - case ARMISD::VORRIMM: return "ARMISD::VORRIMM"; - case ARMISD::VBICIMM: return "ARMISD::VBICIMM"; - case ARMISD::VBSL: return "ARMISD::VBSL"; - case ARMISD::MEMCPY: return "ARMISD::MEMCPY"; - case ARMISD::VLD1DUP: return "ARMISD::VLD1DUP"; - case ARMISD::VLD2DUP: return "ARMISD::VLD2DUP"; - case ARMISD::VLD3DUP: return "ARMISD::VLD3DUP"; - case ARMISD::VLD4DUP: return "ARMISD::VLD4DUP"; - case ARMISD::VLD1_UPD: return "ARMISD::VLD1_UPD"; - case ARMISD::VLD2_UPD: return "ARMISD::VLD2_UPD"; - case ARMISD::VLD3_UPD: return "ARMISD::VLD3_UPD"; - case ARMISD::VLD4_UPD: return "ARMISD::VLD4_UPD"; - case ARMISD::VLD2LN_UPD: return "ARMISD::VLD2LN_UPD"; - case ARMISD::VLD3LN_UPD: return "ARMISD::VLD3LN_UPD"; - case ARMISD::VLD4LN_UPD: return "ARMISD::VLD4LN_UPD"; - case ARMISD::VLD1DUP_UPD: return "ARMISD::VLD1DUP_UPD"; - case ARMISD::VLD2DUP_UPD: return "ARMISD::VLD2DUP_UPD"; - case ARMISD::VLD3DUP_UPD: return "ARMISD::VLD3DUP_UPD"; - case ARMISD::VLD4DUP_UPD: return "ARMISD::VLD4DUP_UPD"; - case ARMISD::VST1_UPD: return "ARMISD::VST1_UPD"; - case ARMISD::VST2_UPD: return "ARMISD::VST2_UPD"; - case ARMISD::VST3_UPD: return "ARMISD::VST3_UPD"; - case ARMISD::VST4_UPD: return "ARMISD::VST4_UPD"; - case ARMISD::VST2LN_UPD: return "ARMISD::VST2LN_UPD"; - case ARMISD::VST3LN_UPD: return "ARMISD::VST3LN_UPD"; - case ARMISD::VST4LN_UPD: return "ARMISD::VST4LN_UPD"; - case ARMISD::WLS: return "ARMISD::WLS"; - case ARMISD::LE: return "ARMISD::LE"; - case ARMISD::LOOP_DEC: return "ARMISD::LOOP_DEC"; - case ARMISD::CSINV: return "ARMISD::CSINV"; - case ARMISD::CSNEG: return "ARMISD::CSNEG"; - case ARMISD::CSINC: return "ARMISD::CSINC"; + case ARMISD::FIRST_NUMBER: + break; + MAKE_CASE(ARMISD::Wrapper) + MAKE_CASE(ARMISD::WrapperPIC) + MAKE_CASE(ARMISD::WrapperJT) + MAKE_CASE(ARMISD::COPY_STRUCT_BYVAL) + MAKE_CASE(ARMISD::CALL) + MAKE_CASE(ARMISD::CALL_PRED) + MAKE_CASE(ARMISD::CALL_NOLINK) + MAKE_CASE(ARMISD::tSECALL) + MAKE_CASE(ARMISD::t2CALL_BTI) + MAKE_CASE(ARMISD::BRCOND) + MAKE_CASE(ARMISD::BR_JT) + MAKE_CASE(ARMISD::BR2_JT) + MAKE_CASE(ARMISD::RET_FLAG) + MAKE_CASE(ARMISD::SERET_FLAG) + MAKE_CASE(ARMISD::INTRET_FLAG) + MAKE_CASE(ARMISD::PIC_ADD) + MAKE_CASE(ARMISD::CMP) + MAKE_CASE(ARMISD::CMN) + MAKE_CASE(ARMISD::CMPZ) + MAKE_CASE(ARMISD::CMPFP) + MAKE_CASE(ARMISD::CMPFPE) + MAKE_CASE(ARMISD::CMPFPw0) + MAKE_CASE(ARMISD::CMPFPEw0) + MAKE_CASE(ARMISD::BCC_i64) + MAKE_CASE(ARMISD::FMSTAT) + MAKE_CASE(ARMISD::CMOV) + MAKE_CASE(ARMISD::SUBS) + MAKE_CASE(ARMISD::SSAT) + MAKE_CASE(ARMISD::USAT) + MAKE_CASE(ARMISD::ASRL) + MAKE_CASE(ARMISD::LSRL) + MAKE_CASE(ARMISD::LSLL) + MAKE_CASE(ARMISD::SRL_FLAG) + MAKE_CASE(ARMISD::SRA_FLAG) + MAKE_CASE(ARMISD::RRX) + MAKE_CASE(ARMISD::ADDC) + MAKE_CASE(ARMISD::ADDE) + MAKE_CASE(ARMISD::SUBC) + MAKE_CASE(ARMISD::SUBE) + MAKE_CASE(ARMISD::LSLS) + MAKE_CASE(ARMISD::VMOVRRD) + MAKE_CASE(ARMISD::VMOVDRR) + MAKE_CASE(ARMISD::VMOVhr) + MAKE_CASE(ARMISD::VMOVrh) + MAKE_CASE(ARMISD::VMOVSR) + MAKE_CASE(ARMISD::EH_SJLJ_SETJMP) + MAKE_CASE(ARMISD::EH_SJLJ_LONGJMP) + MAKE_CASE(ARMISD::EH_SJLJ_SETUP_DISPATCH) + MAKE_CASE(ARMISD::TC_RETURN) + MAKE_CASE(ARMISD::THREAD_POINTER) + MAKE_CASE(ARMISD::DYN_ALLOC) + MAKE_CASE(ARMISD::MEMBARRIER_MCR) + MAKE_CASE(ARMISD::PRELOAD) + MAKE_CASE(ARMISD::LDRD) + MAKE_CASE(ARMISD::STRD) + MAKE_CASE(ARMISD::WIN__CHKSTK) + MAKE_CASE(ARMISD::WIN__DBZCHK) + MAKE_CASE(ARMISD::PREDICATE_CAST) + MAKE_CASE(ARMISD::VECTOR_REG_CAST) + MAKE_CASE(ARMISD::MVESEXT) + MAKE_CASE(ARMISD::MVEZEXT) + MAKE_CASE(ARMISD::MVETRUNC) + MAKE_CASE(ARMISD::VCMP) + MAKE_CASE(ARMISD::VCMPZ) + MAKE_CASE(ARMISD::VTST) + MAKE_CASE(ARMISD::VSHLs) + MAKE_CASE(ARMISD::VSHLu) + MAKE_CASE(ARMISD::VSHLIMM) + MAKE_CASE(ARMISD::VSHRsIMM) + MAKE_CASE(ARMISD::VSHRuIMM) + MAKE_CASE(ARMISD::VRSHRsIMM) + MAKE_CASE(ARMISD::VRSHRuIMM) + MAKE_CASE(ARMISD::VRSHRNIMM) + MAKE_CASE(ARMISD::VQSHLsIMM) + MAKE_CASE(ARMISD::VQSHLuIMM) + MAKE_CASE(ARMISD::VQSHLsuIMM) + MAKE_CASE(ARMISD::VQSHRNsIMM) + MAKE_CASE(ARMISD::VQSHRNuIMM) + MAKE_CASE(ARMISD::VQSHRNsuIMM) + MAKE_CASE(ARMISD::VQRSHRNsIMM) + MAKE_CASE(ARMISD::VQRSHRNuIMM) + MAKE_CASE(ARMISD::VQRSHRNsuIMM) + MAKE_CASE(ARMISD::VSLIIMM) + MAKE_CASE(ARMISD::VSRIIMM) + MAKE_CASE(ARMISD::VGETLANEu) + MAKE_CASE(ARMISD::VGETLANEs) + MAKE_CASE(ARMISD::VMOVIMM) + MAKE_CASE(ARMISD::VMVNIMM) + MAKE_CASE(ARMISD::VMOVFPIMM) + MAKE_CASE(ARMISD::VDUP) + MAKE_CASE(ARMISD::VDUPLANE) + MAKE_CASE(ARMISD::VEXT) + MAKE_CASE(ARMISD::VREV64) + MAKE_CASE(ARMISD::VREV32) + MAKE_CASE(ARMISD::VREV16) + MAKE_CASE(ARMISD::VZIP) + MAKE_CASE(ARMISD::VUZP) + MAKE_CASE(ARMISD::VTRN) + MAKE_CASE(ARMISD::VTBL1) + MAKE_CASE(ARMISD::VTBL2) + MAKE_CASE(ARMISD::VMOVN) + MAKE_CASE(ARMISD::VQMOVNs) + MAKE_CASE(ARMISD::VQMOVNu) + MAKE_CASE(ARMISD::VCVTN) + MAKE_CASE(ARMISD::VCVTL) + MAKE_CASE(ARMISD::VIDUP) + MAKE_CASE(ARMISD::VMULLs) + MAKE_CASE(ARMISD::VMULLu) + MAKE_CASE(ARMISD::VQDMULH) + MAKE_CASE(ARMISD::VADDVs) + MAKE_CASE(ARMISD::VADDVu) + MAKE_CASE(ARMISD::VADDVps) + MAKE_CASE(ARMISD::VADDVpu) + MAKE_CASE(ARMISD::VADDLVs) + MAKE_CASE(ARMISD::VADDLVu) + MAKE_CASE(ARMISD::VADDLVAs) + MAKE_CASE(ARMISD::VADDLVAu) + MAKE_CASE(ARMISD::VADDLVps) + MAKE_CASE(ARMISD::VADDLVpu) + MAKE_CASE(ARMISD::VADDLVAps) + MAKE_CASE(ARMISD::VADDLVApu) + MAKE_CASE(ARMISD::VMLAVs) + MAKE_CASE(ARMISD::VMLAVu) + MAKE_CASE(ARMISD::VMLAVps) + MAKE_CASE(ARMISD::VMLAVpu) + MAKE_CASE(ARMISD::VMLALVs) + MAKE_CASE(ARMISD::VMLALVu) + MAKE_CASE(ARMISD::VMLALVps) + MAKE_CASE(ARMISD::VMLALVpu) + MAKE_CASE(ARMISD::VMLALVAs) + MAKE_CASE(ARMISD::VMLALVAu) + MAKE_CASE(ARMISD::VMLALVAps) + MAKE_CASE(ARMISD::VMLALVApu) + MAKE_CASE(ARMISD::VMINVu) + MAKE_CASE(ARMISD::VMINVs) + MAKE_CASE(ARMISD::VMAXVu) + MAKE_CASE(ARMISD::VMAXVs) + MAKE_CASE(ARMISD::UMAAL) + MAKE_CASE(ARMISD::UMLAL) + MAKE_CASE(ARMISD::SMLAL) + MAKE_CASE(ARMISD::SMLALBB) + MAKE_CASE(ARMISD::SMLALBT) + MAKE_CASE(ARMISD::SMLALTB) + MAKE_CASE(ARMISD::SMLALTT) + MAKE_CASE(ARMISD::SMULWB) + MAKE_CASE(ARMISD::SMULWT) + MAKE_CASE(ARMISD::SMLALD) + MAKE_CASE(ARMISD::SMLALDX) + MAKE_CASE(ARMISD::SMLSLD) + MAKE_CASE(ARMISD::SMLSLDX) + MAKE_CASE(ARMISD::SMMLAR) + MAKE_CASE(ARMISD::SMMLSR) + MAKE_CASE(ARMISD::QADD16b) + MAKE_CASE(ARMISD::QSUB16b) + MAKE_CASE(ARMISD::QADD8b) + MAKE_CASE(ARMISD::QSUB8b) + MAKE_CASE(ARMISD::UQADD16b) + MAKE_CASE(ARMISD::UQSUB16b) + MAKE_CASE(ARMISD::UQADD8b) + MAKE_CASE(ARMISD::UQSUB8b) + MAKE_CASE(ARMISD::BUILD_VECTOR) + MAKE_CASE(ARMISD::BFI) + MAKE_CASE(ARMISD::VORRIMM) + MAKE_CASE(ARMISD::VBICIMM) + MAKE_CASE(ARMISD::VBSP) + MAKE_CASE(ARMISD::MEMCPY) + MAKE_CASE(ARMISD::VLD1DUP) + MAKE_CASE(ARMISD::VLD2DUP) + MAKE_CASE(ARMISD::VLD3DUP) + MAKE_CASE(ARMISD::VLD4DUP) + MAKE_CASE(ARMISD::VLD1_UPD) + MAKE_CASE(ARMISD::VLD2_UPD) + MAKE_CASE(ARMISD::VLD3_UPD) + MAKE_CASE(ARMISD::VLD4_UPD) + MAKE_CASE(ARMISD::VLD1x2_UPD) + MAKE_CASE(ARMISD::VLD1x3_UPD) + MAKE_CASE(ARMISD::VLD1x4_UPD) + MAKE_CASE(ARMISD::VLD2LN_UPD) + MAKE_CASE(ARMISD::VLD3LN_UPD) + MAKE_CASE(ARMISD::VLD4LN_UPD) + MAKE_CASE(ARMISD::VLD1DUP_UPD) + MAKE_CASE(ARMISD::VLD2DUP_UPD) + MAKE_CASE(ARMISD::VLD3DUP_UPD) + MAKE_CASE(ARMISD::VLD4DUP_UPD) + MAKE_CASE(ARMISD::VST1_UPD) + MAKE_CASE(ARMISD::VST2_UPD) + MAKE_CASE(ARMISD::VST3_UPD) + MAKE_CASE(ARMISD::VST4_UPD) + MAKE_CASE(ARMISD::VST1x2_UPD) + MAKE_CASE(ARMISD::VST1x3_UPD) + MAKE_CASE(ARMISD::VST1x4_UPD) + MAKE_CASE(ARMISD::VST2LN_UPD) + MAKE_CASE(ARMISD::VST3LN_UPD) + MAKE_CASE(ARMISD::VST4LN_UPD) + MAKE_CASE(ARMISD::WLS) + MAKE_CASE(ARMISD::WLSSETUP) + MAKE_CASE(ARMISD::LE) + MAKE_CASE(ARMISD::LOOP_DEC) + MAKE_CASE(ARMISD::CSINV) + MAKE_CASE(ARMISD::CSNEG) + MAKE_CASE(ARMISD::CSINC) + MAKE_CASE(ARMISD::MEMCPYLOOP) + MAKE_CASE(ARMISD::MEMSETLOOP) +#undef MAKE_CASE } return nullptr; } @@ -1715,8 +1884,11 @@ EVT ARMTargetLowering::getSetCCResultType(const DataLayout &DL, LLVMContext &, return getPointerTy(DL); // MVE has a predicate register. - if (Subtarget->hasMVEIntegerOps() && - (VT == MVT::v4i32 || VT == MVT::v8i16 || VT == MVT::v16i8)) + if ((Subtarget->hasMVEIntegerOps() && + (VT == MVT::v2i64 || VT == MVT::v4i32 || VT == MVT::v8i16 || + VT == MVT::v16i8)) || + (Subtarget->hasMVEFloatOps() && + (VT == MVT::v2f64 || VT == MVT::v4f32 || VT == MVT::v8f16))) return MVT::getVectorVT(MVT::i1, VT.getVectorElementCount()); return VT.changeVectorElementTypeToInteger(); } @@ -1730,12 +1902,18 @@ ARMTargetLowering::getRegClassFor(MVT VT, bool isDivergent) const { // v8i64 to QQQQ registers. v4i64 and v8i64 are only used for REG_SEQUENCE to // load / store 4 to 8 consecutive NEON D registers, or 2 to 4 consecutive // MVE Q registers. - if (Subtarget->hasNEON() || Subtarget->hasMVEIntegerOps()) { + if (Subtarget->hasNEON()) { if (VT == MVT::v4i64) return &ARM::QQPRRegClass; if (VT == MVT::v8i64) return &ARM::QQQQPRRegClass; } + if (Subtarget->hasMVEIntegerOps()) { + if (VT == MVT::v4i64) + return &ARM::MQQPRRegClass; + if (VT == MVT::v8i64) + return &ARM::MQQQQPRRegClass; + } return TargetLowering::getRegClassFor(VT); } @@ -1743,13 +1921,14 @@ ARMTargetLowering::getRegClassFor(MVT VT, bool isDivergent) const { // source/dest is aligned and the copy size is large enough. We therefore want // to align such objects passed to memory intrinsics. bool ARMTargetLowering::shouldAlignPointerArgs(CallInst *CI, unsigned &MinSize, - unsigned &PrefAlign) const { + Align &PrefAlign) const { if (!isa<MemIntrinsic>(CI)) return false; MinSize = 8; // On ARM11 onwards (excluding M class) 8-byte aligned LDM is typically 1 // cycle faster than 4-byte aligned LDM. - PrefAlign = (Subtarget->hasV6Ops() && !Subtarget->isMClass() ? 8 : 4); + PrefAlign = + (Subtarget->hasV6Ops() && !Subtarget->isMClass() ? Align(8) : Align(4)); return true; } @@ -1896,8 +2075,10 @@ ARMTargetLowering::getEffectiveCallingConv(CallingConv::ID CC, return CallingConv::PreserveMost; case CallingConv::ARM_AAPCS_VFP: case CallingConv::Swift: + case CallingConv::SwiftTail: return isVarArg ? CallingConv::ARM_AAPCS : CallingConv::ARM_AAPCS_VFP; case CallingConv::C: + case CallingConv::Tail: if (!Subtarget->isAAPCS_ABI()) return CallingConv::ARM_APCS; else if (Subtarget->hasVFP2Base() && !Subtarget->isThumb1Only() && @@ -1955,6 +2136,35 @@ CCAssignFn *ARMTargetLowering::CCAssignFnForNode(CallingConv::ID CC, } } +SDValue ARMTargetLowering::MoveToHPR(const SDLoc &dl, SelectionDAG &DAG, + MVT LocVT, MVT ValVT, SDValue Val) const { + Val = DAG.getNode(ISD::BITCAST, dl, MVT::getIntegerVT(LocVT.getSizeInBits()), + Val); + if (Subtarget->hasFullFP16()) { + Val = DAG.getNode(ARMISD::VMOVhr, dl, ValVT, Val); + } else { + Val = DAG.getNode(ISD::TRUNCATE, dl, + MVT::getIntegerVT(ValVT.getSizeInBits()), Val); + Val = DAG.getNode(ISD::BITCAST, dl, ValVT, Val); + } + return Val; +} + +SDValue ARMTargetLowering::MoveFromHPR(const SDLoc &dl, SelectionDAG &DAG, + MVT LocVT, MVT ValVT, + SDValue Val) const { + if (Subtarget->hasFullFP16()) { + Val = DAG.getNode(ARMISD::VMOVrh, dl, + MVT::getIntegerVT(LocVT.getSizeInBits()), Val); + } else { + Val = DAG.getNode(ISD::BITCAST, dl, + MVT::getIntegerVT(ValVT.getSizeInBits()), Val); + Val = DAG.getNode(ISD::ZERO_EXTEND, dl, + MVT::getIntegerVT(LocVT.getSizeInBits()), Val); + } + return DAG.getNode(ISD::BITCAST, dl, LocVT, Val); +} + /// LowerCallResult - Lower the result values of a call into the /// appropriate copies out of appropriate physical registers. SDValue ARMTargetLowering::LowerCallResult( @@ -1982,7 +2192,8 @@ SDValue ARMTargetLowering::LowerCallResult( } SDValue Val; - if (VA.needsCustom()) { + if (VA.needsCustom() && + (VA.getLocVT() == MVT::f64 || VA.getLocVT() == MVT::v2f64)) { // Handle f64 or half of a v2f64. SDValue Lo = DAG.getCopyFromReg(Chain, dl, VA.getLocReg(), MVT::i32, InFlag); @@ -2031,25 +2242,44 @@ SDValue ARMTargetLowering::LowerCallResult( break; } + // f16 arguments have their size extended to 4 bytes and passed as if they + // had been copied to the LSBs of a 32-bit register. + // For that, it's passed extended to i32 (soft ABI) or to f32 (hard ABI) + if (VA.needsCustom() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) + Val = MoveToHPR(dl, DAG, VA.getLocVT(), VA.getValVT(), Val); + InVals.push_back(Val); } return Chain; } -/// LowerMemOpCallTo - Store the argument to the stack. -SDValue ARMTargetLowering::LowerMemOpCallTo(SDValue Chain, SDValue StackPtr, - SDValue Arg, const SDLoc &dl, - SelectionDAG &DAG, - const CCValAssign &VA, - ISD::ArgFlagsTy Flags) const { - unsigned LocMemOffset = VA.getLocMemOffset(); - SDValue PtrOff = DAG.getIntPtrConstant(LocMemOffset, dl); - PtrOff = DAG.getNode(ISD::ADD, dl, getPointerTy(DAG.getDataLayout()), - StackPtr, PtrOff); - return DAG.getStore( - Chain, dl, Arg, PtrOff, - MachinePointerInfo::getStack(DAG.getMachineFunction(), LocMemOffset)); +std::pair<SDValue, MachinePointerInfo> ARMTargetLowering::computeAddrForCallArg( + const SDLoc &dl, SelectionDAG &DAG, const CCValAssign &VA, SDValue StackPtr, + bool IsTailCall, int SPDiff) const { + SDValue DstAddr; + MachinePointerInfo DstInfo; + int32_t Offset = VA.getLocMemOffset(); + MachineFunction &MF = DAG.getMachineFunction(); + + if (IsTailCall) { + Offset += SPDiff; + auto PtrVT = getPointerTy(DAG.getDataLayout()); + int Size = VA.getLocVT().getFixedSizeInBits() / 8; + int FI = MF.getFrameInfo().CreateFixedObject(Size, Offset, true); + DstAddr = DAG.getFrameIndex(FI, PtrVT); + DstInfo = + MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI); + } else { + SDValue PtrOff = DAG.getIntPtrConstant(Offset, dl); + DstAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(DAG.getDataLayout()), + StackPtr, PtrOff); + DstInfo = + MachinePointerInfo::getStack(DAG.getMachineFunction(), Offset); + } + + return std::make_pair(DstAddr, DstInfo); } void ARMTargetLowering::PassF64ArgInRegs(const SDLoc &dl, SelectionDAG &DAG, @@ -2058,7 +2288,8 @@ void ARMTargetLowering::PassF64ArgInRegs(const SDLoc &dl, SelectionDAG &DAG, CCValAssign &VA, CCValAssign &NextVA, SDValue &StackPtr, SmallVectorImpl<SDValue> &MemOpChains, - ISD::ArgFlagsTy Flags) const { + bool IsTailCall, + int SPDiff) const { SDValue fmrrd = DAG.getNode(ARMISD::VMOVRRD, dl, DAG.getVTList(MVT::i32, MVT::i32), Arg); unsigned id = Subtarget->isLittle() ? 0 : 1; @@ -2072,12 +2303,20 @@ void ARMTargetLowering::PassF64ArgInRegs(const SDLoc &dl, SelectionDAG &DAG, StackPtr = DAG.getCopyFromReg(Chain, dl, ARM::SP, getPointerTy(DAG.getDataLayout())); - MemOpChains.push_back(LowerMemOpCallTo(Chain, StackPtr, fmrrd.getValue(1-id), - dl, DAG, NextVA, - Flags)); + SDValue DstAddr; + MachinePointerInfo DstInfo; + std::tie(DstAddr, DstInfo) = + computeAddrForCallArg(dl, DAG, NextVA, StackPtr, IsTailCall, SPDiff); + MemOpChains.push_back( + DAG.getStore(Chain, dl, fmrrd.getValue(1 - id), DstAddr, DstInfo)); } } +static bool canGuaranteeTCO(CallingConv::ID CC, bool GuaranteeTailCalls) { + return (CC == CallingConv::Fast && GuaranteeTailCalls) || + CC == CallingConv::Tail || CC == CallingConv::SwiftTail; +} + /// LowerCall - Lowering a call into a callseq_start <- /// ARMISD:CALL <- callseq_end chain. Also add input and output parameter /// nodes. @@ -2097,22 +2336,41 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, bool isVarArg = CLI.IsVarArg; MachineFunction &MF = DAG.getMachineFunction(); + ARMFunctionInfo *AFI = MF.getInfo<ARMFunctionInfo>(); MachineFunction::CallSiteInfo CSInfo; bool isStructRet = (Outs.empty()) ? false : Outs[0].Flags.isSRet(); bool isThisReturn = false; + bool isCmseNSCall = false; + bool isSibCall = false; bool PreferIndirect = false; + bool GuardWithBTI = false; + + // Lower 'returns_twice' calls to a pseudo-instruction. + if (CLI.CB && CLI.CB->getAttributes().hasFnAttr(Attribute::ReturnsTwice) && + !Subtarget->noBTIAtReturnTwice()) + GuardWithBTI = AFI->branchTargetEnforcement(); + + // Determine whether this is a non-secure function call. + if (CLI.CB && CLI.CB->getAttributes().hasFnAttr("cmse_nonsecure_call")) + isCmseNSCall = true; // Disable tail calls if they're not supported. if (!Subtarget->supportsTailCall()) isTailCall = false; + // For both the non-secure calls and the returns from a CMSE entry function, + // the function needs to do some extra work afte r the call, or before the + // return, respectively, thus it cannot end with atail call + if (isCmseNSCall || AFI->isCmseNSEntryFunction()) + isTailCall = false; + if (isa<GlobalAddressSDNode>(Callee)) { // If we're optimizing for minimum size and the function is called three or // more times in this block, we can improve codesize by calling indirectly // as BLXr has a 16-bit encoding. auto *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal(); - if (CLI.CS) { - auto *BB = CLI.CS.getParent(); + if (CLI.CB) { + auto *BB = CLI.CB->getParent(); PreferIndirect = Subtarget->isThumb() && Subtarget->hasMinSize() && count_if(GV->users(), [&BB](const User *U) { return isa<Instruction>(U) && @@ -2126,15 +2384,20 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, Callee, CallConv, isVarArg, isStructRet, MF.getFunction().hasStructRetAttr(), Outs, OutVals, Ins, DAG, PreferIndirect); - if (!isTailCall && CLI.CS && CLI.CS.isMustTailCall()) - report_fatal_error("failed to perform tail call elimination on a call " - "site marked musttail"); + + if (isTailCall && !getTargetMachine().Options.GuaranteedTailCallOpt && + CallConv != CallingConv::Tail && CallConv != CallingConv::SwiftTail) + isSibCall = true; + // We don't support GuaranteedTailCallOpt for ARM, only automatically // detected sibcalls. if (isTailCall) ++NumTailCalls; } + if (!isTailCall && CLI.CB && CLI.CB->isMustTailCall()) + report_fatal_error("failed to perform tail call elimination on a call " + "site marked musttail"); // Analyze operands of the call, assigning locations to each operand. SmallVector<CCValAssign, 16> ArgLocs; CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), ArgLocs, @@ -2144,13 +2407,40 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // Get a count of how many bytes are to be pushed on the stack. unsigned NumBytes = CCInfo.getNextStackOffset(); - if (isTailCall) { - // For tail calls, memory operands are available in our caller's stack. + // SPDiff is the byte offset of the call's argument area from the callee's. + // Stores to callee stack arguments will be placed in FixedStackSlots offset + // by this amount for a tail call. In a sibling call it must be 0 because the + // caller will deallocate the entire stack and the callee still expects its + // arguments to begin at SP+0. Completely unused for non-tail calls. + int SPDiff = 0; + + if (isTailCall && !isSibCall) { + auto FuncInfo = MF.getInfo<ARMFunctionInfo>(); + unsigned NumReusableBytes = FuncInfo->getArgumentStackSize(); + + // Since callee will pop argument stack as a tail call, we must keep the + // popped size 16-byte aligned. + Align StackAlign = DAG.getDataLayout().getStackAlignment(); + NumBytes = alignTo(NumBytes, StackAlign); + + // SPDiff will be negative if this tail call requires more space than we + // would automatically have in our incoming argument space. Positive if we + // can actually shrink the stack. + SPDiff = NumReusableBytes - NumBytes; + + // If this call requires more stack than we have available from + // LowerFormalArguments, tell FrameLowering to reserve space for it. + if (SPDiff < 0 && AFI->getArgRegsSaveSize() < (unsigned)-SPDiff) + AFI->setArgRegsSaveSize(-SPDiff); + } + + if (isSibCall) { + // For sibling tail calls, memory operands are available in our caller's stack. NumBytes = 0; } else { // Adjust the stack pointer for the new arguments... // These operations are automatically eliminated by the prolog/epilog pass - Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, dl); + Chain = DAG.getCALLSEQ_START(Chain, isTailCall ? 0 : NumBytes, 0, dl); } SDValue StackPtr = @@ -2159,6 +2449,13 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, RegsToPassVector RegsToPass; SmallVector<SDValue, 8> MemOpChains; + // During a tail call, stores to the argument area must happen after all of + // the function's incoming arguments have been loaded because they may alias. + // This is done by folding in a TokenFactor from LowerFormalArguments, but + // there's no point in doing so repeatedly so this tracks whether that's + // happened yet. + bool AfterFormalArgLoads = false; + // Walk the register/memloc assignments, inserting copies/loads. In the case // of tail call optimization, arguments are handled later. for (unsigned i = 0, realArgIdx = 0, e = ArgLocs.size(); @@ -2187,31 +2484,57 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, break; } - // f64 and v2f64 might be passed in i32 pairs and must be split into pieces - if (VA.needsCustom()) { - if (VA.getLocVT() == MVT::v2f64) { - SDValue Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f64, Arg, - DAG.getConstant(0, dl, MVT::i32)); - SDValue Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f64, Arg, - DAG.getConstant(1, dl, MVT::i32)); - - PassF64ArgInRegs(dl, DAG, Chain, Op0, RegsToPass, - VA, ArgLocs[++i], StackPtr, MemOpChains, Flags); + if (isTailCall && VA.isMemLoc() && !AfterFormalArgLoads) { + Chain = DAG.getStackArgumentTokenFactor(Chain); + AfterFormalArgLoads = true; + } - VA = ArgLocs[++i]; // skip ahead to next loc - if (VA.isRegLoc()) { - PassF64ArgInRegs(dl, DAG, Chain, Op1, RegsToPass, - VA, ArgLocs[++i], StackPtr, MemOpChains, Flags); - } else { - assert(VA.isMemLoc()); + // f16 arguments have their size extended to 4 bytes and passed as if they + // had been copied to the LSBs of a 32-bit register. + // For that, it's passed extended to i32 (soft ABI) or to f32 (hard ABI) + if (VA.needsCustom() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) { + Arg = MoveFromHPR(dl, DAG, VA.getLocVT(), VA.getValVT(), Arg); + } else { + // f16 arguments could have been extended prior to argument lowering. + // Mask them arguments if this is a CMSE nonsecure call. + auto ArgVT = Outs[realArgIdx].ArgVT; + if (isCmseNSCall && (ArgVT == MVT::f16)) { + auto LocBits = VA.getLocVT().getSizeInBits(); + auto MaskValue = APInt::getLowBitsSet(LocBits, ArgVT.getSizeInBits()); + SDValue Mask = + DAG.getConstant(MaskValue, dl, MVT::getIntegerVT(LocBits)); + Arg = DAG.getNode(ISD::BITCAST, dl, MVT::getIntegerVT(LocBits), Arg); + Arg = DAG.getNode(ISD::AND, dl, MVT::getIntegerVT(LocBits), Arg, Mask); + Arg = DAG.getNode(ISD::BITCAST, dl, VA.getLocVT(), Arg); + } + } - MemOpChains.push_back(LowerMemOpCallTo(Chain, StackPtr, Op1, - dl, DAG, VA, Flags)); - } + // f64 and v2f64 might be passed in i32 pairs and must be split into pieces + if (VA.needsCustom() && VA.getLocVT() == MVT::v2f64) { + SDValue Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f64, Arg, + DAG.getConstant(0, dl, MVT::i32)); + SDValue Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f64, Arg, + DAG.getConstant(1, dl, MVT::i32)); + + PassF64ArgInRegs(dl, DAG, Chain, Op0, RegsToPass, VA, ArgLocs[++i], + StackPtr, MemOpChains, isTailCall, SPDiff); + + VA = ArgLocs[++i]; // skip ahead to next loc + if (VA.isRegLoc()) { + PassF64ArgInRegs(dl, DAG, Chain, Op1, RegsToPass, VA, ArgLocs[++i], + StackPtr, MemOpChains, isTailCall, SPDiff); } else { - PassF64ArgInRegs(dl, DAG, Chain, Arg, RegsToPass, VA, ArgLocs[++i], - StackPtr, MemOpChains, Flags); + assert(VA.isMemLoc()); + SDValue DstAddr; + MachinePointerInfo DstInfo; + std::tie(DstAddr, DstInfo) = + computeAddrForCallArg(dl, DAG, VA, StackPtr, isTailCall, SPDiff); + MemOpChains.push_back(DAG.getStore(Chain, dl, Op1, DstAddr, DstInfo)); } + } else if (VA.needsCustom() && VA.getLocVT() == MVT::f64) { + PassF64ArgInRegs(dl, DAG, Chain, Arg, RegsToPass, VA, ArgLocs[++i], + StackPtr, MemOpChains, isTailCall, SPDiff); } else if (VA.isRegLoc()) { if (realArgIdx == 0 && Flags.isReturned() && !Flags.isSwiftSelf() && Outs[0].VT == MVT::i32) { @@ -2222,7 +2545,7 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, isThisReturn = true; } const TargetOptions &Options = DAG.getTarget().Options; - if (Options.EnableDebugEntryValues) + if (Options.EmitCallSiteInfo) CSInfo.emplace_back(VA.getLocReg(), i); RegsToPass.push_back(std::make_pair(VA.getLocReg(), Arg)); } else if (isByVal) { @@ -2245,9 +2568,9 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, for (i = 0, j = RegBegin; j < RegEnd; i++, j++) { SDValue Const = DAG.getConstant(4*i, dl, MVT::i32); SDValue AddArg = DAG.getNode(ISD::ADD, dl, PtrVT, Arg, Const); - SDValue Load = DAG.getLoad(PtrVT, dl, Chain, AddArg, - MachinePointerInfo(), - DAG.InferPtrAlignment(AddArg)); + SDValue Load = + DAG.getLoad(PtrVT, dl, Chain, AddArg, MachinePointerInfo(), + DAG.InferPtrAlign(AddArg)); MemOpChains.push_back(Load.getValue(1)); RegsToPass.push_back(std::make_pair(j, Load)); } @@ -2261,26 +2584,31 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, if (Flags.getByValSize() > 4*offset) { auto PtrVT = getPointerTy(DAG.getDataLayout()); - unsigned LocMemOffset = VA.getLocMemOffset(); - SDValue StkPtrOff = DAG.getIntPtrConstant(LocMemOffset, dl); - SDValue Dst = DAG.getNode(ISD::ADD, dl, PtrVT, StackPtr, StkPtrOff); + SDValue Dst; + MachinePointerInfo DstInfo; + std::tie(Dst, DstInfo) = + computeAddrForCallArg(dl, DAG, VA, StackPtr, isTailCall, SPDiff); SDValue SrcOffset = DAG.getIntPtrConstant(4*offset, dl); SDValue Src = DAG.getNode(ISD::ADD, dl, PtrVT, Arg, SrcOffset); SDValue SizeNode = DAG.getConstant(Flags.getByValSize() - 4*offset, dl, MVT::i32); - SDValue AlignNode = DAG.getConstant(Flags.getByValAlign(), dl, - MVT::i32); + SDValue AlignNode = + DAG.getConstant(Flags.getNonZeroByValAlign().value(), dl, MVT::i32); SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue); SDValue Ops[] = { Chain, Dst, Src, SizeNode, AlignNode}; MemOpChains.push_back(DAG.getNode(ARMISD::COPY_STRUCT_BYVAL, dl, VTs, Ops)); } - } else if (!isTailCall) { + } else { assert(VA.isMemLoc()); + SDValue DstAddr; + MachinePointerInfo DstInfo; + std::tie(DstAddr, DstInfo) = + computeAddrForCallArg(dl, DAG, VA, StackPtr, isTailCall, SPDiff); - MemOpChains.push_back(LowerMemOpCallTo(Chain, StackPtr, Arg, - dl, DAG, VA, Flags)); + SDValue Store = DAG.getStore(Chain, dl, Arg, DstAddr, DstInfo); + MemOpChains.push_back(Store); } } @@ -2303,15 +2631,14 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, const TargetMachine &TM = getTargetMachine(); const Module *Mod = MF.getFunction().getParent(); - const GlobalValue *GV = nullptr; + const GlobalValue *GVal = nullptr; if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) - GV = G->getGlobal(); + GVal = G->getGlobal(); bool isStub = - !TM.shouldAssumeDSOLocal(*Mod, GV) && Subtarget->isTargetMachO(); + !TM.shouldAssumeDSOLocal(*Mod, GVal) && Subtarget->isTargetMachO(); bool isARMFunc = !Subtarget->isThumb() || (isStub && !Subtarget->isMClass()); bool isLocalARMFunc = false; - ARMFunctionInfo *AFI = MF.getInfo<ARMFunctionInfo>(); auto PtrVt = getPointerTy(DAG.getDataLayout()); if (Subtarget->genLongCalls()) { @@ -2321,36 +2648,58 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // those, the target's already in a register, so we don't need to do // anything extra. if (isa<GlobalAddressSDNode>(Callee)) { - // Create a constant pool entry for the callee address - unsigned ARMPCLabelIndex = AFI->createPICLabelUId(); - ARMConstantPoolValue *CPV = - ARMConstantPoolConstant::Create(GV, ARMPCLabelIndex, ARMCP::CPValue, 0); - - // Get the address of the callee into a register - SDValue CPAddr = DAG.getTargetConstantPool(CPV, PtrVt, 4); - CPAddr = DAG.getNode(ARMISD::Wrapper, dl, MVT::i32, CPAddr); - Callee = DAG.getLoad( - PtrVt, dl, DAG.getEntryNode(), CPAddr, - MachinePointerInfo::getConstantPool(DAG.getMachineFunction())); + // When generating execute-only code we use movw movt pair. + // Currently execute-only is only available for architectures that + // support movw movt, so we are safe to assume that. + if (Subtarget->genExecuteOnly()) { + assert(Subtarget->useMovt() && + "long-calls with execute-only requires movt and movw!"); + ++NumMovwMovt; + Callee = DAG.getNode(ARMISD::Wrapper, dl, PtrVt, + DAG.getTargetGlobalAddress(GVal, dl, PtrVt)); + } else { + // Create a constant pool entry for the callee address + unsigned ARMPCLabelIndex = AFI->createPICLabelUId(); + ARMConstantPoolValue *CPV = ARMConstantPoolConstant::Create( + GVal, ARMPCLabelIndex, ARMCP::CPValue, 0); + + // Get the address of the callee into a register + SDValue Addr = DAG.getTargetConstantPool(CPV, PtrVt, Align(4)); + Addr = DAG.getNode(ARMISD::Wrapper, dl, MVT::i32, Addr); + Callee = DAG.getLoad( + PtrVt, dl, DAG.getEntryNode(), Addr, + MachinePointerInfo::getConstantPool(DAG.getMachineFunction())); + } } else if (ExternalSymbolSDNode *S=dyn_cast<ExternalSymbolSDNode>(Callee)) { const char *Sym = S->getSymbol(); - // Create a constant pool entry for the callee address - unsigned ARMPCLabelIndex = AFI->createPICLabelUId(); - ARMConstantPoolValue *CPV = - ARMConstantPoolSymbol::Create(*DAG.getContext(), Sym, - ARMPCLabelIndex, 0); - // Get the address of the callee into a register - SDValue CPAddr = DAG.getTargetConstantPool(CPV, PtrVt, 4); - CPAddr = DAG.getNode(ARMISD::Wrapper, dl, MVT::i32, CPAddr); - Callee = DAG.getLoad( - PtrVt, dl, DAG.getEntryNode(), CPAddr, - MachinePointerInfo::getConstantPool(DAG.getMachineFunction())); + // When generating execute-only code we use movw movt pair. + // Currently execute-only is only available for architectures that + // support movw movt, so we are safe to assume that. + if (Subtarget->genExecuteOnly()) { + assert(Subtarget->useMovt() && + "long-calls with execute-only requires movt and movw!"); + ++NumMovwMovt; + Callee = DAG.getNode(ARMISD::Wrapper, dl, PtrVt, + DAG.getTargetGlobalAddress(GVal, dl, PtrVt)); + } else { + // Create a constant pool entry for the callee address + unsigned ARMPCLabelIndex = AFI->createPICLabelUId(); + ARMConstantPoolValue *CPV = ARMConstantPoolSymbol::Create( + *DAG.getContext(), Sym, ARMPCLabelIndex, 0); + + // Get the address of the callee into a register + SDValue Addr = DAG.getTargetConstantPool(CPV, PtrVt, Align(4)); + Addr = DAG.getNode(ARMISD::Wrapper, dl, MVT::i32, Addr); + Callee = DAG.getLoad( + PtrVt, dl, DAG.getEntryNode(), Addr, + MachinePointerInfo::getConstantPool(DAG.getMachineFunction())); + } } } else if (isa<GlobalAddressSDNode>(Callee)) { if (!PreferIndirect) { isDirect = true; - bool isDef = GV->isStrongDefinitionForLinker(); + bool isDef = GVal->isStrongDefinitionForLinker(); // ARM call to a local ARM function is predicable. isLocalARMFunc = !Subtarget->isThumb() && (isDef || !ARMInterworking); @@ -2359,21 +2708,21 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, assert(Subtarget->isTargetMachO() && "WrapperPIC use on non-MachO?"); Callee = DAG.getNode( ARMISD::WrapperPIC, dl, PtrVt, - DAG.getTargetGlobalAddress(GV, dl, PtrVt, 0, ARMII::MO_NONLAZY)); + DAG.getTargetGlobalAddress(GVal, dl, PtrVt, 0, ARMII::MO_NONLAZY)); Callee = DAG.getLoad( PtrVt, dl, DAG.getEntryNode(), Callee, - MachinePointerInfo::getGOT(DAG.getMachineFunction()), - /* Alignment = */ 0, MachineMemOperand::MODereferenceable | - MachineMemOperand::MOInvariant); + MachinePointerInfo::getGOT(DAG.getMachineFunction()), MaybeAlign(), + MachineMemOperand::MODereferenceable | + MachineMemOperand::MOInvariant); } else if (Subtarget->isTargetCOFF()) { assert(Subtarget->isTargetWindows() && "Windows is the only supported COFF target"); unsigned TargetFlags = ARMII::MO_NO_FLAG; - if (GV->hasDLLImportStorageClass()) + if (GVal->hasDLLImportStorageClass()) TargetFlags = ARMII::MO_DLLIMPORT; - else if (!TM.shouldAssumeDSOLocal(*GV->getParent(), GV)) + else if (!TM.shouldAssumeDSOLocal(*GVal->getParent(), GVal)) TargetFlags = ARMII::MO_COFFSTUB; - Callee = DAG.getTargetGlobalAddress(GV, dl, PtrVt, /*offset=*/0, + Callee = DAG.getTargetGlobalAddress(GVal, dl, PtrVt, /*offset=*/0, TargetFlags); if (TargetFlags & (ARMII::MO_DLLIMPORT | ARMII::MO_COFFSTUB)) Callee = @@ -2381,7 +2730,7 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, DAG.getNode(ARMISD::Wrapper, dl, PtrVt, Callee), MachinePointerInfo::getGOT(DAG.getMachineFunction())); } else { - Callee = DAG.getTargetGlobalAddress(GV, dl, PtrVt, 0, 0); + Callee = DAG.getTargetGlobalAddress(GVal, dl, PtrVt, 0, 0); } } } else if (ExternalSymbolSDNode *S = dyn_cast<ExternalSymbolSDNode>(Callee)) { @@ -2393,7 +2742,7 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, ARMConstantPoolValue *CPV = ARMConstantPoolSymbol::Create(*DAG.getContext(), Sym, ARMPCLabelIndex, 4); - SDValue CPAddr = DAG.getTargetConstantPool(CPV, PtrVt, 4); + SDValue CPAddr = DAG.getTargetConstantPool(CPV, PtrVt, Align(4)); CPAddr = DAG.getNode(ARMISD::Wrapper, dl, MVT::i32, CPAddr); Callee = DAG.getLoad( PtrVt, dl, DAG.getEntryNode(), CPAddr, @@ -2405,10 +2754,33 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, } } + if (isCmseNSCall) { + assert(!isARMFunc && !isDirect && + "Cannot handle call to ARM function or direct call"); + if (NumBytes > 0) { + DiagnosticInfoUnsupported Diag(DAG.getMachineFunction().getFunction(), + "call to non-secure function would " + "require passing arguments on stack", + dl.getDebugLoc()); + DAG.getContext()->diagnose(Diag); + } + if (isStructRet) { + DiagnosticInfoUnsupported Diag( + DAG.getMachineFunction().getFunction(), + "call to non-secure function would return value through pointer", + dl.getDebugLoc()); + DAG.getContext()->diagnose(Diag); + } + } + // FIXME: handle tail calls differently. unsigned CallOpc; if (Subtarget->isThumb()) { - if ((!isDirect || isARMFunc) && !Subtarget->hasV5TOps()) + if (GuardWithBTI) + CallOpc = ARMISD::t2CALL_BTI; + else if (isCmseNSCall) + CallOpc = ARMISD::tSECALL; + else if ((!isDirect || isARMFunc) && !Subtarget->hasV5TOps()) CallOpc = ARMISD::CALL_NOLINK; else CallOpc = ARMISD::CALL; @@ -2424,10 +2796,23 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, CallOpc = isLocalARMFunc ? ARMISD::CALL_PRED : ARMISD::CALL; } + // We don't usually want to end the call-sequence here because we would tidy + // the frame up *after* the call, however in the ABI-changing tail-call case + // we've carefully laid out the parameters so that when sp is reset they'll be + // in the correct location. + if (isTailCall && !isSibCall) { + Chain = DAG.getCALLSEQ_END(Chain, 0, 0, InFlag, dl); + InFlag = Chain.getValue(1); + } + std::vector<SDValue> Ops; Ops.push_back(Chain); Ops.push_back(Callee); + if (isTailCall) { + Ops.push_back(DAG.getTargetConstant(SPDiff, dl, MVT::i32)); + } + // Add argument registers to the end of the list so that they are known live // into the call. for (unsigned i = 0, e = RegsToPass.size(); i != e; ++i) @@ -2435,25 +2820,23 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, RegsToPass[i].second.getValueType())); // Add a register mask operand representing the call-preserved registers. - if (!isTailCall) { - const uint32_t *Mask; - const ARMBaseRegisterInfo *ARI = Subtarget->getRegisterInfo(); - if (isThisReturn) { - // For 'this' returns, use the R0-preserving mask if applicable - Mask = ARI->getThisReturnPreservedMask(MF, CallConv); - if (!Mask) { - // Set isThisReturn to false if the calling convention is not one that - // allows 'returned' to be modeled in this way, so LowerCallResult does - // not try to pass 'this' straight through - isThisReturn = false; - Mask = ARI->getCallPreservedMask(MF, CallConv); - } - } else + const uint32_t *Mask; + const ARMBaseRegisterInfo *ARI = Subtarget->getRegisterInfo(); + if (isThisReturn) { + // For 'this' returns, use the R0-preserving mask if applicable + Mask = ARI->getThisReturnPreservedMask(MF, CallConv); + if (!Mask) { + // Set isThisReturn to false if the calling convention is not one that + // allows 'returned' to be modeled in this way, so LowerCallResult does + // not try to pass 'this' straight through + isThisReturn = false; Mask = ARI->getCallPreservedMask(MF, CallConv); + } + } else + Mask = ARI->getCallPreservedMask(MF, CallConv); - assert(Mask && "Missing call preserved mask for calling convention"); - Ops.push_back(DAG.getRegisterMask(Mask)); - } + assert(Mask && "Missing call preserved mask for calling convention"); + Ops.push_back(DAG.getRegisterMask(Mask)); if (InFlag.getNode()) Ops.push_back(InFlag); @@ -2468,11 +2851,18 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // Returns a chain and a flag for retval copy to use. Chain = DAG.getNode(CallOpc, dl, NodeTys, Ops); + DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge); InFlag = Chain.getValue(1); DAG.addCallSiteInfo(Chain.getNode(), std::move(CSInfo)); - Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(NumBytes, dl, true), - DAG.getIntPtrConstant(0, dl, true), InFlag, dl); + // If we're guaranteeing tail-calls will be honoured, the callee must + // pop its own argument stack on return. But this call is *not* a tail call so + // we need to undo that after it returns to restore the status-quo. + bool TailCallOpt = getTargetMachine().Options.GuaranteedTailCallOpt; + uint64_t CalleePopBytes = + canGuaranteeTCO(CallConv, TailCallOpt) ? alignTo(NumBytes, 16) : -1ULL; + + Chain = DAG.getCALLSEQ_END(Chain, NumBytes, CalleePopBytes, InFlag, dl); if (!Ins.empty()) InFlag = Chain.getValue(1); @@ -2488,15 +2878,15 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, /// and then confiscate the rest of the parameter registers to insure /// this. void ARMTargetLowering::HandleByVal(CCState *State, unsigned &Size, - unsigned Align) const { + Align Alignment) const { // Byval (as with any stack) slots are always at least 4 byte aligned. - Align = std::max(Align, 4U); + Alignment = std::max(Alignment, Align(4)); unsigned Reg = State->AllocateReg(GPRArgRegs); if (!Reg) return; - unsigned AlignInRegs = Align / 4; + unsigned AlignInRegs = Alignment.value() / 4; unsigned Waste = (ARM::R4 - Reg) % AlignInRegs; for (unsigned i = 0; i < Waste; ++i) Reg = State->AllocateReg(GPRArgRegs); @@ -2547,8 +2937,8 @@ bool MatchingStackOffset(SDValue Arg, unsigned Offset, ISD::ArgFlagsTy Flags, unsigned Bytes = Arg.getValueSizeInBits() / 8; int FI = std::numeric_limits<int>::max(); if (Arg.getOpcode() == ISD::CopyFromReg) { - unsigned VR = cast<RegisterSDNode>(Arg.getOperand(1))->getReg(); - if (!Register::isVirtualRegister(VR)) + Register VR = cast<RegisterSDNode>(Arg.getOperand(1))->getReg(); + if (!VR.isVirtual()) return false; MachineInstr *Def = MRI->getVRegDef(VR); if (!Def) @@ -2600,9 +2990,17 @@ bool ARMTargetLowering::IsEligibleForTailCallOptimization( // Indirect tail calls cannot be optimized for Thumb1 if the args // to the call take up r0-r3. The reason is that there are no legal registers // left to hold the pointer to the function to be called. - if (Subtarget->isThumb1Only() && Outs.size() >= 4 && - (!isa<GlobalAddressSDNode>(Callee.getNode()) || isIndirect)) - return false; + // Similarly, if the function uses return address sign and authentication, + // r12 is needed to hold the PAC and is not available to hold the callee + // address. + if (Outs.size() >= 4 && + (!isa<GlobalAddressSDNode>(Callee.getNode()) || isIndirect)) { + if (Subtarget->isThumb1Only()) + return false; + // Conservatively assume the function spills LR. + if (MF.getInfo<ARMFunctionInfo>()->shouldSignReturnAddress(true)) + return false; + } // Look for obvious safe cases to perform tail call optimization that do not // require ABI changes. This is what gcc calls sibcall. @@ -2613,6 +3011,9 @@ bool ARMTargetLowering::IsEligibleForTailCallOptimization( if (CallerF.hasFnAttribute("interrupt")) return false; + if (canGuaranteeTCO(CalleeCC, getTargetMachine().Options.GuaranteedTailCallOpt)) + return CalleeCC == CallerCC; + // Also avoid sibcall optimization if either caller or callee uses struct // return semantics. if (isCalleeStructRet || isCallerStructRet) @@ -2635,9 +3036,11 @@ bool ARMTargetLowering::IsEligibleForTailCallOptimization( // Check that the call results are passed in the same way. LLVMContext &C = *DAG.getContext(); - if (!CCState::resultsCompatible(CalleeCC, CallerCC, MF, C, Ins, - CCAssignFnForReturn(CalleeCC, isVarArg), - CCAssignFnForReturn(CallerCC, isVarArg))) + if (!CCState::resultsCompatible( + getEffectiveCallingConv(CalleeCC, isVarArg), + getEffectiveCallingConv(CallerCC, CallerF.isVarArg()), MF, C, Ins, + CCAssignFnForReturn(CalleeCC, isVarArg), + CCAssignFnForReturn(CallerCC, CallerF.isVarArg()))) return false; // The callee has to preserve all registers the caller needs to preserve. const ARMBaseRegisterInfo *TRI = Subtarget->getRegisterInfo(); @@ -2678,7 +3081,7 @@ bool ARMTargetLowering::IsEligibleForTailCallOptimization( ISD::ArgFlagsTy Flags = Outs[realArgIdx].Flags; if (VA.getLocInfo() == CCValAssign::Indirect) return false; - if (VA.needsCustom()) { + if (VA.needsCustom() && (RegVT == MVT::f64 || RegVT == MVT::v2f64)) { // f64 and vector types are split into multiple registers or // register/stack-slot combinations. The types will not match // the registers; give up on memory f64 refs until we figure @@ -2777,6 +3180,17 @@ ARMTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, ARMFunctionInfo *AFI = MF.getInfo<ARMFunctionInfo>(); AFI->setReturnRegsCount(RVLocs.size()); + // Report error if cmse entry function returns structure through first ptr arg. + if (AFI->isCmseNSEntryFunction() && MF.getFunction().hasStructRetAttr()) { + // Note: using an empty SDLoc(), as the first line of the function is a + // better place to report than the last line. + DiagnosticInfoUnsupported Diag( + DAG.getMachineFunction().getFunction(), + "secure entry function would return value through pointer", + SDLoc().getDebugLoc()); + DAG.getContext()->diagnose(Diag); + } + // Copy the result values into the output registers. for (unsigned i = 0, realRVLocIdx = 0; i != RVLocs.size(); @@ -2819,7 +3233,24 @@ ARMTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, break; } - if (VA.needsCustom()) { + // Mask f16 arguments if this is a CMSE nonsecure entry. + auto RetVT = Outs[realRVLocIdx].ArgVT; + if (AFI->isCmseNSEntryFunction() && (RetVT == MVT::f16)) { + if (VA.needsCustom() && VA.getValVT() == MVT::f16) { + Arg = MoveFromHPR(dl, DAG, VA.getLocVT(), VA.getValVT(), Arg); + } else { + auto LocBits = VA.getLocVT().getSizeInBits(); + auto MaskValue = APInt::getLowBitsSet(LocBits, RetVT.getSizeInBits()); + SDValue Mask = + DAG.getConstant(MaskValue, dl, MVT::getIntegerVT(LocBits)); + Arg = DAG.getNode(ISD::BITCAST, dl, MVT::getIntegerVT(LocBits), Arg); + Arg = DAG.getNode(ISD::AND, dl, MVT::getIntegerVT(LocBits), Arg, Mask); + Arg = DAG.getNode(ISD::BITCAST, dl, VA.getLocVT(), Arg); + } + } + + if (VA.needsCustom() && + (VA.getLocVT() == MVT::v2f64 || VA.getLocVT() == MVT::f64)) { if (VA.getLocVT() == MVT::v2f64) { // Extract the first half and return it in two registers. SDValue Half = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f64, Arg, @@ -2827,15 +3258,15 @@ ARMTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, SDValue HalfGPRs = DAG.getNode(ARMISD::VMOVRRD, dl, DAG.getVTList(MVT::i32, MVT::i32), Half); - Chain = DAG.getCopyToReg(Chain, dl, VA.getLocReg(), - HalfGPRs.getValue(isLittleEndian ? 0 : 1), - Flag); + Chain = + DAG.getCopyToReg(Chain, dl, VA.getLocReg(), + HalfGPRs.getValue(isLittleEndian ? 0 : 1), Flag); Flag = Chain.getValue(1); RetOps.push_back(DAG.getRegister(VA.getLocReg(), VA.getLocVT())); VA = RVLocs[++i]; // skip ahead to next loc - Chain = DAG.getCopyToReg(Chain, dl, VA.getLocReg(), - HalfGPRs.getValue(isLittleEndian ? 1 : 0), - Flag); + Chain = + DAG.getCopyToReg(Chain, dl, VA.getLocReg(), + HalfGPRs.getValue(isLittleEndian ? 1 : 0), Flag); Flag = Chain.getValue(1); RetOps.push_back(DAG.getRegister(VA.getLocReg(), VA.getLocVT())); VA = RVLocs[++i]; // skip ahead to next loc @@ -2849,22 +3280,20 @@ ARMTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, SDValue fmrrd = DAG.getNode(ARMISD::VMOVRRD, dl, DAG.getVTList(MVT::i32, MVT::i32), Arg); Chain = DAG.getCopyToReg(Chain, dl, VA.getLocReg(), - fmrrd.getValue(isLittleEndian ? 0 : 1), - Flag); + fmrrd.getValue(isLittleEndian ? 0 : 1), Flag); Flag = Chain.getValue(1); RetOps.push_back(DAG.getRegister(VA.getLocReg(), VA.getLocVT())); VA = RVLocs[++i]; // skip ahead to next loc Chain = DAG.getCopyToReg(Chain, dl, VA.getLocReg(), - fmrrd.getValue(isLittleEndian ? 1 : 0), - Flag); + fmrrd.getValue(isLittleEndian ? 1 : 0), Flag); } else Chain = DAG.getCopyToReg(Chain, dl, VA.getLocReg(), Arg, Flag); // Guarantee that all emitted copies are // stuck together, avoiding something bad. Flag = Chain.getValue(1); - RetOps.push_back(DAG.getRegister(VA.getLocReg(), - ReturnF16 ? MVT::f16 : VA.getLocVT())); + RetOps.push_back(DAG.getRegister( + VA.getLocReg(), ReturnF16 ? Arg.getValueType() : VA.getLocVT())); } const ARMBaseRegisterInfo *TRI = Subtarget->getRegisterInfo(); const MCPhysReg *I = @@ -2898,7 +3327,9 @@ ARMTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, return LowerInterruptReturn(RetOps, dl, DAG); } - return DAG.getNode(ARMISD::RET_FLAG, dl, MVT::Other, RetOps); + ARMISD::NodeType RetNode = AFI->isCmseNSEntryFunction() ? ARMISD::SERET_FLAG : + ARMISD::RET_FLAG; + return DAG.getNode(RetNode, dl, MVT::Other, RetOps); } bool ARMTargetLowering::isUsedByReturnOnly(SDNode *N, SDValue &Chain) const { @@ -2919,26 +3350,24 @@ bool ARMTargetLowering::isUsedByReturnOnly(SDNode *N, SDValue &Chain) const { SDNode *VMov = Copy; // f64 returned in a pair of GPRs. SmallPtrSet<SDNode*, 2> Copies; - for (SDNode::use_iterator UI = VMov->use_begin(), UE = VMov->use_end(); - UI != UE; ++UI) { - if (UI->getOpcode() != ISD::CopyToReg) + for (SDNode *U : VMov->uses()) { + if (U->getOpcode() != ISD::CopyToReg) return false; - Copies.insert(*UI); + Copies.insert(U); } if (Copies.size() > 2) return false; - for (SDNode::use_iterator UI = VMov->use_begin(), UE = VMov->use_end(); - UI != UE; ++UI) { - SDValue UseChain = UI->getOperand(0); + for (SDNode *U : VMov->uses()) { + SDValue UseChain = U->getOperand(0); if (Copies.count(UseChain.getNode())) // Second CopyToReg - Copy = *UI; + Copy = U; else { // We are at the top of this chain. // If the copy has a glue operand, we conservatively assume it // isn't safe to perform a tail call. - if (UI->getOperand(UI->getNumOperands()-1).getValueType() == MVT::Glue) + if (U->getOperand(U->getNumOperands() - 1).getValueType() == MVT::Glue) return false; // First CopyToReg TCChain = UseChain; @@ -2961,10 +3390,9 @@ bool ARMTargetLowering::isUsedByReturnOnly(SDNode *N, SDValue &Chain) const { } bool HasRet = false; - for (SDNode::use_iterator UI = Copy->use_begin(), UE = Copy->use_end(); - UI != UE; ++UI) { - if (UI->getOpcode() != ARMISD::RET_FLAG && - UI->getOpcode() != ARMISD::INTRET_FLAG) + for (const SDNode *U : Copy->uses()) { + if (U->getOpcode() != ARMISD::RET_FLAG && + U->getOpcode() != ARMISD::INTRET_FLAG) return false; HasRet = true; } @@ -3039,12 +3467,16 @@ SDValue ARMTargetLowering::LowerConstantPool(SDValue Op, return LowerGlobalAddress(GA, DAG); } + // The 16-bit ADR instruction can only encode offsets that are multiples of 4, + // so we need to align to at least 4 bytes when we don't have 32-bit ADR. + Align CPAlign = CP->getAlign(); + if (Subtarget->isThumb1Only()) + CPAlign = std::max(CPAlign, Align(4)); if (CP->isMachineConstantPoolEntry()) - Res = DAG.getTargetConstantPool(CP->getMachineCPVal(), PtrVT, - CP->getAlignment()); + Res = + DAG.getTargetConstantPool(CP->getMachineCPVal(), PtrVT, CPAlign); else - Res = DAG.getTargetConstantPool(CP->getConstVal(), PtrVT, - CP->getAlignment()); + Res = DAG.getTargetConstantPool(CP->getConstVal(), PtrVT, CPAlign); return DAG.getNode(ARMISD::Wrapper, dl, MVT::i32, Res); } @@ -3063,14 +3495,14 @@ SDValue ARMTargetLowering::LowerBlockAddress(SDValue Op, SDValue CPAddr; bool IsPositionIndependent = isPositionIndependent() || Subtarget->isROPI(); if (!IsPositionIndependent) { - CPAddr = DAG.getTargetConstantPool(BA, PtrVT, 4); + CPAddr = DAG.getTargetConstantPool(BA, PtrVT, Align(4)); } else { unsigned PCAdj = Subtarget->isThumb() ? 4 : 8; ARMPCLabelIndex = AFI->createPICLabelUId(); ARMConstantPoolValue *CPV = ARMConstantPoolConstant::Create(BA, ARMPCLabelIndex, ARMCP::CPBlockAddress, PCAdj); - CPAddr = DAG.getTargetConstantPool(CPV, PtrVT, 4); + CPAddr = DAG.getTargetConstantPool(CPV, PtrVT, Align(4)); } CPAddr = DAG.getNode(ARMISD::Wrapper, DL, PtrVT, CPAddr); SDValue Result = DAG.getLoad( @@ -3122,8 +3554,7 @@ ARMTargetLowering::LowerGlobalTLSAddressDarwin(SDValue Op, SDValue Chain = DAG.getEntryNode(); SDValue FuncTLVGet = DAG.getLoad( MVT::i32, DL, Chain, DescAddr, - MachinePointerInfo::getGOT(DAG.getMachineFunction()), - /* Alignment = */ 4, + MachinePointerInfo::getGOT(DAG.getMachineFunction()), Align(4), MachineMemOperand::MONonTemporal | MachineMemOperand::MODereferenceable | MachineMemOperand::MOInvariant); Chain = FuncTLVGet.getValue(1); @@ -3199,8 +3630,9 @@ ARMTargetLowering::LowerGlobalTLSAddressWindows(SDValue Op, const auto *GA = cast<GlobalAddressSDNode>(Op); auto *CPV = ARMConstantPoolConstant::Create(GA->getGlobal(), ARMCP::SECREL); SDValue Offset = DAG.getLoad( - PtrVT, DL, Chain, DAG.getNode(ARMISD::Wrapper, DL, MVT::i32, - DAG.getTargetConstantPool(CPV, PtrVT, 4)), + PtrVT, DL, Chain, + DAG.getNode(ARMISD::Wrapper, DL, MVT::i32, + DAG.getTargetConstantPool(CPV, PtrVT, Align(4))), MachinePointerInfo::getConstantPool(DAG.getMachineFunction())); return DAG.getNode(ISD::ADD, DL, PtrVT, TLS, Offset); @@ -3219,7 +3651,7 @@ ARMTargetLowering::LowerToTLSGeneralDynamicModel(GlobalAddressSDNode *GA, ARMConstantPoolValue *CPV = ARMConstantPoolConstant::Create(GA->getGlobal(), ARMPCLabelIndex, ARMCP::CPValue, PCAdj, ARMCP::TLSGD, true); - SDValue Argument = DAG.getTargetConstantPool(CPV, PtrVT, 4); + SDValue Argument = DAG.getTargetConstantPool(CPV, PtrVT, Align(4)); Argument = DAG.getNode(ARMISD::Wrapper, dl, MVT::i32, Argument); Argument = DAG.getLoad( PtrVT, dl, DAG.getEntryNode(), Argument, @@ -3270,7 +3702,7 @@ ARMTargetLowering::LowerToTLSExecModels(GlobalAddressSDNode *GA, ARMConstantPoolConstant::Create(GA->getGlobal(), ARMPCLabelIndex, ARMCP::CPValue, PCAdj, ARMCP::GOTTPOFF, true); - Offset = DAG.getTargetConstantPool(CPV, PtrVT, 4); + Offset = DAG.getTargetConstantPool(CPV, PtrVT, Align(4)); Offset = DAG.getNode(ARMISD::Wrapper, dl, MVT::i32, Offset); Offset = DAG.getLoad( PtrVT, dl, Chain, Offset, @@ -3288,7 +3720,7 @@ ARMTargetLowering::LowerToTLSExecModels(GlobalAddressSDNode *GA, assert(model == TLSModel::LocalExec); ARMConstantPoolValue *CPV = ARMConstantPoolConstant::Create(GV, ARMCP::TPOFF); - Offset = DAG.getTargetConstantPool(CPV, PtrVT, 4); + Offset = DAG.getTargetConstantPool(CPV, PtrVT, Align(4)); Offset = DAG.getNode(ARMISD::Wrapper, dl, MVT::i32, Offset); Offset = DAG.getLoad( PtrVT, dl, Chain, Offset, @@ -3330,14 +3762,11 @@ ARMTargetLowering::LowerGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const { /// Return true if all users of V are within function F, looking through /// ConstantExprs. static bool allUsersAreInFunction(const Value *V, const Function *F) { - SmallVector<const User*,4> Worklist; - for (auto *U : V->users()) - Worklist.push_back(U); + SmallVector<const User*,4> Worklist(V->users()); while (!Worklist.empty()) { auto *U = Worklist.pop_back_val(); if (isa<ConstantExpr>(U)) { - for (auto *UU : U->users()) - Worklist.push_back(UU); + append_range(Worklist, U->users()); continue; } @@ -3380,7 +3809,7 @@ static SDValue promoteToConstantPool(const ARMTargetLowering *TLI, // from .data to .text. This is not allowed in position-independent code. auto *Init = GVar->getInitializer(); if ((TLI->isPositionIndependent() || TLI->getSubtarget()->isROPI()) && - Init->needsRelocation()) + Init->needsDynamicRelocation()) return SDValue(); // The constant islands pass can only really deal with alignment requests @@ -3391,11 +3820,11 @@ static SDValue promoteToConstantPool(const ARMTargetLowering *TLI, // that are strings for simplicity. auto *CDAInit = dyn_cast<ConstantDataArray>(Init); unsigned Size = DAG.getDataLayout().getTypeAllocSize(Init->getType()); - unsigned Align = DAG.getDataLayout().getPreferredAlignment(GVar); + Align PrefAlign = DAG.getDataLayout().getPreferredAlign(GVar); unsigned RequiredPadding = 4 - (Size % 4); bool PaddingPossible = RequiredPadding == 4 || (CDAInit && CDAInit->isString()); - if (!PaddingPossible || Align > 4 || Size > ConstpoolPromotionMaxSize || + if (!PaddingPossible || PrefAlign > 4 || Size > ConstpoolPromotionMaxSize || Size == 0) return SDValue(); @@ -3434,8 +3863,7 @@ static SDValue promoteToConstantPool(const ARMTargetLowering *TLI, } auto CPVal = ARMConstantPoolConstant::Create(GVar, Init); - SDValue CPAddr = - DAG.getTargetConstantPool(CPVal, PtrVT, /*Align=*/4); + SDValue CPAddr = DAG.getTargetConstantPool(CPVal, PtrVT, Align(4)); if (!AFI->getGlobalsPromotedToConstantPool().count(GVar)) { AFI->markGlobalAsPromotedToConstantPool(GVar); AFI->setPromotedConstpoolIncrease(AFI->getPromotedConstpoolIncrease() + @@ -3447,7 +3875,7 @@ static SDValue promoteToConstantPool(const ARMTargetLowering *TLI, bool ARMTargetLowering::isReadOnly(const GlobalValue *GV) const { if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(GV)) - if (!(GV = GA->getBaseObject())) + if (!(GV = GA->getAliaseeObject())) return false; if (const auto *V = dyn_cast<GlobalVariable>(GV)) return V->isConstant(); @@ -3505,7 +3933,7 @@ SDValue ARMTargetLowering::LowerGlobalAddressELF(SDValue Op, } else { // use literal pool for address constant ARMConstantPoolValue *CPV = ARMConstantPoolConstant::Create(GV, ARMCP::SBREL); - SDValue CPAddr = DAG.getTargetConstantPool(CPV, PtrVT, 4); + SDValue CPAddr = DAG.getTargetConstantPool(CPV, PtrVT, Align(4)); CPAddr = DAG.getNode(ARMISD::Wrapper, dl, MVT::i32, CPAddr); RelAddr = DAG.getLoad( PtrVT, dl, DAG.getEntryNode(), CPAddr, @@ -3525,7 +3953,7 @@ SDValue ARMTargetLowering::LowerGlobalAddressELF(SDValue Op, return DAG.getNode(ARMISD::Wrapper, dl, PtrVT, DAG.getTargetGlobalAddress(GV, dl, PtrVT)); } else { - SDValue CPAddr = DAG.getTargetConstantPool(GV, PtrVT, 4); + SDValue CPAddr = DAG.getTargetConstantPool(GV, PtrVT, Align(4)); CPAddr = DAG.getNode(ARMISD::Wrapper, dl, MVT::i32, CPAddr); return DAG.getLoad( PtrVT, dl, DAG.getEntryNode(), CPAddr, @@ -3633,10 +4061,10 @@ SDValue ARMTargetLowering::LowerINTRINSIC_VOID( ARI->getCallPreservedMask(DAG.getMachineFunction(), CallingConv::C); assert(Mask && "Missing call preserved mask for calling convention"); // Mark LR an implicit live-in. - unsigned Reg = MF.addLiveIn(ARM::LR, getRegClassFor(MVT::i32)); + Register Reg = MF.addLiveIn(ARM::LR, getRegClassFor(MVT::i32)); SDValue ReturnAddress = DAG.getCopyFromReg(DAG.getEntryNode(), dl, Reg, PtrVT); - std::vector<EVT> ResultTys = {MVT::Other, MVT::Glue}; + constexpr EVT ResultTys[] = {MVT::Other, MVT::Glue}; SDValue Callee = DAG.getTargetExternalSymbol("\01__gnu_mcount_nc", PtrVT, 0); SDValue RegisterMask = DAG.getRegisterMask(Mask); @@ -3720,7 +4148,7 @@ ARMTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG, ARMConstantPoolValue *CPV = ARMConstantPoolConstant::Create(&MF.getFunction(), ARMPCLabelIndex, ARMCP::CPLSDA, PCAdj); - CPAddr = DAG.getTargetConstantPool(CPV, PtrVT, 4); + CPAddr = DAG.getTargetConstantPool(CPV, PtrVT, Align(4)); CPAddr = DAG.getNode(ARMISD::Wrapper, dl, MVT::i32, CPAddr); SDValue Result = DAG.getLoad( PtrVT, dl, DAG.getEntryNode(), CPAddr, @@ -3782,6 +4210,15 @@ ARMTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG, case Intrinsic::arm_mve_pred_v2i: return DAG.getNode(ARMISD::PREDICATE_CAST, SDLoc(Op), Op.getValueType(), Op.getOperand(1)); + case Intrinsic::arm_mve_vreinterpretq: + return DAG.getNode(ARMISD::VECTOR_REG_CAST, SDLoc(Op), Op.getValueType(), + Op.getOperand(1)); + case Intrinsic::arm_mve_lsll: + return DAG.getNode(ARMISD::LSLL, SDLoc(Op), Op->getVTList(), + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); + case Intrinsic::arm_mve_asrl: + return DAG.getNode(ARMISD::ASRL, SDLoc(Op), Op->getVTList(), + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); } } @@ -3878,7 +4315,7 @@ SDValue ARMTargetLowering::GetF64FormalArgument(CCValAssign &VA, RC = &ARM::GPRRegClass; // Transform the arguments stored in physical registers into virtual ones. - unsigned Reg = MF.addLiveIn(VA.getLocReg(), RC); + Register Reg = MF.addLiveIn(VA.getLocReg(), RC); SDValue ArgValue = DAG.getCopyFromReg(Root, dl, Reg, MVT::i32); SDValue ArgValue2; @@ -3948,7 +4385,7 @@ int ARMTargetLowering::StoreByValRegs(CCState &CCInfo, SelectionDAG &DAG, AFI->isThumb1OnlyFunction() ? &ARM::tGPRRegClass : &ARM::GPRRegClass; for (unsigned Reg = RBegin, i = 0; Reg < REnd; ++Reg, ++i) { - unsigned VReg = MF.addLiveIn(Reg, RC); + Register VReg = MF.addLiveIn(Reg, RC); SDValue Val = DAG.getCopyFromReg(Chain, dl, VReg, MVT::i32); SDValue Store = DAG.getStore(Val.getValue(1), dl, Val, FIN, MachinePointerInfo(OrigArg, 4 * i)); @@ -3982,6 +4419,42 @@ void ARMTargetLowering::VarArgStyleRegisters(CCState &CCInfo, SelectionDAG &DAG, AFI->setVarArgsFrameIndex(FrameIndex); } +bool ARMTargetLowering::splitValueIntoRegisterParts( + SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts, + unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const { + bool IsABIRegCopy = CC.has_value(); + EVT ValueVT = Val.getValueType(); + if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) && + PartVT == MVT::f32) { + unsigned ValueBits = ValueVT.getSizeInBits(); + unsigned PartBits = PartVT.getSizeInBits(); + Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(ValueBits), Val); + Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::getIntegerVT(PartBits), Val); + Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val); + Parts[0] = Val; + return true; + } + return false; +} + +SDValue ARMTargetLowering::joinRegisterPartsIntoValue( + SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts, + MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const { + bool IsABIRegCopy = CC.has_value(); + if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) && + PartVT == MVT::f32) { + unsigned ValueBits = ValueVT.getSizeInBits(); + unsigned PartBits = PartVT.getSizeInBits(); + SDValue Val = Parts[0]; + + Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(PartBits), Val); + Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::getIntegerVT(ValueBits), Val); + Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); + return Val; + } + return SDValue(); +} + SDValue ARMTargetLowering::LowerFormalArguments( SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl, @@ -4035,7 +4508,7 @@ SDValue ARMTargetLowering::LowerFormalArguments( int lastInsIndex = -1; if (isVarArg && MFI.hasVAStart()) { unsigned RegIdx = CCInfo.getFirstUnallocated(GPRArgRegs); - if (RegIdx != array_lengthof(GPRArgRegs)) + if (RegIdx != std::size(GPRArgRegs)) ArgRegBegin = std::min(ArgRegBegin, (unsigned)GPRArgRegs[RegIdx]); } @@ -4054,44 +4527,41 @@ SDValue ARMTargetLowering::LowerFormalArguments( if (VA.isRegLoc()) { EVT RegVT = VA.getLocVT(); - if (VA.needsCustom()) { + if (VA.needsCustom() && VA.getLocVT() == MVT::v2f64) { // f64 and vector types are split up into multiple registers or // combinations of registers and stack slots. - if (VA.getLocVT() == MVT::v2f64) { - SDValue ArgValue1 = GetF64FormalArgument(VA, ArgLocs[++i], - Chain, DAG, dl); - VA = ArgLocs[++i]; // skip ahead to next loc - SDValue ArgValue2; - if (VA.isMemLoc()) { - int FI = MFI.CreateFixedObject(8, VA.getLocMemOffset(), true); - SDValue FIN = DAG.getFrameIndex(FI, PtrVT); - ArgValue2 = DAG.getLoad(MVT::f64, dl, Chain, FIN, - MachinePointerInfo::getFixedStack( - DAG.getMachineFunction(), FI)); - } else { - ArgValue2 = GetF64FormalArgument(VA, ArgLocs[++i], - Chain, DAG, dl); - } - ArgValue = DAG.getNode(ISD::UNDEF, dl, MVT::v2f64); - ArgValue = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, MVT::v2f64, - ArgValue, ArgValue1, - DAG.getIntPtrConstant(0, dl)); - ArgValue = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, MVT::v2f64, - ArgValue, ArgValue2, - DAG.getIntPtrConstant(1, dl)); - } else - ArgValue = GetF64FormalArgument(VA, ArgLocs[++i], Chain, DAG, dl); + SDValue ArgValue1 = + GetF64FormalArgument(VA, ArgLocs[++i], Chain, DAG, dl); + VA = ArgLocs[++i]; // skip ahead to next loc + SDValue ArgValue2; + if (VA.isMemLoc()) { + int FI = MFI.CreateFixedObject(8, VA.getLocMemOffset(), true); + SDValue FIN = DAG.getFrameIndex(FI, PtrVT); + ArgValue2 = DAG.getLoad( + MVT::f64, dl, Chain, FIN, + MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI)); + } else { + ArgValue2 = GetF64FormalArgument(VA, ArgLocs[++i], Chain, DAG, dl); + } + ArgValue = DAG.getNode(ISD::UNDEF, dl, MVT::v2f64); + ArgValue = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, MVT::v2f64, ArgValue, + ArgValue1, DAG.getIntPtrConstant(0, dl)); + ArgValue = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, MVT::v2f64, ArgValue, + ArgValue2, DAG.getIntPtrConstant(1, dl)); + } else if (VA.needsCustom() && VA.getLocVT() == MVT::f64) { + ArgValue = GetF64FormalArgument(VA, ArgLocs[++i], Chain, DAG, dl); } else { const TargetRegisterClass *RC; - - if (RegVT == MVT::f16) + if (RegVT == MVT::f16 || RegVT == MVT::bf16) RC = &ARM::HPRRegClass; else if (RegVT == MVT::f32) RC = &ARM::SPRRegClass; - else if (RegVT == MVT::f64 || RegVT == MVT::v4f16) + else if (RegVT == MVT::f64 || RegVT == MVT::v4f16 || + RegVT == MVT::v4bf16) RC = &ARM::DPRRegClass; - else if (RegVT == MVT::v2f64 || RegVT == MVT::v8f16) + else if (RegVT == MVT::v2f64 || RegVT == MVT::v8f16 || + RegVT == MVT::v8bf16) RC = &ARM::QPRRegClass; else if (RegVT == MVT::i32) RC = AFI->isThumb1OnlyFunction() ? &ARM::tGPRRegClass @@ -4100,7 +4570,7 @@ SDValue ARMTargetLowering::LowerFormalArguments( llvm_unreachable("RegVT not supported by FORMAL_ARGUMENTS Lowering"); // Transform the arguments in physical registers into virtual ones. - unsigned Reg = MF.addLiveIn(VA.getLocReg(), RC); + Register Reg = MF.addLiveIn(VA.getLocReg(), RC); ArgValue = DAG.getCopyFromReg(Chain, dl, Reg, RegVT); // If this value is passed in r0 and has the returned attribute (e.g. @@ -4131,9 +4601,16 @@ SDValue ARMTargetLowering::LowerFormalArguments( break; } + // f16 arguments have their size extended to 4 bytes and passed as if they + // had been copied to the LSBs of a 32-bit register. + // For that, it's passed extended to i32 (soft ABI) or to f32 (hard ABI) + if (VA.needsCustom() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) + ArgValue = MoveToHPR(dl, DAG, VA.getLocVT(), VA.getValVT(), ArgValue); + InVals.push_back(ArgValue); } else { // VA.isRegLoc() - // sanity check + // Only arguments passed on the stack should make it here. assert(VA.isMemLoc()); assert(VA.getValVT() != MVT::i64 && "i64 should already be lowered"); @@ -4176,12 +4653,35 @@ SDValue ARMTargetLowering::LowerFormalArguments( } // varargs - if (isVarArg && MFI.hasVAStart()) - VarArgStyleRegisters(CCInfo, DAG, dl, Chain, - CCInfo.getNextStackOffset(), + if (isVarArg && MFI.hasVAStart()) { + VarArgStyleRegisters(CCInfo, DAG, dl, Chain, CCInfo.getNextStackOffset(), TotalArgRegsSaveSize); + if (AFI->isCmseNSEntryFunction()) { + DiagnosticInfoUnsupported Diag( + DAG.getMachineFunction().getFunction(), + "secure entry function must not be variadic", dl.getDebugLoc()); + DAG.getContext()->diagnose(Diag); + } + } + + unsigned StackArgSize = CCInfo.getNextStackOffset(); + bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt; + if (canGuaranteeTCO(CallConv, TailCallOpt)) { + // The only way to guarantee a tail call is if the callee restores its + // argument area, but it must also keep the stack aligned when doing so. + const DataLayout &DL = DAG.getDataLayout(); + StackArgSize = alignTo(StackArgSize, DL.getStackAlignment()); - AFI->setArgumentStackSize(CCInfo.getNextStackOffset()); + AFI->setArgumentStackToRestore(StackArgSize); + } + AFI->setArgumentStackSize(StackArgSize); + + if (CCInfo.getNextStackOffset() > 0 && AFI->isCmseNSEntryFunction()) { + DiagnosticInfoUnsupported Diag( + DAG.getMachineFunction().getFunction(), + "secure entry function requires arguments on stack", dl.getDebugLoc()); + DAG.getContext()->diagnose(Diag); + } return Chain; } @@ -4546,24 +5046,49 @@ SDValue ARMTargetLowering::LowerUnsignedALUO(SDValue Op, return DAG.getNode(ISD::MERGE_VALUES, dl, VTs, Value, Overflow); } -static SDValue LowerSADDSUBSAT(SDValue Op, SelectionDAG &DAG, - const ARMSubtarget *Subtarget) { +static SDValue LowerADDSUBSAT(SDValue Op, SelectionDAG &DAG, + const ARMSubtarget *Subtarget) { EVT VT = Op.getValueType(); - if (!Subtarget->hasDSP()) + if (!Subtarget->hasV6Ops() || !Subtarget->hasDSP()) return SDValue(); if (!VT.isSimple()) return SDValue(); unsigned NewOpcode; - bool IsAdd = Op->getOpcode() == ISD::SADDSAT; switch (VT.getSimpleVT().SimpleTy) { default: return SDValue(); case MVT::i8: - NewOpcode = IsAdd ? ARMISD::QADD8b : ARMISD::QSUB8b; + switch (Op->getOpcode()) { + case ISD::UADDSAT: + NewOpcode = ARMISD::UQADD8b; + break; + case ISD::SADDSAT: + NewOpcode = ARMISD::QADD8b; + break; + case ISD::USUBSAT: + NewOpcode = ARMISD::UQSUB8b; + break; + case ISD::SSUBSAT: + NewOpcode = ARMISD::QSUB8b; + break; + } break; case MVT::i16: - NewOpcode = IsAdd ? ARMISD::QADD16b : ARMISD::QSUB16b; + switch (Op->getOpcode()) { + case ISD::UADDSAT: + NewOpcode = ARMISD::UQADD16b; + break; + case ISD::SADDSAT: + NewOpcode = ARMISD::QADD16b; + break; + case ISD::USUBSAT: + NewOpcode = ARMISD::UQSUB16b; + break; + case ISD::SSUBSAT: + NewOpcode = ARMISD::QSUB16b; + break; + } break; } @@ -4743,16 +5268,6 @@ static bool isLowerSaturate(const SDValue LHS, const SDValue RHS, ((K == RHS && K == TrueVal) || (K == LHS && K == FalseVal))); } -// Similar to isLowerSaturate(), but checks for upper-saturating conditions. -static bool isUpperSaturate(const SDValue LHS, const SDValue RHS, - const SDValue TrueVal, const SDValue FalseVal, - const ISD::CondCode CC, const SDValue K) { - return (isGTorGE(CC) && - ((K == RHS && K == TrueVal) || (K == LHS && K == FalseVal))) || - (isLTorLE(CC) && - ((K == LHS && K == TrueVal) || (K == RHS && K == FalseVal))); -} - // Check if two chained conditionals could be converted into SSAT or USAT. // // SSAT can replace a set of two conditional selectors that bound a number to an @@ -4764,101 +5279,68 @@ static bool isUpperSaturate(const SDValue LHS, const SDValue RHS, // x < k ? (x < -k ? -k : x) : k // etc. // -// USAT works similarily to SSAT but bounds on the interval [0, k] where k + 1 is -// a power of 2. +// LLVM canonicalizes these to either a min(max()) or a max(min()) +// pattern. This function tries to match one of these and will return a SSAT +// node if successful. // -// It returns true if the conversion can be done, false otherwise. -// Additionally, the variable is returned in parameter V, the constant in K and -// usat is set to true if the conditional represents an unsigned saturation -static bool isSaturatingConditional(const SDValue &Op, SDValue &V, - uint64_t &K, bool &usat) { - SDValue LHS1 = Op.getOperand(0); - SDValue RHS1 = Op.getOperand(1); +// USAT works similarily to SSAT but bounds on the interval [0, k] where k + 1 +// is a power of 2. +static SDValue LowerSaturatingConditional(SDValue Op, SelectionDAG &DAG) { + EVT VT = Op.getValueType(); + SDValue V1 = Op.getOperand(0); + SDValue K1 = Op.getOperand(1); SDValue TrueVal1 = Op.getOperand(2); SDValue FalseVal1 = Op.getOperand(3); ISD::CondCode CC1 = cast<CondCodeSDNode>(Op.getOperand(4))->get(); const SDValue Op2 = isa<ConstantSDNode>(TrueVal1) ? FalseVal1 : TrueVal1; if (Op2.getOpcode() != ISD::SELECT_CC) - return false; + return SDValue(); - SDValue LHS2 = Op2.getOperand(0); - SDValue RHS2 = Op2.getOperand(1); + SDValue V2 = Op2.getOperand(0); + SDValue K2 = Op2.getOperand(1); SDValue TrueVal2 = Op2.getOperand(2); SDValue FalseVal2 = Op2.getOperand(3); ISD::CondCode CC2 = cast<CondCodeSDNode>(Op2.getOperand(4))->get(); - // Find out which are the constants and which are the variables - // in each conditional - SDValue *K1 = isa<ConstantSDNode>(LHS1) ? &LHS1 : isa<ConstantSDNode>(RHS1) - ? &RHS1 - : nullptr; - SDValue *K2 = isa<ConstantSDNode>(LHS2) ? &LHS2 : isa<ConstantSDNode>(RHS2) - ? &RHS2 - : nullptr; - SDValue K2Tmp = isa<ConstantSDNode>(TrueVal2) ? TrueVal2 : FalseVal2; - SDValue V1Tmp = (K1 && *K1 == LHS1) ? RHS1 : LHS1; - SDValue V2Tmp = (K2 && *K2 == LHS2) ? RHS2 : LHS2; - SDValue V2 = (K2Tmp == TrueVal2) ? FalseVal2 : TrueVal2; - - // We must detect cases where the original operations worked with 16- or - // 8-bit values. In such case, V2Tmp != V2 because the comparison operations - // must work with sign-extended values but the select operations return - // the original non-extended value. - SDValue V2TmpReg = V2Tmp; - if (V2Tmp->getOpcode() == ISD::SIGN_EXTEND_INREG) - V2TmpReg = V2Tmp->getOperand(0); - - // Check that the registers and the constants have the correct values - // in both conditionals - if (!K1 || !K2 || *K1 == Op2 || *K2 != K2Tmp || V1Tmp != V2Tmp || - V2TmpReg != V2) - return false; + SDValue V1Tmp = V1; + SDValue V2Tmp = V2; - // Figure out which conditional is saturating the lower/upper bound. - const SDValue *LowerCheckOp = - isLowerSaturate(LHS1, RHS1, TrueVal1, FalseVal1, CC1, *K1) - ? &Op - : isLowerSaturate(LHS2, RHS2, TrueVal2, FalseVal2, CC2, *K2) - ? &Op2 - : nullptr; - const SDValue *UpperCheckOp = - isUpperSaturate(LHS1, RHS1, TrueVal1, FalseVal1, CC1, *K1) - ? &Op - : isUpperSaturate(LHS2, RHS2, TrueVal2, FalseVal2, CC2, *K2) - ? &Op2 - : nullptr; - - if (!UpperCheckOp || !LowerCheckOp || LowerCheckOp == UpperCheckOp) - return false; + // Check that the registers and the constants match a max(min()) or min(max()) + // pattern + if (V1Tmp != TrueVal1 || V2Tmp != TrueVal2 || K1 != FalseVal1 || + K2 != FalseVal2 || + !((isGTorGE(CC1) && isLTorLE(CC2)) || (isLTorLE(CC1) && isGTorGE(CC2)))) + return SDValue(); // Check that the constant in the lower-bound check is // the opposite of the constant in the upper-bound check // in 1's complement. - int64_t Val1 = cast<ConstantSDNode>(*K1)->getSExtValue(); - int64_t Val2 = cast<ConstantSDNode>(*K2)->getSExtValue(); + if (!isa<ConstantSDNode>(K1) || !isa<ConstantSDNode>(K2)) + return SDValue(); + + int64_t Val1 = cast<ConstantSDNode>(K1)->getSExtValue(); + int64_t Val2 = cast<ConstantSDNode>(K2)->getSExtValue(); int64_t PosVal = std::max(Val1, Val2); int64_t NegVal = std::min(Val1, Val2); - if (((Val1 > Val2 && UpperCheckOp == &Op) || - (Val1 < Val2 && UpperCheckOp == &Op2)) && - isPowerOf2_64(PosVal + 1)) { - - // Handle the difference between USAT (unsigned) and SSAT (signed) saturation - if (Val1 == ~Val2) - usat = false; - else if (NegVal == 0) - usat = true; - else - return false; - - V = V2; - K = (uint64_t)PosVal; // At this point, PosVal is guaranteed to be positive + if (!((Val1 > Val2 && isLTorLE(CC1)) || (Val1 < Val2 && isLTorLE(CC2))) || + !isPowerOf2_64(PosVal + 1)) + return SDValue(); - return true; - } + // Handle the difference between USAT (unsigned) and SSAT (signed) + // saturation + // At this point, PosVal is guaranteed to be positive + uint64_t K = PosVal; + SDLoc dl(Op); + if (Val1 == ~Val2) + return DAG.getNode(ARMISD::SSAT, dl, VT, V2Tmp, + DAG.getConstant(countTrailingOnes(K), dl, VT)); + if (NegVal == 0) + return DAG.getNode(ARMISD::USAT, dl, VT, V2Tmp, + DAG.getConstant(countTrailingOnes(K), dl, VT)); - return false; + return SDValue(); } // Check if a condition of the type x < k ? k : x can be converted into a @@ -4918,18 +5400,9 @@ SDValue ARMTargetLowering::LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); // Try to convert two saturating conditional selects into a single SSAT - SDValue SatValue; - uint64_t SatConstant; - bool SatUSat; - if (((!Subtarget->isThumb() && Subtarget->hasV6Ops()) || Subtarget->isThumb2()) && - isSaturatingConditional(Op, SatValue, SatConstant, SatUSat)) { - if (SatUSat) - return DAG.getNode(ARMISD::USAT, dl, VT, SatValue, - DAG.getConstant(countTrailingOnes(SatConstant), dl, VT)); - else - return DAG.getNode(ARMISD::SSAT, dl, VT, SatValue, - DAG.getConstant(countTrailingOnes(SatConstant), dl, VT)); - } + if ((!Subtarget->isThumb() && Subtarget->hasV6Ops()) || Subtarget->isThumb2()) + if (SDValue SatValue = LowerSaturatingConditional(Op, DAG)) + return SatValue; // Try to convert expressions of the form x < k ? k : x (and similar forms) // into more efficient bit operations, which is possible when k is 0 or -1 @@ -4938,6 +5411,7 @@ SDValue ARMTargetLowering::LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const { // instructions. // Only allow this transformation on full-width (32-bit) operations SDValue LowerSatConstant; + SDValue SatValue; if (VT == MVT::i32 && isLowerSaturatingConditional(Op, SatValue, LowerSatConstant)) { SDValue ShiftV = DAG.getNode(ISD::SRA, dl, VT, SatValue, @@ -4995,8 +5469,6 @@ SDValue ARMTargetLowering::LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const { std::swap(TVal, FVal); CC = ISD::getSetCCInverse(CC, LHS.getValueType()); } - if (TVal == 0) - TrueVal = DAG.getRegister(ARM::ZR, MVT::i32); // Drops F's value because we can get it by inverting/negating TVal. FalseVal = TrueVal; @@ -5118,7 +5590,7 @@ static SDValue bitcastf32Toi32(SDValue Op, SelectionDAG &DAG) { if (LoadSDNode *Ld = dyn_cast<LoadSDNode>(Op)) return DAG.getLoad(MVT::i32, SDLoc(Op), Ld->getChain(), Ld->getBasePtr(), - Ld->getPointerInfo(), Ld->getAlignment(), + Ld->getPointerInfo(), Ld->getAlign(), Ld->getMemOperand()->getFlags()); llvm_unreachable("Unknown VFP cmp argument!"); @@ -5138,14 +5610,14 @@ static void expandf64Toi32(SDValue Op, SelectionDAG &DAG, SDValue Ptr = Ld->getBasePtr(); RetVal1 = DAG.getLoad(MVT::i32, dl, Ld->getChain(), Ptr, Ld->getPointerInfo(), - Ld->getAlignment(), Ld->getMemOperand()->getFlags()); + Ld->getAlign(), Ld->getMemOperand()->getFlags()); EVT PtrType = Ptr.getValueType(); - unsigned NewAlign = MinAlign(Ld->getAlignment(), 4); SDValue NewPtr = DAG.getNode(ISD::ADD, dl, PtrType, Ptr, DAG.getConstant(4, dl, PtrType)); RetVal2 = DAG.getLoad(MVT::i32, dl, Ld->getChain(), NewPtr, - Ld->getPointerInfo().getWithOffset(4), NewAlign, + Ld->getPointerInfo().getWithOffset(4), + commonAlignment(Ld->getAlign(), 4), Ld->getMemOperand()->getFlags()); return; } @@ -5372,8 +5844,7 @@ static SDValue LowerVectorFP_TO_INT(SDValue Op, SelectionDAG &DAG) { return DAG.UnrollVectorOp(Op.getNode()); } - const bool HasFullFP16 = - static_cast<const ARMSubtarget&>(DAG.getSubtarget()).hasFullFP16(); + const bool HasFullFP16 = DAG.getSubtarget<ARMSubtarget>().hasFullFP16(); EVT NewTy; const EVT OpTy = Op.getOperand(0).getValueType(); @@ -5432,6 +5903,43 @@ SDValue ARMTargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const { return Op; } +static SDValue LowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG, + const ARMSubtarget *Subtarget) { + EVT VT = Op.getValueType(); + EVT ToVT = cast<VTSDNode>(Op.getOperand(1))->getVT(); + EVT FromVT = Op.getOperand(0).getValueType(); + + if (VT == MVT::i32 && ToVT == MVT::i32 && FromVT == MVT::f32) + return Op; + if (VT == MVT::i32 && ToVT == MVT::i32 && FromVT == MVT::f64 && + Subtarget->hasFP64()) + return Op; + if (VT == MVT::i32 && ToVT == MVT::i32 && FromVT == MVT::f16 && + Subtarget->hasFullFP16()) + return Op; + if (VT == MVT::v4i32 && ToVT == MVT::i32 && FromVT == MVT::v4f32 && + Subtarget->hasMVEFloatOps()) + return Op; + if (VT == MVT::v8i16 && ToVT == MVT::i16 && FromVT == MVT::v8f16 && + Subtarget->hasMVEFloatOps()) + return Op; + + if (FromVT != MVT::v4f32 && FromVT != MVT::v8f16) + return SDValue(); + + SDLoc DL(Op); + bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT_SAT; + unsigned BW = ToVT.getScalarSizeInBits() - IsSigned; + SDValue CVT = DAG.getNode(Op.getOpcode(), DL, VT, Op.getOperand(0), + DAG.getValueType(VT.getScalarType())); + SDValue Max = DAG.getNode(IsSigned ? ISD::SMIN : ISD::UMIN, DL, VT, CVT, + DAG.getConstant((1 << BW) - 1, DL, VT)); + if (IsSigned) + Max = DAG.getNode(ISD::SMAX, DL, VT, Max, + DAG.getConstant(-(1 << BW), DL, VT)); + return Max; +} + static SDValue LowerVectorINT_TO_FP(SDValue Op, SelectionDAG &DAG) { EVT VT = Op.getValueType(); SDLoc dl(Op); @@ -5446,8 +5954,7 @@ static SDValue LowerVectorINT_TO_FP(SDValue Op, SelectionDAG &DAG) { Op.getOperand(0).getValueType() == MVT::v8i16) && "Invalid type for custom lowering!"); - const bool HasFullFP16 = - static_cast<const ARMSubtarget&>(DAG.getSubtarget()).hasFullFP16(); + const bool HasFullFP16 = DAG.getSubtarget<ARMSubtarget>().hasFullFP16(); EVT DestVecType; if (VT == MVT::v4f32) @@ -5599,7 +6106,7 @@ SDValue ARMTargetLowering::LowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const{ } // Return LR, which contains the return address. Mark it an implicit live-in. - unsigned Reg = MF.addLiveIn(ARM::LR, getRegClassFor(MVT::i32)); + Register Reg = MF.addLiveIn(ARM::LR, getRegClassFor(MVT::i32)); return DAG.getCopyFromReg(DAG.getEntryNode(), dl, Reg, VT); } @@ -5709,85 +6216,27 @@ static SDValue CombineVMOVDRRCandidateWithVecOp(const SDNode *BC, /// use a VMOVDRR or VMOVRRD node. This should not be done when the non-i64 /// operand type is illegal (e.g., v2f32 for a target that doesn't support /// vectors), since the legalizer won't know what to do with that. -static SDValue ExpandBITCAST(SDNode *N, SelectionDAG &DAG, - const ARMSubtarget *Subtarget) { +SDValue ARMTargetLowering::ExpandBITCAST(SDNode *N, SelectionDAG &DAG, + const ARMSubtarget *Subtarget) const { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); SDLoc dl(N); SDValue Op = N->getOperand(0); - // This function is only supposed to be called for i64 types, either as the - // source or destination of the bit convert. + // This function is only supposed to be called for i16 and i64 types, either + // as the source or destination of the bit convert. EVT SrcVT = Op.getValueType(); EVT DstVT = N->getValueType(0); - const bool HasFullFP16 = Subtarget->hasFullFP16(); - - if (SrcVT == MVT::f32 && DstVT == MVT::i32) { - // FullFP16: half values are passed in S-registers, and we don't - // need any of the bitcast and moves: - // - // t2: f32,ch = CopyFromReg t0, Register:f32 %0 - // t5: i32 = bitcast t2 - // t18: f16 = ARMISD::VMOVhr t5 - if (Op.getOpcode() != ISD::CopyFromReg || - Op.getValueType() != MVT::f32) - return SDValue(); - - auto Move = N->use_begin(); - if (Move->getOpcode() != ARMISD::VMOVhr) - return SDValue(); - - SDValue Ops[] = { Op.getOperand(0), Op.getOperand(1) }; - SDValue Copy = DAG.getNode(ISD::CopyFromReg, SDLoc(Op), MVT::f16, Ops); - DAG.ReplaceAllUsesWith(*Move, &Copy); - return Copy; - } - - if (SrcVT == MVT::i16 && DstVT == MVT::f16) { - if (!HasFullFP16) - return SDValue(); - // SoftFP: read half-precision arguments: - // - // t2: i32,ch = ... - // t7: i16 = truncate t2 <~~~~ Op - // t8: f16 = bitcast t7 <~~~~ N - // - if (Op.getOperand(0).getValueType() == MVT::i32) - return DAG.getNode(ARMISD::VMOVhr, SDLoc(Op), - MVT::f16, Op.getOperand(0)); - return SDValue(); - } + if ((SrcVT == MVT::i16 || SrcVT == MVT::i32) && + (DstVT == MVT::f16 || DstVT == MVT::bf16)) + return MoveToHPR(SDLoc(N), DAG, MVT::i32, DstVT.getSimpleVT(), + DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), MVT::i32, Op)); - // Half-precision return values - if (SrcVT == MVT::f16 && DstVT == MVT::i16) { - if (!HasFullFP16) - return SDValue(); - // - // t11: f16 = fadd t8, t10 - // t12: i16 = bitcast t11 <~~~ SDNode N - // t13: i32 = zero_extend t12 - // t16: ch,glue = CopyToReg t0, Register:i32 %r0, t13 - // t17: ch = ARMISD::RET_FLAG t16, Register:i32 %r0, t16:1 - // - // transform this into: - // - // t20: i32 = ARMISD::VMOVrh t11 - // t16: ch,glue = CopyToReg t0, Register:i32 %r0, t20 - // - auto ZeroExtend = N->use_begin(); - if (N->use_size() != 1 || ZeroExtend->getOpcode() != ISD::ZERO_EXTEND || - ZeroExtend->getValueType(0) != MVT::i32) - return SDValue(); - - auto Copy = ZeroExtend->use_begin(); - if (Copy->getOpcode() == ISD::CopyToReg && - Copy->use_begin()->getOpcode() == ARMISD::RET_FLAG) { - SDValue Cvt = DAG.getNode(ARMISD::VMOVrh, SDLoc(Op), MVT::i32, Op); - DAG.ReplaceAllUsesWith(*ZeroExtend, &Cvt); - return Cvt; - } - return SDValue(); - } + if ((DstVT == MVT::i16 || DstVT == MVT::i32) && + (SrcVT == MVT::f16 || SrcVT == MVT::bf16)) + return DAG.getNode( + ISD::TRUNCATE, SDLoc(N), DstVT, + MoveFromHPR(SDLoc(N), DAG, MVT::i32, SrcVT.getSimpleVT(), Op)); if (!(SrcVT == MVT::i64 || DstVT == MVT::i64)) return SDValue(); @@ -5923,23 +6372,69 @@ SDValue ARMTargetLowering::LowerShiftLeftParts(SDValue Op, return DAG.getMergeValues(Ops, dl); } -SDValue ARMTargetLowering::LowerFLT_ROUNDS_(SDValue Op, - SelectionDAG &DAG) const { +SDValue ARMTargetLowering::LowerGET_ROUNDING(SDValue Op, + SelectionDAG &DAG) const { // The rounding mode is in bits 23:22 of the FPSCR. // The ARM rounding mode value to FLT_ROUNDS mapping is 0->1, 1->2, 2->3, 3->0 // The formula we use to implement this is (((FPSCR + 1 << 22) >> 22) & 3) // so that the shift + and get folded into a bitfield extract. SDLoc dl(Op); - SDValue Ops[] = { DAG.getEntryNode(), - DAG.getConstant(Intrinsic::arm_get_fpscr, dl, MVT::i32) }; + SDValue Chain = Op.getOperand(0); + SDValue Ops[] = {Chain, + DAG.getConstant(Intrinsic::arm_get_fpscr, dl, MVT::i32)}; - SDValue FPSCR = DAG.getNode(ISD::INTRINSIC_W_CHAIN, dl, MVT::i32, Ops); + SDValue FPSCR = + DAG.getNode(ISD::INTRINSIC_W_CHAIN, dl, {MVT::i32, MVT::Other}, Ops); + Chain = FPSCR.getValue(1); SDValue FltRounds = DAG.getNode(ISD::ADD, dl, MVT::i32, FPSCR, DAG.getConstant(1U << 22, dl, MVT::i32)); SDValue RMODE = DAG.getNode(ISD::SRL, dl, MVT::i32, FltRounds, DAG.getConstant(22, dl, MVT::i32)); - return DAG.getNode(ISD::AND, dl, MVT::i32, RMODE, - DAG.getConstant(3, dl, MVT::i32)); + SDValue And = DAG.getNode(ISD::AND, dl, MVT::i32, RMODE, + DAG.getConstant(3, dl, MVT::i32)); + return DAG.getMergeValues({And, Chain}, dl); +} + +SDValue ARMTargetLowering::LowerSET_ROUNDING(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + SDValue Chain = Op->getOperand(0); + SDValue RMValue = Op->getOperand(1); + + // The rounding mode is in bits 23:22 of the FPSCR. + // The llvm.set.rounding argument value to ARM rounding mode value mapping + // is 0->3, 1->0, 2->1, 3->2. The formula we use to implement this is + // ((arg - 1) & 3) << 22). + // + // It is expected that the argument of llvm.set.rounding is within the + // segment [0, 3], so NearestTiesToAway (4) is not handled here. It is + // responsibility of the code generated llvm.set.rounding to ensure this + // condition. + + // Calculate new value of FPSCR[23:22]. + RMValue = DAG.getNode(ISD::SUB, DL, MVT::i32, RMValue, + DAG.getConstant(1, DL, MVT::i32)); + RMValue = DAG.getNode(ISD::AND, DL, MVT::i32, RMValue, + DAG.getConstant(0x3, DL, MVT::i32)); + RMValue = DAG.getNode(ISD::SHL, DL, MVT::i32, RMValue, + DAG.getConstant(ARM::RoundingBitsPos, DL, MVT::i32)); + + // Get current value of FPSCR. + SDValue Ops[] = {Chain, + DAG.getConstant(Intrinsic::arm_get_fpscr, DL, MVT::i32)}; + SDValue FPSCR = + DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, {MVT::i32, MVT::Other}, Ops); + Chain = FPSCR.getValue(1); + FPSCR = FPSCR.getValue(0); + + // Put new rounding mode into FPSCR[23:22]. + const unsigned RMMask = ~(ARM::Rounding::rmMask << ARM::RoundingBitsPos); + FPSCR = DAG.getNode(ISD::AND, DL, MVT::i32, FPSCR, + DAG.getConstant(RMMask, DL, MVT::i32)); + FPSCR = DAG.getNode(ISD::OR, DL, MVT::i32, FPSCR, RMValue); + SDValue Ops2[] = { + Chain, DAG.getConstant(Intrinsic::arm_set_fpscr, DL, MVT::i32), FPSCR}; + return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2); } static SDValue LowerCTTZ(SDNode *N, SelectionDAG &DAG, @@ -6271,23 +6766,23 @@ static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG, if (ST->hasMVEFloatOps()) { Opc = ARMCC::NE; break; } else { - Invert = true; LLVM_FALLTHROUGH; + Invert = true; [[fallthrough]]; } case ISD::SETOEQ: case ISD::SETEQ: Opc = ARMCC::EQ; break; case ISD::SETOLT: - case ISD::SETLT: Swap = true; LLVM_FALLTHROUGH; + case ISD::SETLT: Swap = true; [[fallthrough]]; case ISD::SETOGT: case ISD::SETGT: Opc = ARMCC::GT; break; case ISD::SETOLE: - case ISD::SETLE: Swap = true; LLVM_FALLTHROUGH; + case ISD::SETLE: Swap = true; [[fallthrough]]; case ISD::SETOGE: case ISD::SETGE: Opc = ARMCC::GE; break; - case ISD::SETUGE: Swap = true; LLVM_FALLTHROUGH; + case ISD::SETUGE: Swap = true; [[fallthrough]]; case ISD::SETULE: Invert = true; Opc = ARMCC::GT; break; - case ISD::SETUGT: Swap = true; LLVM_FALLTHROUGH; + case ISD::SETUGT: Swap = true; [[fallthrough]]; case ISD::SETULT: Invert = true; Opc = ARMCC::GE; break; - case ISD::SETUEQ: Invert = true; LLVM_FALLTHROUGH; + case ISD::SETUEQ: Invert = true; [[fallthrough]]; case ISD::SETONE: { // Expand this to (OLT | OGT). SDValue TmpOp0 = DAG.getNode(ARMISD::VCMP, dl, CmpVT, Op1, Op0, @@ -6299,7 +6794,7 @@ static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG, Result = DAG.getNOT(dl, Result, VT); return Result; } - case ISD::SETUO: Invert = true; LLVM_FALLTHROUGH; + case ISD::SETUO: Invert = true; [[fallthrough]]; case ISD::SETO: { // Expand this to (OLT | OGE). SDValue TmpOp0 = DAG.getNode(ARMISD::VCMP, dl, CmpVT, Op1, Op0, @@ -6320,16 +6815,16 @@ static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG, if (ST->hasMVEIntegerOps()) { Opc = ARMCC::NE; break; } else { - Invert = true; LLVM_FALLTHROUGH; + Invert = true; [[fallthrough]]; } case ISD::SETEQ: Opc = ARMCC::EQ; break; - case ISD::SETLT: Swap = true; LLVM_FALLTHROUGH; + case ISD::SETLT: Swap = true; [[fallthrough]]; case ISD::SETGT: Opc = ARMCC::GT; break; - case ISD::SETLE: Swap = true; LLVM_FALLTHROUGH; + case ISD::SETLE: Swap = true; [[fallthrough]]; case ISD::SETGE: Opc = ARMCC::GE; break; - case ISD::SETULT: Swap = true; LLVM_FALLTHROUGH; + case ISD::SETULT: Swap = true; [[fallthrough]]; case ISD::SETUGT: Opc = ARMCC::HI; break; - case ISD::SETULE: Swap = true; LLVM_FALLTHROUGH; + case ISD::SETULE: Swap = true; [[fallthrough]]; case ISD::SETUGE: Opc = ARMCC::HS; break; } @@ -6361,25 +6856,25 @@ static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG, // If one of the operands is a constant vector zero, attempt to fold the // comparison to a specialized compare-against-zero form. - SDValue SingleOp; - if (ISD::isBuildVectorAllZeros(Op1.getNode())) - SingleOp = Op0; - else if (ISD::isBuildVectorAllZeros(Op0.getNode())) { + if (ISD::isBuildVectorAllZeros(Op0.getNode()) && + (Opc == ARMCC::GE || Opc == ARMCC::GT || Opc == ARMCC::EQ || + Opc == ARMCC::NE)) { if (Opc == ARMCC::GE) Opc = ARMCC::LE; else if (Opc == ARMCC::GT) Opc = ARMCC::LT; - SingleOp = Op1; + std::swap(Op0, Op1); } SDValue Result; - if (SingleOp.getNode()) { - Result = DAG.getNode(ARMISD::VCMPZ, dl, CmpVT, SingleOp, + if (ISD::isBuildVectorAllZeros(Op1.getNode()) && + (Opc == ARMCC::GE || Opc == ARMCC::GT || Opc == ARMCC::LE || + Opc == ARMCC::LT || Opc == ARMCC::NE || Opc == ARMCC::EQ)) + Result = DAG.getNode(ARMISD::VCMPZ, dl, CmpVT, Op0, DAG.getConstant(Opc, dl, MVT::i32)); - } else { + else Result = DAG.getNode(ARMISD::VCMP, dl, CmpVT, Op0, Op1, DAG.getConstant(Opc, dl, MVT::i32)); - } Result = DAG.getSExtOrTrunc(Result, dl, VT); @@ -6424,9 +6919,10 @@ static SDValue LowerSETCCCARRY(SDValue Op, SelectionDAG &DAG) { /// immediate" operand (e.g., VMOV). If so, return the encoded value. static SDValue isVMOVModifiedImm(uint64_t SplatBits, uint64_t SplatUndef, unsigned SplatBitSize, SelectionDAG &DAG, - const SDLoc &dl, EVT &VT, bool is128Bits, + const SDLoc &dl, EVT &VT, EVT VectorVT, VMOVModImmType type) { unsigned OpCmode, Imm; + bool is128Bits = VectorVT.is128BitVector(); // SplatBitSize is set to the smallest size that splats the vector, so a // zero vector will always have SplatBitSize == 8. However, NEON modified @@ -6530,12 +7026,10 @@ static SDValue isVMOVModifiedImm(uint64_t SplatBits, uint64_t SplatUndef, return SDValue(); // NEON has a 64-bit VMOV splat where each byte is either 0 or 0xff. uint64_t BitMask = 0xff; - uint64_t Val = 0; unsigned ImmMask = 1; Imm = 0; for (int ByteNum = 0; ByteNum < 8; ++ByteNum) { if (((SplatBits | SplatUndef) & BitMask) == BitMask) { - Val |= BitMask; Imm |= ImmMask; } else if ((SplatBits & BitMask) != 0) { return SDValue(); @@ -6544,9 +7038,18 @@ static SDValue isVMOVModifiedImm(uint64_t SplatBits, uint64_t SplatUndef, ImmMask <<= 1; } - if (DAG.getDataLayout().isBigEndian()) - // swap higher and lower 32 bit word - Imm = ((Imm & 0xf) << 4) | ((Imm & 0xf0) >> 4); + if (DAG.getDataLayout().isBigEndian()) { + // Reverse the order of elements within the vector. + unsigned BytesPerElem = VectorVT.getScalarSizeInBits() / 8; + unsigned Mask = (1 << BytesPerElem) - 1; + unsigned NumElems = 8 / BytesPerElem; + unsigned NewImm = 0; + for (unsigned ElemNum = 0; ElemNum < NumElems; ++ElemNum) { + unsigned Elem = ((Imm >> ElemNum * BytesPerElem) & Mask); + NewImm |= Elem << (NumElems - ElemNum - 1) * BytesPerElem; + } + Imm = NewImm; + } // Op=1, Cmode=1110. OpCmode = 0x1e; @@ -6585,8 +7088,6 @@ SDValue ARMTargetLowering::LowerConstantFP(SDValue Op, SelectionDAG &DAG, case MVT::f64: { SDValue Lo = DAG.getConstant(INTVal.trunc(32), DL, MVT::i32); SDValue Hi = DAG.getConstant(INTVal.lshr(32).trunc(32), DL, MVT::i32); - if (!ST->isLittle()) - std::swap(Lo, Hi); return DAG.getNode(ARMISD::VMOVDRR, DL, MVT::f64, Lo, Hi); } case MVT::f32: @@ -6639,7 +7140,7 @@ SDValue ARMTargetLowering::LowerConstantFP(SDValue Op, SelectionDAG &DAG, // Try a VMOV.i32 (FIXME: i8, i16, or i64 could work too). SDValue NewVal = isVMOVModifiedImm(iVal & 0xffffffffU, 0, 32, DAG, SDLoc(Op), - VMovVT, false, VMOVModImm); + VMovVT, VT, VMOVModImm); if (NewVal != SDValue()) { SDLoc DL(Op); SDValue VecConstant = DAG.getNode(ARMISD::VMOVIMM, DL, VMovVT, @@ -6656,7 +7157,7 @@ SDValue ARMTargetLowering::LowerConstantFP(SDValue Op, SelectionDAG &DAG, // Finally, try a VMVN.i32 NewVal = isVMOVModifiedImm(~iVal & 0xffffffffU, 0, 32, DAG, SDLoc(Op), VMovVT, - false, VMVNModImm); + VT, VMVNModImm); if (NewVal != SDValue()) { SDLoc DL(Op); SDValue VecConstant = DAG.getNode(ARMISD::VMVNIMM, DL, VMovVT, NewVal); @@ -6740,35 +7241,6 @@ static bool isVEXTMask(ArrayRef<int> M, EVT VT, return true; } -/// isVREVMask - Check if a vector shuffle corresponds to a VREV -/// instruction with the specified blocksize. (The order of the elements -/// within each block of the vector is reversed.) -static bool isVREVMask(ArrayRef<int> M, EVT VT, unsigned BlockSize) { - assert((BlockSize==16 || BlockSize==32 || BlockSize==64) && - "Only possible block sizes for VREV are: 16, 32, 64"); - - unsigned EltSz = VT.getScalarSizeInBits(); - if (EltSz == 64) - return false; - - unsigned NumElts = VT.getVectorNumElements(); - unsigned BlockElts = M[0] + 1; - // If the first shuffle index is UNDEF, be optimistic. - if (M[0] < 0) - BlockElts = BlockSize / EltSz; - - if (BlockSize <= EltSz || BlockSize != BlockElts * EltSz) - return false; - - for (unsigned i = 0; i < NumElts; ++i) { - if (M[i] < 0) continue; // ignore UNDEF indices - if ((unsigned) M[i] != (i - i%BlockElts) + (BlockElts - 1 - i%BlockElts)) - return false; - } - - return true; -} - static bool isVTBLMask(ArrayRef<int> M, EVT VT) { // We can handle <8 x i8> vector shuffles. If the index in the mask is out of // range, then 0 is placed into the resulting vector. So pretty much any mask @@ -7041,11 +7513,33 @@ static bool isReverseMask(ArrayRef<int> M, EVT VT) { return true; } -static bool isVMOVNMask(ArrayRef<int> M, EVT VT, bool Top) { +static bool isTruncMask(ArrayRef<int> M, EVT VT, bool Top, bool SingleSource) { unsigned NumElts = VT.getVectorNumElements(); // Make sure the mask has the right size. if (NumElts != M.size() || (VT != MVT::v8i16 && VT != MVT::v16i8)) + return false; + + // Half-width truncation patterns (e.g. v4i32 -> v8i16): + // !Top && SingleSource: <0, 2, 4, 6, 0, 2, 4, 6> + // !Top && !SingleSource: <0, 2, 4, 6, 8, 10, 12, 14> + // Top && SingleSource: <1, 3, 5, 7, 1, 3, 5, 7> + // Top && !SingleSource: <1, 3, 5, 7, 9, 11, 13, 15> + int Ofs = Top ? 1 : 0; + int Upper = SingleSource ? 0 : NumElts; + for (int i = 0, e = NumElts / 2; i != e; ++i) { + if (M[i] >= 0 && M[i] != (i * 2) + Ofs) return false; + if (M[i + e] >= 0 && M[i + e] != (i * 2) + Ofs + Upper) + return false; + } + return true; +} + +static bool isVMOVNMask(ArrayRef<int> M, EVT VT, bool Top, bool SingleSource) { + unsigned NumElts = VT.getVectorNumElements(); + // Make sure the mask has the right size. + if (NumElts != M.size() || (VT != MVT::v8i16 && VT != MVT::v16i8)) + return false; // If Top // Look for <0, N, 2, N+2, 4, N+4, ..>. @@ -7054,16 +7548,137 @@ static bool isVMOVNMask(ArrayRef<int> M, EVT VT, bool Top) { // Look for <0, N+1, 2, N+3, 4, N+5, ..> // This inserts Input1 into Input2 unsigned Offset = Top ? 0 : 1; - for (unsigned i = 0; i < NumElts; i+=2) { + unsigned N = SingleSource ? 0 : NumElts; + for (unsigned i = 0; i < NumElts; i += 2) { if (M[i] >= 0 && M[i] != (int)i) return false; - if (M[i+1] >= 0 && M[i+1] != (int)(NumElts + i + Offset)) + if (M[i + 1] >= 0 && M[i + 1] != (int)(N + i + Offset)) return false; } return true; } +static bool isVMOVNTruncMask(ArrayRef<int> M, EVT ToVT, bool rev) { + unsigned NumElts = ToVT.getVectorNumElements(); + if (NumElts != M.size()) + return false; + + // Test if the Trunc can be convertable to a VMOVN with this shuffle. We are + // looking for patterns of: + // !rev: 0 N/2 1 N/2+1 2 N/2+2 ... + // rev: N/2 0 N/2+1 1 N/2+2 2 ... + + unsigned Off0 = rev ? NumElts / 2 : 0; + unsigned Off1 = rev ? 0 : NumElts / 2; + for (unsigned i = 0; i < NumElts; i += 2) { + if (M[i] >= 0 && M[i] != (int)(Off0 + i / 2)) + return false; + if (M[i + 1] >= 0 && M[i + 1] != (int)(Off1 + i / 2)) + return false; + } + + return true; +} + +// Reconstruct an MVE VCVT from a BuildVector of scalar fptrunc, all extracted +// from a pair of inputs. For example: +// BUILDVECTOR(FP_ROUND(EXTRACT_ELT(X, 0), +// FP_ROUND(EXTRACT_ELT(Y, 0), +// FP_ROUND(EXTRACT_ELT(X, 1), +// FP_ROUND(EXTRACT_ELT(Y, 1), ...) +static SDValue LowerBuildVectorOfFPTrunc(SDValue BV, SelectionDAG &DAG, + const ARMSubtarget *ST) { + assert(BV.getOpcode() == ISD::BUILD_VECTOR && "Unknown opcode!"); + if (!ST->hasMVEFloatOps()) + return SDValue(); + + SDLoc dl(BV); + EVT VT = BV.getValueType(); + if (VT != MVT::v8f16) + return SDValue(); + + // We are looking for a buildvector of fptrunc elements, where all the + // elements are interleavingly extracted from two sources. Check the first two + // items are valid enough and extract some info from them (they are checked + // properly in the loop below). + if (BV.getOperand(0).getOpcode() != ISD::FP_ROUND || + BV.getOperand(0).getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT || + BV.getOperand(0).getOperand(0).getConstantOperandVal(1) != 0) + return SDValue(); + if (BV.getOperand(1).getOpcode() != ISD::FP_ROUND || + BV.getOperand(1).getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT || + BV.getOperand(1).getOperand(0).getConstantOperandVal(1) != 0) + return SDValue(); + SDValue Op0 = BV.getOperand(0).getOperand(0).getOperand(0); + SDValue Op1 = BV.getOperand(1).getOperand(0).getOperand(0); + if (Op0.getValueType() != MVT::v4f32 || Op1.getValueType() != MVT::v4f32) + return SDValue(); + + // Check all the values in the BuildVector line up with our expectations. + for (unsigned i = 1; i < 4; i++) { + auto Check = [](SDValue Trunc, SDValue Op, unsigned Idx) { + return Trunc.getOpcode() == ISD::FP_ROUND && + Trunc.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT && + Trunc.getOperand(0).getOperand(0) == Op && + Trunc.getOperand(0).getConstantOperandVal(1) == Idx; + }; + if (!Check(BV.getOperand(i * 2 + 0), Op0, i)) + return SDValue(); + if (!Check(BV.getOperand(i * 2 + 1), Op1, i)) + return SDValue(); + } + + SDValue N1 = DAG.getNode(ARMISD::VCVTN, dl, VT, DAG.getUNDEF(VT), Op0, + DAG.getConstant(0, dl, MVT::i32)); + return DAG.getNode(ARMISD::VCVTN, dl, VT, N1, Op1, + DAG.getConstant(1, dl, MVT::i32)); +} + +// Reconstruct an MVE VCVT from a BuildVector of scalar fpext, all extracted +// from a single input on alternating lanes. For example: +// BUILDVECTOR(FP_ROUND(EXTRACT_ELT(X, 0), +// FP_ROUND(EXTRACT_ELT(X, 2), +// FP_ROUND(EXTRACT_ELT(X, 4), ...) +static SDValue LowerBuildVectorOfFPExt(SDValue BV, SelectionDAG &DAG, + const ARMSubtarget *ST) { + assert(BV.getOpcode() == ISD::BUILD_VECTOR && "Unknown opcode!"); + if (!ST->hasMVEFloatOps()) + return SDValue(); + + SDLoc dl(BV); + EVT VT = BV.getValueType(); + if (VT != MVT::v4f32) + return SDValue(); + + // We are looking for a buildvector of fptext elements, where all the + // elements are alternating lanes from a single source. For example <0,2,4,6> + // or <1,3,5,7>. Check the first two items are valid enough and extract some + // info from them (they are checked properly in the loop below). + if (BV.getOperand(0).getOpcode() != ISD::FP_EXTEND || + BV.getOperand(0).getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + SDValue Op0 = BV.getOperand(0).getOperand(0).getOperand(0); + int Offset = BV.getOperand(0).getOperand(0).getConstantOperandVal(1); + if (Op0.getValueType() != MVT::v8f16 || (Offset != 0 && Offset != 1)) + return SDValue(); + + // Check all the values in the BuildVector line up with our expectations. + for (unsigned i = 1; i < 4; i++) { + auto Check = [](SDValue Trunc, SDValue Op, unsigned Idx) { + return Trunc.getOpcode() == ISD::FP_EXTEND && + Trunc.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT && + Trunc.getOperand(0).getOperand(0) == Op && + Trunc.getOperand(0).getConstantOperandVal(1) == Idx; + }; + if (!Check(BV.getOperand(i), Op0, 2 * i + Offset)) + return SDValue(); + } + + return DAG.getNode(ARMISD::VCVTL, dl, VT, Op0, + DAG.getConstant(Offset, dl, MVT::i32)); +} + // If N is an integer constant that can be moved into a register in one // instruction, return an SDValue of such a constant (will become a MOV // instruction). Otherwise return null. @@ -7094,7 +7709,10 @@ static SDValue LowerBUILD_VECTOR_i1(SDValue Op, SelectionDAG &DAG, unsigned NumElts = VT.getVectorNumElements(); unsigned BoolMask; unsigned BitsPerBool; - if (NumElts == 4) { + if (NumElts == 2) { + BitsPerBool = 8; + BoolMask = 0xff; + } else if (NumElts == 4) { BitsPerBool = 4; BoolMask = 0xf; } else if (NumElts == 8) { @@ -7110,10 +7728,9 @@ static SDValue LowerBUILD_VECTOR_i1(SDValue Op, SelectionDAG &DAG, // extend that single value SDValue FirstOp = Op.getOperand(0); if (!isa<ConstantSDNode>(FirstOp) && - std::all_of(std::next(Op->op_begin()), Op->op_end(), - [&FirstOp](SDUse &U) { - return U.get().isUndef() || U.get() == FirstOp; - })) { + llvm::all_of(llvm::drop_begin(Op->ops()), [&FirstOp](const SDUse &U) { + return U.get().isUndef() || U.get() == FirstOp; + })) { SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, MVT::i32, FirstOp, DAG.getValueType(MVT::i1)); return DAG.getNode(ARMISD::PREDICATE_CAST, dl, Op.getValueType(), Ext); @@ -7144,6 +7761,79 @@ static SDValue LowerBUILD_VECTOR_i1(SDValue Op, SelectionDAG &DAG, return Base; } +static SDValue LowerBUILD_VECTORToVIDUP(SDValue Op, SelectionDAG &DAG, + const ARMSubtarget *ST) { + if (!ST->hasMVEIntegerOps()) + return SDValue(); + + // We are looking for a buildvector where each element is Op[0] + i*N + EVT VT = Op.getValueType(); + SDValue Op0 = Op.getOperand(0); + unsigned NumElts = VT.getVectorNumElements(); + + // Get the increment value from operand 1 + SDValue Op1 = Op.getOperand(1); + if (Op1.getOpcode() != ISD::ADD || Op1.getOperand(0) != Op0 || + !isa<ConstantSDNode>(Op1.getOperand(1))) + return SDValue(); + unsigned N = Op1.getConstantOperandVal(1); + if (N != 1 && N != 2 && N != 4 && N != 8) + return SDValue(); + + // Check that each other operand matches + for (unsigned I = 2; I < NumElts; I++) { + SDValue OpI = Op.getOperand(I); + if (OpI.getOpcode() != ISD::ADD || OpI.getOperand(0) != Op0 || + !isa<ConstantSDNode>(OpI.getOperand(1)) || + OpI.getConstantOperandVal(1) != I * N) + return SDValue(); + } + + SDLoc DL(Op); + return DAG.getNode(ARMISD::VIDUP, DL, DAG.getVTList(VT, MVT::i32), Op0, + DAG.getConstant(N, DL, MVT::i32)); +} + +// Returns true if the operation N can be treated as qr instruction variant at +// operand Op. +static bool IsQRMVEInstruction(const SDNode *N, const SDNode *Op) { + switch (N->getOpcode()) { + case ISD::ADD: + case ISD::MUL: + case ISD::SADDSAT: + case ISD::UADDSAT: + return true; + case ISD::SUB: + case ISD::SSUBSAT: + case ISD::USUBSAT: + return N->getOperand(1).getNode() == Op; + case ISD::INTRINSIC_WO_CHAIN: + switch (N->getConstantOperandVal(0)) { + case Intrinsic::arm_mve_add_predicated: + case Intrinsic::arm_mve_mul_predicated: + case Intrinsic::arm_mve_qadd_predicated: + case Intrinsic::arm_mve_vhadd: + case Intrinsic::arm_mve_hadd_predicated: + case Intrinsic::arm_mve_vqdmulh: + case Intrinsic::arm_mve_qdmulh_predicated: + case Intrinsic::arm_mve_vqrdmulh: + case Intrinsic::arm_mve_qrdmulh_predicated: + case Intrinsic::arm_mve_vqdmull: + case Intrinsic::arm_mve_vqdmull_predicated: + return true; + case Intrinsic::arm_mve_sub_predicated: + case Intrinsic::arm_mve_qsub_predicated: + case Intrinsic::arm_mve_vhsub: + case Intrinsic::arm_mve_hsub_predicated: + return N->getOperand(2).getNode() == Op; + default: + return false; + } + default: + return false; + } +} + // If this is a case we can't handle, return null and let the default // expansion code take care of it. SDValue ARMTargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, @@ -7155,21 +7845,37 @@ SDValue ARMTargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, if (ST->hasMVEIntegerOps() && VT.getScalarSizeInBits() == 1) return LowerBUILD_VECTOR_i1(Op, DAG, ST); + if (SDValue R = LowerBUILD_VECTORToVIDUP(Op, DAG, ST)) + return R; + APInt SplatBits, SplatUndef; unsigned SplatBitSize; bool HasAnyUndefs; if (BVN->isConstantSplat(SplatBits, SplatUndef, SplatBitSize, HasAnyUndefs)) { - if (SplatUndef.isAllOnesValue()) + if (SplatUndef.isAllOnes()) return DAG.getUNDEF(VT); + // If all the users of this constant splat are qr instruction variants, + // generate a vdup of the constant. + if (ST->hasMVEIntegerOps() && VT.getScalarSizeInBits() == SplatBitSize && + (SplatBitSize == 8 || SplatBitSize == 16 || SplatBitSize == 32) && + all_of(BVN->uses(), + [BVN](const SDNode *U) { return IsQRMVEInstruction(U, BVN); })) { + EVT DupVT = SplatBitSize == 32 ? MVT::v4i32 + : SplatBitSize == 16 ? MVT::v8i16 + : MVT::v16i8; + SDValue Const = DAG.getConstant(SplatBits.getZExtValue(), dl, MVT::i32); + SDValue VDup = DAG.getNode(ARMISD::VDUP, dl, DupVT, Const); + return DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, VT, VDup); + } + if ((ST->hasNEON() && SplatBitSize <= 64) || - (ST->hasMVEIntegerOps() && SplatBitSize <= 32)) { + (ST->hasMVEIntegerOps() && SplatBitSize <= 64)) { // Check if an immediate VMOV works. EVT VmovVT; - SDValue Val = isVMOVModifiedImm(SplatBits.getZExtValue(), - SplatUndef.getZExtValue(), SplatBitSize, - DAG, dl, VmovVT, VT.is128BitVector(), - VMOVModImm); + SDValue Val = + isVMOVModifiedImm(SplatBits.getZExtValue(), SplatUndef.getZExtValue(), + SplatBitSize, DAG, dl, VmovVT, VT, VMOVModImm); if (Val.getNode()) { SDValue Vmov = DAG.getNode(ARMISD::VMOVIMM, dl, VmovVT, Val); @@ -7179,9 +7885,8 @@ SDValue ARMTargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, // Try an immediate VMVN. uint64_t NegatedImm = (~SplatBits).getZExtValue(); Val = isVMOVModifiedImm( - NegatedImm, SplatUndef.getZExtValue(), SplatBitSize, - DAG, dl, VmovVT, VT.is128BitVector(), - ST->hasMVEIntegerOps() ? MVEVMVNModImm : VMVNModImm); + NegatedImm, SplatUndef.getZExtValue(), SplatBitSize, DAG, dl, VmovVT, + VT, ST->hasMVEIntegerOps() ? MVEVMVNModImm : VMVNModImm); if (Val.getNode()) { SDValue Vmov = DAG.getNode(ARMISD::VMVNIMM, dl, VmovVT, Val); return DAG.getNode(ISD::BITCAST, dl, VT, Vmov); @@ -7195,6 +7900,18 @@ SDValue ARMTargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, return DAG.getNode(ARMISD::VMOVFPIMM, dl, VT, Val); } } + + // If we are under MVE, generate a VDUP(constant), bitcast to the original + // type. + if (ST->hasMVEIntegerOps() && + (SplatBitSize == 8 || SplatBitSize == 16 || SplatBitSize == 32)) { + EVT DupVT = SplatBitSize == 32 ? MVT::v4i32 + : SplatBitSize == 16 ? MVT::v8i16 + : MVT::v16i8; + SDValue Const = DAG.getConstant(SplatBits.getZExtValue(), dl, MVT::i32); + SDValue VDup = DAG.getNode(ARMISD::VDUP, dl, DupVT, Const); + return DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, VT, VDup); + } } } @@ -7321,12 +8038,19 @@ SDValue ARMTargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, if (isConstant) return SDValue(); - // Empirical tests suggest this is rarely worth it for vectors of length <= 2. - if (NumElts >= 4) { - SDValue shuffle = ReconstructShuffle(Op, DAG); - if (shuffle != SDValue()) + // Reconstruct the BUILDVECTOR to one of the legal shuffles (such as vext and + // vmovn). Empirical tests suggest this is rarely worth it for vectors of + // length <= 2. + if (NumElts >= 4) + if (SDValue shuffle = ReconstructShuffle(Op, DAG)) return shuffle; - } + + // Attempt to turn a buildvector of scalar fptrunc's or fpext's back into + // VCVT's + if (SDValue VCVT = LowerBuildVectorOfFPTrunc(Op, DAG, Subtarget)) + return VCVT; + if (SDValue VCVT = LowerBuildVectorOfFPExt(Op, DAG, Subtarget)) + return VCVT; if (ST->hasNEON() && VT.is128BitVector() && VT != MVT::v2f64 && VT != MVT::v4f32) { // If we haven't found an efficient lowering, try splitting a 128-bit vector @@ -7334,12 +8058,11 @@ SDValue ARMTargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElts); EVT ExtVT = VT.getVectorElementType(); EVT HVT = EVT::getVectorVT(*DAG.getContext(), ExtVT, NumElts / 2); - SDValue Lower = - DAG.getBuildVector(HVT, dl, makeArrayRef(&Ops[0], NumElts / 2)); + SDValue Lower = DAG.getBuildVector(HVT, dl, ArrayRef(&Ops[0], NumElts / 2)); if (Lower.getOpcode() == ISD::BUILD_VECTOR) Lower = LowerBUILD_VECTOR(Lower, DAG, ST); - SDValue Upper = DAG.getBuildVector( - HVT, dl, makeArrayRef(&Ops[NumElts / 2], NumElts / 2)); + SDValue Upper = + DAG.getBuildVector(HVT, dl, ArrayRef(&Ops[NumElts / 2], NumElts / 2)); if (Upper.getOpcode() == ISD::BUILD_VECTOR) Upper = LowerBUILD_VECTOR(Upper, DAG, ST); if (Lower && Upper) @@ -7464,17 +8187,19 @@ SDValue ARMTargetLowering::ReconstructShuffle(SDValue Op, for (auto &Src : Sources) { EVT SrcVT = Src.ShuffleVec.getValueType(); - if (SrcVT.getSizeInBits() == VT.getSizeInBits()) + uint64_t SrcVTSize = SrcVT.getFixedSizeInBits(); + uint64_t VTSize = VT.getFixedSizeInBits(); + if (SrcVTSize == VTSize) continue; // This stage of the search produces a source with the same element type as // the original, but with a total width matching the BUILD_VECTOR output. EVT EltVT = SrcVT.getVectorElementType(); - unsigned NumSrcElts = VT.getSizeInBits() / EltVT.getSizeInBits(); + unsigned NumSrcElts = VTSize / EltVT.getFixedSizeInBits(); EVT DestVT = EVT::getVectorVT(*DAG.getContext(), EltVT, NumSrcElts); - if (SrcVT.getSizeInBits() < VT.getSizeInBits()) { - if (2 * SrcVT.getSizeInBits() != VT.getSizeInBits()) + if (SrcVTSize < VTSize) { + if (2 * SrcVTSize != VTSize) return SDValue(); // We can pad out the smaller vector for free, so if it's part of a // shuffle... @@ -7484,7 +8209,7 @@ SDValue ARMTargetLowering::ReconstructShuffle(SDValue Op, continue; } - if (SrcVT.getSizeInBits() != 2 * VT.getSizeInBits()) + if (SrcVTSize != 2 * VTSize) return SDValue(); if (Src.MaxElt - Src.MinElt >= NumSrcElts) { @@ -7527,12 +8252,12 @@ SDValue ARMTargetLowering::ReconstructShuffle(SDValue Op, if (SrcEltTy == SmallestEltTy) continue; assert(ShuffleVT.getVectorElementType() == SmallestEltTy); - Src.ShuffleVec = DAG.getNode(ISD::BITCAST, dl, ShuffleVT, Src.ShuffleVec); + Src.ShuffleVec = DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, ShuffleVT, Src.ShuffleVec); Src.WindowScale = SrcEltTy.getSizeInBits() / SmallestEltTy.getSizeInBits(); Src.WindowBase *= Src.WindowScale; } - // Final sanity check before we try to actually produce a shuffle. + // Final check before we try to actually produce a shuffle. LLVM_DEBUG(for (auto Src : Sources) assert(Src.ShuffleVec.getValueType() == ShuffleVT);); @@ -7552,7 +8277,7 @@ SDValue ARMTargetLowering::ReconstructShuffle(SDValue Op, // trunc. So only std::min(SrcBits, DestBits) actually get defined in this // segment. EVT OrigEltTy = Entry.getOperand(0).getValueType().getVectorElementType(); - int BitsDefined = std::min(OrigEltTy.getSizeInBits(), + int BitsDefined = std::min(OrigEltTy.getScalarSizeInBits(), VT.getScalarSizeInBits()); int LanesDefined = BitsDefined / BitsPerShuffleLane; @@ -7579,7 +8304,7 @@ SDValue ARMTargetLowering::ReconstructShuffle(SDValue Op, ShuffleOps[1], Mask, DAG); if (!Shuffle) return SDValue(); - return DAG.getNode(ISD::BITCAST, dl, VT, Shuffle); + return DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, VT, Shuffle); } enum ShuffleOpCodes { @@ -7655,11 +8380,17 @@ bool ARMTargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const { isVTBLMask(M, VT) || isNEONTwoResultShuffleMask(M, VT, WhichResult, isV_UNDEF))) return true; - else if (Subtarget->hasNEON() && (VT == MVT::v8i16 || VT == MVT::v16i8) && + else if ((VT == MVT::v8i16 || VT == MVT::v8f16 || VT == MVT::v16i8) && isReverseMask(M, VT)) return true; else if (Subtarget->hasMVEIntegerOps() && - (isVMOVNMask(M, VT, 0) || isVMOVNMask(M, VT, 1))) + (isVMOVNMask(M, VT, true, false) || + isVMOVNMask(M, VT, false, false) || isVMOVNMask(M, VT, true, true))) + return true; + else if (Subtarget->hasMVEIntegerOps() && + (isTruncMask(M, VT, false, false) || + isTruncMask(M, VT, false, true) || + isTruncMask(M, VT, true, false) || isTruncMask(M, VT, true, true))) return true; else return false; @@ -7689,14 +8420,13 @@ static SDValue GeneratePerfectShuffle(unsigned PFEntry, SDValue LHS, default: llvm_unreachable("Unknown shuffle opcode!"); case OP_VREV: // VREV divides the vector in half and swaps within the half. - if (VT.getVectorElementType() == MVT::i32 || - VT.getVectorElementType() == MVT::f32) + if (VT.getScalarSizeInBits() == 32) return DAG.getNode(ARMISD::VREV64, dl, VT, OpLHS); // vrev <4 x i16> -> VREV32 - if (VT.getVectorElementType() == MVT::i16) + if (VT.getScalarSizeInBits() == 16) return DAG.getNode(ARMISD::VREV32, dl, VT, OpLHS); // vrev <4 x i8> -> VREV16 - assert(VT.getVectorElementType() == MVT::i8); + assert(VT.getScalarSizeInBits() == 8); return DAG.getNode(ARMISD::VREV16, dl, VT, OpLHS); case OP_VDUP0: case OP_VDUP1: @@ -7734,9 +8464,8 @@ static SDValue LowerVECTOR_SHUFFLEv8i8(SDValue Op, SDLoc DL(Op); SmallVector<SDValue, 8> VTBLMask; - for (ArrayRef<int>::iterator - I = ShuffleMask.begin(), E = ShuffleMask.end(); I != E; ++I) - VTBLMask.push_back(DAG.getConstant(*I, DL, MVT::i32)); + for (int I : ShuffleMask) + VTBLMask.push_back(DAG.getConstant(I, DL, MVT::i32)); if (V2.getNode()->isUndef()) return DAG.getNode(ARMISD::VTBL1, DL, MVT::v8i8, V1, @@ -7746,25 +8475,29 @@ static SDValue LowerVECTOR_SHUFFLEv8i8(SDValue Op, DAG.getBuildVector(MVT::v8i8, DL, VTBLMask)); } -static SDValue LowerReverse_VECTOR_SHUFFLEv16i8_v8i16(SDValue Op, - SelectionDAG &DAG) { +static SDValue LowerReverse_VECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) { SDLoc DL(Op); - SDValue OpLHS = Op.getOperand(0); - EVT VT = OpLHS.getValueType(); + EVT VT = Op.getValueType(); - assert((VT == MVT::v8i16 || VT == MVT::v16i8) && + assert((VT == MVT::v8i16 || VT == MVT::v8f16 || VT == MVT::v16i8) && "Expect an v8i16/v16i8 type"); - OpLHS = DAG.getNode(ARMISD::VREV64, DL, VT, OpLHS); - // For a v16i8 type: After the VREV, we have got <8, ...15, 8, ..., 0>. Now, + SDValue OpLHS = DAG.getNode(ARMISD::VREV64, DL, VT, Op.getOperand(0)); + // For a v16i8 type: After the VREV, we have got <7, ..., 0, 15, ..., 8>. Now, // extract the first 8 bytes into the top double word and the last 8 bytes - // into the bottom double word. The v8i16 case is similar. - unsigned ExtractNum = (VT == MVT::v16i8) ? 8 : 4; - return DAG.getNode(ARMISD::VEXT, DL, VT, OpLHS, OpLHS, - DAG.getConstant(ExtractNum, DL, MVT::i32)); + // into the bottom double word, through a new vector shuffle that will be + // turned into a VEXT on Neon, or a couple of VMOVDs on MVE. + std::vector<int> NewMask; + for (unsigned i = 0; i < VT.getVectorNumElements() / 2; i++) + NewMask.push_back(VT.getVectorNumElements() / 2 + i); + for (unsigned i = 0; i < VT.getVectorNumElements() / 2; i++) + NewMask.push_back(i); + return DAG.getVectorShuffle(VT, DL, OpLHS, OpLHS, NewMask); } static EVT getVectorTyFromPredicateVector(EVT VT) { switch (VT.getSimpleVT().SimpleTy) { + case MVT::v2i1: + return MVT::v2f64; case MVT::v4i1: return MVT::v4i32; case MVT::v8i1: @@ -7821,6 +8554,7 @@ static SDValue LowerVECTOR_SHUFFLE_i1(SDValue Op, SelectionDAG &DAG, "No support for vector shuffle of boolean predicates"); SDValue V1 = Op.getOperand(0); + SDValue V2 = Op.getOperand(1); SDLoc dl(Op); if (isReverseMask(ShuffleMask, VT)) { SDValue cast = DAG.getNode(ARMISD::PREDICATE_CAST, dl, MVT::i32, V1); @@ -7838,15 +8572,26 @@ static SDValue LowerVECTOR_SHUFFLE_i1(SDValue Op, SelectionDAG &DAG, // many cases the generated code might be even better than scalar code // operating on bits. Just imagine trying to shuffle 8 arbitrary 2-bit // fields in a register into 8 other arbitrary 2-bit fields! - SDValue PredAsVector = PromoteMVEPredVector(dl, V1, VT, DAG); - EVT NewVT = PredAsVector.getValueType(); + SDValue PredAsVector1 = PromoteMVEPredVector(dl, V1, VT, DAG); + EVT NewVT = PredAsVector1.getValueType(); + SDValue PredAsVector2 = V2.isUndef() ? DAG.getUNDEF(NewVT) + : PromoteMVEPredVector(dl, V2, VT, DAG); + assert(PredAsVector2.getValueType() == NewVT && + "Expected identical vector type in expanded i1 shuffle!"); // Do the shuffle! - SDValue Shuffled = DAG.getVectorShuffle(NewVT, dl, PredAsVector, - DAG.getUNDEF(NewVT), ShuffleMask); + SDValue Shuffled = DAG.getVectorShuffle(NewVT, dl, PredAsVector1, + PredAsVector2, ShuffleMask); // Now return the result of comparing the shuffled vector with zero, - // which will generate a real predicate, i.e. v4i1, v8i1 or v16i1. + // which will generate a real predicate, i.e. v4i1, v8i1 or v16i1. For a v2i1 + // we convert to a v4i1 compare to fill in the two halves of the i64 as i32s. + if (VT == MVT::v2i1) { + SDValue BC = DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, MVT::v4i32, Shuffled); + SDValue Cmp = DAG.getNode(ARMISD::VCMPZ, dl, MVT::v4i1, BC, + DAG.getConstant(ARMCC::NE, dl, MVT::i32)); + return DAG.getNode(ARMISD::PREDICATE_CAST, dl, MVT::v2i1, Cmp); + } return DAG.getNode(ARMISD::VCMPZ, dl, VT, Shuffled, DAG.getConstant(ARMCC::NE, dl, MVT::i32)); } @@ -7904,8 +8649,8 @@ static SDValue LowerVECTOR_SHUFFLEUsingMovs(SDValue Op, Input = Op->getOperand(1); Elt -= 4; } - SDValue BitCast = DAG.getBitcast(MVT::v4i32, Input); - Parts[Part] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, BitCast, + SDValue BitCast = DAG.getBitcast(MVT::v4f32, Input); + Parts[Part] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f32, BitCast, DAG.getConstant(Elt, dl, MVT::i32)); } } @@ -7924,19 +8669,70 @@ static SDValue LowerVECTOR_SHUFFLEUsingMovs(SDValue Op, Parts[Part] ? -1 : ShuffleMask[Part * QuarterSize + i]); SDValue NewShuffle = DAG.getVectorShuffle( VT, dl, Op->getOperand(0), Op->getOperand(1), NewShuffleMask); - SDValue BitCast = DAG.getBitcast(MVT::v4i32, NewShuffle); + SDValue BitCast = DAG.getBitcast(MVT::v4f32, NewShuffle); for (int Part = 0; Part < 4; ++Part) if (!Parts[Part]) - Parts[Part] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, + Parts[Part] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f32, BitCast, DAG.getConstant(Part, dl, MVT::i32)); } // Build a vector out of the various parts and bitcast it back to the original // type. - SDValue NewVec = DAG.getBuildVector(MVT::v4i32, dl, Parts); + SDValue NewVec = DAG.getNode(ARMISD::BUILD_VECTOR, dl, MVT::v4f32, Parts); return DAG.getBitcast(VT, NewVec); } +static SDValue LowerVECTOR_SHUFFLEUsingOneOff(SDValue Op, + ArrayRef<int> ShuffleMask, + SelectionDAG &DAG) { + SDValue V1 = Op.getOperand(0); + SDValue V2 = Op.getOperand(1); + EVT VT = Op.getValueType(); + unsigned NumElts = VT.getVectorNumElements(); + + // An One-Off Identity mask is one that is mostly an identity mask from as + // single source but contains a single element out-of-place, either from a + // different vector or from another position in the same vector. As opposed to + // lowering this via a ARMISD::BUILD_VECTOR we can generate an extract/insert + // pair directly. + auto isOneOffIdentityMask = [](ArrayRef<int> Mask, EVT VT, int BaseOffset, + int &OffElement) { + OffElement = -1; + int NonUndef = 0; + for (int i = 0, NumMaskElts = Mask.size(); i < NumMaskElts; ++i) { + if (Mask[i] == -1) + continue; + NonUndef++; + if (Mask[i] != i + BaseOffset) { + if (OffElement == -1) + OffElement = i; + else + return false; + } + } + return NonUndef > 2 && OffElement != -1; + }; + int OffElement; + SDValue VInput; + if (isOneOffIdentityMask(ShuffleMask, VT, 0, OffElement)) + VInput = V1; + else if (isOneOffIdentityMask(ShuffleMask, VT, NumElts, OffElement)) + VInput = V2; + else + return SDValue(); + + SDLoc dl(Op); + EVT SVT = VT.getScalarType() == MVT::i8 || VT.getScalarType() == MVT::i16 + ? MVT::i32 + : VT.getScalarType(); + SDValue Elt = DAG.getNode( + ISD::EXTRACT_VECTOR_ELT, dl, SVT, + ShuffleMask[OffElement] < (int)NumElts ? V1 : V2, + DAG.getVectorIdxConstant(ShuffleMask[OffElement] % NumElts, dl)); + return DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, VInput, Elt, + DAG.getVectorIdxConstant(OffElement % NumElts, dl)); +} + static SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, const ARMSubtarget *ST) { SDValue V1 = Op.getOperand(0); @@ -8023,12 +8819,15 @@ static SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, } } if (ST->hasMVEIntegerOps()) { - if (isVMOVNMask(ShuffleMask, VT, 0)) + if (isVMOVNMask(ShuffleMask, VT, false, false)) return DAG.getNode(ARMISD::VMOVN, dl, VT, V2, V1, DAG.getConstant(0, dl, MVT::i32)); - if (isVMOVNMask(ShuffleMask, VT, 1)) + if (isVMOVNMask(ShuffleMask, VT, true, false)) return DAG.getNode(ARMISD::VMOVN, dl, VT, V1, V2, DAG.getConstant(1, dl, MVT::i32)); + if (isVMOVNMask(ShuffleMask, VT, true, true)) + return DAG.getNode(ARMISD::VMOVN, dl, VT, V1, V1, + DAG.getConstant(1, dl, MVT::i32)); } // Also check for these shuffles through CONCAT_VECTORS: we canonicalize @@ -8070,6 +8869,29 @@ static SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, } } + if (ST->hasMVEIntegerOps() && EltSize <= 32) { + if (SDValue V = LowerVECTOR_SHUFFLEUsingOneOff(Op, ShuffleMask, DAG)) + return V; + + for (bool Top : {false, true}) { + for (bool SingleSource : {false, true}) { + if (isTruncMask(ShuffleMask, VT, Top, SingleSource)) { + MVT FromSVT = MVT::getIntegerVT(EltSize * 2); + MVT FromVT = MVT::getVectorVT(FromSVT, ShuffleMask.size() / 2); + SDValue Lo = DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, FromVT, V1); + SDValue Hi = DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, FromVT, + SingleSource ? V1 : V2); + if (Top) { + SDValue Amt = DAG.getConstant(EltSize, dl, FromVT); + Lo = DAG.getNode(ISD::SRL, dl, FromVT, Lo, Amt); + Hi = DAG.getNode(ISD::SRL, dl, FromVT, Hi, Amt); + } + return DAG.getNode(ARMISD::MVETRUNC, dl, VT, Lo, Hi); + } + } + } + } + // If the shuffle is not directly supported and it has 4 elements, use // the PerfectShuffle-generated table to synthesize it from other shuffles. unsigned NumElts = VT.getVectorNumElements(); @@ -8124,8 +8946,9 @@ static SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG, return DAG.getNode(ISD::BITCAST, dl, VT, Val); } - if (ST->hasNEON() && (VT == MVT::v8i16 || VT == MVT::v16i8) && isReverseMask(ShuffleMask, VT)) - return LowerReverse_VECTOR_SHUFFLEv16i8_v8i16(Op, DAG); + if ((VT == MVT::v8i16 || VT == MVT::v8f16 || VT == MVT::v16i8) && + isReverseMask(ShuffleMask, VT)) + return LowerReverse_VECTOR_SHUFFLE(Op, DAG); if (ST->hasNEON() && VT == MVT::v8i8) if (SDValue NewOp = LowerVECTOR_SHUFFLEv8i8(Op, ShuffleMask, DAG)) @@ -8242,54 +9065,75 @@ static SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG, static SDValue LowerCONCAT_VECTORS_i1(SDValue Op, SelectionDAG &DAG, const ARMSubtarget *ST) { - SDValue V1 = Op.getOperand(0); - SDValue V2 = Op.getOperand(1); SDLoc dl(Op); - EVT VT = Op.getValueType(); - EVT Op1VT = V1.getValueType(); - EVT Op2VT = V2.getValueType(); - unsigned NumElts = VT.getVectorNumElements(); - - assert(Op1VT == Op2VT && "Operand types don't match!"); - assert(VT.getScalarSizeInBits() == 1 && + assert(Op.getValueType().getScalarSizeInBits() == 1 && + "Unexpected custom CONCAT_VECTORS lowering"); + assert(isPowerOf2_32(Op.getNumOperands()) && "Unexpected custom CONCAT_VECTORS lowering"); assert(ST->hasMVEIntegerOps() && "CONCAT_VECTORS lowering only supported for MVE"); - SDValue NewV1 = PromoteMVEPredVector(dl, V1, Op1VT, DAG); - SDValue NewV2 = PromoteMVEPredVector(dl, V2, Op2VT, DAG); - - // We now have Op1 + Op2 promoted to vectors of integers, where v8i1 gets - // promoted to v8i16, etc. - - MVT ElType = getVectorTyFromPredicateVector(VT).getScalarType().getSimpleVT(); - - // Extract the vector elements from Op1 and Op2 one by one and truncate them - // to be the right size for the destination. For example, if Op1 is v4i1 then - // the promoted vector is v4i32. The result of concatentation gives a v8i1, - // which when promoted is v8i16. That means each i32 element from Op1 needs - // truncating to i16 and inserting in the result. - EVT ConcatVT = MVT::getVectorVT(ElType, NumElts); - SDValue ConVec = DAG.getNode(ISD::UNDEF, dl, ConcatVT); - auto ExractInto = [&DAG, &dl](SDValue NewV, SDValue ConVec, unsigned &j) { - EVT NewVT = NewV.getValueType(); - EVT ConcatVT = ConVec.getValueType(); - for (unsigned i = 0, e = NewVT.getVectorNumElements(); i < e; i++, j++) { - SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, NewV, - DAG.getIntPtrConstant(i, dl)); - ConVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ConcatVT, ConVec, Elt, - DAG.getConstant(j, dl, MVT::i32)); + auto ConcatPair = [&](SDValue V1, SDValue V2) { + EVT Op1VT = V1.getValueType(); + EVT Op2VT = V2.getValueType(); + assert(Op1VT == Op2VT && "Operand types don't match!"); + EVT VT = Op1VT.getDoubleNumVectorElementsVT(*DAG.getContext()); + + SDValue NewV1 = PromoteMVEPredVector(dl, V1, Op1VT, DAG); + SDValue NewV2 = PromoteMVEPredVector(dl, V2, Op2VT, DAG); + + // We now have Op1 + Op2 promoted to vectors of integers, where v8i1 gets + // promoted to v8i16, etc. + MVT ElType = + getVectorTyFromPredicateVector(VT).getScalarType().getSimpleVT(); + unsigned NumElts = 2 * Op1VT.getVectorNumElements(); + + // Extract the vector elements from Op1 and Op2 one by one and truncate them + // to be the right size for the destination. For example, if Op1 is v4i1 + // then the promoted vector is v4i32. The result of concatenation gives a + // v8i1, which when promoted is v8i16. That means each i32 element from Op1 + // needs truncating to i16 and inserting in the result. + EVT ConcatVT = MVT::getVectorVT(ElType, NumElts); + SDValue ConVec = DAG.getNode(ISD::UNDEF, dl, ConcatVT); + auto ExtractInto = [&DAG, &dl](SDValue NewV, SDValue ConVec, unsigned &j) { + EVT NewVT = NewV.getValueType(); + EVT ConcatVT = ConVec.getValueType(); + for (unsigned i = 0, e = NewVT.getVectorNumElements(); i < e; i++, j++) { + SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, NewV, + DAG.getIntPtrConstant(i, dl)); + ConVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ConcatVT, ConVec, Elt, + DAG.getConstant(j, dl, MVT::i32)); + } + return ConVec; + }; + unsigned j = 0; + ConVec = ExtractInto(NewV1, ConVec, j); + ConVec = ExtractInto(NewV2, ConVec, j); + + // Now return the result of comparing the subvector with zero, which will + // generate a real predicate, i.e. v4i1, v8i1 or v16i1. For a v2i1 we + // convert to a v4i1 compare to fill in the two halves of the i64 as i32s. + if (VT == MVT::v2i1) { + SDValue BC = DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, MVT::v4i32, ConVec); + SDValue Cmp = DAG.getNode(ARMISD::VCMPZ, dl, MVT::v4i1, BC, + DAG.getConstant(ARMCC::NE, dl, MVT::i32)); + return DAG.getNode(ARMISD::PREDICATE_CAST, dl, MVT::v2i1, Cmp); } - return ConVec; + return DAG.getNode(ARMISD::VCMPZ, dl, VT, ConVec, + DAG.getConstant(ARMCC::NE, dl, MVT::i32)); }; - unsigned j = 0; - ConVec = ExractInto(NewV1, ConVec, j); - ConVec = ExractInto(NewV2, ConVec, j); - // Now return the result of comparing the subvector with zero, - // which will generate a real predicate, i.e. v4i1, v8i1 or v16i1. - return DAG.getNode(ARMISD::VCMPZ, dl, VT, ConVec, - DAG.getConstant(ARMCC::NE, dl, MVT::i32)); + // Concat each pair of subvectors and pack into the lower half of the array. + SmallVector<SDValue> ConcatOps(Op->op_begin(), Op->op_end()); + while (ConcatOps.size() > 1) { + for (unsigned I = 0, E = ConcatOps.size(); I != E; I += 2) { + SDValue V1 = ConcatOps[I]; + SDValue V2 = ConcatOps[I + 1]; + ConcatOps[I / 2] = ConcatPair(V1, V2); + } + ConcatOps.resize(ConcatOps.size() / 2); + } + return ConcatOps[0]; } static SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG, @@ -8339,6 +9183,22 @@ static SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG, MVT ElType = getVectorTyFromPredicateVector(VT).getScalarType().getSimpleVT(); + if (NumElts == 2) { + EVT SubVT = MVT::v4i32; + SDValue SubVec = DAG.getNode(ISD::UNDEF, dl, SubVT); + for (unsigned i = Index, j = 0; i < (Index + NumElts); i++, j += 2) { + SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, NewV1, + DAG.getIntPtrConstant(i, dl)); + SubVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, SubVT, SubVec, Elt, + DAG.getConstant(j, dl, MVT::i32)); + SubVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, SubVT, SubVec, Elt, + DAG.getConstant(j + 1, dl, MVT::i32)); + } + SDValue Cmp = DAG.getNode(ARMISD::VCMPZ, dl, MVT::v4i1, SubVec, + DAG.getConstant(ARMCC::NE, dl, MVT::i32)); + return DAG.getNode(ARMISD::PREDICATE_CAST, dl, MVT::v2i1, Cmp); + } + EVT SubVT = MVT::getVectorVT(ElType, NumElts); SDValue SubVec = DAG.getNode(ISD::UNDEF, dl, SubVT); for (unsigned i = Index, j = 0; i < (Index + NumElts); i++, j++) { @@ -8354,6 +9214,116 @@ static SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG, DAG.getConstant(ARMCC::NE, dl, MVT::i32)); } +// Turn a truncate into a predicate (an i1 vector) into icmp(and(x, 1), 0). +static SDValue LowerTruncatei1(SDNode *N, SelectionDAG &DAG, + const ARMSubtarget *ST) { + assert(ST->hasMVEIntegerOps() && "Expected MVE!"); + EVT VT = N->getValueType(0); + assert((VT == MVT::v16i1 || VT == MVT::v8i1 || VT == MVT::v4i1) && + "Expected a vector i1 type!"); + SDValue Op = N->getOperand(0); + EVT FromVT = Op.getValueType(); + SDLoc DL(N); + + SDValue And = + DAG.getNode(ISD::AND, DL, FromVT, Op, DAG.getConstant(1, DL, FromVT)); + return DAG.getNode(ISD::SETCC, DL, VT, And, DAG.getConstant(0, DL, FromVT), + DAG.getCondCode(ISD::SETNE)); +} + +static SDValue LowerTruncate(SDNode *N, SelectionDAG &DAG, + const ARMSubtarget *Subtarget) { + if (!Subtarget->hasMVEIntegerOps()) + return SDValue(); + + EVT ToVT = N->getValueType(0); + if (ToVT.getScalarType() == MVT::i1) + return LowerTruncatei1(N, DAG, Subtarget); + + // MVE does not have a single instruction to perform the truncation of a v4i32 + // into the lower half of a v8i16, in the same way that a NEON vmovn would. + // Most of the instructions in MVE follow the 'Beats' system, where moving + // values from different lanes is usually something that the instructions + // avoid. + // + // Instead it has top/bottom instructions such as VMOVLT/B and VMOVNT/B, + // which take a the top/bottom half of a larger lane and extend it (or do the + // opposite, truncating into the top/bottom lane from a larger lane). Note + // that because of the way we widen lanes, a v4i16 is really a v4i32 using the + // bottom 16bits from each vector lane. This works really well with T/B + // instructions, but that doesn't extend to v8i32->v8i16 where the lanes need + // to move order. + // + // But truncates and sext/zext are always going to be fairly common from llvm. + // We have several options for how to deal with them: + // - Wherever possible combine them into an instruction that makes them + // "free". This includes loads/stores, which can perform the trunc as part + // of the memory operation. Or certain shuffles that can be turned into + // VMOVN/VMOVL. + // - Lane Interleaving to transform blocks surrounded by ext/trunc. So + // trunc(mul(sext(a), sext(b))) may become + // VMOVNT(VMUL(VMOVLB(a), VMOVLB(b)), VMUL(VMOVLT(a), VMOVLT(b))). (Which in + // this case can use VMULL). This is performed in the + // MVELaneInterleavingPass. + // - Otherwise we have an option. By default we would expand the + // zext/sext/trunc into a series of lane extract/inserts going via GPR + // registers. One for each vector lane in the vector. This can obviously be + // very expensive. + // - The other option is to use the fact that loads/store can extend/truncate + // to turn a trunc into two truncating stack stores and a stack reload. This + // becomes 3 back-to-back memory operations, but at least that is less than + // all the insert/extracts. + // + // In order to do the last, we convert certain trunc's into MVETRUNC, which + // are either optimized where they can be, or eventually lowered into stack + // stores/loads. This prevents us from splitting a v8i16 trunc into two stores + // two early, where other instructions would be better, and stops us from + // having to reconstruct multiple buildvector shuffles into loads/stores. + if (ToVT != MVT::v8i16 && ToVT != MVT::v16i8) + return SDValue(); + EVT FromVT = N->getOperand(0).getValueType(); + if (FromVT != MVT::v8i32 && FromVT != MVT::v16i16) + return SDValue(); + + SDValue Lo, Hi; + std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0); + SDLoc DL(N); + return DAG.getNode(ARMISD::MVETRUNC, DL, ToVT, Lo, Hi); +} + +static SDValue LowerVectorExtend(SDNode *N, SelectionDAG &DAG, + const ARMSubtarget *Subtarget) { + if (!Subtarget->hasMVEIntegerOps()) + return SDValue(); + + // See LowerTruncate above for an explanation of MVEEXT/MVETRUNC. + + EVT ToVT = N->getValueType(0); + if (ToVT != MVT::v16i32 && ToVT != MVT::v8i32 && ToVT != MVT::v16i16) + return SDValue(); + SDValue Op = N->getOperand(0); + EVT FromVT = Op.getValueType(); + if (FromVT != MVT::v8i16 && FromVT != MVT::v16i8) + return SDValue(); + + SDLoc DL(N); + EVT ExtVT = ToVT.getHalfNumVectorElementsVT(*DAG.getContext()); + if (ToVT.getScalarType() == MVT::i32 && FromVT.getScalarType() == MVT::i8) + ExtVT = MVT::v8i16; + + unsigned Opcode = + N->getOpcode() == ISD::SIGN_EXTEND ? ARMISD::MVESEXT : ARMISD::MVEZEXT; + SDValue Ext = DAG.getNode(Opcode, DL, DAG.getVTList(ExtVT, ExtVT), Op); + SDValue Ext1 = Ext.getValue(1); + + if (ToVT.getScalarType() == MVT::i32 && FromVT.getScalarType() == MVT::i8) { + Ext = DAG.getNode(N->getOpcode(), DL, MVT::v8i32, Ext); + Ext1 = DAG.getNode(N->getOpcode(), DL, MVT::v8i32, Ext1); + } + + return DAG.getNode(ISD::CONCAT_VECTORS, DL, ToVT, Ext, Ext1); +} + /// isExtendedBUILD_VECTOR - Check if N is a constant BUILD_VECTOR where each /// element has been zero/sign-extended, depending on the isSigned parameter, /// from an integer type half its size. @@ -8379,7 +9349,7 @@ static bool isExtendedBUILD_VECTOR(SDNode *N, SelectionDAG &DAG, Hi1->getSExtValue() == Lo1->getSExtValue() >> 32) return true; } else { - if (Hi0->isNullValue() && Hi1->isNullValue()) + if (Hi0->isZero() && Hi1->isZero()) return true; } return false; @@ -8418,10 +9388,11 @@ static bool isSignExtended(SDNode *N, SelectionDAG &DAG) { return false; } -/// isZeroExtended - Check if a node is a vector value that is zero-extended -/// or a constant BUILD_VECTOR with zero-extended elements. +/// isZeroExtended - Check if a node is a vector value that is zero-extended (or +/// any-extended) or a constant BUILD_VECTOR with zero-extended elements. static bool isZeroExtended(SDNode *N, SelectionDAG &DAG) { - if (N->getOpcode() == ISD::ZERO_EXTEND || ISD::isZEXTLoad(N)) + if (N->getOpcode() == ISD::ZERO_EXTEND || N->getOpcode() == ISD::ANY_EXTEND || + ISD::isZEXTLoad(N)) return true; if (isExtendedBUILD_VECTOR(N, DAG, false)) return true; @@ -8476,26 +9447,27 @@ static SDValue SkipLoadExtensionForVMULL(LoadSDNode *LD, SelectionDAG& DAG) { // The load already has the right type. if (ExtendedTy == LD->getMemoryVT()) return DAG.getLoad(LD->getMemoryVT(), SDLoc(LD), LD->getChain(), - LD->getBasePtr(), LD->getPointerInfo(), - LD->getAlignment(), LD->getMemOperand()->getFlags()); + LD->getBasePtr(), LD->getPointerInfo(), LD->getAlign(), + LD->getMemOperand()->getFlags()); // We need to create a zextload/sextload. We cannot just create a load // followed by a zext/zext node because LowerMUL is also run during normal // operation legalization where we can't create illegal types. return DAG.getExtLoad(LD->getExtensionType(), SDLoc(LD), ExtendedTy, LD->getChain(), LD->getBasePtr(), LD->getPointerInfo(), - LD->getMemoryVT(), LD->getAlignment(), + LD->getMemoryVT(), LD->getAlign(), LD->getMemOperand()->getFlags()); } /// SkipExtensionForVMULL - For a node that is a SIGN_EXTEND, ZERO_EXTEND, -/// extending load, or BUILD_VECTOR with extended elements, return the -/// unextended value. The unextended vector should be 64 bits so that it can +/// ANY_EXTEND, extending load, or BUILD_VECTOR with extended elements, return +/// the unextended value. The unextended vector should be 64 bits so that it can /// be used as an operand to a VMULL instruction. If the original vector size /// before extension is less than 64 bits we add a an extension to resize /// the vector to 64 bits. static SDValue SkipExtensionForVMULL(SDNode *N, SelectionDAG &DAG) { - if (N->getOpcode() == ISD::SIGN_EXTEND || N->getOpcode() == ISD::ZERO_EXTEND) + if (N->getOpcode() == ISD::SIGN_EXTEND || + N->getOpcode() == ISD::ZERO_EXTEND || N->getOpcode() == ISD::ANY_EXTEND) return AddRequiredExtensionForVMULL(N->getOperand(0), DAG, N->getOperand(0)->getValueType(0), N->getValueType(0), @@ -8892,7 +9864,7 @@ SDValue ARMTargetLowering::LowerFSINCOS(SDValue Op, SelectionDAG &DAG) const { if (ShouldUseSRet) { // Create stack object for sret. const uint64_t ByteSize = DL.getTypeAllocSize(RetTy); - const unsigned StackAlign = DL.getPrefTypeAlignment(RetTy); + const Align StackAlign = DL.getPrefTypeAlign(RetTy); int FrameIdx = MFI.CreateStackObject(ByteSize, StackAlign, false); SRet = DAG.getFrameIndex(FrameIdx, TLI.getPointerTy(DL)); @@ -8992,7 +9964,7 @@ ARMTargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor, if (N->getOpcode() != ISD::SDIV) return SDValue(); - const auto &ST = static_cast<const ARMSubtarget&>(DAG.getSubtarget()); + const auto &ST = DAG.getSubtarget<ARMSubtarget>(); const bool MinSize = ST.hasMinSize(); const bool HasDivide = ST.isThumb() ? ST.hasDivideInThumbMode() : ST.hasDivideInARMMode(); @@ -9067,69 +10039,136 @@ void ARMTargetLowering::ExpandDIV_Windows( DAG.getConstant(32, dl, TLI.getPointerTy(DL))); Upper = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, Upper); - Results.push_back(Lower); - Results.push_back(Upper); + Results.push_back(DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Lower, Upper)); } static SDValue LowerPredicateLoad(SDValue Op, SelectionDAG &DAG) { LoadSDNode *LD = cast<LoadSDNode>(Op.getNode()); EVT MemVT = LD->getMemoryVT(); - assert((MemVT == MVT::v4i1 || MemVT == MVT::v8i1 || MemVT == MVT::v16i1) && + assert((MemVT == MVT::v2i1 || MemVT == MVT::v4i1 || MemVT == MVT::v8i1 || + MemVT == MVT::v16i1) && "Expected a predicate type!"); assert(MemVT == Op.getValueType()); assert(LD->getExtensionType() == ISD::NON_EXTLOAD && "Expected a non-extending load"); assert(LD->isUnindexed() && "Expected a unindexed load"); - // The basic MVE VLDR on a v4i1/v8i1 actually loads the entire 16bit + // The basic MVE VLDR on a v2i1/v4i1/v8i1 actually loads the entire 16bit // predicate, with the "v4i1" bits spread out over the 16 bits loaded. We - // need to make sure that 8/4 bits are actually loaded into the correct + // need to make sure that 8/4/2 bits are actually loaded into the correct // place, which means loading the value and then shuffling the values into // the bottom bits of the predicate. // Equally, VLDR for an v16i1 will actually load 32bits (so will be incorrect // for BE). + // Speaking of BE, apparently the rest of llvm will assume a reverse order to + // a natural VMSR(load), so needs to be reversed. SDLoc dl(Op); SDValue Load = DAG.getExtLoad( ISD::EXTLOAD, dl, MVT::i32, LD->getChain(), LD->getBasePtr(), EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()), LD->getMemOperand()); - SDValue Pred = DAG.getNode(ARMISD::PREDICATE_CAST, dl, MVT::v16i1, Load); + SDValue Val = Load; + if (DAG.getDataLayout().isBigEndian()) + Val = DAG.getNode(ISD::SRL, dl, MVT::i32, + DAG.getNode(ISD::BITREVERSE, dl, MVT::i32, Load), + DAG.getConstant(32 - MemVT.getSizeInBits(), dl, MVT::i32)); + SDValue Pred = DAG.getNode(ARMISD::PREDICATE_CAST, dl, MVT::v16i1, Val); if (MemVT != MVT::v16i1) Pred = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MemVT, Pred, DAG.getConstant(0, dl, MVT::i32)); return DAG.getMergeValues({Pred, Load.getValue(1)}, dl); } +void ARMTargetLowering::LowerLOAD(SDNode *N, SmallVectorImpl<SDValue> &Results, + SelectionDAG &DAG) const { + LoadSDNode *LD = cast<LoadSDNode>(N); + EVT MemVT = LD->getMemoryVT(); + assert(LD->isUnindexed() && "Loads should be unindexed at this point."); + + if (MemVT == MVT::i64 && Subtarget->hasV5TEOps() && + !Subtarget->isThumb1Only() && LD->isVolatile()) { + SDLoc dl(N); + SDValue Result = DAG.getMemIntrinsicNode( + ARMISD::LDRD, dl, DAG.getVTList({MVT::i32, MVT::i32, MVT::Other}), + {LD->getChain(), LD->getBasePtr()}, MemVT, LD->getMemOperand()); + SDValue Lo = Result.getValue(DAG.getDataLayout().isLittleEndian() ? 0 : 1); + SDValue Hi = Result.getValue(DAG.getDataLayout().isLittleEndian() ? 1 : 0); + SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Lo, Hi); + Results.append({Pair, Result.getValue(2)}); + } +} + static SDValue LowerPredicateStore(SDValue Op, SelectionDAG &DAG) { StoreSDNode *ST = cast<StoreSDNode>(Op.getNode()); EVT MemVT = ST->getMemoryVT(); - assert((MemVT == MVT::v4i1 || MemVT == MVT::v8i1 || MemVT == MVT::v16i1) && + assert((MemVT == MVT::v2i1 || MemVT == MVT::v4i1 || MemVT == MVT::v8i1 || + MemVT == MVT::v16i1) && "Expected a predicate type!"); assert(MemVT == ST->getValue().getValueType()); assert(!ST->isTruncatingStore() && "Expected a non-extending store"); assert(ST->isUnindexed() && "Expected a unindexed store"); - // Only store the v4i1 or v8i1 worth of bits, via a buildvector with top bits - // unset and a scalar store. + // Only store the v2i1 or v4i1 or v8i1 worth of bits, via a buildvector with + // top bits unset and a scalar store. SDLoc dl(Op); SDValue Build = ST->getValue(); if (MemVT != MVT::v16i1) { SmallVector<SDValue, 16> Ops; - for (unsigned I = 0; I < MemVT.getVectorNumElements(); I++) + for (unsigned I = 0; I < MemVT.getVectorNumElements(); I++) { + unsigned Elt = DAG.getDataLayout().isBigEndian() + ? MemVT.getVectorNumElements() - I - 1 + : I; Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, Build, - DAG.getConstant(I, dl, MVT::i32))); + DAG.getConstant(Elt, dl, MVT::i32))); + } for (unsigned I = MemVT.getVectorNumElements(); I < 16; I++) Ops.push_back(DAG.getUNDEF(MVT::i32)); Build = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v16i1, Ops); } SDValue GRP = DAG.getNode(ARMISD::PREDICATE_CAST, dl, MVT::i32, Build); + if (MemVT == MVT::v16i1 && DAG.getDataLayout().isBigEndian()) + GRP = DAG.getNode(ISD::SRL, dl, MVT::i32, + DAG.getNode(ISD::BITREVERSE, dl, MVT::i32, GRP), + DAG.getConstant(16, dl, MVT::i32)); return DAG.getTruncStore( ST->getChain(), dl, GRP, ST->getBasePtr(), EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()), ST->getMemOperand()); } +static SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG, + const ARMSubtarget *Subtarget) { + StoreSDNode *ST = cast<StoreSDNode>(Op.getNode()); + EVT MemVT = ST->getMemoryVT(); + assert(ST->isUnindexed() && "Stores should be unindexed at this point."); + + if (MemVT == MVT::i64 && Subtarget->hasV5TEOps() && + !Subtarget->isThumb1Only() && ST->isVolatile()) { + SDNode *N = Op.getNode(); + SDLoc dl(N); + + SDValue Lo = DAG.getNode( + ISD::EXTRACT_ELEMENT, dl, MVT::i32, ST->getValue(), + DAG.getTargetConstant(DAG.getDataLayout().isLittleEndian() ? 0 : 1, dl, + MVT::i32)); + SDValue Hi = DAG.getNode( + ISD::EXTRACT_ELEMENT, dl, MVT::i32, ST->getValue(), + DAG.getTargetConstant(DAG.getDataLayout().isLittleEndian() ? 1 : 0, dl, + MVT::i32)); + + return DAG.getMemIntrinsicNode(ARMISD::STRD, dl, DAG.getVTList(MVT::Other), + {ST->getChain(), Lo, Hi, ST->getBasePtr()}, + MemVT, ST->getMemOperand()); + } else if (Subtarget->hasMVEIntegerOps() && + ((MemVT == MVT::v2i1 || MemVT == MVT::v4i1 || MemVT == MVT::v8i1 || + MemVT == MVT::v16i1))) { + return LowerPredicateStore(Op, DAG); + } + + return SDValue(); +} + static bool isZeroVector(SDValue N) { return (ISD::isBuildVectorAllZeros(N.getNode()) || (N->getOpcode() == ARMISD::VMOVIMM && @@ -9155,15 +10194,89 @@ static SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) { N->getMemoryVT(), N->getMemOperand(), N->getAddressingMode(), N->getExtensionType(), N->isExpandingLoad()); SDValue Combo = NewLoad; - if (!PassThru.isUndef() && - (PassThru.getOpcode() != ISD::BITCAST || - !isZeroVector(PassThru->getOperand(0)))) + bool PassThruIsCastZero = (PassThru.getOpcode() == ISD::BITCAST || + PassThru.getOpcode() == ARMISD::VECTOR_REG_CAST) && + isZeroVector(PassThru->getOperand(0)); + if (!PassThru.isUndef() && !PassThruIsCastZero) Combo = DAG.getNode(ISD::VSELECT, dl, VT, Mask, NewLoad, PassThru); return DAG.getMergeValues({Combo, NewLoad.getValue(1)}, dl); } +static SDValue LowerVecReduce(SDValue Op, SelectionDAG &DAG, + const ARMSubtarget *ST) { + if (!ST->hasMVEIntegerOps()) + return SDValue(); + + SDLoc dl(Op); + unsigned BaseOpcode = 0; + switch (Op->getOpcode()) { + default: llvm_unreachable("Expected VECREDUCE opcode"); + case ISD::VECREDUCE_FADD: BaseOpcode = ISD::FADD; break; + case ISD::VECREDUCE_FMUL: BaseOpcode = ISD::FMUL; break; + case ISD::VECREDUCE_MUL: BaseOpcode = ISD::MUL; break; + case ISD::VECREDUCE_AND: BaseOpcode = ISD::AND; break; + case ISD::VECREDUCE_OR: BaseOpcode = ISD::OR; break; + case ISD::VECREDUCE_XOR: BaseOpcode = ISD::XOR; break; + case ISD::VECREDUCE_FMAX: BaseOpcode = ISD::FMAXNUM; break; + case ISD::VECREDUCE_FMIN: BaseOpcode = ISD::FMINNUM; break; + } + + SDValue Op0 = Op->getOperand(0); + EVT VT = Op0.getValueType(); + EVT EltVT = VT.getVectorElementType(); + unsigned NumElts = VT.getVectorNumElements(); + unsigned NumActiveLanes = NumElts; + + assert((NumActiveLanes == 16 || NumActiveLanes == 8 || NumActiveLanes == 4 || + NumActiveLanes == 2) && + "Only expected a power 2 vector size"); + + // Use Mul(X, Rev(X)) until 4 items remain. Going down to 4 vector elements + // allows us to easily extract vector elements from the lanes. + while (NumActiveLanes > 4) { + unsigned RevOpcode = NumActiveLanes == 16 ? ARMISD::VREV16 : ARMISD::VREV32; + SDValue Rev = DAG.getNode(RevOpcode, dl, VT, Op0); + Op0 = DAG.getNode(BaseOpcode, dl, VT, Op0, Rev); + NumActiveLanes /= 2; + } + + SDValue Res; + if (NumActiveLanes == 4) { + // The remaining 4 elements are summed sequentially + SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Op0, + DAG.getConstant(0 * NumElts / 4, dl, MVT::i32)); + SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Op0, + DAG.getConstant(1 * NumElts / 4, dl, MVT::i32)); + SDValue Ext2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Op0, + DAG.getConstant(2 * NumElts / 4, dl, MVT::i32)); + SDValue Ext3 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Op0, + DAG.getConstant(3 * NumElts / 4, dl, MVT::i32)); + SDValue Res0 = DAG.getNode(BaseOpcode, dl, EltVT, Ext0, Ext1, Op->getFlags()); + SDValue Res1 = DAG.getNode(BaseOpcode, dl, EltVT, Ext2, Ext3, Op->getFlags()); + Res = DAG.getNode(BaseOpcode, dl, EltVT, Res0, Res1, Op->getFlags()); + } else { + SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Op0, + DAG.getConstant(0, dl, MVT::i32)); + SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Op0, + DAG.getConstant(1, dl, MVT::i32)); + Res = DAG.getNode(BaseOpcode, dl, EltVT, Ext0, Ext1, Op->getFlags()); + } + + // Result type may be wider than element type. + if (EltVT != Op->getValueType(0)) + Res = DAG.getNode(ISD::ANY_EXTEND, dl, Op->getValueType(0), Res); + return Res; +} + +static SDValue LowerVecReduceF(SDValue Op, SelectionDAG &DAG, + const ARMSubtarget *ST) { + if (!ST->hasMVEFloatOps()) + return SDValue(); + return LowerVecReduce(Op, DAG, ST); +} + static SDValue LowerAtomicLoadStore(SDValue Op, SelectionDAG &DAG) { - if (isStrongerThanMonotonic(cast<AtomicSDNode>(Op)->getOrdering())) + if (isStrongerThanMonotonic(cast<AtomicSDNode>(Op)->getSuccessOrdering())) // Acquire/Release load/store is not legal for targets without a dmb or // equivalent available. return SDValue(); @@ -9231,12 +10344,13 @@ static void ReplaceCMP_SWAP_64Results(SDNode *N, bool isBigEndian = DAG.getDataLayout().isBigEndian(); - Results.push_back( + SDValue Lo = DAG.getTargetExtractSubreg(isBigEndian ? ARM::gsub_1 : ARM::gsub_0, - SDLoc(N), MVT::i32, SDValue(CmpSwap, 0))); - Results.push_back( + SDLoc(N), MVT::i32, SDValue(CmpSwap, 0)); + SDValue Hi = DAG.getTargetExtractSubreg(isBigEndian ? ARM::gsub_0 : ARM::gsub_1, - SDLoc(N), MVT::i32, SDValue(CmpSwap, 0))); + SDLoc(N), MVT::i32, SDValue(CmpSwap, 0)); + Results.push_back(DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), MVT::i64, Lo, Hi)); Results.push_back(SDValue(CmpSwap, 2)); } @@ -9285,6 +10399,15 @@ SDValue ARMTargetLowering::LowerFSETCC(SDValue Op, SelectionDAG &DAG) const { return DAG.getMergeValues({Result, Chain}, dl); } +SDValue ARMTargetLowering::LowerSPONENTRY(SDValue Op, SelectionDAG &DAG) const { + MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo(); + + EVT VT = getPointerTy(DAG.getDataLayout()); + SDLoc DL(Op); + int FI = MFI.CreateFixedObject(4, 0, false); + return DAG.getFrameIndex(FI, VT); +} + SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { LLVM_DEBUG(dbgs() << "Lowering node: "; Op.dump()); switch (Op.getOpcode()) { @@ -9308,6 +10431,8 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::STRICT_FP_TO_UINT: case ISD::FP_TO_SINT: case ISD::FP_TO_UINT: return LowerFP_TO_INT(Op, DAG); + case ISD::FP_TO_SINT_SAT: + case ISD::FP_TO_UINT_SAT: return LowerFP_TO_INT_SAT(Op, DAG, Subtarget); case ISD::FCOPYSIGN: return LowerFCOPYSIGN(Op, DAG); case ISD::RETURNADDR: return LowerRETURNADDR(Op, DAG); case ISD::FRAMEADDR: return LowerFRAMEADDR(Op, DAG); @@ -9338,7 +10463,11 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::INSERT_VECTOR_ELT: return LowerINSERT_VECTOR_ELT(Op, DAG); case ISD::EXTRACT_VECTOR_ELT: return LowerEXTRACT_VECTOR_ELT(Op, DAG, Subtarget); case ISD::CONCAT_VECTORS: return LowerCONCAT_VECTORS(Op, DAG, Subtarget); - case ISD::FLT_ROUNDS_: return LowerFLT_ROUNDS_(Op, DAG); + case ISD::TRUNCATE: return LowerTruncate(Op.getNode(), DAG, Subtarget); + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: return LowerVectorExtend(Op.getNode(), DAG, Subtarget); + case ISD::GET_ROUNDING: return LowerGET_ROUNDING(Op, DAG); + case ISD::SET_ROUNDING: return LowerSET_ROUNDING(Op, DAG); case ISD::MUL: return LowerMUL(Op, DAG); case ISD::SDIV: if (Subtarget->isTargetWindows() && !Op.getValueType().isVector()) @@ -9358,13 +10487,25 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { return LowerUnsignedALUO(Op, DAG); case ISD::SADDSAT: case ISD::SSUBSAT: - return LowerSADDSUBSAT(Op, DAG, Subtarget); + case ISD::UADDSAT: + case ISD::USUBSAT: + return LowerADDSUBSAT(Op, DAG, Subtarget); case ISD::LOAD: return LowerPredicateLoad(Op, DAG); case ISD::STORE: - return LowerPredicateStore(Op, DAG); + return LowerSTORE(Op, DAG, Subtarget); case ISD::MLOAD: return LowerMLOAD(Op, DAG); + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: + return LowerVecReduce(Op, DAG, Subtarget); + case ISD::VECREDUCE_FADD: + case ISD::VECREDUCE_FMUL: + case ISD::VECREDUCE_FMIN: + case ISD::VECREDUCE_FMAX: + return LowerVecReduceF(Op, DAG, Subtarget); case ISD::ATOMIC_LOAD: case ISD::ATOMIC_STORE: return LowerAtomicLoadStore(Op, DAG); case ISD::FSINCOS: return LowerFSINCOS(Op, DAG); @@ -9380,6 +10521,8 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::FP_EXTEND: return LowerFP_EXTEND(Op, DAG); case ISD::STRICT_FSETCC: case ISD::STRICT_FSETCCS: return LowerFSETCC(Op, DAG); + case ISD::SPONENTRY: + return LowerSPONENTRY(Op, DAG); case ARMISD::WIN__DBZCHK: return SDValue(); } } @@ -9411,8 +10554,8 @@ static void ReplaceLongIntrinsic(SDNode *N, SmallVectorImpl<SDValue> &Results, DAG.getVTList(MVT::i32, MVT::i32), N->getOperand(1), N->getOperand(2), Lo, Hi); - Results.push_back(LongMul.getValue(0)); - Results.push_back(LongMul.getValue(1)); + Results.push_back(DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, + LongMul.getValue(0), LongMul.getValue(1))); } /// ReplaceNodeResults - Replace the results of node with an illegal result @@ -9448,7 +10591,9 @@ void ARMTargetLowering::ReplaceNodeResults(SDNode *N, return; case ISD::SADDSAT: case ISD::SSUBSAT: - Res = LowerSADDSUBSAT(SDValue(N, 0), DAG, Subtarget); + case ISD::UADDSAT: + case ISD::USUBSAT: + Res = LowerADDSUBSAT(SDValue(N, 0), DAG, Subtarget); break; case ISD::READCYCLECOUNTER: ReplaceREADCYCLECOUNTER(N, Results, DAG, Subtarget); @@ -9463,10 +10608,20 @@ void ARMTargetLowering::ReplaceNodeResults(SDNode *N, return; case ISD::INTRINSIC_WO_CHAIN: return ReplaceLongIntrinsic(N, Results, DAG); - case ISD::ABS: - lowerABS(N, Results, DAG); - return ; - + case ISD::LOAD: + LowerLOAD(N, Results, DAG); + break; + case ISD::TRUNCATE: + Res = LowerTruncate(N, DAG, Subtarget); + break; + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: + Res = LowerVectorExtend(N, DAG, Subtarget); + break; + case ISD::FP_TO_SINT_SAT: + case ISD::FP_TO_UINT_SAT: + Res = LowerFP_TO_INT_SAT(SDValue(N, 0), DAG, Subtarget); + break; } if (Res.getNode()) Results.push_back(Res); @@ -9499,7 +10654,7 @@ void ARMTargetLowering::SetupEntryBlockForSjLj(MachineInstr &MI, unsigned PCAdj = (isThumb || isThumb2) ? 4 : 8; ARMConstantPoolValue *CPV = ARMConstantPoolMBB::Create(F.getContext(), DispatchBB, PCLabelId, PCAdj); - unsigned CPI = MCP->getConstantPoolIndex(CPV, 4); + unsigned CPI = MCP->getConstantPoolIndex(CPV, Align(4)); const TargetRegisterClass *TRC = isThumb ? &ARM::tGPRRegClass : &ARM::GPRRegClass; @@ -9507,11 +10662,11 @@ void ARMTargetLowering::SetupEntryBlockForSjLj(MachineInstr &MI, // Grab constant pool and fixed stack memory operands. MachineMemOperand *CPMMO = MF->getMachineMemOperand(MachinePointerInfo::getConstantPool(*MF), - MachineMemOperand::MOLoad, 4, 4); + MachineMemOperand::MOLoad, 4, Align(4)); MachineMemOperand *FIMMOSt = MF->getMachineMemOperand(MachinePointerInfo::getFixedStack(*MF, FI), - MachineMemOperand::MOStore, 4, 4); + MachineMemOperand::MOStore, 4, Align(4)); // Load the address of the dispatch MBB into the jump buffer. if (isThumb2) { @@ -9622,25 +10777,23 @@ void ARMTargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, // associated with. DenseMap<unsigned, SmallVector<MachineBasicBlock*, 2>> CallSiteNumToLPad; unsigned MaxCSNum = 0; - for (MachineFunction::iterator BB = MF->begin(), E = MF->end(); BB != E; - ++BB) { - if (!BB->isEHPad()) continue; + for (MachineBasicBlock &BB : *MF) { + if (!BB.isEHPad()) + continue; // FIXME: We should assert that the EH_LABEL is the first MI in the landing // pad. - for (MachineBasicBlock::iterator - II = BB->begin(), IE = BB->end(); II != IE; ++II) { - if (!II->isEHLabel()) continue; + for (MachineInstr &II : BB) { + if (!II.isEHLabel()) + continue; - MCSymbol *Sym = II->getOperand(0).getMCSymbol(); + MCSymbol *Sym = II.getOperand(0).getMCSymbol(); if (!MF->hasCallSiteLandingPad(Sym)) continue; SmallVectorImpl<unsigned> &CallSiteIdxs = MF->getCallSiteLandingPad(Sym); - for (SmallVectorImpl<unsigned>::iterator - CSI = CallSiteIdxs.begin(), CSE = CallSiteIdxs.end(); - CSI != CSE; ++CSI) { - CallSiteNumToLPad[*CSI].push_back(&*BB); - MaxCSNum = std::max(MaxCSNum, *CSI); + for (unsigned Idx : CallSiteIdxs) { + CallSiteNumToLPad[Idx].push_back(&BB); + MaxCSNum = std::max(MaxCSNum, Idx); } break; } @@ -9652,10 +10805,9 @@ void ARMTargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, LPadList.reserve(CallSiteNumToLPad.size()); for (unsigned I = 1; I <= MaxCSNum; ++I) { SmallVectorImpl<MachineBasicBlock*> &MBBList = CallSiteNumToLPad[I]; - for (SmallVectorImpl<MachineBasicBlock*>::iterator - II = MBBList.begin(), IE = MBBList.end(); II != IE; ++II) { - LPadList.push_back(*II); - InvokeBBs.insert((*II)->pred_begin(), (*II)->pred_end()); + for (MachineBasicBlock *MBB : MBBList) { + LPadList.push_back(MBB); + InvokeBBs.insert(MBB->pred_begin(), MBB->pred_end()); } } @@ -9697,7 +10849,7 @@ void ARMTargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, MachineMemOperand *FIMMOLd = MF->getMachineMemOperand( MachinePointerInfo::getFixedStack(*MF, FI), - MachineMemOperand::MOLoad | MachineMemOperand::MOVolatile, 4, 4); + MachineMemOperand::MOLoad | MachineMemOperand::MOVolatile, 4, Align(4)); MachineInstrBuilder MIB; MIB = BuildMI(DispatchBB, dl, TII->get(ARM::Int_eh_sjlj_dispatchsetup)); @@ -9788,10 +10940,8 @@ void ARMTargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, const Constant *C = ConstantInt::get(Int32Ty, NumLPads); // MachineConstantPool wants an explicit alignment. - unsigned Align = MF->getDataLayout().getPrefTypeAlignment(Int32Ty); - if (Align == 0) - Align = MF->getDataLayout().getTypeAllocSize(C->getType()); - unsigned Idx = ConstantPool->getConstantPoolIndex(C, Align); + Align Alignment = MF->getDataLayout().getPrefTypeAlign(Int32Ty); + unsigned Idx = ConstantPool->getConstantPoolIndex(C, Alignment); Register VReg1 = MRI->createVirtualRegister(TRC); BuildMI(DispatchBB, dl, TII->get(ARM::tLDRpci)) @@ -9828,8 +10978,9 @@ void ARMTargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, .addReg(NewVReg3) .add(predOps(ARMCC::AL)); - MachineMemOperand *JTMMOLd = MF->getMachineMemOperand( - MachinePointerInfo::getJumpTable(*MF), MachineMemOperand::MOLoad, 4, 4); + MachineMemOperand *JTMMOLd = + MF->getMachineMemOperand(MachinePointerInfo::getJumpTable(*MF), + MachineMemOperand::MOLoad, 4, Align(4)); Register NewVReg5 = MRI->createVirtualRegister(TRC); BuildMI(DispContBB, dl, TII->get(ARM::tLDRi), NewVReg5) @@ -9889,10 +11040,8 @@ void ARMTargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, const Constant *C = ConstantInt::get(Int32Ty, NumLPads); // MachineConstantPool wants an explicit alignment. - unsigned Align = MF->getDataLayout().getPrefTypeAlignment(Int32Ty); - if (Align == 0) - Align = MF->getDataLayout().getTypeAllocSize(C->getType()); - unsigned Idx = ConstantPool->getConstantPoolIndex(C, Align); + Align Alignment = MF->getDataLayout().getPrefTypeAlign(Int32Ty); + unsigned Idx = ConstantPool->getConstantPoolIndex(C, Alignment); Register VReg1 = MRI->createVirtualRegister(TRC); BuildMI(DispatchBB, dl, TII->get(ARM::LDRcp)) @@ -9922,8 +11071,9 @@ void ARMTargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, .addJumpTableIndex(MJTI) .add(predOps(ARMCC::AL)); - MachineMemOperand *JTMMOLd = MF->getMachineMemOperand( - MachinePointerInfo::getJumpTable(*MF), MachineMemOperand::MOLoad, 4, 4); + MachineMemOperand *JTMMOLd = + MF->getMachineMemOperand(MachinePointerInfo::getJumpTable(*MF), + MachineMemOperand::MOLoad, 4, Align(4)); Register NewVReg5 = MRI->createVirtualRegister(TRC); BuildMI(DispContBB, dl, TII->get(ARM::LDRrs), NewVReg5) .addReg(NewVReg3, RegState::Kill) @@ -9946,9 +11096,7 @@ void ARMTargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, // Add the jump table entries as successors to the MBB. SmallPtrSet<MachineBasicBlock*, 8> SeenMBBs; - for (std::vector<MachineBasicBlock*>::iterator - I = LPadList.begin(), E = LPadList.end(); I != E; ++I) { - MachineBasicBlock *CurMBB = *I; + for (MachineBasicBlock *CurMBB : LPadList) { if (SeenMBBs.insert(CurMBB).second) DispContBB->addSuccessor(CurMBB); } @@ -9960,8 +11108,7 @@ void ARMTargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, // Remove the landing pad successor from the invoke block and replace it // with the new dispatch block. - SmallVector<MachineBasicBlock*, 4> Successors(BB->succ_begin(), - BB->succ_end()); + SmallVector<MachineBasicBlock*, 4> Successors(BB->successors()); while (!Successors.empty()) { MachineBasicBlock *SMBB = Successors.pop_back_val(); if (SMBB->isEHPad()) { @@ -10011,9 +11158,8 @@ void ARMTargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, // Mark all former landing pads as non-landing pads. The dispatch is the only // landing pad now. - for (SmallVectorImpl<MachineBasicBlock*>::iterator - I = MBBLPads.begin(), E = MBBLPads.end(); I != E; ++I) - (*I)->setIsEHPad(false); + for (MachineBasicBlock *MBBLPad : MBBLPads) + MBBLPad->setIsEHPad(false); // The instruction is gone now. MI.eraseFromParent(); @@ -10021,10 +11167,9 @@ void ARMTargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, static MachineBasicBlock *OtherSucc(MachineBasicBlock *MBB, MachineBasicBlock *Succ) { - for (MachineBasicBlock::succ_iterator I = MBB->succ_begin(), - E = MBB->succ_end(); I != E; ++I) - if (*I != Succ) - return *I; + for (MachineBasicBlock *S : MBB->successors()) + if (S != Succ) + return S; llvm_unreachable("Expecting a BB with two successors!"); } @@ -10162,7 +11307,7 @@ ARMTargetLowering::EmitStructByval(MachineInstr &MI, Register dest = MI.getOperand(0).getReg(); Register src = MI.getOperand(1).getReg(); unsigned SizeVal = MI.getOperand(2).getImm(); - unsigned Align = MI.getOperand(3).getImm(); + unsigned Alignment = MI.getOperand(3).getImm(); DebugLoc dl = MI.getDebugLoc(); MachineFunction *MF = BB->getParent(); @@ -10175,17 +11320,17 @@ ARMTargetLowering::EmitStructByval(MachineInstr &MI, bool IsThumb2 = Subtarget->isThumb2(); bool IsThumb = Subtarget->isThumb(); - if (Align & 1) { + if (Alignment & 1) { UnitSize = 1; - } else if (Align & 2) { + } else if (Alignment & 2) { UnitSize = 2; } else { // Check whether we can use NEON instructions. if (!MF->getFunction().hasFnAttribute(Attribute::NoImplicitFloat) && Subtarget->hasNEON()) { - if ((Align % 16 == 0) && SizeVal >= 16) + if ((Alignment % 16 == 0) && SizeVal >= 16) UnitSize = 16; - else if ((Align % 8 == 0) && SizeVal >= 8) + else if ((Alignment % 8 == 0) && SizeVal >= 8) UnitSize = 8; } // Can't use NEON instructions. @@ -10291,13 +11436,11 @@ ARMTargetLowering::EmitStructByval(MachineInstr &MI, const Constant *C = ConstantInt::get(Int32Ty, LoopSize); // MachineConstantPool wants an explicit alignment. - unsigned Align = MF->getDataLayout().getPrefTypeAlignment(Int32Ty); - if (Align == 0) - Align = MF->getDataLayout().getTypeAllocSize(C->getType()); - unsigned Idx = ConstantPool->getConstantPoolIndex(C, Align); + Align Alignment = MF->getDataLayout().getPrefTypeAlign(Int32Ty); + unsigned Idx = ConstantPool->getConstantPoolIndex(C, Alignment); MachineMemOperand *CPMMO = MF->getMachineMemOperand(MachinePointerInfo::getConstantPool(*MF), - MachineMemOperand::MOLoad, 4, 4); + MachineMemOperand::MOLoad, 4, Align(4)); if (IsThumb) BuildMI(*BB, MI, dl, TII->get(ARM::tLDRpci)) @@ -10447,7 +11590,7 @@ ARMTargetLowering::EmitLowered__chkstk(MachineInstr &MI, BuildMI(*MBB, MI, DL, TII.get(ARM::t2MOVi32imm), Reg) .addExternalSymbol("__chkstk"); - BuildMI(*MBB, MI, DL, TII.get(ARM::tBLXr)) + BuildMI(*MBB, MI, DL, TII.get(gettBLXrOpcode(*MBB->getParent()))) .add(predOps(ARMCC::AL)) .addReg(Reg, RegState::Kill) .addReg(ARM::R4, RegState::Implicit | RegState::Kill) @@ -10524,13 +11667,9 @@ static bool checkAndUpdateCPSRKill(MachineBasicBlock::iterator SelectItr, // If we hit the end of the block, check whether CPSR is live into a // successor. if (miI == BB->end()) { - for (MachineBasicBlock::succ_iterator sItr = BB->succ_begin(), - sEnd = BB->succ_end(); - sItr != sEnd; ++sItr) { - MachineBasicBlock* succ = *sItr; - if (succ->isLiveIn(ARM::CPSR)) + for (MachineBasicBlock *Succ : BB->successors()) + if (Succ->isLiveIn(ARM::CPSR)) return false; - } } // We found a def, or hit the end of the basic block and CPSR wasn't live @@ -10539,6 +11678,148 @@ static bool checkAndUpdateCPSRKill(MachineBasicBlock::iterator SelectItr, return true; } +/// Adds logic in loop entry MBB to calculate loop iteration count and adds +/// t2WhileLoopSetup and t2WhileLoopStart to generate WLS loop +static Register genTPEntry(MachineBasicBlock *TpEntry, + MachineBasicBlock *TpLoopBody, + MachineBasicBlock *TpExit, Register OpSizeReg, + const TargetInstrInfo *TII, DebugLoc Dl, + MachineRegisterInfo &MRI) { + // Calculates loop iteration count = ceil(n/16) = (n + 15) >> 4. + Register AddDestReg = MRI.createVirtualRegister(&ARM::rGPRRegClass); + BuildMI(TpEntry, Dl, TII->get(ARM::t2ADDri), AddDestReg) + .addUse(OpSizeReg) + .addImm(15) + .add(predOps(ARMCC::AL)) + .addReg(0); + + Register LsrDestReg = MRI.createVirtualRegister(&ARM::rGPRRegClass); + BuildMI(TpEntry, Dl, TII->get(ARM::t2LSRri), LsrDestReg) + .addUse(AddDestReg, RegState::Kill) + .addImm(4) + .add(predOps(ARMCC::AL)) + .addReg(0); + + Register TotalIterationsReg = MRI.createVirtualRegister(&ARM::GPRlrRegClass); + BuildMI(TpEntry, Dl, TII->get(ARM::t2WhileLoopSetup), TotalIterationsReg) + .addUse(LsrDestReg, RegState::Kill); + + BuildMI(TpEntry, Dl, TII->get(ARM::t2WhileLoopStart)) + .addUse(TotalIterationsReg) + .addMBB(TpExit); + + BuildMI(TpEntry, Dl, TII->get(ARM::t2B)) + .addMBB(TpLoopBody) + .add(predOps(ARMCC::AL)); + + return TotalIterationsReg; +} + +/// Adds logic in the loopBody MBB to generate MVE_VCTP, t2DoLoopDec and +/// t2DoLoopEnd. These are used by later passes to generate tail predicated +/// loops. +static void genTPLoopBody(MachineBasicBlock *TpLoopBody, + MachineBasicBlock *TpEntry, MachineBasicBlock *TpExit, + const TargetInstrInfo *TII, DebugLoc Dl, + MachineRegisterInfo &MRI, Register OpSrcReg, + Register OpDestReg, Register ElementCountReg, + Register TotalIterationsReg, bool IsMemcpy) { + // First insert 4 PHI nodes for: Current pointer to Src (if memcpy), Dest + // array, loop iteration counter, predication counter. + + Register SrcPhiReg, CurrSrcReg; + if (IsMemcpy) { + // Current position in the src array + SrcPhiReg = MRI.createVirtualRegister(&ARM::rGPRRegClass); + CurrSrcReg = MRI.createVirtualRegister(&ARM::rGPRRegClass); + BuildMI(TpLoopBody, Dl, TII->get(ARM::PHI), SrcPhiReg) + .addUse(OpSrcReg) + .addMBB(TpEntry) + .addUse(CurrSrcReg) + .addMBB(TpLoopBody); + } + + // Current position in the dest array + Register DestPhiReg = MRI.createVirtualRegister(&ARM::rGPRRegClass); + Register CurrDestReg = MRI.createVirtualRegister(&ARM::rGPRRegClass); + BuildMI(TpLoopBody, Dl, TII->get(ARM::PHI), DestPhiReg) + .addUse(OpDestReg) + .addMBB(TpEntry) + .addUse(CurrDestReg) + .addMBB(TpLoopBody); + + // Current loop counter + Register LoopCounterPhiReg = MRI.createVirtualRegister(&ARM::GPRlrRegClass); + Register RemainingLoopIterationsReg = + MRI.createVirtualRegister(&ARM::GPRlrRegClass); + BuildMI(TpLoopBody, Dl, TII->get(ARM::PHI), LoopCounterPhiReg) + .addUse(TotalIterationsReg) + .addMBB(TpEntry) + .addUse(RemainingLoopIterationsReg) + .addMBB(TpLoopBody); + + // Predication counter + Register PredCounterPhiReg = MRI.createVirtualRegister(&ARM::rGPRRegClass); + Register RemainingElementsReg = MRI.createVirtualRegister(&ARM::rGPRRegClass); + BuildMI(TpLoopBody, Dl, TII->get(ARM::PHI), PredCounterPhiReg) + .addUse(ElementCountReg) + .addMBB(TpEntry) + .addUse(RemainingElementsReg) + .addMBB(TpLoopBody); + + // Pass predication counter to VCTP + Register VccrReg = MRI.createVirtualRegister(&ARM::VCCRRegClass); + BuildMI(TpLoopBody, Dl, TII->get(ARM::MVE_VCTP8), VccrReg) + .addUse(PredCounterPhiReg) + .addImm(ARMVCC::None) + .addReg(0) + .addReg(0); + + BuildMI(TpLoopBody, Dl, TII->get(ARM::t2SUBri), RemainingElementsReg) + .addUse(PredCounterPhiReg) + .addImm(16) + .add(predOps(ARMCC::AL)) + .addReg(0); + + // VLDRB (only if memcpy) and VSTRB instructions, predicated using VPR + Register SrcValueReg; + if (IsMemcpy) { + SrcValueReg = MRI.createVirtualRegister(&ARM::MQPRRegClass); + BuildMI(TpLoopBody, Dl, TII->get(ARM::MVE_VLDRBU8_post)) + .addDef(CurrSrcReg) + .addDef(SrcValueReg) + .addReg(SrcPhiReg) + .addImm(16) + .addImm(ARMVCC::Then) + .addUse(VccrReg) + .addReg(0); + } else + SrcValueReg = OpSrcReg; + + BuildMI(TpLoopBody, Dl, TII->get(ARM::MVE_VSTRBU8_post)) + .addDef(CurrDestReg) + .addUse(SrcValueReg) + .addReg(DestPhiReg) + .addImm(16) + .addImm(ARMVCC::Then) + .addUse(VccrReg) + .addReg(0); + + // Add the pseudoInstrs for decrementing the loop counter and marking the + // end:t2DoLoopDec and t2DoLoopEnd + BuildMI(TpLoopBody, Dl, TII->get(ARM::t2LoopDec), RemainingLoopIterationsReg) + .addUse(LoopCounterPhiReg) + .addImm(1); + + BuildMI(TpLoopBody, Dl, TII->get(ARM::t2LoopEnd)) + .addUse(RemainingLoopIterationsReg) + .addMBB(TpLoopBody); + + BuildMI(TpLoopBody, Dl, TII->get(ARM::t2B)) + .addMBB(TpExit) + .add(predOps(ARMCC::AL)); +} + MachineBasicBlock * ARMTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MachineBasicBlock *BB) const { @@ -10565,6 +11846,98 @@ ARMTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, return BB; } + case ARM::MVE_MEMCPYLOOPINST: + case ARM::MVE_MEMSETLOOPINST: { + + // Transformation below expands MVE_MEMCPYLOOPINST/MVE_MEMSETLOOPINST Pseudo + // into a Tail Predicated (TP) Loop. It adds the instructions to calculate + // the iteration count =ceil(size_in_bytes/16)) in the TP entry block and + // adds the relevant instructions in the TP loop Body for generation of a + // WLSTP loop. + + // Below is relevant portion of the CFG after the transformation. + // The Machine Basic Blocks are shown along with branch conditions (in + // brackets). Note that TP entry/exit MBBs depict the entry/exit of this + // portion of the CFG and may not necessarily be the entry/exit of the + // function. + + // (Relevant) CFG after transformation: + // TP entry MBB + // | + // |-----------------| + // (n <= 0) (n > 0) + // | | + // | TP loop Body MBB<--| + // | | | + // \ |___________| + // \ / + // TP exit MBB + + MachineFunction *MF = BB->getParent(); + MachineFunctionProperties &Properties = MF->getProperties(); + MachineRegisterInfo &MRI = MF->getRegInfo(); + + Register OpDestReg = MI.getOperand(0).getReg(); + Register OpSrcReg = MI.getOperand(1).getReg(); + Register OpSizeReg = MI.getOperand(2).getReg(); + + // Allocate the required MBBs and add to parent function. + MachineBasicBlock *TpEntry = BB; + MachineBasicBlock *TpLoopBody = MF->CreateMachineBasicBlock(); + MachineBasicBlock *TpExit; + + MF->push_back(TpLoopBody); + + // If any instructions are present in the current block after + // MVE_MEMCPYLOOPINST or MVE_MEMSETLOOPINST, split the current block and + // move the instructions into the newly created exit block. If there are no + // instructions add an explicit branch to the FallThrough block and then + // split. + // + // The split is required for two reasons: + // 1) A terminator(t2WhileLoopStart) will be placed at that site. + // 2) Since a TPLoopBody will be added later, any phis in successive blocks + // need to be updated. splitAt() already handles this. + TpExit = BB->splitAt(MI, false); + if (TpExit == BB) { + assert(BB->canFallThrough() && "Exit Block must be Fallthrough of the " + "block containing memcpy/memset Pseudo"); + TpExit = BB->getFallThrough(); + BuildMI(BB, dl, TII->get(ARM::t2B)) + .addMBB(TpExit) + .add(predOps(ARMCC::AL)); + TpExit = BB->splitAt(MI, false); + } + + // Add logic for iteration count + Register TotalIterationsReg = + genTPEntry(TpEntry, TpLoopBody, TpExit, OpSizeReg, TII, dl, MRI); + + // Add the vectorized (and predicated) loads/store instructions + bool IsMemcpy = MI.getOpcode() == ARM::MVE_MEMCPYLOOPINST; + genTPLoopBody(TpLoopBody, TpEntry, TpExit, TII, dl, MRI, OpSrcReg, + OpDestReg, OpSizeReg, TotalIterationsReg, IsMemcpy); + + // Required to avoid conflict with the MachineVerifier during testing. + Properties.reset(MachineFunctionProperties::Property::NoPHIs); + + // Connect the blocks + TpEntry->addSuccessor(TpLoopBody); + TpLoopBody->addSuccessor(TpLoopBody); + TpLoopBody->addSuccessor(TpExit); + + // Reorder for a more natural layout + TpLoopBody->moveAfter(TpEntry); + TpExit->moveAfter(TpLoopBody); + + // Finally, remove the memcpy Psuedo Instruction + MI.eraseFromParent(); + + // Return the exit block as it may contain other instructions requiring a + // custom inserter + return TpExit; + } + // The Thumb2 pre-indexed stores have the same MI operands, they just // define them differently in the .td files from the isel patterns, so // they need pseudos. @@ -10612,8 +11985,8 @@ ARMTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case ARM::STRH_preidx: NewOpc = ARM::STRH_PRE; break; } MachineInstrBuilder MIB = BuildMI(*BB, MI, dl, TII->get(NewOpc)); - for (unsigned i = 0; i < MI.getNumOperands(); ++i) - MIB.add(MI.getOperand(i)); + for (const MachineOperand &MO : MI.operands()) + MIB.add(MO); MI.eraseFromParent(); return BB; } @@ -10893,7 +12266,7 @@ void ARMTargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI, if (Subtarget->isThumb1Only()) { for (unsigned c = MCID->getNumOperands() - 4; c--;) { MI.addOperand(MI.getOperand(1)); - MI.RemoveOperand(1); + MI.removeOperand(1); } // Restore the ties @@ -10916,7 +12289,7 @@ void ARMTargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI, // Any ARM instruction that sets the 's' bit should specify an optional // "cc_out" operand in the last operand position. - if (!MI.hasOptionalDef() || !MCID->OpInfo[ccOutIdx].isOptionalDef()) { + if (!MI.hasOptionalDef() || !MCID->operands()[ccOutIdx].isOptionalDef()) { assert(!NewOpc && "Optional cc_out operand required"); return; } @@ -10931,7 +12304,7 @@ void ARMTargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI, definesCPSR = true; if (MO.isDead()) deadCPSR = true; - MI.RemoveOperand(i); + MI.removeOperand(i); break; } } @@ -11002,7 +12375,7 @@ static bool isConditionalZeroOrAllOnes(SDNode *N, bool AllOnes, // (zext cc) can never be the all ones value. if (AllOnes) return false; - LLVM_FALLTHROUGH; + [[fallthrough]]; case ISD::SIGN_EXTEND: { SDLoc dl(N); EVT VT = N->getValueType(0); @@ -11018,8 +12391,7 @@ static bool isConditionalZeroOrAllOnes(SDNode *N, bool AllOnes, // When looking for a 0 constant, N can be zext or sext. OtherOp = DAG.getConstant(1, dl, VT); else - OtherOp = DAG.getConstant(APInt::getAllOnesValue(VT.getSizeInBits()), dl, - VT); + OtherOp = DAG.getAllOnesConstant(dl, VT); return true; } } @@ -11611,7 +12983,7 @@ static SDValue PerformAddcSubcCombine(SDNode *N, const ARMSubtarget *Subtarget) { SelectionDAG &DAG(DCI.DAG); - if (N->getOpcode() == ARMISD::SUBC) { + if (N->getOpcode() == ARMISD::SUBC && N->hasAnyUseOfValue(1)) { // (SUBC (ADDE 0, 0, C), 1) -> C SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); @@ -11667,20 +13039,333 @@ static SDValue PerformAddeSubeCombine(SDNode *N, return SDValue(); } +static SDValue PerformSELECTCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *Subtarget) { + if (!Subtarget->hasMVEIntegerOps()) + return SDValue(); + + SDLoc dl(N); + SDValue SetCC; + SDValue LHS; + SDValue RHS; + ISD::CondCode CC; + SDValue TrueVal; + SDValue FalseVal; + + if (N->getOpcode() == ISD::SELECT && + N->getOperand(0)->getOpcode() == ISD::SETCC) { + SetCC = N->getOperand(0); + LHS = SetCC->getOperand(0); + RHS = SetCC->getOperand(1); + CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get(); + TrueVal = N->getOperand(1); + FalseVal = N->getOperand(2); + } else if (N->getOpcode() == ISD::SELECT_CC) { + LHS = N->getOperand(0); + RHS = N->getOperand(1); + CC = cast<CondCodeSDNode>(N->getOperand(4))->get(); + TrueVal = N->getOperand(2); + FalseVal = N->getOperand(3); + } else { + return SDValue(); + } + + unsigned int Opcode = 0; + if ((TrueVal->getOpcode() == ISD::VECREDUCE_UMIN || + FalseVal->getOpcode() == ISD::VECREDUCE_UMIN) && + (CC == ISD::SETULT || CC == ISD::SETUGT)) { + Opcode = ARMISD::VMINVu; + if (CC == ISD::SETUGT) + std::swap(TrueVal, FalseVal); + } else if ((TrueVal->getOpcode() == ISD::VECREDUCE_SMIN || + FalseVal->getOpcode() == ISD::VECREDUCE_SMIN) && + (CC == ISD::SETLT || CC == ISD::SETGT)) { + Opcode = ARMISD::VMINVs; + if (CC == ISD::SETGT) + std::swap(TrueVal, FalseVal); + } else if ((TrueVal->getOpcode() == ISD::VECREDUCE_UMAX || + FalseVal->getOpcode() == ISD::VECREDUCE_UMAX) && + (CC == ISD::SETUGT || CC == ISD::SETULT)) { + Opcode = ARMISD::VMAXVu; + if (CC == ISD::SETULT) + std::swap(TrueVal, FalseVal); + } else if ((TrueVal->getOpcode() == ISD::VECREDUCE_SMAX || + FalseVal->getOpcode() == ISD::VECREDUCE_SMAX) && + (CC == ISD::SETGT || CC == ISD::SETLT)) { + Opcode = ARMISD::VMAXVs; + if (CC == ISD::SETLT) + std::swap(TrueVal, FalseVal); + } else + return SDValue(); + + // Normalise to the right hand side being the vector reduction + switch (TrueVal->getOpcode()) { + case ISD::VECREDUCE_UMIN: + case ISD::VECREDUCE_SMIN: + case ISD::VECREDUCE_UMAX: + case ISD::VECREDUCE_SMAX: + std::swap(LHS, RHS); + std::swap(TrueVal, FalseVal); + break; + } + + EVT VectorType = FalseVal->getOperand(0).getValueType(); + + if (VectorType != MVT::v16i8 && VectorType != MVT::v8i16 && + VectorType != MVT::v4i32) + return SDValue(); + + EVT VectorScalarType = VectorType.getVectorElementType(); + + // The values being selected must also be the ones being compared + if (TrueVal != LHS || FalseVal != RHS) + return SDValue(); + + EVT LeftType = LHS->getValueType(0); + EVT RightType = RHS->getValueType(0); + + // The types must match the reduced type too + if (LeftType != VectorScalarType || RightType != VectorScalarType) + return SDValue(); + + // Legalise the scalar to an i32 + if (VectorScalarType != MVT::i32) + LHS = DCI.DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, LHS); + + // Generate the reduction as an i32 for legalisation purposes + auto Reduction = + DCI.DAG.getNode(Opcode, dl, MVT::i32, LHS, RHS->getOperand(0)); + + // The result isn't actually an i32 so truncate it back to its original type + if (VectorScalarType != MVT::i32) + Reduction = DCI.DAG.getNode(ISD::TRUNCATE, dl, VectorScalarType, Reduction); + + return Reduction; +} + +// A special combine for the vqdmulh family of instructions. This is one of the +// potential set of patterns that could patch this instruction. The base pattern +// you would expect to be min(max(ashr(mul(mul(sext(x), 2), sext(y)), 16))). +// This matches the different min(max(ashr(mul(mul(sext(x), sext(y)), 2), 16))), +// which llvm will have optimized to min(ashr(mul(sext(x), sext(y)), 15))) as +// the max is unnecessary. +static SDValue PerformVQDMULHCombine(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + SDValue Shft; + ConstantSDNode *Clamp; + + if (!VT.isVector() || VT.getScalarSizeInBits() > 64) + return SDValue(); + + if (N->getOpcode() == ISD::SMIN) { + Shft = N->getOperand(0); + Clamp = isConstOrConstSplat(N->getOperand(1)); + } else if (N->getOpcode() == ISD::VSELECT) { + // Detect a SMIN, which for an i64 node will be a vselect/setcc, not a smin. + SDValue Cmp = N->getOperand(0); + if (Cmp.getOpcode() != ISD::SETCC || + cast<CondCodeSDNode>(Cmp.getOperand(2))->get() != ISD::SETLT || + Cmp.getOperand(0) != N->getOperand(1) || + Cmp.getOperand(1) != N->getOperand(2)) + return SDValue(); + Shft = N->getOperand(1); + Clamp = isConstOrConstSplat(N->getOperand(2)); + } else + return SDValue(); + + if (!Clamp) + return SDValue(); + + MVT ScalarType; + int ShftAmt = 0; + switch (Clamp->getSExtValue()) { + case (1 << 7) - 1: + ScalarType = MVT::i8; + ShftAmt = 7; + break; + case (1 << 15) - 1: + ScalarType = MVT::i16; + ShftAmt = 15; + break; + case (1ULL << 31) - 1: + ScalarType = MVT::i32; + ShftAmt = 31; + break; + default: + return SDValue(); + } + + if (Shft.getOpcode() != ISD::SRA) + return SDValue(); + ConstantSDNode *N1 = isConstOrConstSplat(Shft.getOperand(1)); + if (!N1 || N1->getSExtValue() != ShftAmt) + return SDValue(); + + SDValue Mul = Shft.getOperand(0); + if (Mul.getOpcode() != ISD::MUL) + return SDValue(); + + SDValue Ext0 = Mul.getOperand(0); + SDValue Ext1 = Mul.getOperand(1); + if (Ext0.getOpcode() != ISD::SIGN_EXTEND || + Ext1.getOpcode() != ISD::SIGN_EXTEND) + return SDValue(); + EVT VecVT = Ext0.getOperand(0).getValueType(); + if (!VecVT.isPow2VectorType() || VecVT.getVectorNumElements() == 1) + return SDValue(); + if (Ext1.getOperand(0).getValueType() != VecVT || + VecVT.getScalarType() != ScalarType || + VT.getScalarSizeInBits() < ScalarType.getScalarSizeInBits() * 2) + return SDValue(); + + SDLoc DL(Mul); + unsigned LegalLanes = 128 / (ShftAmt + 1); + EVT LegalVecVT = MVT::getVectorVT(ScalarType, LegalLanes); + // For types smaller than legal vectors extend to be legal and only use needed + // lanes. + if (VecVT.getSizeInBits() < 128) { + EVT ExtVecVT = + MVT::getVectorVT(MVT::getIntegerVT(128 / VecVT.getVectorNumElements()), + VecVT.getVectorNumElements()); + SDValue Inp0 = + DAG.getNode(ISD::ANY_EXTEND, DL, ExtVecVT, Ext0.getOperand(0)); + SDValue Inp1 = + DAG.getNode(ISD::ANY_EXTEND, DL, ExtVecVT, Ext1.getOperand(0)); + Inp0 = DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, LegalVecVT, Inp0); + Inp1 = DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, LegalVecVT, Inp1); + SDValue VQDMULH = DAG.getNode(ARMISD::VQDMULH, DL, LegalVecVT, Inp0, Inp1); + SDValue Trunc = DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, ExtVecVT, VQDMULH); + Trunc = DAG.getNode(ISD::TRUNCATE, DL, VecVT, Trunc); + return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Trunc); + } + + // For larger types, split into legal sized chunks. + assert(VecVT.getSizeInBits() % 128 == 0 && "Expected a power2 type"); + unsigned NumParts = VecVT.getSizeInBits() / 128; + SmallVector<SDValue> Parts; + for (unsigned I = 0; I < NumParts; ++I) { + SDValue Inp0 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LegalVecVT, Ext0.getOperand(0), + DAG.getVectorIdxConstant(I * LegalLanes, DL)); + SDValue Inp1 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LegalVecVT, Ext1.getOperand(0), + DAG.getVectorIdxConstant(I * LegalLanes, DL)); + SDValue VQDMULH = DAG.getNode(ARMISD::VQDMULH, DL, LegalVecVT, Inp0, Inp1); + Parts.push_back(VQDMULH); + } + return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, + DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, Parts)); +} + +static SDValue PerformVSELECTCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *Subtarget) { + if (!Subtarget->hasMVEIntegerOps()) + return SDValue(); + + if (SDValue V = PerformVQDMULHCombine(N, DCI.DAG)) + return V; + + // Transforms vselect(not(cond), lhs, rhs) into vselect(cond, rhs, lhs). + // + // We need to re-implement this optimization here as the implementation in the + // Target-Independent DAGCombiner does not handle the kind of constant we make + // (it calls isConstOrConstSplat with AllowTruncation set to false - and for + // good reason, allowing truncation there would break other targets). + // + // Currently, this is only done for MVE, as it's the only target that benefits + // from this transformation (e.g. VPNOT+VPSEL becomes a single VPSEL). + if (N->getOperand(0).getOpcode() != ISD::XOR) + return SDValue(); + SDValue XOR = N->getOperand(0); + + // Check if the XOR's RHS is either a 1, or a BUILD_VECTOR of 1s. + // It is important to check with truncation allowed as the BUILD_VECTORs we + // generate in those situations will truncate their operands. + ConstantSDNode *Const = + isConstOrConstSplat(XOR->getOperand(1), /*AllowUndefs*/ false, + /*AllowTruncation*/ true); + if (!Const || !Const->isOne()) + return SDValue(); + + // Rewrite into vselect(cond, rhs, lhs). + SDValue Cond = XOR->getOperand(0); + SDValue LHS = N->getOperand(1); + SDValue RHS = N->getOperand(2); + EVT Type = N->getValueType(0); + return DCI.DAG.getNode(ISD::VSELECT, SDLoc(N), Type, Cond, RHS, LHS); +} + +// Convert vsetcc([0,1,2,..], splat(n), ult) -> vctp n +static SDValue PerformVSetCCToVCTPCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *Subtarget) { + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + ISD::CondCode CC = cast<CondCodeSDNode>(N->getOperand(2))->get(); + EVT VT = N->getValueType(0); + + if (!Subtarget->hasMVEIntegerOps() || + !DCI.DAG.getTargetLoweringInfo().isTypeLegal(VT)) + return SDValue(); + + if (CC == ISD::SETUGE) { + std::swap(Op0, Op1); + CC = ISD::SETULT; + } + + if (CC != ISD::SETULT || VT.getScalarSizeInBits() != 1 || + Op0.getOpcode() != ISD::BUILD_VECTOR) + return SDValue(); + + // Check first operand is BuildVector of 0,1,2,... + for (unsigned I = 0; I < VT.getVectorNumElements(); I++) { + if (!Op0.getOperand(I).isUndef() && + !(isa<ConstantSDNode>(Op0.getOperand(I)) && + Op0.getConstantOperandVal(I) == I)) + return SDValue(); + } + + // The second is a Splat of Op1S + SDValue Op1S = DCI.DAG.getSplatValue(Op1); + if (!Op1S) + return SDValue(); + + unsigned Opc; + switch (VT.getVectorNumElements()) { + case 2: + Opc = Intrinsic::arm_mve_vctp64; + break; + case 4: + Opc = Intrinsic::arm_mve_vctp32; + break; + case 8: + Opc = Intrinsic::arm_mve_vctp16; + break; + case 16: + Opc = Intrinsic::arm_mve_vctp8; + break; + default: + return SDValue(); + } + + SDLoc DL(N); + return DCI.DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DCI.DAG.getConstant(Opc, DL, MVT::i32), + DCI.DAG.getZExtOrTrunc(Op1S, DL, MVT::i32)); +} + static SDValue PerformABSCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI, - const ARMSubtarget *Subtarget) { - SDValue res; + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *Subtarget) { SelectionDAG &DAG = DCI.DAG; const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (TLI.isOperationLegal(N->getOpcode(), N->getValueType(0))) return SDValue(); - if (!TLI.expandABS(N, res, DAG)) - return SDValue(); - - return res; + return TLI.expandABS(N, DAG); } /// PerformADDECombine - Target-specific dag combine transform from @@ -11724,9 +13409,248 @@ static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, return SDValue(); } +static SDValue TryDistrubutionADDVecReduce(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + SDLoc dl(N); + + auto IsVecReduce = [](SDValue Op) { + switch (Op.getOpcode()) { + case ISD::VECREDUCE_ADD: + case ARMISD::VADDVs: + case ARMISD::VADDVu: + case ARMISD::VMLAVs: + case ARMISD::VMLAVu: + return true; + } + return false; + }; + + auto DistrubuteAddAddVecReduce = [&](SDValue N0, SDValue N1) { + // Distribute add(X, add(vecreduce(Y), vecreduce(Z))) -> + // add(add(X, vecreduce(Y)), vecreduce(Z)) + // to make better use of vaddva style instructions. + if (VT == MVT::i32 && N1.getOpcode() == ISD::ADD && !IsVecReduce(N0) && + IsVecReduce(N1.getOperand(0)) && IsVecReduce(N1.getOperand(1)) && + !isa<ConstantSDNode>(N0) && N1->hasOneUse()) { + SDValue Add0 = DAG.getNode(ISD::ADD, dl, VT, N0, N1.getOperand(0)); + return DAG.getNode(ISD::ADD, dl, VT, Add0, N1.getOperand(1)); + } + // And turn add(add(A, reduce(B)), add(C, reduce(D))) -> + // add(add(add(A, C), reduce(B)), reduce(D)) + if (VT == MVT::i32 && N0.getOpcode() == ISD::ADD && + N1.getOpcode() == ISD::ADD && N0->hasOneUse() && N1->hasOneUse()) { + unsigned N0RedOp = 0; + if (!IsVecReduce(N0.getOperand(N0RedOp))) { + N0RedOp = 1; + if (!IsVecReduce(N0.getOperand(N0RedOp))) + return SDValue(); + } + + unsigned N1RedOp = 0; + if (!IsVecReduce(N1.getOperand(N1RedOp))) + N1RedOp = 1; + if (!IsVecReduce(N1.getOperand(N1RedOp))) + return SDValue(); + + SDValue Add0 = DAG.getNode(ISD::ADD, dl, VT, N0.getOperand(1 - N0RedOp), + N1.getOperand(1 - N1RedOp)); + SDValue Add1 = + DAG.getNode(ISD::ADD, dl, VT, Add0, N0.getOperand(N0RedOp)); + return DAG.getNode(ISD::ADD, dl, VT, Add1, N1.getOperand(N1RedOp)); + } + return SDValue(); + }; + if (SDValue R = DistrubuteAddAddVecReduce(N0, N1)) + return R; + if (SDValue R = DistrubuteAddAddVecReduce(N1, N0)) + return R; + + // Distribute add(vecreduce(load(Y)), vecreduce(load(Z))) + // Or add(add(X, vecreduce(load(Y))), vecreduce(load(Z))) + // by ascending load offsets. This can help cores prefetch if the order of + // loads is more predictable. + auto DistrubuteVecReduceLoad = [&](SDValue N0, SDValue N1, bool IsForward) { + // Check if two reductions are known to load data where one is before/after + // another. Return negative if N0 loads data before N1, positive if N1 is + // before N0 and 0 otherwise if nothing is known. + auto IsKnownOrderedLoad = [&](SDValue N0, SDValue N1) { + // Look through to the first operand of a MUL, for the VMLA case. + // Currently only looks at the first operand, in the hope they are equal. + if (N0.getOpcode() == ISD::MUL) + N0 = N0.getOperand(0); + if (N1.getOpcode() == ISD::MUL) + N1 = N1.getOperand(0); + + // Return true if the two operands are loads to the same object and the + // offset of the first is known to be less than the offset of the second. + LoadSDNode *Load0 = dyn_cast<LoadSDNode>(N0); + LoadSDNode *Load1 = dyn_cast<LoadSDNode>(N1); + if (!Load0 || !Load1 || Load0->getChain() != Load1->getChain() || + !Load0->isSimple() || !Load1->isSimple() || Load0->isIndexed() || + Load1->isIndexed()) + return 0; + + auto BaseLocDecomp0 = BaseIndexOffset::match(Load0, DAG); + auto BaseLocDecomp1 = BaseIndexOffset::match(Load1, DAG); + + if (!BaseLocDecomp0.getBase() || + BaseLocDecomp0.getBase() != BaseLocDecomp1.getBase() || + !BaseLocDecomp0.hasValidOffset() || !BaseLocDecomp1.hasValidOffset()) + return 0; + if (BaseLocDecomp0.getOffset() < BaseLocDecomp1.getOffset()) + return -1; + if (BaseLocDecomp0.getOffset() > BaseLocDecomp1.getOffset()) + return 1; + return 0; + }; + + SDValue X; + if (N0.getOpcode() == ISD::ADD && N0->hasOneUse()) { + if (IsVecReduce(N0.getOperand(0)) && IsVecReduce(N0.getOperand(1))) { + int IsBefore = IsKnownOrderedLoad(N0.getOperand(0).getOperand(0), + N0.getOperand(1).getOperand(0)); + if (IsBefore < 0) { + X = N0.getOperand(0); + N0 = N0.getOperand(1); + } else if (IsBefore > 0) { + X = N0.getOperand(1); + N0 = N0.getOperand(0); + } else + return SDValue(); + } else if (IsVecReduce(N0.getOperand(0))) { + X = N0.getOperand(1); + N0 = N0.getOperand(0); + } else if (IsVecReduce(N0.getOperand(1))) { + X = N0.getOperand(0); + N0 = N0.getOperand(1); + } else + return SDValue(); + } else if (IsForward && IsVecReduce(N0) && IsVecReduce(N1) && + IsKnownOrderedLoad(N0.getOperand(0), N1.getOperand(0)) < 0) { + // Note this is backward to how you would expect. We create + // add(reduce(load + 16), reduce(load + 0)) so that the + // add(reduce(load+16), X) is combined into VADDVA(X, load+16)), leaving + // the X as VADDV(load + 0) + return DAG.getNode(ISD::ADD, dl, VT, N1, N0); + } else + return SDValue(); + + if (!IsVecReduce(N0) || !IsVecReduce(N1)) + return SDValue(); + + if (IsKnownOrderedLoad(N1.getOperand(0), N0.getOperand(0)) >= 0) + return SDValue(); + + // Switch from add(add(X, N0), N1) to add(add(X, N1), N0) + SDValue Add0 = DAG.getNode(ISD::ADD, dl, VT, X, N1); + return DAG.getNode(ISD::ADD, dl, VT, Add0, N0); + }; + if (SDValue R = DistrubuteVecReduceLoad(N0, N1, true)) + return R; + if (SDValue R = DistrubuteVecReduceLoad(N1, N0, false)) + return R; + return SDValue(); +} + +static SDValue PerformADDVecReduce(SDNode *N, SelectionDAG &DAG, + const ARMSubtarget *Subtarget) { + if (!Subtarget->hasMVEIntegerOps()) + return SDValue(); + + if (SDValue R = TryDistrubutionADDVecReduce(N, DAG)) + return R; + + EVT VT = N->getValueType(0); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + SDLoc dl(N); + + if (VT != MVT::i64) + return SDValue(); + + // We are looking for a i64 add of a VADDLVx. Due to these being i64's, this + // will look like: + // t1: i32,i32 = ARMISD::VADDLVs x + // t2: i64 = build_pair t1, t1:1 + // t3: i64 = add t2, y + // Otherwise we try to push the add up above VADDLVAx, to potentially allow + // the add to be simplified seperately. + // We also need to check for sext / zext and commutitive adds. + auto MakeVecReduce = [&](unsigned Opcode, unsigned OpcodeA, SDValue NA, + SDValue NB) { + if (NB->getOpcode() != ISD::BUILD_PAIR) + return SDValue(); + SDValue VecRed = NB->getOperand(0); + if ((VecRed->getOpcode() != Opcode && VecRed->getOpcode() != OpcodeA) || + VecRed.getResNo() != 0 || + NB->getOperand(1) != SDValue(VecRed.getNode(), 1)) + return SDValue(); + + if (VecRed->getOpcode() == OpcodeA) { + // add(NA, VADDLVA(Inp), Y) -> VADDLVA(add(NA, Inp), Y) + SDValue Inp = DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, + VecRed.getOperand(0), VecRed.getOperand(1)); + NA = DAG.getNode(ISD::ADD, dl, MVT::i64, Inp, NA); + } + + SmallVector<SDValue, 4> Ops; + Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, NA, + DAG.getConstant(0, dl, MVT::i32))); + Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, NA, + DAG.getConstant(1, dl, MVT::i32))); + unsigned S = VecRed->getOpcode() == OpcodeA ? 2 : 0; + for (unsigned I = S, E = VecRed.getNumOperands(); I < E; I++) + Ops.push_back(VecRed->getOperand(I)); + SDValue Red = + DAG.getNode(OpcodeA, dl, DAG.getVTList({MVT::i32, MVT::i32}), Ops); + return DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Red, + SDValue(Red.getNode(), 1)); + }; + + if (SDValue M = MakeVecReduce(ARMISD::VADDLVs, ARMISD::VADDLVAs, N0, N1)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VADDLVu, ARMISD::VADDLVAu, N0, N1)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VADDLVs, ARMISD::VADDLVAs, N1, N0)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VADDLVu, ARMISD::VADDLVAu, N1, N0)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VADDLVps, ARMISD::VADDLVAps, N0, N1)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VADDLVpu, ARMISD::VADDLVApu, N0, N1)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VADDLVps, ARMISD::VADDLVAps, N1, N0)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VADDLVpu, ARMISD::VADDLVApu, N1, N0)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VMLALVs, ARMISD::VMLALVAs, N0, N1)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VMLALVu, ARMISD::VMLALVAu, N0, N1)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VMLALVs, ARMISD::VMLALVAs, N1, N0)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VMLALVu, ARMISD::VMLALVAu, N1, N0)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VMLALVps, ARMISD::VMLALVAps, N0, N1)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VMLALVpu, ARMISD::VMLALVApu, N0, N1)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VMLALVps, ARMISD::VMLALVAps, N1, N0)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VMLALVpu, ARMISD::VMLALVApu, N1, N0)) + return M; + return SDValue(); +} + bool ARMTargetLowering::isDesirableToCommuteWithShift(const SDNode *N, CombineLevel Level) const { + assert((N->getOpcode() == ISD::SHL || N->getOpcode() == ISD::SRA || + N->getOpcode() == ISD::SRL) && + "Expected shift op"); + if (Level == BeforeLegalizeTypes) return true; @@ -11760,8 +13684,38 @@ ARMTargetLowering::isDesirableToCommuteWithShift(const SDNode *N, return false; } +bool ARMTargetLowering::isDesirableToCommuteXorWithShift( + const SDNode *N) const { + assert(N->getOpcode() == ISD::XOR && + (N->getOperand(0).getOpcode() == ISD::SHL || + N->getOperand(0).getOpcode() == ISD::SRL) && + "Expected XOR(SHIFT) pattern"); + + // Only commute if the entire NOT mask is a hidden shifted mask. + auto *XorC = dyn_cast<ConstantSDNode>(N->getOperand(1)); + auto *ShiftC = dyn_cast<ConstantSDNode>(N->getOperand(0).getOperand(1)); + if (XorC && ShiftC) { + unsigned MaskIdx, MaskLen; + if (XorC->getAPIntValue().isShiftedMask(MaskIdx, MaskLen)) { + unsigned ShiftAmt = ShiftC->getZExtValue(); + unsigned BitWidth = N->getValueType(0).getScalarSizeInBits(); + if (N->getOperand(0).getOpcode() == ISD::SHL) + return MaskIdx == ShiftAmt && MaskLen == (BitWidth - ShiftAmt); + return MaskIdx == 0 && MaskLen == (BitWidth - ShiftAmt); + } + } + + return false; +} + bool ARMTargetLowering::shouldFoldConstantShiftPairToMask( const SDNode *N, CombineLevel Level) const { + assert(((N->getOpcode() == ISD::SHL && + N->getOperand(0).getOpcode() == ISD::SRL) || + (N->getOpcode() == ISD::SRL && + N->getOperand(0).getOpcode() == ISD::SHL)) && + "Expected shift-shift mask"); + if (!Subtarget->isThumb1Only()) return true; @@ -11780,6 +13734,26 @@ bool ARMTargetLowering::preferIncOfAddToSubOfNot(EVT VT) const { return VT.isScalarInteger(); } +bool ARMTargetLowering::shouldConvertFpToSat(unsigned Op, EVT FPVT, + EVT VT) const { + if (!isOperationLegalOrCustom(Op, VT) || !FPVT.isSimple()) + return false; + + switch (FPVT.getSimpleVT().SimpleTy) { + case MVT::f16: + return Subtarget->hasVFP2Base(); + case MVT::f32: + return Subtarget->hasVFP2Base(); + case MVT::f64: + return Subtarget->hasFP64(); + case MVT::v4f32: + case MVT::v8f16: + return Subtarget->hasMVEFloatOps(); + default: + return false; + } +} + static SDValue PerformSHLSimplify(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *ST) { @@ -11807,7 +13781,7 @@ static SDValue PerformSHLSimplify(SDNode *N, return SDValue(); // Check that all the users could perform the shl themselves. - for (auto U : N->uses()) { + for (auto *U : N->uses()) { switch(U->getOpcode()) { default: return SDValue(); @@ -11849,10 +13823,13 @@ static SDValue PerformSHLSimplify(SDNode *N, APInt C2Int = C2->getAPIntValue(); APInt C1Int = C1ShlC2->getAPIntValue(); + unsigned C2Width = C2Int.getBitWidth(); + if (C2Int.uge(C2Width)) + return SDValue(); + uint64_t C2Value = C2Int.getZExtValue(); // Check that performing a lshr will not lose any information. - APInt Mask = APInt::getHighBitsSet(C2Int.getBitWidth(), - C2Int.getBitWidth() - C2->getZExtValue()); + APInt Mask = APInt::getHighBitsSet(C2Width, C2Width - C2Value); if ((C1Int & Mask) != C1Int) return SDValue(); @@ -11895,6 +13872,9 @@ static SDValue PerformADDCombine(SDNode *N, if (SDValue Result = PerformSHLSimplify(N, DCI, Subtarget)) return Result; + if (SDValue Result = PerformADDVecReduce(N, DCI.DAG, Subtarget)) + return Result; + // First try with the default operand order. if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI, Subtarget)) return Result; @@ -11903,6 +13883,26 @@ static SDValue PerformADDCombine(SDNode *N, return PerformADDCombineWithOperands(N, N1, N0, DCI, Subtarget); } +// Combine (sub 0, (csinc X, Y, CC)) -> (csinv -X, Y, CC) +// providing -X is as cheap as X (currently, just a constant). +static SDValue PerformSubCSINCCombine(SDNode *N, SelectionDAG &DAG) { + if (N->getValueType(0) != MVT::i32 || !isNullConstant(N->getOperand(0))) + return SDValue(); + SDValue CSINC = N->getOperand(1); + if (CSINC.getOpcode() != ARMISD::CSINC || !CSINC.hasOneUse()) + return SDValue(); + + ConstantSDNode *X = dyn_cast<ConstantSDNode>(CSINC.getOperand(0)); + if (!X) + return SDValue(); + + return DAG.getNode(ARMISD::CSINV, SDLoc(N), MVT::i32, + DAG.getNode(ISD::SUB, SDLoc(N), MVT::i32, N->getOperand(0), + CSINC.getOperand(0)), + CSINC.getOperand(1), CSINC.getOperand(2), + CSINC.getOperand(3)); +} + /// PerformSUBCombine - Target-specific dag combine xforms for ISD::SUB. /// static SDValue PerformSUBCombine(SDNode *N, @@ -11916,6 +13916,9 @@ static SDValue PerformSUBCombine(SDNode *N, if (SDValue Result = combineSelectAndUse(N, N1, N0, DCI)) return Result; + if (SDValue R = PerformSubCSINCCombine(N, DCI.DAG)) + return R; + if (!Subtarget->hasMVEIntegerOps() || !N->getValueType(0).isVector()) return SDValue(); @@ -11986,18 +13989,86 @@ static SDValue PerformVMULCombine(SDNode *N, DAG.getNode(ISD::MUL, DL, VT, N01, N1)); } +static SDValue PerformMVEVMULLCombine(SDNode *N, SelectionDAG &DAG, + const ARMSubtarget *Subtarget) { + EVT VT = N->getValueType(0); + if (VT != MVT::v2i64) + return SDValue(); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + auto IsSignExt = [&](SDValue Op) { + if (Op->getOpcode() != ISD::SIGN_EXTEND_INREG) + return SDValue(); + EVT VT = cast<VTSDNode>(Op->getOperand(1))->getVT(); + if (VT.getScalarSizeInBits() == 32) + return Op->getOperand(0); + return SDValue(); + }; + auto IsZeroExt = [&](SDValue Op) { + // Zero extends are a little more awkward. At the point we are matching + // this, we are looking for an AND with a (-1, 0, -1, 0) buildvector mask. + // That might be before of after a bitcast depending on how the and is + // placed. Because this has to look through bitcasts, it is currently only + // supported on LE. + if (!Subtarget->isLittle()) + return SDValue(); + + SDValue And = Op; + if (And->getOpcode() == ISD::BITCAST) + And = And->getOperand(0); + if (And->getOpcode() != ISD::AND) + return SDValue(); + SDValue Mask = And->getOperand(1); + if (Mask->getOpcode() == ISD::BITCAST) + Mask = Mask->getOperand(0); + + if (Mask->getOpcode() != ISD::BUILD_VECTOR || + Mask.getValueType() != MVT::v4i32) + return SDValue(); + if (isAllOnesConstant(Mask->getOperand(0)) && + isNullConstant(Mask->getOperand(1)) && + isAllOnesConstant(Mask->getOperand(2)) && + isNullConstant(Mask->getOperand(3))) + return And->getOperand(0); + return SDValue(); + }; + + SDLoc dl(N); + if (SDValue Op0 = IsSignExt(N0)) { + if (SDValue Op1 = IsSignExt(N1)) { + SDValue New0a = DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, MVT::v4i32, Op0); + SDValue New1a = DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, MVT::v4i32, Op1); + return DAG.getNode(ARMISD::VMULLs, dl, VT, New0a, New1a); + } + } + if (SDValue Op0 = IsZeroExt(N0)) { + if (SDValue Op1 = IsZeroExt(N1)) { + SDValue New0a = DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, MVT::v4i32, Op0); + SDValue New1a = DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, MVT::v4i32, Op1); + return DAG.getNode(ARMISD::VMULLu, dl, VT, New0a, New1a); + } + } + + return SDValue(); +} + static SDValue PerformMULCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget) { SelectionDAG &DAG = DCI.DAG; + EVT VT = N->getValueType(0); + if (Subtarget->hasMVEIntegerOps() && VT == MVT::v2i64) + return PerformMVEVMULLCombine(N, DAG, Subtarget); + if (Subtarget->isThumb1Only()) return SDValue(); if (DCI.isBeforeLegalize() || DCI.isCalledByLegalizer()) return SDValue(); - EVT VT = N->getValueType(0); if (VT.is64BitVector() || VT.is128BitVector()) return PerformVMULCombine(N, DCI, Subtarget); if (VT != MVT::i32) @@ -12182,20 +14253,21 @@ static SDValue PerformANDCombine(SDNode *N, EVT VT = N->getValueType(0); SelectionDAG &DAG = DCI.DAG; - if(!DAG.getTargetLoweringInfo().isTypeLegal(VT)) + if (!DAG.getTargetLoweringInfo().isTypeLegal(VT) || VT == MVT::v2i1 || + VT == MVT::v4i1 || VT == MVT::v8i1 || VT == MVT::v16i1) return SDValue(); APInt SplatBits, SplatUndef; unsigned SplatBitSize; bool HasAnyUndefs; - if (BVN && Subtarget->hasNEON() && + if (BVN && (Subtarget->hasNEON() || Subtarget->hasMVEIntegerOps()) && BVN->isConstantSplat(SplatBits, SplatUndef, SplatBitSize, HasAnyUndefs)) { - if (SplatBitSize <= 64) { + if (SplatBitSize == 8 || SplatBitSize == 16 || SplatBitSize == 32 || + SplatBitSize == 64) { EVT VbicVT; SDValue Val = isVMOVModifiedImm((~SplatBits).getZExtValue(), SplatUndef.getZExtValue(), SplatBitSize, - DAG, dl, VbicVT, VT.is128BitVector(), - OtherModImm); + DAG, dl, VbicVT, VT, OtherModImm); if (Val.getNode()) { SDValue Input = DAG.getNode(ISD::BITCAST, dl, VbicVT, N->getOperand(0)); @@ -12425,58 +14497,43 @@ static bool isValidMVECond(unsigned CC, bool IsFloat) { }; } -static SDValue PerformORCombine_i1(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI, +static ARMCC::CondCodes getVCMPCondCode(SDValue N) { + if (N->getOpcode() == ARMISD::VCMP) + return (ARMCC::CondCodes)N->getConstantOperandVal(2); + else if (N->getOpcode() == ARMISD::VCMPZ) + return (ARMCC::CondCodes)N->getConstantOperandVal(1); + else + llvm_unreachable("Not a VCMP/VCMPZ!"); +} + +static bool CanInvertMVEVCMP(SDValue N) { + ARMCC::CondCodes CC = ARMCC::getOppositeCondition(getVCMPCondCode(N)); + return isValidMVECond(CC, N->getOperand(0).getValueType().isFloatingPoint()); +} + +static SDValue PerformORCombine_i1(SDNode *N, SelectionDAG &DAG, const ARMSubtarget *Subtarget) { // Try to invert "or A, B" -> "and ~A, ~B", as the "and" is easier to chain // together with predicates EVT VT = N->getValueType(0); + SDLoc DL(N); SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ARMCC::CondCodes CondCode0 = ARMCC::AL; - ARMCC::CondCodes CondCode1 = ARMCC::AL; - if (N0->getOpcode() == ARMISD::VCMP) - CondCode0 = (ARMCC::CondCodes)cast<const ConstantSDNode>(N0->getOperand(2)) - ->getZExtValue(); - else if (N0->getOpcode() == ARMISD::VCMPZ) - CondCode0 = (ARMCC::CondCodes)cast<const ConstantSDNode>(N0->getOperand(1)) - ->getZExtValue(); - if (N1->getOpcode() == ARMISD::VCMP) - CondCode1 = (ARMCC::CondCodes)cast<const ConstantSDNode>(N1->getOperand(2)) - ->getZExtValue(); - else if (N1->getOpcode() == ARMISD::VCMPZ) - CondCode1 = (ARMCC::CondCodes)cast<const ConstantSDNode>(N1->getOperand(1)) - ->getZExtValue(); - - if (CondCode0 == ARMCC::AL || CondCode1 == ARMCC::AL) - return SDValue(); - - unsigned Opposite0 = ARMCC::getOppositeCondition(CondCode0); - unsigned Opposite1 = ARMCC::getOppositeCondition(CondCode1); - - if (!isValidMVECond(Opposite0, - N0->getOperand(0)->getValueType(0).isFloatingPoint()) || - !isValidMVECond(Opposite1, - N1->getOperand(0)->getValueType(0).isFloatingPoint())) - return SDValue(); - - SmallVector<SDValue, 4> Ops0; - Ops0.push_back(N0->getOperand(0)); - if (N0->getOpcode() == ARMISD::VCMP) - Ops0.push_back(N0->getOperand(1)); - Ops0.push_back(DCI.DAG.getConstant(Opposite0, SDLoc(N0), MVT::i32)); - SmallVector<SDValue, 4> Ops1; - Ops1.push_back(N1->getOperand(0)); - if (N1->getOpcode() == ARMISD::VCMP) - Ops1.push_back(N1->getOperand(1)); - Ops1.push_back(DCI.DAG.getConstant(Opposite1, SDLoc(N1), MVT::i32)); - - SDValue NewN0 = DCI.DAG.getNode(N0->getOpcode(), SDLoc(N0), VT, Ops0); - SDValue NewN1 = DCI.DAG.getNode(N1->getOpcode(), SDLoc(N1), VT, Ops1); - SDValue And = DCI.DAG.getNode(ISD::AND, SDLoc(N), VT, NewN0, NewN1); - return DCI.DAG.getNode(ISD::XOR, SDLoc(N), VT, And, - DCI.DAG.getAllOnesConstant(SDLoc(N), VT)); + auto IsFreelyInvertable = [&](SDValue V) { + if (V->getOpcode() == ARMISD::VCMP || V->getOpcode() == ARMISD::VCMPZ) + return CanInvertMVEVCMP(V); + return false; + }; + + // At least one operand must be freely invertable. + if (!(IsFreelyInvertable(N0) || IsFreelyInvertable(N1))) + return SDValue(); + + SDValue NewN0 = DAG.getLogicalNOT(DL, N0, VT); + SDValue NewN1 = DAG.getLogicalNOT(DL, N1, VT); + SDValue And = DAG.getNode(ISD::AND, DL, VT, NewN0, NewN1); + return DAG.getLogicalNOT(DL, And, VT); } /// PerformORCombine - Target-specific dag combine xforms for ISD::OR @@ -12492,17 +14549,21 @@ static SDValue PerformORCombine(SDNode *N, if(!DAG.getTargetLoweringInfo().isTypeLegal(VT)) return SDValue(); + if (Subtarget->hasMVEIntegerOps() && (VT == MVT::v2i1 || VT == MVT::v4i1 || + VT == MVT::v8i1 || VT == MVT::v16i1)) + return PerformORCombine_i1(N, DAG, Subtarget); + APInt SplatBits, SplatUndef; unsigned SplatBitSize; bool HasAnyUndefs; - if (BVN && Subtarget->hasNEON() && + if (BVN && (Subtarget->hasNEON() || Subtarget->hasMVEIntegerOps()) && BVN->isConstantSplat(SplatBits, SplatUndef, SplatBitSize, HasAnyUndefs)) { - if (SplatBitSize <= 64) { + if (SplatBitSize == 8 || SplatBitSize == 16 || SplatBitSize == 32 || + SplatBitSize == 64) { EVT VorrVT; - SDValue Val = isVMOVModifiedImm(SplatBits.getZExtValue(), - SplatUndef.getZExtValue(), SplatBitSize, - DAG, dl, VorrVT, VT.is128BitVector(), - OtherModImm); + SDValue Val = + isVMOVModifiedImm(SplatBits.getZExtValue(), SplatUndef.getZExtValue(), + SplatBitSize, DAG, dl, VorrVT, VT, OtherModImm); if (Val.getNode()) { SDValue Input = DAG.getNode(ISD::BITCAST, dl, VorrVT, N->getOperand(0)); @@ -12553,7 +14614,7 @@ static SDValue PerformORCombine(SDNode *N, // Canonicalize the vector type to make instruction selection // simpler. EVT CanonicalVT = VT.is128BitVector() ? MVT::v4i32 : MVT::v2i32; - SDValue Result = DAG.getNode(ARMISD::VBSL, dl, CanonicalVT, + SDValue Result = DAG.getNode(ARMISD::VBSP, dl, CanonicalVT, N0->getOperand(1), N0->getOperand(0), N1->getOperand(0)); @@ -12563,10 +14624,6 @@ static SDValue PerformORCombine(SDNode *N, } } - if (Subtarget->hasMVEIntegerOps() && - (VT == MVT::v4i1 || VT == MVT::v8i1 || VT == MVT::v16i1)) - return PerformORCombine_i1(N, DCI, Subtarget); - // Try to use the ARM/Thumb2 BFI (bitfield insert) instruction when // reasonable. if (N0.getOpcode() == ISD::AND && N0.hasOneUse()) { @@ -12598,6 +14655,27 @@ static SDValue PerformXORCombine(SDNode *N, return Result; } + if (Subtarget->hasMVEIntegerOps()) { + // fold (xor(vcmp/z, 1)) into a vcmp with the opposite condition. + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + const TargetLowering *TLI = Subtarget->getTargetLowering(); + if (TLI->isConstTrueVal(N1) && + (N0->getOpcode() == ARMISD::VCMP || N0->getOpcode() == ARMISD::VCMPZ)) { + if (CanInvertMVEVCMP(N0)) { + SDLoc DL(N0); + ARMCC::CondCodes CC = ARMCC::getOppositeCondition(getVCMPCondCode(N0)); + + SmallVector<SDValue, 4> Ops; + Ops.push_back(N0->getOperand(0)); + if (N0->getOpcode() == ARMISD::VCMP) + Ops.push_back(N0->getOperand(1)); + Ops.push_back(DAG.getConstant(CC, DL, MVT::i32)); + return DAG.getNode(N0->getOpcode(), DL, N0->getValueType(0), Ops); + } + } + } + return SDValue(); } @@ -12634,52 +14712,40 @@ static bool BitsProperlyConcatenate(const APInt &A, const APInt &B) { } static SDValue FindBFIToCombineWith(SDNode *N) { - // We have a BFI in N. Follow a possible chain of BFIs and find a BFI it can combine with, - // if one exists. + // We have a BFI in N. Find a BFI it can combine with, if one exists. APInt ToMask, FromMask; SDValue From = ParseBFI(N, ToMask, FromMask); SDValue To = N->getOperand(0); - // Now check for a compatible BFI to merge with. We can pass through BFIs that - // aren't compatible, but not if they set the same bit in their destination as - // we do (or that of any BFI we're going to combine with). SDValue V = To; - APInt CombinedToMask = ToMask; - while (V.getOpcode() == ARMISD::BFI) { - APInt NewToMask, NewFromMask; - SDValue NewFrom = ParseBFI(V.getNode(), NewToMask, NewFromMask); - if (NewFrom != From) { - // This BFI has a different base. Keep going. - CombinedToMask |= NewToMask; - V = V.getOperand(0); - continue; - } + if (V.getOpcode() != ARMISD::BFI) + return SDValue(); - // Do the written bits conflict with any we've seen so far? - if ((NewToMask & CombinedToMask).getBoolValue()) - // Conflicting bits - bail out because going further is unsafe. - return SDValue(); + APInt NewToMask, NewFromMask; + SDValue NewFrom = ParseBFI(V.getNode(), NewToMask, NewFromMask); + if (NewFrom != From) + return SDValue(); - // Are the new bits contiguous when combined with the old bits? - if (BitsProperlyConcatenate(ToMask, NewToMask) && - BitsProperlyConcatenate(FromMask, NewFromMask)) - return V; - if (BitsProperlyConcatenate(NewToMask, ToMask) && - BitsProperlyConcatenate(NewFromMask, FromMask)) - return V; + // Do the written bits conflict with any we've seen so far? + if ((NewToMask & ToMask).getBoolValue()) + // Conflicting bits. + return SDValue(); - // We've seen a write to some bits, so track it. - CombinedToMask |= NewToMask; - // Keep going... - V = V.getOperand(0); - } + // Are the new bits contiguous when combined with the old bits? + if (BitsProperlyConcatenate(ToMask, NewToMask) && + BitsProperlyConcatenate(FromMask, NewFromMask)) + return V; + if (BitsProperlyConcatenate(NewToMask, ToMask) && + BitsProperlyConcatenate(NewFromMask, FromMask)) + return V; return SDValue(); } -static SDValue PerformBFICombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI) { +static SDValue PerformBFICombine(SDNode *N, SelectionDAG &DAG) { + SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); + if (N1.getOpcode() == ISD::AND) { // (bfi A, (and B, Mask1), Mask2) -> (bfi A, B, Mask2) iff // the bits being cleared by the AND are not demanded by the BFI. @@ -12688,24 +14754,20 @@ static SDValue PerformBFICombine(SDNode *N, return SDValue(); unsigned InvMask = cast<ConstantSDNode>(N->getOperand(2))->getZExtValue(); unsigned LSB = countTrailingZeros(~InvMask); - unsigned Width = (32 - countLeadingZeros(~InvMask)) - LSB; + unsigned Width = llvm::bit_width<unsigned>(~InvMask) - LSB; assert(Width < static_cast<unsigned>(std::numeric_limits<unsigned>::digits) && "undefined behavior"); unsigned Mask = (1u << Width) - 1; unsigned Mask2 = N11C->getZExtValue(); if ((Mask & (~Mask2)) == 0) - return DCI.DAG.getNode(ARMISD::BFI, SDLoc(N), N->getValueType(0), - N->getOperand(0), N1.getOperand(0), - N->getOperand(2)); - } else if (N->getOperand(0).getOpcode() == ARMISD::BFI) { - // We have a BFI of a BFI. Walk up the BFI chain to see how long it goes. - // Keep track of any consecutive bits set that all come from the same base - // value. We can combine these together into a single BFI. - SDValue CombineBFI = FindBFIToCombineWith(N); - if (CombineBFI == SDValue()) - return SDValue(); + return DAG.getNode(ARMISD::BFI, SDLoc(N), N->getValueType(0), + N->getOperand(0), N1.getOperand(0), N->getOperand(2)); + return SDValue(); + } + // Look for another BFI to combine with. + if (SDValue CombineBFI = FindBFIToCombineWith(N)) { // We've found a BFI. APInt ToMask1, FromMask1; SDValue From1 = ParseBFI(N, ToMask1, FromMask1); @@ -12715,9 +14777,7 @@ static SDValue PerformBFICombine(SDNode *N, assert(From1 == From2); (void)From2; - // First, unlink CombineBFI. - DCI.DAG.ReplaceAllUsesWith(CombineBFI, CombineBFI.getOperand(0)); - // Then create a new BFI, combining the two together. + // Create a new BFI, combining the two together. APInt NewFromMask = FromMask1 | FromMask2; APInt NewToMask = ToMask1 | ToMask2; @@ -12725,11 +14785,101 @@ static SDValue PerformBFICombine(SDNode *N, SDLoc dl(N); if (NewFromMask[0] == 0) - From1 = DCI.DAG.getNode( - ISD::SRL, dl, VT, From1, - DCI.DAG.getConstant(NewFromMask.countTrailingZeros(), dl, VT)); - return DCI.DAG.getNode(ARMISD::BFI, dl, VT, N->getOperand(0), From1, - DCI.DAG.getConstant(~NewToMask, dl, VT)); + From1 = DAG.getNode( + ISD::SRL, dl, VT, From1, + DAG.getConstant(NewFromMask.countTrailingZeros(), dl, VT)); + return DAG.getNode(ARMISD::BFI, dl, VT, CombineBFI.getOperand(0), From1, + DAG.getConstant(~NewToMask, dl, VT)); + } + + // Reassociate BFI(BFI (A, B, M1), C, M2) to BFI(BFI (A, C, M2), B, M1) so + // that lower bit insertions are performed first, providing that M1 and M2 + // do no overlap. This can allow multiple BFI instructions to be combined + // together by the other folds above. + if (N->getOperand(0).getOpcode() == ARMISD::BFI) { + APInt ToMask1 = ~N->getConstantOperandAPInt(2); + APInt ToMask2 = ~N0.getConstantOperandAPInt(2); + + if (!N0.hasOneUse() || (ToMask1 & ToMask2) != 0 || + ToMask1.countLeadingZeros() < ToMask2.countLeadingZeros()) + return SDValue(); + + EVT VT = N->getValueType(0); + SDLoc dl(N); + SDValue BFI1 = DAG.getNode(ARMISD::BFI, dl, VT, N0.getOperand(0), + N->getOperand(1), N->getOperand(2)); + return DAG.getNode(ARMISD::BFI, dl, VT, BFI1, N0.getOperand(1), + N0.getOperand(2)); + } + + return SDValue(); +} + +// Check that N is CMPZ(CSINC(0, 0, CC, X)), +// or CMPZ(CMOV(1, 0, CC, $cpsr, X)) +// return X if valid. +static SDValue IsCMPZCSINC(SDNode *Cmp, ARMCC::CondCodes &CC) { + if (Cmp->getOpcode() != ARMISD::CMPZ || !isNullConstant(Cmp->getOperand(1))) + return SDValue(); + SDValue CSInc = Cmp->getOperand(0); + + // Ignore any `And 1` nodes that may not yet have been removed. We are + // looking for a value that produces 1/0, so these have no effect on the + // code. + while (CSInc.getOpcode() == ISD::AND && + isa<ConstantSDNode>(CSInc.getOperand(1)) && + CSInc.getConstantOperandVal(1) == 1 && CSInc->hasOneUse()) + CSInc = CSInc.getOperand(0); + + if (CSInc.getOpcode() == ARMISD::CSINC && + isNullConstant(CSInc.getOperand(0)) && + isNullConstant(CSInc.getOperand(1)) && CSInc->hasOneUse()) { + CC = (ARMCC::CondCodes)CSInc.getConstantOperandVal(2); + return CSInc.getOperand(3); + } + if (CSInc.getOpcode() == ARMISD::CMOV && isOneConstant(CSInc.getOperand(0)) && + isNullConstant(CSInc.getOperand(1)) && CSInc->hasOneUse()) { + CC = (ARMCC::CondCodes)CSInc.getConstantOperandVal(2); + return CSInc.getOperand(4); + } + if (CSInc.getOpcode() == ARMISD::CMOV && isOneConstant(CSInc.getOperand(1)) && + isNullConstant(CSInc.getOperand(0)) && CSInc->hasOneUse()) { + CC = ARMCC::getOppositeCondition( + (ARMCC::CondCodes)CSInc.getConstantOperandVal(2)); + return CSInc.getOperand(4); + } + return SDValue(); +} + +static SDValue PerformCMPZCombine(SDNode *N, SelectionDAG &DAG) { + // Given CMPZ(CSINC(C, 0, 0, EQ), 0), we can just use C directly. As in + // t92: glue = ARMISD::CMPZ t74, 0 + // t93: i32 = ARMISD::CSINC 0, 0, 1, t92 + // t96: glue = ARMISD::CMPZ t93, 0 + // t114: i32 = ARMISD::CSINV 0, 0, 0, t96 + ARMCC::CondCodes Cond; + if (SDValue C = IsCMPZCSINC(N, Cond)) + if (Cond == ARMCC::EQ) + return C; + return SDValue(); +} + +static SDValue PerformCSETCombine(SDNode *N, SelectionDAG &DAG) { + // Fold away an unneccessary CMPZ/CSINC + // CSXYZ A, B, C1 (CMPZ (CSINC 0, 0, C2, D), 0) -> + // if C1==EQ -> CSXYZ A, B, C2, D + // if C1==NE -> CSXYZ A, B, NOT(C2), D + ARMCC::CondCodes Cond; + if (SDValue C = IsCMPZCSINC(N->getOperand(3).getNode(), Cond)) { + if (N->getConstantOperandVal(2) == ARMCC::EQ) + return DAG.getNode(N->getOpcode(), SDLoc(N), MVT::i32, N->getOperand(0), + N->getOperand(1), + DAG.getConstant(Cond, SDLoc(N), MVT::i32), C); + if (N->getConstantOperandVal(2) == ARMCC::NE) + return DAG.getNode( + N->getOpcode(), SDLoc(N), MVT::i32, N->getOperand(0), + N->getOperand(1), + DAG.getConstant(ARMCC::getOppositeCondition(Cond), SDLoc(N), MVT::i32), C); } return SDValue(); } @@ -12758,14 +14908,14 @@ static SDValue PerformVMOVRRDCombine(SDNode *N, SDValue BasePtr = LD->getBasePtr(); SDValue NewLD1 = DAG.getLoad(MVT::i32, DL, LD->getChain(), BasePtr, LD->getPointerInfo(), - LD->getAlignment(), LD->getMemOperand()->getFlags()); + LD->getAlign(), LD->getMemOperand()->getFlags()); SDValue OffsetPtr = DAG.getNode(ISD::ADD, DL, MVT::i32, BasePtr, DAG.getConstant(4, DL, MVT::i32)); SDValue NewLD2 = DAG.getLoad(MVT::i32, DL, LD->getChain(), OffsetPtr, LD->getPointerInfo().getWithOffset(4), - std::min(4U, LD->getAlignment()), + commonAlignment(LD->getAlign(), 4), LD->getMemOperand()->getFlags()); DAG.ReplaceAllUsesOfValueWith(SDValue(LD, 1), NewLD2.getValue(1)); @@ -12775,6 +14925,54 @@ static SDValue PerformVMOVRRDCombine(SDNode *N, return Result; } + // VMOVRRD(extract(..(build_vector(a, b, c, d)))) -> a,b or c,d + // VMOVRRD(extract(insert_vector(insert_vector(.., a, l1), b, l2))) -> a,b + if (InDouble.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + isa<ConstantSDNode>(InDouble.getOperand(1))) { + SDValue BV = InDouble.getOperand(0); + // Look up through any nop bitcasts and vector_reg_casts. bitcasts may + // change lane order under big endian. + bool BVSwap = BV.getOpcode() == ISD::BITCAST; + while ( + (BV.getOpcode() == ISD::BITCAST || + BV.getOpcode() == ARMISD::VECTOR_REG_CAST) && + (BV.getValueType() == MVT::v2f64 || BV.getValueType() == MVT::v2i64)) { + BVSwap = BV.getOpcode() == ISD::BITCAST; + BV = BV.getOperand(0); + } + if (BV.getValueType() != MVT::v4i32) + return SDValue(); + + // Handle buildvectors, pulling out the correct lane depending on + // endianness. + unsigned Offset = InDouble.getConstantOperandVal(1) == 1 ? 2 : 0; + if (BV.getOpcode() == ISD::BUILD_VECTOR) { + SDValue Op0 = BV.getOperand(Offset); + SDValue Op1 = BV.getOperand(Offset + 1); + if (!Subtarget->isLittle() && BVSwap) + std::swap(Op0, Op1); + + return DCI.DAG.getMergeValues({Op0, Op1}, SDLoc(N)); + } + + // A chain of insert_vectors, grabbing the correct value of the chain of + // inserts. + SDValue Op0, Op1; + while (BV.getOpcode() == ISD::INSERT_VECTOR_ELT) { + if (isa<ConstantSDNode>(BV.getOperand(2))) { + if (BV.getConstantOperandVal(2) == Offset) + Op0 = BV.getOperand(1); + if (BV.getConstantOperandVal(2) == Offset + 1) + Op1 = BV.getOperand(1); + } + BV = BV.getOperand(0); + } + if (!Subtarget->isLittle() && BVSwap) + std::swap(Op0, Op1); + if (Op0 && Op1) + return DCI.DAG.getMergeValues({Op0, Op1}, SDLoc(N)); + } + return SDValue(); } @@ -12796,6 +14994,84 @@ static SDValue PerformVMOVDRRCombine(SDNode *N, SelectionDAG &DAG) { return SDValue(); } +static SDValue PerformVMOVhrCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + SDValue Op0 = N->getOperand(0); + + // VMOVhr (VMOVrh (X)) -> X + if (Op0->getOpcode() == ARMISD::VMOVrh) + return Op0->getOperand(0); + + // FullFP16: half values are passed in S-registers, and we don't + // need any of the bitcast and moves: + // + // t2: f32,ch = CopyFromReg t0, Register:f32 %0 + // t5: i32 = bitcast t2 + // t18: f16 = ARMISD::VMOVhr t5 + if (Op0->getOpcode() == ISD::BITCAST) { + SDValue Copy = Op0->getOperand(0); + if (Copy.getValueType() == MVT::f32 && + Copy->getOpcode() == ISD::CopyFromReg) { + SDValue Ops[] = {Copy->getOperand(0), Copy->getOperand(1)}; + SDValue NewCopy = + DCI.DAG.getNode(ISD::CopyFromReg, SDLoc(N), N->getValueType(0), Ops); + return NewCopy; + } + } + + // fold (VMOVhr (load x)) -> (load (f16*)x) + if (LoadSDNode *LN0 = dyn_cast<LoadSDNode>(Op0)) { + if (LN0->hasOneUse() && LN0->isUnindexed() && + LN0->getMemoryVT() == MVT::i16) { + SDValue Load = + DCI.DAG.getLoad(N->getValueType(0), SDLoc(N), LN0->getChain(), + LN0->getBasePtr(), LN0->getMemOperand()); + DCI.DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Load.getValue(0)); + DCI.DAG.ReplaceAllUsesOfValueWith(Op0.getValue(1), Load.getValue(1)); + return Load; + } + } + + // Only the bottom 16 bits of the source register are used. + APInt DemandedMask = APInt::getLowBitsSet(32, 16); + const TargetLowering &TLI = DCI.DAG.getTargetLoweringInfo(); + if (TLI.SimplifyDemandedBits(Op0, DemandedMask, DCI)) + return SDValue(N, 0); + + return SDValue(); +} + +static SDValue PerformVMOVrhCombine(SDNode *N, SelectionDAG &DAG) { + SDValue N0 = N->getOperand(0); + EVT VT = N->getValueType(0); + + // fold (VMOVrh (fpconst x)) -> const x + if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(N0)) { + APFloat V = C->getValueAPF(); + return DAG.getConstant(V.bitcastToAPInt().getZExtValue(), SDLoc(N), VT); + } + + // fold (VMOVrh (load x)) -> (zextload (i16*)x) + if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse()) { + LoadSDNode *LN0 = cast<LoadSDNode>(N0); + + SDValue Load = + DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N), VT, LN0->getChain(), + LN0->getBasePtr(), MVT::i16, LN0->getMemOperand()); + DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Load.getValue(0)); + DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1)); + return Load; + } + + // Fold VMOVrh(extract(x, n)) -> vgetlaneu(x, n) + if (N0->getOpcode() == ISD::EXTRACT_VECTOR_ELT && + isa<ConstantSDNode>(N0->getOperand(1))) + return DAG.getNode(ARMISD::VGETLANEu, SDLoc(N), VT, N0->getOperand(0), + N0->getOperand(1)); + + return SDValue(); +} + /// hasNormalLoadOperand - Check if any of the operands of a BUILD_VECTOR node /// are normal, non-volatile loads. If so, it is profitable to bitcast an /// i64 vector to have f64 elements, since the value can then be loaded @@ -12946,15 +15222,55 @@ PerformPREDICATE_CASTCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { // If the valuetypes are the same, we can remove the cast entirely. if (Op->getOperand(0).getValueType() == VT) return Op->getOperand(0); - return DCI.DAG.getNode(ARMISD::PREDICATE_CAST, dl, - Op->getOperand(0).getValueType(), Op->getOperand(0)); + return DCI.DAG.getNode(ARMISD::PREDICATE_CAST, dl, VT, Op->getOperand(0)); } + // Turn pred_cast(xor x, -1) into xor(pred_cast x, -1), in order to produce + // more VPNOT which might get folded as else predicates. + if (Op.getValueType() == MVT::i32 && isBitwiseNot(Op)) { + SDValue X = + DCI.DAG.getNode(ARMISD::PREDICATE_CAST, dl, VT, Op->getOperand(0)); + SDValue C = DCI.DAG.getNode(ARMISD::PREDICATE_CAST, dl, VT, + DCI.DAG.getConstant(65535, dl, MVT::i32)); + return DCI.DAG.getNode(ISD::XOR, dl, VT, X, C); + } + + // Only the bottom 16 bits of the source register are used. + if (Op.getValueType() == MVT::i32) { + APInt DemandedMask = APInt::getLowBitsSet(32, 16); + const TargetLowering &TLI = DCI.DAG.getTargetLoweringInfo(); + if (TLI.SimplifyDemandedBits(Op, DemandedMask, DCI)) + return SDValue(N, 0); + } return SDValue(); } -static SDValue PerformVCMPCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI, +static SDValue PerformVECTOR_REG_CASTCombine(SDNode *N, SelectionDAG &DAG, + const ARMSubtarget *ST) { + EVT VT = N->getValueType(0); + SDValue Op = N->getOperand(0); + SDLoc dl(N); + + // Under Little endian, a VECTOR_REG_CAST is equivalent to a BITCAST + if (ST->isLittle()) + return DAG.getNode(ISD::BITCAST, dl, VT, Op); + + // VECTOR_REG_CAST undef -> undef + if (Op.isUndef()) + return DAG.getUNDEF(VT); + + // VECTOR_REG_CAST(VECTOR_REG_CAST(x)) == VECTOR_REG_CAST(x) + if (Op->getOpcode() == ARMISD::VECTOR_REG_CAST) { + // If the valuetypes are the same, we can remove the cast entirely. + if (Op->getOperand(0).getValueType() == VT) + return Op->getOperand(0); + return DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, VT, Op->getOperand(0)); + } + + return SDValue(); +} + +static SDValue PerformVCMPCombine(SDNode *N, SelectionDAG &DAG, const ARMSubtarget *Subtarget) { if (!Subtarget->hasMVEIntegerOps()) return SDValue(); @@ -12968,19 +15284,18 @@ static SDValue PerformVCMPCombine(SDNode *N, // vcmp X, 0, cc -> vcmpz X, cc if (isZeroVector(Op1)) - return DCI.DAG.getNode(ARMISD::VCMPZ, dl, VT, Op0, - N->getOperand(2)); + return DAG.getNode(ARMISD::VCMPZ, dl, VT, Op0, N->getOperand(2)); unsigned SwappedCond = getSwappedCondition(Cond); if (isValidMVECond(SwappedCond, VT.isFloatingPoint())) { // vcmp 0, X, cc -> vcmpz X, reversed(cc) if (isZeroVector(Op0)) - return DCI.DAG.getNode(ARMISD::VCMPZ, dl, VT, Op1, - DCI.DAG.getConstant(SwappedCond, dl, MVT::i32)); + return DAG.getNode(ARMISD::VCMPZ, dl, VT, Op1, + DAG.getConstant(SwappedCond, dl, MVT::i32)); // vcmp vdup(Y), X, cc -> vcmp X, vdup(Y), reversed(cc) if (Op0->getOpcode() == ARMISD::VDUP && Op1->getOpcode() != ARMISD::VDUP) - return DCI.DAG.getNode(ARMISD::VCMP, dl, VT, Op1, Op0, - DCI.DAG.getConstant(SwappedCond, dl, MVT::i32)); + return DAG.getNode(ARMISD::VCMP, dl, VT, Op1, Op0, + DAG.getConstant(SwappedCond, dl, MVT::i32)); } return SDValue(); @@ -13012,9 +15327,265 @@ static SDValue PerformInsertEltCombine(SDNode *N, return DAG.getNode(ISD::BITCAST, dl, VT, InsElt); } +// Convert a pair of extracts from the same base vector to a VMOVRRD. Either +// directly or bitcast to an integer if the original is a float vector. +// extract(x, n); extract(x, n+1) -> VMOVRRD(extract v2f64 x, n/2) +// bitcast(extract(x, n)); bitcast(extract(x, n+1)) -> VMOVRRD(extract x, n/2) +static SDValue +PerformExtractEltToVMOVRRD(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { + EVT VT = N->getValueType(0); + SDLoc dl(N); + + if (!DCI.isAfterLegalizeDAG() || VT != MVT::i32 || + !DCI.DAG.getTargetLoweringInfo().isTypeLegal(MVT::f64)) + return SDValue(); + + SDValue Ext = SDValue(N, 0); + if (Ext.getOpcode() == ISD::BITCAST && + Ext.getOperand(0).getValueType() == MVT::f32) + Ext = Ext.getOperand(0); + if (Ext.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + !isa<ConstantSDNode>(Ext.getOperand(1)) || + Ext.getConstantOperandVal(1) % 2 != 0) + return SDValue(); + if (Ext->use_size() == 1 && + (Ext->use_begin()->getOpcode() == ISD::SINT_TO_FP || + Ext->use_begin()->getOpcode() == ISD::UINT_TO_FP)) + return SDValue(); + + SDValue Op0 = Ext.getOperand(0); + EVT VecVT = Op0.getValueType(); + unsigned ResNo = Op0.getResNo(); + unsigned Lane = Ext.getConstantOperandVal(1); + if (VecVT.getVectorNumElements() != 4) + return SDValue(); + + // Find another extract, of Lane + 1 + auto OtherIt = find_if(Op0->uses(), [&](SDNode *V) { + return V->getOpcode() == ISD::EXTRACT_VECTOR_ELT && + isa<ConstantSDNode>(V->getOperand(1)) && + V->getConstantOperandVal(1) == Lane + 1 && + V->getOperand(0).getResNo() == ResNo; + }); + if (OtherIt == Op0->uses().end()) + return SDValue(); + + // For float extracts, we need to be converting to a i32 for both vector + // lanes. + SDValue OtherExt(*OtherIt, 0); + if (OtherExt.getValueType() != MVT::i32) { + if (OtherExt->use_size() != 1 || + OtherExt->use_begin()->getOpcode() != ISD::BITCAST || + OtherExt->use_begin()->getValueType(0) != MVT::i32) + return SDValue(); + OtherExt = SDValue(*OtherExt->use_begin(), 0); + } + + // Convert the type to a f64 and extract with a VMOVRRD. + SDValue F64 = DCI.DAG.getNode( + ISD::EXTRACT_VECTOR_ELT, dl, MVT::f64, + DCI.DAG.getNode(ARMISD::VECTOR_REG_CAST, dl, MVT::v2f64, Op0), + DCI.DAG.getConstant(Ext.getConstantOperandVal(1) / 2, dl, MVT::i32)); + SDValue VMOVRRD = + DCI.DAG.getNode(ARMISD::VMOVRRD, dl, {MVT::i32, MVT::i32}, F64); + + DCI.CombineTo(OtherExt.getNode(), SDValue(VMOVRRD.getNode(), 1)); + return VMOVRRD; +} + +static SDValue PerformExtractEltCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *ST) { + SDValue Op0 = N->getOperand(0); + EVT VT = N->getValueType(0); + SDLoc dl(N); + + // extract (vdup x) -> x + if (Op0->getOpcode() == ARMISD::VDUP) { + SDValue X = Op0->getOperand(0); + if (VT == MVT::f16 && X.getValueType() == MVT::i32) + return DCI.DAG.getNode(ARMISD::VMOVhr, dl, VT, X); + if (VT == MVT::i32 && X.getValueType() == MVT::f16) + return DCI.DAG.getNode(ARMISD::VMOVrh, dl, VT, X); + if (VT == MVT::f32 && X.getValueType() == MVT::i32) + return DCI.DAG.getNode(ISD::BITCAST, dl, VT, X); + + while (X.getValueType() != VT && X->getOpcode() == ISD::BITCAST) + X = X->getOperand(0); + if (X.getValueType() == VT) + return X; + } + + // extract ARM_BUILD_VECTOR -> x + if (Op0->getOpcode() == ARMISD::BUILD_VECTOR && + isa<ConstantSDNode>(N->getOperand(1)) && + N->getConstantOperandVal(1) < Op0.getNumOperands()) { + return Op0.getOperand(N->getConstantOperandVal(1)); + } + + // extract(bitcast(BUILD_VECTOR(VMOVDRR(a, b), ..))) -> a or b + if (Op0.getValueType() == MVT::v4i32 && + isa<ConstantSDNode>(N->getOperand(1)) && + Op0.getOpcode() == ISD::BITCAST && + Op0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR && + Op0.getOperand(0).getValueType() == MVT::v2f64) { + SDValue BV = Op0.getOperand(0); + unsigned Offset = N->getConstantOperandVal(1); + SDValue MOV = BV.getOperand(Offset < 2 ? 0 : 1); + if (MOV.getOpcode() == ARMISD::VMOVDRR) + return MOV.getOperand(ST->isLittle() ? Offset % 2 : 1 - Offset % 2); + } + + // extract x, n; extract x, n+1 -> VMOVRRD x + if (SDValue R = PerformExtractEltToVMOVRRD(N, DCI)) + return R; + + // extract (MVETrunc(x)) -> extract x + if (Op0->getOpcode() == ARMISD::MVETRUNC) { + unsigned Idx = N->getConstantOperandVal(1); + unsigned Vec = + Idx / Op0->getOperand(0).getValueType().getVectorNumElements(); + unsigned SubIdx = + Idx % Op0->getOperand(0).getValueType().getVectorNumElements(); + return DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Op0.getOperand(Vec), + DCI.DAG.getConstant(SubIdx, dl, MVT::i32)); + } + + return SDValue(); +} + +static SDValue PerformSignExtendInregCombine(SDNode *N, SelectionDAG &DAG) { + SDValue Op = N->getOperand(0); + EVT VT = N->getValueType(0); + + // sext_inreg(VGETLANEu) -> VGETLANEs + if (Op.getOpcode() == ARMISD::VGETLANEu && + cast<VTSDNode>(N->getOperand(1))->getVT() == + Op.getOperand(0).getValueType().getScalarType()) + return DAG.getNode(ARMISD::VGETLANEs, SDLoc(N), VT, Op.getOperand(0), + Op.getOperand(1)); + + return SDValue(); +} + +// When lowering complex nodes that we recognize, like VQDMULH and MULH, we +// can end up with shuffle(binop(shuffle, shuffle)), that can be simplified to +// binop as the shuffles cancel out. +static SDValue FlattenVectorShuffle(ShuffleVectorSDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + if (!N->getOperand(1).isUndef() || N->getOperand(0).getValueType() != VT) + return SDValue(); + SDValue Op = N->getOperand(0); + + // Looking for binary operators that will have been folded from + // truncates/extends. + switch (Op.getOpcode()) { + case ARMISD::VQDMULH: + case ISD::MULHS: + case ISD::MULHU: + case ISD::ABDS: + case ISD::ABDU: + case ISD::AVGFLOORS: + case ISD::AVGFLOORU: + case ISD::AVGCEILS: + case ISD::AVGCEILU: + break; + default: + return SDValue(); + } + + ShuffleVectorSDNode *Op0 = dyn_cast<ShuffleVectorSDNode>(Op.getOperand(0)); + ShuffleVectorSDNode *Op1 = dyn_cast<ShuffleVectorSDNode>(Op.getOperand(1)); + if (!Op0 || !Op1 || !Op0->getOperand(1).isUndef() || + !Op1->getOperand(1).isUndef() || Op0->getMask() != Op1->getMask() || + Op0->getOperand(0).getValueType() != VT) + return SDValue(); + + // Check the mask turns into an identity shuffle. + ArrayRef<int> NMask = N->getMask(); + ArrayRef<int> OpMask = Op0->getMask(); + for (int i = 0, e = NMask.size(); i != e; i++) { + if (NMask[i] > 0 && OpMask[NMask[i]] > 0 && OpMask[NMask[i]] != i) + return SDValue(); + } + + return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), + Op0->getOperand(0), Op1->getOperand(0)); +} + +static SDValue +PerformInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { + SDValue Vec = N->getOperand(0); + SDValue SubVec = N->getOperand(1); + uint64_t IdxVal = N->getConstantOperandVal(2); + EVT VecVT = Vec.getValueType(); + EVT SubVT = SubVec.getValueType(); + + // Only do this for legal fixed vector types. + if (!VecVT.isFixedLengthVector() || + !DCI.DAG.getTargetLoweringInfo().isTypeLegal(VecVT) || + !DCI.DAG.getTargetLoweringInfo().isTypeLegal(SubVT)) + return SDValue(); + + // Ignore widening patterns. + if (IdxVal == 0 && Vec.isUndef()) + return SDValue(); + + // Subvector must be half the width and an "aligned" insertion. + unsigned NumSubElts = SubVT.getVectorNumElements(); + if ((SubVT.getSizeInBits() * 2) != VecVT.getSizeInBits() || + (IdxVal != 0 && IdxVal != NumSubElts)) + return SDValue(); + + // Fold insert_subvector -> concat_vectors + // insert_subvector(Vec,Sub,lo) -> concat_vectors(Sub,extract(Vec,hi)) + // insert_subvector(Vec,Sub,hi) -> concat_vectors(extract(Vec,lo),Sub) + SDLoc DL(N); + SDValue Lo, Hi; + if (IdxVal == 0) { + Lo = SubVec; + Hi = DCI.DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, Vec, + DCI.DAG.getVectorIdxConstant(NumSubElts, DL)); + } else { + Lo = DCI.DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, Vec, + DCI.DAG.getVectorIdxConstant(0, DL)); + Hi = SubVec; + } + return DCI.DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, Lo, Hi); +} + +// shuffle(MVETrunc(x, y)) -> VMOVN(x, y) +static SDValue PerformShuffleVMOVNCombine(ShuffleVectorSDNode *N, + SelectionDAG &DAG) { + SDValue Trunc = N->getOperand(0); + EVT VT = Trunc.getValueType(); + if (Trunc.getOpcode() != ARMISD::MVETRUNC || !N->getOperand(1).isUndef()) + return SDValue(); + + SDLoc DL(Trunc); + if (isVMOVNTruncMask(N->getMask(), VT, false)) + return DAG.getNode( + ARMISD::VMOVN, DL, VT, + DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, VT, Trunc.getOperand(0)), + DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, VT, Trunc.getOperand(1)), + DAG.getConstant(1, DL, MVT::i32)); + else if (isVMOVNTruncMask(N->getMask(), VT, true)) + return DAG.getNode( + ARMISD::VMOVN, DL, VT, + DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, VT, Trunc.getOperand(1)), + DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, VT, Trunc.getOperand(0)), + DAG.getConstant(1, DL, MVT::i32)); + return SDValue(); +} + /// PerformVECTOR_SHUFFLECombine - Target-specific dag combine xforms for /// ISD::VECTOR_SHUFFLE. static SDValue PerformVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG) { + if (SDValue R = FlattenVectorShuffle(cast<ShuffleVectorSDNode>(N), DAG)) + return R; + if (SDValue R = PerformShuffleVMOVNCombine(cast<ShuffleVectorSDNode>(N), DAG)) + return R; + // The LLVM shufflevector instruction does not require the shuffle mask // length to match the operand vector length, but ISD::VECTOR_SHUFFLE does // have that requirement. When translating to ISD::VECTOR_SHUFFLE, if the @@ -13064,6 +15635,388 @@ static SDValue PerformVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG) { DAG.getUNDEF(VT), NewMask); } +/// Load/store instruction that can be merged with a base address +/// update +struct BaseUpdateTarget { + SDNode *N; + bool isIntrinsic; + bool isStore; + unsigned AddrOpIdx; +}; + +struct BaseUpdateUser { + /// Instruction that updates a pointer + SDNode *N; + /// Pointer increment operand + SDValue Inc; + /// Pointer increment value if it is a constant, or 0 otherwise + unsigned ConstInc; +}; + +static bool TryCombineBaseUpdate(struct BaseUpdateTarget &Target, + struct BaseUpdateUser &User, + bool SimpleConstIncOnly, + TargetLowering::DAGCombinerInfo &DCI) { + SelectionDAG &DAG = DCI.DAG; + SDNode *N = Target.N; + MemSDNode *MemN = cast<MemSDNode>(N); + SDLoc dl(N); + + // Find the new opcode for the updating load/store. + bool isLoadOp = true; + bool isLaneOp = false; + // Workaround for vst1x and vld1x intrinsics which do not have alignment + // as an operand. + bool hasAlignment = true; + unsigned NewOpc = 0; + unsigned NumVecs = 0; + if (Target.isIntrinsic) { + unsigned IntNo = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue(); + switch (IntNo) { + default: + llvm_unreachable("unexpected intrinsic for Neon base update"); + case Intrinsic::arm_neon_vld1: + NewOpc = ARMISD::VLD1_UPD; + NumVecs = 1; + break; + case Intrinsic::arm_neon_vld2: + NewOpc = ARMISD::VLD2_UPD; + NumVecs = 2; + break; + case Intrinsic::arm_neon_vld3: + NewOpc = ARMISD::VLD3_UPD; + NumVecs = 3; + break; + case Intrinsic::arm_neon_vld4: + NewOpc = ARMISD::VLD4_UPD; + NumVecs = 4; + break; + case Intrinsic::arm_neon_vld1x2: + NewOpc = ARMISD::VLD1x2_UPD; + NumVecs = 2; + hasAlignment = false; + break; + case Intrinsic::arm_neon_vld1x3: + NewOpc = ARMISD::VLD1x3_UPD; + NumVecs = 3; + hasAlignment = false; + break; + case Intrinsic::arm_neon_vld1x4: + NewOpc = ARMISD::VLD1x4_UPD; + NumVecs = 4; + hasAlignment = false; + break; + case Intrinsic::arm_neon_vld2dup: + NewOpc = ARMISD::VLD2DUP_UPD; + NumVecs = 2; + break; + case Intrinsic::arm_neon_vld3dup: + NewOpc = ARMISD::VLD3DUP_UPD; + NumVecs = 3; + break; + case Intrinsic::arm_neon_vld4dup: + NewOpc = ARMISD::VLD4DUP_UPD; + NumVecs = 4; + break; + case Intrinsic::arm_neon_vld2lane: + NewOpc = ARMISD::VLD2LN_UPD; + NumVecs = 2; + isLaneOp = true; + break; + case Intrinsic::arm_neon_vld3lane: + NewOpc = ARMISD::VLD3LN_UPD; + NumVecs = 3; + isLaneOp = true; + break; + case Intrinsic::arm_neon_vld4lane: + NewOpc = ARMISD::VLD4LN_UPD; + NumVecs = 4; + isLaneOp = true; + break; + case Intrinsic::arm_neon_vst1: + NewOpc = ARMISD::VST1_UPD; + NumVecs = 1; + isLoadOp = false; + break; + case Intrinsic::arm_neon_vst2: + NewOpc = ARMISD::VST2_UPD; + NumVecs = 2; + isLoadOp = false; + break; + case Intrinsic::arm_neon_vst3: + NewOpc = ARMISD::VST3_UPD; + NumVecs = 3; + isLoadOp = false; + break; + case Intrinsic::arm_neon_vst4: + NewOpc = ARMISD::VST4_UPD; + NumVecs = 4; + isLoadOp = false; + break; + case Intrinsic::arm_neon_vst2lane: + NewOpc = ARMISD::VST2LN_UPD; + NumVecs = 2; + isLoadOp = false; + isLaneOp = true; + break; + case Intrinsic::arm_neon_vst3lane: + NewOpc = ARMISD::VST3LN_UPD; + NumVecs = 3; + isLoadOp = false; + isLaneOp = true; + break; + case Intrinsic::arm_neon_vst4lane: + NewOpc = ARMISD::VST4LN_UPD; + NumVecs = 4; + isLoadOp = false; + isLaneOp = true; + break; + case Intrinsic::arm_neon_vst1x2: + NewOpc = ARMISD::VST1x2_UPD; + NumVecs = 2; + isLoadOp = false; + hasAlignment = false; + break; + case Intrinsic::arm_neon_vst1x3: + NewOpc = ARMISD::VST1x3_UPD; + NumVecs = 3; + isLoadOp = false; + hasAlignment = false; + break; + case Intrinsic::arm_neon_vst1x4: + NewOpc = ARMISD::VST1x4_UPD; + NumVecs = 4; + isLoadOp = false; + hasAlignment = false; + break; + } + } else { + isLaneOp = true; + switch (N->getOpcode()) { + default: + llvm_unreachable("unexpected opcode for Neon base update"); + case ARMISD::VLD1DUP: + NewOpc = ARMISD::VLD1DUP_UPD; + NumVecs = 1; + break; + case ARMISD::VLD2DUP: + NewOpc = ARMISD::VLD2DUP_UPD; + NumVecs = 2; + break; + case ARMISD::VLD3DUP: + NewOpc = ARMISD::VLD3DUP_UPD; + NumVecs = 3; + break; + case ARMISD::VLD4DUP: + NewOpc = ARMISD::VLD4DUP_UPD; + NumVecs = 4; + break; + case ISD::LOAD: + NewOpc = ARMISD::VLD1_UPD; + NumVecs = 1; + isLaneOp = false; + break; + case ISD::STORE: + NewOpc = ARMISD::VST1_UPD; + NumVecs = 1; + isLaneOp = false; + isLoadOp = false; + break; + } + } + + // Find the size of memory referenced by the load/store. + EVT VecTy; + if (isLoadOp) { + VecTy = N->getValueType(0); + } else if (Target.isIntrinsic) { + VecTy = N->getOperand(Target.AddrOpIdx + 1).getValueType(); + } else { + assert(Target.isStore && + "Node has to be a load, a store, or an intrinsic!"); + VecTy = N->getOperand(1).getValueType(); + } + + bool isVLDDUPOp = + NewOpc == ARMISD::VLD1DUP_UPD || NewOpc == ARMISD::VLD2DUP_UPD || + NewOpc == ARMISD::VLD3DUP_UPD || NewOpc == ARMISD::VLD4DUP_UPD; + + unsigned NumBytes = NumVecs * VecTy.getSizeInBits() / 8; + if (isLaneOp || isVLDDUPOp) + NumBytes /= VecTy.getVectorNumElements(); + + if (NumBytes >= 3 * 16 && User.ConstInc != NumBytes) { + // VLD3/4 and VST3/4 for 128-bit vectors are implemented with two + // separate instructions that make it harder to use a non-constant update. + return false; + } + + if (SimpleConstIncOnly && User.ConstInc != NumBytes) + return false; + + // OK, we found an ADD we can fold into the base update. + // Now, create a _UPD node, taking care of not breaking alignment. + + EVT AlignedVecTy = VecTy; + Align Alignment = MemN->getAlign(); + + // If this is a less-than-standard-aligned load/store, change the type to + // match the standard alignment. + // The alignment is overlooked when selecting _UPD variants; and it's + // easier to introduce bitcasts here than fix that. + // There are 3 ways to get to this base-update combine: + // - intrinsics: they are assumed to be properly aligned (to the standard + // alignment of the memory type), so we don't need to do anything. + // - ARMISD::VLDx nodes: they are only generated from the aforementioned + // intrinsics, so, likewise, there's nothing to do. + // - generic load/store instructions: the alignment is specified as an + // explicit operand, rather than implicitly as the standard alignment + // of the memory type (like the intrisics). We need to change the + // memory type to match the explicit alignment. That way, we don't + // generate non-standard-aligned ARMISD::VLDx nodes. + if (isa<LSBaseSDNode>(N)) { + if (Alignment.value() < VecTy.getScalarSizeInBits() / 8) { + MVT EltTy = MVT::getIntegerVT(Alignment.value() * 8); + assert(NumVecs == 1 && "Unexpected multi-element generic load/store."); + assert(!isLaneOp && "Unexpected generic load/store lane."); + unsigned NumElts = NumBytes / (EltTy.getSizeInBits() / 8); + AlignedVecTy = MVT::getVectorVT(EltTy, NumElts); + } + // Don't set an explicit alignment on regular load/stores that we want + // to transform to VLD/VST 1_UPD nodes. + // This matches the behavior of regular load/stores, which only get an + // explicit alignment if the MMO alignment is larger than the standard + // alignment of the memory type. + // Intrinsics, however, always get an explicit alignment, set to the + // alignment of the MMO. + Alignment = Align(1); + } + + // Create the new updating load/store node. + // First, create an SDVTList for the new updating node's results. + EVT Tys[6]; + unsigned NumResultVecs = (isLoadOp ? NumVecs : 0); + unsigned n; + for (n = 0; n < NumResultVecs; ++n) + Tys[n] = AlignedVecTy; + Tys[n++] = MVT::i32; + Tys[n] = MVT::Other; + SDVTList SDTys = DAG.getVTList(ArrayRef(Tys, NumResultVecs + 2)); + + // Then, gather the new node's operands. + SmallVector<SDValue, 8> Ops; + Ops.push_back(N->getOperand(0)); // incoming chain + Ops.push_back(N->getOperand(Target.AddrOpIdx)); + Ops.push_back(User.Inc); + + if (StoreSDNode *StN = dyn_cast<StoreSDNode>(N)) { + // Try to match the intrinsic's signature + Ops.push_back(StN->getValue()); + } else { + // Loads (and of course intrinsics) match the intrinsics' signature, + // so just add all but the alignment operand. + unsigned LastOperand = + hasAlignment ? N->getNumOperands() - 1 : N->getNumOperands(); + for (unsigned i = Target.AddrOpIdx + 1; i < LastOperand; ++i) + Ops.push_back(N->getOperand(i)); + } + + // For all node types, the alignment operand is always the last one. + Ops.push_back(DAG.getConstant(Alignment.value(), dl, MVT::i32)); + + // If this is a non-standard-aligned STORE, the penultimate operand is the + // stored value. Bitcast it to the aligned type. + if (AlignedVecTy != VecTy && N->getOpcode() == ISD::STORE) { + SDValue &StVal = Ops[Ops.size() - 2]; + StVal = DAG.getNode(ISD::BITCAST, dl, AlignedVecTy, StVal); + } + + EVT LoadVT = isLaneOp ? VecTy.getVectorElementType() : AlignedVecTy; + SDValue UpdN = DAG.getMemIntrinsicNode(NewOpc, dl, SDTys, Ops, LoadVT, + MemN->getMemOperand()); + + // Update the uses. + SmallVector<SDValue, 5> NewResults; + for (unsigned i = 0; i < NumResultVecs; ++i) + NewResults.push_back(SDValue(UpdN.getNode(), i)); + + // If this is an non-standard-aligned LOAD, the first result is the loaded + // value. Bitcast it to the expected result type. + if (AlignedVecTy != VecTy && N->getOpcode() == ISD::LOAD) { + SDValue &LdVal = NewResults[0]; + LdVal = DAG.getNode(ISD::BITCAST, dl, VecTy, LdVal); + } + + NewResults.push_back(SDValue(UpdN.getNode(), NumResultVecs + 1)); // chain + DCI.CombineTo(N, NewResults); + DCI.CombineTo(User.N, SDValue(UpdN.getNode(), NumResultVecs)); + + return true; +} + +// If (opcode ptr inc) is and ADD-like instruction, return the +// increment value. Otherwise return 0. +static unsigned getPointerConstIncrement(unsigned Opcode, SDValue Ptr, + SDValue Inc, const SelectionDAG &DAG) { + ConstantSDNode *CInc = dyn_cast<ConstantSDNode>(Inc.getNode()); + if (!CInc) + return 0; + + switch (Opcode) { + case ARMISD::VLD1_UPD: + case ISD::ADD: + return CInc->getZExtValue(); + case ISD::OR: { + if (DAG.haveNoCommonBitsSet(Ptr, Inc)) { + // (OR ptr inc) is the same as (ADD ptr inc) + return CInc->getZExtValue(); + } + return 0; + } + default: + return 0; + } +} + +static bool findPointerConstIncrement(SDNode *N, SDValue *Ptr, SDValue *CInc) { + switch (N->getOpcode()) { + case ISD::ADD: + case ISD::OR: { + if (isa<ConstantSDNode>(N->getOperand(1))) { + *Ptr = N->getOperand(0); + *CInc = N->getOperand(1); + return true; + } + return false; + } + case ARMISD::VLD1_UPD: { + if (isa<ConstantSDNode>(N->getOperand(2))) { + *Ptr = N->getOperand(1); + *CInc = N->getOperand(2); + return true; + } + return false; + } + default: + return false; + } +} + +static bool isValidBaseUpdate(SDNode *N, SDNode *User) { + // Check that the add is independent of the load/store. + // Otherwise, folding it would create a cycle. Search through Addr + // as well, since the User may not be a direct user of Addr and + // only share a base pointer. + SmallPtrSet<const SDNode *, 32> Visited; + SmallVector<const SDNode *, 16> Worklist; + Worklist.push_back(N); + Worklist.push_back(User); + if (SDNode::hasPredecessorHelper(N, Visited, Worklist) || + SDNode::hasPredecessorHelper(User, Visited, Worklist)) + return false; + return true; +} + /// CombineBaseUpdate - Target-specific DAG combine function for VLDDUP, /// NEON load/store intrinsics, and generic vector load/stores, to merge /// base address updates. @@ -13071,18 +16024,125 @@ static SDValue PerformVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG) { /// The caller is assumed to have checked legality. static SDValue CombineBaseUpdate(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { - SelectionDAG &DAG = DCI.DAG; const bool isIntrinsic = (N->getOpcode() == ISD::INTRINSIC_VOID || N->getOpcode() == ISD::INTRINSIC_W_CHAIN); const bool isStore = N->getOpcode() == ISD::STORE; const unsigned AddrOpIdx = ((isIntrinsic || isStore) ? 2 : 1); + BaseUpdateTarget Target = {N, isIntrinsic, isStore, AddrOpIdx}; + SDValue Addr = N->getOperand(AddrOpIdx); + + SmallVector<BaseUpdateUser, 8> BaseUpdates; + + // Search for a use of the address operand that is an increment. + for (SDNode::use_iterator UI = Addr.getNode()->use_begin(), + UE = Addr.getNode()->use_end(); UI != UE; ++UI) { + SDNode *User = *UI; + if (UI.getUse().getResNo() != Addr.getResNo() || + User->getNumOperands() != 2) + continue; + + SDValue Inc = User->getOperand(UI.getOperandNo() == 1 ? 0 : 1); + unsigned ConstInc = + getPointerConstIncrement(User->getOpcode(), Addr, Inc, DCI.DAG); + + if (ConstInc || User->getOpcode() == ISD::ADD) + BaseUpdates.push_back({User, Inc, ConstInc}); + } + + // If the address is a constant pointer increment itself, find + // another constant increment that has the same base operand + SDValue Base; + SDValue CInc; + if (findPointerConstIncrement(Addr.getNode(), &Base, &CInc)) { + unsigned Offset = + getPointerConstIncrement(Addr->getOpcode(), Base, CInc, DCI.DAG); + for (SDNode::use_iterator UI = Base->use_begin(), UE = Base->use_end(); + UI != UE; ++UI) { + + SDNode *User = *UI; + if (UI.getUse().getResNo() != Base.getResNo() || User == Addr.getNode() || + User->getNumOperands() != 2) + continue; + + SDValue UserInc = User->getOperand(UI.getOperandNo() == 0 ? 1 : 0); + unsigned UserOffset = + getPointerConstIncrement(User->getOpcode(), Base, UserInc, DCI.DAG); + + if (!UserOffset || UserOffset <= Offset) + continue; + + unsigned NewConstInc = UserOffset - Offset; + SDValue NewInc = DCI.DAG.getConstant(NewConstInc, SDLoc(N), MVT::i32); + BaseUpdates.push_back({User, NewInc, NewConstInc}); + } + } + + // Try to fold the load/store with an update that matches memory + // access size. This should work well for sequential loads. + // + // Filter out invalid updates as well. + unsigned NumValidUpd = BaseUpdates.size(); + for (unsigned I = 0; I < NumValidUpd;) { + BaseUpdateUser &User = BaseUpdates[I]; + if (!isValidBaseUpdate(N, User.N)) { + --NumValidUpd; + std::swap(BaseUpdates[I], BaseUpdates[NumValidUpd]); + continue; + } + + if (TryCombineBaseUpdate(Target, User, /*SimpleConstIncOnly=*/true, DCI)) + return SDValue(); + ++I; + } + BaseUpdates.resize(NumValidUpd); + + // Try to fold with other users. Non-constant updates are considered + // first, and constant updates are sorted to not break a sequence of + // strided accesses (if there is any). + std::stable_sort(BaseUpdates.begin(), BaseUpdates.end(), + [](const BaseUpdateUser &LHS, const BaseUpdateUser &RHS) { + return LHS.ConstInc < RHS.ConstInc; + }); + for (BaseUpdateUser &User : BaseUpdates) { + if (TryCombineBaseUpdate(Target, User, /*SimpleConstIncOnly=*/false, DCI)) + return SDValue(); + } + return SDValue(); +} + +static SDValue PerformVLDCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + if (DCI.isBeforeLegalize() || DCI.isCalledByLegalizer()) + return SDValue(); + + return CombineBaseUpdate(N, DCI); +} + +static SDValue PerformMVEVLDCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + if (DCI.isBeforeLegalize() || DCI.isCalledByLegalizer()) + return SDValue(); + + SelectionDAG &DAG = DCI.DAG; + SDValue Addr = N->getOperand(2); MemSDNode *MemN = cast<MemSDNode>(N); SDLoc dl(N); + // For the stores, where there are multiple intrinsics we only actually want + // to post-inc the last of the them. + unsigned IntNo = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue(); + if (IntNo == Intrinsic::arm_mve_vst2q && + cast<ConstantSDNode>(N->getOperand(5))->getZExtValue() != 1) + return SDValue(); + if (IntNo == Intrinsic::arm_mve_vst4q && + cast<ConstantSDNode>(N->getOperand(7))->getZExtValue() != 3) + return SDValue(); + // Search for a use of the address operand that is an increment. for (SDNode::use_iterator UI = Addr.getNode()->use_begin(), - UE = Addr.getNode()->use_end(); UI != UE; ++UI) { + UE = Addr.getNode()->use_end(); + UI != UE; ++UI) { SDNode *User = *UI; if (User->getOpcode() != ISD::ADD || UI.getUse().getResNo() != Addr.getResNo()) @@ -13102,126 +16162,46 @@ static SDValue CombineBaseUpdate(SDNode *N, // Find the new opcode for the updating load/store. bool isLoadOp = true; - bool isLaneOp = false; unsigned NewOpc = 0; unsigned NumVecs = 0; - if (isIntrinsic) { - unsigned IntNo = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue(); - switch (IntNo) { - default: llvm_unreachable("unexpected intrinsic for Neon base update"); - case Intrinsic::arm_neon_vld1: NewOpc = ARMISD::VLD1_UPD; - NumVecs = 1; break; - case Intrinsic::arm_neon_vld2: NewOpc = ARMISD::VLD2_UPD; - NumVecs = 2; break; - case Intrinsic::arm_neon_vld3: NewOpc = ARMISD::VLD3_UPD; - NumVecs = 3; break; - case Intrinsic::arm_neon_vld4: NewOpc = ARMISD::VLD4_UPD; - NumVecs = 4; break; - case Intrinsic::arm_neon_vld2dup: - case Intrinsic::arm_neon_vld3dup: - case Intrinsic::arm_neon_vld4dup: - // TODO: Support updating VLDxDUP nodes. For now, we just skip - // combining base updates for such intrinsics. - continue; - case Intrinsic::arm_neon_vld2lane: NewOpc = ARMISD::VLD2LN_UPD; - NumVecs = 2; isLaneOp = true; break; - case Intrinsic::arm_neon_vld3lane: NewOpc = ARMISD::VLD3LN_UPD; - NumVecs = 3; isLaneOp = true; break; - case Intrinsic::arm_neon_vld4lane: NewOpc = ARMISD::VLD4LN_UPD; - NumVecs = 4; isLaneOp = true; break; - case Intrinsic::arm_neon_vst1: NewOpc = ARMISD::VST1_UPD; - NumVecs = 1; isLoadOp = false; break; - case Intrinsic::arm_neon_vst2: NewOpc = ARMISD::VST2_UPD; - NumVecs = 2; isLoadOp = false; break; - case Intrinsic::arm_neon_vst3: NewOpc = ARMISD::VST3_UPD; - NumVecs = 3; isLoadOp = false; break; - case Intrinsic::arm_neon_vst4: NewOpc = ARMISD::VST4_UPD; - NumVecs = 4; isLoadOp = false; break; - case Intrinsic::arm_neon_vst2lane: NewOpc = ARMISD::VST2LN_UPD; - NumVecs = 2; isLoadOp = false; isLaneOp = true; break; - case Intrinsic::arm_neon_vst3lane: NewOpc = ARMISD::VST3LN_UPD; - NumVecs = 3; isLoadOp = false; isLaneOp = true; break; - case Intrinsic::arm_neon_vst4lane: NewOpc = ARMISD::VST4LN_UPD; - NumVecs = 4; isLoadOp = false; isLaneOp = true; break; - } - } else { - isLaneOp = true; - switch (N->getOpcode()) { - default: llvm_unreachable("unexpected opcode for Neon base update"); - case ARMISD::VLD1DUP: NewOpc = ARMISD::VLD1DUP_UPD; NumVecs = 1; break; - case ARMISD::VLD2DUP: NewOpc = ARMISD::VLD2DUP_UPD; NumVecs = 2; break; - case ARMISD::VLD3DUP: NewOpc = ARMISD::VLD3DUP_UPD; NumVecs = 3; break; - case ARMISD::VLD4DUP: NewOpc = ARMISD::VLD4DUP_UPD; NumVecs = 4; break; - case ISD::LOAD: NewOpc = ARMISD::VLD1_UPD; - NumVecs = 1; isLaneOp = false; break; - case ISD::STORE: NewOpc = ARMISD::VST1_UPD; - NumVecs = 1; isLaneOp = false; isLoadOp = false; break; - } + switch (IntNo) { + default: + llvm_unreachable("unexpected intrinsic for MVE VLDn combine"); + case Intrinsic::arm_mve_vld2q: + NewOpc = ARMISD::VLD2_UPD; + NumVecs = 2; + break; + case Intrinsic::arm_mve_vld4q: + NewOpc = ARMISD::VLD4_UPD; + NumVecs = 4; + break; + case Intrinsic::arm_mve_vst2q: + NewOpc = ARMISD::VST2_UPD; + NumVecs = 2; + isLoadOp = false; + break; + case Intrinsic::arm_mve_vst4q: + NewOpc = ARMISD::VST4_UPD; + NumVecs = 4; + isLoadOp = false; + break; } // Find the size of memory referenced by the load/store. EVT VecTy; if (isLoadOp) { VecTy = N->getValueType(0); - } else if (isIntrinsic) { - VecTy = N->getOperand(AddrOpIdx+1).getValueType(); } else { - assert(isStore && "Node has to be a load, a store, or an intrinsic!"); - VecTy = N->getOperand(1).getValueType(); + VecTy = N->getOperand(3).getValueType(); } unsigned NumBytes = NumVecs * VecTy.getSizeInBits() / 8; - if (isLaneOp) - NumBytes /= VecTy.getVectorNumElements(); // If the increment is a constant, it must match the memory ref size. SDValue Inc = User->getOperand(User->getOperand(0) == Addr ? 1 : 0); ConstantSDNode *CInc = dyn_cast<ConstantSDNode>(Inc.getNode()); - if (NumBytes >= 3 * 16 && (!CInc || CInc->getZExtValue() != NumBytes)) { - // VLD3/4 and VST3/4 for 128-bit vectors are implemented with two - // separate instructions that make it harder to use a non-constant update. + if (!CInc || CInc->getZExtValue() != NumBytes) continue; - } - - // OK, we found an ADD we can fold into the base update. - // Now, create a _UPD node, taking care of not breaking alignment. - - EVT AlignedVecTy = VecTy; - unsigned Alignment = MemN->getAlignment(); - - // If this is a less-than-standard-aligned load/store, change the type to - // match the standard alignment. - // The alignment is overlooked when selecting _UPD variants; and it's - // easier to introduce bitcasts here than fix that. - // There are 3 ways to get to this base-update combine: - // - intrinsics: they are assumed to be properly aligned (to the standard - // alignment of the memory type), so we don't need to do anything. - // - ARMISD::VLDx nodes: they are only generated from the aforementioned - // intrinsics, so, likewise, there's nothing to do. - // - generic load/store instructions: the alignment is specified as an - // explicit operand, rather than implicitly as the standard alignment - // of the memory type (like the intrisics). We need to change the - // memory type to match the explicit alignment. That way, we don't - // generate non-standard-aligned ARMISD::VLDx nodes. - if (isa<LSBaseSDNode>(N)) { - if (Alignment == 0) - Alignment = 1; - if (Alignment < VecTy.getScalarSizeInBits() / 8) { - MVT EltTy = MVT::getIntegerVT(Alignment * 8); - assert(NumVecs == 1 && "Unexpected multi-element generic load/store."); - assert(!isLaneOp && "Unexpected generic load/store lane."); - unsigned NumElts = NumBytes / (EltTy.getSizeInBits() / 8); - AlignedVecTy = MVT::getVectorVT(EltTy, NumElts); - } - // Don't set an explicit alignment on regular load/stores that we want - // to transform to VLD/VST 1_UPD nodes. - // This matches the behavior of regular load/stores, which only get an - // explicit alignment if the MMO alignment is larger than the standard - // alignment of the memory type. - // Intrinsics, however, always get an explicit alignment, set to the - // alignment of the MMO. - Alignment = 1; - } // Create the new updating load/store node. // First, create an SDVTList for the new updating node's results. @@ -13229,39 +16209,21 @@ static SDValue CombineBaseUpdate(SDNode *N, unsigned NumResultVecs = (isLoadOp ? NumVecs : 0); unsigned n; for (n = 0; n < NumResultVecs; ++n) - Tys[n] = AlignedVecTy; + Tys[n] = VecTy; Tys[n++] = MVT::i32; Tys[n] = MVT::Other; - SDVTList SDTys = DAG.getVTList(makeArrayRef(Tys, NumResultVecs+2)); + SDVTList SDTys = DAG.getVTList(ArrayRef(Tys, NumResultVecs + 2)); // Then, gather the new node's operands. SmallVector<SDValue, 8> Ops; Ops.push_back(N->getOperand(0)); // incoming chain - Ops.push_back(N->getOperand(AddrOpIdx)); + Ops.push_back(N->getOperand(2)); // ptr Ops.push_back(Inc); - if (StoreSDNode *StN = dyn_cast<StoreSDNode>(N)) { - // Try to match the intrinsic's signature - Ops.push_back(StN->getValue()); - } else { - // Loads (and of course intrinsics) match the intrinsics' signature, - // so just add all but the alignment operand. - for (unsigned i = AddrOpIdx + 1; i < N->getNumOperands() - 1; ++i) - Ops.push_back(N->getOperand(i)); - } - - // For all node types, the alignment operand is always the last one. - Ops.push_back(DAG.getConstant(Alignment, dl, MVT::i32)); + for (unsigned i = 3; i < N->getNumOperands(); ++i) + Ops.push_back(N->getOperand(i)); - // If this is a non-standard-aligned STORE, the penultimate operand is the - // stored value. Bitcast it to the aligned type. - if (AlignedVecTy != VecTy && N->getOpcode() == ISD::STORE) { - SDValue &StVal = Ops[Ops.size()-2]; - StVal = DAG.getNode(ISD::BITCAST, dl, AlignedVecTy, StVal); - } - - EVT LoadVT = isLaneOp ? VecTy.getVectorElementType() : AlignedVecTy; - SDValue UpdN = DAG.getMemIntrinsicNode(NewOpc, dl, SDTys, Ops, LoadVT, + SDValue UpdN = DAG.getMemIntrinsicNode(NewOpc, dl, SDTys, Ops, VecTy, MemN->getMemOperand()); // Update the uses. @@ -13269,28 +16231,14 @@ static SDValue CombineBaseUpdate(SDNode *N, for (unsigned i = 0; i < NumResultVecs; ++i) NewResults.push_back(SDValue(UpdN.getNode(), i)); - // If this is an non-standard-aligned LOAD, the first result is the loaded - // value. Bitcast it to the expected result type. - if (AlignedVecTy != VecTy && N->getOpcode() == ISD::LOAD) { - SDValue &LdVal = NewResults[0]; - LdVal = DAG.getNode(ISD::BITCAST, dl, VecTy, LdVal); - } - - NewResults.push_back(SDValue(UpdN.getNode(), NumResultVecs+1)); // chain + NewResults.push_back(SDValue(UpdN.getNode(), NumResultVecs + 1)); // chain DCI.CombineTo(N, NewResults); DCI.CombineTo(User, SDValue(UpdN.getNode(), NumResultVecs)); break; } - return SDValue(); -} -static SDValue PerformVLDCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI) { - if (DCI.isBeforeLegalize() || DCI.isCalledByLegalizer()) - return SDValue(); - - return CombineBaseUpdate(N, DCI); + return SDValue(); } /// CombineVLDDUP - For a VDUPLANE node N, check if its source operand is a @@ -13345,7 +16293,7 @@ static bool CombineVLDDUP(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { for (n = 0; n < NumVecs; ++n) Tys[n] = VT; Tys[n] = MVT::Other; - SDVTList SDTys = DAG.getVTList(makeArrayRef(Tys, NumVecs+1)); + SDVTList SDTys = DAG.getVTList(ArrayRef(Tys, NumVecs + 1)); SDValue Ops[] = { VLD->getOperand(0), VLD->getOperand(2) }; MemIntrinsicSDNode *VLDMemInt = cast<MemIntrinsicSDNode>(VLD); SDValue VLDDup = DAG.getMemIntrinsicNode(NewOpc, SDLoc(VLD), SDTys, @@ -13377,8 +16325,21 @@ static bool CombineVLDDUP(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { /// PerformVDUPLANECombine - Target-specific dag combine xforms for /// ARMISD::VDUPLANE. static SDValue PerformVDUPLANECombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI) { + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *Subtarget) { SDValue Op = N->getOperand(0); + EVT VT = N->getValueType(0); + + // On MVE, we just convert the VDUPLANE to a VDUP with an extract. + if (Subtarget->hasMVEIntegerOps()) { + EVT ExtractVT = VT.getVectorElementType(); + // We need to ensure we are creating a legal type. + if (!DCI.DAG.getTargetLoweringInfo().isTypeLegal(ExtractVT)) + ExtractVT = MVT::i32; + SDValue Extract = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(N), ExtractVT, + N->getOperand(0), N->getOperand(1)); + return DCI.DAG.getNode(ARMISD::VDUP, SDLoc(N), VT, Extract); + } // If the source is a vldN-lane (N > 1) intrinsic, and all the other uses // of that intrinsic are also VDUPLANEs, combine them to a vldN-dup operation. @@ -13399,7 +16360,6 @@ static SDValue PerformVDUPLANECombine(SDNode *N, unsigned EltBits; if (ARM_AM::decodeVMOVModImm(Imm, EltBits) == 0) EltSize = 8; - EVT VT = N->getValueType(0); if (EltSize > VT.getScalarSizeInBits()) return SDValue(); @@ -13407,11 +16367,21 @@ static SDValue PerformVDUPLANECombine(SDNode *N, } /// PerformVDUPCombine - Target-specific dag combine xforms for ARMISD::VDUP. -static SDValue PerformVDUPCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI, +static SDValue PerformVDUPCombine(SDNode *N, SelectionDAG &DAG, const ARMSubtarget *Subtarget) { - SelectionDAG &DAG = DCI.DAG; SDValue Op = N->getOperand(0); + SDLoc dl(N); + + if (Subtarget->hasMVEIntegerOps()) { + // Convert VDUP f32 -> VDUP BITCAST i32 under MVE, as we know the value will + // need to come from a GPR. + if (Op.getValueType() == MVT::f32) + return DAG.getNode(ARMISD::VDUP, dl, N->getValueType(0), + DAG.getNode(ISD::BITCAST, dl, MVT::i32, Op)); + else if (Op.getValueType() == MVT::f16) + return DAG.getNode(ARMISD::VDUP, dl, N->getValueType(0), + DAG.getNode(ARMISD::VMOVrh, dl, MVT::i32, Op)); + } if (!Subtarget->hasNEON()) return SDValue(); @@ -13422,12 +16392,12 @@ static SDValue PerformVDUPCombine(SDNode *N, LoadSDNode *LD = dyn_cast<LoadSDNode>(Op.getNode()); if (LD && Op.hasOneUse() && LD->isUnindexed() && LD->getMemoryVT() == N->getValueType(0).getVectorElementType()) { - SDValue Ops[] = { LD->getOperand(0), LD->getOperand(1), - DAG.getConstant(LD->getAlignment(), SDLoc(N), MVT::i32) }; + SDValue Ops[] = {LD->getOperand(0), LD->getOperand(1), + DAG.getConstant(LD->getAlign().value(), SDLoc(N), MVT::i32)}; SDVTList SDTys = DAG.getVTList(N->getValueType(0), MVT::Other); - SDValue VLDDup = DAG.getMemIntrinsicNode(ARMISD::VLD1DUP, SDLoc(N), SDTys, - Ops, LD->getMemoryVT(), - LD->getMemOperand()); + SDValue VLDDup = + DAG.getMemIntrinsicNode(ARMISD::VLD1DUP, SDLoc(N), SDTys, Ops, + LD->getMemoryVT(), LD->getMemOperand()); DAG.ReplaceAllUsesOfValueWith(SDValue(LD, 1), VLDDup.getValue(1)); return VLDDup; } @@ -13436,11 +16406,12 @@ static SDValue PerformVDUPCombine(SDNode *N, } static SDValue PerformLOADCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI) { + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *Subtarget) { EVT VT = N->getValueType(0); // If this is a legal vector load, try to combine it into a VLD1_UPD. - if (ISD::isNormalLoad(N) && VT.isVector() && + if (Subtarget->hasNEON() && ISD::isNormalLoad(N) && VT.isVector() && DCI.DAG.getTargetLoweringInfo().isTypeLegal(VT)) return CombineBaseUpdate(N, DCI); @@ -13524,7 +16495,7 @@ static SDValue PerformTruncatingStoreCombine(StoreSDNode *St, ShuffWide, DAG.getIntPtrConstant(I, DL)); SDValue Ch = DAG.getStore(St->getChain(), DL, SubVec, BasePtr, St->getPointerInfo(), - St->getAlignment(), St->getMemOperand()->getFlags()); + St->getAlign(), St->getMemOperand()->getFlags()); BasePtr = DAG.getNode(ISD::ADD, DL, BasePtr.getValueType(), BasePtr, Increment); Chains.push_back(Ch); @@ -13532,7 +16503,7 @@ static SDValue PerformTruncatingStoreCombine(StoreSDNode *St, return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains); } -// Try taking a single vector store from an truncate (which would otherwise turn +// Try taking a single vector store from an fpround (which would otherwise turn // into an expensive buildvector) and splitting it into a series of narrowing // stores. static SDValue PerformSplittingToNarrowingStores(StoreSDNode *St, @@ -13540,7 +16511,7 @@ static SDValue PerformSplittingToNarrowingStores(StoreSDNode *St, if (!St->isSimple() || St->isTruncatingStore() || !St->isUnindexed()) return SDValue(); SDValue Trunc = St->getValue(); - if (Trunc->getOpcode() != ISD::TRUNCATE) + if (Trunc->getOpcode() != ISD::FP_ROUND) return SDValue(); EVT FromVT = Trunc->getOperand(0).getValueType(); EVT ToVT = Trunc.getValueType(); @@ -13550,34 +16521,73 @@ static SDValue PerformSplittingToNarrowingStores(StoreSDNode *St, EVT ToEltVT = ToVT.getVectorElementType(); EVT FromEltVT = FromVT.getVectorElementType(); - unsigned NumElements = 0; - if (FromEltVT == MVT::i32 && (ToEltVT == MVT::i16 || ToEltVT == MVT::i8)) - NumElements = 4; - if (FromEltVT == MVT::i16 && ToEltVT == MVT::i8) - NumElements = 8; - if (NumElements == 0 || FromVT.getVectorNumElements() == NumElements || - FromVT.getVectorNumElements() % NumElements != 0) + if (FromEltVT != MVT::f32 || ToEltVT != MVT::f16) + return SDValue(); + + unsigned NumElements = 4; + if (FromVT.getVectorNumElements() % NumElements != 0) return SDValue(); + // Test if the Trunc will be convertable to a VMOVN with a shuffle, and if so + // use the VMOVN over splitting the store. We are looking for patterns of: + // !rev: 0 N 1 N+1 2 N+2 ... + // rev: N 0 N+1 1 N+2 2 ... + // The shuffle may either be a single source (in which case N = NumElts/2) or + // two inputs extended with concat to the same size (in which case N = + // NumElts). + auto isVMOVNShuffle = [&](ShuffleVectorSDNode *SVN, bool Rev) { + ArrayRef<int> M = SVN->getMask(); + unsigned NumElts = ToVT.getVectorNumElements(); + if (SVN->getOperand(1).isUndef()) + NumElts /= 2; + + unsigned Off0 = Rev ? NumElts : 0; + unsigned Off1 = Rev ? 0 : NumElts; + + for (unsigned I = 0; I < NumElts; I += 2) { + if (M[I] >= 0 && M[I] != (int)(Off0 + I / 2)) + return false; + if (M[I + 1] >= 0 && M[I + 1] != (int)(Off1 + I / 2)) + return false; + } + + return true; + }; + + if (auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(Trunc.getOperand(0))) + if (isVMOVNShuffle(Shuffle, false) || isVMOVNShuffle(Shuffle, true)) + return SDValue(); + + LLVMContext &C = *DAG.getContext(); SDLoc DL(St); // Details about the old store SDValue Ch = St->getChain(); SDValue BasePtr = St->getBasePtr(); - unsigned Alignment = St->getOriginalAlignment(); + Align Alignment = St->getOriginalAlign(); MachineMemOperand::Flags MMOFlags = St->getMemOperand()->getFlags(); AAMDNodes AAInfo = St->getAAInfo(); - EVT NewFromVT = EVT::getVectorVT(*DAG.getContext(), FromEltVT, NumElements); - EVT NewToVT = EVT::getVectorVT(*DAG.getContext(), ToEltVT, NumElements); + // We split the store into slices of NumElements. fp16 trunc stores are vcvt + // and then stored as truncating integer stores. + EVT NewFromVT = EVT::getVectorVT(C, FromEltVT, NumElements); + EVT NewToVT = EVT::getVectorVT( + C, EVT::getIntegerVT(C, ToEltVT.getSizeInBits()), NumElements); SmallVector<SDValue, 4> Stores; for (unsigned i = 0; i < FromVT.getVectorNumElements() / NumElements; i++) { unsigned NewOffset = i * NumElements * ToEltVT.getSizeInBits() / 8; - SDValue NewPtr = DAG.getObjectPtrOffset(DL, BasePtr, NewOffset); + SDValue NewPtr = + DAG.getObjectPtrOffset(DL, BasePtr, TypeSize::Fixed(NewOffset)); SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewFromVT, Trunc.getOperand(0), DAG.getConstant(i * NumElements, DL, MVT::i32)); + + SDValue FPTrunc = + DAG.getNode(ARMISD::VCVTN, DL, MVT::v8f16, DAG.getUNDEF(MVT::v8f16), + Extract, DAG.getConstant(0, DL, MVT::i32)); + Extract = DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, MVT::v4i32, FPTrunc); + SDValue Store = DAG.getTruncStore( Ch, DL, Extract, NewPtr, St->getPointerInfo().getWithOffset(NewOffset), NewToVT, Alignment, MMOFlags, AAInfo); @@ -13586,6 +16596,83 @@ static SDValue PerformSplittingToNarrowingStores(StoreSDNode *St, return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Stores); } +// Try taking a single vector store from an MVETRUNC (which would otherwise turn +// into an expensive buildvector) and splitting it into a series of narrowing +// stores. +static SDValue PerformSplittingMVETruncToNarrowingStores(StoreSDNode *St, + SelectionDAG &DAG) { + if (!St->isSimple() || St->isTruncatingStore() || !St->isUnindexed()) + return SDValue(); + SDValue Trunc = St->getValue(); + if (Trunc->getOpcode() != ARMISD::MVETRUNC) + return SDValue(); + EVT FromVT = Trunc->getOperand(0).getValueType(); + EVT ToVT = Trunc.getValueType(); + + LLVMContext &C = *DAG.getContext(); + SDLoc DL(St); + // Details about the old store + SDValue Ch = St->getChain(); + SDValue BasePtr = St->getBasePtr(); + Align Alignment = St->getOriginalAlign(); + MachineMemOperand::Flags MMOFlags = St->getMemOperand()->getFlags(); + AAMDNodes AAInfo = St->getAAInfo(); + + EVT NewToVT = EVT::getVectorVT(C, ToVT.getVectorElementType(), + FromVT.getVectorNumElements()); + + SmallVector<SDValue, 4> Stores; + for (unsigned i = 0; i < Trunc.getNumOperands(); i++) { + unsigned NewOffset = + i * FromVT.getVectorNumElements() * ToVT.getScalarSizeInBits() / 8; + SDValue NewPtr = + DAG.getObjectPtrOffset(DL, BasePtr, TypeSize::Fixed(NewOffset)); + + SDValue Extract = Trunc.getOperand(i); + SDValue Store = DAG.getTruncStore( + Ch, DL, Extract, NewPtr, St->getPointerInfo().getWithOffset(NewOffset), + NewToVT, Alignment, MMOFlags, AAInfo); + Stores.push_back(Store); + } + return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Stores); +} + +// Given a floating point store from an extracted vector, with an integer +// VGETLANE that already exists, store the existing VGETLANEu directly. This can +// help reduce fp register pressure, doesn't require the fp extract and allows +// use of more integer post-inc stores not available with vstr. +static SDValue PerformExtractFpToIntStores(StoreSDNode *St, SelectionDAG &DAG) { + if (!St->isSimple() || St->isTruncatingStore() || !St->isUnindexed()) + return SDValue(); + SDValue Extract = St->getValue(); + EVT VT = Extract.getValueType(); + // For now only uses f16. This may be useful for f32 too, but that will + // be bitcast(extract), not the VGETLANEu we currently check here. + if (VT != MVT::f16 || Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + + SDNode *GetLane = + DAG.getNodeIfExists(ARMISD::VGETLANEu, DAG.getVTList(MVT::i32), + {Extract.getOperand(0), Extract.getOperand(1)}); + if (!GetLane) + return SDValue(); + + LLVMContext &C = *DAG.getContext(); + SDLoc DL(St); + // Create a new integer store to replace the existing floating point version. + SDValue Ch = St->getChain(); + SDValue BasePtr = St->getBasePtr(); + Align Alignment = St->getOriginalAlign(); + MachineMemOperand::Flags MMOFlags = St->getMemOperand()->getFlags(); + AAMDNodes AAInfo = St->getAAInfo(); + EVT NewToVT = EVT::getIntegerVT(C, VT.getSizeInBits()); + SDValue Store = DAG.getTruncStore(Ch, DL, SDValue(GetLane, 0), BasePtr, + St->getPointerInfo(), NewToVT, Alignment, + MMOFlags, AAInfo); + + return Store; +} + /// PerformSTORECombine - Target-specific dag combine xforms for /// ISD::STORE. static SDValue PerformSTORECombine(SDNode *N, @@ -13601,9 +16688,15 @@ static SDValue PerformSTORECombine(SDNode *N, if (SDValue Store = PerformTruncatingStoreCombine(St, DCI.DAG)) return Store; - if (Subtarget->hasMVEIntegerOps()) + if (Subtarget->hasMVEIntegerOps()) { if (SDValue NewToken = PerformSplittingToNarrowingStores(St, DCI.DAG)) return NewToken; + if (SDValue NewChain = PerformExtractFpToIntStores(St, DCI.DAG)) + return NewChain; + if (SDValue NewToken = + PerformSplittingMVETruncToNarrowingStores(St, DCI.DAG)) + return NewToken; + } if (!ISD::isNormalStore(St)) return SDValue(); @@ -13618,15 +16711,15 @@ static SDValue PerformSTORECombine(SDNode *N, SDValue BasePtr = St->getBasePtr(); SDValue NewST1 = DAG.getStore( St->getChain(), DL, StVal.getNode()->getOperand(isBigEndian ? 1 : 0), - BasePtr, St->getPointerInfo(), St->getAlignment(), + BasePtr, St->getPointerInfo(), St->getOriginalAlign(), St->getMemOperand()->getFlags()); SDValue OffsetPtr = DAG.getNode(ISD::ADD, DL, MVT::i32, BasePtr, DAG.getConstant(4, DL, MVT::i32)); return DAG.getStore(NewST1.getValue(0), DL, StVal.getNode()->getOperand(isBigEndian ? 0 : 1), - OffsetPtr, St->getPointerInfo(), - std::min(4U, St->getAlignment() / 2), + OffsetPtr, St->getPointerInfo().getWithOffset(4), + St->getOriginalAlign(), St->getMemOperand()->getFlags()); } @@ -13650,7 +16743,7 @@ static SDValue PerformSTORECombine(SDNode *N, DCI.AddToWorklist(ExtElt.getNode()); DCI.AddToWorklist(V.getNode()); return DAG.getStore(St->getChain(), dl, V, St->getBasePtr(), - St->getPointerInfo(), St->getAlignment(), + St->getPointerInfo(), St->getAlign(), St->getMemOperand()->getFlags(), St->getAAInfo()); } @@ -13719,6 +16812,49 @@ static SDValue PerformVCVTCombine(SDNode *N, SelectionDAG &DAG, return FixConv; } +static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG, + const ARMSubtarget *Subtarget) { + if (!Subtarget->hasMVEFloatOps()) + return SDValue(); + + // Turn (fadd x, (vselect c, y, -0.0)) into (vselect c, (fadd x, y), x) + // The second form can be more easily turned into a predicated vadd, and + // possibly combined into a fma to become a predicated vfma. + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + // The identity element for a fadd is -0.0 or +0.0 when the nsz flag is set, + // which these VMOV's represent. + auto isIdentitySplat = [&](SDValue Op, bool NSZ) { + if (Op.getOpcode() != ISD::BITCAST || + Op.getOperand(0).getOpcode() != ARMISD::VMOVIMM) + return false; + uint64_t ImmVal = Op.getOperand(0).getConstantOperandVal(0); + if (VT == MVT::v4f32 && (ImmVal == 1664 || (ImmVal == 0 && NSZ))) + return true; + if (VT == MVT::v8f16 && (ImmVal == 2688 || (ImmVal == 0 && NSZ))) + return true; + return false; + }; + + if (Op0.getOpcode() == ISD::VSELECT && Op1.getOpcode() != ISD::VSELECT) + std::swap(Op0, Op1); + + if (Op1.getOpcode() != ISD::VSELECT) + return SDValue(); + + SDNodeFlags FaddFlags = N->getFlags(); + bool NSZ = FaddFlags.hasNoSignedZeros(); + if (!isIdentitySplat(Op1.getOperand(2), NSZ)) + return SDValue(); + + SDValue FAdd = + DAG.getNode(ISD::FADD, DL, VT, Op0, Op1.getOperand(1), FaddFlags); + return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0, FaddFlags); +} + /// PerformVDIVCombine - VCVT (fixed-point to floating-point, Advanced SIMD) /// can replace combinations of VCVT (integer to floating-point) and VDIV /// when the VDIV has a constant operand that is a power of 2. @@ -13778,8 +16914,351 @@ static SDValue PerformVDIVCombine(SDNode *N, SelectionDAG &DAG, ConvInput, DAG.getConstant(C, dl, MVT::i32)); } +static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG, + const ARMSubtarget *ST) { + if (!ST->hasMVEIntegerOps()) + return SDValue(); + + assert(N->getOpcode() == ISD::VECREDUCE_ADD); + EVT ResVT = N->getValueType(0); + SDValue N0 = N->getOperand(0); + SDLoc dl(N); + + // Try to turn vecreduce_add(add(x, y)) into vecreduce(x) + vecreduce(y) + if (ResVT == MVT::i32 && N0.getOpcode() == ISD::ADD && + (N0.getValueType() == MVT::v4i32 || N0.getValueType() == MVT::v8i16 || + N0.getValueType() == MVT::v16i8)) { + SDValue Red0 = DAG.getNode(ISD::VECREDUCE_ADD, dl, ResVT, N0.getOperand(0)); + SDValue Red1 = DAG.getNode(ISD::VECREDUCE_ADD, dl, ResVT, N0.getOperand(1)); + return DAG.getNode(ISD::ADD, dl, ResVT, Red0, Red1); + } + + // We are looking for something that will have illegal types if left alone, + // but that we can convert to a single instruction under MVE. For example + // vecreduce_add(sext(A, v8i32)) => VADDV.s16 A + // or + // vecreduce_add(mul(zext(A, v16i32), zext(B, v16i32))) => VMLADAV.u8 A, B + + // The legal cases are: + // VADDV u/s 8/16/32 + // VMLAV u/s 8/16/32 + // VADDLV u/s 32 + // VMLALV u/s 16/32 + + // If the input vector is smaller than legal (v4i8/v4i16 for example) we can + // extend it and use v4i32 instead. + auto ExtTypeMatches = [](SDValue A, ArrayRef<MVT> ExtTypes) { + EVT AVT = A.getValueType(); + return any_of(ExtTypes, [&](MVT Ty) { + return AVT.getVectorNumElements() == Ty.getVectorNumElements() && + AVT.bitsLE(Ty); + }); + }; + auto ExtendIfNeeded = [&](SDValue A, unsigned ExtendCode) { + EVT AVT = A.getValueType(); + if (!AVT.is128BitVector()) + A = DAG.getNode(ExtendCode, dl, + AVT.changeVectorElementType(MVT::getIntegerVT( + 128 / AVT.getVectorMinNumElements())), + A); + return A; + }; + auto IsVADDV = [&](MVT RetTy, unsigned ExtendCode, ArrayRef<MVT> ExtTypes) { + if (ResVT != RetTy || N0->getOpcode() != ExtendCode) + return SDValue(); + SDValue A = N0->getOperand(0); + if (ExtTypeMatches(A, ExtTypes)) + return ExtendIfNeeded(A, ExtendCode); + return SDValue(); + }; + auto IsPredVADDV = [&](MVT RetTy, unsigned ExtendCode, + ArrayRef<MVT> ExtTypes, SDValue &Mask) { + if (ResVT != RetTy || N0->getOpcode() != ISD::VSELECT || + !ISD::isBuildVectorAllZeros(N0->getOperand(2).getNode())) + return SDValue(); + Mask = N0->getOperand(0); + SDValue Ext = N0->getOperand(1); + if (Ext->getOpcode() != ExtendCode) + return SDValue(); + SDValue A = Ext->getOperand(0); + if (ExtTypeMatches(A, ExtTypes)) + return ExtendIfNeeded(A, ExtendCode); + return SDValue(); + }; + auto IsVMLAV = [&](MVT RetTy, unsigned ExtendCode, ArrayRef<MVT> ExtTypes, + SDValue &A, SDValue &B) { + // For a vmla we are trying to match a larger pattern: + // ExtA = sext/zext A + // ExtB = sext/zext B + // Mul = mul ExtA, ExtB + // vecreduce.add Mul + // There might also be en extra extend between the mul and the addreduce, so + // long as the bitwidth is high enough to make them equivalent (for example + // original v8i16 might be mul at v8i32 and the reduce happens at v8i64). + if (ResVT != RetTy) + return false; + SDValue Mul = N0; + if (Mul->getOpcode() == ExtendCode && + Mul->getOperand(0).getScalarValueSizeInBits() * 2 >= + ResVT.getScalarSizeInBits()) + Mul = Mul->getOperand(0); + if (Mul->getOpcode() != ISD::MUL) + return false; + SDValue ExtA = Mul->getOperand(0); + SDValue ExtB = Mul->getOperand(1); + if (ExtA->getOpcode() != ExtendCode || ExtB->getOpcode() != ExtendCode) + return false; + A = ExtA->getOperand(0); + B = ExtB->getOperand(0); + if (ExtTypeMatches(A, ExtTypes) && ExtTypeMatches(B, ExtTypes)) { + A = ExtendIfNeeded(A, ExtendCode); + B = ExtendIfNeeded(B, ExtendCode); + return true; + } + return false; + }; + auto IsPredVMLAV = [&](MVT RetTy, unsigned ExtendCode, ArrayRef<MVT> ExtTypes, + SDValue &A, SDValue &B, SDValue &Mask) { + // Same as the pattern above with a select for the zero predicated lanes + // ExtA = sext/zext A + // ExtB = sext/zext B + // Mul = mul ExtA, ExtB + // N0 = select Mask, Mul, 0 + // vecreduce.add N0 + if (ResVT != RetTy || N0->getOpcode() != ISD::VSELECT || + !ISD::isBuildVectorAllZeros(N0->getOperand(2).getNode())) + return false; + Mask = N0->getOperand(0); + SDValue Mul = N0->getOperand(1); + if (Mul->getOpcode() == ExtendCode && + Mul->getOperand(0).getScalarValueSizeInBits() * 2 >= + ResVT.getScalarSizeInBits()) + Mul = Mul->getOperand(0); + if (Mul->getOpcode() != ISD::MUL) + return false; + SDValue ExtA = Mul->getOperand(0); + SDValue ExtB = Mul->getOperand(1); + if (ExtA->getOpcode() != ExtendCode || ExtB->getOpcode() != ExtendCode) + return false; + A = ExtA->getOperand(0); + B = ExtB->getOperand(0); + if (ExtTypeMatches(A, ExtTypes) && ExtTypeMatches(B, ExtTypes)) { + A = ExtendIfNeeded(A, ExtendCode); + B = ExtendIfNeeded(B, ExtendCode); + return true; + } + return false; + }; + auto Create64bitNode = [&](unsigned Opcode, ArrayRef<SDValue> Ops) { + // Split illegal MVT::v16i8->i64 vector reductions into two legal v8i16->i64 + // reductions. The operands are extended with MVEEXT, but as they are + // reductions the lane orders do not matter. MVEEXT may be combined with + // loads to produce two extending loads, or else they will be expanded to + // VREV/VMOVL. + EVT VT = Ops[0].getValueType(); + if (VT == MVT::v16i8) { + assert((Opcode == ARMISD::VMLALVs || Opcode == ARMISD::VMLALVu) && + "Unexpected illegal long reduction opcode"); + bool IsUnsigned = Opcode == ARMISD::VMLALVu; + + SDValue Ext0 = + DAG.getNode(IsUnsigned ? ARMISD::MVEZEXT : ARMISD::MVESEXT, dl, + DAG.getVTList(MVT::v8i16, MVT::v8i16), Ops[0]); + SDValue Ext1 = + DAG.getNode(IsUnsigned ? ARMISD::MVEZEXT : ARMISD::MVESEXT, dl, + DAG.getVTList(MVT::v8i16, MVT::v8i16), Ops[1]); + + SDValue MLA0 = DAG.getNode(Opcode, dl, DAG.getVTList(MVT::i32, MVT::i32), + Ext0, Ext1); + SDValue MLA1 = + DAG.getNode(IsUnsigned ? ARMISD::VMLALVAu : ARMISD::VMLALVAs, dl, + DAG.getVTList(MVT::i32, MVT::i32), MLA0, MLA0.getValue(1), + Ext0.getValue(1), Ext1.getValue(1)); + return DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, MLA1, MLA1.getValue(1)); + } + SDValue Node = DAG.getNode(Opcode, dl, {MVT::i32, MVT::i32}, Ops); + return DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Node, + SDValue(Node.getNode(), 1)); + }; + + SDValue A, B; + SDValue Mask; + if (IsVMLAV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8}, A, B)) + return DAG.getNode(ARMISD::VMLAVs, dl, ResVT, A, B); + if (IsVMLAV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8}, A, B)) + return DAG.getNode(ARMISD::VMLAVu, dl, ResVT, A, B); + if (IsVMLAV(MVT::i64, ISD::SIGN_EXTEND, {MVT::v16i8, MVT::v8i16, MVT::v4i32}, + A, B)) + return Create64bitNode(ARMISD::VMLALVs, {A, B}); + if (IsVMLAV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v16i8, MVT::v8i16, MVT::v4i32}, + A, B)) + return Create64bitNode(ARMISD::VMLALVu, {A, B}); + if (IsVMLAV(MVT::i16, ISD::SIGN_EXTEND, {MVT::v16i8}, A, B)) + return DAG.getNode(ISD::TRUNCATE, dl, ResVT, + DAG.getNode(ARMISD::VMLAVs, dl, MVT::i32, A, B)); + if (IsVMLAV(MVT::i16, ISD::ZERO_EXTEND, {MVT::v16i8}, A, B)) + return DAG.getNode(ISD::TRUNCATE, dl, ResVT, + DAG.getNode(ARMISD::VMLAVu, dl, MVT::i32, A, B)); + + if (IsPredVMLAV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8}, A, B, + Mask)) + return DAG.getNode(ARMISD::VMLAVps, dl, ResVT, A, B, Mask); + if (IsPredVMLAV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8}, A, B, + Mask)) + return DAG.getNode(ARMISD::VMLAVpu, dl, ResVT, A, B, Mask); + if (IsPredVMLAV(MVT::i64, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v4i32}, A, B, + Mask)) + return Create64bitNode(ARMISD::VMLALVps, {A, B, Mask}); + if (IsPredVMLAV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v4i32}, A, B, + Mask)) + return Create64bitNode(ARMISD::VMLALVpu, {A, B, Mask}); + if (IsPredVMLAV(MVT::i16, ISD::SIGN_EXTEND, {MVT::v16i8}, A, B, Mask)) + return DAG.getNode(ISD::TRUNCATE, dl, ResVT, + DAG.getNode(ARMISD::VMLAVps, dl, MVT::i32, A, B, Mask)); + if (IsPredVMLAV(MVT::i16, ISD::ZERO_EXTEND, {MVT::v16i8}, A, B, Mask)) + return DAG.getNode(ISD::TRUNCATE, dl, ResVT, + DAG.getNode(ARMISD::VMLAVpu, dl, MVT::i32, A, B, Mask)); + + if (SDValue A = IsVADDV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8})) + return DAG.getNode(ARMISD::VADDVs, dl, ResVT, A); + if (SDValue A = IsVADDV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8})) + return DAG.getNode(ARMISD::VADDVu, dl, ResVT, A); + if (SDValue A = IsVADDV(MVT::i64, ISD::SIGN_EXTEND, {MVT::v4i32})) + return Create64bitNode(ARMISD::VADDLVs, {A}); + if (SDValue A = IsVADDV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v4i32})) + return Create64bitNode(ARMISD::VADDLVu, {A}); + if (SDValue A = IsVADDV(MVT::i16, ISD::SIGN_EXTEND, {MVT::v16i8})) + return DAG.getNode(ISD::TRUNCATE, dl, ResVT, + DAG.getNode(ARMISD::VADDVs, dl, MVT::i32, A)); + if (SDValue A = IsVADDV(MVT::i16, ISD::ZERO_EXTEND, {MVT::v16i8})) + return DAG.getNode(ISD::TRUNCATE, dl, ResVT, + DAG.getNode(ARMISD::VADDVu, dl, MVT::i32, A)); + + if (SDValue A = IsPredVADDV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8}, Mask)) + return DAG.getNode(ARMISD::VADDVps, dl, ResVT, A, Mask); + if (SDValue A = IsPredVADDV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8}, Mask)) + return DAG.getNode(ARMISD::VADDVpu, dl, ResVT, A, Mask); + if (SDValue A = IsPredVADDV(MVT::i64, ISD::SIGN_EXTEND, {MVT::v4i32}, Mask)) + return Create64bitNode(ARMISD::VADDLVps, {A, Mask}); + if (SDValue A = IsPredVADDV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v4i32}, Mask)) + return Create64bitNode(ARMISD::VADDLVpu, {A, Mask}); + if (SDValue A = IsPredVADDV(MVT::i16, ISD::SIGN_EXTEND, {MVT::v16i8}, Mask)) + return DAG.getNode(ISD::TRUNCATE, dl, ResVT, + DAG.getNode(ARMISD::VADDVps, dl, MVT::i32, A, Mask)); + if (SDValue A = IsPredVADDV(MVT::i16, ISD::ZERO_EXTEND, {MVT::v16i8}, Mask)) + return DAG.getNode(ISD::TRUNCATE, dl, ResVT, + DAG.getNode(ARMISD::VADDVpu, dl, MVT::i32, A, Mask)); + + // Some complications. We can get a case where the two inputs of the mul are + // the same, then the output sext will have been helpfully converted to a + // zext. Turn it back. + SDValue Op = N0; + if (Op->getOpcode() == ISD::VSELECT) + Op = Op->getOperand(1); + if (Op->getOpcode() == ISD::ZERO_EXTEND && + Op->getOperand(0)->getOpcode() == ISD::MUL) { + SDValue Mul = Op->getOperand(0); + if (Mul->getOperand(0) == Mul->getOperand(1) && + Mul->getOperand(0)->getOpcode() == ISD::SIGN_EXTEND) { + SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND, dl, N0->getValueType(0), Mul); + if (Op != N0) + Ext = DAG.getNode(ISD::VSELECT, dl, N0->getValueType(0), + N0->getOperand(0), Ext, N0->getOperand(2)); + return DAG.getNode(ISD::VECREDUCE_ADD, dl, ResVT, Ext); + } + } + + return SDValue(); +} + +static SDValue PerformVMOVNCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + unsigned IsTop = N->getConstantOperandVal(2); + + // VMOVNT a undef -> a + // VMOVNB a undef -> a + // VMOVNB undef a -> a + if (Op1->isUndef()) + return Op0; + if (Op0->isUndef() && !IsTop) + return Op1; + + // VMOVNt(c, VQMOVNb(a, b)) => VQMOVNt(c, b) + // VMOVNb(c, VQMOVNb(a, b)) => VQMOVNb(c, b) + if ((Op1->getOpcode() == ARMISD::VQMOVNs || + Op1->getOpcode() == ARMISD::VQMOVNu) && + Op1->getConstantOperandVal(2) == 0) + return DCI.DAG.getNode(Op1->getOpcode(), SDLoc(Op1), N->getValueType(0), + Op0, Op1->getOperand(1), N->getOperand(2)); + + // Only the bottom lanes from Qm (Op1) and either the top or bottom lanes from + // Qd (Op0) are demanded from a VMOVN, depending on whether we are inserting + // into the top or bottom lanes. + unsigned NumElts = N->getValueType(0).getVectorNumElements(); + APInt Op1DemandedElts = APInt::getSplat(NumElts, APInt::getLowBitsSet(2, 1)); + APInt Op0DemandedElts = + IsTop ? Op1DemandedElts + : APInt::getSplat(NumElts, APInt::getHighBitsSet(2, 1)); + + const TargetLowering &TLI = DCI.DAG.getTargetLoweringInfo(); + if (TLI.SimplifyDemandedVectorElts(Op0, Op0DemandedElts, DCI)) + return SDValue(N, 0); + if (TLI.SimplifyDemandedVectorElts(Op1, Op1DemandedElts, DCI)) + return SDValue(N, 0); + + return SDValue(); +} + +static SDValue PerformVQMOVNCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + SDValue Op0 = N->getOperand(0); + unsigned IsTop = N->getConstantOperandVal(2); + + unsigned NumElts = N->getValueType(0).getVectorNumElements(); + APInt Op0DemandedElts = + APInt::getSplat(NumElts, IsTop ? APInt::getLowBitsSet(2, 1) + : APInt::getHighBitsSet(2, 1)); + + const TargetLowering &TLI = DCI.DAG.getTargetLoweringInfo(); + if (TLI.SimplifyDemandedVectorElts(Op0, Op0DemandedElts, DCI)) + return SDValue(N, 0); + return SDValue(); +} + +static SDValue PerformLongShiftCombine(SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + + // Turn X << -C -> X >> C and viceversa. The negative shifts can come up from + // uses of the intrinsics. + if (auto C = dyn_cast<ConstantSDNode>(N->getOperand(2))) { + int ShiftAmt = C->getSExtValue(); + if (ShiftAmt == 0) { + SDValue Merge = DAG.getMergeValues({Op0, Op1}, DL); + DAG.ReplaceAllUsesWith(N, Merge.getNode()); + return SDValue(); + } + + if (ShiftAmt >= -32 && ShiftAmt < 0) { + unsigned NewOpcode = + N->getOpcode() == ARMISD::LSLL ? ARMISD::LSRL : ARMISD::LSLL; + SDValue NewShift = DAG.getNode(NewOpcode, DL, N->getVTList(), Op0, Op1, + DAG.getConstant(-ShiftAmt, DL, MVT::i32)); + DAG.ReplaceAllUsesWith(N, NewShift.getNode()); + return NewShift; + } + } + + return SDValue(); +} + /// PerformIntrinsicCombine - ARM-specific DAG combining for intrinsics. -static SDValue PerformIntrinsicCombine(SDNode *N, SelectionDAG &DAG) { +SDValue ARMTargetLowering::PerformIntrinsicCombine(SDNode *N, + DAGCombinerInfo &DCI) const { + SelectionDAG &DAG = DCI.DAG; unsigned IntNo = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue(); switch (IntNo) { default: @@ -13928,6 +17407,72 @@ static SDValue PerformIntrinsicCombine(SDNode *N, SelectionDAG &DAG) { case Intrinsic::arm_neon_vqrshiftu: // No immediate versions of these to check for. break; + + case Intrinsic::arm_mve_vqdmlah: + case Intrinsic::arm_mve_vqdmlash: + case Intrinsic::arm_mve_vqrdmlah: + case Intrinsic::arm_mve_vqrdmlash: + case Intrinsic::arm_mve_vmla_n_predicated: + case Intrinsic::arm_mve_vmlas_n_predicated: + case Intrinsic::arm_mve_vqdmlah_predicated: + case Intrinsic::arm_mve_vqdmlash_predicated: + case Intrinsic::arm_mve_vqrdmlah_predicated: + case Intrinsic::arm_mve_vqrdmlash_predicated: { + // These intrinsics all take an i32 scalar operand which is narrowed to the + // size of a single lane of the vector type they return. So we don't need + // any bits of that operand above that point, which allows us to eliminate + // uxth/sxth. + unsigned BitWidth = N->getValueType(0).getScalarSizeInBits(); + APInt DemandedMask = APInt::getLowBitsSet(32, BitWidth); + if (SimplifyDemandedBits(N->getOperand(3), DemandedMask, DCI)) + return SDValue(); + break; + } + + case Intrinsic::arm_mve_minv: + case Intrinsic::arm_mve_maxv: + case Intrinsic::arm_mve_minav: + case Intrinsic::arm_mve_maxav: + case Intrinsic::arm_mve_minv_predicated: + case Intrinsic::arm_mve_maxv_predicated: + case Intrinsic::arm_mve_minav_predicated: + case Intrinsic::arm_mve_maxav_predicated: { + // These intrinsics all take an i32 scalar operand which is narrowed to the + // size of a single lane of the vector type they take as the other input. + unsigned BitWidth = N->getOperand(2)->getValueType(0).getScalarSizeInBits(); + APInt DemandedMask = APInt::getLowBitsSet(32, BitWidth); + if (SimplifyDemandedBits(N->getOperand(1), DemandedMask, DCI)) + return SDValue(); + break; + } + + case Intrinsic::arm_mve_addv: { + // Turn this intrinsic straight into the appropriate ARMISD::VADDV node, + // which allow PerformADDVecReduce to turn it into VADDLV when possible. + bool Unsigned = cast<ConstantSDNode>(N->getOperand(2))->getZExtValue(); + unsigned Opc = Unsigned ? ARMISD::VADDVu : ARMISD::VADDVs; + return DAG.getNode(Opc, SDLoc(N), N->getVTList(), N->getOperand(1)); + } + + case Intrinsic::arm_mve_addlv: + case Intrinsic::arm_mve_addlv_predicated: { + // Same for these, but ARMISD::VADDLV has to be followed by a BUILD_PAIR + // which recombines the two outputs into an i64 + bool Unsigned = cast<ConstantSDNode>(N->getOperand(2))->getZExtValue(); + unsigned Opc = IntNo == Intrinsic::arm_mve_addlv ? + (Unsigned ? ARMISD::VADDLVu : ARMISD::VADDLVs) : + (Unsigned ? ARMISD::VADDLVpu : ARMISD::VADDLVps); + + SmallVector<SDValue, 4> Ops; + for (unsigned i = 1, e = N->getNumOperands(); i < e; i++) + if (i != 2) // skip the unsigned flag + Ops.push_back(N->getOperand(i)); + + SDLoc dl(N); + SDValue val = DAG.getNode(Opc, dl, {MVT::i32, MVT::i32}, Ops); + return DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, val.getValue(0), + val.getValue(1)); + } } return SDValue(); @@ -13943,18 +17488,6 @@ static SDValue PerformShiftCombine(SDNode *N, const ARMSubtarget *ST) { SelectionDAG &DAG = DCI.DAG; EVT VT = N->getValueType(0); - if (N->getOpcode() == ISD::SRL && VT == MVT::i32 && ST->hasV6Ops()) { - // Canonicalize (srl (bswap x), 16) to (rotr (bswap x), 16) if the high - // 16-bits of x is zero. This optimizes rev + lsr 16 to rev16. - SDValue N1 = N->getOperand(1); - if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(N1)) { - SDValue N0 = N->getOperand(0); - if (C->getZExtValue() == 16 && N0.getOpcode() == ISD::BSWAP && - DAG.MaskedValueIsZero(N0.getOperand(0), - APInt::getHighBitsSet(32, 16))) - return DAG.getNode(ISD::ROTR, SDLoc(N), VT, N0, N1); - } - } if (ST->isThumb1Only() && N->getOpcode() == ISD::SHL && VT == MVT::i32 && N->getOperand(0)->getOpcode() == ISD::AND && @@ -13994,7 +17527,7 @@ static SDValue PerformShiftCombine(SDNode *N, const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (!VT.isVector() || !TLI.isTypeLegal(VT)) return SDValue(); - if (ST->hasMVEIntegerOps() && VT == MVT::v2i64) + if (ST->hasMVEIntegerOps()) return SDValue(); int64_t Cnt; @@ -14023,9 +17556,10 @@ static SDValue PerformShiftCombine(SDNode *N, return SDValue(); } -// Look for a sign/zero extend of a larger than legal load. This can be split -// into two extending loads, which are simpler to deal with than an arbitrary -// sign extend. +// Look for a sign/zero/fpextend extend of a larger than legal load. This can be +// split into multiple extending loads, which are simpler to deal with than an +// arbitrary extend. For fp extends we use an integer extending load and a VCVTL +// to convert the type to an f32. static SDValue PerformSplittingToWideningLoad(SDNode *N, SelectionDAG &DAG) { SDValue N0 = N->getOperand(0); if (N0.getOpcode() != ISD::LOAD) @@ -14043,49 +17577,66 @@ static SDValue PerformSplittingToWideningLoad(SDNode *N, SelectionDAG &DAG) { EVT FromEltVT = FromVT.getVectorElementType(); unsigned NumElements = 0; - if (ToEltVT == MVT::i32 && (FromEltVT == MVT::i16 || FromEltVT == MVT::i8)) + if (ToEltVT == MVT::i32 && FromEltVT == MVT::i8) + NumElements = 4; + if (ToEltVT == MVT::f32 && FromEltVT == MVT::f16) NumElements = 4; - if (ToEltVT == MVT::i16 && FromEltVT == MVT::i8) - NumElements = 8; if (NumElements == 0 || - FromVT.getVectorNumElements() == NumElements || + (FromEltVT != MVT::f16 && FromVT.getVectorNumElements() == NumElements) || FromVT.getVectorNumElements() % NumElements != 0 || !isPowerOf2_32(NumElements)) return SDValue(); + LLVMContext &C = *DAG.getContext(); SDLoc DL(LD); // Details about the old load SDValue Ch = LD->getChain(); SDValue BasePtr = LD->getBasePtr(); - unsigned Alignment = LD->getOriginalAlignment(); + Align Alignment = LD->getOriginalAlign(); MachineMemOperand::Flags MMOFlags = LD->getMemOperand()->getFlags(); AAMDNodes AAInfo = LD->getAAInfo(); ISD::LoadExtType NewExtType = N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD; SDValue Offset = DAG.getUNDEF(BasePtr.getValueType()); - EVT NewFromVT = FromVT.getHalfNumVectorElementsVT(*DAG.getContext()); - EVT NewToVT = ToVT.getHalfNumVectorElementsVT(*DAG.getContext()); - unsigned NewOffset = NewFromVT.getSizeInBits() / 8; - SDValue NewPtr = DAG.getObjectPtrOffset(DL, BasePtr, NewOffset); - - // Split the load in half, each side of which is extended separately. This - // is good enough, as legalisation will take it from there. They are either - // already legal or they will be split further into something that is - // legal. - SDValue NewLoad1 = - DAG.getLoad(ISD::UNINDEXED, NewExtType, NewToVT, DL, Ch, BasePtr, Offset, - LD->getPointerInfo(), NewFromVT, Alignment, MMOFlags, AAInfo); - SDValue NewLoad2 = - DAG.getLoad(ISD::UNINDEXED, NewExtType, NewToVT, DL, Ch, NewPtr, Offset, - LD->getPointerInfo().getWithOffset(NewOffset), NewFromVT, - Alignment, MMOFlags, AAInfo); - - SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, - SDValue(NewLoad1.getNode(), 1), - SDValue(NewLoad2.getNode(), 1)); + EVT NewFromVT = EVT::getVectorVT( + C, EVT::getIntegerVT(C, FromEltVT.getScalarSizeInBits()), NumElements); + EVT NewToVT = EVT::getVectorVT( + C, EVT::getIntegerVT(C, ToEltVT.getScalarSizeInBits()), NumElements); + + SmallVector<SDValue, 4> Loads; + SmallVector<SDValue, 4> Chains; + for (unsigned i = 0; i < FromVT.getVectorNumElements() / NumElements; i++) { + unsigned NewOffset = (i * NewFromVT.getSizeInBits()) / 8; + SDValue NewPtr = + DAG.getObjectPtrOffset(DL, BasePtr, TypeSize::Fixed(NewOffset)); + + SDValue NewLoad = + DAG.getLoad(ISD::UNINDEXED, NewExtType, NewToVT, DL, Ch, NewPtr, Offset, + LD->getPointerInfo().getWithOffset(NewOffset), NewFromVT, + Alignment, MMOFlags, AAInfo); + Loads.push_back(NewLoad); + Chains.push_back(SDValue(NewLoad.getNode(), 1)); + } + + // Float truncs need to extended with VCVTB's into their floating point types. + if (FromEltVT == MVT::f16) { + SmallVector<SDValue, 4> Extends; + + for (unsigned i = 0; i < Loads.size(); i++) { + SDValue LoadBC = + DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, MVT::v8f16, Loads[i]); + SDValue FPExt = DAG.getNode(ARMISD::VCVTL, DL, MVT::v4f32, LoadBC, + DAG.getConstant(0, DL, MVT::i32)); + Extends.push_back(FPExt); + } + + Loads = Extends; + } + + SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains); DAG.ReplaceAllUsesOfValueWith(SDValue(LD, 1), NewChain); - return DAG.getNode(ISD::CONCAT_VECTORS, DL, ToVT, NewLoad1, NewLoad2); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, ToVT, Loads); } /// PerformExtendCombine - Target-specific DAG combining for ISD::SIGN_EXTEND, @@ -14133,6 +17684,164 @@ static SDValue PerformExtendCombine(SDNode *N, SelectionDAG &DAG, return SDValue(); } +static SDValue PerformFPExtendCombine(SDNode *N, SelectionDAG &DAG, + const ARMSubtarget *ST) { + if (ST->hasMVEFloatOps()) + if (SDValue NewLoad = PerformSplittingToWideningLoad(N, DAG)) + return NewLoad; + + return SDValue(); +} + +// Lower smin(smax(x, C1), C2) to ssat or usat, if they have saturating +// constant bounds. +static SDValue PerformMinMaxToSatCombine(SDValue Op, SelectionDAG &DAG, + const ARMSubtarget *Subtarget) { + if ((Subtarget->isThumb() || !Subtarget->hasV6Ops()) && + !Subtarget->isThumb2()) + return SDValue(); + + EVT VT = Op.getValueType(); + SDValue Op0 = Op.getOperand(0); + + if (VT != MVT::i32 || + (Op0.getOpcode() != ISD::SMIN && Op0.getOpcode() != ISD::SMAX) || + !isa<ConstantSDNode>(Op.getOperand(1)) || + !isa<ConstantSDNode>(Op0.getOperand(1))) + return SDValue(); + + SDValue Min = Op; + SDValue Max = Op0; + SDValue Input = Op0.getOperand(0); + if (Min.getOpcode() == ISD::SMAX) + std::swap(Min, Max); + + APInt MinC = Min.getConstantOperandAPInt(1); + APInt MaxC = Max.getConstantOperandAPInt(1); + + if (Min.getOpcode() != ISD::SMIN || Max.getOpcode() != ISD::SMAX || + !(MinC + 1).isPowerOf2()) + return SDValue(); + + SDLoc DL(Op); + if (MinC == ~MaxC) + return DAG.getNode(ARMISD::SSAT, DL, VT, Input, + DAG.getConstant(MinC.countTrailingOnes(), DL, VT)); + if (MaxC == 0) + return DAG.getNode(ARMISD::USAT, DL, VT, Input, + DAG.getConstant(MinC.countTrailingOnes(), DL, VT)); + + return SDValue(); +} + +/// PerformMinMaxCombine - Target-specific DAG combining for creating truncating +/// saturates. +static SDValue PerformMinMaxCombine(SDNode *N, SelectionDAG &DAG, + const ARMSubtarget *ST) { + EVT VT = N->getValueType(0); + SDValue N0 = N->getOperand(0); + + if (VT == MVT::i32) + return PerformMinMaxToSatCombine(SDValue(N, 0), DAG, ST); + + if (!ST->hasMVEIntegerOps()) + return SDValue(); + + if (SDValue V = PerformVQDMULHCombine(N, DAG)) + return V; + + if (VT != MVT::v4i32 && VT != MVT::v8i16) + return SDValue(); + + auto IsSignedSaturate = [&](SDNode *Min, SDNode *Max) { + // Check one is a smin and the other is a smax + if (Min->getOpcode() != ISD::SMIN) + std::swap(Min, Max); + if (Min->getOpcode() != ISD::SMIN || Max->getOpcode() != ISD::SMAX) + return false; + + APInt SaturateC; + if (VT == MVT::v4i32) + SaturateC = APInt(32, (1 << 15) - 1, true); + else //if (VT == MVT::v8i16) + SaturateC = APInt(16, (1 << 7) - 1, true); + + APInt MinC, MaxC; + if (!ISD::isConstantSplatVector(Min->getOperand(1).getNode(), MinC) || + MinC != SaturateC) + return false; + if (!ISD::isConstantSplatVector(Max->getOperand(1).getNode(), MaxC) || + MaxC != ~SaturateC) + return false; + return true; + }; + + if (IsSignedSaturate(N, N0.getNode())) { + SDLoc DL(N); + MVT ExtVT, HalfVT; + if (VT == MVT::v4i32) { + HalfVT = MVT::v8i16; + ExtVT = MVT::v4i16; + } else { // if (VT == MVT::v8i16) + HalfVT = MVT::v16i8; + ExtVT = MVT::v8i8; + } + + // Create a VQMOVNB with undef top lanes, then signed extended into the top + // half. That extend will hopefully be removed if only the bottom bits are + // demanded (though a truncating store, for example). + SDValue VQMOVN = + DAG.getNode(ARMISD::VQMOVNs, DL, HalfVT, DAG.getUNDEF(HalfVT), + N0->getOperand(0), DAG.getConstant(0, DL, MVT::i32)); + SDValue Bitcast = DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, VT, VQMOVN); + return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Bitcast, + DAG.getValueType(ExtVT)); + } + + auto IsUnsignedSaturate = [&](SDNode *Min) { + // For unsigned, we just need to check for <= 0xffff + if (Min->getOpcode() != ISD::UMIN) + return false; + + APInt SaturateC; + if (VT == MVT::v4i32) + SaturateC = APInt(32, (1 << 16) - 1, true); + else //if (VT == MVT::v8i16) + SaturateC = APInt(16, (1 << 8) - 1, true); + + APInt MinC; + if (!ISD::isConstantSplatVector(Min->getOperand(1).getNode(), MinC) || + MinC != SaturateC) + return false; + return true; + }; + + if (IsUnsignedSaturate(N)) { + SDLoc DL(N); + MVT HalfVT; + unsigned ExtConst; + if (VT == MVT::v4i32) { + HalfVT = MVT::v8i16; + ExtConst = 0x0000FFFF; + } else { //if (VT == MVT::v8i16) + HalfVT = MVT::v16i8; + ExtConst = 0x00FF; + } + + // Create a VQMOVNB with undef top lanes, then ZExt into the top half with + // an AND. That extend will hopefully be removed if only the bottom bits are + // demanded (though a truncating store, for example). + SDValue VQMOVN = + DAG.getNode(ARMISD::VQMOVNu, DL, HalfVT, DAG.getUNDEF(HalfVT), N0, + DAG.getConstant(0, DL, MVT::i32)); + SDValue Bitcast = DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, VT, VQMOVN); + return DAG.getNode(ISD::AND, DL, VT, Bitcast, + DAG.getConstant(ExtConst, DL, VT)); + } + + return SDValue(); +} + static const APInt *isPowerOf2Constant(SDValue V) { ConstantSDNode *C = dyn_cast<ConstantSDNode>(V); if (!C) @@ -14254,7 +17963,7 @@ static SDValue SearchLoopIntrinsic(SDValue N, ISD::CondCode &CC, int &Imm, auto *Const = dyn_cast<ConstantSDNode>(N.getOperand(1)); if (!Const) return SDValue(); - if (Const->isNullValue()) + if (Const->isZero()) Imm = 0; else if (Const->isOne()) Imm = 1; @@ -14265,7 +17974,7 @@ static SDValue SearchLoopIntrinsic(SDValue N, ISD::CondCode &CC, int &Imm, } case ISD::INTRINSIC_W_CHAIN: { unsigned IntOp = cast<ConstantSDNode>(N.getOperand(1))->getZExtValue(); - if (IntOp != Intrinsic::test_set_loop_iterations && + if (IntOp != Intrinsic::test_start_loop_iterations && IntOp != Intrinsic::loop_decrement_reg) return SDValue(); return N; @@ -14280,7 +17989,7 @@ static SDValue PerformHWLoopCombine(SDNode *N, // The hwloop intrinsics that we're interested are used for control-flow, // either for entering or exiting the loop: - // - test.set.loop.iterations will test whether its operand is zero. If it + // - test.start.loop.iterations will test whether its operand is zero. If it // is zero, the proceeding branch should not enter the loop. // - loop.decrement.reg also tests whether its operand is zero. If it is // zero, the proceeding branch should not branch back to the beginning of @@ -14306,7 +18015,7 @@ static SDValue PerformHWLoopCombine(SDNode *N, Cond = N->getOperand(2); Dest = N->getOperand(4); if (auto *Const = dyn_cast<ConstantSDNode>(N->getOperand(3))) { - if (!Const->isOne() && !Const->isNullValue()) + if (!Const->isOne() && !Const->isZero()) return SDValue(); Imm = Const->getZExtValue(); } else @@ -14355,21 +18064,25 @@ static SDValue PerformHWLoopCombine(SDNode *N, DAG.ReplaceAllUsesOfValueWith(SDValue(Br, 0), NewBr); }; - if (IntOp == Intrinsic::test_set_loop_iterations) { + if (IntOp == Intrinsic::test_start_loop_iterations) { SDValue Res; + SDValue Setup = DAG.getNode(ARMISD::WLSSETUP, dl, MVT::i32, Elements); // We expect this 'instruction' to branch when the counter is zero. if (IsTrueIfZero(CC, Imm)) { - SDValue Ops[] = { Chain, Elements, Dest }; + SDValue Ops[] = {Chain, Setup, Dest}; Res = DAG.getNode(ARMISD::WLS, dl, MVT::Other, Ops); } else { // The logic is the reverse of what we need for WLS, so find the other // basic block target: the target of the proceeding br. UpdateUncondBr(Br, Dest, DAG); - SDValue Ops[] = { Chain, Elements, OtherTarget }; + SDValue Ops[] = {Chain, Setup, OtherTarget}; Res = DAG.getNode(ARMISD::WLS, dl, MVT::Other, Ops); } - DAG.ReplaceAllUsesOfValueWith(Int.getValue(1), Int.getOperand(0)); + // Update LR count to the new value + DAG.ReplaceAllUsesOfValueWith(Int.getValue(0), Setup); + // Update chain + DAG.ReplaceAllUsesOfValueWith(Int.getValue(2), Int.getOperand(0)); return Res; } else { SDValue Size = DAG.getTargetConstant( @@ -14507,6 +18220,23 @@ ARMTargetLowering::PerformCMOVCombine(SDNode *N, SelectionDAG &DAG) const { if (!VT.isInteger()) return SDValue(); + // Fold away an unneccessary CMPZ/CMOV + // CMOV A, B, C1, $cpsr, (CMPZ (CMOV 1, 0, C2, D), 0) -> + // if C1==EQ -> CMOV A, B, C2, $cpsr, D + // if C1==NE -> CMOV A, B, NOT(C2), $cpsr, D + if (N->getConstantOperandVal(2) == ARMCC::EQ || + N->getConstantOperandVal(2) == ARMCC::NE) { + ARMCC::CondCodes Cond; + if (SDValue C = IsCMPZCSINC(N->getOperand(4).getNode(), Cond)) { + if (N->getConstantOperandVal(2) == ARMCC::NE) + Cond = ARMCC::getOppositeCondition(Cond); + return DAG.getNode(N->getOpcode(), SDLoc(N), MVT::i32, N->getOperand(0), + N->getOperand(1), + DAG.getTargetConstant(Cond, SDLoc(N), MVT::i32), + N->getOperand(3), C); + } + } + // Materialize a boolean comparison for integers so we can avoid branching. if (isNullConstant(FalseVal)) { if (CC == ARMCC::EQ && isOneConstant(TrueVal)) { @@ -14614,10 +18344,325 @@ ARMTargetLowering::PerformCMOVCombine(SDNode *N, SelectionDAG &DAG) const { return Res; } +static SDValue PerformBITCASTCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *ST) { + SelectionDAG &DAG = DCI.DAG; + SDValue Src = N->getOperand(0); + EVT DstVT = N->getValueType(0); + + // Convert v4f32 bitcast (v4i32 vdup (i32)) -> v4f32 vdup (i32) under MVE. + if (ST->hasMVEIntegerOps() && Src.getOpcode() == ARMISD::VDUP) { + EVT SrcVT = Src.getValueType(); + if (SrcVT.getScalarSizeInBits() == DstVT.getScalarSizeInBits()) + return DAG.getNode(ARMISD::VDUP, SDLoc(N), DstVT, Src.getOperand(0)); + } + + // We may have a bitcast of something that has already had this bitcast + // combine performed on it, so skip past any VECTOR_REG_CASTs. + while (Src.getOpcode() == ARMISD::VECTOR_REG_CAST) + Src = Src.getOperand(0); + + // Bitcast from element-wise VMOV or VMVN doesn't need VREV if the VREV that + // would be generated is at least the width of the element type. + EVT SrcVT = Src.getValueType(); + if ((Src.getOpcode() == ARMISD::VMOVIMM || + Src.getOpcode() == ARMISD::VMVNIMM || + Src.getOpcode() == ARMISD::VMOVFPIMM) && + SrcVT.getScalarSizeInBits() <= DstVT.getScalarSizeInBits() && + DAG.getDataLayout().isBigEndian()) + return DAG.getNode(ARMISD::VECTOR_REG_CAST, SDLoc(N), DstVT, Src); + + // bitcast(extract(x, n)); bitcast(extract(x, n+1)) -> VMOVRRD x + if (SDValue R = PerformExtractEltToVMOVRRD(N, DCI)) + return R; + + return SDValue(); +} + +// Some combines for the MVETrunc truncations legalizer helper. Also lowers the +// node into stack operations after legalizeOps. +SDValue ARMTargetLowering::PerformMVETruncCombine( + SDNode *N, TargetLowering::DAGCombinerInfo &DCI) const { + SelectionDAG &DAG = DCI.DAG; + EVT VT = N->getValueType(0); + SDLoc DL(N); + + // MVETrunc(Undef, Undef) -> Undef + if (all_of(N->ops(), [](SDValue Op) { return Op.isUndef(); })) + return DAG.getUNDEF(VT); + + // MVETrunc(MVETrunc a b, MVETrunc c, d) -> MVETrunc + if (N->getNumOperands() == 2 && + N->getOperand(0).getOpcode() == ARMISD::MVETRUNC && + N->getOperand(1).getOpcode() == ARMISD::MVETRUNC) + return DAG.getNode(ARMISD::MVETRUNC, DL, VT, N->getOperand(0).getOperand(0), + N->getOperand(0).getOperand(1), + N->getOperand(1).getOperand(0), + N->getOperand(1).getOperand(1)); + + // MVETrunc(shuffle, shuffle) -> VMOVN + if (N->getNumOperands() == 2 && + N->getOperand(0).getOpcode() == ISD::VECTOR_SHUFFLE && + N->getOperand(1).getOpcode() == ISD::VECTOR_SHUFFLE) { + auto *S0 = cast<ShuffleVectorSDNode>(N->getOperand(0).getNode()); + auto *S1 = cast<ShuffleVectorSDNode>(N->getOperand(1).getNode()); + + if (S0->getOperand(0) == S1->getOperand(0) && + S0->getOperand(1) == S1->getOperand(1)) { + // Construct complete shuffle mask + SmallVector<int, 8> Mask(S0->getMask()); + Mask.append(S1->getMask().begin(), S1->getMask().end()); + + if (isVMOVNTruncMask(Mask, VT, false)) + return DAG.getNode( + ARMISD::VMOVN, DL, VT, + DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, VT, S0->getOperand(0)), + DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, VT, S0->getOperand(1)), + DAG.getConstant(1, DL, MVT::i32)); + if (isVMOVNTruncMask(Mask, VT, true)) + return DAG.getNode( + ARMISD::VMOVN, DL, VT, + DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, VT, S0->getOperand(1)), + DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, VT, S0->getOperand(0)), + DAG.getConstant(1, DL, MVT::i32)); + } + } + + // For MVETrunc of a buildvector or shuffle, it can be beneficial to lower the + // truncate to a buildvector to allow the generic optimisations to kick in. + if (all_of(N->ops(), [](SDValue Op) { + return Op.getOpcode() == ISD::BUILD_VECTOR || + Op.getOpcode() == ISD::VECTOR_SHUFFLE || + (Op.getOpcode() == ISD::BITCAST && + Op.getOperand(0).getOpcode() == ISD::BUILD_VECTOR); + })) { + SmallVector<SDValue, 8> Extracts; + for (unsigned Op = 0; Op < N->getNumOperands(); Op++) { + SDValue O = N->getOperand(Op); + for (unsigned i = 0; i < O.getValueType().getVectorNumElements(); i++) { + SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, O, + DAG.getConstant(i, DL, MVT::i32)); + Extracts.push_back(Ext); + } + } + return DAG.getBuildVector(VT, DL, Extracts); + } + + // If we are late in the legalization process and nothing has optimised + // the trunc to anything better, lower it to a stack store and reload, + // performing the truncation whilst keeping the lanes in the correct order: + // VSTRH.32 a, stack; VSTRH.32 b, stack+8; VLDRW.32 stack; + if (!DCI.isAfterLegalizeDAG()) + return SDValue(); + + SDValue StackPtr = DAG.CreateStackTemporary(TypeSize::Fixed(16), Align(4)); + int SPFI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex(); + int NumIns = N->getNumOperands(); + assert((NumIns == 2 || NumIns == 4) && + "Expected 2 or 4 inputs to an MVETrunc"); + EVT StoreVT = VT.getHalfNumVectorElementsVT(*DAG.getContext()); + if (N->getNumOperands() == 4) + StoreVT = StoreVT.getHalfNumVectorElementsVT(*DAG.getContext()); + + SmallVector<SDValue> Chains; + for (int I = 0; I < NumIns; I++) { + SDValue Ptr = DAG.getNode( + ISD::ADD, DL, StackPtr.getValueType(), StackPtr, + DAG.getConstant(I * 16 / NumIns, DL, StackPtr.getValueType())); + MachinePointerInfo MPI = MachinePointerInfo::getFixedStack( + DAG.getMachineFunction(), SPFI, I * 16 / NumIns); + SDValue Ch = DAG.getTruncStore(DAG.getEntryNode(), DL, N->getOperand(I), + Ptr, MPI, StoreVT, Align(4)); + Chains.push_back(Ch); + } + + SDValue Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains); + MachinePointerInfo MPI = + MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), SPFI, 0); + return DAG.getLoad(VT, DL, Chain, StackPtr, MPI, Align(4)); +} + +// Take a MVEEXT(load x) and split that into (extload x, extload x+8) +static SDValue PerformSplittingMVEEXTToWideningLoad(SDNode *N, + SelectionDAG &DAG) { + SDValue N0 = N->getOperand(0); + LoadSDNode *LD = dyn_cast<LoadSDNode>(N0.getNode()); + if (!LD || !LD->isSimple() || !N0.hasOneUse() || LD->isIndexed()) + return SDValue(); + + EVT FromVT = LD->getMemoryVT(); + EVT ToVT = N->getValueType(0); + if (!ToVT.isVector()) + return SDValue(); + assert(FromVT.getVectorNumElements() == ToVT.getVectorNumElements() * 2); + EVT ToEltVT = ToVT.getVectorElementType(); + EVT FromEltVT = FromVT.getVectorElementType(); + + unsigned NumElements = 0; + if (ToEltVT == MVT::i32 && (FromEltVT == MVT::i16 || FromEltVT == MVT::i8)) + NumElements = 4; + if (ToEltVT == MVT::i16 && FromEltVT == MVT::i8) + NumElements = 8; + assert(NumElements != 0); + + ISD::LoadExtType NewExtType = + N->getOpcode() == ARMISD::MVESEXT ? ISD::SEXTLOAD : ISD::ZEXTLOAD; + if (LD->getExtensionType() != ISD::NON_EXTLOAD && + LD->getExtensionType() != ISD::EXTLOAD && + LD->getExtensionType() != NewExtType) + return SDValue(); + + LLVMContext &C = *DAG.getContext(); + SDLoc DL(LD); + // Details about the old load + SDValue Ch = LD->getChain(); + SDValue BasePtr = LD->getBasePtr(); + Align Alignment = LD->getOriginalAlign(); + MachineMemOperand::Flags MMOFlags = LD->getMemOperand()->getFlags(); + AAMDNodes AAInfo = LD->getAAInfo(); + + SDValue Offset = DAG.getUNDEF(BasePtr.getValueType()); + EVT NewFromVT = EVT::getVectorVT( + C, EVT::getIntegerVT(C, FromEltVT.getScalarSizeInBits()), NumElements); + EVT NewToVT = EVT::getVectorVT( + C, EVT::getIntegerVT(C, ToEltVT.getScalarSizeInBits()), NumElements); + + SmallVector<SDValue, 4> Loads; + SmallVector<SDValue, 4> Chains; + for (unsigned i = 0; i < FromVT.getVectorNumElements() / NumElements; i++) { + unsigned NewOffset = (i * NewFromVT.getSizeInBits()) / 8; + SDValue NewPtr = + DAG.getObjectPtrOffset(DL, BasePtr, TypeSize::Fixed(NewOffset)); + + SDValue NewLoad = + DAG.getLoad(ISD::UNINDEXED, NewExtType, NewToVT, DL, Ch, NewPtr, Offset, + LD->getPointerInfo().getWithOffset(NewOffset), NewFromVT, + Alignment, MMOFlags, AAInfo); + Loads.push_back(NewLoad); + Chains.push_back(SDValue(NewLoad.getNode(), 1)); + } + + SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains); + DAG.ReplaceAllUsesOfValueWith(SDValue(LD, 1), NewChain); + return DAG.getMergeValues(Loads, DL); +} + +// Perform combines for MVEEXT. If it has not be optimized to anything better +// before lowering, it gets converted to stack store and extloads performing the +// extend whilst still keeping the same lane ordering. +SDValue ARMTargetLowering::PerformMVEExtCombine( + SDNode *N, TargetLowering::DAGCombinerInfo &DCI) const { + SelectionDAG &DAG = DCI.DAG; + EVT VT = N->getValueType(0); + SDLoc DL(N); + assert(N->getNumValues() == 2 && "Expected MVEEXT with 2 elements"); + assert((VT == MVT::v4i32 || VT == MVT::v8i16) && "Unexpected MVEEXT type"); + + EVT ExtVT = N->getOperand(0).getValueType().getHalfNumVectorElementsVT( + *DAG.getContext()); + auto Extend = [&](SDValue V) { + SDValue VVT = DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, VT, V); + return N->getOpcode() == ARMISD::MVESEXT + ? DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, VVT, + DAG.getValueType(ExtVT)) + : DAG.getZeroExtendInReg(VVT, DL, ExtVT); + }; + + // MVEEXT(VDUP) -> SIGN_EXTEND_INREG(VDUP) + if (N->getOperand(0).getOpcode() == ARMISD::VDUP) { + SDValue Ext = Extend(N->getOperand(0)); + return DAG.getMergeValues({Ext, Ext}, DL); + } + + // MVEEXT(shuffle) -> SIGN_EXTEND_INREG/ZERO_EXTEND_INREG + if (auto *SVN = dyn_cast<ShuffleVectorSDNode>(N->getOperand(0))) { + ArrayRef<int> Mask = SVN->getMask(); + assert(Mask.size() == 2 * VT.getVectorNumElements()); + assert(Mask.size() == SVN->getValueType(0).getVectorNumElements()); + unsigned Rev = VT == MVT::v4i32 ? ARMISD::VREV32 : ARMISD::VREV16; + SDValue Op0 = SVN->getOperand(0); + SDValue Op1 = SVN->getOperand(1); + + auto CheckInregMask = [&](int Start, int Offset) { + for (int Idx = 0, E = VT.getVectorNumElements(); Idx < E; ++Idx) + if (Mask[Start + Idx] >= 0 && Mask[Start + Idx] != Idx * 2 + Offset) + return false; + return true; + }; + SDValue V0 = SDValue(N, 0); + SDValue V1 = SDValue(N, 1); + if (CheckInregMask(0, 0)) + V0 = Extend(Op0); + else if (CheckInregMask(0, 1)) + V0 = Extend(DAG.getNode(Rev, DL, SVN->getValueType(0), Op0)); + else if (CheckInregMask(0, Mask.size())) + V0 = Extend(Op1); + else if (CheckInregMask(0, Mask.size() + 1)) + V0 = Extend(DAG.getNode(Rev, DL, SVN->getValueType(0), Op1)); + + if (CheckInregMask(VT.getVectorNumElements(), Mask.size())) + V1 = Extend(Op1); + else if (CheckInregMask(VT.getVectorNumElements(), Mask.size() + 1)) + V1 = Extend(DAG.getNode(Rev, DL, SVN->getValueType(0), Op1)); + else if (CheckInregMask(VT.getVectorNumElements(), 0)) + V1 = Extend(Op0); + else if (CheckInregMask(VT.getVectorNumElements(), 1)) + V1 = Extend(DAG.getNode(Rev, DL, SVN->getValueType(0), Op0)); + + if (V0.getNode() != N || V1.getNode() != N) + return DAG.getMergeValues({V0, V1}, DL); + } + + // MVEEXT(load) -> extload, extload + if (N->getOperand(0)->getOpcode() == ISD::LOAD) + if (SDValue L = PerformSplittingMVEEXTToWideningLoad(N, DAG)) + return L; + + if (!DCI.isAfterLegalizeDAG()) + return SDValue(); + + // Lower to a stack store and reload: + // VSTRW.32 a, stack; VLDRH.32 stack; VLDRH.32 stack+8; + SDValue StackPtr = DAG.CreateStackTemporary(TypeSize::Fixed(16), Align(4)); + int SPFI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex(); + int NumOuts = N->getNumValues(); + assert((NumOuts == 2 || NumOuts == 4) && + "Expected 2 or 4 outputs to an MVEEXT"); + EVT LoadVT = N->getOperand(0).getValueType().getHalfNumVectorElementsVT( + *DAG.getContext()); + if (N->getNumOperands() == 4) + LoadVT = LoadVT.getHalfNumVectorElementsVT(*DAG.getContext()); + + MachinePointerInfo MPI = + MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), SPFI, 0); + SDValue Chain = DAG.getStore(DAG.getEntryNode(), DL, N->getOperand(0), + StackPtr, MPI, Align(4)); + + SmallVector<SDValue> Loads; + for (int I = 0; I < NumOuts; I++) { + SDValue Ptr = DAG.getNode( + ISD::ADD, DL, StackPtr.getValueType(), StackPtr, + DAG.getConstant(I * 16 / NumOuts, DL, StackPtr.getValueType())); + MachinePointerInfo MPI = MachinePointerInfo::getFixedStack( + DAG.getMachineFunction(), SPFI, I * 16 / NumOuts); + SDValue Load = DAG.getExtLoad( + N->getOpcode() == ARMISD::MVESEXT ? ISD::SEXTLOAD : ISD::ZEXTLOAD, DL, + VT, Chain, Ptr, MPI, LoadVT, Align(4)); + Loads.push_back(Load); + } + + return DAG.getMergeValues(Loads, DL); +} + SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { switch (N->getOpcode()) { default: break; + case ISD::SELECT_CC: + case ISD::SELECT: return PerformSELECTCombine(N, DCI, Subtarget); + case ISD::VSELECT: return PerformVSELECTCombine(N, DCI, Subtarget); + case ISD::SETCC: return PerformVSetCCToVCTPCombine(N, DCI, Subtarget); case ISD::ABS: return PerformABSCombine(N, DCI, Subtarget); case ARMISD::ADDE: return PerformADDECombine(N, DCI, Subtarget); case ARMISD::UMLAL: return PerformUMLALCombine(N, DCI.DAG, Subtarget); @@ -14632,31 +18677,57 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N, case ARMISD::ADDC: case ARMISD::SUBC: return PerformAddcSubcCombine(N, DCI, Subtarget); case ARMISD::SUBE: return PerformAddeSubeCombine(N, DCI, Subtarget); - case ARMISD::BFI: return PerformBFICombine(N, DCI); + case ARMISD::BFI: return PerformBFICombine(N, DCI.DAG); case ARMISD::VMOVRRD: return PerformVMOVRRDCombine(N, DCI, Subtarget); case ARMISD::VMOVDRR: return PerformVMOVDRRCombine(N, DCI.DAG); + case ARMISD::VMOVhr: return PerformVMOVhrCombine(N, DCI); + case ARMISD::VMOVrh: return PerformVMOVrhCombine(N, DCI.DAG); case ISD::STORE: return PerformSTORECombine(N, DCI, Subtarget); case ISD::BUILD_VECTOR: return PerformBUILD_VECTORCombine(N, DCI, Subtarget); case ISD::INSERT_VECTOR_ELT: return PerformInsertEltCombine(N, DCI); + case ISD::EXTRACT_VECTOR_ELT: + return PerformExtractEltCombine(N, DCI, Subtarget); + case ISD::SIGN_EXTEND_INREG: return PerformSignExtendInregCombine(N, DCI.DAG); + case ISD::INSERT_SUBVECTOR: return PerformInsertSubvectorCombine(N, DCI); case ISD::VECTOR_SHUFFLE: return PerformVECTOR_SHUFFLECombine(N, DCI.DAG); - case ARMISD::VDUPLANE: return PerformVDUPLANECombine(N, DCI); - case ARMISD::VDUP: return PerformVDUPCombine(N, DCI, Subtarget); + case ARMISD::VDUPLANE: return PerformVDUPLANECombine(N, DCI, Subtarget); + case ARMISD::VDUP: return PerformVDUPCombine(N, DCI.DAG, Subtarget); case ISD::FP_TO_SINT: case ISD::FP_TO_UINT: return PerformVCVTCombine(N, DCI.DAG, Subtarget); + case ISD::FADD: + return PerformFAddVSelectCombine(N, DCI.DAG, Subtarget); case ISD::FDIV: return PerformVDIVCombine(N, DCI.DAG, Subtarget); - case ISD::INTRINSIC_WO_CHAIN: return PerformIntrinsicCombine(N, DCI.DAG); + case ISD::INTRINSIC_WO_CHAIN: + return PerformIntrinsicCombine(N, DCI); case ISD::SHL: case ISD::SRA: case ISD::SRL: return PerformShiftCombine(N, DCI, Subtarget); case ISD::SIGN_EXTEND: case ISD::ZERO_EXTEND: - case ISD::ANY_EXTEND: return PerformExtendCombine(N, DCI.DAG, Subtarget); - case ARMISD::CMOV: return PerformCMOVCombine(N, DCI.DAG); - case ARMISD::BRCOND: return PerformBRCONDCombine(N, DCI.DAG); - case ISD::LOAD: return PerformLOADCombine(N, DCI); + case ISD::ANY_EXTEND: + return PerformExtendCombine(N, DCI.DAG, Subtarget); + case ISD::FP_EXTEND: + return PerformFPExtendCombine(N, DCI.DAG, Subtarget); + case ISD::SMIN: + case ISD::UMIN: + case ISD::SMAX: + case ISD::UMAX: + return PerformMinMaxCombine(N, DCI.DAG, Subtarget); + case ARMISD::CMOV: + return PerformCMOVCombine(N, DCI.DAG); + case ARMISD::BRCOND: + return PerformBRCONDCombine(N, DCI.DAG); + case ARMISD::CMPZ: + return PerformCMPZCombine(N, DCI.DAG); + case ARMISD::CSINC: + case ARMISD::CSINV: + case ARMISD::CSNEG: + return PerformCSETCombine(N, DCI.DAG); + case ISD::LOAD: + return PerformLOADCombine(N, DCI, Subtarget); case ARMISD::VLD1DUP: case ARMISD::VLD2DUP: case ARMISD::VLD3DUP: @@ -14664,10 +18735,30 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N, return PerformVLDCombine(N, DCI); case ARMISD::BUILD_VECTOR: return PerformARMBUILD_VECTORCombine(N, DCI); + case ISD::BITCAST: + return PerformBITCASTCombine(N, DCI, Subtarget); case ARMISD::PREDICATE_CAST: return PerformPREDICATE_CASTCombine(N, DCI); + case ARMISD::VECTOR_REG_CAST: + return PerformVECTOR_REG_CASTCombine(N, DCI.DAG, Subtarget); + case ARMISD::MVETRUNC: + return PerformMVETruncCombine(N, DCI); + case ARMISD::MVESEXT: + case ARMISD::MVEZEXT: + return PerformMVEExtCombine(N, DCI); case ARMISD::VCMP: - return PerformVCMPCombine(N, DCI, Subtarget); + return PerformVCMPCombine(N, DCI.DAG, Subtarget); + case ISD::VECREDUCE_ADD: + return PerformVECREDUCE_ADDCombine(N, DCI.DAG, Subtarget); + case ARMISD::VMOVN: + return PerformVMOVNCombine(N, DCI); + case ARMISD::VQMOVNs: + case ARMISD::VQMOVNu: + return PerformVQMOVNCombine(N, DCI); + case ARMISD::ASRL: + case ARMISD::LSRL: + case ARMISD::LSLL: + return PerformLongShiftCombine(N, DCI.DAG); case ARMISD::SMULWB: { unsigned BitWidth = N->getValueType(0).getSizeInBits(); APInt DemandedMask = APInt::getLowBitsSet(BitWidth, 16); @@ -14684,7 +18775,9 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N, } case ARMISD::SMLALBB: case ARMISD::QADD16b: - case ARMISD::QSUB16b: { + case ARMISD::QSUB16b: + case ARMISD::UQADD16b: + case ARMISD::UQSUB16b: { unsigned BitWidth = N->getValueType(0).getSizeInBits(); APInt DemandedMask = APInt::getLowBitsSet(BitWidth, 16); if ((SimplifyDemandedBits(N->getOperand(0), DemandedMask, DCI)) || @@ -14721,7 +18814,9 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N, break; } case ARMISD::QADD8b: - case ARMISD::QSUB8b: { + case ARMISD::QSUB8b: + case ARMISD::UQADD8b: + case ARMISD::UQSUB8b: { unsigned BitWidth = N->getValueType(0).getSizeInBits(); APInt DemandedMask = APInt::getLowBitsSet(BitWidth, 8); if ((SimplifyDemandedBits(N->getOperand(0), DemandedMask, DCI)) || @@ -14756,6 +18851,11 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N, case Intrinsic::arm_neon_vst3lane: case Intrinsic::arm_neon_vst4lane: return PerformVLDCombine(N, DCI); + case Intrinsic::arm_mve_vld2q: + case Intrinsic::arm_mve_vld4q: + case Intrinsic::arm_mve_vst2q: + case Intrinsic::arm_mve_vst4q: + return PerformMVEVLDCombine(N, DCI); default: break; } break; @@ -14769,9 +18869,9 @@ bool ARMTargetLowering::isDesirableToTransformToIntegerOp(unsigned Opc, } bool ARMTargetLowering::allowsMisalignedMemoryAccesses(EVT VT, unsigned, - unsigned Alignment, + Align Alignment, MachineMemOperand::Flags, - bool *Fast) const { + unsigned *Fast) const { // Depends what it gets converted into if the type is weird. if (!VT.isSimple()) return false; @@ -14795,7 +18895,7 @@ bool ARMTargetLowering::allowsMisalignedMemoryAccesses(EVT VT, unsigned, // A big-endian target may also explicitly support unaligned accesses if (Subtarget->hasNEON() && (AllowsUnaligned || Subtarget->isLittle())) { if (Fast) - *Fast = true; + *Fast = 1; return true; } } @@ -14804,9 +18904,10 @@ bool ARMTargetLowering::allowsMisalignedMemoryAccesses(EVT VT, unsigned, return false; // These are for predicates - if ((Ty == MVT::v16i1 || Ty == MVT::v8i1 || Ty == MVT::v4i1)) { + if ((Ty == MVT::v16i1 || Ty == MVT::v8i1 || Ty == MVT::v4i1 || + Ty == MVT::v2i1)) { if (Fast) - *Fast = true; + *Fast = 1; return true; } @@ -14832,37 +18933,30 @@ bool ARMTargetLowering::allowsMisalignedMemoryAccesses(EVT VT, unsigned, Ty == MVT::v4i32 || Ty == MVT::v4f32 || Ty == MVT::v2i64 || Ty == MVT::v2f64) { if (Fast) - *Fast = true; + *Fast = 1; return true; } return false; } -static bool memOpAlign(unsigned DstAlign, unsigned SrcAlign, - unsigned AlignCheck) { - return ((SrcAlign == 0 || SrcAlign % AlignCheck == 0) && - (DstAlign == 0 || DstAlign % AlignCheck == 0)); -} EVT ARMTargetLowering::getOptimalMemOpType( - uint64_t Size, unsigned DstAlign, unsigned SrcAlign, bool IsMemset, - bool ZeroMemset, bool MemcpyStrSrc, - const AttributeList &FuncAttributes) const { + const MemOp &Op, const AttributeList &FuncAttributes) const { // See if we can use NEON instructions for this... - if ((!IsMemset || ZeroMemset) && Subtarget->hasNEON() && - !FuncAttributes.hasFnAttribute(Attribute::NoImplicitFloat)) { - bool Fast; - if (Size >= 16 && - (memOpAlign(SrcAlign, DstAlign, 16) || - (allowsMisalignedMemoryAccesses(MVT::v2f64, 0, 1, + if ((Op.isMemcpy() || Op.isZeroMemset()) && Subtarget->hasNEON() && + !FuncAttributes.hasFnAttr(Attribute::NoImplicitFloat)) { + unsigned Fast; + if (Op.size() >= 16 && + (Op.isAligned(Align(16)) || + (allowsMisalignedMemoryAccesses(MVT::v2f64, 0, Align(1), MachineMemOperand::MONone, &Fast) && Fast))) { return MVT::v2f64; - } else if (Size >= 8 && - (memOpAlign(SrcAlign, DstAlign, 8) || + } else if (Op.size() >= 8 && + (Op.isAligned(Align(8)) || (allowsMisalignedMemoryAccesses( - MVT::f64, 0, 1, MachineMemOperand::MONone, &Fast) && + MVT::f64, 0, Align(1), MachineMemOperand::MONone, &Fast) && Fast))) { return MVT::f64; } @@ -14974,45 +19068,119 @@ bool ARMTargetLowering::shouldSinkOperands(Instruction *I, if (!Subtarget->hasMVEIntegerOps()) return false; - auto IsSinker = [](Instruction *I, int Operand) { + auto IsFMSMul = [&](Instruction *I) { + if (!I->hasOneUse()) + return false; + auto *Sub = cast<Instruction>(*I->users().begin()); + return Sub->getOpcode() == Instruction::FSub && Sub->getOperand(1) == I; + }; + auto IsFMS = [&](Instruction *I) { + if (match(I->getOperand(0), m_FNeg(m_Value())) || + match(I->getOperand(1), m_FNeg(m_Value()))) + return true; + return false; + }; + + auto IsSinker = [&](Instruction *I, int Operand) { switch (I->getOpcode()) { case Instruction::Add: case Instruction::Mul: + case Instruction::FAdd: case Instruction::ICmp: + case Instruction::FCmp: return true; + case Instruction::FMul: + return !IsFMSMul(I); case Instruction::Sub: + case Instruction::FSub: case Instruction::Shl: case Instruction::LShr: case Instruction::AShr: return Operand == 1; + case Instruction::Call: + if (auto *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + case Intrinsic::fma: + return !IsFMS(I); + case Intrinsic::sadd_sat: + case Intrinsic::uadd_sat: + case Intrinsic::arm_mve_add_predicated: + case Intrinsic::arm_mve_mul_predicated: + case Intrinsic::arm_mve_qadd_predicated: + case Intrinsic::arm_mve_vhadd: + case Intrinsic::arm_mve_hadd_predicated: + case Intrinsic::arm_mve_vqdmull: + case Intrinsic::arm_mve_vqdmull_predicated: + case Intrinsic::arm_mve_vqdmulh: + case Intrinsic::arm_mve_qdmulh_predicated: + case Intrinsic::arm_mve_vqrdmulh: + case Intrinsic::arm_mve_qrdmulh_predicated: + case Intrinsic::arm_mve_fma_predicated: + return true; + case Intrinsic::ssub_sat: + case Intrinsic::usub_sat: + case Intrinsic::arm_mve_sub_predicated: + case Intrinsic::arm_mve_qsub_predicated: + case Intrinsic::arm_mve_hsub_predicated: + case Intrinsic::arm_mve_vhsub: + return Operand == 1; + default: + return false; + } + } + return false; default: return false; } }; - int Op = 0; - if (!isa<ShuffleVectorInst>(I->getOperand(Op))) - Op = 1; - if (!IsSinker(I, Op)) - return false; - if (!match(I->getOperand(Op), - m_ShuffleVector(m_InsertElement(m_Undef(), m_Value(), m_ZeroInt()), - m_Undef(), m_Zero()))) { - return false; - } - Instruction *Shuffle = cast<Instruction>(I->getOperand(Op)); - // All uses of the shuffle should be sunk to avoid duplicating it across gpr - // and vector registers - for (Use &U : Shuffle->uses()) { - Instruction *Insn = cast<Instruction>(U.getUser()); - if (!IsSinker(Insn, U.getOperandNo())) - return false; + for (auto OpIdx : enumerate(I->operands())) { + Instruction *Op = dyn_cast<Instruction>(OpIdx.value().get()); + // Make sure we are not already sinking this operand + if (!Op || any_of(Ops, [&](Use *U) { return U->get() == Op; })) + continue; + + Instruction *Shuffle = Op; + if (Shuffle->getOpcode() == Instruction::BitCast) + Shuffle = dyn_cast<Instruction>(Shuffle->getOperand(0)); + // We are looking for a splat that can be sunk. + if (!Shuffle || + !match(Shuffle, m_Shuffle( + m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()), + m_Undef(), m_ZeroMask()))) + continue; + if (!IsSinker(I, OpIdx.index())) + continue; + + // All uses of the shuffle should be sunk to avoid duplicating it across gpr + // and vector registers + for (Use &U : Op->uses()) { + Instruction *Insn = cast<Instruction>(U.getUser()); + if (!IsSinker(Insn, U.getOperandNo())) + return false; + } + + Ops.push_back(&Shuffle->getOperandUse(0)); + if (Shuffle != Op) + Ops.push_back(&Op->getOperandUse(0)); + Ops.push_back(&OpIdx.value()); } - Ops.push_back(&Shuffle->getOperandUse(0)); - Ops.push_back(&I->getOperandUse(Op)); return true; } +Type *ARMTargetLowering::shouldConvertSplatType(ShuffleVectorInst *SVI) const { + if (!Subtarget->hasMVEIntegerOps()) + return nullptr; + Type *SVIType = SVI->getType(); + Type *ScalarType = SVIType->getScalarType(); + + if (ScalarType->isFloatTy()) + return Type::getInt32Ty(SVIType->getContext()); + if (ScalarType->isHalfTy()) + return Type::getInt16Ty(SVIType->getContext()); + return nullptr; +} + bool ARMTargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { EVT VT = ExtVal.getValueType(); @@ -15024,6 +19192,9 @@ bool ARMTargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { return false; } + if (Subtarget->hasMVEIntegerOps()) + return true; + // Don't create a loadext if we can fold the extension into a wide/long // instruction. // If there's more than one user instruction, the loadext is desirable no @@ -15054,17 +19225,6 @@ bool ARMTargetLowering::allowTruncateForTailCall(Type *Ty1, Type *Ty2) const { return true; } -int ARMTargetLowering::getScalingFactorCost(const DataLayout &DL, - const AddrMode &AM, Type *Ty, - unsigned AS) const { - if (isLegalAddressingMode(DL, AM, Ty, AS)) { - if (Subtarget->hasFPAO()) - return AM.Scale < 0 ? 1 : 0; // positive offsets execute faster - return 0; - } - return -1; -} - /// isFMAFasterThanFMulAndFAdd - Return true if an FMA operation is faster /// than a pair of fmul and fadd instructions. fmuladd intrinsics will be /// expanded to FMAs when this method returns true, otherwise fmuladd is @@ -15361,6 +19521,31 @@ bool ARMTargetLowering::isLegalAddImmediate(int64_t Imm) const { return AbsImm >= 0 && AbsImm <= 255; } +// Return false to prevent folding +// (mul (add r, c0), c1) -> (add (mul r, c1), c0*c1) in DAGCombine, +// if the folding leads to worse code. +bool ARMTargetLowering::isMulAddWithConstProfitable(SDValue AddNode, + SDValue ConstNode) const { + // Let the DAGCombiner decide for vector types and large types. + const EVT VT = AddNode.getValueType(); + if (VT.isVector() || VT.getScalarSizeInBits() > 32) + return true; + + // It is worse if c0 is legal add immediate, while c1*c0 is not + // and has to be composed by at least two instructions. + const ConstantSDNode *C0Node = cast<ConstantSDNode>(AddNode.getOperand(1)); + const ConstantSDNode *C1Node = cast<ConstantSDNode>(ConstNode); + const int64_t C0 = C0Node->getSExtValue(); + APInt CA = C0Node->getAPIntValue() * C1Node->getAPIntValue(); + if (!isLegalAddImmediate(C0) || isLegalAddImmediate(CA.getSExtValue())) + return true; + if (ConstantMaterializationCost((unsigned)CA.getZExtValue(), Subtarget) > 1) + return false; + + // Default to true and let the DAGCombiner decide. + return true; +} + static bool getARMIndexedAddressParts(SDNode *Ptr, EVT VT, bool isSEXTLoad, SDValue &Base, SDValue &Offset, bool &isInc, @@ -15445,7 +19630,7 @@ static bool getT2IndexedAddressParts(SDNode *Ptr, EVT VT, return false; } -static bool getMVEIndexedAddressParts(SDNode *Ptr, EVT VT, unsigned Align, +static bool getMVEIndexedAddressParts(SDNode *Ptr, EVT VT, Align Alignment, bool isSEXTLoad, bool IsMasked, bool isLE, SDValue &Base, SDValue &Offset, bool &isInc, SelectionDAG &DAG) { @@ -15480,16 +19665,16 @@ static bool getMVEIndexedAddressParts(SDNode *Ptr, EVT VT, unsigned Align, // (in BE/masked) type. Base = Ptr->getOperand(0); if (VT == MVT::v4i16) { - if (Align >= 2 && IsInRange(RHSC, 0x80, 2)) + if (Alignment >= 2 && IsInRange(RHSC, 0x80, 2)) return true; } else if (VT == MVT::v4i8 || VT == MVT::v8i8) { if (IsInRange(RHSC, 0x80, 1)) return true; - } else if (Align >= 4 && + } else if (Alignment >= 4 && (CanChangeType || VT == MVT::v4i32 || VT == MVT::v4f32) && IsInRange(RHSC, 0x80, 4)) return true; - else if (Align >= 2 && + else if (Alignment >= 2 && (CanChangeType || VT == MVT::v8i16 || VT == MVT::v8f16) && IsInRange(RHSC, 0x80, 2)) return true; @@ -15511,28 +19696,28 @@ ARMTargetLowering::getPreIndexedAddressParts(SDNode *N, SDValue &Base, EVT VT; SDValue Ptr; - unsigned Align; + Align Alignment; bool isSEXTLoad = false; bool IsMasked = false; if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) { Ptr = LD->getBasePtr(); VT = LD->getMemoryVT(); - Align = LD->getAlignment(); + Alignment = LD->getAlign(); isSEXTLoad = LD->getExtensionType() == ISD::SEXTLOAD; } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) { Ptr = ST->getBasePtr(); VT = ST->getMemoryVT(); - Align = ST->getAlignment(); + Alignment = ST->getAlign(); } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) { Ptr = LD->getBasePtr(); VT = LD->getMemoryVT(); - Align = LD->getAlignment(); + Alignment = LD->getAlign(); isSEXTLoad = LD->getExtensionType() == ISD::SEXTLOAD; IsMasked = true; } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) { Ptr = ST->getBasePtr(); VT = ST->getMemoryVT(); - Align = ST->getAlignment(); + Alignment = ST->getAlign(); IsMasked = true; } else return false; @@ -15541,9 +19726,9 @@ ARMTargetLowering::getPreIndexedAddressParts(SDNode *N, SDValue &Base, bool isLegal = false; if (VT.isVector()) isLegal = Subtarget->hasMVEIntegerOps() && - getMVEIndexedAddressParts(Ptr.getNode(), VT, Align, isSEXTLoad, - IsMasked, Subtarget->isLittle(), Base, - Offset, isInc, DAG); + getMVEIndexedAddressParts( + Ptr.getNode(), VT, Alignment, isSEXTLoad, IsMasked, + Subtarget->isLittle(), Base, Offset, isInc, DAG); else { if (Subtarget->isThumb2()) isLegal = getT2IndexedAddressParts(Ptr.getNode(), VT, isSEXTLoad, Base, @@ -15569,31 +19754,31 @@ bool ARMTargetLowering::getPostIndexedAddressParts(SDNode *N, SDNode *Op, SelectionDAG &DAG) const { EVT VT; SDValue Ptr; - unsigned Align; + Align Alignment; bool isSEXTLoad = false, isNonExt; bool IsMasked = false; if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) { VT = LD->getMemoryVT(); Ptr = LD->getBasePtr(); - Align = LD->getAlignment(); + Alignment = LD->getAlign(); isSEXTLoad = LD->getExtensionType() == ISD::SEXTLOAD; isNonExt = LD->getExtensionType() == ISD::NON_EXTLOAD; } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) { VT = ST->getMemoryVT(); Ptr = ST->getBasePtr(); - Align = ST->getAlignment(); + Alignment = ST->getAlign(); isNonExt = !ST->isTruncatingStore(); } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) { VT = LD->getMemoryVT(); Ptr = LD->getBasePtr(); - Align = LD->getAlignment(); + Alignment = LD->getAlign(); isSEXTLoad = LD->getExtensionType() == ISD::SEXTLOAD; isNonExt = LD->getExtensionType() == ISD::NON_EXTLOAD; IsMasked = true; } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) { VT = ST->getMemoryVT(); Ptr = ST->getBasePtr(); - Align = ST->getAlignment(); + Alignment = ST->getAlign(); isNonExt = !ST->isTruncatingStore(); IsMasked = true; } else @@ -15608,6 +19793,8 @@ bool ARMTargetLowering::getPostIndexedAddressParts(SDNode *N, SDNode *Op, auto *RHS = dyn_cast<ConstantSDNode>(Op->getOperand(1)); if (!RHS || RHS->getZExtValue() != 4) return false; + if (Alignment < Align(4)) + return false; Offset = Op->getOperand(1); Base = Op->getOperand(0); @@ -15619,7 +19806,7 @@ bool ARMTargetLowering::getPostIndexedAddressParts(SDNode *N, SDNode *Op, bool isLegal = false; if (VT.isVector()) isLegal = Subtarget->hasMVEIntegerOps() && - getMVEIndexedAddressParts(Op, VT, Align, isSEXTLoad, IsMasked, + getMVEIndexedAddressParts(Op, VT, Alignment, isSEXTLoad, IsMasked, Subtarget->isLittle(), Base, Offset, isInc, DAG); else { @@ -15681,8 +19868,7 @@ void ARMTargetLowering::computeKnownBitsForTargetNode(const SDValue Op, return; KnownBits KnownRHS = DAG.computeKnownBits(Op.getOperand(1), Depth+1); - Known.Zero &= KnownRHS.Zero; - Known.One &= KnownRHS.One; + Known = KnownBits::commonBits(Known, KnownRHS); return; } case ISD::INTRINSIC_W_CHAIN: { @@ -15734,18 +19920,45 @@ void ARMTargetLowering::computeKnownBitsForTargetNode(const SDValue Op, if (Op.getOpcode() == ARMISD::VGETLANEs) Known = Known.sext(DstSz); else { - Known = Known.zext(DstSz, true /* extended bits are known zero */); + Known = Known.zext(DstSz); } assert(DstSz == Known.getBitWidth()); break; } + case ARMISD::VMOVrh: { + KnownBits KnownOp = DAG.computeKnownBits(Op->getOperand(0), Depth + 1); + assert(KnownOp.getBitWidth() == 16); + Known = KnownOp.zext(32); + break; + } + case ARMISD::CSINC: + case ARMISD::CSINV: + case ARMISD::CSNEG: { + KnownBits KnownOp0 = DAG.computeKnownBits(Op->getOperand(0), Depth + 1); + KnownBits KnownOp1 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1); + + // The result is either: + // CSINC: KnownOp0 or KnownOp1 + 1 + // CSINV: KnownOp0 or ~KnownOp1 + // CSNEG: KnownOp0 or KnownOp1 * -1 + if (Op.getOpcode() == ARMISD::CSINC) + KnownOp1 = KnownBits::computeForAddSub( + true, false, KnownOp1, KnownBits::makeConstant(APInt(32, 1))); + else if (Op.getOpcode() == ARMISD::CSINV) + std::swap(KnownOp1.Zero, KnownOp1.One); + else if (Op.getOpcode() == ARMISD::CSNEG) + KnownOp1 = KnownBits::mul( + KnownOp1, KnownBits::makeConstant(APInt(32, -1))); + + Known = KnownBits::commonBits(KnownOp0, KnownOp1); + break; + } } } -bool -ARMTargetLowering::targetShrinkDemandedConstant(SDValue Op, - const APInt &DemandedAPInt, - TargetLoweringOpt &TLO) const { +bool ARMTargetLowering::targetShrinkDemandedConstant( + SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts, + TargetLoweringOpt &TLO) const { // Delay optimization, so we don't have to deal with illegal types, or block // optimizations. if (!TLO.LegalOps) @@ -15770,7 +19983,7 @@ ARMTargetLowering::targetShrinkDemandedConstant(SDValue Op, unsigned Mask = C->getZExtValue(); - unsigned Demanded = DemandedAPInt.getZExtValue(); + unsigned Demanded = DemandedBits.getZExtValue(); unsigned ShrunkMask = Mask & Demanded; unsigned ExpandedMask = Mask | ~Demanded; @@ -15825,6 +20038,43 @@ ARMTargetLowering::targetShrinkDemandedConstant(SDValue Op, return false; } +bool ARMTargetLowering::SimplifyDemandedBitsForTargetNode( + SDValue Op, const APInt &OriginalDemandedBits, + const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO, + unsigned Depth) const { + unsigned Opc = Op.getOpcode(); + + switch (Opc) { + case ARMISD::ASRL: + case ARMISD::LSRL: { + // If this is result 0 and the other result is unused, see if the demand + // bits allow us to shrink this long shift into a standard small shift in + // the opposite direction. + if (Op.getResNo() == 0 && !Op->hasAnyUseOfValue(1) && + isa<ConstantSDNode>(Op->getOperand(2))) { + unsigned ShAmt = Op->getConstantOperandVal(2); + if (ShAmt < 32 && OriginalDemandedBits.isSubsetOf(APInt::getAllOnes(32) + << (32 - ShAmt))) + return TLO.CombineTo( + Op, TLO.DAG.getNode( + ISD::SHL, SDLoc(Op), MVT::i32, Op.getOperand(1), + TLO.DAG.getConstant(32 - ShAmt, SDLoc(Op), MVT::i32))); + } + break; + } + case ARMISD::VBICIMM: { + SDValue Op0 = Op.getOperand(0); + unsigned ModImm = Op.getConstantOperandVal(1); + unsigned EltBits = 0; + uint64_t Mask = ARM_AM::decodeVMOVModImm(ModImm, EltBits); + if ((OriginalDemandedBits & Mask) == 0) + return TLO.CombineTo(Op, Op0); + } + } + + return TargetLowering::SimplifyDemandedBitsForTargetNode( + Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth); +} //===----------------------------------------------------------------------===// // ARM Inline Assembly Support @@ -15835,7 +20085,7 @@ bool ARMTargetLowering::ExpandInlineAsm(CallInst *CI) const { if (!Subtarget->hasV6Ops()) return false; - InlineAsm *IA = cast<InlineAsm>(CI->getCalledValue()); + InlineAsm *IA = cast<InlineAsm>(CI->getCalledOperand()); std::string AsmStr = IA->getAsmString(); SmallVector<StringRef, 4> AsmPieces; SplitString(AsmStr, AsmPieces, ";\n"); @@ -15843,7 +20093,7 @@ bool ARMTargetLowering::ExpandInlineAsm(CallInst *CI) const { switch (AsmPieces.size()) { default: return false; case 1: - AsmStr = AsmPieces[0]; + AsmStr = std::string(AsmPieces[0]); AsmPieces.clear(); SplitString(AsmStr, AsmPieces, " \t,"); @@ -15967,6 +20217,8 @@ RCPair ARMTargetLowering::getRegForInlineAsmConstraint( case 'w': if (VT == MVT::Other) break; + if (VT == MVT::f16 || VT == MVT::bf16) + return RCPair(0U, &ARM::HPRRegClass); if (VT == MVT::f32) return RCPair(0U, &ARM::SPRRegClass); if (VT.getSizeInBits() == 64) @@ -15987,6 +20239,8 @@ RCPair ARMTargetLowering::getRegForInlineAsmConstraint( case 't': if (VT == MVT::Other) break; + if (VT == MVT::f16 || VT == MVT::bf16) + return RCPair(0U, &ARM::HPRRegClass); if (VT == MVT::f32 || VT == MVT::i32) return RCPair(0U, &ARM::SPRRegClass); if (VT.getSizeInBits() == 64) @@ -16014,7 +20268,7 @@ RCPair ARMTargetLowering::getRegForInlineAsmConstraint( break; } - if (StringRef("{cc}").equals_lower(Constraint)) + if (StringRef("{cc}").equals_insensitive(Constraint)) return std::make_pair(unsigned(ARM::CPSR), &ARM::CCRRegClass); return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT); @@ -16238,9 +20492,22 @@ SDValue ARMTargetLowering::LowerDivRem(SDValue Op, SelectionDAG &DAG) const { "Invalid opcode for Div/Rem lowering"); bool isSigned = (Opcode == ISD::SDIVREM); EVT VT = Op->getValueType(0); - Type *Ty = VT.getTypeForEVT(*DAG.getContext()); SDLoc dl(Op); + if (VT == MVT::i64 && isa<ConstantSDNode>(Op.getOperand(1))) { + SmallVector<SDValue> Result; + if (expandDIVREMByConstant(Op.getNode(), Result, MVT::i32, DAG)) { + SDValue Res0 = + DAG.getNode(ISD::BUILD_PAIR, dl, VT, Result[0], Result[1]); + SDValue Res1 = + DAG.getNode(ISD::BUILD_PAIR, dl, VT, Result[2], Result[3]); + return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), + {Res0, Res1}); + } + } + + Type *Ty = VT.getTypeForEVT(*DAG.getContext()); + // If the target has hardware divide, use divide + multiply + subtract: // div = a / b // rem = a - b * div @@ -16289,11 +20556,20 @@ SDValue ARMTargetLowering::LowerDivRem(SDValue Op, SelectionDAG &DAG) const { // Lowers REM using divmod helpers // see RTABI section 4.2/4.3 SDValue ARMTargetLowering::LowerREM(SDNode *N, SelectionDAG &DAG) const { + EVT VT = N->getValueType(0); + + if (VT == MVT::i64 && isa<ConstantSDNode>(N->getOperand(1))) { + SmallVector<SDValue> Result; + if (expandDIVREMByConstant(N, Result, MVT::i32, DAG)) + return DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), N->getValueType(0), + Result[0], Result[1]); + } + // Build return types (div and rem) std::vector<Type*> RetTyParams; Type *RetTyElement; - switch (N->getValueType(0).getSimpleVT().SimpleTy) { + switch (VT.getSimpleVT().SimpleTy) { default: llvm_unreachable("Unexpected request for libcall!"); case MVT::i8: RetTyElement = Type::getInt8Ty(*DAG.getContext()); break; case MVT::i16: RetTyElement = Type::getInt16Ty(*DAG.getContext()); break; @@ -16342,13 +20618,15 @@ ARMTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const if (DAG.getMachineFunction().getFunction().hasFnAttribute( "no-stack-arg-probe")) { - unsigned Align = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue(); + MaybeAlign Align = + cast<ConstantSDNode>(Op.getOperand(2))->getMaybeAlignValue(); SDValue SP = DAG.getCopyFromReg(Chain, DL, ARM::SP, MVT::i32); Chain = SP.getValue(1); SP = DAG.getNode(ISD::SUB, DL, MVT::i32, SP, Size); if (Align) - SP = DAG.getNode(ISD::AND, DL, MVT::i32, SP.getValue(0), - DAG.getConstant(-(uint64_t)Align, DL, MVT::i32)); + SP = + DAG.getNode(ISD::AND, DL, MVT::i32, SP.getValue(0), + DAG.getConstant(-(uint64_t)Align->value(), DL, MVT::i32)); Chain = DAG.getCopyToReg(Chain, DL, ARM::SP, SP); SDValue Ops[2] = { SP, Chain }; return DAG.getMergeValues(Ops, DL); @@ -16463,38 +20741,6 @@ SDValue ARMTargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const { return IsStrict ? DAG.getMergeValues({Result, Chain}, Loc) : Result; } -void ARMTargetLowering::lowerABS(SDNode *N, SmallVectorImpl<SDValue> &Results, - SelectionDAG &DAG) const { - assert(N->getValueType(0) == MVT::i64 && "Unexpected type (!= i64) on ABS."); - MVT HalfT = MVT::i32; - SDLoc dl(N); - SDValue Hi, Lo, Tmp; - - if (!isOperationLegalOrCustom(ISD::ADDCARRY, HalfT) || - !isOperationLegalOrCustom(ISD::UADDO, HalfT)) - return ; - - unsigned OpTypeBits = HalfT.getScalarSizeInBits(); - SDVTList VTList = DAG.getVTList(HalfT, MVT::i1); - - Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HalfT, N->getOperand(0), - DAG.getConstant(0, dl, HalfT)); - Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HalfT, N->getOperand(0), - DAG.getConstant(1, dl, HalfT)); - - Tmp = DAG.getNode(ISD::SRA, dl, HalfT, Hi, - DAG.getConstant(OpTypeBits - 1, dl, - getShiftAmountTy(HalfT, DAG.getDataLayout()))); - Lo = DAG.getNode(ISD::UADDO, dl, VTList, Tmp, Lo); - Hi = DAG.getNode(ISD::ADDCARRY, dl, VTList, Tmp, Hi, - SDValue(Lo.getNode(), 1)); - Hi = DAG.getNode(ISD::XOR, dl, HalfT, Tmp, Hi); - Lo = DAG.getNode(ISD::XOR, dl, HalfT, Tmp, Lo); - - Results.push_back(Lo); - Results.push_back(Hi); -} - bool ARMTargetLowering::isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const { // The ARM target isn't yet aware of offsets. @@ -16519,6 +20765,9 @@ bool ARMTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT, return false; if (VT == MVT::f16 && Subtarget->hasFullFP16()) return ARM_AM::getFP16Imm(Imm) != -1; + if (VT == MVT::f32 && Subtarget->hasFullFP16() && + ARM_AM::getFP32FP16Imm(Imm) != -1) + return true; if (VT == MVT::f32) return ARM_AM::getFP32Imm(Imm) != -1; if (VT == MVT::f64 && Subtarget->hasFP64()) @@ -16551,8 +20800,8 @@ bool ARMTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, Info.memVT = EVT::getVectorVT(I.getType()->getContext(), MVT::i64, NumElts); Info.ptrVal = I.getArgOperand(0); Info.offset = 0; - Value *AlignArg = I.getArgOperand(I.getNumArgOperands() - 1); - Info.align = MaybeAlign(cast<ConstantInt>(AlignArg)->getZExtValue()); + Value *AlignArg = I.getArgOperand(I.arg_size() - 1); + Info.align = cast<ConstantInt>(AlignArg)->getMaybeAlignValue(); // volatile loads with NEON intrinsics not supported Info.flags = MachineMemOperand::MOLoad; return true; @@ -16565,7 +20814,7 @@ bool ARMTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, auto &DL = I.getCalledFunction()->getParent()->getDataLayout(); uint64_t NumElts = DL.getTypeSizeInBits(I.getType()) / 64; Info.memVT = EVT::getVectorVT(I.getType()->getContext(), MVT::i64, NumElts); - Info.ptrVal = I.getArgOperand(I.getNumArgOperands() - 1); + Info.ptrVal = I.getArgOperand(I.arg_size() - 1); Info.offset = 0; Info.align.reset(); // volatile loads with NEON intrinsics not supported @@ -16583,7 +20832,7 @@ bool ARMTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, // Conservatively set memVT to the entire set of vectors stored. auto &DL = I.getCalledFunction()->getParent()->getDataLayout(); unsigned NumElts = 0; - for (unsigned ArgI = 1, ArgE = I.getNumArgOperands(); ArgI < ArgE; ++ArgI) { + for (unsigned ArgI = 1, ArgE = I.arg_size(); ArgI < ArgE; ++ArgI) { Type *ArgTy = I.getArgOperand(ArgI)->getType(); if (!ArgTy->isVectorTy()) break; @@ -16592,8 +20841,8 @@ bool ARMTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, Info.memVT = EVT::getVectorVT(I.getType()->getContext(), MVT::i64, NumElts); Info.ptrVal = I.getArgOperand(0); Info.offset = 0; - Value *AlignArg = I.getArgOperand(I.getNumArgOperands() - 1); - Info.align = MaybeAlign(cast<ConstantInt>(AlignArg)->getZExtValue()); + Value *AlignArg = I.getArgOperand(I.arg_size() - 1); + Info.align = cast<ConstantInt>(AlignArg)->getMaybeAlignValue(); // volatile stores with NEON intrinsics not supported Info.flags = MachineMemOperand::MOStore; return true; @@ -16605,7 +20854,7 @@ bool ARMTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, // Conservatively set memVT to the entire set of vectors stored. auto &DL = I.getCalledFunction()->getParent()->getDataLayout(); unsigned NumElts = 0; - for (unsigned ArgI = 1, ArgE = I.getNumArgOperands(); ArgI < ArgE; ++ArgI) { + for (unsigned ArgI = 1, ArgE = I.arg_size(); ArgI < ArgE; ++ArgI) { Type *ArgTy = I.getArgOperand(ArgI)->getType(); if (!ArgTy->isVectorTy()) break; @@ -16619,27 +20868,115 @@ bool ARMTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, Info.flags = MachineMemOperand::MOStore; return true; } + case Intrinsic::arm_mve_vld2q: + case Intrinsic::arm_mve_vld4q: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + // Conservatively set memVT to the entire set of vectors loaded. + Type *VecTy = cast<StructType>(I.getType())->getElementType(1); + unsigned Factor = Intrinsic == Intrinsic::arm_mve_vld2q ? 2 : 4; + Info.memVT = EVT::getVectorVT(VecTy->getContext(), MVT::i64, Factor * 2); + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.align = Align(VecTy->getScalarSizeInBits() / 8); + // volatile loads with MVE intrinsics not supported + Info.flags = MachineMemOperand::MOLoad; + return true; + } + case Intrinsic::arm_mve_vst2q: + case Intrinsic::arm_mve_vst4q: { + Info.opc = ISD::INTRINSIC_VOID; + // Conservatively set memVT to the entire set of vectors stored. + Type *VecTy = I.getArgOperand(1)->getType(); + unsigned Factor = Intrinsic == Intrinsic::arm_mve_vst2q ? 2 : 4; + Info.memVT = EVT::getVectorVT(VecTy->getContext(), MVT::i64, Factor * 2); + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.align = Align(VecTy->getScalarSizeInBits() / 8); + // volatile stores with MVE intrinsics not supported + Info.flags = MachineMemOperand::MOStore; + return true; + } + case Intrinsic::arm_mve_vldr_gather_base: + case Intrinsic::arm_mve_vldr_gather_base_predicated: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.ptrVal = nullptr; + Info.memVT = MVT::getVT(I.getType()); + Info.align = Align(1); + Info.flags |= MachineMemOperand::MOLoad; + return true; + } + case Intrinsic::arm_mve_vldr_gather_base_wb: + case Intrinsic::arm_mve_vldr_gather_base_wb_predicated: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.ptrVal = nullptr; + Info.memVT = MVT::getVT(I.getType()->getContainedType(0)); + Info.align = Align(1); + Info.flags |= MachineMemOperand::MOLoad; + return true; + } + case Intrinsic::arm_mve_vldr_gather_offset: + case Intrinsic::arm_mve_vldr_gather_offset_predicated: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.ptrVal = nullptr; + MVT DataVT = MVT::getVT(I.getType()); + unsigned MemSize = cast<ConstantInt>(I.getArgOperand(2))->getZExtValue(); + Info.memVT = MVT::getVectorVT(MVT::getIntegerVT(MemSize), + DataVT.getVectorNumElements()); + Info.align = Align(1); + Info.flags |= MachineMemOperand::MOLoad; + return true; + } + case Intrinsic::arm_mve_vstr_scatter_base: + case Intrinsic::arm_mve_vstr_scatter_base_predicated: { + Info.opc = ISD::INTRINSIC_VOID; + Info.ptrVal = nullptr; + Info.memVT = MVT::getVT(I.getArgOperand(2)->getType()); + Info.align = Align(1); + Info.flags |= MachineMemOperand::MOStore; + return true; + } + case Intrinsic::arm_mve_vstr_scatter_base_wb: + case Intrinsic::arm_mve_vstr_scatter_base_wb_predicated: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.ptrVal = nullptr; + Info.memVT = MVT::getVT(I.getArgOperand(2)->getType()); + Info.align = Align(1); + Info.flags |= MachineMemOperand::MOStore; + return true; + } + case Intrinsic::arm_mve_vstr_scatter_offset: + case Intrinsic::arm_mve_vstr_scatter_offset_predicated: { + Info.opc = ISD::INTRINSIC_VOID; + Info.ptrVal = nullptr; + MVT DataVT = MVT::getVT(I.getArgOperand(2)->getType()); + unsigned MemSize = cast<ConstantInt>(I.getArgOperand(3))->getZExtValue(); + Info.memVT = MVT::getVectorVT(MVT::getIntegerVT(MemSize), + DataVT.getVectorNumElements()); + Info.align = Align(1); + Info.flags |= MachineMemOperand::MOStore; + return true; + } case Intrinsic::arm_ldaex: case Intrinsic::arm_ldrex: { auto &DL = I.getCalledFunction()->getParent()->getDataLayout(); - PointerType *PtrTy = cast<PointerType>(I.getArgOperand(0)->getType()); + Type *ValTy = I.getParamElementType(0); Info.opc = ISD::INTRINSIC_W_CHAIN; - Info.memVT = MVT::getVT(PtrTy->getElementType()); + Info.memVT = MVT::getVT(ValTy); Info.ptrVal = I.getArgOperand(0); Info.offset = 0; - Info.align = MaybeAlign(DL.getABITypeAlignment(PtrTy->getElementType())); + Info.align = DL.getABITypeAlign(ValTy); Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOVolatile; return true; } case Intrinsic::arm_stlex: case Intrinsic::arm_strex: { auto &DL = I.getCalledFunction()->getParent()->getDataLayout(); - PointerType *PtrTy = cast<PointerType>(I.getArgOperand(1)->getType()); + Type *ValTy = I.getParamElementType(1); Info.opc = ISD::INTRINSIC_W_CHAIN; - Info.memVT = MVT::getVT(PtrTy->getElementType()); + Info.memVT = MVT::getVT(ValTy); Info.ptrVal = I.getArgOperand(1); Info.offset = 0; - Info.align = MaybeAlign(DL.getABITypeAlignment(PtrTy->getElementType())); + Info.align = DL.getABITypeAlign(ValTy); Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MOVolatile; return true; } @@ -16690,7 +21027,7 @@ bool ARMTargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, return (Index == 0 || Index == ResVT.getVectorNumElements()); } -Instruction* ARMTargetLowering::makeDMB(IRBuilder<> &Builder, +Instruction *ARMTargetLowering::makeDMB(IRBuilderBase &Builder, ARM_MB::MemBOpt Domain) const { Module *M = Builder.GetInsertBlock()->getParent()->getParent(); @@ -16720,7 +21057,7 @@ Instruction* ARMTargetLowering::makeDMB(IRBuilder<> &Builder, } // Based on http://www.cl.cam.ac.uk/~pes20/cpp/cpp0xmappings.html -Instruction *ARMTargetLowering::emitLeadingFence(IRBuilder<> &Builder, +Instruction *ARMTargetLowering::emitLeadingFence(IRBuilderBase &Builder, Instruction *Inst, AtomicOrdering Ord) const { switch (Ord) { @@ -16733,7 +21070,7 @@ Instruction *ARMTargetLowering::emitLeadingFence(IRBuilder<> &Builder, case AtomicOrdering::SequentiallyConsistent: if (!Inst->hasAtomicStore()) return nullptr; // Nothing to do - LLVM_FALLTHROUGH; + [[fallthrough]]; case AtomicOrdering::Release: case AtomicOrdering::AcquireRelease: if (Subtarget->preferISHSTBarriers()) @@ -16745,7 +21082,7 @@ Instruction *ARMTargetLowering::emitLeadingFence(IRBuilder<> &Builder, llvm_unreachable("Unknown fence ordering in emitLeadingFence"); } -Instruction *ARMTargetLowering::emitTrailingFence(IRBuilder<> &Builder, +Instruction *ARMTargetLowering::emitTrailingFence(IRBuilderBase &Builder, Instruction *Inst, AtomicOrdering Ord) const { switch (Ord) { @@ -16767,9 +21104,19 @@ Instruction *ARMTargetLowering::emitTrailingFence(IRBuilder<> &Builder, // are doomed anyway, so defer to the default libcall and blame the OS when // things go wrong. Cortex M doesn't have ldrexd/strexd though, so don't emit // anything for those. -bool ARMTargetLowering::shouldExpandAtomicStoreInIR(StoreInst *SI) const { +TargetLoweringBase::AtomicExpansionKind +ARMTargetLowering::shouldExpandAtomicStoreInIR(StoreInst *SI) const { + bool has64BitAtomicStore; + if (Subtarget->isMClass()) + has64BitAtomicStore = false; + else if (Subtarget->isThumb()) + has64BitAtomicStore = Subtarget->hasV7Ops(); + else + has64BitAtomicStore = Subtarget->hasV6Ops(); + unsigned Size = SI->getValueOperand()->getType()->getPrimitiveSizeInBits(); - return (Size == 64) && !Subtarget->isMClass(); + return Size == 64 && has64BitAtomicStore ? AtomicExpansionKind::Expand + : AtomicExpansionKind::None; } // Loads and stores less than 64-bits are already atomic; ones above that @@ -16781,9 +21128,17 @@ bool ARMTargetLowering::shouldExpandAtomicStoreInIR(StoreInst *SI) const { // sections A8.8.72-74 LDRD) TargetLowering::AtomicExpansionKind ARMTargetLowering::shouldExpandAtomicLoadInIR(LoadInst *LI) const { + bool has64BitAtomicLoad; + if (Subtarget->isMClass()) + has64BitAtomicLoad = false; + else if (Subtarget->isThumb()) + has64BitAtomicLoad = Subtarget->hasV7Ops(); + else + has64BitAtomicLoad = Subtarget->hasV6Ops(); + unsigned Size = LI->getType()->getPrimitiveSizeInBits(); - return ((Size == 64) && !Subtarget->isMClass()) ? AtomicExpansionKind::LLOnly - : AtomicExpansionKind::None; + return (Size == 64 && has64BitAtomicLoad) ? AtomicExpansionKind::LLOnly + : AtomicExpansionKind::None; } // For the real atomic operations, we have ldrex/strex up to 32 bits, @@ -16794,12 +21149,28 @@ ARMTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { return AtomicExpansionKind::CmpXChg; unsigned Size = AI->getType()->getPrimitiveSizeInBits(); - bool hasAtomicRMW = !Subtarget->isThumb() || Subtarget->hasV8MBaselineOps(); - return (Size <= (Subtarget->isMClass() ? 32U : 64U) && hasAtomicRMW) - ? AtomicExpansionKind::LLSC - : AtomicExpansionKind::None; + bool hasAtomicRMW; + if (Subtarget->isMClass()) + hasAtomicRMW = Subtarget->hasV8MBaselineOps(); + else if (Subtarget->isThumb()) + hasAtomicRMW = Subtarget->hasV7Ops(); + else + hasAtomicRMW = Subtarget->hasV6Ops(); + if (Size <= (Subtarget->isMClass() ? 32U : 64U) && hasAtomicRMW) { + // At -O0, fast-regalloc cannot cope with the live vregs necessary to + // implement atomicrmw without spilling. If the target address is also on + // the stack and close enough to the spill slot, this can lead to a + // situation where the monitor always gets cleared and the atomic operation + // can never succeed. So at -O0 lower this operation to a CAS loop. + if (getTargetMachine().getOptLevel() == CodeGenOpt::None) + return AtomicExpansionKind::CmpXChg; + return AtomicExpansionKind::LLSC; + } + return AtomicExpansionKind::None; } +// Similar to shouldExpandAtomicRMWInIR, ldrex/strex can be used up to 32 +// bits, and up to 64 bits on the non-M profiles. TargetLowering::AtomicExpansionKind ARMTargetLowering::shouldExpandAtomicCmpXchgInIR(AtomicCmpXchgInst *AI) const { // At -O0, fast-regalloc cannot cope with the live vregs necessary to @@ -16807,9 +21178,16 @@ ARMTargetLowering::shouldExpandAtomicCmpXchgInIR(AtomicCmpXchgInst *AI) const { // on the stack and close enough to the spill slot, this can lead to a // situation where the monitor always gets cleared and the atomic operation // can never succeed. So at -O0 we need a late-expanded pseudo-inst instead. - bool HasAtomicCmpXchg = - !Subtarget->isThumb() || Subtarget->hasV8MBaselineOps(); - if (getTargetMachine().getOptLevel() != 0 && HasAtomicCmpXchg) + unsigned Size = AI->getOperand(1)->getType()->getPrimitiveSizeInBits(); + bool HasAtomicCmpXchg; + if (Subtarget->isMClass()) + HasAtomicCmpXchg = Subtarget->hasV8MBaselineOps(); + else if (Subtarget->isThumb()) + HasAtomicCmpXchg = Subtarget->hasV7Ops(); + else + HasAtomicCmpXchg = Subtarget->hasV6Ops(); + if (getTargetMachine().getOptLevel() != 0 && HasAtomicCmpXchg && + Size <= (Subtarget->isMClass() ? 32U : 64U)) return AtomicExpansionKind::LLSC; return AtomicExpansionKind::None; } @@ -16819,9 +21197,11 @@ bool ARMTargetLowering::shouldInsertFencesForAtomic( return InsertFencesForAtomic; } -// This has so far only been implemented for MachO. bool ARMTargetLowering::useLoadStackGuardNode() const { - return Subtarget->isTargetMachO(); + if (Subtarget->getTargetTriple().isOSOpenBSD()) + return false; + // ROPI/RWPI are not supported currently. + return !Subtarget->isROPI() && !Subtarget->isRWPI(); } void ARMTargetLowering::insertSSPDeclarations(Module &M) const { @@ -16837,7 +21217,7 @@ void ARMTargetLowering::insertSSPDeclarations(Module &M) const { "__security_check_cookie", Type::getVoidTy(M.getContext()), Type::getInt8PtrTy(M.getContext())); if (Function *F = dyn_cast<Function>(SecurityCheckCookie.getCallee())) - F->addAttribute(1, Attribute::AttrKind::InReg); + F->addParamAttr(0, Attribute::AttrKind::InReg); } Value *ARMTargetLowering::getSDagStackGuard(const Module &M) const { @@ -16873,7 +21253,7 @@ bool ARMTargetLowering::canCombineStoreAndExtract(Type *VectorTy, Value *Idx, return false; assert(VectorTy->isVectorTy() && "VectorTy is not a vector type"); - unsigned BitWidth = cast<VectorType>(VectorTy)->getBitWidth(); + unsigned BitWidth = VectorTy->getPrimitiveSizeInBits().getFixedValue(); // We can do a store + vector extract on any vector that fits perfectly in a D // or Q register. if (BitWidth == 64 || BitWidth == 128) { @@ -16883,28 +21263,48 @@ bool ARMTargetLowering::canCombineStoreAndExtract(Type *VectorTy, Value *Idx, return false; } -bool ARMTargetLowering::isCheapToSpeculateCttz() const { +bool ARMTargetLowering::isCheapToSpeculateCttz(Type *Ty) const { return Subtarget->hasV6T2Ops(); } -bool ARMTargetLowering::isCheapToSpeculateCtlz() const { +bool ARMTargetLowering::isCheapToSpeculateCtlz(Type *Ty) const { return Subtarget->hasV6T2Ops(); } -bool ARMTargetLowering::shouldExpandShift(SelectionDAG &DAG, SDNode *N) const { - return !Subtarget->hasMinSize() || Subtarget->isTargetWindows(); +bool ARMTargetLowering::isMaskAndCmp0FoldingBeneficial( + const Instruction &AndI) const { + if (!Subtarget->hasV7Ops()) + return false; + + // Sink the `and` instruction only if the mask would fit into a modified + // immediate operand. + ConstantInt *Mask = dyn_cast<ConstantInt>(AndI.getOperand(1)); + if (!Mask || Mask->getValue().getBitWidth() > 32u) + return false; + auto MaskVal = unsigned(Mask->getValue().getZExtValue()); + return (Subtarget->isThumb2() ? ARM_AM::getT2SOImmVal(MaskVal) + : ARM_AM::getSOImmVal(MaskVal)) != -1; } -Value *ARMTargetLowering::emitLoadLinked(IRBuilder<> &Builder, Value *Addr, +TargetLowering::ShiftLegalizationStrategy +ARMTargetLowering::preferredShiftLegalizationStrategy( + SelectionDAG &DAG, SDNode *N, unsigned ExpansionFactor) const { + if (Subtarget->hasMinSize() && !Subtarget->isTargetWindows()) + return ShiftLegalizationStrategy::LowerToLibcall; + return TargetLowering::preferredShiftLegalizationStrategy(DAG, N, + ExpansionFactor); +} + +Value *ARMTargetLowering::emitLoadLinked(IRBuilderBase &Builder, Type *ValueTy, + Value *Addr, AtomicOrdering Ord) const { Module *M = Builder.GetInsertBlock()->getParent()->getParent(); - Type *ValTy = cast<PointerType>(Addr->getType())->getElementType(); bool IsAcquire = isAcquireOrStronger(Ord); // Since i64 isn't legal and intrinsics don't get type-lowered, the ldrexd // intrinsic must return {i32, i32} and we have to recombine them into a // single i64 here. - if (ValTy->getPrimitiveSizeInBits() == 64) { + if (ValueTy->getPrimitiveSizeInBits() == 64) { Intrinsic::ID Int = IsAcquire ? Intrinsic::arm_ldaexd : Intrinsic::arm_ldrexd; Function *Ldrex = Intrinsic::getDeclaration(M, Int); @@ -16916,31 +21316,32 @@ Value *ARMTargetLowering::emitLoadLinked(IRBuilder<> &Builder, Value *Addr, Value *Hi = Builder.CreateExtractValue(LoHi, 1, "hi"); if (!Subtarget->isLittle()) std::swap (Lo, Hi); - Lo = Builder.CreateZExt(Lo, ValTy, "lo64"); - Hi = Builder.CreateZExt(Hi, ValTy, "hi64"); + Lo = Builder.CreateZExt(Lo, ValueTy, "lo64"); + Hi = Builder.CreateZExt(Hi, ValueTy, "hi64"); return Builder.CreateOr( - Lo, Builder.CreateShl(Hi, ConstantInt::get(ValTy, 32)), "val64"); + Lo, Builder.CreateShl(Hi, ConstantInt::get(ValueTy, 32)), "val64"); } Type *Tys[] = { Addr->getType() }; Intrinsic::ID Int = IsAcquire ? Intrinsic::arm_ldaex : Intrinsic::arm_ldrex; Function *Ldrex = Intrinsic::getDeclaration(M, Int, Tys); + CallInst *CI = Builder.CreateCall(Ldrex, Addr); - return Builder.CreateTruncOrBitCast( - Builder.CreateCall(Ldrex, Addr), - cast<PointerType>(Addr->getType())->getElementType()); + CI->addParamAttr( + 0, Attribute::get(M->getContext(), Attribute::ElementType, ValueTy)); + return Builder.CreateTruncOrBitCast(CI, ValueTy); } void ARMTargetLowering::emitAtomicCmpXchgNoStoreLLBalance( - IRBuilder<> &Builder) const { + IRBuilderBase &Builder) const { if (!Subtarget->hasV7Ops()) return; Module *M = Builder.GetInsertBlock()->getParent()->getParent(); Builder.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::arm_clrex)); } -Value *ARMTargetLowering::emitStoreConditional(IRBuilder<> &Builder, Value *Val, - Value *Addr, +Value *ARMTargetLowering::emitStoreConditional(IRBuilderBase &Builder, + Value *Val, Value *Addr, AtomicOrdering Ord) const { Module *M = Builder.GetInsertBlock()->getParent()->getParent(); bool IsRelease = isReleaseOrStronger(Ord); @@ -16966,10 +21367,13 @@ Value *ARMTargetLowering::emitStoreConditional(IRBuilder<> &Builder, Value *Val, Type *Tys[] = { Addr->getType() }; Function *Strex = Intrinsic::getDeclaration(M, Int, Tys); - return Builder.CreateCall( + CallInst *CI = Builder.CreateCall( Strex, {Builder.CreateZExtOrBitCast( Val, Strex->getFunctionType()->getParamType(0)), Addr}); + CI->addParamAttr(1, Attribute::get(M->getContext(), Attribute::ElementType, + Val->getType())); + return CI; } @@ -16986,7 +21390,8 @@ ARMTargetLowering::getNumInterleavedAccesses(VectorType *VecTy, } bool ARMTargetLowering::isLegalInterleavedAccessType( - unsigned Factor, VectorType *VecTy, const DataLayout &DL) const { + unsigned Factor, FixedVectorType *VecTy, Align Alignment, + const DataLayout &DL) const { unsigned VecSize = DL.getTypeSizeInBits(VecTy); unsigned ElSize = DL.getTypeSizeInBits(VecTy->getElementType()); @@ -17009,6 +21414,9 @@ bool ARMTargetLowering::isLegalInterleavedAccessType( // Ensure the element type is legal. if (ElSize != 8 && ElSize != 16 && ElSize != 32) return false; + // And the alignment if high enough under MVE. + if (Subtarget->hasMVEIntegerOps() && Alignment < ElSize / 8) + return false; // Ensure the total vector size is 64 or a multiple of 128. Types larger than // 128 will be split into multiple interleaved accesses. @@ -17045,15 +21453,16 @@ bool ARMTargetLowering::lowerInterleavedLoad( assert(Shuffles.size() == Indices.size() && "Unmatched number of shufflevectors and indices"); - VectorType *VecTy = Shuffles[0]->getType(); - Type *EltTy = VecTy->getVectorElementType(); + auto *VecTy = cast<FixedVectorType>(Shuffles[0]->getType()); + Type *EltTy = VecTy->getElementType(); const DataLayout &DL = LI->getModule()->getDataLayout(); + Align Alignment = LI->getAlign(); // Skip if we do not have NEON and skip illegal vector types. We can // "legalize" wide vector types into multiple interleaved accesses as long as // the vector types are divisible by 128. - if (!isLegalInterleavedAccessType(Factor, VecTy, DL)) + if (!isLegalInterleavedAccessType(Factor, VecTy, Alignment, DL)) return false; unsigned NumLoads = getNumInterleavedAccesses(VecTy, DL); @@ -17061,8 +21470,7 @@ bool ARMTargetLowering::lowerInterleavedLoad( // A pointer vector can not be the return type of the ldN intrinsics. Need to // load integer vectors first and then convert to pointer vectors. if (EltTy->isPointerTy()) - VecTy = - VectorType::get(DL.getIntPtrType(EltTy), VecTy->getVectorNumElements()); + VecTy = FixedVectorType::get(DL.getIntPtrType(EltTy), VecTy); IRBuilder<> Builder(LI); @@ -17072,15 +21480,15 @@ bool ARMTargetLowering::lowerInterleavedLoad( if (NumLoads > 1) { // If we're going to generate more than one load, reset the sub-vector type // to something legal. - VecTy = VectorType::get(VecTy->getVectorElementType(), - VecTy->getVectorNumElements() / NumLoads); + VecTy = FixedVectorType::get(VecTy->getElementType(), + VecTy->getNumElements() / NumLoads); // We will compute the pointer operand of each load from the original base // address using GEPs. Cast the base address to a pointer to the scalar // element type. BaseAddr = Builder.CreateBitCast( - BaseAddr, VecTy->getVectorElementType()->getPointerTo( - LI->getPointerAddressSpace())); + BaseAddr, + VecTy->getElementType()->getPointerTo(LI->getPointerAddressSpace())); } assert(isTypeLegal(EVT::getEVT(VecTy)) && "Illegal vldN vector type!"); @@ -17097,7 +21505,7 @@ bool ARMTargetLowering::lowerInterleavedLoad( SmallVector<Value *, 2> Ops; Ops.push_back(Builder.CreateBitCast(BaseAddr, Int8Ptr)); - Ops.push_back(Builder.getInt32(LI->getAlignment())); + Ops.push_back(Builder.getInt32(LI->getAlign().value())); return Builder.CreateCall(VldnFunc, Ops, "vldN"); } else { @@ -17105,8 +21513,8 @@ bool ARMTargetLowering::lowerInterleavedLoad( "expected interleave factor of 2 or 4 for MVE"); Intrinsic::ID LoadInts = Factor == 2 ? Intrinsic::arm_mve_vld2q : Intrinsic::arm_mve_vld4q; - Type *VecEltTy = VecTy->getVectorElementType()->getPointerTo( - LI->getPointerAddressSpace()); + Type *VecEltTy = + VecTy->getElementType()->getPointerTo(LI->getPointerAddressSpace()); Type *Tys[] = {VecTy, VecEltTy}; Function *VldnFunc = Intrinsic::getDeclaration(LI->getModule(), LoadInts, Tys); @@ -17126,9 +21534,8 @@ bool ARMTargetLowering::lowerInterleavedLoad( // If we're generating more than one load, compute the base address of // subsequent loads as an offset from the previous. if (LoadCount > 0) - BaseAddr = - Builder.CreateConstGEP1_32(VecTy->getVectorElementType(), BaseAddr, - VecTy->getVectorNumElements() * Factor); + BaseAddr = Builder.CreateConstGEP1_32(VecTy->getElementType(), BaseAddr, + VecTy->getNumElements() * Factor); CallInst *VldN = createLoadIntrinsic(BaseAddr); @@ -17143,8 +21550,8 @@ bool ARMTargetLowering::lowerInterleavedLoad( // Convert the integer vector to pointer vector if the element is pointer. if (EltTy->isPointerTy()) SubVec = Builder.CreateIntToPtr( - SubVec, VectorType::get(SV->getType()->getVectorElementType(), - VecTy->getVectorNumElements())); + SubVec, + FixedVectorType::get(SV->getType()->getElementType(), VecTy)); SubVecs[SV].push_back(SubVec); } @@ -17196,20 +21603,20 @@ bool ARMTargetLowering::lowerInterleavedStore(StoreInst *SI, assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() && "Invalid interleave factor"); - VectorType *VecTy = SVI->getType(); - assert(VecTy->getVectorNumElements() % Factor == 0 && - "Invalid interleaved store"); + auto *VecTy = cast<FixedVectorType>(SVI->getType()); + assert(VecTy->getNumElements() % Factor == 0 && "Invalid interleaved store"); - unsigned LaneLen = VecTy->getVectorNumElements() / Factor; - Type *EltTy = VecTy->getVectorElementType(); - VectorType *SubVecTy = VectorType::get(EltTy, LaneLen); + unsigned LaneLen = VecTy->getNumElements() / Factor; + Type *EltTy = VecTy->getElementType(); + auto *SubVecTy = FixedVectorType::get(EltTy, LaneLen); const DataLayout &DL = SI->getModule()->getDataLayout(); + Align Alignment = SI->getAlign(); // Skip if we do not have NEON and skip illegal vector types. We can // "legalize" wide vector types into multiple interleaved accesses as long as // the vector types are divisible by 128. - if (!isLegalInterleavedAccessType(Factor, SubVecTy, DL)) + if (!isLegalInterleavedAccessType(Factor, SubVecTy, Alignment, DL)) return false; unsigned NumStores = getNumInterleavedAccesses(SubVecTy, DL); @@ -17224,12 +21631,12 @@ bool ARMTargetLowering::lowerInterleavedStore(StoreInst *SI, Type *IntTy = DL.getIntPtrType(EltTy); // Convert to the corresponding integer vector. - Type *IntVecTy = - VectorType::get(IntTy, Op0->getType()->getVectorNumElements()); + auto *IntVecTy = + FixedVectorType::get(IntTy, cast<FixedVectorType>(Op0->getType())); Op0 = Builder.CreatePtrToInt(Op0, IntVecTy); Op1 = Builder.CreatePtrToInt(Op1, IntVecTy); - SubVecTy = VectorType::get(IntTy, LaneLen); + SubVecTy = FixedVectorType::get(IntTy, LaneLen); } // The base address of the store. @@ -17239,14 +21646,14 @@ bool ARMTargetLowering::lowerInterleavedStore(StoreInst *SI, // If we're going to generate more than one store, reset the lane length // and sub-vector type to something legal. LaneLen /= NumStores; - SubVecTy = VectorType::get(SubVecTy->getVectorElementType(), LaneLen); + SubVecTy = FixedVectorType::get(SubVecTy->getElementType(), LaneLen); // We will compute the pointer operand of each store from the original base // address using GEPs. Cast the base address to a pointer to the scalar // element type. BaseAddr = Builder.CreateBitCast( - BaseAddr, SubVecTy->getVectorElementType()->getPointerTo( - SI->getPointerAddressSpace())); + BaseAddr, + SubVecTy->getElementType()->getPointerTo(SI->getPointerAddressSpace())); } assert(isTypeLegal(EVT::getEVT(SubVecTy)) && "Illegal vstN vector type!"); @@ -17267,16 +21674,15 @@ bool ARMTargetLowering::lowerInterleavedStore(StoreInst *SI, SmallVector<Value *, 6> Ops; Ops.push_back(Builder.CreateBitCast(BaseAddr, Int8Ptr)); - for (auto S : Shuffles) - Ops.push_back(S); - Ops.push_back(Builder.getInt32(SI->getAlignment())); + append_range(Ops, Shuffles); + Ops.push_back(Builder.getInt32(SI->getAlign().value())); Builder.CreateCall(VstNFunc, Ops); } else { assert((Factor == 2 || Factor == 4) && "expected interleave factor of 2 or 4 for MVE"); Intrinsic::ID StoreInts = Factor == 2 ? Intrinsic::arm_mve_vst2q : Intrinsic::arm_mve_vst4q; - Type *EltPtrTy = SubVecTy->getVectorElementType()->getPointerTo( + Type *EltPtrTy = SubVecTy->getElementType()->getPointerTo( SI->getPointerAddressSpace()); Type *Tys[] = {EltPtrTy, SubVecTy}; Function *VstNFunc = @@ -17284,8 +21690,7 @@ bool ARMTargetLowering::lowerInterleavedStore(StoreInst *SI, SmallVector<Value *, 6> Ops; Ops.push_back(Builder.CreateBitCast(BaseAddr, EltPtrTy)); - for (auto S : Shuffles) - Ops.push_back(S); + append_range(Ops, Shuffles); for (unsigned F = 0; F < Factor; F++) { Ops.push_back(Builder.getInt32(F)); Builder.CreateCall(VstNFunc, Ops); @@ -17298,7 +21703,7 @@ bool ARMTargetLowering::lowerInterleavedStore(StoreInst *SI, // If we generating more than one store, we compute the base address of // subsequent stores as an offset from the previous. if (StoreCount > 0) - BaseAddr = Builder.CreateConstGEP1_32(SubVecTy->getVectorElementType(), + BaseAddr = Builder.CreateConstGEP1_32(SubVecTy->getElementType(), BaseAddr, LaneLen * Factor); SmallVector<Value *, 4> Shuffles; @@ -17308,7 +21713,7 @@ bool ARMTargetLowering::lowerInterleavedStore(StoreInst *SI, unsigned IdxI = StoreCount * LaneLen * Factor + i; if (Mask[IdxI] >= 0) { Shuffles.push_back(Builder.CreateShuffleVector( - Op0, Op1, createSequentialMask(Builder, Mask[IdxI], LaneLen, 0))); + Op0, Op1, createSequentialMask(Mask[IdxI], LaneLen, 0))); } else { unsigned StartMask = 0; for (unsigned j = 1; j < LaneLen; j++) { @@ -17325,7 +21730,7 @@ bool ARMTargetLowering::lowerInterleavedStore(StoreInst *SI, // Note: StartMask cannot be negative, it's checked in // isReInterleaveMask Shuffles.push_back(Builder.CreateShuffleVector( - Op0, Op1, createSequentialMask(Builder, StartMask, LaneLen, 0))); + Op0, Op1, createSequentialMask(StartMask, LaneLen, 0))); } } @@ -17373,11 +21778,11 @@ static bool isHomogeneousAggregate(Type *Ty, HABaseType &Base, case HA_DOUBLE: return false; case HA_VECT64: - return VT->getBitWidth() == 64; + return VT->getPrimitiveSizeInBits().getFixedValue() == 64; case HA_VECT128: - return VT->getBitWidth() == 128; + return VT->getPrimitiveSizeInBits().getFixedValue() == 128; case HA_UNKNOWN: - switch (VT->getBitWidth()) { + switch (VT->getPrimitiveSizeInBits().getFixedValue()) { case 64: Base = HA_VECT64; return true; @@ -17394,9 +21799,9 @@ static bool isHomogeneousAggregate(Type *Ty, HABaseType &Base, } /// Return the correct alignment for the current calling convention. -Align ARMTargetLowering::getABIAlignmentForCallingConv(Type *ArgTy, - DataLayout DL) const { - const Align ABITypeAlign(DL.getABITypeAlignment(ArgTy)); +Align ARMTargetLowering::getABIAlignmentForCallingConv( + Type *ArgTy, const DataLayout &DL) const { + const Align ABITypeAlign = DL.getABITypeAlign(ArgTy); if (!ArgTy->isVectorTy()) return ABITypeAlign; @@ -17409,7 +21814,8 @@ Align ARMTargetLowering::getABIAlignmentForCallingConv(Type *ArgTy, /// [N x i32] or [N x i64]. This allows front-ends to skip emitting padding when /// passing according to AAPCS rules. bool ARMTargetLowering::functionArgumentNeedsConsecutiveRegisters( - Type *Ty, CallingConv::ID CallConv, bool isVarArg) const { + Type *Ty, CallingConv::ID CallConv, bool isVarArg, + const DataLayout &DL) const { if (getEffectiveCallingConv(CallConv, isVarArg) != CallingConv::ARM_AAPCS_VFP) return false; @@ -17423,18 +21829,18 @@ bool ARMTargetLowering::functionArgumentNeedsConsecutiveRegisters( return IsHA || IsIntArray; } -unsigned ARMTargetLowering::getExceptionPointerRegister( +Register ARMTargetLowering::getExceptionPointerRegister( const Constant *PersonalityFn) const { // Platforms which do not use SjLj EH may return values in these registers // via the personality function. - return Subtarget->useSjLjEH() ? ARM::NoRegister : ARM::R0; + return Subtarget->useSjLjEH() ? Register() : ARM::R0; } -unsigned ARMTargetLowering::getExceptionSelectorRegister( +Register ARMTargetLowering::getExceptionSelectorRegister( const Constant *PersonalityFn) const { // Platforms which do not use SjLj EH may return values in these registers // via the personality function. - return Subtarget->useSjLjEH() ? ARM::NoRegister : ARM::R1; + return Subtarget->useSjLjEH() ? Register() : ARM::R1; } void ARMTargetLowering::initializeSplitCSR(MachineBasicBlock *Entry) const { @@ -17488,3 +21894,105 @@ void ARMTargetLowering::finalizeLowering(MachineFunction &MF) const { MF.getFrameInfo().computeMaxCallFrameSize(MF); TargetLoweringBase::finalizeLowering(MF); } + +bool ARMTargetLowering::isComplexDeinterleavingSupported() const { + return Subtarget->hasMVEIntegerOps(); +} + +bool ARMTargetLowering::isComplexDeinterleavingOperationSupported( + ComplexDeinterleavingOperation Operation, Type *Ty) const { + auto *VTy = dyn_cast<FixedVectorType>(Ty); + if (!VTy) + return false; + + auto *ScalarTy = VTy->getScalarType(); + unsigned NumElements = VTy->getNumElements(); + + unsigned VTyWidth = VTy->getScalarSizeInBits() * NumElements; + if (VTyWidth < 128 || !llvm::isPowerOf2_32(VTyWidth)) + return false; + + // Both VCADD and VCMUL/VCMLA support the same types, F16 and F32 + if (ScalarTy->isHalfTy() || ScalarTy->isFloatTy()) + return Subtarget->hasMVEFloatOps(); + + if (Operation != ComplexDeinterleavingOperation::CAdd) + return false; + + return Subtarget->hasMVEIntegerOps() && + (ScalarTy->isIntegerTy(8) || ScalarTy->isIntegerTy(16) || + ScalarTy->isIntegerTy(32)); +} + +Value *ARMTargetLowering::createComplexDeinterleavingIR( + Instruction *I, ComplexDeinterleavingOperation OperationType, + ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, + Value *Accumulator) const { + + FixedVectorType *Ty = cast<FixedVectorType>(InputA->getType()); + + IRBuilder<> B(I); + + unsigned TyWidth = Ty->getScalarSizeInBits() * Ty->getNumElements(); + + assert(TyWidth >= 128 && "Width of vector type must be at least 128 bits"); + + if (TyWidth > 128) { + int Stride = Ty->getNumElements() / 2; + auto SplitSeq = llvm::seq<int>(0, Ty->getNumElements()); + auto SplitSeqVec = llvm::to_vector(SplitSeq); + ArrayRef<int> LowerSplitMask(&SplitSeqVec[0], Stride); + ArrayRef<int> UpperSplitMask(&SplitSeqVec[Stride], Stride); + + auto *LowerSplitA = B.CreateShuffleVector(InputA, LowerSplitMask); + auto *LowerSplitB = B.CreateShuffleVector(InputB, LowerSplitMask); + auto *UpperSplitA = B.CreateShuffleVector(InputA, UpperSplitMask); + auto *UpperSplitB = B.CreateShuffleVector(InputB, UpperSplitMask); + Value *LowerSplitAcc = nullptr; + Value *UpperSplitAcc = nullptr; + + if (Accumulator) { + LowerSplitAcc = B.CreateShuffleVector(Accumulator, LowerSplitMask); + UpperSplitAcc = B.CreateShuffleVector(Accumulator, UpperSplitMask); + } + + auto *LowerSplitInt = createComplexDeinterleavingIR( + I, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc); + auto *UpperSplitInt = createComplexDeinterleavingIR( + I, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc); + + ArrayRef<int> JoinMask(&SplitSeqVec[0], Ty->getNumElements()); + return B.CreateShuffleVector(LowerSplitInt, UpperSplitInt, JoinMask); + } + + auto *IntTy = Type::getInt32Ty(B.getContext()); + + ConstantInt *ConstRotation = nullptr; + if (OperationType == ComplexDeinterleavingOperation::CMulPartial) { + ConstRotation = ConstantInt::get(IntTy, (int)Rotation); + + if (Accumulator) + return B.CreateIntrinsic(Intrinsic::arm_mve_vcmlaq, Ty, + {ConstRotation, Accumulator, InputB, InputA}); + return B.CreateIntrinsic(Intrinsic::arm_mve_vcmulq, Ty, + {ConstRotation, InputB, InputA}); + } + + if (OperationType == ComplexDeinterleavingOperation::CAdd) { + // 1 means the value is not halved. + auto *ConstHalving = ConstantInt::get(IntTy, 1); + + if (Rotation == ComplexDeinterleavingRotation::Rotation_90) + ConstRotation = ConstantInt::get(IntTy, 0); + else if (Rotation == ComplexDeinterleavingRotation::Rotation_270) + ConstRotation = ConstantInt::get(IntTy, 1); + + if (!ConstRotation) + return nullptr; // Invalid rotation for arm_mve_vcaddq + + return B.CreateIntrinsic(Intrinsic::arm_mve_vcaddq, Ty, + {ConstHalving, ConstRotation, InputA, InputB}); + } + + return nullptr; +} |