diff options
author | Pascal Stumpf <pascal@cvs.openbsd.org> | 2016-09-03 22:47:00 +0000 |
---|---|---|
committer | Pascal Stumpf <pascal@cvs.openbsd.org> | 2016-09-03 22:47:00 +0000 |
commit | 5253aae2877ae65a12d702ed10e3c3bec9436083 (patch) | |
tree | 704ef645b590ea776e3c0cb2b7b4ba215e035f21 /gnu | |
parent | 4bb7a4bc52e27a5e4572928a3a188fcf0f5ea903 (diff) |
Use the space freed up by sparc and zaurus to import LLVM.
ok hackroom@
Diffstat (limited to 'gnu')
-rw-r--r-- | gnu/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 13855 |
1 files changed, 4634 insertions, 9221 deletions
diff --git a/gnu/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/gnu/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 6af01423ca1..210aa95b02d 100644 --- a/gnu/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/gnu/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -1,4 +1,4 @@ -//===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===// +//===-- DAGCombiner.cpp - Implement a DAG node combiner -------------------===// // // The LLVM Compiler Infrastructure // @@ -16,64 +16,28 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/APInt.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/IntervalMap.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/STLExtras.h" +#include "llvm/CodeGen/SelectionDAG.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/MemoryLocation.h" -#include "llvm/CodeGen/DAGCombine.h" -#include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" -#include "llvm/CodeGen/MachineMemOperand.h" -#include "llvm/CodeGen/RuntimeLibcalls.h" -#include "llvm/CodeGen/SelectionDAG.h" -#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h" -#include "llvm/CodeGen/SelectionDAGNodes.h" -#include "llvm/CodeGen/SelectionDAGTargetInfo.h" -#include "llvm/CodeGen/TargetLowering.h" -#include "llvm/CodeGen/TargetRegisterInfo.h" -#include "llvm/CodeGen/TargetSubtargetInfo.h" -#include "llvm/CodeGen/ValueTypes.h" -#include "llvm/IR/Attributes.h" -#include "llvm/IR/Constant.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Metadata.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/CodeGen.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/KnownBits.h" -#include "llvm/Support/MachineValueType.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetLowering.h" #include "llvm/Target/TargetOptions.h" +#include "llvm/Target/TargetRegisterInfo.h" +#include "llvm/Target/TargetSubtargetInfo.h" #include <algorithm> -#include <cassert> -#include <cstdint> -#include <functional> -#include <iterator> -#include <string> -#include <tuple> -#include <utility> - using namespace llvm; #define DEBUG_TYPE "dagcombine" @@ -84,46 +48,51 @@ STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created"); STATISTIC(OpsNarrowed , "Number of load/op/store narrowed"); STATISTIC(LdStFP2Int , "Number of fp load/store pairs transformed to int"); STATISTIC(SlicedLoads, "Number of load sliced"); -STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops"); -static cl::opt<bool> -CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden, - cl::desc("Enable DAG combiner's use of IR alias analysis")); +namespace { + static cl::opt<bool> + CombinerAA("combiner-alias-analysis", cl::Hidden, + cl::desc("Enable DAG combiner alias-analysis heuristics")); -static cl::opt<bool> -UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true), - cl::desc("Enable DAG combiner's use of TBAA")); + static cl::opt<bool> + CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden, + cl::desc("Enable DAG combiner's use of IR alias analysis")); + + static cl::opt<bool> + UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true), + cl::desc("Enable DAG combiner's use of TBAA")); #ifndef NDEBUG -static cl::opt<std::string> -CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden, - cl::desc("Only use DAG-combiner alias analysis in this" - " function")); + static cl::opt<std::string> + CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden, + cl::desc("Only use DAG-combiner alias analysis in this" + " function")); #endif -/// Hidden option to stress test load slicing, i.e., when this option -/// is enabled, load slicing bypasses most of its profitability guards. -static cl::opt<bool> -StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden, - cl::desc("Bypass the profitability model of load slicing"), - cl::init(false)); + /// Hidden option to stress test load slicing, i.e., when this option + /// is enabled, load slicing bypasses most of its profitability guards. + static cl::opt<bool> + StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden, + cl::desc("Bypass the profitability model of load " + "slicing"), + cl::init(false)); -static cl::opt<bool> - MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true), - cl::desc("DAG combiner may split indexing from loads")); + static cl::opt<bool> + MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true), + cl::desc("DAG combiner may split indexing from loads")); -namespace { +//------------------------------ DAGCombiner ---------------------------------// class DAGCombiner { SelectionDAG &DAG; const TargetLowering &TLI; CombineLevel Level; CodeGenOpt::Level OptLevel; - bool LegalOperations = false; - bool LegalTypes = false; + bool LegalOperations; + bool LegalTypes; bool ForCodeSize; - /// Worklist of all of the nodes that need to be simplified. + /// \brief Worklist of all of the nodes that need to be simplified. /// /// This must behave as a stack -- new nodes to process are pushed onto the /// back and when processing we pop off of the back. @@ -132,21 +101,21 @@ namespace { /// due to nodes being deleted from the underlying DAG. SmallVector<SDNode *, 64> Worklist; - /// Mapping from an SDNode to its position on the worklist. + /// \brief Mapping from an SDNode to its position on the worklist. /// /// This is used to find and remove nodes from the worklist (by nulling /// them) when they are deleted from the underlying DAG. It relies on /// stable indices of nodes within the worklist. DenseMap<SDNode *, unsigned> WorklistMap; - /// Set of nodes which have been combined (at least once). + /// \brief Set of nodes which have been combined (at least once). /// /// This is used to allow us to reliably add any operands of a DAG node /// which have not yet been combined to the worklist. - SmallPtrSet<SDNode *, 32> CombinedNodes; + SmallPtrSet<SDNode *, 64> CombinedNodes; // AA - Used for DAG load/store alias analysis. - AliasAnalysis *AA; + AliasAnalysis &AA; /// When an instruction is simplified, add all users of the instruction to /// the work lists because they might get more simplified now. @@ -159,25 +128,9 @@ namespace { SDValue visit(SDNode *N); public: - DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOpt::Level OL) - : DAG(D), TLI(D.getTargetLoweringInfo()), Level(BeforeLegalizeTypes), - OptLevel(OL), AA(AA) { - ForCodeSize = DAG.getMachineFunction().getFunction().optForSize(); - - MaximumLegalStoreInBits = 0; - for (MVT VT : MVT::all_valuetypes()) - if (EVT(VT).isSimple() && VT != MVT::Other && - TLI.isTypeLegal(EVT(VT)) && - VT.getSizeInBits() >= MaximumLegalStoreInBits) - MaximumLegalStoreInBits = VT.getSizeInBits(); - } - /// Add to the worklist making sure its instance is at the back (next to be /// processed.) void AddToWorklist(SDNode *N) { - assert(N->getOpcode() != ISD::DELETED_NODE && - "Deleted Node added to Worklist"); - // Skip handle nodes as they can't usefully be combined and confuse the // zero-use deletion strategy. if (N->getOpcode() == ISD::HANDLENODE) @@ -222,41 +175,24 @@ namespace { void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO); private: - unsigned MaximumLegalStoreInBits; /// Check the specified integer node value to see if it can be simplified or /// if things it uses can be simplified by bit propagation. /// If so, return true. bool SimplifyDemandedBits(SDValue Op) { - unsigned BitWidth = Op.getScalarValueSizeInBits(); + unsigned BitWidth = Op.getValueType().getScalarType().getSizeInBits(); APInt Demanded = APInt::getAllOnesValue(BitWidth); return SimplifyDemandedBits(Op, Demanded); } - /// Check the specified vector node value to see if it can be simplified or - /// if things it uses can be simplified as it only uses some of the - /// elements. If so, return true. - bool SimplifyDemandedVectorElts(SDValue Op) { - unsigned NumElts = Op.getValueType().getVectorNumElements(); - APInt Demanded = APInt::getAllOnesValue(NumElts); - return SimplifyDemandedVectorElts(Op, Demanded); - } - bool SimplifyDemandedBits(SDValue Op, const APInt &Demanded); - bool SimplifyDemandedVectorElts(SDValue Op, const APInt &Demanded, - bool AssumeSingleUse = false); bool CombineToPreIndexedLoadStore(SDNode *N); bool CombineToPostIndexedLoadStore(SDNode *N); SDValue SplitIndexingFromLoad(LoadSDNode *LD); bool SliceUpLoad(SDNode *N); - // Scalars have size 0 to distinguish from singleton vectors. - SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD); - bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val); - bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val); - - /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed + /// \brief Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed /// load. /// /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced. @@ -264,9 +200,8 @@ namespace { /// \param EltNo index of the vector element to load. /// \param OriginalLoad load that EVE came from to be replaced. /// \returns EVE on success SDValue() on failure. - SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT, - SDValue EltNo, - LoadSDNode *OriginalLoad); + SDValue ReplaceExtractVectorEltOfLoadWithNarrowedLoad( + SDNode *EVE, EVT InVecVT, SDValue EltNo, LoadSDNode *OriginalLoad); void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad); SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace); SDValue SExtPromoteOperand(SDValue Op, EVT PVT); @@ -276,6 +211,10 @@ namespace { SDValue PromoteExtend(SDValue Op); bool PromoteLoad(SDValue Op); + void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs, + SDValue Trunc, SDValue ExtLoad, SDLoc DL, + ISD::NodeType ExtType); + /// Call the node-specific routine that knows how to fold each /// particular type of node. If that doesn't do anything, try the /// target-specific DAG combines. @@ -291,26 +230,15 @@ namespace { SDValue visitTokenFactor(SDNode *N); SDValue visitMERGE_VALUES(SDNode *N); SDValue visitADD(SDNode *N); - SDValue visitADDLike(SDValue N0, SDValue N1, SDNode *LocReference); SDValue visitSUB(SDNode *N); - SDValue visitADDSAT(SDNode *N); - SDValue visitSUBSAT(SDNode *N); SDValue visitADDC(SDNode *N); - SDValue visitUADDO(SDNode *N); - SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N); SDValue visitSUBC(SDNode *N); - SDValue visitUSUBO(SDNode *N); SDValue visitADDE(SDNode *N); - SDValue visitADDCARRY(SDNode *N); - SDValue visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, SDNode *N); SDValue visitSUBE(SDNode *N); - SDValue visitSUBCARRY(SDNode *N); SDValue visitMUL(SDNode *N); SDValue useDivRem(SDNode *N); SDValue visitSDIV(SDNode *N); - SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N); SDValue visitUDIV(SDNode *N); - SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N); SDValue visitREM(SDNode *N); SDValue visitMULHU(SDNode *N); SDValue visitMULHS(SDNode *N); @@ -320,19 +248,16 @@ namespace { SDValue visitUMULO(SDNode *N); SDValue visitIMINMAX(SDNode *N); SDValue visitAND(SDNode *N); - SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N); + SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *LocReference); SDValue visitOR(SDNode *N); - SDValue visitORLike(SDValue N0, SDValue N1, SDNode *N); + SDValue visitORLike(SDValue N0, SDValue N1, SDNode *LocReference); SDValue visitXOR(SDNode *N); SDValue SimplifyVBinOp(SDNode *N); SDValue visitSHL(SDNode *N); SDValue visitSRA(SDNode *N); SDValue visitSRL(SDNode *N); - SDValue visitFunnelShift(SDNode *N); SDValue visitRotate(SDNode *N); - SDValue visitABS(SDNode *N); SDValue visitBSWAP(SDNode *N); - SDValue visitBITREVERSE(SDNode *N); SDValue visitCTLZ(SDNode *N); SDValue visitCTLZ_ZERO_UNDEF(SDNode *N); SDValue visitCTTZ(SDNode *N); @@ -342,14 +267,12 @@ namespace { SDValue visitVSELECT(SDNode *N); SDValue visitSELECT_CC(SDNode *N); SDValue visitSETCC(SDNode *N); - SDValue visitSETCCCARRY(SDNode *N); + SDValue visitSETCCE(SDNode *N); SDValue visitSIGN_EXTEND(SDNode *N); SDValue visitZERO_EXTEND(SDNode *N); SDValue visitANY_EXTEND(SDNode *N); - SDValue visitAssertExt(SDNode *N); SDValue visitSIGN_EXTEND_INREG(SDNode *N); SDValue visitSIGN_EXTEND_VECTOR_INREG(SDNode *N); - SDValue visitZERO_EXTEND_VECTOR_INREG(SDNode *N); SDValue visitTRUNCATE(SDNode *N); SDValue visitBITCAST(SDNode *N); SDValue visitBUILD_PAIR(SDNode *N); @@ -361,7 +284,6 @@ namespace { SDValue visitFREM(SDNode *N); SDValue visitFSQRT(SDNode *N); SDValue visitFCOPYSIGN(SDNode *N); - SDValue visitFPOW(SDNode *N); SDValue visitSINT_TO_FP(SDNode *N); SDValue visitUINT_TO_FP(SDNode *N); SDValue visitFP_TO_SINT(SDNode *N); @@ -376,8 +298,6 @@ namespace { SDValue visitFFLOOR(SDNode *N); SDValue visitFMINNUM(SDNode *N); SDValue visitFMAXNUM(SDNode *N); - SDValue visitFMINIMUM(SDNode *N); - SDValue visitFMAXIMUM(SDNode *N); SDValue visitBRCOND(SDNode *N); SDValue visitBR_CC(SDNode *N); SDValue visitLOAD(SDNode *N); @@ -403,35 +323,21 @@ namespace { SDValue visitFADDForFMACombine(SDNode *N); SDValue visitFSUBForFMACombine(SDNode *N); - SDValue visitFMULForFMADistributiveCombine(SDNode *N); + SDValue visitFMULForFMACombine(SDNode *N); SDValue XformToShuffleWithZero(SDNode *N); - SDValue ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, - SDValue N1, SDNodeFlags Flags); + SDValue ReassociateOps(unsigned Opc, SDLoc DL, SDValue LHS, SDValue RHS); SDValue visitShiftByConstant(SDNode *N, ConstantSDNode *Amt); - SDValue foldSelectOfConstants(SDNode *N); - SDValue foldVSelectOfConstants(SDNode *N); - SDValue foldBinOpIntoSelect(SDNode *BO); bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS); - SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N); - SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2); - SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, - SDValue N2, SDValue N3, ISD::CondCode CC, + SDValue SimplifyBinOpWithSameOpcodeHands(SDNode *N); + SDValue SimplifySelect(SDLoc DL, SDValue N0, SDValue N1, SDValue N2); + SDValue SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1, SDValue N2, + SDValue N3, ISD::CondCode CC, bool NotExtCompare = false); - SDValue convertSelectOfFPConstantsToLoadOffset( - const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3, - ISD::CondCode CC); - SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1, - SDValue N2, SDValue N3, ISD::CondCode CC); - SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1, - const SDLoc &DL); - SDValue unfoldMaskedMerge(SDNode *N); - SDValue unfoldExtremeBitClearingToShifts(SDNode *N); SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond, - const SDLoc &DL, bool foldBooleans); - SDValue rebuildSetCC(SDValue N); + SDLoc DL, bool foldBooleans = true); bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS, SDValue &CC) const; @@ -441,42 +347,32 @@ namespace { unsigned HiOp); SDValue CombineConsecutiveLoads(SDNode *N, EVT VT); SDValue CombineExtLoad(SDNode *N); - SDValue CombineZExtLogicopShiftLoad(SDNode *N); SDValue combineRepeatedFPDivisors(SDNode *N); - SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex); SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT); SDValue BuildSDIV(SDNode *N); SDValue BuildSDIVPow2(SDNode *N); SDValue BuildUDIV(SDNode *N); - SDValue BuildLogBase2(SDValue V, const SDLoc &DL); - SDValue BuildReciprocalEstimate(SDValue Op, SDNodeFlags Flags); - SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags); - SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags); - SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip); - SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations, - SDNodeFlags Flags, bool Reciprocal); - SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations, - SDNodeFlags Flags, bool Reciprocal); + SDValue BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags); + SDValue BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags); + SDValue BuildRsqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations, + SDNodeFlags *Flags); + SDValue BuildRsqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations, + SDNodeFlags *Flags); SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, bool DemandHighBits = true); SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1); SDNode *MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg, SDValue InnerPos, SDValue InnerNeg, unsigned PosOpcode, unsigned NegOpcode, - const SDLoc &DL); - SDNode *MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL); - SDValue MatchLoadCombine(SDNode *N); + SDLoc DL); + SDNode *MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL); SDValue ReduceLoadWidth(SDNode *N); SDValue ReduceLoadOpStoreWidth(SDNode *N); - SDValue splitMergedValStore(StoreSDNode *ST); SDValue TransformFPLoadStorePair(SDNode *N); - SDValue convertBuildVecZextToZext(SDNode *N); SDValue reduceBuildVecExtToExtBuildVec(SDNode *N); - SDValue reduceBuildVecToShuffle(SDNode *N); - SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N, - ArrayRef<int> VectorMask, SDValue VecIn1, - SDValue VecIn2, unsigned LeftIdx); - SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast); + SDValue reduceBuildVecConvertToConvertBuildVec(SDNode *N); + + SDValue GetDemandedBits(SDValue V, const APInt &Mask); /// Walk up chain skipping non-aliasing memory nodes, /// looking for aliasing nodes and adding them to the Aliases vector. @@ -490,29 +386,22 @@ namespace { /// chain (aliasing node.) SDValue FindBetterChain(SDNode *N, SDValue Chain); - /// Try to replace a store and any possibly adjacent stores on - /// consecutive chains with better chains. Return true only if St is - /// replaced. - /// - /// Notice that other chains may still be replaced even if the function - /// returns false. + /// Do FindBetterChain for a store and any possibly adjacent stores on + /// consecutive chains. bool findBetterNeighborChains(StoreSDNode *St); - // Helper for findBetterNeighborChains. Walk up store chain add additional - // chained stores that do not overlap and can be parallelized. - bool parallelizeChainedStores(StoreSDNode *St); - /// Holds a pointer to an LSBaseSDNode as well as information on where it /// is located in a sequence of memory operations connected by a chain. struct MemOpLink { + MemOpLink (LSBaseSDNode *N, int64_t Offset, unsigned Seq): + MemNode(N), OffsetFromBase(Offset), SequenceNum(Seq) { } // Ptr to the mem node. LSBaseSDNode *MemNode; - // Offset from the base ptr. int64_t OffsetFromBase; - - MemOpLink(LSBaseSDNode *N, int64_t Offset) - : MemNode(N), OffsetFromBase(Offset) {} + // What is the sequence number of this mem node. + // Lowest mem operand in the DAG starts at zero. + unsigned SequenceNum; }; /// This is a helper function for visitMUL to check the profitability @@ -523,64 +412,44 @@ namespace { SDValue &AddNode, SDValue &ConstNode); + /// This is a helper function for MergeStoresOfConstantsOrVecElts. Returns a + /// constant build_vector of the stored constant values in Stores. + SDValue getMergedConstantVectorStore(SelectionDAG &DAG, + SDLoc SL, + ArrayRef<MemOpLink> Stores, + SmallVectorImpl<SDValue> &Chains, + EVT Ty) const; + /// This is a helper function for visitAND and visitZERO_EXTEND. Returns /// true if the (and (load x) c) pattern matches an extload. ExtVT returns - /// the type of the loaded value to be extended. + /// the type of the loaded value to be extended. LoadedVT returns the type + /// of the original loaded value. NarrowLoad returns whether the load would + /// need to be narrowed in order to match. bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, - EVT LoadResultTy, EVT &ExtVT); - - /// Helper function to calculate whether the given Load/Store can have its - /// width reduced to ExtVT. - bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType, - EVT &MemVT, unsigned ShAmt = 0); - - /// Used by BackwardsPropagateMask to find suitable loads. - bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads, - SmallPtrSetImpl<SDNode*> &NodesWithConsts, - ConstantSDNode *Mask, SDNode *&NodeToMask); - /// Attempt to propagate a given AND node back to load leaves so that they - /// can be combined into narrow loads. - bool BackwardsPropagateMask(SDNode *N, SelectionDAG &DAG); - - /// Helper function for MergeConsecutiveStores which merges the - /// component store chains. - SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes, - unsigned NumStores); - - /// This is a helper function for MergeConsecutiveStores. When the - /// source elements of the consecutive stores are all constants or - /// all extracted vector elements, try to merge them into one - /// larger store introducing bitcasts if necessary. \return True - /// if a merged store was created. + EVT LoadResultTy, EVT &ExtVT, EVT &LoadedVT, + bool &NarrowLoad); + + /// This is a helper function for MergeConsecutiveStores. When the source + /// elements of the consecutive stores are all constants or all extracted + /// vector elements, try to merge them into one larger store. + /// \return True if a merged store was created. bool MergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores, - bool IsConstantSrc, bool UseVector, - bool UseTrunc); - - /// This is a helper function for MergeConsecutiveStores. Stores - /// that potentially may be merged with St are placed in - /// StoreNodes. RootNode is a chain predecessor to all store - /// candidates. - void getStoreMergeCandidates(StoreSDNode *St, - SmallVectorImpl<MemOpLink> &StoreNodes, - SDNode *&Root); - - /// Helper function for MergeConsecutiveStores. Checks if - /// candidate stores have indirect dependency through their - /// operands. RootNode is the predecessor to all stores calculated - /// by getStoreMergeCandidates and is used to prune the dependency check. - /// \return True if safe to merge. - bool checkMergeStoreCandidatesForDependencies( - SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores, - SDNode *RootNode); + bool IsConstantSrc, bool UseVector); + + /// This is a helper function for MergeConsecutiveStores. + /// Stores that may be merged are placed in StoreNodes. + /// Loads that may alias with those stores are placed in AliasLoadNodes. + void getStoreMergeAndAliasCandidates( + StoreSDNode* St, SmallVectorImpl<MemOpLink> &StoreNodes, + SmallVectorImpl<LSBaseSDNode*> &AliasLoadNodes); /// Merge consecutive store operations into a wide store. /// This optimization uses wide integers or vectors when possible. - /// \return number of stores that were merged into a merged store (the - /// affected nodes are stored as a prefix in \p StoreNodes). - bool MergeConsecutiveStores(StoreSDNode *St); + /// \return True if some memory operations were changed. + bool MergeConsecutiveStores(StoreSDNode *N); - /// Try to transform a truncation where C is a constant: + /// \brief Try to transform a truncation where C is a constant: /// (trunc (and X, C)) -> (and (trunc X), (trunc C)) /// /// \p N needs to be a truncation and its first operand an AND. Other @@ -588,17 +457,13 @@ namespace { /// single-use) and if missed an empty SDValue is returned. SDValue distributeTruncateThroughAnd(SDNode *N); - /// Helper function to determine whether the target supports operation - /// given by \p Opcode for type \p VT, that is, whether the operation - /// is legal or custom before legalizing operations, and whether is - /// legal (but not custom) after legalization. - bool hasOperation(unsigned Opcode, EVT VT) { - if (LegalOperations) - return TLI.isOperationLegal(Opcode, VT); - return TLI.isOperationLegalOrCustom(Opcode, VT); + public: + DAGCombiner(SelectionDAG &D, AliasAnalysis &A, CodeGenOpt::Level OL) + : DAG(D), TLI(D.getTargetLoweringInfo()), Level(BeforeLegalizeTypes), + OptLevel(OL), LegalOperations(false), LegalTypes(false), AA(A) { + ForCodeSize = DAG.getMachineFunction().getFunction()->optForSize(); } - public: /// Runs the dag combiner on all nodes in the work list void Run(CombineLevel AtLevel); @@ -608,7 +473,11 @@ namespace { /// legalization these can be huge. EVT getShiftAmountTy(EVT LHSTy) { assert(LHSTy.isInteger() && "Shift amount is not an integer type!"); - return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout(), LegalTypes); + if (LHSTy.isVector()) + return LHSTy; + auto &DL = DAG.getDataLayout(); + return LegalTypes ? TLI.getScalarShiftAmountTy(DL, LHSTy) + : TLI.getPointerTy(DL); } /// This method returns true if we are running before type legalization or @@ -622,17 +491,15 @@ namespace { EVT getSetCCResultType(EVT VT) const { return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); } - - void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs, - SDValue OrigLoad, SDValue ExtLoad, - ISD::NodeType ExtType); }; +} + +namespace { /// This class is a DAGUpdateListener that removes any deleted /// nodes from the worklist. class WorklistRemover : public SelectionDAG::DAGUpdateListener { DAGCombiner &DC; - public: explicit WorklistRemover(DAGCombiner &dc) : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {} @@ -641,8 +508,7 @@ public: DC.removeFromWorklist(N); } }; - -} // end anonymous namespace +} //===----------------------------------------------------------------------===// // TargetLowering::DAGCombinerInfo implementation @@ -652,6 +518,10 @@ void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) { ((DAGCombiner*)DC)->AddToWorklist(N); } +void TargetLowering::DAGCombinerInfo::RemoveFromWorklist(SDNode *N) { + ((DAGCombiner*)DC)->removeFromWorklist(N); +} + SDValue TargetLowering::DAGCombinerInfo:: CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) { return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo); @@ -662,6 +532,7 @@ CombineTo(SDNode *N, SDValue Res, bool AddTo) { return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo); } + SDValue TargetLowering::DAGCombinerInfo:: CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) { return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo); @@ -701,34 +572,25 @@ static char isNegatibleForFree(SDValue Op, bool LegalOperations, // fneg is removable even if it has multiple uses. if (Op.getOpcode() == ISD::FNEG) return 2; - // Don't allow anything with multiple uses unless we know it is free. - EVT VT = Op.getValueType(); - const SDNodeFlags Flags = Op->getFlags(); - if (!Op.hasOneUse()) - if (!(Op.getOpcode() == ISD::FP_EXTEND && - TLI.isFPExtFree(VT, Op.getOperand(0).getValueType()))) - return 0; + // Don't allow anything with multiple uses. + if (!Op.hasOneUse()) return 0; // Don't recurse exponentially. if (Depth > 6) return 0; switch (Op.getOpcode()) { default: return false; - case ISD::ConstantFP: { - if (!LegalOperations) - return 1; - - // Don't invert constant FP values after legalization unless the target says - // the negated constant is legal. - return TLI.isOperationLegal(ISD::ConstantFP, VT) || - TLI.isFPImmLegal(neg(cast<ConstantFPSDNode>(Op)->getValueAPF()), VT); - } + case ISD::ConstantFP: + // Don't invert constant FP values after legalize. The negated constant + // isn't necessarily legal. + return LegalOperations ? 0 : 1; case ISD::FADD: - if (!Options->UnsafeFPMath && !Flags.hasNoSignedZeros()) - return 0; + // FIXME: determine better conditions for this xform. + if (!Options->UnsafeFPMath) return 0; // After operation legalization, it might not be legal to create new FSUBs. - if (LegalOperations && !TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) + if (LegalOperations && + !TLI.isOperationLegalOrCustom(ISD::FSUB, Op.getValueType())) return 0; // fold (fneg (fadd A, B)) -> (fsub (fneg A), B) @@ -740,15 +602,15 @@ static char isNegatibleForFree(SDValue Op, bool LegalOperations, Depth + 1); case ISD::FSUB: // We can't turn -(A-B) into B-A when we honor signed zeros. - if (!Options->NoSignedZerosFPMath && - !Flags.hasNoSignedZeros()) - return 0; + if (!Options->UnsafeFPMath) return 0; // fold (fneg (fsub A, B)) -> (fsub B, A) return 1; case ISD::FMUL: case ISD::FDIV: + if (Options->HonorSignDependentRoundingFPMath()) return 0; + // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) or (fmul X, (fneg Y)) if (char V = isNegatibleForFree(Op.getOperand(0), LegalOperations, TLI, Options, Depth + 1)) @@ -772,9 +634,12 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, // fneg is removable even if it has multiple uses. if (Op.getOpcode() == ISD::FNEG) return Op.getOperand(0); + // Don't allow anything with multiple uses. + assert(Op.hasOneUse() && "Unknown reuse!"); + assert(Depth <= 6 && "GetNegatedExpression doesn't match isNegatibleForFree"); - const SDNodeFlags Flags = Op.getNode()->getFlags(); + const SDNodeFlags *Flags = Op.getNode()->getFlags(); switch (Op.getOpcode()) { default: llvm_unreachable("Unknown code"); @@ -784,7 +649,8 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, return DAG.getConstantFP(V, SDLoc(Op), Op.getValueType()); } case ISD::FADD: - assert(Options.UnsafeFPMath || Flags.hasNoSignedZeros()); + // FIXME: determine better conditions for this xform. + assert(Options.UnsafeFPMath); // fold (fneg (fadd A, B)) -> (fsub (fneg A), B) if (isNegatibleForFree(Op.getOperand(0), LegalOperations, @@ -799,6 +665,9 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, LegalOperations, Depth+1), Op.getOperand(0), Flags); case ISD::FSUB: + // We can't turn -(A-B) into B-A when we honor signed zeros. + assert(Options.UnsafeFPMath); + // fold (fneg (fsub 0, B)) -> B if (ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(Op.getOperand(0))) if (N0CFP->isZero()) @@ -810,6 +679,8 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, case ISD::FMUL: case ISD::FDIV: + assert(!Options.HonorSignDependentRoundingFPMath()); + // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) if (isNegatibleForFree(Op.getOperand(0), LegalOperations, DAG.getTargetLoweringInfo(), &Options, Depth+1)) @@ -837,15 +708,6 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, } } -// APInts must be the same size for most operations, this helper -// function zero extends the shorter of the pair so that they match. -// We provide an Offset so that we can create bitwidths that won't overflow. -static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) { - unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth()); - LHS = LHS.zextOrSelf(Bits); - RHS = RHS.zextOrSelf(Bits); -} - // Return true if this node is a setcc, or is a select_cc // that selects between the target values used for true and false, making it // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to @@ -885,7 +747,33 @@ bool DAGCombiner::isOneUseSetCC(SDValue N) const { return false; } -// Returns the SDNode if it is a constant float BuildVector +/// Returns true if N is a BUILD_VECTOR node whose +/// elements are all the same constant or undefined. +static bool isConstantSplatVector(SDNode *N, APInt& SplatValue) { + BuildVectorSDNode *C = dyn_cast<BuildVectorSDNode>(N); + if (!C) + return false; + + APInt SplatUndef; + unsigned SplatBitSize; + bool HasAnyUndefs; + EVT EltVT = N->getValueType(0).getVectorElementType(); + return (C->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, + HasAnyUndefs) && + EltVT.getSizeInBits() >= SplatBitSize); +} + +// \brief Returns the SDNode if it is a constant integer BuildVector +// or constant integer. +static SDNode *isConstantIntBuildVectorOrConstantInt(SDValue N) { + if (isa<ConstantSDNode>(N)) + return N.getNode(); + if (ISD::isBuildVectorOfConstantSDNodes(N.getNode())) + return N.getNode(); + return nullptr; +} + +// \brief Returns the SDNode if it is a constant float BuildVector // or constant float. static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) { if (isa<ConstantFPSDNode>(N)) @@ -895,45 +783,50 @@ static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) { return nullptr; } -// Determines if it is a constant integer or a build vector of constant -// integers (and undefs). -// Do not permit build vector implicit truncation. -static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) { - if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N)) - return !(Const->isOpaque() && NoOpaques); - if (N.getOpcode() != ISD::BUILD_VECTOR) - return false; - unsigned BitWidth = N.getScalarValueSizeInBits(); - for (const SDValue &Op : N->op_values()) { - if (Op.isUndef()) - continue; - ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op); - if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth || - (Const->isOpaque() && NoOpaques)) - return false; +// \brief Returns the SDNode if it is a constant splat BuildVector or constant +// int. +static ConstantSDNode *isConstOrConstSplat(SDValue N) { + if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N)) + return CN; + + if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) { + BitVector UndefElements; + ConstantSDNode *CN = BV->getConstantSplatNode(&UndefElements); + + // BuildVectors can truncate their operands. Ignore that case here. + // FIXME: We blindly ignore splats which include undef which is overly + // pessimistic. + if (CN && UndefElements.none() && + CN->getValueType(0) == N.getValueType().getScalarType()) + return CN; } - return true; -} -// Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with -// undef's. -static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) { - if (V.getOpcode() != ISD::BUILD_VECTOR) - return false; - return isConstantOrConstantVector(V, NoOpaques) || - ISD::isBuildVectorOfConstantFPSDNodes(V.getNode()); + return nullptr; } -SDValue DAGCombiner::ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, - SDValue N1, SDNodeFlags Flags) { - // Don't reassociate reductions. - if (Flags.hasVectorReduction()) - return SDValue(); +// \brief Returns the SDNode if it is a constant splat BuildVector or constant +// float. +static ConstantFPSDNode *isConstOrConstSplatFP(SDValue N) { + if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N)) + return CN; + + if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) { + BitVector UndefElements; + ConstantFPSDNode *CN = BV->getConstantFPSplatNode(&UndefElements); + + if (CN && UndefElements.none()) + return CN; + } + + return nullptr; +} +SDValue DAGCombiner::ReassociateOps(unsigned Opc, SDLoc DL, + SDValue N0, SDValue N1) { EVT VT = N0.getValueType(); - if (N0.getOpcode() == Opc && !N0->getFlags().hasVectorReduction()) { - if (SDNode *L = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1))) { - if (SDNode *R = DAG.isConstantIntBuildVectorOrConstantInt(N1)) { + if (N0.getOpcode() == Opc) { + if (SDNode *L = isConstantIntBuildVectorOrConstantInt(N0.getOperand(1))) { + if (SDNode *R = isConstantIntBuildVectorOrConstantInt(N1)) { // reassoc. (op (op x, c1), c2) -> (op x, (op c1, c2)) if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, L, R)) return DAG.getNode(Opc, DL, VT, N0.getOperand(0), OpNode); @@ -951,18 +844,18 @@ SDValue DAGCombiner::ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, } } - if (N1.getOpcode() == Opc && !N1->getFlags().hasVectorReduction()) { - if (SDNode *R = DAG.isConstantIntBuildVectorOrConstantInt(N1.getOperand(1))) { - if (SDNode *L = DAG.isConstantIntBuildVectorOrConstantInt(N0)) { + if (N1.getOpcode() == Opc) { + if (SDNode *R = isConstantIntBuildVectorOrConstantInt(N1.getOperand(1))) { + if (SDNode *L = isConstantIntBuildVectorOrConstantInt(N0)) { // reassoc. (op c2, (op x, c1)) -> (op x, (op c1, c2)) if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, R, L)) return DAG.getNode(Opc, DL, VT, N1.getOperand(0), OpNode); return SDValue(); } if (N1.hasOneUse()) { - // reassoc. (op x, (op y, c1)) -> (op (op x, y), c1) iff x+c1 has one + // reassoc. (op y, (op x, c1)) -> (op (op x, y), c1) iff x+c1 has one // use - SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N0, N1.getOperand(0)); + SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N1.getOperand(0), N0); if (!OpNode.getNode()) return SDValue(); AddToWorklist(OpNode.getNode()); @@ -978,9 +871,11 @@ SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo, bool AddTo) { assert(N->getNumValues() == NumTo && "Broken CombineTo call!"); ++NodesCombined; - LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: "; - To[0].getNode()->dump(&DAG); - dbgs() << " and " << NumTo - 1 << " other values\n"); + DEBUG(dbgs() << "\nReplacing.1 "; + N->dump(&DAG); + dbgs() << "\nWith: "; + To[0].getNode()->dump(&DAG); + dbgs() << " and " << NumTo-1 << " other values\n"); for (unsigned i = 0, e = NumTo; i != e; ++i) assert((!To[i].getNode() || N->getValueType(i) == To[i].getValueType()) && @@ -1028,32 +923,8 @@ CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) { /// things it uses can be simplified by bit propagation. If so, return true. bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &Demanded) { TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations); - KnownBits Known; - if (!TLI.SimplifyDemandedBits(Op, Demanded, Known, TLO)) - return false; - - // Revisit the node. - AddToWorklist(Op.getNode()); - - // Replace the old value with the new one. - ++NodesCombined; - LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG); - dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG); - dbgs() << '\n'); - - CommitTargetLoweringOpt(TLO); - return true; -} - -/// Check the specified vector node value to see if it can be simplified or -/// if things it uses can be simplified as it only uses some of the elements. -/// If so, return true. -bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op, const APInt &Demanded, - bool AssumeSingleUse) { - TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations); - APInt KnownUndef, KnownZero; - if (!TLI.SimplifyDemandedVectorElts(Op, Demanded, KnownUndef, KnownZero, TLO, - 0, AssumeSingleUse)) + APInt KnownZero, KnownOne; + if (!TLI.SimplifyDemandedBits(Op, Demanded, KnownZero, KnownOne, TLO)) return false; // Revisit the node. @@ -1061,21 +932,26 @@ bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op, const APInt &Demanded, // Replace the old value with the new one. ++NodesCombined; - LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG); - dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG); - dbgs() << '\n'); + DEBUG(dbgs() << "\nReplacing.2 "; + TLO.Old.getNode()->dump(&DAG); + dbgs() << "\nWith: "; + TLO.New.getNode()->dump(&DAG); + dbgs() << '\n'); CommitTargetLoweringOpt(TLO); return true; } void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) { - SDLoc DL(Load); + SDLoc dl(Load); EVT VT = Load->getValueType(0); - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0)); + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, dl, VT, SDValue(ExtLoad, 0)); - LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: "; - Trunc.getNode()->dump(&DAG); dbgs() << '\n'); + DEBUG(dbgs() << "\nReplacing.9 "; + Load->dump(&DAG); + dbgs() << "\nWith: "; + Trunc.getNode()->dump(&DAG); + dbgs() << '\n'); WorklistRemover DeadNodes(*this); DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc); DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1)); @@ -1085,14 +961,15 @@ void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) { SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) { Replace = false; - SDLoc DL(Op); - if (ISD::isUNINDEXEDLoad(Op.getNode())) { - LoadSDNode *LD = cast<LoadSDNode>(Op); + SDLoc dl(Op); + if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Op)) { EVT MemVT = LD->getMemoryVT(); - ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD - : LD->getExtensionType(); + ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) + ? (TLI.isLoadExtLegal(ISD::ZEXTLOAD, PVT, MemVT) ? ISD::ZEXTLOAD + : ISD::EXTLOAD) + : LD->getExtensionType(); Replace = true; - return DAG.getExtLoad(ExtType, DL, PVT, + return DAG.getExtLoad(ExtType, dl, PVT, LD->getChain(), LD->getBasePtr(), MemVT, LD->getMemOperand()); } @@ -1101,30 +978,30 @@ SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) { switch (Opc) { default: break; case ISD::AssertSext: - if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT)) - return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1)); - break; + return DAG.getNode(ISD::AssertSext, dl, PVT, + SExtPromoteOperand(Op.getOperand(0), PVT), + Op.getOperand(1)); case ISD::AssertZext: - if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT)) - return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1)); - break; + return DAG.getNode(ISD::AssertZext, dl, PVT, + ZExtPromoteOperand(Op.getOperand(0), PVT), + Op.getOperand(1)); case ISD::Constant: { unsigned ExtOpc = Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; - return DAG.getNode(ExtOpc, DL, PVT, Op); + return DAG.getNode(ExtOpc, dl, PVT, Op); } } if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT)) return SDValue(); - return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op); + return DAG.getNode(ISD::ANY_EXTEND, dl, PVT, Op); } SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) { if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT)) return SDValue(); EVT OldVT = Op.getValueType(); - SDLoc DL(Op); + SDLoc dl(Op); bool Replace = false; SDValue NewOp = PromoteOperand(Op, PVT, Replace); if (!NewOp.getNode()) @@ -1133,13 +1010,13 @@ SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) { if (Replace) ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode()); - return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp, + return DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, NewOp.getValueType(), NewOp, DAG.getValueType(OldVT)); } SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) { EVT OldVT = Op.getValueType(); - SDLoc DL(Op); + SDLoc dl(Op); bool Replace = false; SDValue NewOp = PromoteOperand(Op, PVT, Replace); if (!NewOp.getNode()) @@ -1148,7 +1025,7 @@ SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) { if (Replace) ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode()); - return DAG.getZeroExtendInReg(NewOp, DL, OldVT); + return DAG.getZeroExtendInReg(NewOp, dl, OldVT); } /// Promote the specified integer binary operation if the target indicates it is @@ -1174,44 +1051,37 @@ SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) { if (TLI.IsDesirableToPromoteOp(Op, PVT)) { assert(PVT != VT && "Don't know what type to promote to!"); - LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG)); - bool Replace0 = false; SDValue N0 = Op.getOperand(0); SDValue NN0 = PromoteOperand(N0, PVT, Replace0); + if (!NN0.getNode()) + return SDValue(); bool Replace1 = false; SDValue N1 = Op.getOperand(1); - SDValue NN1 = PromoteOperand(N1, PVT, Replace1); - SDLoc DL(Op); - - SDValue RV = - DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1)); - - // We are always replacing N0/N1's use in N and only need - // additional replacements if there are additional uses. - Replace0 &= !N0->hasOneUse(); - Replace1 &= (N0 != N1) && !N1->hasOneUse(); - - // Combine Op here so it is preserved past replacements. - CombineTo(Op.getNode(), RV); - - // If operands have a use ordering, make sure we deal with - // predecessor first. - if (Replace0 && Replace1 && N0.getNode()->isPredecessorOf(N1.getNode())) { - std::swap(N0, N1); - std::swap(NN0, NN1); + SDValue NN1; + if (N0 == N1) + NN1 = NN0; + else { + NN1 = PromoteOperand(N1, PVT, Replace1); + if (!NN1.getNode()) + return SDValue(); } - if (Replace0) { - AddToWorklist(NN0.getNode()); - ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode()); - } - if (Replace1) { + AddToWorklist(NN0.getNode()); + if (NN1.getNode()) AddToWorklist(NN1.getNode()); + + if (Replace0) + ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode()); + if (Replace1) ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode()); - } - return Op; + + DEBUG(dbgs() << "\nPromoting "; + Op.getNode()->dump(&DAG)); + SDLoc dl(Op); + return DAG.getNode(ISD::TRUNCATE, dl, VT, + DAG.getNode(Opc, dl, PVT, NN0, NN1)); } return SDValue(); } @@ -1239,32 +1109,26 @@ SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) { if (TLI.IsDesirableToPromoteOp(Op, PVT)) { assert(PVT != VT && "Don't know what type to promote to!"); - LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG)); - bool Replace = false; SDValue N0 = Op.getOperand(0); - SDValue N1 = Op.getOperand(1); if (Opc == ISD::SRA) - N0 = SExtPromoteOperand(N0, PVT); + N0 = SExtPromoteOperand(Op.getOperand(0), PVT); else if (Opc == ISD::SRL) - N0 = ZExtPromoteOperand(N0, PVT); + N0 = ZExtPromoteOperand(Op.getOperand(0), PVT); else N0 = PromoteOperand(N0, PVT, Replace); - if (!N0.getNode()) return SDValue(); - SDLoc DL(Op); - SDValue RV = - DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1)); - AddToWorklist(N0.getNode()); if (Replace) ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode()); - // Deal with Op being deleted. - if (Op && Op.getOpcode() != ISD::DELETED_NODE) - return RV; + DEBUG(dbgs() << "\nPromoting "; + Op.getNode()->dump(&DAG)); + SDLoc dl(Op); + return DAG.getNode(ISD::TRUNCATE, dl, VT, + DAG.getNode(Opc, dl, PVT, N0, Op.getOperand(1))); } return SDValue(); } @@ -1291,7 +1155,8 @@ SDValue DAGCombiner::PromoteExtend(SDValue Op) { // fold (aext (aext x)) -> (aext x) // fold (aext (zext x)) -> (zext x) // fold (aext (sext x)) -> (sext x) - LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG)); + DEBUG(dbgs() << "\nPromoting "; + Op.getNode()->dump(&DAG)); return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0)); } return SDValue(); @@ -1301,9 +1166,6 @@ bool DAGCombiner::PromoteLoad(SDValue Op) { if (!LegalOperations) return false; - if (!ISD::isUNINDEXEDLoad(Op.getNode())) - return false; - EVT VT = Op.getValueType(); if (VT.isVector() || !VT.isInteger()) return false; @@ -1320,19 +1182,24 @@ bool DAGCombiner::PromoteLoad(SDValue Op) { if (TLI.IsDesirableToPromoteOp(Op, PVT)) { assert(PVT != VT && "Don't know what type to promote to!"); - SDLoc DL(Op); + SDLoc dl(Op); SDNode *N = Op.getNode(); LoadSDNode *LD = cast<LoadSDNode>(N); EVT MemVT = LD->getMemoryVT(); - ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD - : LD->getExtensionType(); - SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT, + ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) + ? (TLI.isLoadExtLegal(ISD::ZEXTLOAD, PVT, MemVT) ? ISD::ZEXTLOAD + : ISD::EXTLOAD) + : LD->getExtensionType(); + SDValue NewLD = DAG.getExtLoad(ExtType, dl, PVT, LD->getChain(), LD->getBasePtr(), MemVT, LD->getMemOperand()); - SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD); + SDValue Result = DAG.getNode(ISD::TRUNCATE, dl, VT, NewLD); - LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: "; - Result.getNode()->dump(&DAG); dbgs() << '\n'); + DEBUG(dbgs() << "\nPromoting "; + N->dump(&DAG); + dbgs() << "\nTo: "; + Result.getNode()->dump(&DAG); + dbgs() << '\n'); WorklistRemover DeadNodes(*this); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1)); @@ -1343,7 +1210,7 @@ bool DAGCombiner::PromoteLoad(SDValue Op) { return false; } -/// Recursively delete a node which has no uses and any operands for +/// \brief Recursively delete a node which has no uses and any operands for /// which it is the only use. /// /// Note that this both deletes the nodes and removes them from the worklist. @@ -1392,7 +1259,8 @@ void DAGCombiner::Run(CombineLevel AtLevel) { // changes of the root. HandleSDNode Dummy(DAG.getRoot()); - // While the worklist isn't empty, find a node and try to combine it. + // while the worklist isn't empty, find a node and + // try and combine it. while (!WorklistMap.empty()) { SDNode *N; // The Worklist holds the SDNodes in order, but it may contain null entries. @@ -1427,7 +1295,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) { continue; } - LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG)); + DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG)); // Add any operands of the new node which have not yet been combined to the // worklist as well. Because the worklist uniques things already, this @@ -1452,17 +1320,21 @@ void DAGCombiner::Run(CombineLevel AtLevel) { continue; assert(N->getOpcode() != ISD::DELETED_NODE && - RV.getOpcode() != ISD::DELETED_NODE && + RV.getNode()->getOpcode() != ISD::DELETED_NODE && "Node was deleted but visit returned new node!"); - LLVM_DEBUG(dbgs() << " ... into: "; RV.getNode()->dump(&DAG)); + DEBUG(dbgs() << " ... into: "; + RV.getNode()->dump(&DAG)); + // Transfer debug value. + DAG.TransferDbgValues(SDValue(N, 0), RV); if (N->getNumValues() == RV.getNode()->getNumValues()) DAG.ReplaceAllUsesWith(N, RV.getNode()); else { assert(N->getValueType(0) == RV.getValueType() && N->getNumValues() == 1 && "Type mismatch"); - DAG.ReplaceAllUsesWith(N, &RV); + SDValue OpV = RV; + DAG.ReplaceAllUsesWith(N, &OpV); } // Push the new node and any users onto the worklist @@ -1488,18 +1360,10 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::MERGE_VALUES: return visitMERGE_VALUES(N); case ISD::ADD: return visitADD(N); case ISD::SUB: return visitSUB(N); - case ISD::SADDSAT: - case ISD::UADDSAT: return visitADDSAT(N); - case ISD::SSUBSAT: - case ISD::USUBSAT: return visitSUBSAT(N); case ISD::ADDC: return visitADDC(N); - case ISD::UADDO: return visitUADDO(N); case ISD::SUBC: return visitSUBC(N); - case ISD::USUBO: return visitUSUBO(N); case ISD::ADDE: return visitADDE(N); - case ISD::ADDCARRY: return visitADDCARRY(N); case ISD::SUBE: return visitSUBE(N); - case ISD::SUBCARRY: return visitSUBCARRY(N); case ISD::MUL: return visitMUL(N); case ISD::SDIV: return visitSDIV(N); case ISD::UDIV: return visitUDIV(N); @@ -1523,11 +1387,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::SRL: return visitSRL(N); case ISD::ROTR: case ISD::ROTL: return visitRotate(N); - case ISD::FSHL: - case ISD::FSHR: return visitFunnelShift(N); - case ISD::ABS: return visitABS(N); case ISD::BSWAP: return visitBSWAP(N); - case ISD::BITREVERSE: return visitBITREVERSE(N); case ISD::CTLZ: return visitCTLZ(N); case ISD::CTLZ_ZERO_UNDEF: return visitCTLZ_ZERO_UNDEF(N); case ISD::CTTZ: return visitCTTZ(N); @@ -1537,15 +1397,12 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::VSELECT: return visitVSELECT(N); case ISD::SELECT_CC: return visitSELECT_CC(N); case ISD::SETCC: return visitSETCC(N); - case ISD::SETCCCARRY: return visitSETCCCARRY(N); + case ISD::SETCCE: return visitSETCCE(N); case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N); case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N); case ISD::ANY_EXTEND: return visitANY_EXTEND(N); - case ISD::AssertSext: - case ISD::AssertZext: return visitAssertExt(N); case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N); case ISD::SIGN_EXTEND_VECTOR_INREG: return visitSIGN_EXTEND_VECTOR_INREG(N); - case ISD::ZERO_EXTEND_VECTOR_INREG: return visitZERO_EXTEND_VECTOR_INREG(N); case ISD::TRUNCATE: return visitTRUNCATE(N); case ISD::BITCAST: return visitBITCAST(N); case ISD::BUILD_PAIR: return visitBUILD_PAIR(N); @@ -1557,7 +1414,6 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::FREM: return visitFREM(N); case ISD::FSQRT: return visitFSQRT(N); case ISD::FCOPYSIGN: return visitFCOPYSIGN(N); - case ISD::FPOW: return visitFPOW(N); case ISD::SINT_TO_FP: return visitSINT_TO_FP(N); case ISD::UINT_TO_FP: return visitUINT_TO_FP(N); case ISD::FP_TO_SINT: return visitFP_TO_SINT(N); @@ -1570,8 +1426,6 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::FFLOOR: return visitFFLOOR(N); case ISD::FMINNUM: return visitFMINNUM(N); case ISD::FMAXNUM: return visitFMAXNUM(N); - case ISD::FMINIMUM: return visitFMINIMUM(N); - case ISD::FMAXIMUM: return visitFMAXIMUM(N); case ISD::FCEIL: return visitFCEIL(N); case ISD::FTRUNC: return visitFTRUNC(N); case ISD::BRCOND: return visitBRCOND(N); @@ -1644,15 +1498,15 @@ SDValue DAGCombiner::combine(SDNode *N) { } } - // If N is a commutative binary node, try eliminate it if the commuted - // version is already present in the DAG. - if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode()) && + // If N is a commutative binary node, try commuting it to enable more + // sdisel CSE. + if (!RV.getNode() && SelectionDAG::isCommutativeBinOp(N->getOpcode()) && N->getNumValues() == 1) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); // Constant operands are canonicalized to RHS. - if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) { + if (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1)) { SDValue Ops[] = {N1, N0}; SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops, N->getFlags()); @@ -1689,12 +1543,8 @@ SDValue DAGCombiner::visitTokenFactor(SDNode *N) { return N->getOperand(1); } - // Don't simplify token factors if optnone. - if (OptLevel == CodeGenOpt::None) - return SDValue(); - SmallVector<SDNode *, 8> TFs; // List of token factors to visit. - SmallVector<SDValue, 8> Ops; // Ops for replacing token factor. + SmallVector<SDValue, 8> Ops; // Ops for replacing token factor. SmallPtrSet<SDNode*, 16> SeenOps; bool Changed = false; // If we should replace this token factor. @@ -1708,6 +1558,7 @@ SDValue DAGCombiner::visitTokenFactor(SDNode *N) { // Check each of the operands. for (const SDValue &Op : TF->op_values()) { + switch (Op.getOpcode()) { case ISD::EntryToken: // Entry tokens don't need to be added to the list. They are @@ -1716,7 +1567,8 @@ SDValue DAGCombiner::visitTokenFactor(SDNode *N) { break; case ISD::TokenFactor: - if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) { + if (Op.hasOneUse() && + std::find(TFs.begin(), TFs.end(), Op.getNode()) == TFs.end()) { // Queue up for processing. TFs.push_back(Op.getNode()); // Clean up in case the token factor is removed. @@ -1724,7 +1576,7 @@ SDValue DAGCombiner::visitTokenFactor(SDNode *N) { Changed = true; break; } - LLVM_FALLTHROUGH; + // Fall thru default: // Only add if it isn't already in the list. @@ -1737,108 +1589,26 @@ SDValue DAGCombiner::visitTokenFactor(SDNode *N) { } } - // Remove Nodes that are chained to another node in the list. Do so - // by walking up chains breath-first stopping when we've seen - // another operand. In general we must climb to the EntryNode, but we can exit - // early if we find all remaining work is associated with just one operand as - // no further pruning is possible. - - // List of nodes to search through and original Ops from which they originate. - SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist; - SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op. - SmallPtrSet<SDNode *, 16> SeenChains; - bool DidPruneOps = false; - - unsigned NumLeftToConsider = 0; - for (const SDValue &Op : Ops) { - Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++)); - OpWorkCount.push_back(1); - } - - auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) { - // If this is an Op, we can remove the op from the list. Remark any - // search associated with it as from the current OpNumber. - if (SeenOps.count(Op) != 0) { - Changed = true; - DidPruneOps = true; - unsigned OrigOpNumber = 0; - while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op) - OrigOpNumber++; - assert((OrigOpNumber != Ops.size()) && - "expected to find TokenFactor Operand"); - // Re-mark worklist from OrigOpNumber to OpNumber - for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) { - if (Worklist[i].second == OrigOpNumber) { - Worklist[i].second = OpNumber; - } - } - OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber]; - OpWorkCount[OrigOpNumber] = 0; - NumLeftToConsider--; - } - // Add if it's a new chain - if (SeenChains.insert(Op).second) { - OpWorkCount[OpNumber]++; - Worklist.push_back(std::make_pair(Op, OpNumber)); - } - }; - - for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) { - // We need at least be consider at least 2 Ops to prune. - if (NumLeftToConsider <= 1) - break; - auto CurNode = Worklist[i].first; - auto CurOpNumber = Worklist[i].second; - assert((OpWorkCount[CurOpNumber] > 0) && - "Node should not appear in worklist"); - switch (CurNode->getOpcode()) { - case ISD::EntryToken: - // Hitting EntryToken is the only way for the search to terminate without - // hitting - // another operand's search. Prevent us from marking this operand - // considered. - NumLeftToConsider++; - break; - case ISD::TokenFactor: - for (const SDValue &Op : CurNode->op_values()) - AddToWorklist(i, Op.getNode(), CurOpNumber); - break; - case ISD::CopyFromReg: - case ISD::CopyToReg: - AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber); - break; - default: - if (auto *MemNode = dyn_cast<MemSDNode>(CurNode)) - AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber); - break; - } - OpWorkCount[CurOpNumber]--; - if (OpWorkCount[CurOpNumber] == 0) - NumLeftToConsider--; - } + SDValue Result; // If we've changed things around then replace token factor. if (Changed) { - SDValue Result; if (Ops.empty()) { // The entry token is the only possible outcome. Result = DAG.getEntryNode(); } else { - if (DidPruneOps) { - SmallVector<SDValue, 8> PrunedOps; - // - for (const SDValue &Op : Ops) { - if (SeenChains.count(Op.getNode()) == 0) - PrunedOps.push_back(Op); - } - Result = DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, PrunedOps); - } else { - Result = DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Ops); - } + // New and improved token factor. + Result = DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Ops); } - return Result; + + // Add users to worklist if AA is enabled, since it may introduce + // a lot of new chained token factors while removing memory deps. + bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA + : DAG.getSubtarget().useAA(); + return CombineTo(N, Result, UseAA /*add to worklist*/); } - return SDValue(); + + return Result; } /// MERGE_VALUES can always be eliminated. @@ -1851,179 +1621,24 @@ SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) { // can be tried again once they have new operands. AddUsersToWorklist(N); do { - // Do as a single replacement to avoid rewalking use lists. - SmallVector<SDValue, 8> Ops; for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) - Ops.push_back(N->getOperand(i)); - DAG.ReplaceAllUsesWith(N, Ops.data()); + DAG.ReplaceAllUsesOfValueWith(SDValue(N, i), N->getOperand(i)); } while (!N->use_empty()); deleteAndRecombine(N); return SDValue(N, 0); // Return N so it doesn't get rechecked! } -/// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a -/// ConstantSDNode pointer else nullptr. +/// If \p N is a ContantSDNode with isOpaque() == false return it casted to a +/// ContantSDNode pointer else nullptr. static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) { ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N); return Const != nullptr && !Const->isOpaque() ? Const : nullptr; } -SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) { - assert(ISD::isBinaryOp(BO) && "Unexpected binary operator"); - - // Don't do this unless the old select is going away. We want to eliminate the - // binary operator, not replace a binop with a select. - // TODO: Handle ISD::SELECT_CC. - unsigned SelOpNo = 0; - SDValue Sel = BO->getOperand(0); - if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) { - SelOpNo = 1; - Sel = BO->getOperand(1); - } - - if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) - return SDValue(); - - SDValue CT = Sel.getOperand(1); - if (!isConstantOrConstantVector(CT, true) && - !isConstantFPBuildVectorOrConstantFP(CT)) - return SDValue(); - - SDValue CF = Sel.getOperand(2); - if (!isConstantOrConstantVector(CF, true) && - !isConstantFPBuildVectorOrConstantFP(CF)) - return SDValue(); - - // Bail out if any constants are opaque because we can't constant fold those. - // The exception is "and" and "or" with either 0 or -1 in which case we can - // propagate non constant operands into select. I.e.: - // and (select Cond, 0, -1), X --> select Cond, 0, X - // or X, (select Cond, -1, 0) --> select Cond, -1, X - auto BinOpcode = BO->getOpcode(); - bool CanFoldNonConst = - (BinOpcode == ISD::AND || BinOpcode == ISD::OR) && - (isNullOrNullSplat(CT) || isAllOnesOrAllOnesSplat(CT)) && - (isNullOrNullSplat(CF) || isAllOnesOrAllOnesSplat(CF)); - - SDValue CBO = BO->getOperand(SelOpNo ^ 1); - if (!CanFoldNonConst && - !isConstantOrConstantVector(CBO, true) && - !isConstantFPBuildVectorOrConstantFP(CBO)) - return SDValue(); - - EVT VT = Sel.getValueType(); - - // In case of shift value and shift amount may have different VT. For instance - // on x86 shift amount is i8 regardles of LHS type. Bail out if we have - // swapped operands and value types do not match. NB: x86 is fine if operands - // are not swapped with shift amount VT being not bigger than shifted value. - // TODO: that is possible to check for a shift operation, correct VTs and - // still perform optimization on x86 if needed. - if (SelOpNo && VT != CBO.getValueType()) - return SDValue(); - - // We have a select-of-constants followed by a binary operator with a - // constant. Eliminate the binop by pulling the constant math into the select. - // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO - SDLoc DL(Sel); - SDValue NewCT = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CT) - : DAG.getNode(BinOpcode, DL, VT, CT, CBO); - if (!CanFoldNonConst && !NewCT.isUndef() && - !isConstantOrConstantVector(NewCT, true) && - !isConstantFPBuildVectorOrConstantFP(NewCT)) - return SDValue(); - - SDValue NewCF = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CF) - : DAG.getNode(BinOpcode, DL, VT, CF, CBO); - if (!CanFoldNonConst && !NewCF.isUndef() && - !isConstantOrConstantVector(NewCF, true) && - !isConstantFPBuildVectorOrConstantFP(NewCF)) - return SDValue(); - - return DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF); -} - -static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) { - assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) && - "Expecting add or sub"); - - // Match a constant operand and a zext operand for the math instruction: - // add Z, C - // sub C, Z - bool IsAdd = N->getOpcode() == ISD::ADD; - SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0); - SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1); - auto *CN = dyn_cast<ConstantSDNode>(C); - if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND) - return SDValue(); - - // Match the zext operand as a setcc of a boolean. - if (Z.getOperand(0).getOpcode() != ISD::SETCC || - Z.getOperand(0).getValueType() != MVT::i1) - return SDValue(); - - // Match the compare as: setcc (X & 1), 0, eq. - SDValue SetCC = Z.getOperand(0); - ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get(); - if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) || - SetCC.getOperand(0).getOpcode() != ISD::AND || - !isOneConstant(SetCC.getOperand(0).getOperand(1))) - return SDValue(); - - // We are adding/subtracting a constant and an inverted low bit. Turn that - // into a subtract/add of the low bit with incremented/decremented constant: - // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1)) - // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1)) - EVT VT = C.getValueType(); - SDLoc DL(N); - SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT); - SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) : - DAG.getConstant(CN->getAPIntValue() - 1, DL, VT); - return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit); -} - -/// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into -/// a shift and add with a different constant. -static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) { - assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) && - "Expecting add or sub"); - - // We need a constant operand for the add/sub, and the other operand is a - // logical shift right: add (srl), C or sub C, (srl). - bool IsAdd = N->getOpcode() == ISD::ADD; - SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0); - SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1); - ConstantSDNode *C = isConstOrConstSplat(ConstantOp); - if (!C || ShiftOp.getOpcode() != ISD::SRL) - return SDValue(); - - // The shift must be of a 'not' value. - SDValue Not = ShiftOp.getOperand(0); - if (!Not.hasOneUse() || !isBitwiseNot(Not)) - return SDValue(); - - // The shift must be moving the sign bit to the least-significant-bit. - EVT VT = ShiftOp.getValueType(); - SDValue ShAmt = ShiftOp.getOperand(1); - ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt); - if (!ShAmtC || ShAmtC->getZExtValue() != VT.getScalarSizeInBits() - 1) - return SDValue(); - - // Eliminate the 'not' by adjusting the shift and add/sub constant: - // add (srl (not X), 31), C --> add (sra X, 31), (C + 1) - // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1) - SDLoc DL(N); - auto ShOpcode = IsAdd ? ISD::SRA : ISD::SRL; - SDValue NewShift = DAG.getNode(ShOpcode, DL, VT, Not.getOperand(0), ShAmt); - APInt NewC = IsAdd ? C->getAPIntValue() + 1 : C->getAPIntValue() - 1; - return DAG.getNode(ISD::ADD, DL, VT, NewShift, DAG.getConstant(NewC, DL, VT)); -} - SDValue DAGCombiner::visitADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N0.getValueType(); - SDLoc DL(N); // fold vector ops if (VT.isVector()) { @@ -2038,102 +1653,69 @@ SDValue DAGCombiner::visitADD(SDNode *N) { } // fold (add x, undef) -> undef - if (N0.isUndef()) + if (N0.getOpcode() == ISD::UNDEF) return N0; - - if (N1.isUndef()) + if (N1.getOpcode() == ISD::UNDEF) return N1; - - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) { - // canonicalize constant to RHS - if (!DAG.isConstantIntBuildVectorOrConstantInt(N1)) - return DAG.getNode(ISD::ADD, DL, VT, N1, N0); - // fold (add c1, c2) -> c1+c2 - return DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, N0.getNode(), - N1.getNode()); - } - + // fold (add c1, c2) -> c1+c2 + ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); + ConstantSDNode *N1C = getAsNonOpaqueConstant(N1); + if (N0C && N1C) + return DAG.FoldConstantArithmetic(ISD::ADD, SDLoc(N), VT, N0C, N1C); + // canonicalize constant to RHS + if (isConstantIntBuildVectorOrConstantInt(N0) && + !isConstantIntBuildVectorOrConstantInt(N1)) + return DAG.getNode(ISD::ADD, SDLoc(N), VT, N1, N0); // fold (add x, 0) -> x if (isNullConstant(N1)) return N0; - - if (isConstantOrConstantVector(N1, /* NoOpaque */ true)) { - // fold ((c1-A)+c2) -> (c1+c2)-A - if (N0.getOpcode() == ISD::SUB && - isConstantOrConstantVector(N0.getOperand(0), /* NoOpaque */ true)) { - // FIXME: Adding 2 constants should be handled by FoldConstantArithmetic. + // fold (add Sym, c) -> Sym+c + if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0)) + if (!LegalOperations && TLI.isOffsetFoldingLegal(GA) && N1C && + GA->getOpcode() == ISD::GlobalAddress) + return DAG.getGlobalAddress(GA->getGlobal(), SDLoc(N1C), VT, + GA->getOffset() + + (uint64_t)N1C->getSExtValue()); + // fold ((c1-A)+c2) -> (c1+c2)-A + if (N1C && N0.getOpcode() == ISD::SUB) + if (ConstantSDNode *N0C = getAsNonOpaqueConstant(N0.getOperand(0))) { + SDLoc DL(N); return DAG.getNode(ISD::SUB, DL, VT, - DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)), + DAG.getConstant(N1C->getAPIntValue()+ + N0C->getAPIntValue(), DL, VT), N0.getOperand(1)); } - - // add (sext i1 X), 1 -> zext (not i1 X) - // We don't transform this pattern: - // add (zext i1 X), -1 -> sext (not i1 X) - // because most (?) targets generate better code for the zext form. - if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() && - isOneOrOneSplat(N1)) { - SDValue X = N0.getOperand(0); - if ((!LegalOperations || - (TLI.isOperationLegal(ISD::XOR, X.getValueType()) && - TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) && - X.getScalarValueSizeInBits() == 1) { - SDValue Not = DAG.getNOT(DL, X, X.getValueType()); - return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not); - } - } - - // Undo the add -> or combine to merge constant offsets from a frame index. - if (N0.getOpcode() == ISD::OR && - isa<FrameIndexSDNode>(N0.getOperand(0)) && - isa<ConstantSDNode>(N0.getOperand(1)) && - DAG.haveNoCommonBitsSet(N0.getOperand(0), N0.getOperand(1))) { - SDValue Add0 = DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(1)); - return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add0); - } - } - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; - // reassociate add - if (SDValue RADD = ReassociateOps(ISD::ADD, DL, N0, N1, N->getFlags())) + if (SDValue RADD = ReassociateOps(ISD::ADD, SDLoc(N), N0, N1)) return RADD; - // fold ((0-A) + B) -> B-A - if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0))) - return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1)); - + if (N0.getOpcode() == ISD::SUB && isNullConstant(N0.getOperand(0))) + return DAG.getNode(ISD::SUB, SDLoc(N), VT, N1, N0.getOperand(1)); // fold (A + (0-B)) -> A-B - if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0))) - return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1)); - + if (N1.getOpcode() == ISD::SUB && isNullConstant(N1.getOperand(0))) + return DAG.getNode(ISD::SUB, SDLoc(N), VT, N0, N1.getOperand(1)); // fold (A+(B-A)) -> B if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1)) return N1.getOperand(0); - // fold ((B-A)+A) -> B if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1)) return N0.getOperand(0); - // fold (A+(B-(A+C))) to (B-C) if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD && N0 == N1.getOperand(1).getOperand(0)) - return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0), + return DAG.getNode(ISD::SUB, SDLoc(N), VT, N1.getOperand(0), N1.getOperand(1).getOperand(1)); - // fold (A+(B-(C+A))) to (B-C) if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD && N0 == N1.getOperand(1).getOperand(1)) - return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0), + return DAG.getNode(ISD::SUB, SDLoc(N), VT, N1.getOperand(0), N1.getOperand(1).getOperand(0)); - // fold (A+((B-A)+or-C)) to (B+or-C) if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) && N1.getOperand(0).getOpcode() == ISD::SUB && N0 == N1.getOperand(0).getOperand(1)) - return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0), - N1.getOperand(1)); + return DAG.getNode(N1.getOpcode(), SDLoc(N), VT, + N1.getOperand(0).getOperand(0), N1.getOperand(1)); // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB) { @@ -2142,148 +1724,52 @@ SDValue DAGCombiner::visitADD(SDNode *N) { SDValue N10 = N1.getOperand(0); SDValue N11 = N1.getOperand(1); - if (isConstantOrConstantVector(N00) || isConstantOrConstantVector(N10)) - return DAG.getNode(ISD::SUB, DL, VT, + if (isa<ConstantSDNode>(N00) || isa<ConstantSDNode>(N10)) + return DAG.getNode(ISD::SUB, SDLoc(N), VT, DAG.getNode(ISD::ADD, SDLoc(N0), VT, N00, N10), DAG.getNode(ISD::ADD, SDLoc(N1), VT, N01, N11)); } - if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG)) - return V; - - if (SDValue V = foldAddSubOfSignBit(N, DAG)) - return V; - - if (SimplifyDemandedBits(SDValue(N, 0))) + if (!VT.isVector() && SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); // fold (a+b) -> (a|b) iff a and b share no bits. if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) && - DAG.haveNoCommonBitsSet(N0, N1)) - return DAG.getNode(ISD::OR, DL, VT, N0, N1); - - // fold (add (xor a, -1), 1) -> (sub 0, a) - if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) - return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), - N0.getOperand(0)); - - if (SDValue Combined = visitADDLike(N0, N1, N)) - return Combined; - - if (SDValue Combined = visitADDLike(N1, N0, N)) - return Combined; - - return SDValue(); -} - -SDValue DAGCombiner::visitADDSAT(SDNode *N) { - unsigned Opcode = N->getOpcode(); - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - EVT VT = N0.getValueType(); - SDLoc DL(N); - - // fold vector ops - if (VT.isVector()) { - // TODO SimplifyVBinOp - - // fold (add_sat x, 0) -> x, vector edition - if (ISD::isBuildVectorAllZeros(N1.getNode())) - return N0; - if (ISD::isBuildVectorAllZeros(N0.getNode())) - return N1; - } - - // fold (add_sat x, undef) -> -1 - if (N0.isUndef() || N1.isUndef()) - return DAG.getAllOnesConstant(DL, VT); - - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) { - // canonicalize constant to RHS - if (!DAG.isConstantIntBuildVectorOrConstantInt(N1)) - return DAG.getNode(Opcode, DL, VT, N1, N0); - // fold (add_sat c1, c2) -> c3 - return DAG.FoldConstantArithmetic(Opcode, DL, VT, N0.getNode(), - N1.getNode()); - } - - // fold (add_sat x, 0) -> x - if (isNullConstant(N1)) - return N0; - - // If it cannot overflow, transform into an add. - if (Opcode == ISD::UADDSAT) - if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never) - return DAG.getNode(ISD::ADD, DL, VT, N0, N1); - - return SDValue(); -} - -static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) { - bool Masked = false; - - // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization. - while (true) { - if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) { - V = V.getOperand(0); - continue; - } - - if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) { - Masked = true; - V = V.getOperand(0); - continue; - } - - break; - } - - // If this is not a carry, return. - if (V.getResNo() != 1) - return SDValue(); - - if (V.getOpcode() != ISD::ADDCARRY && V.getOpcode() != ISD::SUBCARRY && - V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO) - return SDValue(); - - // If the result is masked, then no matter what kind of bool it is we can - // return. If it isn't, then we need to make sure the bool type is either 0 or - // 1 and not other values. - if (Masked || - TLI.getBooleanContents(V.getValueType()) == - TargetLoweringBase::ZeroOrOneBooleanContent) - return V; - - return SDValue(); -} - -SDValue DAGCombiner::visitADDLike(SDValue N0, SDValue N1, SDNode *LocReference) { - EVT VT = N0.getValueType(); - SDLoc DL(LocReference); + VT.isInteger() && !VT.isVector() && DAG.haveNoCommonBitsSet(N0, N1)) + return DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N1); // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n)) if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB && - isNullOrNullSplat(N1.getOperand(0).getOperand(0))) - return DAG.getNode(ISD::SUB, DL, VT, N0, - DAG.getNode(ISD::SHL, DL, VT, + isNullConstant(N1.getOperand(0).getOperand(0))) + return DAG.getNode(ISD::SUB, SDLoc(N), VT, N0, + DAG.getNode(ISD::SHL, SDLoc(N), VT, N1.getOperand(0).getOperand(1), N1.getOperand(1))); + if (N0.getOpcode() == ISD::SHL && N0.getOperand(0).getOpcode() == ISD::SUB && + isNullConstant(N0.getOperand(0).getOperand(0))) + return DAG.getNode(ISD::SUB, SDLoc(N), VT, N1, + DAG.getNode(ISD::SHL, SDLoc(N), VT, + N0.getOperand(0).getOperand(1), + N0.getOperand(1))); if (N1.getOpcode() == ISD::AND) { SDValue AndOp0 = N1.getOperand(0); unsigned NumSignBits = DAG.ComputeNumSignBits(AndOp0); - unsigned DestBits = VT.getScalarSizeInBits(); + unsigned DestBits = VT.getScalarType().getSizeInBits(); // (add z, (and (sbbl x, x), 1)) -> (sub z, (sbbl x, x)) // and similar xforms where the inner op is either ~0 or 0. - if (NumSignBits == DestBits && isOneOrOneSplat(N1->getOperand(1))) - return DAG.getNode(ISD::SUB, DL, VT, N0, AndOp0); + if (NumSignBits == DestBits && isOneConstant(N1->getOperand(1))) { + SDLoc DL(N); + return DAG.getNode(ISD::SUB, DL, VT, N->getOperand(0), AndOp0); + } } // add (sext i1), X -> sub X, (zext i1) if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.getOperand(0).getValueType() == MVT::i1 && !TLI.isOperationLegal(ISD::SIGN_EXTEND, MVT::i1)) { + SDLoc DL(N); SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0)); return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt); } @@ -2292,25 +1778,13 @@ SDValue DAGCombiner::visitADDLike(SDValue N0, SDValue N1, SDNode *LocReference) if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) { VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1)); if (TN->getVT() == MVT::i1) { + SDLoc DL(N); SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0), DAG.getConstant(1, DL, VT)); return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt); } } - // (add X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry) - if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1)) && - N1.getResNo() == 0) - return DAG.getNode(ISD::ADDCARRY, DL, N1->getVTList(), - N0, N1.getOperand(0), N1.getOperand(2)); - - // (add X, Carry) -> (addcarry X, 0, Carry) - if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) - if (SDValue Carry = getAsCarry(TLI, N1)) - return DAG.getNode(ISD::ADDCARRY, DL, - DAG.getVTList(VT, Carry.getValueType()), N0, - DAG.getConstant(0, DL, VT), Carry); - return SDValue(); } @@ -2318,131 +1792,40 @@ SDValue DAGCombiner::visitADDC(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N0.getValueType(); - SDLoc DL(N); // If the flag result is dead, turn this into an ADD. if (!N->hasAnyUseOfValue(1)) - return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1), - DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue)); + return CombineTo(N, DAG.getNode(ISD::ADD, SDLoc(N), VT, N0, N1), + DAG.getNode(ISD::CARRY_FALSE, + SDLoc(N), MVT::Glue)); // canonicalize constant to RHS. ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0); ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); if (N0C && !N1C) - return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0); + return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N1, N0); // fold (addc x, 0) -> x + no carry out if (isNullConstant(N1)) return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, - DL, MVT::Glue)); + SDLoc(N), MVT::Glue)); - // If it cannot overflow, transform into an add. - if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never) - return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1), - DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue)); + // fold (addc a, b) -> (or a, b), CARRY_FALSE iff a and b share no bits. + APInt LHSZero, LHSOne; + APInt RHSZero, RHSOne; + DAG.computeKnownBits(N0, LHSZero, LHSOne); - return SDValue(); -} + if (LHSZero.getBoolValue()) { + DAG.computeKnownBits(N1, RHSZero, RHSOne); -static SDValue flipBoolean(SDValue V, const SDLoc &DL, EVT VT, - SelectionDAG &DAG, const TargetLowering &TLI) { - SDValue Cst; - switch (TLI.getBooleanContents(VT)) { - case TargetLowering::ZeroOrOneBooleanContent: - case TargetLowering::UndefinedBooleanContent: - Cst = DAG.getConstant(1, DL, VT); - break; - case TargetLowering::ZeroOrNegativeOneBooleanContent: - Cst = DAG.getConstant(-1, DL, VT); - break; + // If all possibly-set bits on the LHS are clear on the RHS, return an OR. + // If all possibly-set bits on the RHS are clear on the LHS, return an OR. + if ((RHSZero & ~LHSZero) == ~LHSZero || (LHSZero & ~RHSZero) == ~RHSZero) + return CombineTo(N, DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N1), + DAG.getNode(ISD::CARRY_FALSE, + SDLoc(N), MVT::Glue)); } - return DAG.getNode(ISD::XOR, DL, VT, V, Cst); -} - -static bool isBooleanFlip(SDValue V, EVT VT, const TargetLowering &TLI) { - if (V.getOpcode() != ISD::XOR) return false; - ConstantSDNode *Const = dyn_cast<ConstantSDNode>(V.getOperand(1)); - if (!Const) return false; - - switch(TLI.getBooleanContents(VT)) { - case TargetLowering::ZeroOrOneBooleanContent: - return Const->isOne(); - case TargetLowering::ZeroOrNegativeOneBooleanContent: - return Const->isAllOnesValue(); - case TargetLowering::UndefinedBooleanContent: - return (Const->getAPIntValue() & 0x01) == 1; - } - llvm_unreachable("Unsupported boolean content"); -} - -SDValue DAGCombiner::visitUADDO(SDNode *N) { - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - EVT VT = N0.getValueType(); - if (VT.isVector()) - return SDValue(); - - EVT CarryVT = N->getValueType(1); - SDLoc DL(N); - - // If the flag result is dead, turn this into an ADD. - if (!N->hasAnyUseOfValue(1)) - return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1), - DAG.getUNDEF(CarryVT)); - - // canonicalize constant to RHS. - ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0); - ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); - if (N0C && !N1C) - return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N1, N0); - - // fold (uaddo x, 0) -> x + no carry out - if (isNullConstant(N1)) - return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT)); - - // If it cannot overflow, transform into an add. - if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never) - return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1), - DAG.getConstant(0, DL, CarryVT)); - - // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry. - if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) { - SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(), - DAG.getConstant(0, DL, VT), - N0.getOperand(0)); - return CombineTo(N, Sub, - flipBoolean(Sub.getValue(1), DL, CarryVT, DAG, TLI)); - } - - if (SDValue Combined = visitUADDOLike(N0, N1, N)) - return Combined; - - if (SDValue Combined = visitUADDOLike(N1, N0, N)) - return Combined; - - return SDValue(); -} - -SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) { - auto VT = N0.getValueType(); - - // (uaddo X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry) - // If Y + 1 cannot overflow. - if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1))) { - SDValue Y = N1.getOperand(0); - SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType()); - if (DAG.computeOverflowKind(Y, One) == SelectionDAG::OFK_Never) - return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, Y, - N1.getOperand(2)); - } - - // (uaddo X, Carry) -> (addcarry X, 0, Carry) - if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) - if (SDValue Carry = getAsCarry(TLI, N1)) - return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, - DAG.getConstant(0, SDLoc(N), VT), Carry); - return SDValue(); } @@ -2465,108 +1848,11 @@ SDValue DAGCombiner::visitADDE(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitADDCARRY(SDNode *N) { - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - SDValue CarryIn = N->getOperand(2); - SDLoc DL(N); - - // canonicalize constant to RHS - ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0); - ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); - if (N0C && !N1C) - return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), N1, N0, CarryIn); - - // fold (addcarry x, y, false) -> (uaddo x, y) - if (isNullConstant(CarryIn)) { - if (!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0))) - return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1); - } - - EVT CarryVT = CarryIn.getValueType(); - - // fold (addcarry 0, 0, X) -> (and (ext/trunc X), 1) and no carry. - if (isNullConstant(N0) && isNullConstant(N1)) { - EVT VT = N0.getValueType(); - SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT); - AddToWorklist(CarryExt.getNode()); - return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt, - DAG.getConstant(1, DL, VT)), - DAG.getConstant(0, DL, CarryVT)); - } - - // fold (addcarry (xor a, -1), 0, !b) -> (subcarry 0, a, b) and flip carry. - if (isBitwiseNot(N0) && isNullConstant(N1) && - isBooleanFlip(CarryIn, CarryVT, TLI)) { - SDValue Sub = DAG.getNode(ISD::SUBCARRY, DL, N->getVTList(), - DAG.getConstant(0, DL, N0.getValueType()), - N0.getOperand(0), CarryIn.getOperand(0)); - return CombineTo(N, Sub, - flipBoolean(Sub.getValue(1), DL, CarryVT, DAG, TLI)); - } - - if (SDValue Combined = visitADDCARRYLike(N0, N1, CarryIn, N)) - return Combined; - - if (SDValue Combined = visitADDCARRYLike(N1, N0, CarryIn, N)) - return Combined; - - return SDValue(); -} - -SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, - SDNode *N) { - // Iff the flag result is dead: - // (addcarry (add|uaddo X, Y), 0, Carry) -> (addcarry X, Y, Carry) - if ((N0.getOpcode() == ISD::ADD || - (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0)) && - isNullConstant(N1) && !N->hasAnyUseOfValue(1)) - return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), - N0.getOperand(0), N0.getOperand(1), CarryIn); - - /** - * When one of the addcarry argument is itself a carry, we may be facing - * a diamond carry propagation. In which case we try to transform the DAG - * to ensure linear carry propagation if that is possible. - * - * We are trying to get: - * (addcarry X, 0, (addcarry A, B, Z):Carry) - */ - if (auto Y = getAsCarry(TLI, N1)) { - /** - * (uaddo A, B) - * / \ - * Carry Sum - * | \ - * | (addcarry *, 0, Z) - * | / - * \ Carry - * | / - * (addcarry X, *, *) - */ - if (Y.getOpcode() == ISD::UADDO && - CarryIn.getResNo() == 1 && - CarryIn.getOpcode() == ISD::ADDCARRY && - isNullConstant(CarryIn.getOperand(1)) && - CarryIn.getOperand(0) == Y.getValue(0)) { - auto NewY = DAG.getNode(ISD::ADDCARRY, SDLoc(N), Y->getVTList(), - Y.getOperand(0), Y.getOperand(1), - CarryIn.getOperand(2)); - AddToWorklist(NewY.getNode()); - return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, - DAG.getConstant(0, SDLoc(N), N0.getValueType()), - NewY.getValue(1)); - } - } - - return SDValue(); -} - // Since it may not be valid to emit a fold to zero for vector initializers // check if we can before folding. -static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT, - SelectionDAG &DAG, bool LegalOperations) { +static SDValue tryFoldToZero(SDLoc DL, const TargetLowering &TLI, EVT VT, + SelectionDAG &DAG, + bool LegalOperations, bool LegalTypes) { if (!VT.isVector()) return DAG.getConstant(0, DL, VT); if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) @@ -2578,7 +1864,6 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N0.getValueType(); - SDLoc DL(N); // fold vector ops if (VT.isVector()) { @@ -2593,155 +1878,66 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { // fold (sub x, x) -> 0 // FIXME: Refactor this and xor and other similar operations together. if (N0 == N1) - return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations); - if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && - DAG.isConstantIntBuildVectorOrConstantInt(N1)) { - // fold (sub c1, c2) -> c1-c2 - return DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, N0.getNode(), - N1.getNode()); - } - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; - + return tryFoldToZero(SDLoc(N), TLI, VT, DAG, LegalOperations, LegalTypes); + // fold (sub c1, c2) -> c1-c2 + ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); ConstantSDNode *N1C = getAsNonOpaqueConstant(N1); - + if (N0C && N1C) + return DAG.FoldConstantArithmetic(ISD::SUB, SDLoc(N), VT, N0C, N1C); // fold (sub x, c) -> (add x, -c) if (N1C) { + SDLoc DL(N); return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getConstant(-N1C->getAPIntValue(), DL, VT)); } - - if (isNullOrNullSplat(N0)) { - unsigned BitWidth = VT.getScalarSizeInBits(); - // Right-shifting everything out but the sign bit followed by negation is - // the same as flipping arithmetic/logical shift type without the negation: - // -(X >>u 31) -> (X >>s 31) - // -(X >>s 31) -> (X >>u 31) - if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) { - ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1)); - if (ShiftAmt && ShiftAmt->getZExtValue() == BitWidth - 1) { - auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA; - if (!LegalOperations || TLI.isOperationLegal(NewSh, VT)) - return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1)); - } - } - - // 0 - X --> 0 if the sub is NUW. - if (N->getFlags().hasNoUnsignedWrap()) - return N0; - - if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) { - // N1 is either 0 or the minimum signed value. If the sub is NSW, then - // N1 must be 0 because negating the minimum signed value is undefined. - if (N->getFlags().hasNoSignedWrap()) - return N0; - - // 0 - X --> X if X is 0 or the minimum signed value. - return N1; - } - } - // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) - if (isAllOnesOrAllOnesSplat(N0)) - return DAG.getNode(ISD::XOR, DL, VT, N1, N0); - - // fold (A - (0-B)) -> A+B - if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0))) - return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1)); - + if (isAllOnesConstant(N0)) + return DAG.getNode(ISD::XOR, SDLoc(N), VT, N1, N0); // fold A-(A-B) -> B if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0)) return N1.getOperand(1); - // fold (A+B)-A -> B if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1) return N0.getOperand(1); - // fold (A+B)-B -> A if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1) return N0.getOperand(0); - // fold C2-(A+C1) -> (C2-C1)-A - if (N1.getOpcode() == ISD::ADD) { - SDValue N11 = N1.getOperand(1); - if (isConstantOrConstantVector(N0, /* NoOpaques */ true) && - isConstantOrConstantVector(N11, /* NoOpaques */ true)) { - SDValue NewC = DAG.getNode(ISD::SUB, DL, VT, N0, N11); - return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0)); - } + ConstantSDNode *N1C1 = N1.getOpcode() != ISD::ADD ? nullptr : + dyn_cast<ConstantSDNode>(N1.getOperand(1).getNode()); + if (N1.getOpcode() == ISD::ADD && N0C && N1C1) { + SDLoc DL(N); + SDValue NewC = DAG.getConstant(N0C->getAPIntValue() - N1C1->getAPIntValue(), + DL, VT); + return DAG.getNode(ISD::SUB, DL, VT, NewC, + N1.getOperand(0)); } - // fold ((A+(B+or-C))-B) -> A+or-C if (N0.getOpcode() == ISD::ADD && (N0.getOperand(1).getOpcode() == ISD::SUB || N0.getOperand(1).getOpcode() == ISD::ADD) && N0.getOperand(1).getOperand(0) == N1) - return DAG.getNode(N0.getOperand(1).getOpcode(), DL, VT, N0.getOperand(0), - N0.getOperand(1).getOperand(1)); - + return DAG.getNode(N0.getOperand(1).getOpcode(), SDLoc(N), VT, + N0.getOperand(0), N0.getOperand(1).getOperand(1)); // fold ((A+(C+B))-B) -> A+C - if (N0.getOpcode() == ISD::ADD && N0.getOperand(1).getOpcode() == ISD::ADD && + if (N0.getOpcode() == ISD::ADD && + N0.getOperand(1).getOpcode() == ISD::ADD && N0.getOperand(1).getOperand(1) == N1) - return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), - N0.getOperand(1).getOperand(0)); - + return DAG.getNode(ISD::ADD, SDLoc(N), VT, + N0.getOperand(0), N0.getOperand(1).getOperand(0)); // fold ((A-(B-C))-C) -> A-B - if (N0.getOpcode() == ISD::SUB && N0.getOperand(1).getOpcode() == ISD::SUB && + if (N0.getOpcode() == ISD::SUB && + N0.getOperand(1).getOpcode() == ISD::SUB && N0.getOperand(1).getOperand(1) == N1) - return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), - N0.getOperand(1).getOperand(0)); - - // fold (A-(B-C)) -> A+(C-B) - if (N1.getOpcode() == ISD::SUB && N1.hasOneUse()) - return DAG.getNode(ISD::ADD, DL, VT, N0, - DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(1), - N1.getOperand(0))); - - // fold (X - (-Y * Z)) -> (X + (Y * Z)) - if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) { - if (N1.getOperand(0).getOpcode() == ISD::SUB && - isNullOrNullSplat(N1.getOperand(0).getOperand(0))) { - SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, - N1.getOperand(0).getOperand(1), - N1.getOperand(1)); - return DAG.getNode(ISD::ADD, DL, VT, N0, Mul); - } - if (N1.getOperand(1).getOpcode() == ISD::SUB && - isNullOrNullSplat(N1.getOperand(1).getOperand(0))) { - SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, - N1.getOperand(0), - N1.getOperand(1).getOperand(1)); - return DAG.getNode(ISD::ADD, DL, VT, N0, Mul); - } - } + return DAG.getNode(ISD::SUB, SDLoc(N), VT, + N0.getOperand(0), N0.getOperand(1).getOperand(0)); // If either operand of a sub is undef, the result is undef - if (N0.isUndef()) + if (N0.getOpcode() == ISD::UNDEF) return N0; - if (N1.isUndef()) + if (N1.getOpcode() == ISD::UNDEF) return N1; - if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG)) - return V; - - if (SDValue V = foldAddSubOfSignBit(N, DAG)) - return V; - - // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X) - if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) { - if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) { - SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1); - SDValue S0 = N1.getOperand(0); - if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0)) { - unsigned OpSizeInBits = VT.getScalarSizeInBits(); - if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1))) - if (C->getAPIntValue() == (OpSizeInBits - 1)) - return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0); - } - } - } - // If the relocation model supports it, consider symbol offsets. if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0)) if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) { @@ -2749,72 +1945,25 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { if (N1C && GA->getOpcode() == ISD::GlobalAddress) return DAG.getGlobalAddress(GA->getGlobal(), SDLoc(N1C), VT, GA->getOffset() - - (uint64_t)N1C->getSExtValue()); + (uint64_t)N1C->getSExtValue()); // fold (sub Sym+c1, Sym+c2) -> c1-c2 if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1)) if (GA->getGlobal() == GB->getGlobal()) return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(), - DL, VT); + SDLoc(N), VT); } // sub X, (sextinreg Y i1) -> add X, (and Y 1) if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) { VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1)); if (TN->getVT() == MVT::i1) { + SDLoc DL(N); SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0), DAG.getConstant(1, DL, VT)); return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt); } } - // Prefer an add for more folding potential and possibly better codegen: - // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1) - if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) { - SDValue ShAmt = N1.getOperand(1); - ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt); - if (ShAmtC && ShAmtC->getZExtValue() == N1.getScalarValueSizeInBits() - 1) { - SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt); - return DAG.getNode(ISD::ADD, DL, VT, N0, SRA); - } - } - - return SDValue(); -} - -SDValue DAGCombiner::visitSUBSAT(SDNode *N) { - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - EVT VT = N0.getValueType(); - SDLoc DL(N); - - // fold vector ops - if (VT.isVector()) { - // TODO SimplifyVBinOp - - // fold (sub_sat x, 0) -> x, vector edition - if (ISD::isBuildVectorAllZeros(N1.getNode())) - return N0; - } - - // fold (sub_sat x, undef) -> 0 - if (N0.isUndef() || N1.isUndef()) - return DAG.getConstant(0, DL, VT); - - // fold (sub_sat x, x) -> 0 - if (N0 == N1) - return DAG.getConstant(0, DL, VT); - - if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && - DAG.isConstantIntBuildVectorOrConstantInt(N1)) { - // fold (sub_sat c1, c2) -> c3 - return DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, N0.getNode(), - N1.getNode()); - } - - // fold (sub_sat x, 0) -> x - if (isNullConstant(N1)) - return N0; - return SDValue(); } @@ -2846,38 +1995,6 @@ SDValue DAGCombiner::visitSUBC(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitUSUBO(SDNode *N) { - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - EVT VT = N0.getValueType(); - if (VT.isVector()) - return SDValue(); - - EVT CarryVT = N->getValueType(1); - SDLoc DL(N); - - // If the flag result is dead, turn this into an SUB. - if (!N->hasAnyUseOfValue(1)) - return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1), - DAG.getUNDEF(CarryVT)); - - // fold (usubo x, x) -> 0 + no borrow - if (N0 == N1) - return CombineTo(N, DAG.getConstant(0, DL, VT), - DAG.getConstant(0, DL, CarryVT)); - - // fold (usubo x, 0) -> x + no borrow - if (isNullConstant(N1)) - return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT)); - - // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow - if (isAllOnesConstant(N0)) - return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0), - DAG.getConstant(0, DL, CarryVT)); - - return SDValue(); -} - SDValue DAGCombiner::visitSUBE(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -2890,28 +2007,13 @@ SDValue DAGCombiner::visitSUBE(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitSUBCARRY(SDNode *N) { - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - SDValue CarryIn = N->getOperand(2); - - // fold (subcarry x, y, false) -> (usubo x, y) - if (isNullConstant(CarryIn)) { - if (!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0))) - return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1); - } - - return SDValue(); -} - SDValue DAGCombiner::visitMUL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N0.getValueType(); // fold (mul x, undef) -> 0 - if (N0.isUndef() || N1.isUndef()) + if (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF) return DAG.getConstant(0, SDLoc(N), VT); bool N0IsConst = false; @@ -2924,14 +2026,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { if (SDValue FoldedVOp = SimplifyVBinOp(N)) return FoldedVOp; - N0IsConst = ISD::isConstantSplatVector(N0.getNode(), ConstValue0); - N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1); - assert((!N0IsConst || - ConstValue0.getBitWidth() == VT.getScalarSizeInBits()) && - "Splat APInt should be element width"); - assert((!N1IsConst || - ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) && - "Splat APInt should be element width"); + N0IsConst = isConstantSplatVector(N0.getNode(), ConstValue0); + N1IsConst = isConstantSplatVector(N1.getNode(), ConstValue1); } else { N0IsConst = isa<ConstantSDNode>(N0); if (N0IsConst) { @@ -2951,19 +2047,19 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { N0.getNode(), N1.getNode()); // canonicalize constant to RHS (vector doesn't have to splat) - if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && - !DAG.isConstantIntBuildVectorOrConstantInt(N1)) + if (isConstantIntBuildVectorOrConstantInt(N0) && + !isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0); // fold (mul x, 0) -> 0 - if (N1IsConst && ConstValue1.isNullValue()) + if (N1IsConst && ConstValue1 == 0) return N1; + // We require a splat of the entire scalar bit width for non-contiguous + // bit patterns. + bool IsFullSplat = + ConstValue1.getBitWidth() == VT.getScalarType().getSizeInBits(); // fold (mul x, 1) -> x - if (N1IsConst && ConstValue1.isOneValue()) + if (N1IsConst && ConstValue1 == 1 && IsFullSplat) return N0; - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; - // fold (mul x, -1) -> 0-x if (N1IsConst && ConstValue1.isAllOnesValue()) { SDLoc DL(N); @@ -2971,17 +2067,16 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { DAG.getConstant(0, DL, VT), N0); } // fold (mul x, (1 << c)) -> x << c - if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) && - DAG.isKnownToBeAPowerOfTwo(N1) && - (!VT.isVector() || Level <= AfterLegalizeVectorOps)) { + if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isPowerOf2() && + IsFullSplat) { SDLoc DL(N); - SDValue LogBase2 = BuildLogBase2(N1, DL); - EVT ShiftVT = getShiftAmountTy(N0.getValueType()); - SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT); - return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc); + return DAG.getNode(ISD::SHL, DL, VT, N0, + DAG.getConstant(ConstValue1.logBase2(), DL, + getShiftAmountTy(N0.getValueType()))); } // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c - if (N1IsConst && !N1IsOpaqueConst && (-ConstValue1).isPowerOf2()) { + if (N1IsConst && !N1IsOpaqueConst && (-ConstValue1).isPowerOf2() && + IsFullSplat) { unsigned Log2Val = (-ConstValue1).logBase2(); SDLoc DL(N); // FIXME: If the input is something that is easily negated (e.g. a @@ -2993,74 +2088,46 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { getShiftAmountTy(N0.getValueType())))); } - // Try to transform multiply-by-(power-of-2 +/- 1) into shift and add/sub. - // mul x, (2^N + 1) --> add (shl x, N), x - // mul x, (2^N - 1) --> sub (shl x, N), x - // Examples: x * 33 --> (x << 5) + x - // x * 15 --> (x << 4) - x - // x * -33 --> -((x << 5) + x) - // x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4) - if (N1IsConst && TLI.decomposeMulByConstant(VT, N1)) { - // TODO: We could handle more general decomposition of any constant by - // having the target set a limit on number of ops and making a - // callback to determine that sequence (similar to sqrt expansion). - unsigned MathOp = ISD::DELETED_NODE; - APInt MulC = ConstValue1.abs(); - if ((MulC - 1).isPowerOf2()) - MathOp = ISD::ADD; - else if ((MulC + 1).isPowerOf2()) - MathOp = ISD::SUB; - - if (MathOp != ISD::DELETED_NODE) { - unsigned ShAmt = MathOp == ISD::ADD ? (MulC - 1).logBase2() - : (MulC + 1).logBase2(); - assert(ShAmt > 0 && ShAmt < VT.getScalarSizeInBits() && - "Not expecting multiply-by-constant that could have simplified"); - SDLoc DL(N); - SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, N0, - DAG.getConstant(ShAmt, DL, VT)); - SDValue R = DAG.getNode(MathOp, DL, VT, Shl, N0); - if (ConstValue1.isNegative()) - R = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), R); - return R; - } - } - + APInt Val; // (mul (shl X, c1), c2) -> (mul X, c2 << c1) - if (N0.getOpcode() == ISD::SHL && - isConstantOrConstantVector(N1, /* NoOpaques */ true) && - isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) { - SDValue C3 = DAG.getNode(ISD::SHL, SDLoc(N), VT, N1, N0.getOperand(1)); - if (isConstantOrConstantVector(C3)) - return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), C3); + if (N1IsConst && N0.getOpcode() == ISD::SHL && + (isConstantSplatVector(N0.getOperand(1).getNode(), Val) || + isa<ConstantSDNode>(N0.getOperand(1)))) { + SDValue C3 = DAG.getNode(ISD::SHL, SDLoc(N), VT, + N1, N0.getOperand(1)); + AddToWorklist(C3.getNode()); + return DAG.getNode(ISD::MUL, SDLoc(N), VT, + N0.getOperand(0), C3); } // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one // use. { - SDValue Sh(nullptr, 0), Y(nullptr, 0); - + SDValue Sh(nullptr,0), Y(nullptr,0); // Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)). if (N0.getOpcode() == ISD::SHL && - isConstantOrConstantVector(N0.getOperand(1)) && + (isConstantSplatVector(N0.getOperand(1).getNode(), Val) || + isa<ConstantSDNode>(N0.getOperand(1))) && N0.getNode()->hasOneUse()) { Sh = N0; Y = N1; } else if (N1.getOpcode() == ISD::SHL && - isConstantOrConstantVector(N1.getOperand(1)) && + isa<ConstantSDNode>(N1.getOperand(1)) && N1.getNode()->hasOneUse()) { Sh = N1; Y = N0; } if (Sh.getNode()) { - SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N), VT, Sh.getOperand(0), Y); - return DAG.getNode(ISD::SHL, SDLoc(N), VT, Mul, Sh.getOperand(1)); + SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N), VT, + Sh.getOperand(0), Y); + return DAG.getNode(ISD::SHL, SDLoc(N), VT, + Mul, Sh.getOperand(1)); } } // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2) - if (DAG.isConstantIntBuildVectorOrConstantInt(N1) && + if (isConstantIntBuildVectorOrConstantInt(N1) && N0.getOpcode() == ISD::ADD && - DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) && + isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) && isMulAddWithConstProfitable(N, N0, N1)) return DAG.getNode(ISD::ADD, SDLoc(N), VT, DAG.getNode(ISD::MUL, SDLoc(N0), VT, @@ -3069,7 +2136,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { N0.getOperand(1), N1)); // reassociate mul - if (SDValue RMUL = ReassociateOps(ISD::MUL, SDLoc(N), N0, N1, N->getFlags())) + if (SDValue RMUL = ReassociateOps(ISD::MUL, SDLoc(N), N0, N1)) return RMUL; return SDValue(); @@ -3079,10 +2146,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned, const TargetLowering &TLI) { RTLIB::Libcall LC; - EVT NodeType = Node->getValueType(0); - if (!NodeType.isSimple()) - return false; - switch (NodeType.getSimpleVT().SimpleTy) { + switch (Node->getSimpleValueType(0).SimpleTy) { default: return false; // No libcall for vector types. case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break; case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break; @@ -3099,18 +2163,14 @@ SDValue DAGCombiner::useDivRem(SDNode *Node) { if (Node->use_empty()) return SDValue(); // This is a dead node, leave it alone. - unsigned Opcode = Node->getOpcode(); - bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM); - unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM; - - // DivMod lib calls can still work on non-legal types if using lib-calls. EVT VT = Node->getValueType(0); - if (VT.isVector() || !VT.isInteger()) + if (!TLI.isTypeLegal(VT)) return SDValue(); - if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT)) - return SDValue(); + unsigned Opcode = Node->getOpcode(); + bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM); + unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM; // If DIVREM is going to get expanded into a libcall, // but there is no libcall available, then don't combine. if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) && @@ -3135,8 +2195,7 @@ SDValue DAGCombiner::useDivRem(SDNode *Node) { for (SDNode::use_iterator UI = Op0.getNode()->use_begin(), UE = Op0.getNode()->use_end(); UI != UE; ++UI) { SDNode *User = *UI; - if (User == Node || User->getOpcode() == ISD::DELETED_NODE || - User->use_empty()) + if (User == Node || User->use_empty()) continue; // Convert the other matching node(s), too; // otherwise, the DIVREM may get target-legalized into something @@ -3165,57 +2224,10 @@ SDValue DAGCombiner::useDivRem(SDNode *Node) { return combined; } -static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) { - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - EVT VT = N->getValueType(0); - SDLoc DL(N); - - unsigned Opc = N->getOpcode(); - bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc); - ConstantSDNode *N1C = isConstOrConstSplat(N1); - - // X / undef -> undef - // X % undef -> undef - // X / 0 -> undef - // X % 0 -> undef - // NOTE: This includes vectors where any divisor element is zero/undef. - if (DAG.isUndef(Opc, {N0, N1})) - return DAG.getUNDEF(VT); - - // undef / X -> 0 - // undef % X -> 0 - if (N0.isUndef()) - return DAG.getConstant(0, DL, VT); - - // 0 / X -> 0 - // 0 % X -> 0 - ConstantSDNode *N0C = isConstOrConstSplat(N0); - if (N0C && N0C->isNullValue()) - return N0; - - // X / X -> 1 - // X % X -> 0 - if (N0 == N1) - return DAG.getConstant(IsDiv ? 1 : 0, DL, VT); - - // X / 1 -> X - // X % 1 -> 0 - // If this is a boolean op (single-bit element type), we can't have - // division-by-zero or remainder-by-zero, so assume the divisor is 1. - // TODO: Similarly, if we're zero-extending a boolean divisor, then assume - // it's a 1. - if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1)) - return IsDiv ? N0 : DAG.getConstant(0, DL, VT); - - return SDValue(); -} - SDValue DAGCombiner::visitSDIV(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); - EVT CCVT = getSetCCResultType(VT); // fold vector ops if (VT.isVector()) @@ -3229,129 +2241,85 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { ConstantSDNode *N1C = isConstOrConstSplat(N1); if (N0C && N1C && !N0C->isOpaque() && !N1C->isOpaque()) return DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, N0C, N1C); + // fold (sdiv X, 1) -> X + if (N1C && N1C->isOne()) + return N0; // fold (sdiv X, -1) -> 0-X if (N1C && N1C->isAllOnesValue()) - return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0); - // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0) - if (N1C && N1C->getAPIntValue().isMinSignedValue()) - return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ), - DAG.getConstant(1, DL, VT), - DAG.getConstant(0, DL, VT)); - - if (SDValue V = simplifyDivRem(N, DAG)) - return V; - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; + return DAG.getNode(ISD::SUB, DL, VT, + DAG.getConstant(0, DL, VT), N0); // If we know the sign bits of both operands are zero, strength reduce to a // udiv instead. Handles (X&15) /s 4 -> X&15 >> 2 - if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0)) - return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1); - - if (SDValue V = visitSDIVLike(N0, N1, N)) { - // If the corresponding remainder node exists, update its users with - // (Dividend - (Quotient * Divisor). - if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(), - { N0, N1 })) { - SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1); - SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul); - AddToWorklist(Mul.getNode()); - AddToWorklist(Sub.getNode()); - CombineTo(RemNode, Sub); - } - return V; + if (!VT.isVector()) { + if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0)) + return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1); } - // sdiv, srem -> sdivrem - // If the divisor is constant, then return DIVREM only if isIntDivCheap() is - // true. Otherwise, we break the simplification logic in visitREM(). - AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); - if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr)) - if (SDValue DivRem = useDivRem(N)) - return DivRem; - - return SDValue(); -} - -SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) { - SDLoc DL(N); - EVT VT = N->getValueType(0); - EVT CCVT = getSetCCResultType(VT); - unsigned BitWidth = VT.getScalarSizeInBits(); - - // Helper for determining whether a value is a power-2 constant scalar or a - // vector of such elements. - auto IsPowerOfTwo = [](ConstantSDNode *C) { - if (C->isNullValue() || C->isOpaque()) - return false; - if (C->getAPIntValue().isPowerOf2()) - return true; - if ((-C->getAPIntValue()).isPowerOf2()) - return true; - return false; - }; - // fold (sdiv X, pow2) -> simple ops after legalize // FIXME: We check for the exact bit here because the generic lowering gives // better results in that case. The target-specific lowering should learn how // to handle exact sdivs efficiently. - if (!N->getFlags().hasExact() && ISD::matchUnaryPredicate(N1, IsPowerOfTwo)) { + if (N1C && !N1C->isNullValue() && !N1C->isOpaque() && + !cast<BinaryWithFlagsSDNode>(N)->Flags.hasExact() && + (N1C->getAPIntValue().isPowerOf2() || + (-N1C->getAPIntValue()).isPowerOf2())) { // Target-specific implementation of sdiv x, pow2. if (SDValue Res = BuildSDIVPow2(N)) return Res; - // Create constants that are functions of the shift amount value. - EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType()); - SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy); - SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1); - C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy); - SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1); - if (!isConstantOrConstantVector(Inexact)) - return SDValue(); + unsigned lg2 = N1C->getAPIntValue().countTrailingZeros(); // Splat the sign bit into the register - SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0, - DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy)); - AddToWorklist(Sign.getNode()); + SDValue SGN = + DAG.getNode(ISD::SRA, DL, VT, N0, + DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, + getShiftAmountTy(N0.getValueType()))); + AddToWorklist(SGN.getNode()); // Add (N0 < 0) ? abs2 - 1 : 0; - SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact); - AddToWorklist(Srl.getNode()); - SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl); - AddToWorklist(Add.getNode()); - SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1); - AddToWorklist(Sra.getNode()); - - // Special case: (sdiv X, 1) -> X - // Special Case: (sdiv X, -1) -> 0-X - SDValue One = DAG.getConstant(1, DL, VT); - SDValue AllOnes = DAG.getAllOnesConstant(DL, VT); - SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ); - SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ); - SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes); - Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra); - - // If dividing by a positive value, we're done. Otherwise, the result must - // be negated. - SDValue Zero = DAG.getConstant(0, DL, VT); - SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra); - - // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding. - SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT); - SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra); - return Res; + SDValue SRL = + DAG.getNode(ISD::SRL, DL, VT, SGN, + DAG.getConstant(VT.getScalarSizeInBits() - lg2, DL, + getShiftAmountTy(SGN.getValueType()))); + SDValue ADD = DAG.getNode(ISD::ADD, DL, VT, N0, SRL); + AddToWorklist(SRL.getNode()); + AddToWorklist(ADD.getNode()); // Divide by pow2 + SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, ADD, + DAG.getConstant(lg2, DL, + getShiftAmountTy(ADD.getValueType()))); + + // If we're dividing by a positive value, we're done. Otherwise, we must + // negate the result. + if (N1C->getAPIntValue().isNonNegative()) + return SRA; + + AddToWorklist(SRA.getNode()); + return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), SRA); } // If integer divide is expensive and we satisfy the requirements, emit an // alternate sequence. Targets may check function attributes for size/speed // trade-offs. - AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); - if (isConstantOrConstantVector(N1) && - !TLI.isIntDivCheap(N->getValueType(0), Attr)) + AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes(); + if (N1C && !TLI.isIntDivCheap(N->getValueType(0), Attr)) if (SDValue Op = BuildSDIV(N)) return Op; + // sdiv, srem -> sdivrem + // If the divisor is constant, then return DIVREM only if isIntDivCheap() is true. + // Otherwise, we break the simplification logic in visitREM(). + if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr)) + if (SDValue DivRem = useDivRem(N)) + return DivRem; + + // undef / X -> 0 + if (N0.getOpcode() == ISD::UNDEF) + return DAG.getConstant(0, DL, VT); + // X / undef -> undef + if (N1.getOpcode() == ISD::UNDEF) + return N1; + return SDValue(); } @@ -3359,7 +2327,6 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); - EVT CCVT = getSetCCResultType(VT); // fold vector ops if (VT.isVector()) @@ -3375,83 +2342,48 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, N0C, N1C)) return Folded; - // fold (udiv X, -1) -> select(X == -1, 1, 0) - if (N1C && N1C->getAPIntValue().isAllOnesValue()) - return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ), - DAG.getConstant(1, DL, VT), - DAG.getConstant(0, DL, VT)); - - if (SDValue V = simplifyDivRem(N, DAG)) - return V; - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; - - if (SDValue V = visitUDIVLike(N0, N1, N)) { - // If the corresponding remainder node exists, update its users with - // (Dividend - (Quotient * Divisor). - if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(), - { N0, N1 })) { - SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1); - SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul); - AddToWorklist(Mul.getNode()); - AddToWorklist(Sub.getNode()); - CombineTo(RemNode, Sub); - } - return V; - } - - // sdiv, srem -> sdivrem - // If the divisor is constant, then return DIVREM only if isIntDivCheap() is - // true. Otherwise, we break the simplification logic in visitREM(). - AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); - if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr)) - if (SDValue DivRem = useDivRem(N)) - return DivRem; - - return SDValue(); -} - -SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) { - SDLoc DL(N); - EVT VT = N->getValueType(0); - // fold (udiv x, (1 << c)) -> x >>u c - if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) && - DAG.isKnownToBeAPowerOfTwo(N1)) { - SDValue LogBase2 = BuildLogBase2(N1, DL); - AddToWorklist(LogBase2.getNode()); - - EVT ShiftVT = getShiftAmountTy(N0.getValueType()); - SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT); - AddToWorklist(Trunc.getNode()); - return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc); - } + if (N1C && !N1C->isOpaque() && N1C->getAPIntValue().isPowerOf2()) + return DAG.getNode(ISD::SRL, DL, VT, N0, + DAG.getConstant(N1C->getAPIntValue().logBase2(), DL, + getShiftAmountTy(N0.getValueType()))); // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2 if (N1.getOpcode() == ISD::SHL) { - SDValue N10 = N1.getOperand(0); - if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) && - DAG.isKnownToBeAPowerOfTwo(N10)) { - SDValue LogBase2 = BuildLogBase2(N10, DL); - AddToWorklist(LogBase2.getNode()); - - EVT ADDVT = N1.getOperand(1).getValueType(); - SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT); - AddToWorklist(Trunc.getNode()); - SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc); - AddToWorklist(Add.getNode()); - return DAG.getNode(ISD::SRL, DL, VT, N0, Add); + if (ConstantSDNode *SHC = getAsNonOpaqueConstant(N1.getOperand(0))) { + if (SHC->getAPIntValue().isPowerOf2()) { + EVT ADDVT = N1.getOperand(1).getValueType(); + SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, + N1.getOperand(1), + DAG.getConstant(SHC->getAPIntValue() + .logBase2(), + DL, ADDVT)); + AddToWorklist(Add.getNode()); + return DAG.getNode(ISD::SRL, DL, VT, N0, Add); + } } } // fold (udiv x, c) -> alternate - AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); - if (isConstantOrConstantVector(N1) && - !TLI.isIntDivCheap(N->getValueType(0), Attr)) + AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes(); + if (N1C && !TLI.isIntDivCheap(N->getValueType(0), Attr)) if (SDValue Op = BuildUDIV(N)) return Op; + // sdiv, srem -> sdivrem + // If the divisor is constant, then return DIVREM only if isIntDivCheap() is true. + // Otherwise, we break the simplification logic in visitREM(). + if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr)) + if (SDValue DivRem = useDivRem(N)) + return DivRem; + + // undef / X -> 0 + if (N0.getOpcode() == ISD::UNDEF) + return DAG.getConstant(0, DL, VT); + // X / undef -> undef + if (N1.getOpcode() == ISD::UNDEF) + return N1; + return SDValue(); } @@ -3461,8 +2393,6 @@ SDValue DAGCombiner::visitREM(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); - EVT CCVT = getSetCCResultType(VT); - bool isSigned = (Opcode == ISD::SREM); SDLoc DL(N); @@ -3472,60 +2402,56 @@ SDValue DAGCombiner::visitREM(SDNode *N) { if (N0C && N1C) if (SDValue Folded = DAG.FoldConstantArithmetic(Opcode, DL, VT, N0C, N1C)) return Folded; - // fold (urem X, -1) -> select(X == -1, 0, x) - if (!isSigned && N1C && N1C->getAPIntValue().isAllOnesValue()) - return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ), - DAG.getConstant(0, DL, VT), N0); - - if (SDValue V = simplifyDivRem(N, DAG)) - return V; - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; if (isSigned) { // If we know the sign bits of both operands are zero, strength reduce to a // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15 - if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0)) - return DAG.getNode(ISD::UREM, DL, VT, N0, N1); - } else { - SDValue NegOne = DAG.getAllOnesConstant(DL, VT); - if (DAG.isKnownToBeAPowerOfTwo(N1)) { - // fold (urem x, pow2) -> (and x, pow2-1) - SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne); - AddToWorklist(Add.getNode()); - return DAG.getNode(ISD::AND, DL, VT, N0, Add); + if (!VT.isVector()) { + if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0)) + return DAG.getNode(ISD::UREM, DL, VT, N0, N1); } - if (N1.getOpcode() == ISD::SHL && - DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) { - // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1)) - SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne); - AddToWorklist(Add.getNode()); - return DAG.getNode(ISD::AND, DL, VT, N0, Add); + } else { + // fold (urem x, pow2) -> (and x, pow2-1) + if (N1C && !N1C->isNullValue() && !N1C->isOpaque() && + N1C->getAPIntValue().isPowerOf2()) { + return DAG.getNode(ISD::AND, DL, VT, N0, + DAG.getConstant(N1C->getAPIntValue() - 1, DL, VT)); + } + // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1)) + if (N1.getOpcode() == ISD::SHL) { + if (ConstantSDNode *SHC = getAsNonOpaqueConstant(N1.getOperand(0))) { + if (SHC->getAPIntValue().isPowerOf2()) { + SDValue Add = + DAG.getNode(ISD::ADD, DL, VT, N1, + DAG.getConstant(APInt::getAllOnesValue(VT.getSizeInBits()), DL, + VT)); + AddToWorklist(Add.getNode()); + return DAG.getNode(ISD::AND, DL, VT, N0, Add); + } + } } } - AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); + AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes(); // If X/C can be simplified by the division-by-constant logic, lower // X%C to the equivalent of X-X/C*C. - // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the - // speculative DIV must not cause a DIVREM conversion. We guard against this - // by skipping the simplification if isIntDivCheap(). When div is not cheap, - // combine will not return a DIVREM. Regardless, checking cheapness here - // makes sense since the simplification results in fatter code. - if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) { - SDValue OptimizedDiv = - isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N); - if (OptimizedDiv.getNode()) { - // If the equivalent Div node also exists, update its users. - unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV; - if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(), - { N0, N1 })) - CombineTo(DivNode, OptimizedDiv); + // To avoid mangling nodes, this simplification requires that the combine() + // call for the speculative DIV must not cause a DIVREM conversion. We guard + // against this by skipping the simplification if isIntDivCheap(). When + // div is not cheap, combine will not return a DIVREM. Regardless, + // checking cheapness here makes sense since the simplification results in + // fatter code. + if (N1C && !N1C->isNullValue() && !TLI.isIntDivCheap(VT, Attr)) { + unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV; + SDValue Div = DAG.getNode(DivOpcode, DL, VT, N0, N1); + AddToWorklist(Div.getNode()); + SDValue OptimizedDiv = combine(Div.getNode()); + if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != Div.getNode()) { + assert((OptimizedDiv.getOpcode() != ISD::UDIVREM) && + (OptimizedDiv.getOpcode() != ISD::SDIVREM)); SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1); SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul); - AddToWorklist(OptimizedDiv.getNode()); AddToWorklist(Mul.getNode()); return Sub; } @@ -3535,6 +2461,13 @@ SDValue DAGCombiner::visitREM(SDNode *N) { if (SDValue DivRem = useDivRem(N)) return DivRem.getValue(1); + // undef % X -> 0 + if (N0.getOpcode() == ISD::UNDEF) + return DAG.getConstant(0, DL, VT); + // X % undef -> undef + if (N1.getOpcode() == ISD::UNDEF) + return N1; + return SDValue(); } @@ -3544,26 +2477,20 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) { EVT VT = N->getValueType(0); SDLoc DL(N); - if (VT.isVector()) { - // fold (mulhs x, 0) -> 0 - if (ISD::isBuildVectorAllZeros(N1.getNode())) - return N1; - if (ISD::isBuildVectorAllZeros(N0.getNode())) - return N0; - } - // fold (mulhs x, 0) -> 0 if (isNullConstant(N1)) return N1; // fold (mulhs x, 1) -> (sra x, size(x)-1) - if (isOneConstant(N1)) + if (isOneConstant(N1)) { + SDLoc DL(N); return DAG.getNode(ISD::SRA, DL, N0.getValueType(), N0, - DAG.getConstant(N0.getValueSizeInBits() - 1, DL, + DAG.getConstant(N0.getValueType().getSizeInBits() - 1, + DL, getShiftAmountTy(N0.getValueType()))); - + } // fold (mulhs x, undef) -> 0 - if (N0.isUndef() || N1.isUndef()) - return DAG.getConstant(0, DL, VT); + if (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF) + return DAG.getConstant(0, SDLoc(N), VT); // If the type twice as wide is legal, transform the mulhs to a wider multiply // plus a shift. @@ -3591,14 +2518,6 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) { EVT VT = N->getValueType(0); SDLoc DL(N); - if (VT.isVector()) { - // fold (mulhu x, 0) -> 0 - if (ISD::isBuildVectorAllZeros(N1.getNode())) - return N1; - if (ISD::isBuildVectorAllZeros(N0.getNode())) - return N0; - } - // fold (mulhu x, 0) -> 0 if (isNullConstant(N1)) return N1; @@ -3606,22 +2525,9 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) { if (isOneConstant(N1)) return DAG.getConstant(0, DL, N0.getValueType()); // fold (mulhu x, undef) -> 0 - if (N0.isUndef() || N1.isUndef()) + if (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF) return DAG.getConstant(0, DL, VT); - // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c) - if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) && - DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) { - SDLoc DL(N); - unsigned NumEltBits = VT.getScalarSizeInBits(); - SDValue LogBase2 = BuildLogBase2(N1, DL); - SDValue SRLAmt = DAG.getNode( - ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2); - EVT ShiftVT = getShiftAmountTy(N0.getValueType()); - SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT); - return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc); - } - // If the type twice as wide is legal, transform the mulhu to a wider multiply // plus a shift. if (VT.isSimple() && !VT.isVector()) { @@ -3649,16 +2555,18 @@ SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp, unsigned HiOp) { // If the high half is not needed, just compute the low half. bool HiExists = N->hasAnyUseOfValue(1); - if (!HiExists && (!LegalOperations || - TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) { + if (!HiExists && + (!LegalOperations || + TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) { SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops()); return CombineTo(N, Res, Res); } // If the low half is not needed, just compute the high half. bool LoExists = N->hasAnyUseOfValue(0); - if (!LoExists && (!LegalOperations || - TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) { + if (!LoExists && + (!LegalOperations || + TLI.isOperationLegal(HiOp, N->getValueType(1)))) { SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops()); return CombineTo(N, Res, Res); } @@ -3674,7 +2582,7 @@ SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp, SDValue LoOpt = combine(Lo.getNode()); if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() && (!LegalOperations || - TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType()))) + TLI.isOperationLegal(LoOpt.getOpcode(), LoOpt.getValueType()))) return CombineTo(N, LoOpt, LoOpt); } @@ -3684,7 +2592,7 @@ SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp, SDValue HiOpt = combine(Hi.getNode()); if (HiOpt.getNode() && HiOpt != Hi && (!LegalOperations || - TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType()))) + TLI.isOperationLegal(HiOpt.getOpcode(), HiOpt.getValueType()))) return CombineTo(N, HiOpt, HiOpt); } @@ -3783,142 +2691,97 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) { if (SDValue FoldedVOp = SimplifyVBinOp(N)) return FoldedVOp; - // fold operation with constant operands. + // fold (add c1, c2) -> c1+c2 ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); ConstantSDNode *N1C = getAsNonOpaqueConstant(N1); if (N0C && N1C) return DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, N0C, N1C); // canonicalize constant to RHS - if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && - !DAG.isConstantIntBuildVectorOrConstantInt(N1)) + if (isConstantIntBuildVectorOrConstantInt(N0) && + !isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0); - // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX. - // Only do this if the current op isn't legal and the flipped is. - unsigned Opcode = N->getOpcode(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (!TLI.isOperationLegal(Opcode, VT) && - (N0.isUndef() || DAG.SignBitIsZero(N0)) && - (N1.isUndef() || DAG.SignBitIsZero(N1))) { - unsigned AltOpcode; - switch (Opcode) { - case ISD::SMIN: AltOpcode = ISD::UMIN; break; - case ISD::SMAX: AltOpcode = ISD::UMAX; break; - case ISD::UMIN: AltOpcode = ISD::SMIN; break; - case ISD::UMAX: AltOpcode = ISD::SMAX; break; - default: llvm_unreachable("Unknown MINMAX opcode"); - } - if (TLI.isOperationLegal(AltOpcode, VT)) - return DAG.getNode(AltOpcode, SDLoc(N), VT, N0, N1); - } - return SDValue(); } -/// If this is a bitwise logic instruction and both operands have the same -/// opcode, try to sink the other opcode after the logic instruction. -SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) { +/// If this is a binary operator with two operands of the same opcode, try to +/// simplify it. +SDValue DAGCombiner::SimplifyBinOpWithSameOpcodeHands(SDNode *N) { SDValue N0 = N->getOperand(0), N1 = N->getOperand(1); EVT VT = N0.getValueType(); - unsigned LogicOpcode = N->getOpcode(); - unsigned HandOpcode = N0.getOpcode(); - assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR || - LogicOpcode == ISD::XOR) && "Expected logic opcode"); - assert(HandOpcode == N1.getOpcode() && "Bad input!"); + assert(N0.getOpcode() == N1.getOpcode() && "Bad input!"); // Bail early if none of these transforms apply. - if (N0.getNumOperands() == 0) - return SDValue(); - - // FIXME: We should check number of uses of the operands to not increase - // the instruction count for all transforms. - - // Handle size-changing casts. - SDValue X = N0.getOperand(0); - SDValue Y = N1.getOperand(0); - EVT XVT = X.getValueType(); - SDLoc DL(N); - if (HandOpcode == ISD::ANY_EXTEND || HandOpcode == ISD::ZERO_EXTEND || - HandOpcode == ISD::SIGN_EXTEND) { - // If both operands have other uses, this transform would create extra - // instructions without eliminating anything. - if (!N0.hasOneUse() && !N1.hasOneUse()) - return SDValue(); - // We need matching integer source types. - if (XVT != Y.getValueType()) - return SDValue(); - // Don't create an illegal op during or after legalization. Don't ever - // create an unsupported vector op. - if ((VT.isVector() || LegalOperations) && - !TLI.isOperationLegalOrCustom(LogicOpcode, XVT)) - return SDValue(); - // Avoid infinite looping with PromoteIntBinOp. - // TODO: Should we apply desirable/legal constraints to all opcodes? - if (HandOpcode == ISD::ANY_EXTEND && LegalTypes && - !TLI.isTypeDesirableForOp(LogicOpcode, XVT)) - return SDValue(); - // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y) - SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y); - return DAG.getNode(HandOpcode, DL, VT, Logic); - } - - // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y) - if (HandOpcode == ISD::TRUNCATE) { - // If both operands have other uses, this transform would create extra - // instructions without eliminating anything. - if (!N0.hasOneUse() && !N1.hasOneUse()) - return SDValue(); - // We need matching source types. - if (XVT != Y.getValueType()) - return SDValue(); - // Don't create an illegal op during or after legalization. - if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT)) - return SDValue(); - // Be extra careful sinking truncate. If it's free, there's no benefit in - // widening a binop. Also, don't create a logic op on an illegal type. - if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT)) - return SDValue(); - if (!TLI.isTypeLegal(XVT)) - return SDValue(); - SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y); - return DAG.getNode(HandOpcode, DL, VT, Logic); - } - - // For binops SHL/SRL/SRA/AND: - // logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z - if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL || - HandOpcode == ISD::SRA || HandOpcode == ISD::AND) && + if (N0.getNode()->getNumOperands() == 0) return SDValue(); + + // For each of OP in AND/OR/XOR: + // fold (OP (zext x), (zext y)) -> (zext (OP x, y)) + // fold (OP (sext x), (sext y)) -> (sext (OP x, y)) + // fold (OP (aext x), (aext y)) -> (aext (OP x, y)) + // fold (OP (bswap x), (bswap y)) -> (bswap (OP x, y)) + // fold (OP (trunc x), (trunc y)) -> (trunc (OP x, y)) (if trunc isn't free) + // + // do not sink logical op inside of a vector extend, since it may combine + // into a vsetcc. + EVT Op0VT = N0.getOperand(0).getValueType(); + if ((N0.getOpcode() == ISD::ZERO_EXTEND || + N0.getOpcode() == ISD::SIGN_EXTEND || + N0.getOpcode() == ISD::BSWAP || + // Avoid infinite looping with PromoteIntBinOp. + (N0.getOpcode() == ISD::ANY_EXTEND && + (!LegalTypes || TLI.isTypeDesirableForOp(N->getOpcode(), Op0VT))) || + (N0.getOpcode() == ISD::TRUNCATE && + (!TLI.isZExtFree(VT, Op0VT) || + !TLI.isTruncateFree(Op0VT, VT)) && + TLI.isTypeLegal(Op0VT))) && + !VT.isVector() && + Op0VT == N1.getOperand(0).getValueType() && + (!LegalOperations || TLI.isOperationLegal(N->getOpcode(), Op0VT))) { + SDValue ORNode = DAG.getNode(N->getOpcode(), SDLoc(N0), + N0.getOperand(0).getValueType(), + N0.getOperand(0), N1.getOperand(0)); + AddToWorklist(ORNode.getNode()); + return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, ORNode); + } + + // For each of OP in SHL/SRL/SRA/AND... + // fold (and (OP x, z), (OP y, z)) -> (OP (and x, y), z) + // fold (or (OP x, z), (OP y, z)) -> (OP (or x, y), z) + // fold (xor (OP x, z), (OP y, z)) -> (OP (xor x, y), z) + if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL || + N0.getOpcode() == ISD::SRA || N0.getOpcode() == ISD::AND) && N0.getOperand(1) == N1.getOperand(1)) { - // If either operand has other uses, this transform is not an improvement. - if (!N0.hasOneUse() || !N1.hasOneUse()) - return SDValue(); - SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y); - return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1)); - } - - // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y) - if (HandOpcode == ISD::BSWAP) { - // If either operand has other uses, this transform is not an improvement. - if (!N0.hasOneUse() || !N1.hasOneUse()) - return SDValue(); - SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y); - return DAG.getNode(HandOpcode, DL, VT, Logic); + SDValue ORNode = DAG.getNode(N->getOpcode(), SDLoc(N0), + N0.getOperand(0).getValueType(), + N0.getOperand(0), N1.getOperand(0)); + AddToWorklist(ORNode.getNode()); + return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, + ORNode, N0.getOperand(1)); } // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B)) - // Only perform this optimization up until type legalization, before + // Only perform this optimization after type legalization and before // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and // we don't want to undo this promotion. // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper // on scalars. - if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) && - Level <= AfterLegalizeTypes) { - // Input types must be integer and the same. - if (XVT.isInteger() && XVT == Y.getValueType()) { - SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y); - return DAG.getNode(HandOpcode, DL, VT, Logic); + if ((N0.getOpcode() == ISD::BITCAST || + N0.getOpcode() == ISD::SCALAR_TO_VECTOR) && + Level == AfterLegalizeTypes) { + SDValue In0 = N0.getOperand(0); + SDValue In1 = N1.getOperand(0); + EVT In0Ty = In0.getValueType(); + EVT In1Ty = In1.getValueType(); + SDLoc DL(N); + // If both incoming values are integers, and the original types are the + // same. + if (In0Ty.isInteger() && In1Ty.isInteger() && In0Ty == In1Ty) { + SDValue Op = DAG.getNode(N->getOpcode(), DL, In0Ty, In0, In1); + SDValue BC = DAG.getNode(N0.getOpcode(), DL, VT, Op); + AddToWorklist(Op.getNode()); + return BC; } } @@ -3934,210 +2797,167 @@ SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) { // If both shuffles use the same mask, and both shuffles have the same first // or second operand, then it might still be profitable to move the shuffle // after the xor/and/or operation. - if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) { - auto *SVN0 = cast<ShuffleVectorSDNode>(N0); - auto *SVN1 = cast<ShuffleVectorSDNode>(N1); - assert(X.getValueType() == Y.getValueType() && + if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) { + ShuffleVectorSDNode *SVN0 = cast<ShuffleVectorSDNode>(N0); + ShuffleVectorSDNode *SVN1 = cast<ShuffleVectorSDNode>(N1); + + assert(N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() && "Inputs to shuffles are not the same type"); // Check that both shuffles use the same mask. The masks are known to be of // the same length because the result vector type is the same. // Check also that shuffles have only one use to avoid introducing extra // instructions. - if (!SVN0->hasOneUse() || !SVN1->hasOneUse() || - !SVN0->getMask().equals(SVN1->getMask())) - return SDValue(); - - // Don't try to fold this node if it requires introducing a - // build vector of all zeros that might be illegal at this stage. - SDValue ShOp = N0.getOperand(1); - if (LogicOpcode == ISD::XOR && !ShOp.isUndef()) - ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations); - - // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C) - if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) { - SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, - N0.getOperand(0), N1.getOperand(0)); - return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask()); - } - - // Don't try to fold this node if it requires introducing a - // build vector of all zeros that might be illegal at this stage. - ShOp = N0.getOperand(0); - if (LogicOpcode == ISD::XOR && !ShOp.isUndef()) - ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations); + if (SVN0->hasOneUse() && SVN1->hasOneUse() && + SVN0->getMask().equals(SVN1->getMask())) { + SDValue ShOp = N0->getOperand(1); + + // Don't try to fold this node if it requires introducing a + // build vector of all zeros that might be illegal at this stage. + if (N->getOpcode() == ISD::XOR && ShOp.getOpcode() != ISD::UNDEF) { + if (!LegalTypes) + ShOp = DAG.getConstant(0, SDLoc(N), VT); + else + ShOp = SDValue(); + } + + // (AND (shuf (A, C), shuf (B, C)) -> shuf (AND (A, B), C) + // (OR (shuf (A, C), shuf (B, C)) -> shuf (OR (A, B), C) + // (XOR (shuf (A, C), shuf (B, C)) -> shuf (XOR (A, B), V_0) + if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) { + SDValue NewNode = DAG.getNode(N->getOpcode(), SDLoc(N), VT, + N0->getOperand(0), N1->getOperand(0)); + AddToWorklist(NewNode.getNode()); + return DAG.getVectorShuffle(VT, SDLoc(N), NewNode, ShOp, + &SVN0->getMask()[0]); + } + + // Don't try to fold this node if it requires introducing a + // build vector of all zeros that might be illegal at this stage. + ShOp = N0->getOperand(0); + if (N->getOpcode() == ISD::XOR && ShOp.getOpcode() != ISD::UNDEF) { + if (!LegalTypes) + ShOp = DAG.getConstant(0, SDLoc(N), VT); + else + ShOp = SDValue(); + } - // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B)) - if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) { - SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1), - N1.getOperand(1)); - return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask()); + // (AND (shuf (C, A), shuf (C, B)) -> shuf (C, AND (A, B)) + // (OR (shuf (C, A), shuf (C, B)) -> shuf (C, OR (A, B)) + // (XOR (shuf (C, A), shuf (C, B)) -> shuf (V_0, XOR (A, B)) + if (N0->getOperand(0) == N1->getOperand(0) && ShOp.getNode()) { + SDValue NewNode = DAG.getNode(N->getOpcode(), SDLoc(N), VT, + N0->getOperand(1), N1->getOperand(1)); + AddToWorklist(NewNode.getNode()); + return DAG.getVectorShuffle(VT, SDLoc(N), ShOp, NewNode, + &SVN0->getMask()[0]); + } } } return SDValue(); } -/// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient. -SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1, - const SDLoc &DL) { - SDValue LL, LR, RL, RR, N0CC, N1CC; - if (!isSetCCEquivalent(N0, LL, LR, N0CC) || - !isSetCCEquivalent(N1, RL, RR, N1CC)) - return SDValue(); - - assert(N0.getValueType() == N1.getValueType() && - "Unexpected operand types for bitwise logic op"); - assert(LL.getValueType() == LR.getValueType() && - RL.getValueType() == RR.getValueType() && - "Unexpected operand types for setcc"); - - // If we're here post-legalization or the logic op type is not i1, the logic - // op type must match a setcc result type. Also, all folds require new - // operations on the left and right operands, so those types must match. - EVT VT = N0.getValueType(); - EVT OpVT = LL.getValueType(); - if (LegalOperations || VT.getScalarType() != MVT::i1) - if (VT != getSetCCResultType(OpVT)) - return SDValue(); - if (OpVT != RL.getValueType()) - return SDValue(); - - ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get(); - ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get(); - bool IsInteger = OpVT.isInteger(); - if (LR == RR && CC0 == CC1 && IsInteger) { - bool IsZero = isNullOrNullSplat(LR); - bool IsNeg1 = isAllOnesOrAllOnesSplat(LR); - - // All bits clear? - bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero; - // All sign bits clear? - bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1; - // Any bits set? - bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero; - // Any sign bits set? - bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero; - - // (and (seteq X, 0), (seteq Y, 0)) --> (seteq (or X, Y), 0) - // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1) - // (or (setne X, 0), (setne Y, 0)) --> (setne (or X, Y), 0) - // (or (setlt X, 0), (setlt Y, 0)) --> (setlt (or X, Y), 0) - if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) { - SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL); - AddToWorklist(Or.getNode()); - return DAG.getSetCC(DL, VT, Or, LR, CC1); - } - - // All bits set? - bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1; - // All sign bits set? - bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero; - // Any bits clear? - bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1; - // Any sign bits clear? - bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1; - - // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1) - // (and (setlt X, 0), (setlt Y, 0)) --> (setlt (and X, Y), 0) - // (or (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1) - // (or (setgt X, -1), (setgt Y -1)) --> (setgt (and X, Y), -1) - if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) { - SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL); - AddToWorklist(And.getNode()); - return DAG.getSetCC(DL, VT, And, LR, CC1); - } - } - - // TODO: What is the 'or' equivalent of this fold? - // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2) - if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 && - IsInteger && CC0 == ISD::SETNE && - ((isNullConstant(LR) && isAllOnesConstant(RR)) || - (isAllOnesConstant(LR) && isNullConstant(RR)))) { - SDValue One = DAG.getConstant(1, DL, OpVT); - SDValue Two = DAG.getConstant(2, DL, OpVT); - SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One); - AddToWorklist(Add.getNode()); - return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE); - } - - // Try more general transforms if the predicates match and the only user of - // the compares is the 'and' or 'or'. - if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 && - N0.hasOneUse() && N1.hasOneUse()) { - // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0 - // or (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0 - if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) { - SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR); - SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR); - SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR); - SDValue Zero = DAG.getConstant(0, DL, OpVT); - return DAG.getSetCC(DL, VT, Or, Zero, CC1); - } - } - - // Canonicalize equivalent operands to LL == RL. - if (LL == RR && LR == RL) { - CC1 = ISD::getSetCCSwappedOperands(CC1); - std::swap(RL, RR); - } - - // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC) - // (or (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC) - if (LL == RL && LR == RR) { - ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, IsInteger) - : ISD::getSetCCOrOperation(CC0, CC1, IsInteger); - if (NewCC != ISD::SETCC_INVALID && - (!LegalOperations || - (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) && - TLI.isOperationLegal(ISD::SETCC, OpVT)))) - return DAG.getSetCC(DL, VT, LL, LR, NewCC); - } - - return SDValue(); -} - /// This contains all DAGCombine rules which reduce two values combined by /// an And operation to a single value. This makes them reusable in the context /// of visitSELECT(). Rules involving constants are not included as /// visitSELECT() already handles those cases. -SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) { +SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, + SDNode *LocReference) { EVT VT = N1.getValueType(); - SDLoc DL(N); // fold (and x, undef) -> 0 - if (N0.isUndef() || N1.isUndef()) - return DAG.getConstant(0, DL, VT); - - if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL)) - return V; + if (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF) + return DAG.getConstant(0, SDLoc(LocReference), VT); + // fold (and (setcc x), (setcc y)) -> (setcc (and x, y)) + SDValue LL, LR, RL, RR, CC0, CC1; + if (isSetCCEquivalent(N0, LL, LR, CC0) && isSetCCEquivalent(N1, RL, RR, CC1)){ + ISD::CondCode Op0 = cast<CondCodeSDNode>(CC0)->get(); + ISD::CondCode Op1 = cast<CondCodeSDNode>(CC1)->get(); + + if (LR == RR && isa<ConstantSDNode>(LR) && Op0 == Op1 && + LL.getValueType().isInteger()) { + // fold (and (seteq X, 0), (seteq Y, 0)) -> (seteq (or X, Y), 0) + if (isNullConstant(LR) && Op1 == ISD::SETEQ) { + SDValue ORNode = DAG.getNode(ISD::OR, SDLoc(N0), + LR.getValueType(), LL, RL); + AddToWorklist(ORNode.getNode()); + return DAG.getSetCC(SDLoc(LocReference), VT, ORNode, LR, Op1); + } + if (isAllOnesConstant(LR)) { + // fold (and (seteq X, -1), (seteq Y, -1)) -> (seteq (and X, Y), -1) + if (Op1 == ISD::SETEQ) { + SDValue ANDNode = DAG.getNode(ISD::AND, SDLoc(N0), + LR.getValueType(), LL, RL); + AddToWorklist(ANDNode.getNode()); + return DAG.getSetCC(SDLoc(LocReference), VT, ANDNode, LR, Op1); + } + // fold (and (setgt X, -1), (setgt Y, -1)) -> (setgt (or X, Y), -1) + if (Op1 == ISD::SETGT) { + SDValue ORNode = DAG.getNode(ISD::OR, SDLoc(N0), + LR.getValueType(), LL, RL); + AddToWorklist(ORNode.getNode()); + return DAG.getSetCC(SDLoc(LocReference), VT, ORNode, LR, Op1); + } + } + } + // Simplify (and (setne X, 0), (setne X, -1)) -> (setuge (add X, 1), 2) + if (LL == RL && isa<ConstantSDNode>(LR) && isa<ConstantSDNode>(RR) && + Op0 == Op1 && LL.getValueType().isInteger() && + Op0 == ISD::SETNE && ((isNullConstant(LR) && isAllOnesConstant(RR)) || + (isAllOnesConstant(LR) && isNullConstant(RR)))) { + SDLoc DL(N0); + SDValue ADDNode = DAG.getNode(ISD::ADD, DL, LL.getValueType(), + LL, DAG.getConstant(1, DL, + LL.getValueType())); + AddToWorklist(ADDNode.getNode()); + return DAG.getSetCC(SDLoc(LocReference), VT, ADDNode, + DAG.getConstant(2, DL, LL.getValueType()), + ISD::SETUGE); + } + // canonicalize equivalent to ll == rl + if (LL == RR && LR == RL) { + Op1 = ISD::getSetCCSwappedOperands(Op1); + std::swap(RL, RR); + } + if (LL == RL && LR == RR) { + bool isInteger = LL.getValueType().isInteger(); + ISD::CondCode Result = ISD::getSetCCAndOperation(Op0, Op1, isInteger); + if (Result != ISD::SETCC_INVALID && + (!LegalOperations || + (TLI.isCondCodeLegal(Result, LL.getSimpleValueType()) && + TLI.isOperationLegal(ISD::SETCC, LL.getValueType())))) { + EVT CCVT = getSetCCResultType(LL.getValueType()); + if (N0.getValueType() == CCVT || + (!LegalOperations && N0.getValueType() == MVT::i1)) + return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(), + LL, LR, Result); + } + } + } if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL && VT.getSizeInBits() <= 64) { if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) { - if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) { + APInt ADDC = ADDI->getAPIntValue(); + if (!TLI.isLegalAddImmediate(ADDC.getSExtValue())) { // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal // immediate for an add, but it is legal if its top c2 bits are set, // transform the ADD so the immediate doesn't need to be materialized // in a register. - APInt ADDC = ADDI->getAPIntValue(); - APInt SRLC = SRLI->getAPIntValue(); - if (ADDC.getMinSignedBits() <= 64 && - SRLC.ult(VT.getSizeInBits()) && - !TLI.isLegalAddImmediate(ADDC.getSExtValue())) { + if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) { APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(), - SRLC.getZExtValue()); + SRLI->getZExtValue()); if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) { ADDC |= Mask; if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) { - SDLoc DL0(N0); + SDLoc DL(N0); SDValue NewAdd = - DAG.getNode(ISD::ADD, DL0, VT, + DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), DAG.getConstant(ADDC, DL, VT)); CombineTo(N0.getNode(), NewAdd); // Return N so it doesn't get rechecked! - return SDValue(N, 0); + return SDValue(LocReference, 0); } } } @@ -4145,73 +2965,26 @@ SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) { } } - // Reduce bit extract of low half of an integer to the narrower type. - // (and (srl i64:x, K), KMask) -> - // (i64 zero_extend (and (srl (i32 (trunc i64:x)), K)), KMask) - if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) { - if (ConstantSDNode *CAnd = dyn_cast<ConstantSDNode>(N1)) { - if (ConstantSDNode *CShift = dyn_cast<ConstantSDNode>(N0.getOperand(1))) { - unsigned Size = VT.getSizeInBits(); - const APInt &AndMask = CAnd->getAPIntValue(); - unsigned ShiftBits = CShift->getZExtValue(); - - // Bail out, this node will probably disappear anyway. - if (ShiftBits == 0) - return SDValue(); - - unsigned MaskBits = AndMask.countTrailingOnes(); - EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), Size / 2); - - if (AndMask.isMask() && - // Required bits must not span the two halves of the integer and - // must fit in the half size type. - (ShiftBits + MaskBits <= Size / 2) && - TLI.isNarrowingProfitable(VT, HalfVT) && - TLI.isTypeDesirableForOp(ISD::AND, HalfVT) && - TLI.isTypeDesirableForOp(ISD::SRL, HalfVT) && - TLI.isTruncateFree(VT, HalfVT) && - TLI.isZExtFree(HalfVT, VT)) { - // The isNarrowingProfitable is to avoid regressions on PPC and - // AArch64 which match a few 64-bit bit insert / bit extract patterns - // on downstream users of this. Those patterns could probably be - // extended to handle extensions mixed in. - - SDValue SL(N0); - assert(MaskBits <= Size); - - // Extracting the highest bit of the low half. - EVT ShiftVT = TLI.getShiftAmountTy(HalfVT, DAG.getDataLayout()); - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, HalfVT, - N0.getOperand(0)); - - SDValue NewMask = DAG.getConstant(AndMask.trunc(Size / 2), SL, HalfVT); - SDValue ShiftK = DAG.getConstant(ShiftBits, SL, ShiftVT); - SDValue Shift = DAG.getNode(ISD::SRL, SL, HalfVT, Trunc, ShiftK); - SDValue And = DAG.getNode(ISD::AND, SL, HalfVT, Shift, NewMask); - return DAG.getNode(ISD::ZERO_EXTEND, SL, VT, And); - } - } - } - } - return SDValue(); } bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, - EVT LoadResultTy, EVT &ExtVT) { - if (!AndC->getAPIntValue().isMask()) - return false; + EVT LoadResultTy, EVT &ExtVT, EVT &LoadedVT, + bool &NarrowLoad) { + uint32_t ActiveBits = AndC->getAPIntValue().getActiveBits(); - unsigned ActiveBits = AndC->getAPIntValue().countTrailingOnes(); + if (ActiveBits == 0 || !APIntOps::isMask(ActiveBits, AndC->getAPIntValue())) + return false; ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits); - EVT LoadedVT = LoadN->getMemoryVT(); + LoadedVT = LoadN->getMemoryVT(); if (ExtVT == LoadedVT && (!LegalOperations || TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) { // ZEXTLOAD will match without needing to change the size of the value being // loaded. + NarrowLoad = false; return true; } @@ -4231,306 +3004,15 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT)) return false; + NarrowLoad = true; return true; } -bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST, - ISD::LoadExtType ExtType, EVT &MemVT, - unsigned ShAmt) { - if (!LDST) - return false; - // Only allow byte offsets. - if (ShAmt % 8) - return false; - - // Do not generate loads of non-round integer types since these can - // be expensive (and would be wrong if the type is not byte sized). - if (!MemVT.isRound()) - return false; - - // Don't change the width of a volatile load. - if (LDST->isVolatile()) - return false; - - // Verify that we are actually reducing a load width here. - if (LDST->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits()) - return false; - - // Ensure that this isn't going to produce an unsupported unaligned access. - if (ShAmt && - !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT, - LDST->getAddressSpace(), ShAmt / 8)) - return false; - - // It's not possible to generate a constant of extended or untyped type. - EVT PtrType = LDST->getBasePtr().getValueType(); - if (PtrType == MVT::Untyped || PtrType.isExtended()) - return false; - - if (isa<LoadSDNode>(LDST)) { - LoadSDNode *Load = cast<LoadSDNode>(LDST); - // Don't transform one with multiple uses, this would require adding a new - // load. - if (!SDValue(Load, 0).hasOneUse()) - return false; - - if (LegalOperations && - !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT)) - return false; - - // For the transform to be legal, the load must produce only two values - // (the value loaded and the chain). Don't transform a pre-increment - // load, for example, which produces an extra value. Otherwise the - // transformation is not equivalent, and the downstream logic to replace - // uses gets things wrong. - if (Load->getNumValues() > 2) - return false; - - // If the load that we're shrinking is an extload and we're not just - // discarding the extension we can't simply shrink the load. Bail. - // TODO: It would be possible to merge the extensions in some cases. - if (Load->getExtensionType() != ISD::NON_EXTLOAD && - Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt) - return false; - - if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT)) - return false; - } else { - assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode"); - StoreSDNode *Store = cast<StoreSDNode>(LDST); - // Can't write outside the original store - if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt) - return false; - - if (LegalOperations && - !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT)) - return false; - } - return true; -} - -bool DAGCombiner::SearchForAndLoads(SDNode *N, - SmallVectorImpl<LoadSDNode*> &Loads, - SmallPtrSetImpl<SDNode*> &NodesWithConsts, - ConstantSDNode *Mask, - SDNode *&NodeToMask) { - // Recursively search for the operands, looking for loads which can be - // narrowed. - for (unsigned i = 0, e = N->getNumOperands(); i < e; ++i) { - SDValue Op = N->getOperand(i); - - if (Op.getValueType().isVector()) - return false; - - // Some constants may need fixing up later if they are too large. - if (auto *C = dyn_cast<ConstantSDNode>(Op)) { - if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) && - (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue()) - NodesWithConsts.insert(N); - continue; - } - - if (!Op.hasOneUse()) - return false; - - switch(Op.getOpcode()) { - case ISD::LOAD: { - auto *Load = cast<LoadSDNode>(Op); - EVT ExtVT; - if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) && - isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) { - - // ZEXTLOAD is already small enough. - if (Load->getExtensionType() == ISD::ZEXTLOAD && - ExtVT.bitsGE(Load->getMemoryVT())) - continue; - - // Use LE to convert equal sized loads to zext. - if (ExtVT.bitsLE(Load->getMemoryVT())) - Loads.push_back(Load); - - continue; - } - return false; - } - case ISD::ZERO_EXTEND: - case ISD::AssertZext: { - unsigned ActiveBits = Mask->getAPIntValue().countTrailingOnes(); - EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits); - EVT VT = Op.getOpcode() == ISD::AssertZext ? - cast<VTSDNode>(Op.getOperand(1))->getVT() : - Op.getOperand(0).getValueType(); - - // We can accept extending nodes if the mask is wider or an equal - // width to the original type. - if (ExtVT.bitsGE(VT)) - continue; - break; - } - case ISD::OR: - case ISD::XOR: - case ISD::AND: - if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask, - NodeToMask)) - return false; - continue; - } - - // Allow one node which will masked along with any loads found. - if (NodeToMask) - return false; - - // Also ensure that the node to be masked only produces one data result. - NodeToMask = Op.getNode(); - if (NodeToMask->getNumValues() > 1) { - bool HasValue = false; - for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) { - MVT VT = SDValue(NodeToMask, i).getSimpleValueType(); - if (VT != MVT::Glue && VT != MVT::Other) { - if (HasValue) { - NodeToMask = nullptr; - return false; - } - HasValue = true; - } - } - assert(HasValue && "Node to be masked has no data result?"); - } - } - return true; -} - -bool DAGCombiner::BackwardsPropagateMask(SDNode *N, SelectionDAG &DAG) { - auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1)); - if (!Mask) - return false; - - if (!Mask->getAPIntValue().isMask()) - return false; - - // No need to do anything if the and directly uses a load. - if (isa<LoadSDNode>(N->getOperand(0))) - return false; - - SmallVector<LoadSDNode*, 8> Loads; - SmallPtrSet<SDNode*, 2> NodesWithConsts; - SDNode *FixupNode = nullptr; - if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) { - if (Loads.size() == 0) - return false; - - LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump()); - SDValue MaskOp = N->getOperand(1); - - // If it exists, fixup the single node we allow in the tree that needs - // masking. - if (FixupNode) { - LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump()); - SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode), - FixupNode->getValueType(0), - SDValue(FixupNode, 0), MaskOp); - DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And); - if (And.getOpcode() == ISD ::AND) - DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp); - } - - // Narrow any constants that need it. - for (auto *LogicN : NodesWithConsts) { - SDValue Op0 = LogicN->getOperand(0); - SDValue Op1 = LogicN->getOperand(1); - - if (isa<ConstantSDNode>(Op0)) - std::swap(Op0, Op1); - - SDValue And = DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(), - Op1, MaskOp); - - DAG.UpdateNodeOperands(LogicN, Op0, And); - } - - // Create narrow loads. - for (auto *Load : Loads) { - LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump()); - SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0), - SDValue(Load, 0), MaskOp); - DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And); - if (And.getOpcode() == ISD ::AND) - And = SDValue( - DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0); - SDValue NewLoad = ReduceLoadWidth(And.getNode()); - assert(NewLoad && - "Shouldn't be masking the load if it can't be narrowed"); - CombineTo(Load, NewLoad, NewLoad.getValue(1)); - } - DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode()); - return true; - } - return false; -} - -// Unfold -// x & (-1 'logical shift' y) -// To -// (x 'opposite logical shift' y) 'logical shift' y -// if it is better for performance. -SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) { - assert(N->getOpcode() == ISD::AND); - - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - - // Do we actually prefer shifts over mask? - if (!TLI.preferShiftsToClearExtremeBits(N0)) - return SDValue(); - - // Try to match (-1 '[outer] logical shift' y) - unsigned OuterShift; - unsigned InnerShift; // The opposite direction to the OuterShift. - SDValue Y; // Shift amount. - auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool { - if (!M.hasOneUse()) - return false; - OuterShift = M->getOpcode(); - if (OuterShift == ISD::SHL) - InnerShift = ISD::SRL; - else if (OuterShift == ISD::SRL) - InnerShift = ISD::SHL; - else - return false; - if (!isAllOnesConstant(M->getOperand(0))) - return false; - Y = M->getOperand(1); - return true; - }; - - SDValue X; - if (matchMask(N1)) - X = N0; - else if (matchMask(N0)) - X = N1; - else - return SDValue(); - - SDLoc DL(N); - EVT VT = N->getValueType(0); - - // tmp = x 'opposite logical shift' y - SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y); - // ret = tmp 'logical shift' y - SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y); - - return T1; -} - SDValue DAGCombiner::visitAND(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N1.getValueType(); - // x & x --> x - if (N0 == N1) - return N0; - // fold vector ops if (VT.isVector()) { if (SDValue FoldedVOp = SimplifyVBinOp(N)) @@ -4539,12 +3021,16 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // fold (and x, 0) -> 0, vector edition if (ISD::isBuildVectorAllZeros(N0.getNode())) // do not return N0, because undef node may exist in N0 - return DAG.getConstant(APInt::getNullValue(N0.getScalarValueSizeInBits()), - SDLoc(N), N0.getValueType()); + return DAG.getConstant( + APInt::getNullValue( + N0.getValueType().getScalarType().getSizeInBits()), + SDLoc(N), N0.getValueType()); if (ISD::isBuildVectorAllZeros(N1.getNode())) // do not return N1, because undef node may exist in N1 - return DAG.getConstant(APInt::getNullValue(N1.getScalarValueSizeInBits()), - SDLoc(N), N1.getValueType()); + return DAG.getConstant( + APInt::getNullValue( + N1.getValueType().getScalarType().getSizeInBits()), + SDLoc(N), N1.getValueType()); // fold (and x, -1) -> x, vector edition if (ISD::isBuildVectorAllOnes(N0.getNode())) @@ -4555,46 +3041,34 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // fold (and c1, c2) -> c1&c2 ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); - ConstantSDNode *N1C = isConstOrConstSplat(N1); + ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); if (N0C && N1C && !N1C->isOpaque()) return DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, N0C, N1C); // canonicalize constant to RHS - if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && - !DAG.isConstantIntBuildVectorOrConstantInt(N1)) + if (isConstantIntBuildVectorOrConstantInt(N0) && + !isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0); // fold (and x, -1) -> x if (isAllOnesConstant(N1)) return N0; // if (and x, c) is known to be zero, return 0 - unsigned BitWidth = VT.getScalarSizeInBits(); + unsigned BitWidth = VT.getScalarType().getSizeInBits(); if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnesValue(BitWidth))) return DAG.getConstant(0, SDLoc(N), VT); - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; - // reassociate and - if (SDValue RAND = ReassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags())) + if (SDValue RAND = ReassociateOps(ISD::AND, SDLoc(N), N0, N1)) return RAND; - - // Try to convert a constant mask AND into a shuffle clear mask. - if (VT.isVector()) - if (SDValue Shuffle = XformToShuffleWithZero(N)) - return Shuffle; - // fold (and (or x, C), D) -> D if (C & D) == D - auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) { - return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue()); - }; - if (N0.getOpcode() == ISD::OR && - ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset)) - return N1; + if (N1C && N0.getOpcode() == ISD::OR) + if (ConstantSDNode *ORI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) + if ((ORI->getAPIntValue() & N1C->getAPIntValue()) == N1C->getAPIntValue()) + return N1; // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits. if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) { SDValue N0Op0 = N0.getOperand(0); APInt Mask = ~N1C->getAPIntValue(); - Mask = Mask.trunc(N0Op0.getScalarValueSizeInBits()); + Mask = Mask.trunc(N0Op0.getValueSizeInBits()); if (DAG.MaskedValueIsZero(N0Op0, Mask)) { SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N0.getValueType(), N0Op0); @@ -4616,10 +3090,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // the 'X' node here can either be nothing or an extract_vector_elt to catch // more cases. if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT && - N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() && - N0.getOperand(0).getOpcode() == ISD::LOAD && - N0.getOperand(0).getResNo() == 0) || - (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) { + N0.getOperand(0).getOpcode() == ISD::LOAD) || + N0.getOpcode() == ISD::LOAD) { LoadSDNode *Load = cast<LoadSDNode>( (N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(0) ); @@ -4645,7 +3117,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // that will apply equally to all members of the vector, so AND all the // lanes of the constant together. EVT VT = Vector->getValueType(0); - unsigned BitWidth = VT.getScalarSizeInBits(); + unsigned BitWidth = VT.getVectorElementType().getSizeInBits(); // If the splat value has been compressed to a bitlength lower // than the size of the vector lane, we need to re-expand it to @@ -4676,7 +3148,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // Resize the constant to the same size as the original memory access before // extension. If it is still the AllOnesValue then this AND is completely // unneeded. - Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits()); + Constant = + Constant.zextOrTrunc(Load->getMemoryVT().getScalarType().getSizeInBits()); bool B; switch (Load->getExtensionType()) { @@ -4690,10 +3163,6 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to // preserve semantics once we get rid of the AND. SDValue NewLoad(Load, 0); - - // Fold the AND away. NewLoad may get replaced immediately. - CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0); - if (Load->getExtensionType() == ISD::EXTLOAD) { NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD, Load->getValueType(0), SDLoc(Load), @@ -4711,6 +3180,10 @@ SDValue DAGCombiner::visitAND(SDNode *N) { } } + // Fold the AND away, taking care not to fold to the old load node if we + // replaced it. + CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0); + return SDValue(N, 0); // Return N so it doesn't get rechecked! } } @@ -4718,25 +3191,60 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // fold (and (load x), 255) -> (zextload x, i8) // fold (and (extload x, i16), 255) -> (zextload x, i8) // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8) - if (!VT.isVector() && N1C && (N0.getOpcode() == ISD::LOAD || - (N0.getOpcode() == ISD::ANY_EXTEND && - N0.getOperand(0).getOpcode() == ISD::LOAD))) { - if (SDValue Res = ReduceLoadWidth(N)) { - LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND - ? cast<LoadSDNode>(N0.getOperand(0)) : cast<LoadSDNode>(N0); - AddToWorklist(N); - DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 0), Res); - return SDValue(N, 0); - } - } + if (N1C && (N0.getOpcode() == ISD::LOAD || + (N0.getOpcode() == ISD::ANY_EXTEND && + N0.getOperand(0).getOpcode() == ISD::LOAD))) { + bool HasAnyExt = N0.getOpcode() == ISD::ANY_EXTEND; + LoadSDNode *LN0 = HasAnyExt + ? cast<LoadSDNode>(N0.getOperand(0)) + : cast<LoadSDNode>(N0); + if (LN0->getExtensionType() != ISD::SEXTLOAD && + LN0->isUnindexed() && N0.hasOneUse() && SDValue(LN0, 0).hasOneUse()) { + auto NarrowLoad = false; + EVT LoadResultTy = HasAnyExt ? LN0->getValueType(0) : VT; + EVT ExtVT, LoadedVT; + if (isAndLoadExtLoad(N1C, LN0, LoadResultTy, ExtVT, LoadedVT, + NarrowLoad)) { + if (!NarrowLoad) { + SDValue NewLoad = + DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN0), LoadResultTy, + LN0->getChain(), LN0->getBasePtr(), ExtVT, + LN0->getMemOperand()); + AddToWorklist(N); + CombineTo(LN0, NewLoad, NewLoad.getValue(1)); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } else { + EVT PtrType = LN0->getOperand(1).getValueType(); + + unsigned Alignment = LN0->getAlignment(); + SDValue NewPtr = LN0->getBasePtr(); + + // For big endian targets, we need to add an offset to the pointer + // to load the correct bytes. For little endian systems, we merely + // need to read fewer bytes from the same pointer. + if (DAG.getDataLayout().isBigEndian()) { + unsigned LVTStoreBytes = LoadedVT.getStoreSize(); + unsigned EVTStoreBytes = ExtVT.getStoreSize(); + unsigned PtrOff = LVTStoreBytes - EVTStoreBytes; + SDLoc DL(LN0); + NewPtr = DAG.getNode(ISD::ADD, DL, PtrType, + NewPtr, DAG.getConstant(PtrOff, DL, PtrType)); + Alignment = MinAlign(Alignment, PtrOff); + } - if (Level >= AfterLegalizeTypes) { - // Attempt to propagate the AND back up to the leaves which, if they're - // loads, can be combined to narrow loads and the AND node can be removed. - // Perform after legalization so that extend nodes will already be - // combined into the loads. - if (BackwardsPropagateMask(N, DAG)) { - return SDValue(N, 0); + AddToWorklist(NewPtr.getNode()); + + SDValue Load = + DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN0), LoadResultTy, + LN0->getChain(), NewPtr, + LN0->getPointerInfo(), + ExtVT, LN0->isVolatile(), LN0->isNonTemporal(), + LN0->isInvariant(), Alignment, LN0->getAAInfo()); + AddToWorklist(N); + CombineTo(LN0, Load, Load.getValue(1)); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } + } } } @@ -4745,31 +3253,13 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // Simplify: (and (op x...), (op y...)) -> (op (and x, y)) if (N0.getOpcode() == N1.getOpcode()) - if (SDValue V = hoistLogicOpWithSameOpcodeHands(N)) - return V; - - // Masking the negated extension of a boolean is just the zero-extended - // boolean: - // and (sub 0, zext(bool X)), 1 --> zext(bool X) - // and (sub 0, sext(bool X)), 1 --> zext(bool X) - // - // Note: the SimplifyDemandedBits fold below can make an information-losing - // transform, and then we have no way to find this better fold. - if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) { - if (isNullOrNullSplat(N0.getOperand(0))) { - SDValue SubRHS = N0.getOperand(1); - if (SubRHS.getOpcode() == ISD::ZERO_EXTEND && - SubRHS.getOperand(0).getScalarValueSizeInBits() == 1) - return SubRHS; - if (SubRHS.getOpcode() == ISD::SIGN_EXTEND && - SubRHS.getOperand(0).getScalarValueSizeInBits() == 1) - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, SubRHS.getOperand(0)); - } - } + if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N)) + return Tmp; // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1) // fold (and (sra)) -> (and (srl)) when possible. - if (SimplifyDemandedBits(SDValue(N, 0))) + if (!VT.isVector() && + SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); // fold (zext_inreg (extload x)) -> (zextload x) @@ -4778,9 +3268,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) { EVT MemVT = LN0->getMemoryVT(); // If we zero all the possible extended bits, then we can turn this into // a zextload if we are running before legalize or the operation is legal. - unsigned BitWidth = N1.getScalarValueSizeInBits(); + unsigned BitWidth = N1.getValueType().getScalarType().getSizeInBits(); if (DAG.MaskedValueIsZero(N1, APInt::getHighBitsSet(BitWidth, - BitWidth - MemVT.getScalarSizeInBits())) && + BitWidth - MemVT.getScalarType().getSizeInBits())) && ((!LegalOperations && !LN0->isVolatile()) || TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) { SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, @@ -4798,9 +3288,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) { EVT MemVT = LN0->getMemoryVT(); // If we zero all the possible extended bits, then we can turn this into // a zextload if we are running before legalize or the operation is legal. - unsigned BitWidth = N1.getScalarValueSizeInBits(); + unsigned BitWidth = N1.getValueType().getScalarType().getSizeInBits(); if (DAG.MaskedValueIsZero(N1, APInt::getHighBitsSet(BitWidth, - BitWidth - MemVT.getScalarSizeInBits())) && + BitWidth - MemVT.getScalarType().getSizeInBits())) && ((!LegalOperations && !LN0->isVolatile()) || TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) { SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, @@ -4813,14 +3303,12 @@ SDValue DAGCombiner::visitAND(SDNode *N) { } // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const) if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) { - if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0), - N0.getOperand(1), false)) + SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0), + N0.getOperand(1), false); + if (BSwap.getNode()) return BSwap; } - if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N)) - return Shifts; - return SDValue(); } @@ -4833,10 +3321,10 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, EVT VT = N->getValueType(0); if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16) return SDValue(); - if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT)) + if (!TLI.isOperationLegal(ISD::BSWAP, VT)) return SDValue(); - // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff) + // Recognize (and (shl a, 8), 0xff), (and (srl a, 8), 0xff00) bool LookPassAnd0 = false; bool LookPassAnd1 = false; if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL) @@ -4847,10 +3335,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, if (!N0.getNode()->hasOneUse()) return SDValue(); ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1)); - // Also handle 0xffff since the LHS is guaranteed to have zeros there. - // This is needed for X86. - if (!N01C || (N01C->getZExtValue() != 0xFF00 && - N01C->getZExtValue() != 0xFFFF)) + if (!N01C || N01C->getZExtValue() != 0xFF00) return SDValue(); N0 = N0.getOperand(0); LookPassAnd0 = true; @@ -4870,7 +3355,8 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, std::swap(N0, N1); if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL) return SDValue(); - if (!N0.getNode()->hasOneUse() || !N1.getNode()->hasOneUse()) + if (!N0.getNode()->hasOneUse() || + !N1.getNode()->hasOneUse()) return SDValue(); ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1)); @@ -4897,10 +3383,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, if (!N10.getNode()->hasOneUse()) return SDValue(); ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1)); - // Also allow 0xFFFF since the bits will be shifted out. This is needed - // for X86. - if (!N101C || (N101C->getZExtValue() != 0xFF00 && - N101C->getZExtValue() != 0xFFFF)) + if (!N101C || N101C->getZExtValue() != 0xFF00) return SDValue(); N10 = N10.getOperand(0); LookPassAnd1 = true; @@ -4951,44 +3434,27 @@ static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) { if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL) return false; - SDValue N0 = N.getOperand(0); - unsigned Opc0 = N0.getOpcode(); - if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL) - return false; - - ConstantSDNode *N1C = nullptr; - // SHL or SRL: look upstream for AND mask operand - if (Opc == ISD::AND) - N1C = dyn_cast<ConstantSDNode>(N.getOperand(1)); - else if (Opc0 == ISD::AND) - N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1)); + ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N.getOperand(1)); if (!N1C) return false; - unsigned MaskByteOffset; + unsigned Num; switch (N1C->getZExtValue()) { default: return false; - case 0xFF: MaskByteOffset = 0; break; - case 0xFF00: MaskByteOffset = 1; break; - case 0xFFFF: - // In case demanded bits didn't clear the bits that will be shifted out. - // This is needed for X86. - if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) { - MaskByteOffset = 1; - break; - } - return false; - case 0xFF0000: MaskByteOffset = 2; break; - case 0xFF000000: MaskByteOffset = 3; break; + case 0xFF: Num = 0; break; + case 0xFF00: Num = 1; break; + case 0xFF0000: Num = 2; break; + case 0xFF000000: Num = 3; break; } // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00). + SDValue N0 = N.getOperand(0); if (Opc == ISD::AND) { - if (MaskByteOffset == 0 || MaskByteOffset == 2) { + if (Num == 0 || Num == 2) { // (x >> 8) & 0xff // (x >> 8) & 0xff0000 - if (Opc0 != ISD::SRL) + if (N0.getOpcode() != ISD::SRL) return false; ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1)); if (!C || C->getZExtValue() != 8) @@ -4996,7 +3462,7 @@ static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) { } else { // (x << 8) & 0xff00 // (x << 8) & 0xff000000 - if (Opc0 != ISD::SHL) + if (N0.getOpcode() != ISD::SHL) return false; ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1)); if (!C || C->getZExtValue() != 8) @@ -5005,7 +3471,7 @@ static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) { } else if (Opc == ISD::SHL) { // (x & 0xff) << 8 // (x & 0xff0000) << 8 - if (MaskByteOffset != 0 && MaskByteOffset != 2) + if (Num != 0 && Num != 2) return false; ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1)); if (!C || C->getZExtValue() != 8) @@ -5013,17 +3479,17 @@ static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) { } else { // Opc == ISD::SRL // (x & 0xff00) >> 8 // (x & 0xff000000) >> 8 - if (MaskByteOffset != 1 && MaskByteOffset != 3) + if (Num != 1 && Num != 3) return false; ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1)); if (!C || C->getZExtValue() != 8) return false; } - if (Parts[MaskByteOffset]) + if (Parts[Num]) return false; - Parts[MaskByteOffset] = N0.getOperand(0).getNode(); + Parts[Num] = N0.getOperand(0).getNode(); return true; } @@ -5040,7 +3506,7 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) { EVT VT = N->getValueType(0); if (VT != MVT::i32) return SDValue(); - if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT)) + if (!TLI.isOperationLegal(ISD::BSWAP, VT)) return SDValue(); // Look for either @@ -5055,16 +3521,18 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) { if (N1.getOpcode() == ISD::OR && N00.getNumOperands() == 2 && N01.getNumOperands() == 2) { // (or (or (and), (and)), (or (and), (and))) - if (!isBSwapHWordElement(N00, Parts)) + SDValue N000 = N00.getOperand(0); + if (!isBSwapHWordElement(N000, Parts)) return SDValue(); - if (!isBSwapHWordElement(N01, Parts)) + SDValue N001 = N00.getOperand(1); + if (!isBSwapHWordElement(N001, Parts)) return SDValue(); - SDValue N10 = N1.getOperand(0); - if (!isBSwapHWordElement(N10, Parts)) + SDValue N010 = N01.getOperand(0); + if (!isBSwapHWordElement(N010, Parts)) return SDValue(); - SDValue N11 = N1.getOperand(1); - if (!isBSwapHWordElement(N11, Parts)) + SDValue N011 = N01.getOperand(1); + if (!isBSwapHWordElement(N011, Parts)) return SDValue(); } else { // (or (or (or (and), (and)), (and)), (and)) @@ -5104,16 +3572,59 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) { /// This contains all DAGCombine rules which reduce two values combined by /// an Or operation to a single value \see visitANDLike(). -SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) { +SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *LocReference) { EVT VT = N1.getValueType(); - SDLoc DL(N); - // fold (or x, undef) -> -1 - if (!LegalOperations && (N0.isUndef() || N1.isUndef())) - return DAG.getAllOnesConstant(DL, VT); - - if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL)) - return V; + if (!LegalOperations && + (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF)) { + EVT EltVT = VT.isVector() ? VT.getVectorElementType() : VT; + return DAG.getConstant(APInt::getAllOnesValue(EltVT.getSizeInBits()), + SDLoc(LocReference), VT); + } + // fold (or (setcc x), (setcc y)) -> (setcc (or x, y)) + SDValue LL, LR, RL, RR, CC0, CC1; + if (isSetCCEquivalent(N0, LL, LR, CC0) && isSetCCEquivalent(N1, RL, RR, CC1)){ + ISD::CondCode Op0 = cast<CondCodeSDNode>(CC0)->get(); + ISD::CondCode Op1 = cast<CondCodeSDNode>(CC1)->get(); + + if (LR == RR && Op0 == Op1 && LL.getValueType().isInteger()) { + // fold (or (setne X, 0), (setne Y, 0)) -> (setne (or X, Y), 0) + // fold (or (setlt X, 0), (setlt Y, 0)) -> (setne (or X, Y), 0) + if (isNullConstant(LR) && (Op1 == ISD::SETNE || Op1 == ISD::SETLT)) { + SDValue ORNode = DAG.getNode(ISD::OR, SDLoc(LR), + LR.getValueType(), LL, RL); + AddToWorklist(ORNode.getNode()); + return DAG.getSetCC(SDLoc(LocReference), VT, ORNode, LR, Op1); + } + // fold (or (setne X, -1), (setne Y, -1)) -> (setne (and X, Y), -1) + // fold (or (setgt X, -1), (setgt Y -1)) -> (setgt (and X, Y), -1) + if (isAllOnesConstant(LR) && (Op1 == ISD::SETNE || Op1 == ISD::SETGT)) { + SDValue ANDNode = DAG.getNode(ISD::AND, SDLoc(LR), + LR.getValueType(), LL, RL); + AddToWorklist(ANDNode.getNode()); + return DAG.getSetCC(SDLoc(LocReference), VT, ANDNode, LR, Op1); + } + } + // canonicalize equivalent to ll == rl + if (LL == RR && LR == RL) { + Op1 = ISD::getSetCCSwappedOperands(Op1); + std::swap(RL, RR); + } + if (LL == RL && LR == RR) { + bool isInteger = LL.getValueType().isInteger(); + ISD::CondCode Result = ISD::getSetCCOrOperation(Op0, Op1, isInteger); + if (Result != ISD::SETCC_INVALID && + (!LegalOperations || + (TLI.isCondCodeLegal(Result, LL.getSimpleValueType()) && + TLI.isOperationLegal(ISD::SETCC, LL.getValueType())))) { + EVT CCVT = getSetCCResultType(LL.getValueType()); + if (N0.getValueType() == CCVT || + (!LegalOperations && N0.getValueType() == MVT::i1)) + return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(), + LL, LR, Result); + } + } + } // (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible. if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND && @@ -5134,6 +3645,7 @@ SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) { DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) { SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1.getOperand(0)); + SDLoc DL(LocReference); return DAG.getNode(ISD::AND, DL, VT, X, DAG.getConstant(LHSMask | RHSMask, DL, VT)); } @@ -5149,7 +3661,7 @@ SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) { (N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) { SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(1), N1.getOperand(1)); - return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X); + return DAG.getNode(ISD::AND, SDLoc(LocReference), VT, N0.getOperand(0), X); } return SDValue(); @@ -5160,10 +3672,6 @@ SDValue DAGCombiner::visitOR(SDNode *N) { SDValue N1 = N->getOperand(1); EVT VT = N1.getValueType(); - // x | x --> x - if (N0 == N1) - return N0; - // fold vector ops if (VT.isVector()) { if (SDValue FoldedVOp = SimplifyVBinOp(N)) @@ -5178,75 +3686,70 @@ SDValue DAGCombiner::visitOR(SDNode *N) { // fold (or x, -1) -> -1, vector edition if (ISD::isBuildVectorAllOnes(N0.getNode())) // do not return N0, because undef node may exist in N0 - return DAG.getAllOnesConstant(SDLoc(N), N0.getValueType()); + return DAG.getConstant( + APInt::getAllOnesValue( + N0.getValueType().getScalarType().getSizeInBits()), + SDLoc(N), N0.getValueType()); if (ISD::isBuildVectorAllOnes(N1.getNode())) // do not return N1, because undef node may exist in N1 - return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType()); + return DAG.getConstant( + APInt::getAllOnesValue( + N1.getValueType().getScalarType().getSizeInBits()), + SDLoc(N), N1.getValueType()); - // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask) + // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask1) + // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf B, A, Mask2) // Do this only if the resulting shuffle is legal. if (isa<ShuffleVectorSDNode>(N0) && isa<ShuffleVectorSDNode>(N1) && // Avoid folding a node with illegal type. - TLI.isTypeLegal(VT)) { - bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode()); - bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode()); - bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode()); - bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode()); - // Ensure both shuffles have a zero input. - if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) { - assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!"); - assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!"); - const ShuffleVectorSDNode *SV0 = cast<ShuffleVectorSDNode>(N0); - const ShuffleVectorSDNode *SV1 = cast<ShuffleVectorSDNode>(N1); - bool CanFold = true; - int NumElts = VT.getVectorNumElements(); - SmallVector<int, 4> Mask(NumElts); - - for (int i = 0; i != NumElts; ++i) { - int M0 = SV0->getMaskElt(i); - int M1 = SV1->getMaskElt(i); - - // Determine if either index is pointing to a zero vector. - bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts)); - bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts)); - - // If one element is zero and the otherside is undef, keep undef. - // This also handles the case that both are undef. - if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0)) { - Mask[i] = -1; - continue; - } - - // Make sure only one of the elements is zero. - if (M0Zero == M1Zero) { - CanFold = false; - break; - } - - assert((M0 >= 0 || M1 >= 0) && "Undef index!"); - - // We have a zero and non-zero element. If the non-zero came from - // SV0 make the index a LHS index. If it came from SV1, make it - // a RHS index. We need to mod by NumElts because we don't care - // which operand it came from in the original shuffles. - Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts; + TLI.isTypeLegal(VT) && + N0->getOperand(1) == N1->getOperand(1) && + ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode())) { + bool CanFold = true; + unsigned NumElts = VT.getVectorNumElements(); + const ShuffleVectorSDNode *SV0 = cast<ShuffleVectorSDNode>(N0); + const ShuffleVectorSDNode *SV1 = cast<ShuffleVectorSDNode>(N1); + // We construct two shuffle masks: + // - Mask1 is a shuffle mask for a shuffle with N0 as the first operand + // and N1 as the second operand. + // - Mask2 is a shuffle mask for a shuffle with N1 as the first operand + // and N0 as the second operand. + // We do this because OR is commutable and therefore there might be + // two ways to fold this node into a shuffle. + SmallVector<int,4> Mask1; + SmallVector<int,4> Mask2; + + for (unsigned i = 0; i != NumElts && CanFold; ++i) { + int M0 = SV0->getMaskElt(i); + int M1 = SV1->getMaskElt(i); + + // Both shuffle indexes are undef. Propagate Undef. + if (M0 < 0 && M1 < 0) { + Mask1.push_back(M0); + Mask2.push_back(M0); + continue; } - if (CanFold) { - SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0); - SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0); + if (M0 < 0 || M1 < 0 || + (M0 < (int)NumElts && M1 < (int)NumElts) || + (M0 >= (int)NumElts && M1 >= (int)NumElts)) { + CanFold = false; + break; + } - bool LegalMask = TLI.isShuffleMaskLegal(Mask, VT); - if (!LegalMask) { - std::swap(NewLHS, NewRHS); - ShuffleVectorSDNode::commuteMask(Mask); - LegalMask = TLI.isShuffleMaskLegal(Mask, VT); - } + Mask1.push_back(M0 < (int)NumElts ? M0 : M1 + NumElts); + Mask2.push_back(M1 < (int)NumElts ? M1 : M0 + NumElts); + } - if (LegalMask) - return DAG.getVectorShuffle(VT, SDLoc(N), NewLHS, NewRHS, Mask); - } + if (CanFold) { + // Fold this sequence only if the resulting shuffle is 'legal'. + if (TLI.isShuffleMaskLegal(Mask1, VT)) + return DAG.getVectorShuffle(VT, SDLoc(N), N0->getOperand(0), + N1->getOperand(0), &Mask1[0]); + if (TLI.isShuffleMaskLegal(Mask2, VT)) + return DAG.getVectorShuffle(VT, SDLoc(N), N1->getOperand(0), + N0->getOperand(0), &Mask2[0]); } } } @@ -5257,8 +3760,8 @@ SDValue DAGCombiner::visitOR(SDNode *N) { if (N0C && N1C && !N1C->isOpaque()) return DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, N0C, N1C); // canonicalize constant to RHS - if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && - !DAG.isConstantIntBuildVectorOrConstantInt(N1)) + if (isConstantIntBuildVectorOrConstantInt(N0) && + !isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0); // fold (or x, 0) -> x if (isNullConstant(N1)) @@ -5266,10 +3769,6 @@ SDValue DAGCombiner::visitOR(SDNode *N) { // fold (or x, -1) -> -1 if (isAllOnesConstant(N1)) return N1; - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; - // fold (or x, c) -> c iff (x & ~c) == 0 if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue())) return N1; @@ -5284,175 +3783,56 @@ SDValue DAGCombiner::visitOR(SDNode *N) { return BSwap; // reassociate or - if (SDValue ROR = ReassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags())) + if (SDValue ROR = ReassociateOps(ISD::OR, SDLoc(N), N0, N1)) return ROR; - // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2) - // iff (c1 & c2) != 0 or c1/c2 are undef. - auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) { - return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue()); - }; - if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() && - ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) { - if (SDValue COR = DAG.FoldConstantArithmetic( - ISD::OR, SDLoc(N1), VT, N1.getNode(), N0.getOperand(1).getNode())) { - SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1); - AddToWorklist(IOR.getNode()); - return DAG.getNode(ISD::AND, SDLoc(N), VT, COR, IOR); + // iff (c1 & c2) == 0. + if (N1C && N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() && + isa<ConstantSDNode>(N0.getOperand(1))) { + ConstantSDNode *C1 = cast<ConstantSDNode>(N0.getOperand(1)); + if ((C1->getAPIntValue() & N1C->getAPIntValue()) != 0) { + if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT, + N1C, C1)) + return DAG.getNode( + ISD::AND, SDLoc(N), VT, + DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1), COR); + return SDValue(); } } - // Simplify: (or (op x...), (op y...)) -> (op (or x, y)) if (N0.getOpcode() == N1.getOpcode()) - if (SDValue V = hoistLogicOpWithSameOpcodeHands(N)) - return V; + if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N)) + return Tmp; // See if this is some rotate idiom. if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N))) return SDValue(Rot, 0); - if (SDValue Load = MatchLoadCombine(N)) - return Load; - // Simplify the operands using demanded-bits information. - if (SimplifyDemandedBits(SDValue(N, 0))) + if (!VT.isVector() && + SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); return SDValue(); } -static SDValue stripConstantMask(SelectionDAG &DAG, SDValue Op, SDValue &Mask) { - if (Op.getOpcode() == ISD::AND && - DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) { - Mask = Op.getOperand(1); - return Op.getOperand(0); +/// Match "(X shl/srl V1) & V2" where V2 may not be present. +static bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) { + if (Op.getOpcode() == ISD::AND) { + if (isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) { + Mask = Op.getOperand(1); + Op = Op.getOperand(0); + } else { + return false; + } } - return Op; -} -/// Match "(X shl/srl V1) & V2" where V2 may not be present. -static bool matchRotateHalf(SelectionDAG &DAG, SDValue Op, SDValue &Shift, - SDValue &Mask) { - Op = stripConstantMask(DAG, Op, Mask); if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) { Shift = Op; return true; } - return false; -} - -/// Helper function for visitOR to extract the needed side of a rotate idiom -/// from a shl/srl/mul/udiv. This is meant to handle cases where -/// InstCombine merged some outside op with one of the shifts from -/// the rotate pattern. -/// \returns An empty \c SDValue if the needed shift couldn't be extracted. -/// Otherwise, returns an expansion of \p ExtractFrom based on the following -/// patterns: -/// -/// (or (mul v c0) (shrl (mul v c1) c2)): -/// expands (mul v c0) -> (shl (mul v c1) c3) -/// -/// (or (udiv v c0) (shl (udiv v c1) c2)): -/// expands (udiv v c0) -> (shrl (udiv v c1) c3) -/// -/// (or (shl v c0) (shrl (shl v c1) c2)): -/// expands (shl v c0) -> (shl (shl v c1) c3) -/// -/// (or (shrl v c0) (shl (shrl v c1) c2)): -/// expands (shrl v c0) -> (shrl (shrl v c1) c3) -/// -/// Such that in all cases, c3+c2==bitwidth(op v c1). -static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift, - SDValue ExtractFrom, SDValue &Mask, - const SDLoc &DL) { - assert(OppShift && ExtractFrom && "Empty SDValue"); - assert( - (OppShift.getOpcode() == ISD::SHL || OppShift.getOpcode() == ISD::SRL) && - "Existing shift must be valid as a rotate half"); - - ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask); - // Preconditions: - // (or (op0 v c0) (shiftl/r (op0 v c1) c2)) - // - // Find opcode of the needed shift to be extracted from (op0 v c0). - unsigned Opcode = ISD::DELETED_NODE; - bool IsMulOrDiv = false; - // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift - // opcode or its arithmetic (mul or udiv) variant. - auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) { - IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant; - if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift) - return false; - Opcode = NeededShift; - return true; - }; - // op0 must be either the needed shift opcode or the mul/udiv equivalent - // that the needed shift can be extracted from. - if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) && - (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV))) - return SDValue(); - - // op0 must be the same opcode on both sides, have the same LHS argument, - // and produce the same value type. - SDValue OppShiftLHS = OppShift.getOperand(0); - EVT ShiftedVT = OppShiftLHS.getValueType(); - if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() || - OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) || - ShiftedVT != ExtractFrom.getValueType()) - return SDValue(); - - // Amount of the existing shift. - ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1)); - // Constant mul/udiv/shift amount from the RHS of the shift's LHS op. - ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1)); - // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op. - ConstantSDNode *ExtractFromCst = - isConstOrConstSplat(ExtractFrom.getOperand(1)); - // TODO: We should be able to handle non-uniform constant vectors for these values - // Check that we have constant values. - if (!OppShiftCst || !OppShiftCst->getAPIntValue() || - !OppLHSCst || !OppLHSCst->getAPIntValue() || - !ExtractFromCst || !ExtractFromCst->getAPIntValue()) - return SDValue(); - - // Compute the shift amount we need to extract to complete the rotate. - const unsigned VTWidth = ShiftedVT.getScalarSizeInBits(); - if (OppShiftCst->getAPIntValue().ugt(VTWidth)) - return SDValue(); - APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue(); - // Normalize the bitwidth of the two mul/udiv/shift constant operands. - APInt ExtractFromAmt = ExtractFromCst->getAPIntValue(); - APInt OppLHSAmt = OppLHSCst->getAPIntValue(); - zeroExtendToMatch(ExtractFromAmt, OppLHSAmt); - - // Now try extract the needed shift from the ExtractFrom op and see if the - // result matches up with the existing shift's LHS op. - if (IsMulOrDiv) { - // Op to extract from is a mul or udiv by a constant. - // Check: - // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0 - // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0 - const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(), - NeededShiftAmt.getZExtValue()); - APInt ResultAmt; - APInt Rem; - APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem); - if (Rem != 0 || ResultAmt != OppLHSAmt) - return SDValue(); - } else { - // Op to extract from is a shift by a constant. - // Check: - // c2 - (bitwidth(op0 v c0) - c1) == c0 - if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc( - ExtractFromAmt.getBitWidth())) - return SDValue(); - } - // Return the expanded shift op that should allow a rotate to be formed. - EVT ShiftVT = OppShift.getOperand(1).getValueType(); - EVT ResVT = ExtractFrom.getValueType(); - SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT); - return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode); + return false; } // Return true if we can prove that, whenever Neg and Pos are both in the @@ -5464,8 +3844,7 @@ static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift, // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate // in direction shift1 by Neg. The range [0, EltSize) means that we only need // to consider shift amounts with defined behavior. -static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize, - SelectionDAG &DAG) { +static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize) { // If EltSize is a power of 2 then: // // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1) @@ -5500,12 +3879,9 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize, unsigned MaskLoBits = 0; if (Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) { if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) { - KnownBits Known = DAG.computeKnownBits(Neg.getOperand(0)); - unsigned Bits = Log2_64(EltSize); - if (NegC->getAPIntValue().getActiveBits() <= Bits && - ((NegC->getAPIntValue() | Known.Zero).countTrailingOnes() >= Bits)) { + if (NegC->getAPIntValue() == EltSize - 1) { Neg = Neg.getOperand(0); - MaskLoBits = Bits; + MaskLoBits = Log2_64(EltSize); } } } @@ -5520,15 +3896,10 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize, // On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with // Pos'. The truncation is redundant for the purpose of the equality. - if (MaskLoBits && Pos.getOpcode() == ISD::AND) { - if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) { - KnownBits Known = DAG.computeKnownBits(Pos.getOperand(0)); - if (PosC->getAPIntValue().getActiveBits() <= MaskLoBits && - ((PosC->getAPIntValue() | Known.Zero).countTrailingOnes() >= - MaskLoBits)) + if (MaskLoBits && Pos.getOpcode() == ISD::AND) + if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) + if (PosC->getAPIntValue() == EltSize - 1) Pos = Pos.getOperand(0); - } - } // The condition we need is now: // @@ -5575,7 +3946,7 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize, SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg, SDValue InnerPos, SDValue InnerNeg, unsigned PosOpcode, - unsigned NegOpcode, const SDLoc &DL) { + unsigned NegOpcode, SDLoc DL) { // fold (or (shl x, (*ext y)), // (srl x, (*ext (sub 32, y)))) -> // (rotl x, y) or (rotr x, (sub 32, y)) @@ -5584,7 +3955,7 @@ SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, // (srl x, (*ext y))) -> // (rotr x, y) or (rotl x, (sub 32, y)) EVT VT = Shifted.getValueType(); - if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG)) { + if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits())) { bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT); return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted, HasPos ? Pos : Neg).getNode(); @@ -5596,63 +3967,26 @@ SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, // MatchRotate - Handle an 'or' of two operands. If this is one of the many // idioms for rotate, and if the target supports rotation instructions, generate // a rot[lr]. -SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { +SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) { // Must be a legal type. Expanded 'n promoted things won't work with rotates. EVT VT = LHS.getValueType(); if (!TLI.isTypeLegal(VT)) return nullptr; // The target must have at least one rotate flavor. - bool HasROTL = hasOperation(ISD::ROTL, VT); - bool HasROTR = hasOperation(ISD::ROTR, VT); + bool HasROTL = TLI.isOperationLegalOrCustom(ISD::ROTL, VT); + bool HasROTR = TLI.isOperationLegalOrCustom(ISD::ROTR, VT); if (!HasROTL && !HasROTR) return nullptr; - // Check for truncated rotate. - if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE && - LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) { - assert(LHS.getValueType() == RHS.getValueType()); - if (SDNode *Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) { - return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), - SDValue(Rot, 0)).getNode(); - } - } - // Match "(X shl/srl V1) & V2" where V2 may not be present. SDValue LHSShift; // The shift. SDValue LHSMask; // AND value if any. - matchRotateHalf(DAG, LHS, LHSShift, LHSMask); + if (!MatchRotateHalf(LHS, LHSShift, LHSMask)) + return nullptr; // Not part of a rotate. SDValue RHSShift; // The shift. SDValue RHSMask; // AND value if any. - matchRotateHalf(DAG, RHS, RHSShift, RHSMask); - - // If neither side matched a rotate half, bail - if (!LHSShift && !RHSShift) - return nullptr; - - // InstCombine may have combined a constant shl, srl, mul, or udiv with one - // side of the rotate, so try to handle that here. In all cases we need to - // pass the matched shift from the opposite side to compute the opcode and - // needed shift amount to extract. We still want to do this if both sides - // matched a rotate half because one half may be a potential overshift that - // can be broken down (ie if InstCombine merged two shl or srl ops into a - // single one). - - // Have LHS side of the rotate, try to extract the needed shift from the RHS. - if (LHSShift) - if (SDValue NewRHSShift = - extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL)) - RHSShift = NewRHSShift; - // Have RHS side of the rotate, try to extract the needed shift from the LHS. - if (RHSShift) - if (SDValue NewLHSShift = - extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL)) - LHSShift = NewLHSShift; - - // If a side is still missing, nothing else we can do. - if (!RHSShift || !LHSShift) - return nullptr; - - // At this point we've matched or extracted a shift op on each side. + if (!MatchRotateHalf(RHS, RHSShift, RHSMask)) + return nullptr; // Not part of a rotate. if (LHSShift.getOperand(0) != RHSShift.getOperand(0)) return nullptr; // Not shifting the same value. @@ -5675,28 +4009,31 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1) // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2) - auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS, - ConstantSDNode *RHS) { - return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits; - }; - if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) { + if (isConstOrConstSplat(LHSShiftAmt) && isConstOrConstSplat(RHSShiftAmt)) { + uint64_t LShVal = isConstOrConstSplat(LHSShiftAmt)->getZExtValue(); + uint64_t RShVal = isConstOrConstSplat(RHSShiftAmt)->getZExtValue(); + if ((LShVal + RShVal) != EltSizeInBits) + return nullptr; + SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt); // If there is an AND of either shifted operand, apply it to the result. if (LHSMask.getNode() || RHSMask.getNode()) { - SDValue AllOnes = DAG.getAllOnesConstant(DL, VT); - SDValue Mask = AllOnes; + APInt AllBits = APInt::getAllOnesValue(EltSizeInBits); + SDValue Mask = DAG.getConstant(AllBits, DL, VT); if (LHSMask.getNode()) { - SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt); + APInt RHSBits = APInt::getLowBitsSet(EltSizeInBits, LShVal); Mask = DAG.getNode(ISD::AND, DL, VT, Mask, - DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits)); + DAG.getNode(ISD::OR, DL, VT, LHSMask, + DAG.getConstant(RHSBits, DL, VT))); } if (RHSMask.getNode()) { - SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt); + APInt LHSBits = APInt::getHighBitsSet(EltSizeInBits, RShVal); Mask = DAG.getNode(ISD::AND, DL, VT, Mask, - DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits)); + DAG.getNode(ISD::OR, DL, VT, RHSMask, + DAG.getConstant(LHSBits, DL, VT))); } Rot = DAG.getNode(ISD::AND, DL, VT, Rot, Mask); @@ -5738,385 +4075,6 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { return nullptr; } -namespace { - -/// Represents known origin of an individual byte in load combine pattern. The -/// value of the byte is either constant zero or comes from memory. -struct ByteProvider { - // For constant zero providers Load is set to nullptr. For memory providers - // Load represents the node which loads the byte from memory. - // ByteOffset is the offset of the byte in the value produced by the load. - LoadSDNode *Load = nullptr; - unsigned ByteOffset = 0; - - ByteProvider() = default; - - static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) { - return ByteProvider(Load, ByteOffset); - } - - static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0); } - - bool isConstantZero() const { return !Load; } - bool isMemory() const { return Load; } - - bool operator==(const ByteProvider &Other) const { - return Other.Load == Load && Other.ByteOffset == ByteOffset; - } - -private: - ByteProvider(LoadSDNode *Load, unsigned ByteOffset) - : Load(Load), ByteOffset(ByteOffset) {} -}; - -} // end anonymous namespace - -/// Recursively traverses the expression calculating the origin of the requested -/// byte of the given value. Returns None if the provider can't be calculated. -/// -/// For all the values except the root of the expression verifies that the value -/// has exactly one use and if it's not true return None. This way if the origin -/// of the byte is returned it's guaranteed that the values which contribute to -/// the byte are not used outside of this expression. -/// -/// Because the parts of the expression are not allowed to have more than one -/// use this function iterates over trees, not DAGs. So it never visits the same -/// node more than once. -static const Optional<ByteProvider> -calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, - bool Root = false) { - // Typical i64 by i8 pattern requires recursion up to 8 calls depth - if (Depth == 10) - return None; - - if (!Root && !Op.hasOneUse()) - return None; - - assert(Op.getValueType().isScalarInteger() && "can't handle other types"); - unsigned BitWidth = Op.getValueSizeInBits(); - if (BitWidth % 8 != 0) - return None; - unsigned ByteWidth = BitWidth / 8; - assert(Index < ByteWidth && "invalid index requested"); - (void) ByteWidth; - - switch (Op.getOpcode()) { - case ISD::OR: { - auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1); - if (!LHS) - return None; - auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1); - if (!RHS) - return None; - - if (LHS->isConstantZero()) - return RHS; - if (RHS->isConstantZero()) - return LHS; - return None; - } - case ISD::SHL: { - auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1)); - if (!ShiftOp) - return None; - - uint64_t BitShift = ShiftOp->getZExtValue(); - if (BitShift % 8 != 0) - return None; - uint64_t ByteShift = BitShift / 8; - - return Index < ByteShift - ? ByteProvider::getConstantZero() - : calculateByteProvider(Op->getOperand(0), Index - ByteShift, - Depth + 1); - } - case ISD::ANY_EXTEND: - case ISD::SIGN_EXTEND: - case ISD::ZERO_EXTEND: { - SDValue NarrowOp = Op->getOperand(0); - unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits(); - if (NarrowBitWidth % 8 != 0) - return None; - uint64_t NarrowByteWidth = NarrowBitWidth / 8; - - if (Index >= NarrowByteWidth) - return Op.getOpcode() == ISD::ZERO_EXTEND - ? Optional<ByteProvider>(ByteProvider::getConstantZero()) - : None; - return calculateByteProvider(NarrowOp, Index, Depth + 1); - } - case ISD::BSWAP: - return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1, - Depth + 1); - case ISD::LOAD: { - auto L = cast<LoadSDNode>(Op.getNode()); - if (L->isVolatile() || L->isIndexed()) - return None; - - unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits(); - if (NarrowBitWidth % 8 != 0) - return None; - uint64_t NarrowByteWidth = NarrowBitWidth / 8; - - if (Index >= NarrowByteWidth) - return L->getExtensionType() == ISD::ZEXTLOAD - ? Optional<ByteProvider>(ByteProvider::getConstantZero()) - : None; - return ByteProvider::getMemory(L, Index); - } - } - - return None; -} - -/// Match a pattern where a wide type scalar value is loaded by several narrow -/// loads and combined by shifts and ors. Fold it into a single load or a load -/// and a BSWAP if the targets supports it. -/// -/// Assuming little endian target: -/// i8 *a = ... -/// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24) -/// => -/// i32 val = *((i32)a) -/// -/// i8 *a = ... -/// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3] -/// => -/// i32 val = BSWAP(*((i32)a)) -/// -/// TODO: This rule matches complex patterns with OR node roots and doesn't -/// interact well with the worklist mechanism. When a part of the pattern is -/// updated (e.g. one of the loads) its direct users are put into the worklist, -/// but the root node of the pattern which triggers the load combine is not -/// necessarily a direct user of the changed node. For example, once the address -/// of t28 load is reassociated load combine won't be triggered: -/// t25: i32 = add t4, Constant:i32<2> -/// t26: i64 = sign_extend t25 -/// t27: i64 = add t2, t26 -/// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64 -/// t29: i32 = zero_extend t28 -/// t32: i32 = shl t29, Constant:i8<8> -/// t33: i32 = or t23, t32 -/// As a possible fix visitLoad can check if the load can be a part of a load -/// combine pattern and add corresponding OR roots to the worklist. -SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { - assert(N->getOpcode() == ISD::OR && - "Can only match load combining against OR nodes"); - - // Handles simple types only - EVT VT = N->getValueType(0); - if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64) - return SDValue(); - unsigned ByteWidth = VT.getSizeInBits() / 8; - - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - // Before legalize we can introduce too wide illegal loads which will be later - // split into legal sized loads. This enables us to combine i64 load by i8 - // patterns to a couple of i32 loads on 32 bit targets. - if (LegalOperations && !TLI.isOperationLegal(ISD::LOAD, VT)) - return SDValue(); - - std::function<unsigned(unsigned, unsigned)> LittleEndianByteAt = []( - unsigned BW, unsigned i) { return i; }; - std::function<unsigned(unsigned, unsigned)> BigEndianByteAt = []( - unsigned BW, unsigned i) { return BW - i - 1; }; - - bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian(); - auto MemoryByteOffset = [&] (ByteProvider P) { - assert(P.isMemory() && "Must be a memory byte provider"); - unsigned LoadBitWidth = P.Load->getMemoryVT().getSizeInBits(); - assert(LoadBitWidth % 8 == 0 && - "can only analyze providers for individual bytes not bit"); - unsigned LoadByteWidth = LoadBitWidth / 8; - return IsBigEndianTarget - ? BigEndianByteAt(LoadByteWidth, P.ByteOffset) - : LittleEndianByteAt(LoadByteWidth, P.ByteOffset); - }; - - Optional<BaseIndexOffset> Base; - SDValue Chain; - - SmallPtrSet<LoadSDNode *, 8> Loads; - Optional<ByteProvider> FirstByteProvider; - int64_t FirstOffset = INT64_MAX; - - // Check if all the bytes of the OR we are looking at are loaded from the same - // base address. Collect bytes offsets from Base address in ByteOffsets. - SmallVector<int64_t, 4> ByteOffsets(ByteWidth); - for (unsigned i = 0; i < ByteWidth; i++) { - auto P = calculateByteProvider(SDValue(N, 0), i, 0, /*Root=*/true); - if (!P || !P->isMemory()) // All the bytes must be loaded from memory - return SDValue(); - - LoadSDNode *L = P->Load; - assert(L->hasNUsesOfValue(1, 0) && !L->isVolatile() && !L->isIndexed() && - "Must be enforced by calculateByteProvider"); - assert(L->getOffset().isUndef() && "Unindexed load must have undef offset"); - - // All loads must share the same chain - SDValue LChain = L->getChain(); - if (!Chain) - Chain = LChain; - else if (Chain != LChain) - return SDValue(); - - // Loads must share the same base address - BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG); - int64_t ByteOffsetFromBase = 0; - if (!Base) - Base = Ptr; - else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase)) - return SDValue(); - - // Calculate the offset of the current byte from the base address - ByteOffsetFromBase += MemoryByteOffset(*P); - ByteOffsets[i] = ByteOffsetFromBase; - - // Remember the first byte load - if (ByteOffsetFromBase < FirstOffset) { - FirstByteProvider = P; - FirstOffset = ByteOffsetFromBase; - } - - Loads.insert(L); - } - assert(!Loads.empty() && "All the bytes of the value must be loaded from " - "memory, so there must be at least one load which produces the value"); - assert(Base && "Base address of the accessed memory location must be set"); - assert(FirstOffset != INT64_MAX && "First byte offset must be set"); - - // Check if the bytes of the OR we are looking at match with either big or - // little endian value load - bool BigEndian = true, LittleEndian = true; - for (unsigned i = 0; i < ByteWidth; i++) { - int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset; - LittleEndian &= CurrentByteOffset == LittleEndianByteAt(ByteWidth, i); - BigEndian &= CurrentByteOffset == BigEndianByteAt(ByteWidth, i); - if (!BigEndian && !LittleEndian) - return SDValue(); - } - assert((BigEndian != LittleEndian) && "should be either or"); - assert(FirstByteProvider && "must be set"); - - // Ensure that the first byte is loaded from zero offset of the first load. - // So the combined value can be loaded from the first load address. - if (MemoryByteOffset(*FirstByteProvider) != 0) - return SDValue(); - LoadSDNode *FirstLoad = FirstByteProvider->Load; - - // The node we are looking at matches with the pattern, check if we can - // replace it with a single load and bswap if needed. - - // If the load needs byte swap check if the target supports it - bool NeedsBswap = IsBigEndianTarget != BigEndian; - - // Before legalize we can introduce illegal bswaps which will be later - // converted to an explicit bswap sequence. This way we end up with a single - // load and byte shuffling instead of several loads and byte shuffling. - if (NeedsBswap && LegalOperations && !TLI.isOperationLegal(ISD::BSWAP, VT)) - return SDValue(); - - // Check that a load of the wide type is both allowed and fast on the target - bool Fast = false; - bool Allowed = TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), - VT, FirstLoad->getAddressSpace(), - FirstLoad->getAlignment(), &Fast); - if (!Allowed || !Fast) - return SDValue(); - - SDValue NewLoad = - DAG.getLoad(VT, SDLoc(N), Chain, FirstLoad->getBasePtr(), - FirstLoad->getPointerInfo(), FirstLoad->getAlignment()); - - // Transfer chain users from old loads to the new load. - for (LoadSDNode *L : Loads) - DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1)); - - return NeedsBswap ? DAG.getNode(ISD::BSWAP, SDLoc(N), VT, NewLoad) : NewLoad; -} - -// If the target has andn, bsl, or a similar bit-select instruction, -// we want to unfold masked merge, with canonical pattern of: -// | A | |B| -// ((x ^ y) & m) ^ y -// | D | -// Into: -// (x & m) | (y & ~m) -// If y is a constant, and the 'andn' does not work with immediates, -// we unfold into a different pattern: -// ~(~x & m) & (m | y) -// NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at -// the very least that breaks andnpd / andnps patterns, and because those -// patterns are simplified in IR and shouldn't be created in the DAG -SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) { - assert(N->getOpcode() == ISD::XOR); - - // Don't touch 'not' (i.e. where y = -1). - if (isAllOnesOrAllOnesSplat(N->getOperand(1))) - return SDValue(); - - EVT VT = N->getValueType(0); - - // There are 3 commutable operators in the pattern, - // so we have to deal with 8 possible variants of the basic pattern. - SDValue X, Y, M; - auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) { - if (And.getOpcode() != ISD::AND || !And.hasOneUse()) - return false; - SDValue Xor = And.getOperand(XorIdx); - if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse()) - return false; - SDValue Xor0 = Xor.getOperand(0); - SDValue Xor1 = Xor.getOperand(1); - // Don't touch 'not' (i.e. where y = -1). - if (isAllOnesOrAllOnesSplat(Xor1)) - return false; - if (Other == Xor0) - std::swap(Xor0, Xor1); - if (Other != Xor1) - return false; - X = Xor0; - Y = Xor1; - M = And.getOperand(XorIdx ? 0 : 1); - return true; - }; - - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) && - !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0)) - return SDValue(); - - // Don't do anything if the mask is constant. This should not be reachable. - // InstCombine should have already unfolded this pattern, and DAGCombiner - // probably shouldn't produce it, too. - if (isa<ConstantSDNode>(M.getNode())) - return SDValue(); - - // We can transform if the target has AndNot - if (!TLI.hasAndNot(M)) - return SDValue(); - - SDLoc DL(N); - - // If Y is a constant, check that 'andn' works with immediates. - if (!TLI.hasAndNot(Y)) { - assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable."); - // If not, we need to do a bit more work to make sure andn is still used. - SDValue NotX = DAG.getNOT(DL, X, VT); - SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M); - SDValue NotLHS = DAG.getNOT(DL, LHS, VT); - SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y); - return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS); - } - - SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M); - SDValue NotM = DAG.getNOT(DL, M, VT); - SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM); - - return DAG.getNode(ISD::OR, DL, VT, LHS, RHS); -} - SDValue DAGCombiner::visitXOR(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -6135,138 +4093,112 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { } // fold (xor undef, undef) -> 0. This is a common idiom (misuse). - SDLoc DL(N); - if (N0.isUndef() && N1.isUndef()) - return DAG.getConstant(0, DL, VT); + if (N0.getOpcode() == ISD::UNDEF && N1.getOpcode() == ISD::UNDEF) + return DAG.getConstant(0, SDLoc(N), VT); // fold (xor x, undef) -> undef - if (N0.isUndef()) + if (N0.getOpcode() == ISD::UNDEF) return N0; - if (N1.isUndef()) + if (N1.getOpcode() == ISD::UNDEF) return N1; // fold (xor c1, c2) -> c1^c2 ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); ConstantSDNode *N1C = getAsNonOpaqueConstant(N1); if (N0C && N1C) - return DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, N0C, N1C); + return DAG.FoldConstantArithmetic(ISD::XOR, SDLoc(N), VT, N0C, N1C); // canonicalize constant to RHS - if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && - !DAG.isConstantIntBuildVectorOrConstantInt(N1)) - return DAG.getNode(ISD::XOR, DL, VT, N1, N0); + if (isConstantIntBuildVectorOrConstantInt(N0) && + !isConstantIntBuildVectorOrConstantInt(N1)) + return DAG.getNode(ISD::XOR, SDLoc(N), VT, N1, N0); // fold (xor x, 0) -> x if (isNullConstant(N1)) return N0; - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; - // reassociate xor - if (SDValue RXOR = ReassociateOps(ISD::XOR, DL, N0, N1, N->getFlags())) + if (SDValue RXOR = ReassociateOps(ISD::XOR, SDLoc(N), N0, N1)) return RXOR; // fold !(x cc y) -> (x !cc y) - unsigned N0Opcode = N0.getOpcode(); SDValue LHS, RHS, CC; if (TLI.isConstTrueVal(N1.getNode()) && isSetCCEquivalent(N0, LHS, RHS, CC)) { + bool isInt = LHS.getValueType().isInteger(); ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(), - LHS.getValueType().isInteger()); + isInt); + if (!LegalOperations || TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) { - switch (N0Opcode) { + switch (N0.getOpcode()) { default: llvm_unreachable("Unhandled SetCC Equivalent!"); case ISD::SETCC: - return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC); + return DAG.getSetCC(SDLoc(N), VT, LHS, RHS, NotCC); case ISD::SELECT_CC: - return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2), + return DAG.getSelectCC(SDLoc(N), LHS, RHS, N0.getOperand(2), N0.getOperand(3), NotCC); } } } // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y))) - if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() && + if (isOneConstant(N1) && N0.getOpcode() == ISD::ZERO_EXTEND && + N0.getNode()->hasOneUse() && isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){ SDValue V = N0.getOperand(0); - SDLoc DL0(N0); - V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V, - DAG.getConstant(1, DL0, V.getValueType())); + SDLoc DL(N0); + V = DAG.getNode(ISD::XOR, DL, V.getValueType(), V, + DAG.getConstant(1, DL, V.getValueType())); AddToWorklist(V.getNode()); - return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V); + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, V); } // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc - if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() && - (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) { + if (isOneConstant(N1) && VT == MVT::i1 && + (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND)) { SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1); if (isOneUseSetCC(RHS) || isOneUseSetCC(LHS)) { - unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND; + unsigned NewOpcode = N0.getOpcode() == ISD::AND ? ISD::OR : ISD::AND; LHS = DAG.getNode(ISD::XOR, SDLoc(LHS), VT, LHS, N1); // LHS = ~LHS RHS = DAG.getNode(ISD::XOR, SDLoc(RHS), VT, RHS, N1); // RHS = ~RHS AddToWorklist(LHS.getNode()); AddToWorklist(RHS.getNode()); - return DAG.getNode(NewOpcode, DL, VT, LHS, RHS); + return DAG.getNode(NewOpcode, SDLoc(N), VT, LHS, RHS); } } // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants - if (isAllOnesConstant(N1) && N0.hasOneUse() && - (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) { + if (isAllOnesConstant(N1) && + (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND)) { SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1); if (isa<ConstantSDNode>(RHS) || isa<ConstantSDNode>(LHS)) { - unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND; + unsigned NewOpcode = N0.getOpcode() == ISD::AND ? ISD::OR : ISD::AND; LHS = DAG.getNode(ISD::XOR, SDLoc(LHS), VT, LHS, N1); // LHS = ~LHS RHS = DAG.getNode(ISD::XOR, SDLoc(RHS), VT, RHS, N1); // RHS = ~RHS AddToWorklist(LHS.getNode()); AddToWorklist(RHS.getNode()); - return DAG.getNode(NewOpcode, DL, VT, LHS, RHS); + return DAG.getNode(NewOpcode, SDLoc(N), VT, LHS, RHS); } } // fold (xor (and x, y), y) -> (and (not x), y) - if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) { - SDValue X = N0.getOperand(0); + if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() && + N0->getOperand(1) == N1) { + SDValue X = N0->getOperand(0); SDValue NotX = DAG.getNOT(SDLoc(X), X, VT); AddToWorklist(NotX.getNode()); - return DAG.getNode(ISD::AND, DL, VT, NotX, N1); - } - - if ((N0Opcode == ISD::SRL || N0Opcode == ISD::SHL) && N0.hasOneUse()) { - ConstantSDNode *XorC = isConstOrConstSplat(N1); - ConstantSDNode *ShiftC = isConstOrConstSplat(N0.getOperand(1)); - unsigned BitWidth = VT.getScalarSizeInBits(); - if (XorC && ShiftC) { - // Don't crash on an oversized shift. We can not guarantee that a bogus - // shift has been simplified to undef. - uint64_t ShiftAmt = ShiftC->getLimitedValue(); - if (ShiftAmt < BitWidth) { - APInt Ones = APInt::getAllOnesValue(BitWidth); - Ones = N0Opcode == ISD::SHL ? Ones.shl(ShiftAmt) : Ones.lshr(ShiftAmt); - if (XorC->getAPIntValue() == Ones) { - // If the xor constant is a shifted -1, do a 'not' before the shift: - // xor (X << ShiftC), XorC --> (not X) << ShiftC - // xor (X >> ShiftC), XorC --> (not X) >> ShiftC - SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT); - return DAG.getNode(N0Opcode, DL, VT, Not, N0.getOperand(1)); - } - } - } + return DAG.getNode(ISD::AND, SDLoc(N), VT, NotX, N1); } - - // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X) - if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) { - SDValue A = N0Opcode == ISD::ADD ? N0 : N1; - SDValue S = N0Opcode == ISD::SRA ? N0 : N1; - if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) { - SDValue A0 = A.getOperand(0), A1 = A.getOperand(1); - SDValue S0 = S.getOperand(0); - if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0)) { - unsigned OpSizeInBits = VT.getScalarSizeInBits(); - if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1))) - if (C->getAPIntValue() == (OpSizeInBits - 1)) - return DAG.getNode(ISD::ABS, DL, VT, S0); - } + // fold (xor (xor x, c1), c2) -> (xor x, (xor c1, c2)) + if (N1C && N0.getOpcode() == ISD::XOR) { + if (const ConstantSDNode *N00C = getAsNonOpaqueConstant(N0.getOperand(0))) { + SDLoc DL(N); + return DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1), + DAG.getConstant(N1C->getAPIntValue() ^ + N00C->getAPIntValue(), DL, VT)); + } + if (const ConstantSDNode *N01C = getAsNonOpaqueConstant(N0.getOperand(1))) { + SDLoc DL(N); + return DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0), + DAG.getConstant(N1C->getAPIntValue() ^ + N01C->getAPIntValue(), DL, VT)); } } - // fold (xor x, x) -> 0 if (N0 == N1) - return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations); + return tryFoldToZero(SDLoc(N), TLI, VT, DAG, LegalOperations, LegalTypes); // fold (xor (shl 1, x), -1) -> (rotl ~1, x) // Here is a concrete example of this equivalence: @@ -6286,23 +4218,21 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { // consistent result. // - Pushing the zero left requires shifting one bits in from the right. // A rotate left of ~1 is a nice way of achieving the desired result. - if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL && - isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) { + if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0.getOpcode() == ISD::SHL + && isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) { + SDLoc DL(N); return DAG.getNode(ISD::ROTL, DL, VT, DAG.getConstant(~1, DL, VT), N0.getOperand(1)); } // Simplify: xor (op x...), (op y...) -> (op (xor x, y)) - if (N0Opcode == N1.getOpcode()) - if (SDValue V = hoistLogicOpWithSameOpcodeHands(N)) - return V; - - // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable - if (SDValue MM = unfoldMaskedMerge(N)) - return MM; + if (N0.getOpcode() == N1.getOpcode()) + if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N)) + return Tmp; // Simplify the expression using non-local knowledge. - if (SimplifyDemandedBits(SDValue(N, 0))) + if (!VT.isVector() && + SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); return SDValue(); @@ -6311,10 +4241,6 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { /// Handle transforms common to the three shifts, when the shift amount is a /// constant. SDValue DAGCombiner::visitShiftByConstant(SDNode *N, ConstantSDNode *Amt) { - // Do not turn a 'not' into a regular xor. - if (isBitwiseNot(N->getOperand(0))) - return SDValue(); - SDNode *LHS = N->getOperand(0).getNode(); if (!LHS->hasOneUse()) return SDValue(); @@ -6344,20 +4270,16 @@ SDValue DAGCombiner::visitShiftByConstant(SDNode *N, ConstantSDNode *Amt) { ConstantSDNode *BinOpCst = getAsNonOpaqueConstant(LHS->getOperand(1)); if (!BinOpCst) return SDValue(); - // FIXME: disable this unless the input to the binop is a shift by a constant - // or is copy/select.Enable this in other cases when figure out it's exactly profitable. + // FIXME: disable this unless the input to the binop is a shift by a constant. + // If it is not a shift, it pessimizes some common cases like: + // + // void foo(int *X, int i) { X[i & 1235] = 1; } + // int bar(int *X, int i) { return X[i & 255]; } SDNode *BinOpLHSVal = LHS->getOperand(0).getNode(); - bool isShift = BinOpLHSVal->getOpcode() == ISD::SHL || - BinOpLHSVal->getOpcode() == ISD::SRA || - BinOpLHSVal->getOpcode() == ISD::SRL; - bool isCopyOrSelect = BinOpLHSVal->getOpcode() == ISD::CopyFromReg || - BinOpLHSVal->getOpcode() == ISD::SELECT; - - if ((!isShift || !isa<ConstantSDNode>(BinOpLHSVal->getOperand(1))) && - !isCopyOrSelect) - return SDValue(); - - if (isCopyOrSelect && N->hasOneUse()) + if ((BinOpLHSVal->getOpcode() != ISD::SHL && + BinOpLHSVal->getOpcode() != ISD::SRA && + BinOpLHSVal->getOpcode() != ISD::SRL) || + !isa<ConstantSDNode>(BinOpLHSVal->getOperand(1))) return SDValue(); EVT VT = N->getValueType(0); @@ -6372,7 +4294,7 @@ SDValue DAGCombiner::visitShiftByConstant(SDNode *N, ConstantSDNode *Amt) { return SDValue(); } - if (!TLI.isDesirableToCommuteWithShift(N, Level)) + if (!TLI.isDesirableToCommuteWithShift(LHS)) return SDValue(); // Fold the constants, shifting the binop RHS by the shift amount. @@ -6397,15 +4319,19 @@ SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) { // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC) if (N->hasOneUse() && N->getOperand(0).hasOneUse()) { SDValue N01 = N->getOperand(0).getOperand(1); - if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) { - SDLoc DL(N); - EVT TruncVT = N->getValueType(0); - SDValue N00 = N->getOperand(0).getOperand(0); - SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00); - SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01); - AddToWorklist(Trunc00.getNode()); - AddToWorklist(Trunc01.getNode()); - return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01); + + if (ConstantSDNode *N01C = isConstOrConstSplat(N01)) { + if (!N01C->isOpaque()) { + EVT TruncVT = N->getValueType(0); + SDValue N00 = N->getOperand(0).getOperand(0); + APInt TruncC = N01C->getAPIntValue(); + TruncC = TruncC.trunc(TruncVT.getScalarSizeInBits()); + SDLoc DL(N); + + return DAG.getNode(ISD::AND, DL, TruncVT, + DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00), + DAG.getConstant(TruncC, DL, TruncVT)); + } } } @@ -6413,58 +4339,13 @@ SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) { } SDValue DAGCombiner::visitRotate(SDNode *N) { - SDLoc dl(N); - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - EVT VT = N->getValueType(0); - unsigned Bitsize = VT.getScalarSizeInBits(); - - // fold (rot x, 0) -> x - if (isNullOrNullSplat(N1)) - return N0; - - // fold (rot x, c) -> x iff (c % BitSize) == 0 - if (isPowerOf2_32(Bitsize) && Bitsize > 1) { - APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1); - if (DAG.MaskedValueIsZero(N1, ModuloMask)) - return N0; - } - - // fold (rot x, c) -> (rot x, c % BitSize) - if (ConstantSDNode *Cst = isConstOrConstSplat(N1)) { - if (Cst->getAPIntValue().uge(Bitsize)) { - uint64_t RotAmt = Cst->getAPIntValue().urem(Bitsize); - return DAG.getNode(N->getOpcode(), dl, VT, N0, - DAG.getConstant(RotAmt, dl, N1.getValueType())); - } - } - // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))). - if (N1.getOpcode() == ISD::TRUNCATE && - N1.getOperand(0).getOpcode() == ISD::AND) { - if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode())) - return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1); - } - - unsigned NextOp = N0.getOpcode(); - // fold (rot* (rot* x, c2), c1) -> (rot* x, c1 +- c2 % bitsize) - if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) { - SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1); - SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)); - if (C1 && C2 && C1->getValueType(0) == C2->getValueType(0)) { - EVT ShiftVT = C1->getValueType(0); - bool SameSide = (N->getOpcode() == NextOp); - unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB; - if (SDValue CombinedShift = - DAG.FoldConstantArithmetic(CombineOp, dl, ShiftVT, C1, C2)) { - SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT); - SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic( - ISD::SREM, dl, ShiftVT, CombinedShift.getNode(), - BitsizeC.getNode()); - return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0), - CombinedShiftNorm); - } - } + if (N->getOperand(1).getOpcode() == ISD::TRUNCATE && + N->getOperand(1).getOperand(0).getOpcode() == ISD::AND) { + SDValue NewOp1 = distributeTruncateThroughAnd(N->getOperand(1).getNode()); + if (NewOp1.getNode()) + return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), + N->getOperand(0), NewOp1); } return SDValue(); } @@ -6472,13 +4353,11 @@ SDValue DAGCombiner::visitRotate(SDNode *N) { SDValue DAGCombiner::visitSHL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - if (SDValue V = DAG.simplifyShift(N0, N1)) - return V; - EVT VT = N0.getValueType(); unsigned OpSizeInBits = VT.getScalarSizeInBits(); // fold vector ops + ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); if (VT.isVector()) { if (SDValue FoldedVOp = SimplifyVBinOp(N)) return FoldedVOp; @@ -6499,20 +4378,28 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { N01CV, N1CV)) return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C); } + } else { + N1C = isConstOrConstSplat(N1); } } } - ConstantSDNode *N1C = isConstOrConstSplat(N1); - // fold (shl c1, c2) -> c1<<c2 ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); if (N0C && N1C && !N1C->isOpaque()) return DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, N0C, N1C); - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; - + // fold (shl 0, x) -> 0 + if (isNullConstant(N0)) + return N0; + // fold (shl x, c >= size(x)) -> undef + if (N1C && N1C->getAPIntValue().uge(OpSizeInBits)) + return DAG.getUNDEF(VT); + // fold (shl x, 0) -> x + if (N1C && N1C->isNullValue()) + return N0; + // fold (shl undef, x) -> 0 + if (N0.getOpcode() == ISD::UNDEF) + return DAG.getConstant(0, SDLoc(N), VT); // if (shl x, c) is known to be zero, return 0 if (DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnesValue(OpSizeInBits))) @@ -6520,7 +4407,8 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))). if (N1.getOpcode() == ISD::TRUNCATE && N1.getOperand(0).getOpcode() == ISD::AND) { - if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode())) + SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()); + if (NewOp1.getNode()) return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1); } @@ -6528,29 +4416,15 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { return SDValue(N, 0); // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2)) - if (N0.getOpcode() == ISD::SHL) { - auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS, - ConstantSDNode *RHS) { - APInt c1 = LHS->getAPIntValue(); - APInt c2 = RHS->getAPIntValue(); - zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - return (c1 + c2).uge(OpSizeInBits); - }; - if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) - return DAG.getConstant(0, SDLoc(N), VT); - - auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS, - ConstantSDNode *RHS) { - APInt c1 = LHS->getAPIntValue(); - APInt c2 = RHS->getAPIntValue(); - zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - return (c1 + c2).ult(OpSizeInBits); - }; - if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { + if (N1C && N0.getOpcode() == ISD::SHL) { + if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { + uint64_t c1 = N0C1->getZExtValue(); + uint64_t c2 = N1C->getZExtValue(); SDLoc DL(N); - EVT ShiftVT = N1.getValueType(); - SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); - return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum); + if (c1 + c2 >= OpSizeInBits) + return DAG.getConstant(0, DL, VT); + return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), + DAG.getConstant(c1 + c2, DL, N1.getValueType())); } } @@ -6565,22 +4439,18 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { N0.getOperand(0).getOpcode() == ISD::SHL) { SDValue N0Op0 = N0.getOperand(0); if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) { - APInt c1 = N0Op0C1->getAPIntValue(); - APInt c2 = N1C->getAPIntValue(); - zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - + uint64_t c1 = N0Op0C1->getZExtValue(); + uint64_t c2 = N1C->getZExtValue(); EVT InnerShiftVT = N0Op0.getValueType(); uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits(); - if (c2.uge(OpSizeInBits - InnerShiftSize)) { + if (c2 >= OpSizeInBits - InnerShiftSize) { SDLoc DL(N0); - APInt Sum = c1 + c2; - if (Sum.uge(OpSizeInBits)) + if (c1 + c2 >= OpSizeInBits) return DAG.getConstant(0, DL, VT); - - return DAG.getNode( - ISD::SHL, DL, VT, - DAG.getNode(N0.getOpcode(), DL, VT, N0Op0->getOperand(0)), - DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType())); + return DAG.getNode(ISD::SHL, DL, VT, + DAG.getNode(N0.getOpcode(), DL, VT, + N0Op0->getOperand(0)), + DAG.getConstant(c1 + c2, DL, N1.getValueType())); } } } @@ -6592,8 +4462,8 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { N0.getOperand(0).getOpcode() == ISD::SRL) { SDValue N0Op0 = N0.getOperand(0); if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) { - if (N0Op0C1->getAPIntValue().ult(VT.getScalarSizeInBits())) { - uint64_t c1 = N0Op0C1->getZExtValue(); + uint64_t c1 = N0Op0C1->getZExtValue(); + if (c1 < VT.getScalarSizeInBits()) { uint64_t c2 = N1C->getZExtValue(); if (c1 == c2) { SDValue NewOp0 = N0.getOperand(0); @@ -6612,7 +4482,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2 // fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 > C2 if (N1C && (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) && - N0->getFlags().hasExact()) { + cast<BinaryWithFlagsSDNode>(N0)->Flags.hasExact()) { if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { uint64_t C1 = N0C1->getZExtValue(); uint64_t C2 = N1C->getZExtValue(); @@ -6629,8 +4499,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // (and (srl x, (sub c1, c2), MASK) // Only fold this if the inner shift has no other uses -- if it does, folding // this will increase the total number of instructions. - if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse() && - TLI.shouldFoldShiftPairToMask(N, Level)) { + if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse()) { if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { uint64_t c1 = N0C1->getZExtValue(); if (c1 < OpSizeInBits) { @@ -6638,12 +4507,12 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { APInt Mask = APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - c1); SDValue Shift; if (c2 > c1) { - Mask <<= c2 - c1; + Mask = Mask.shl(c2 - c1); SDLoc DL(N); Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), DAG.getConstant(c2 - c1, DL, N1.getValueType())); } else { - Mask.lshrInPlace(c1 - c2); + Mask = Mask.lshr(c1 - c2); SDLoc DL(N); Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), DAG.getConstant(c1 - c2, DL, N1.getValueType())); @@ -6654,39 +4523,37 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { } } } - // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1)) - if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) && - isConstantOrConstantVector(N1, /* No Opaques */ true)) { + if (N1C && N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1)) { + unsigned BitSize = VT.getScalarSizeInBits(); SDLoc DL(N); - SDValue AllBits = DAG.getAllOnesConstant(DL, VT); - SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1); - return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask); + SDValue HiBitsMask = + DAG.getConstant(APInt::getHighBitsSet(BitSize, + BitSize - N1C->getZExtValue()), + DL, VT); + return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), + HiBitsMask); } // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2) - // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2) // Variant of version done on multiply, except mul by a power of 2 is turned // into a shift. - if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) && - N0.getNode()->hasOneUse() && - isConstantOrConstantVector(N1, /* No Opaques */ true) && - isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true) && - TLI.isDesirableToCommuteWithShift(N, Level)) { + APInt Val; + if (N1C && N0.getOpcode() == ISD::ADD && N0.getNode()->hasOneUse() && + (isa<ConstantSDNode>(N0.getOperand(1)) || + isConstantSplatVector(N0.getOperand(1).getNode(), Val))) { SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1); SDValue Shl1 = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1); - AddToWorklist(Shl0.getNode()); - AddToWorklist(Shl1.getNode()); - return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Shl0, Shl1); + return DAG.getNode(ISD::ADD, SDLoc(N), VT, Shl0, Shl1); } // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2) - if (N0.getOpcode() == ISD::MUL && N0.getNode()->hasOneUse() && - isConstantOrConstantVector(N1, /* No Opaques */ true) && - isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true)) { - SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1); - if (isConstantOrConstantVector(Shl)) - return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl); + if (N1C && N0.getOpcode() == ISD::MUL && N0.getNode()->hasOneUse()) { + if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { + if (SDValue Folded = + DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, N0C1, N1C)) + return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Folded); + } } if (N1C && !N1C->isOpaque()) @@ -6699,33 +4566,34 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { SDValue DAGCombiner::visitSRA(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - if (SDValue V = DAG.simplifyShift(N0, N1)) - return V; - EVT VT = N0.getValueType(); - unsigned OpSizeInBits = VT.getScalarSizeInBits(); - - // Arithmetic shifting an all-sign-bit value is a no-op. - // fold (sra 0, x) -> 0 - // fold (sra -1, x) -> -1 - if (DAG.ComputeNumSignBits(N0) == OpSizeInBits) - return N0; + unsigned OpSizeInBits = VT.getScalarType().getSizeInBits(); // fold vector ops - if (VT.isVector()) + ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); + if (VT.isVector()) { if (SDValue FoldedVOp = SimplifyVBinOp(N)) return FoldedVOp; - ConstantSDNode *N1C = isConstOrConstSplat(N1); + N1C = isConstOrConstSplat(N1); + } // fold (sra c1, c2) -> (sra c1, c2) ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); if (N0C && N1C && !N1C->isOpaque()) return DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, N0C, N1C); - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; - + // fold (sra 0, x) -> 0 + if (isNullConstant(N0)) + return N0; + // fold (sra -1, x) -> -1 + if (isAllOnesConstant(N0)) + return N0; + // fold (sra x, (setge c, size(x))) -> undef + if (N1C && N1C->getZExtValue() >= OpSizeInBits) + return DAG.getUNDEF(VT); + // fold (sra x, 0) -> x + if (N1C && N1C->isNullValue()) + return N0; // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports // sext_inreg. if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) { @@ -6741,30 +4609,14 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { } // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2)) - // clamp (add c1, c2) to max shift. - if (N0.getOpcode() == ISD::SRA) { - SDLoc DL(N); - EVT ShiftVT = N1.getValueType(); - EVT ShiftSVT = ShiftVT.getScalarType(); - SmallVector<SDValue, 16> ShiftValues; - - auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) { - APInt c1 = LHS->getAPIntValue(); - APInt c2 = RHS->getAPIntValue(); - zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - APInt Sum = c1 + c2; - unsigned ShiftSum = - Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue(); - ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT)); - return true; - }; - if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) { - SDValue ShiftValue; - if (VT.isVector()) - ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues); - else - ShiftValue = ShiftValues[0]; - return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue); + if (N1C && N0.getOpcode() == ISD::SRA) { + if (ConstantSDNode *C1 = isConstOrConstSplat(N0.getOperand(1))) { + unsigned Sum = N1C->getZExtValue() + C1->getZExtValue(); + if (Sum >= OpSizeInBits) + Sum = OpSizeInBits - 1; + SDLoc DL(N); + return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), + DAG.getConstant(Sum, DL, N1.getValueType())); } } @@ -6785,7 +4637,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorNumElements()); // Determine the residual right-shift amount. - int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue(); + signed ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue(); // If the shift is not a no-op (in which case this should be just a sign // extend already), the truncated to type is legal, sign_extend is legal @@ -6795,6 +4647,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) && TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) && TLI.isTruncateFree(VT, TruncVT)) { + SDLoc DL(N); SDValue Amt = DAG.getConstant(ShiftAmt, DL, getShiftAmountTy(N0.getOperand(0).getValueType())); @@ -6811,7 +4664,8 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))). if (N1.getOpcode() == ISD::TRUNCATE && N1.getOperand(0).getOpcode() == ISD::AND) { - if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode())) + SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()); + if (NewOp1.getNode()) return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1); } @@ -6844,6 +4698,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { if (N1C && SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); + // If the sign bit is known to be zero, switch this to a SRL. if (DAG.SignBitIsZero(N0)) return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1); @@ -6858,90 +4713,81 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { SDValue DAGCombiner::visitSRL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - if (SDValue V = DAG.simplifyShift(N0, N1)) - return V; - EVT VT = N0.getValueType(); - unsigned OpSizeInBits = VT.getScalarSizeInBits(); + unsigned OpSizeInBits = VT.getScalarType().getSizeInBits(); // fold vector ops - if (VT.isVector()) + ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); + if (VT.isVector()) { if (SDValue FoldedVOp = SimplifyVBinOp(N)) return FoldedVOp; - ConstantSDNode *N1C = isConstOrConstSplat(N1); + N1C = isConstOrConstSplat(N1); + } // fold (srl c1, c2) -> c1 >>u c2 ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); if (N0C && N1C && !N1C->isOpaque()) return DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, N0C, N1C); - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; - + // fold (srl 0, x) -> 0 + if (isNullConstant(N0)) + return N0; + // fold (srl x, c >= size(x)) -> undef + if (N1C && N1C->getZExtValue() >= OpSizeInBits) + return DAG.getUNDEF(VT); + // fold (srl x, 0) -> x + if (N1C && N1C->isNullValue()) + return N0; // if (srl x, c) is known to be zero, return 0 if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnesValue(OpSizeInBits))) return DAG.getConstant(0, SDLoc(N), VT); // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2)) - if (N0.getOpcode() == ISD::SRL) { - auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS, - ConstantSDNode *RHS) { - APInt c1 = LHS->getAPIntValue(); - APInt c2 = RHS->getAPIntValue(); - zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - return (c1 + c2).uge(OpSizeInBits); - }; - if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) - return DAG.getConstant(0, SDLoc(N), VT); - - auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS, - ConstantSDNode *RHS) { - APInt c1 = LHS->getAPIntValue(); - APInt c2 = RHS->getAPIntValue(); - zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - return (c1 + c2).ult(OpSizeInBits); - }; - if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { + if (N1C && N0.getOpcode() == ISD::SRL) { + if (ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1))) { + uint64_t c1 = N01C->getZExtValue(); + uint64_t c2 = N1C->getZExtValue(); SDLoc DL(N); - EVT ShiftVT = N1.getValueType(); - SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); - return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum); + if (c1 + c2 >= OpSizeInBits) + return DAG.getConstant(0, DL, VT); + return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), + DAG.getConstant(c1 + c2, DL, N1.getValueType())); } } // fold (srl (trunc (srl x, c1)), c2) -> 0 or (trunc (srl x, (add c1, c2))) if (N1C && N0.getOpcode() == ISD::TRUNCATE && - N0.getOperand(0).getOpcode() == ISD::SRL) { - if (auto N001C = isConstOrConstSplat(N0.getOperand(0).getOperand(1))) { - uint64_t c1 = N001C->getZExtValue(); - uint64_t c2 = N1C->getZExtValue(); - EVT InnerShiftVT = N0.getOperand(0).getValueType(); - EVT ShiftCountVT = N0.getOperand(0).getOperand(1).getValueType(); - uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits(); - // This is only valid if the OpSizeInBits + c1 = size of inner shift. - if (c1 + OpSizeInBits == InnerShiftSize) { - SDLoc DL(N0); - if (c1 + c2 >= InnerShiftSize) - return DAG.getConstant(0, DL, VT); - return DAG.getNode(ISD::TRUNCATE, DL, VT, - DAG.getNode(ISD::SRL, DL, InnerShiftVT, - N0.getOperand(0).getOperand(0), - DAG.getConstant(c1 + c2, DL, - ShiftCountVT))); - } + N0.getOperand(0).getOpcode() == ISD::SRL && + isa<ConstantSDNode>(N0.getOperand(0)->getOperand(1))) { + uint64_t c1 = + cast<ConstantSDNode>(N0.getOperand(0)->getOperand(1))->getZExtValue(); + uint64_t c2 = N1C->getZExtValue(); + EVT InnerShiftVT = N0.getOperand(0).getValueType(); + EVT ShiftCountVT = N0.getOperand(0)->getOperand(1).getValueType(); + uint64_t InnerShiftSize = InnerShiftVT.getScalarType().getSizeInBits(); + // This is only valid if the OpSizeInBits + c1 = size of inner shift. + if (c1 + OpSizeInBits == InnerShiftSize) { + SDLoc DL(N0); + if (c1 + c2 >= InnerShiftSize) + return DAG.getConstant(0, DL, VT); + return DAG.getNode(ISD::TRUNCATE, DL, VT, + DAG.getNode(ISD::SRL, DL, InnerShiftVT, + N0.getOperand(0)->getOperand(0), + DAG.getConstant(c1 + c2, DL, + ShiftCountVT))); } } // fold (srl (shl x, c), c) -> (and x, cst2) - if (N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1 && - isConstantOrConstantVector(N1, /* NoOpaques */ true)) { - SDLoc DL(N); - SDValue Mask = - DAG.getNode(ISD::SRL, DL, VT, DAG.getAllOnesConstant(DL, VT), N1); - AddToWorklist(Mask.getNode()); - return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), Mask); + if (N1C && N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1) { + unsigned BitSize = N0.getScalarValueSizeInBits(); + if (BitSize <= 64) { + uint64_t ShAmt = N1C->getZExtValue() + 64 - BitSize; + SDLoc DL(N); + return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), + DAG.getConstant(~0ULL >> ShAmt, DL, VT)); + } } // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask) @@ -6960,7 +4806,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { DAG.getConstant(ShiftAmt, DL0, getShiftAmountTy(SmallVT))); AddToWorklist(SmallShift.getNode()); - APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt); + APInt Mask = APInt::getAllOnesValue(OpSizeInBits).lshr(ShiftAmt); SDLoc DL(N); return DAG.getNode(ISD::AND, DL, VT, DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift), @@ -6978,19 +4824,20 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit). if (N1C && N0.getOpcode() == ISD::CTLZ && N1C->getAPIntValue() == Log2_32(OpSizeInBits)) { - KnownBits Known = DAG.computeKnownBits(N0.getOperand(0)); + APInt KnownZero, KnownOne; + DAG.computeKnownBits(N0.getOperand(0), KnownZero, KnownOne); // If any of the input bits are KnownOne, then the input couldn't be all // zeros, thus the result of the srl will always be zero. - if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT); + if (KnownOne.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT); // If all of the bits input the to ctlz node are known to be zero, then // the result of the ctlz is "32" and the result of the shift is one. - APInt UnknownBits = ~Known.Zero; + APInt UnknownBits = ~KnownZero; if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT); // Otherwise, check to see if there is exactly one bit input to the ctlz. - if (UnknownBits.isPowerOf2()) { + if ((UnknownBits & (UnknownBits - 1)) == 0) { // Okay, we know that only that the single bit specified by UnknownBits // could be set on input to the CTLZ node. If this bit is set, the SRL // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair @@ -7064,63 +4911,12 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitFunnelShift(SDNode *N) { - EVT VT = N->getValueType(0); - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - SDValue N2 = N->getOperand(2); - bool IsFSHL = N->getOpcode() == ISD::FSHL; - unsigned BitWidth = VT.getScalarSizeInBits(); - - // fold (fshl N0, N1, 0) -> N0 - // fold (fshr N0, N1, 0) -> N1 - if (isPowerOf2_32(BitWidth)) - if (DAG.MaskedValueIsZero( - N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1))) - return IsFSHL ? N0 : N1; - - // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth) - if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) { - if (Cst->getAPIntValue().uge(BitWidth)) { - uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth); - return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0, N1, - DAG.getConstant(RotAmt, SDLoc(N), N2.getValueType())); - } - } - - // fold (fshl N0, N0, N2) -> (rotl N0, N2) - // fold (fshr N0, N0, N2) -> (rotr N0, N2) - // TODO: Investigate flipping this rotate if only one is legal, if funnel shift - // is legal as well we might be better off avoiding non-constant (BW - N2). - unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR; - if (N0 == N1 && hasOperation(RotOpc, VT)) - return DAG.getNode(RotOpc, SDLoc(N), VT, N0, N2); - - return SDValue(); -} - -SDValue DAGCombiner::visitABS(SDNode *N) { - SDValue N0 = N->getOperand(0); - EVT VT = N->getValueType(0); - - // fold (abs c1) -> c2 - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) - return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0); - // fold (abs (abs x)) -> (abs x) - if (N0.getOpcode() == ISD::ABS) - return N0; - // fold (abs x) -> x iff not-negative - if (DAG.SignBitIsZero(N0)) - return N0; - return SDValue(); -} - SDValue DAGCombiner::visitBSWAP(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); // fold (bswap c1) -> c2 - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) + if (isConstantIntBuildVectorOrConstantInt(N0)) return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N0); // fold (bswap (bswap x)) -> x if (N0.getOpcode() == ISD::BSWAP) @@ -7128,33 +4924,13 @@ SDValue DAGCombiner::visitBSWAP(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitBITREVERSE(SDNode *N) { - SDValue N0 = N->getOperand(0); - EVT VT = N->getValueType(0); - - // fold (bitreverse c1) -> c2 - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) - return DAG.getNode(ISD::BITREVERSE, SDLoc(N), VT, N0); - // fold (bitreverse (bitreverse x)) -> x - if (N0.getOpcode() == ISD::BITREVERSE) - return N0.getOperand(0); - return SDValue(); -} - SDValue DAGCombiner::visitCTLZ(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); // fold (ctlz c1) -> c2 - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) + if (isConstantIntBuildVectorOrConstantInt(N0)) return DAG.getNode(ISD::CTLZ, SDLoc(N), VT, N0); - - // If the value is known never to be zero, switch to the undef version. - if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT)) { - if (DAG.isKnownNeverZero(N0)) - return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0); - } - return SDValue(); } @@ -7163,7 +4939,7 @@ SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) { EVT VT = N->getValueType(0); // fold (ctlz_zero_undef c1) -> c2 - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) + if (isConstantIntBuildVectorOrConstantInt(N0)) return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0); return SDValue(); } @@ -7173,15 +4949,8 @@ SDValue DAGCombiner::visitCTTZ(SDNode *N) { EVT VT = N->getValueType(0); // fold (cttz c1) -> c2 - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) + if (isConstantIntBuildVectorOrConstantInt(N0)) return DAG.getNode(ISD::CTTZ, SDLoc(N), VT, N0); - - // If the value is known never to be zero, switch to the undef version. - if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT)) { - if (DAG.isKnownNeverZero(N0)) - return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0); - } - return SDValue(); } @@ -7190,7 +4959,7 @@ SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) { EVT VT = N->getValueType(0); // fold (cttz_zero_undef c1) -> c2 - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) + if (isConstantIntBuildVectorOrConstantInt(N0)) return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0); return SDValue(); } @@ -7200,30 +4969,20 @@ SDValue DAGCombiner::visitCTPOP(SDNode *N) { EVT VT = N->getValueType(0); // fold (ctpop c1) -> c2 - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) + if (isConstantIntBuildVectorOrConstantInt(N0)) return DAG.getNode(ISD::CTPOP, SDLoc(N), VT, N0); return SDValue(); } -// FIXME: This should be checking for no signed zeros on individual operands, as -// well as no nans. -static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS, SDValue RHS) { - const TargetOptions &Options = DAG.getTarget().Options; - EVT VT = LHS.getValueType(); - - return Options.NoSignedZerosFPMath && VT.isFloatingPoint() && - DAG.isKnownNeverNaN(LHS) && DAG.isKnownNeverNaN(RHS); -} -/// Generate Min/Max node -static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, - SDValue RHS, SDValue True, SDValue False, +/// \brief Generate Min/Max node +static SDValue combineMinNumMaxNum(SDLoc DL, EVT VT, SDValue LHS, SDValue RHS, + SDValue True, SDValue False, ISD::CondCode CC, const TargetLowering &TLI, SelectionDAG &DAG) { if (!(LHS == True && RHS == False) && !(LHS == False && RHS == True)) return SDValue(); - EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT); switch (CC) { case ISD::SETOLT: case ISD::SETOLE: @@ -7231,15 +4990,8 @@ static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, case ISD::SETLE: case ISD::SETULT: case ISD::SETULE: { - // Since it's known never nan to get here already, either fminnum or - // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is - // expanded in terms of it. - unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE; - if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT)) - return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS); - unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM; - if (TLI.isOperationLegalOrCustom(Opcode, TransformVT)) + if (TLI.isOperationLegal(Opcode, VT)) return DAG.getNode(Opcode, DL, VT, LHS, RHS); return SDValue(); } @@ -7249,12 +5001,8 @@ static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, case ISD::SETGE: case ISD::SETUGT: case ISD::SETUGE: { - unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE; - if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT)) - return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS); - unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM; - if (TLI.isOperationLegalOrCustom(Opcode, TransformVT)) + if (TLI.isOperationLegal(Opcode, VT)) return DAG.getNode(Opcode, DL, VT, LHS, RHS); return SDValue(); } @@ -7263,76 +5011,25 @@ static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, } } -SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) { - SDValue Cond = N->getOperand(0); +SDValue DAGCombiner::visitSELECT(SDNode *N) { + SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); SDValue N2 = N->getOperand(2); EVT VT = N->getValueType(0); - EVT CondVT = Cond.getValueType(); - SDLoc DL(N); - - if (!VT.isInteger()) - return SDValue(); - - auto *C1 = dyn_cast<ConstantSDNode>(N1); - auto *C2 = dyn_cast<ConstantSDNode>(N2); - if (!C1 || !C2) - return SDValue(); - - // Only do this before legalization to avoid conflicting with target-specific - // transforms in the other direction (create a select from a zext/sext). There - // is also a target-independent combine here in DAGCombiner in the other - // direction for (select Cond, -1, 0) when the condition is not i1. - if (CondVT == MVT::i1 && !LegalOperations) { - if (C1->isNullValue() && C2->isOne()) { - // select Cond, 0, 1 --> zext (!Cond) - SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1); - if (VT != MVT::i1) - NotCond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NotCond); - return NotCond; - } - if (C1->isNullValue() && C2->isAllOnesValue()) { - // select Cond, 0, -1 --> sext (!Cond) - SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1); - if (VT != MVT::i1) - NotCond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NotCond); - return NotCond; - } - if (C1->isOne() && C2->isNullValue()) { - // select Cond, 1, 0 --> zext (Cond) - if (VT != MVT::i1) - Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond); - return Cond; - } - if (C1->isAllOnesValue() && C2->isNullValue()) { - // select Cond, -1, 0 --> sext (Cond) - if (VT != MVT::i1) - Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond); - return Cond; - } - - // For any constants that differ by 1, we can transform the select into an - // extend and add. Use a target hook because some targets may prefer to - // transform in the other direction. - if (TLI.convertSelectOfConstantsToMath(VT)) { - if (C1->getAPIntValue() - 1 == C2->getAPIntValue()) { - // select Cond, C1, C1-1 --> add (zext Cond), C1-1 - if (VT != MVT::i1) - Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond); - return DAG.getNode(ISD::ADD, DL, VT, Cond, N2); - } - if (C1->getAPIntValue() + 1 == C2->getAPIntValue()) { - // select Cond, C1, C1+1 --> add (sext Cond), C1+1 - if (VT != MVT::i1) - Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond); - return DAG.getNode(ISD::ADD, DL, VT, Cond, N2); - } - } - - return SDValue(); - } + EVT VT0 = N0.getValueType(); - // fold (select Cond, 0, 1) -> (xor Cond, 1) + // fold (select C, X, X) -> X + if (N1 == N2) + return N1; + if (const ConstantSDNode *N0C = dyn_cast<const ConstantSDNode>(N0)) { + // fold (select true, X, Y) -> X + // fold (select false, X, Y) -> Y + return !N0C->isNullValue() ? N1 : N2; + } + // fold (select C, 1, X) -> (or C, X) + if (VT == MVT::i1 && isOneConstant(N1)) + return DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N2); + // fold (select C, 0, 1) -> (xor C, 1) // We can't do this reliably if integer based booleans have different contents // to floating point based booleans. This is because we can't tell whether we // have an integer-based boolean or a floating-point-based boolean unless we @@ -7341,92 +5038,85 @@ SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) { // undiscoverable (or not reasonably discoverable). For example, it could be // in another basic block or it could require searching a complicated // expression. - if (CondVT.isInteger() && - TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) == - TargetLowering::ZeroOrOneBooleanContent && - TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) == - TargetLowering::ZeroOrOneBooleanContent && - C1->isNullValue() && C2->isOne()) { - SDValue NotCond = - DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT)); - if (VT.bitsEq(CondVT)) - return NotCond; - return DAG.getZExtOrTrunc(NotCond, DL, VT); + if (VT.isInteger() && + (VT0 == MVT::i1 || (VT0.isInteger() && + TLI.getBooleanContents(false, false) == + TLI.getBooleanContents(false, true) && + TLI.getBooleanContents(false, false) == + TargetLowering::ZeroOrOneBooleanContent)) && + isNullConstant(N1) && isOneConstant(N2)) { + SDValue XORNode; + if (VT == VT0) { + SDLoc DL(N); + return DAG.getNode(ISD::XOR, DL, VT0, + N0, DAG.getConstant(1, DL, VT0)); + } + SDLoc DL0(N0); + XORNode = DAG.getNode(ISD::XOR, DL0, VT0, + N0, DAG.getConstant(1, DL0, VT0)); + AddToWorklist(XORNode.getNode()); + if (VT.bitsGT(VT0)) + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, XORNode); + return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, XORNode); } - - return SDValue(); -} - -SDValue DAGCombiner::visitSELECT(SDNode *N) { - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - SDValue N2 = N->getOperand(2); - EVT VT = N->getValueType(0); - EVT VT0 = N0.getValueType(); - SDLoc DL(N); - - if (SDValue V = DAG.simplifySelect(N0, N1, N2)) - return V; - - // fold (select X, X, Y) -> (or X, Y) - // fold (select X, 1, Y) -> (or C, Y) - if (VT == VT0 && VT == MVT::i1 && (N0 == N1 || isOneConstant(N1))) - return DAG.getNode(ISD::OR, DL, VT, N0, N2); - - if (SDValue V = foldSelectOfConstants(N)) - return V; - // fold (select C, 0, X) -> (and (not C), X) if (VT == VT0 && VT == MVT::i1 && isNullConstant(N1)) { SDValue NOTNode = DAG.getNOT(SDLoc(N0), N0, VT); AddToWorklist(NOTNode.getNode()); - return DAG.getNode(ISD::AND, DL, VT, NOTNode, N2); + return DAG.getNode(ISD::AND, SDLoc(N), VT, NOTNode, N2); } // fold (select C, X, 1) -> (or (not C), X) if (VT == VT0 && VT == MVT::i1 && isOneConstant(N2)) { SDValue NOTNode = DAG.getNOT(SDLoc(N0), N0, VT); AddToWorklist(NOTNode.getNode()); - return DAG.getNode(ISD::OR, DL, VT, NOTNode, N1); + return DAG.getNode(ISD::OR, SDLoc(N), VT, NOTNode, N1); } + // fold (select C, X, 0) -> (and C, X) + if (VT == MVT::i1 && isNullConstant(N2)) + return DAG.getNode(ISD::AND, SDLoc(N), VT, N0, N1); + // fold (select X, X, Y) -> (or X, Y) + // fold (select X, 1, Y) -> (or X, Y) + if (VT == MVT::i1 && (N0 == N1 || isOneConstant(N1))) + return DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N2); // fold (select X, Y, X) -> (and X, Y) // fold (select X, Y, 0) -> (and X, Y) - if (VT == VT0 && VT == MVT::i1 && (N0 == N2 || isNullConstant(N2))) - return DAG.getNode(ISD::AND, DL, VT, N0, N1); + if (VT == MVT::i1 && (N0 == N2 || isNullConstant(N2))) + return DAG.getNode(ISD::AND, SDLoc(N), VT, N0, N1); // If we can fold this based on the true/false value, do so. if (SimplifySelectOps(N, N1, N2)) - return SDValue(N, 0); // Don't revisit N. + return SDValue(N, 0); // Don't revisit N. if (VT0 == MVT::i1) { // The code in this block deals with the following 2 equivalences: // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y)) // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y) - // The target can specify its preferred form with the + // The target can specify its prefered form with the // shouldNormalizeToSelectSequence() callback. However we always transform // to the right anyway if we find the inner select exists in the DAG anyway // and we always transform to the left side if we know that we can further // optimize the combination of the conditions. - bool normalizeToSequence = - TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT); + bool normalizeToSequence + = TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT); // select (and Cond0, Cond1), X, Y // -> select Cond0, (select Cond1, X, Y), Y if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) { SDValue Cond0 = N0->getOperand(0); SDValue Cond1 = N0->getOperand(1); - SDValue InnerSelect = - DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2); + SDValue InnerSelect = DAG.getNode(ISD::SELECT, SDLoc(N), + N1.getValueType(), Cond1, N1, N2); if (normalizeToSequence || !InnerSelect.use_empty()) - return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, + return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Cond0, InnerSelect, N2); } // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y) if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) { SDValue Cond0 = N0->getOperand(0); SDValue Cond1 = N0->getOperand(1); - SDValue InnerSelect = - DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2); + SDValue InnerSelect = DAG.getNode(ISD::SELECT, SDLoc(N), + N1.getValueType(), Cond1, N1, N2); if (normalizeToSequence || !InnerSelect.use_empty()) - return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1, + return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Cond0, N1, InnerSelect); } @@ -7438,13 +5128,15 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) { // Create the actual and node if we can generate good code for it. if (!normalizeToSequence) { - SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0); - return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1, N2); + SDValue And = DAG.getNode(ISD::AND, SDLoc(N), N0.getValueType(), + N0, N1_0); + return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), And, + N1_1, N2); } // Otherwise see if we can optimize the "and" to a better pattern. if (SDValue Combined = visitANDLike(N0, N1_0, N)) - return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1, - N2); + return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Combined, + N1_1, N2); } } // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y @@ -7455,72 +5147,49 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) { // Create the actual or node if we can generate good code for it. if (!normalizeToSequence) { - SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0); - return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1, N2_2); + SDValue Or = DAG.getNode(ISD::OR, SDLoc(N), N0.getValueType(), + N0, N2_0); + return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Or, + N1, N2_2); } // Otherwise see if we can optimize to a better pattern. if (SDValue Combined = visitORLike(N0, N2_0, N)) - return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1, - N2_2); + return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Combined, + N1, N2_2); } } } - if (VT0 == MVT::i1) { - // select (not Cond), N1, N2 -> select Cond, N2, N1 - if (isBitwiseNot(N0)) - return DAG.getNode(ISD::SELECT, DL, VT, N0->getOperand(0), N2, N1); - } - - // Fold selects based on a setcc into other things, such as min/max/abs. + // fold selects based on a setcc into other things, such as min/max/abs if (N0.getOpcode() == ISD::SETCC) { - SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1); - ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get(); - - // select (fcmp lt x, y), x, y -> fminnum x, y - // select (fcmp gt x, y), x, y -> fmaxnum x, y + // select x, y (fcmp lt x, y) -> fminnum x, y + // select x, y (fcmp gt x, y) -> fmaxnum x, y + // + // This is OK if we don't care about what happens if either operand is a + // NaN. // - // This is OK if we don't care what happens if either operand is a NaN. - if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2)) - if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2, - CC, TLI, DAG)) - return FMinMax; - // Use 'unsigned add with overflow' to optimize an unsigned saturating add. - // This is conservatively limited to pre-legal-operations to give targets - // a chance to reverse the transform if they want to do that. Also, it is - // unlikely that the pattern would be formed late, so it's probably not - // worth going through the other checks. - if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) && - CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) && - N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) { - auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1)); - auto *NotC = dyn_cast<ConstantSDNode>(Cond1); - if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) { - // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) --> - // uaddo Cond0, C; select uaddo.1, -1, uaddo.0 - // - // The IR equivalent of this transform would have this form: - // %a = add %x, C - // %c = icmp ugt %x, ~C - // %r = select %c, -1, %a - // => - // %u = call {iN,i1} llvm.uadd.with.overflow(%x, C) - // %u0 = extractvalue %u, 0 - // %u1 = extractvalue %u, 1 - // %r = select %u1, -1, %u0 - SDVTList VTs = DAG.getVTList(VT, VT0); - SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1)); - return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0)); - } - } + // FIXME: Instead of testing for UnsafeFPMath, this should be checking for + // no signed zeros as well as no nans. + const TargetOptions &Options = DAG.getTarget().Options; + if (Options.UnsafeFPMath && + VT.isFloatingPoint() && N0.hasOneUse() && + DAG.isKnownNeverNaN(N1) && DAG.isKnownNeverNaN(N2)) { + ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get(); - if (TLI.isOperationLegal(ISD::SELECT_CC, VT) || - (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) - return DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1, N2, - N0.getOperand(2)); + if (SDValue FMinMax = combineMinNumMaxNum(SDLoc(N), VT, N0.getOperand(0), + N0.getOperand(1), N1, N2, CC, + TLI, DAG)) + return FMinMax; + } - return SimplifySelect(DL, N0, N1, N2); + if ((!LegalOperations && + TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT)) || + TLI.isOperationLegal(ISD::SELECT_CC, VT)) + return DAG.getNode(ISD::SELECT_CC, SDLoc(N), VT, + N0.getOperand(0), N0.getOperand(1), + N1, N2, N0.getOperand(2)); + return SimplifySelect(SDLoc(N), N0, N1, N2); } return SDValue(); @@ -7546,7 +5215,7 @@ std::pair<SDValue, SDValue> SplitVSETCC(const SDNode *N, SelectionDAG &DAG) { // This function assumes all the vselect's arguments are CONCAT_VECTOR // nodes and that the condition is a BV of ConstantSDNodes (or undefs). static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) { - SDLoc DL(N); + SDLoc dl(N); SDValue Cond = N->getOperand(0); SDValue LHS = N->getOperand(1); SDValue RHS = N->getOperand(2); @@ -7568,7 +5237,7 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) { // length of the BV and see if all the non-undef nodes are the same. ConstantSDNode *BottomHalf = nullptr; for (int i = 0; i < NumElems / 2; ++i) { - if (Cond->getOperand(i)->isUndef()) + if (Cond->getOperand(i)->getOpcode() == ISD::UNDEF) continue; if (BottomHalf == nullptr) @@ -7580,7 +5249,7 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) { // Do the same for the second half of the BuildVector ConstantSDNode *TopHalf = nullptr; for (int i = NumElems / 2; i < NumElems; ++i) { - if (Cond->getOperand(i)->isUndef()) + if (Cond->getOperand(i)->getOpcode() == ISD::UNDEF) continue; if (TopHalf == nullptr) @@ -7593,12 +5262,13 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) { "One half of the selector was all UNDEFs and the other was all the " "same value. This should have been addressed before this function."); return DAG.getNode( - ISD::CONCAT_VECTORS, DL, VT, + ISD::CONCAT_VECTORS, dl, VT, BottomHalf->isNullValue() ? RHS->getOperand(0) : LHS->getOperand(0), TopHalf->isNullValue() ? RHS->getOperand(1) : LHS->getOperand(1)); } SDValue DAGCombiner::visitMSCATTER(SDNode *N) { + if (Level >= AfterLegalizeTypes) return SDValue(); @@ -7618,7 +5288,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) { if (TLI.getTypeAction(*DAG.getContext(), Data.getValueType()) != TargetLowering::TypeSplitVector) return SDValue(); - SDValue MaskLo, MaskHi; + SDValue MaskLo, MaskHi, Lo, Hi; std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); EVT LoVT, HiVT; @@ -7635,7 +5305,6 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) { SDValue DataLo, DataHi; std::tie(DataLo, DataHi) = DAG.SplitVector(Data, DL); - SDValue Scale = MSC->getScale(); SDValue BasePtr = MSC->getBasePtr(); SDValue IndexLo, IndexHi; std::tie(IndexLo, IndexHi) = DAG.SplitVector(MSC->getIndex(), DL); @@ -7645,26 +5314,28 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) { MachineMemOperand::MOStore, LoMemVT.getStoreSize(), Alignment, MSC->getAAInfo(), MSC->getRanges()); - SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo, Scale }; - SDValue Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), - DataLo.getValueType(), DL, OpsLo, MMO); + SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo }; + Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(), + DL, OpsLo, MMO); + + SDValue OpsHi[] = {Chain, DataHi, MaskHi, BasePtr, IndexHi}; + Hi = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(), + DL, OpsHi, MMO); + + AddToWorklist(Lo.getNode()); + AddToWorklist(Hi.getNode()); - // The order of the Scatter operation after split is well defined. The "Hi" - // part comes after the "Lo". So these two operations should be chained one - // after another. - SDValue OpsHi[] = { Lo, DataHi, MaskHi, BasePtr, IndexHi, Scale }; - return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(), - DL, OpsHi, MMO); + return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Lo, Hi); } SDValue DAGCombiner::visitMSTORE(SDNode *N) { + if (Level >= AfterLegalizeTypes) return SDValue(); MaskedStoreSDNode *MST = dyn_cast<MaskedStoreSDNode>(N); SDValue Mask = MST->getMask(); SDValue Data = MST->getValue(); - EVT VT = Data.getValueType(); SDLoc DL(N); // If the MSTORE data type requires splitting and the mask is provided by a @@ -7672,14 +5343,18 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) { // prevents the type legalizer from unrolling SETCC into scalar comparisons // and enables future optimizations (e.g. min/max pattern matching on X86). if (Mask.getOpcode() == ISD::SETCC) { + // Check if any splitting is required. - if (TLI.getTypeAction(*DAG.getContext(), VT) != + if (TLI.getTypeAction(*DAG.getContext(), Data.getValueType()) != TargetLowering::TypeSplitVector) return SDValue(); SDValue MaskLo, MaskHi, Lo, Hi; std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); + EVT LoVT, HiVT; + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(MST->getValueType(0)); + SDValue Chain = MST->getChain(); SDValue Ptr = MST->getBasePtr(); @@ -7689,7 +5364,8 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) { // if Alignment is equal to the vector size, // take the half of it for the second part unsigned SecondHalfAlignment = - (Alignment == VT.getSizeInBits() / 8) ? Alignment / 2 : Alignment; + (Alignment == Data->getValueType(0).getSizeInBits()/8) ? + Alignment/2 : Alignment; EVT LoMemVT, HiMemVT; std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); @@ -7703,21 +5379,20 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) { Alignment, MST->getAAInfo(), MST->getRanges()); Lo = DAG.getMaskedStore(Chain, DL, DataLo, Ptr, MaskLo, LoMemVT, MMO, - MST->isTruncatingStore(), - MST->isCompressingStore()); + MST->isTruncatingStore()); - Ptr = TLI.IncrementMemoryAddress(Ptr, MaskLo, DL, LoMemVT, DAG, - MST->isCompressingStore()); - unsigned HiOffset = LoMemVT.getStoreSize(); + unsigned IncrementSize = LoMemVT.getSizeInBits()/8; + Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, + DAG.getConstant(IncrementSize, DL, Ptr.getValueType())); - MMO = DAG.getMachineFunction().getMachineMemOperand( - MST->getPointerInfo().getWithOffset(HiOffset), - MachineMemOperand::MOStore, HiMemVT.getStoreSize(), SecondHalfAlignment, - MST->getAAInfo(), MST->getRanges()); + MMO = DAG.getMachineFunction(). + getMachineMemOperand(MST->getPointerInfo(), + MachineMemOperand::MOStore, HiMemVT.getStoreSize(), + SecondHalfAlignment, MST->getAAInfo(), + MST->getRanges()); Hi = DAG.getMaskedStore(Chain, DL, DataHi, Ptr, MaskHi, HiMemVT, MMO, - MST->isTruncatingStore(), - MST->isCompressingStore()); + MST->isTruncatingStore()); AddToWorklist(Lo.getNode()); AddToWorklist(Hi.getNode()); @@ -7728,10 +5403,11 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) { } SDValue DAGCombiner::visitMGATHER(SDNode *N) { + if (Level >= AfterLegalizeTypes) return SDValue(); - MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N); + MaskedGatherSDNode *MGT = dyn_cast<MaskedGatherSDNode>(N); SDValue Mask = MGT->getMask(); SDLoc DL(N); @@ -7753,9 +5429,9 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) { SDValue MaskLo, MaskHi, Lo, Hi; std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); - SDValue PassThru = MGT->getPassThru(); - SDValue PassThruLo, PassThruHi; - std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, DL); + SDValue Src0 = MGT->getValue(); + SDValue Src0Lo, Src0Hi; + std::tie(Src0Lo, Src0Hi) = DAG.SplitVector(Src0, DL); EVT LoVT, HiVT; std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT); @@ -7767,7 +5443,6 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) { EVT LoMemVT, HiMemVT; std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); - SDValue Scale = MGT->getScale(); SDValue BasePtr = MGT->getBasePtr(); SDValue Index = MGT->getIndex(); SDValue IndexLo, IndexHi; @@ -7778,13 +5453,13 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) { MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), Alignment, MGT->getAAInfo(), MGT->getRanges()); - SDValue OpsLo[] = { Chain, PassThruLo, MaskLo, BasePtr, IndexLo, Scale }; + SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo }; Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, DL, OpsLo, - MMO); + MMO); - SDValue OpsHi[] = { Chain, PassThruHi, MaskHi, BasePtr, IndexHi, Scale }; + SDValue OpsHi[] = {Chain, Src0Hi, MaskHi, BasePtr, IndexHi}; Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, DL, OpsHi, - MMO); + MMO); AddToWorklist(Lo.getNode()); AddToWorklist(Hi.getNode()); @@ -7805,6 +5480,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) { } SDValue DAGCombiner::visitMLOAD(SDNode *N) { + if (Level >= AfterLegalizeTypes) return SDValue(); @@ -7816,6 +5492,7 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) { // SETCC, then split both nodes and its operands before legalization. This // prevents the type legalizer from unrolling SETCC into scalar comparisons // and enables future optimizations (e.g. min/max pattern matching on X86). + if (Mask.getOpcode() == ISD::SETCC) { EVT VT = N->getValueType(0); @@ -7827,9 +5504,9 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) { SDValue MaskLo, MaskHi, Lo, Hi; std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); - SDValue PassThru = MLD->getPassThru(); - SDValue PassThruLo, PassThruHi; - std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, DL); + SDValue Src0 = MLD->getSrc0(); + SDValue Src0Lo, Src0Hi; + std::tie(Src0Lo, Src0Hi) = DAG.SplitVector(Src0, DL); EVT LoVT, HiVT; std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(MLD->getValueType(0)); @@ -7853,20 +5530,20 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) { MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), Alignment, MLD->getAAInfo(), MLD->getRanges()); - Lo = DAG.getMaskedLoad(LoVT, DL, Chain, Ptr, MaskLo, PassThruLo, LoMemVT, - MMO, ISD::NON_EXTLOAD, MLD->isExpandingLoad()); + Lo = DAG.getMaskedLoad(LoVT, DL, Chain, Ptr, MaskLo, Src0Lo, LoMemVT, MMO, + ISD::NON_EXTLOAD); - Ptr = TLI.IncrementMemoryAddress(Ptr, MaskLo, DL, LoMemVT, DAG, - MLD->isExpandingLoad()); - unsigned HiOffset = LoMemVT.getStoreSize(); + unsigned IncrementSize = LoMemVT.getSizeInBits()/8; + Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, + DAG.getConstant(IncrementSize, DL, Ptr.getValueType())); - MMO = DAG.getMachineFunction().getMachineMemOperand( - MLD->getPointerInfo().getWithOffset(HiOffset), - MachineMemOperand::MOLoad, HiMemVT.getStoreSize(), SecondHalfAlignment, - MLD->getAAInfo(), MLD->getRanges()); + MMO = DAG.getMachineFunction(). + getMachineMemOperand(MLD->getPointerInfo(), + MachineMemOperand::MOLoad, HiMemVT.getStoreSize(), + SecondHalfAlignment, MLD->getAAInfo(), MLD->getRanges()); - Hi = DAG.getMaskedLoad(HiVT, DL, Chain, Ptr, MaskHi, PassThruHi, HiMemVT, - MMO, ISD::NON_EXTLOAD, MLD->isExpandingLoad()); + Hi = DAG.getMaskedLoad(HiVT, DL, Chain, Ptr, MaskHi, Src0Hi, HiMemVT, MMO, + ISD::NON_EXTLOAD); AddToWorklist(Lo.getNode()); AddToWorklist(Hi.getNode()); @@ -7888,66 +5565,12 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) { return SDValue(); } -/// A vector select of 2 constant vectors can be simplified to math/logic to -/// avoid a variable select instruction and possibly avoid constant loads. -SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) { - SDValue Cond = N->getOperand(0); - SDValue N1 = N->getOperand(1); - SDValue N2 = N->getOperand(2); - EVT VT = N->getValueType(0); - if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 || - !TLI.convertSelectOfConstantsToMath(VT) || - !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) || - !ISD::isBuildVectorOfConstantSDNodes(N2.getNode())) - return SDValue(); - - // Check if we can use the condition value to increment/decrement a single - // constant value. This simplifies a select to an add and removes a constant - // load/materialization from the general case. - bool AllAddOne = true; - bool AllSubOne = true; - unsigned Elts = VT.getVectorNumElements(); - for (unsigned i = 0; i != Elts; ++i) { - SDValue N1Elt = N1.getOperand(i); - SDValue N2Elt = N2.getOperand(i); - if (N1Elt.isUndef() || N2Elt.isUndef()) - continue; - - const APInt &C1 = cast<ConstantSDNode>(N1Elt)->getAPIntValue(); - const APInt &C2 = cast<ConstantSDNode>(N2Elt)->getAPIntValue(); - if (C1 != C2 + 1) - AllAddOne = false; - if (C1 != C2 - 1) - AllSubOne = false; - } - - // Further simplifications for the extra-special cases where the constants are - // all 0 or all -1 should be implemented as folds of these patterns. - SDLoc DL(N); - if (AllAddOne || AllSubOne) { - // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C - // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C - auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND; - SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond); - return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2); - } - - // The general case for select-of-constants: - // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2 - // ...but that only makes sense if a vselect is slower than 2 logic ops, so - // leave that to a machine-specific pass. - return SDValue(); -} - SDValue DAGCombiner::visitVSELECT(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); SDValue N2 = N->getOperand(2); SDLoc DL(N); - if (SDValue V = DAG.simplifySelect(N0, N1, N2)) - return V; - // Canonicalize integer abs. // vselect (setg[te] X, 0), X, -X -> // vselect (setgt X, -1), X, -X -> @@ -7969,66 +5592,47 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { if (isAbs) { EVT VT = LHS.getValueType(); - if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) - return DAG.getNode(ISD::ABS, DL, VT, LHS); - SDValue Shift = DAG.getNode( ISD::SRA, DL, VT, LHS, - DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT)); + DAG.getConstant(VT.getScalarType().getSizeInBits() - 1, DL, VT)); SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift); AddToWorklist(Shift.getNode()); AddToWorklist(Add.getNode()); return DAG.getNode(ISD::XOR, DL, VT, Add, Shift); } - - // vselect x, y (fcmp lt x, y) -> fminnum x, y - // vselect x, y (fcmp gt x, y) -> fmaxnum x, y - // - // This is OK if we don't care about what happens if either operand is a - // NaN. - // - EVT VT = N->getValueType(0); - if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N0.getOperand(0), N0.getOperand(1))) { - ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get(); - if (SDValue FMinMax = combineMinNumMaxNum( - DL, VT, N0.getOperand(0), N0.getOperand(1), N1, N2, CC, TLI, DAG)) - return FMinMax; - } - - // If this select has a condition (setcc) with narrower operands than the - // select, try to widen the compare to match the select width. - // TODO: This should be extended to handle any constant. - // TODO: This could be extended to handle non-loading patterns, but that - // requires thorough testing to avoid regressions. - if (isNullOrNullSplat(RHS)) { - EVT NarrowVT = LHS.getValueType(); - EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger(); - EVT SetCCVT = getSetCCResultType(LHS.getValueType()); - unsigned SetCCWidth = SetCCVT.getScalarSizeInBits(); - unsigned WideWidth = WideVT.getScalarSizeInBits(); - bool IsSigned = isSignedIntSetCC(CC); - auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD; - if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() && - SetCCWidth != 1 && SetCCWidth < WideWidth && - TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) && - TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) { - // Both compare operands can be widened for free. The LHS can use an - // extended load, and the RHS is a constant: - // vselect (ext (setcc load(X), C)), N1, N2 --> - // vselect (setcc extload(X), C'), N1, N2 - auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; - SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS); - SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS); - EVT WideSetCCVT = getSetCCResultType(WideVT); - SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC); - return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2); - } - } } if (SimplifySelectOps(N, N1, N2)) return SDValue(N, 0); // Don't revisit N. + // If the VSELECT result requires splitting and the mask is provided by a + // SETCC, then split both nodes and its operands before legalization. This + // prevents the type legalizer from unrolling SETCC into scalar comparisons + // and enables future optimizations (e.g. min/max pattern matching on X86). + if (N0.getOpcode() == ISD::SETCC) { + EVT VT = N->getValueType(0); + + // Check if any splitting is required. + if (TLI.getTypeAction(*DAG.getContext(), VT) != + TargetLowering::TypeSplitVector) + return SDValue(); + + SDValue Lo, Hi, CCLo, CCHi, LL, LH, RL, RH; + std::tie(CCLo, CCHi) = SplitVSETCC(N0.getNode(), DAG); + std::tie(LL, LH) = DAG.SplitVectorOperand(N, 1); + std::tie(RL, RH) = DAG.SplitVectorOperand(N, 2); + + Lo = DAG.getNode(N->getOpcode(), DL, LL.getValueType(), CCLo, LL, RL); + Hi = DAG.getNode(N->getOpcode(), DL, LH.getValueType(), CCHi, LH, RH); + + // Add the new VSELECT nodes to the work list in case they need to be split + // again. + AddToWorklist(Lo.getNode()); + AddToWorklist(Hi.getNode()); + + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi); + } + // Fold (vselect (build_vector all_ones), N1, N2) -> N1 if (ISD::isBuildVectorAllOnes(N0.getNode())) return N1; @@ -8046,9 +5650,6 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { return CV; } - if (SDValue V = foldVSelectOfConstants(N)) - return V; - return SDValue(); } @@ -8065,8 +5666,9 @@ SDValue DAGCombiner::visitSELECT_CC(SDNode *N) { return N2; // Determine if the condition we're dealing with is constant - if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1, - CC, SDLoc(N), false)) { + SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), + N0, N1, CC, SDLoc(N), false); + if (SCC.getNode()) { AddToWorklist(SCC.getNode()); if (ConstantSDNode *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode())) { @@ -8074,7 +5676,7 @@ SDValue DAGCombiner::visitSELECT_CC(SDNode *N) { return N2; // cond always true -> true val else return N3; // cond always false -> false val - } else if (SCC->isUndef()) { + } else if (SCC->getOpcode() == ISD::UNDEF) { // When the condition is UNDEF, just return the first operand. This is // coherent the DAG creation, no setcc node is created in this case return N2; @@ -8095,43 +5697,19 @@ SDValue DAGCombiner::visitSELECT_CC(SDNode *N) { } SDValue DAGCombiner::visitSETCC(SDNode *N) { - // setcc is very commonly used as an argument to brcond. This pattern - // also lend itself to numerous combines and, as a result, it is desired - // we keep the argument to a brcond as a setcc as much as possible. - bool PreferSetCC = - N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND; - - SDValue Combined = SimplifySetCC( - N->getValueType(0), N->getOperand(0), N->getOperand(1), - cast<CondCodeSDNode>(N->getOperand(2))->get(), SDLoc(N), !PreferSetCC); - - if (!Combined) - return SDValue(); - - // If we prefer to have a setcc, and we don't, we'll try our best to - // recreate one using rebuildSetCC. - if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) { - SDValue NewSetCC = rebuildSetCC(Combined); - - // We don't have anything interesting to combine to. - if (NewSetCC.getNode() == N) - return SDValue(); - - if (NewSetCC) - return NewSetCC; - } - - return Combined; + return SimplifySetCC(N->getValueType(0), N->getOperand(0), N->getOperand(1), + cast<CondCodeSDNode>(N->getOperand(2))->get(), + SDLoc(N)); } -SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) { +SDValue DAGCombiner::visitSETCCE(SDNode *N) { SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); SDValue Carry = N->getOperand(2); SDValue Cond = N->getOperand(3); // If Carry is false, fold to a regular SETCC. - if (isNullConstant(Carry)) + if (Carry.getOpcode() == ISD::CARRY_FALSE) return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond); return SDValue(); @@ -8143,47 +5721,43 @@ SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) { /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND). /// Vector extends are not folded if operations are legal; this is to /// avoid introducing illegal build_vector dag nodes. -static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI, - SelectionDAG &DAG, bool LegalTypes) { +static SDNode *tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI, + SelectionDAG &DAG, bool LegalTypes, + bool LegalOperations) { unsigned Opcode = N->getOpcode(); SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND || - Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG || - Opcode == ISD::ZERO_EXTEND_VECTOR_INREG) + Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG) && "Expected EXTEND dag node in input!"); // fold (sext c1) -> c1 // fold (zext c1) -> c1 // fold (aext c1) -> c1 if (isa<ConstantSDNode>(N0)) - return DAG.getNode(Opcode, SDLoc(N), VT, N0); + return DAG.getNode(Opcode, SDLoc(N), VT, N0).getNode(); // fold (sext (build_vector AllConstants) -> (build_vector AllConstants) // fold (zext (build_vector AllConstants) -> (build_vector AllConstants) // fold (aext (build_vector AllConstants) -> (build_vector AllConstants) EVT SVT = VT.getScalarType(); - if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) && + if (!(VT.isVector() && + (!LegalTypes || (!LegalOperations && TLI.isTypeLegal(SVT))) && ISD::isBuildVectorOfConstantSDNodes(N0.getNode()))) - return SDValue(); + return nullptr; // We can fold this node into a build_vector. unsigned VTBits = SVT.getSizeInBits(); - unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits(); + unsigned EVTBits = N0->getValueType(0).getScalarType().getSizeInBits(); SmallVector<SDValue, 8> Elts; unsigned NumElts = VT.getVectorNumElements(); SDLoc DL(N); - // For zero-extensions, UNDEF elements still guarantee to have the upper - // bits set to zero. - bool IsZext = - Opcode == ISD::ZERO_EXTEND || Opcode == ISD::ZERO_EXTEND_VECTOR_INREG; - - for (unsigned i = 0; i != NumElts; ++i) { - SDValue Op = N0.getOperand(i); - if (Op.isUndef()) { - Elts.push_back(IsZext ? DAG.getConstant(0, DL, SVT) : DAG.getUNDEF(SVT)); + for (unsigned i=0; i != NumElts; ++i) { + SDValue Op = N0->getOperand(i); + if (Op->getOpcode() == ISD::UNDEF) { + Elts.push_back(DAG.getUNDEF(SVT)); continue; } @@ -8197,19 +5771,19 @@ static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI, Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT)); } - return DAG.getBuildVector(VT, DL, Elts); + return DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Elts).getNode(); } // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this: // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))" // transformation. Returns true if extension are possible and the above // mentioned transformation is profitable. -static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0, +static bool ExtendUsesToFormExtLoad(SDNode *N, SDValue N0, unsigned ExtOpc, SmallVectorImpl<SDNode *> &ExtendNodes, const TargetLowering &TLI) { bool HasCopyToRegUses = false; - bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType()); + bool isTruncFree = TLI.isTruncateFree(N->getValueType(0), N0.getValueType()); for (SDNode::use_iterator UI = N0.getNode()->use_begin(), UE = N0.getNode()->use_end(); UI != UE; ++UI) { @@ -8265,16 +5839,16 @@ static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0, } void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs, - SDValue OrigLoad, SDValue ExtLoad, + SDValue Trunc, SDValue ExtLoad, SDLoc DL, ISD::NodeType ExtType) { // Extend SetCC uses if necessary. - SDLoc DL(ExtLoad); - for (SDNode *SetCC : SetCCs) { + for (unsigned i = 0, e = SetCCs.size(); i != e; ++i) { + SDNode *SetCC = SetCCs[i]; SmallVector<SDValue, 4> Ops; for (unsigned j = 0; j != 2; ++j) { SDValue SOp = SetCC->getOperand(j); - if (SOp == OrigLoad) + if (SOp == Trunc) Ops.push_back(ExtLoad); else Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp)); @@ -8323,7 +5897,7 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) { return SDValue(); SmallVector<SDNode *, 4> SetCCs; - if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI)) + if (!ExtendUsesToFormExtLoad(N, N0, N->getOpcode(), SetCCs, TLI)) return SDValue(); ISD::LoadExtType ExtType = @@ -8354,9 +5928,10 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) { const unsigned Align = MinAlign(LN0->getAlignment(), Offset); SDValue SplitLoad = DAG.getExtLoad( - ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(), BasePtr, - LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT, Align, - LN0->getMemOperand()->getFlags(), LN0->getAAInfo()); + ExtType, DL, SplitDstVT, LN0->getChain(), BasePtr, + LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT, + LN0->isVolatile(), LN0->isNonTemporal(), LN0->isInvariant(), + Align, LN0->getAAInfo()); BasePtr = DAG.getNode(ISD::ADD, DL, BasePtr.getValueType(), BasePtr, DAG.getConstant(Stride, DL, BasePtr.getValueType())); @@ -8368,254 +5943,37 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) { SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains); SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads); - // Simplify TF. - AddToWorklist(NewChain.getNode()); - CombineTo(N, NewValue); // Replace uses of the original load (before extension) // with a truncate of the concatenated sextloaded vectors. SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue); - ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode()); CombineTo(N0.getNode(), Trunc, NewChain); + ExtendSetCCUses(SetCCs, Trunc, NewValue, DL, + (ISD::NodeType)N->getOpcode()); return SDValue(N, 0); // Return N so it doesn't get rechecked! } -// fold (zext (and/or/xor (shl/shr (load x), cst), cst)) -> -// (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst)) -SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) { - assert(N->getOpcode() == ISD::ZERO_EXTEND); - EVT VT = N->getValueType(0); - - // and/or/xor - SDValue N0 = N->getOperand(0); - if (!(N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR || - N0.getOpcode() == ISD::XOR) || - N0.getOperand(1).getOpcode() != ISD::Constant || - (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT))) - return SDValue(); - - // shl/shr - SDValue N1 = N0->getOperand(0); - if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) || - N1.getOperand(1).getOpcode() != ISD::Constant || - (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT))) - return SDValue(); - - // load - if (!isa<LoadSDNode>(N1.getOperand(0))) - return SDValue(); - LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0)); - EVT MemVT = Load->getMemoryVT(); - if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) || - Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed()) - return SDValue(); - - - // If the shift op is SHL, the logic op must be AND, otherwise the result - // will be wrong. - if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND) - return SDValue(); - - if (!N0.hasOneUse() || !N1.hasOneUse()) - return SDValue(); - - SmallVector<SDNode*, 4> SetCCs; - if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0), - ISD::ZERO_EXTEND, SetCCs, TLI)) - return SDValue(); - - // Actually do the transformation. - SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT, - Load->getChain(), Load->getBasePtr(), - Load->getMemoryVT(), Load->getMemOperand()); - - SDLoc DL1(N1); - SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad, - N1.getOperand(1)); - - APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); - Mask = Mask.zext(VT.getSizeInBits()); - SDLoc DL0(N0); - SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift, - DAG.getConstant(Mask, DL0, VT)); - - ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND); - CombineTo(N, And); - if (SDValue(Load, 0).hasOneUse()) { - DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1)); - } else { - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load), - Load->getValueType(0), ExtLoad); - CombineTo(Load, Trunc, ExtLoad.getValue(1)); - } - return SDValue(N,0); // Return N so it doesn't get rechecked! -} - -/// If we're narrowing or widening the result of a vector select and the final -/// size is the same size as a setcc (compare) feeding the select, then try to -/// apply the cast operation to the select's operands because matching vector -/// sizes for a select condition and other operands should be more efficient. -SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) { - unsigned CastOpcode = Cast->getOpcode(); - assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND || - CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND || - CastOpcode == ISD::FP_ROUND) && - "Unexpected opcode for vector select narrowing/widening"); - - // We only do this transform before legal ops because the pattern may be - // obfuscated by target-specific operations after legalization. Do not create - // an illegal select op, however, because that may be difficult to lower. - EVT VT = Cast->getValueType(0); - if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT)) - return SDValue(); - - SDValue VSel = Cast->getOperand(0); - if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() || - VSel.getOperand(0).getOpcode() != ISD::SETCC) - return SDValue(); - - // Does the setcc have the same vector size as the casted select? - SDValue SetCC = VSel.getOperand(0); - EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType()); - if (SetCCVT.getSizeInBits() != VT.getSizeInBits()) - return SDValue(); - - // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B) - SDValue A = VSel.getOperand(1); - SDValue B = VSel.getOperand(2); - SDValue CastA, CastB; - SDLoc DL(Cast); - if (CastOpcode == ISD::FP_ROUND) { - // FP_ROUND (fptrunc) has an extra flag operand to pass along. - CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1)); - CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1)); - } else { - CastA = DAG.getNode(CastOpcode, DL, VT, A); - CastB = DAG.getNode(CastOpcode, DL, VT, B); - } - return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB); -} - -// fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x))) -// fold ([s|z]ext ( extload x)) -> ([s|z]ext (truncate ([s|z]extload x))) -static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner, - const TargetLowering &TLI, EVT VT, - bool LegalOperations, SDNode *N, - SDValue N0, ISD::LoadExtType ExtLoadType) { - SDNode *N0Node = N0.getNode(); - bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node) - : ISD::isZEXTLoad(N0Node); - if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) || - !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse()) - return {}; - - LoadSDNode *LN0 = cast<LoadSDNode>(N0); - EVT MemVT = LN0->getMemoryVT(); - if ((LegalOperations || LN0->isVolatile() || VT.isVector()) && - !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT)) - return {}; - - SDValue ExtLoad = - DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(), - LN0->getBasePtr(), MemVT, LN0->getMemOperand()); - Combiner.CombineTo(N, ExtLoad); - DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1)); - return SDValue(N, 0); // Return N so it doesn't get rechecked! -} - -// fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x))) -// Only generate vector extloads when 1) they're legal, and 2) they are -// deemed desirable by the target. -static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner, - const TargetLowering &TLI, EVT VT, - bool LegalOperations, SDNode *N, SDValue N0, - ISD::LoadExtType ExtLoadType, - ISD::NodeType ExtOpc) { - if (!ISD::isNON_EXTLoad(N0.getNode()) || - !ISD::isUNINDEXEDLoad(N0.getNode()) || - ((LegalOperations || VT.isVector() || - cast<LoadSDNode>(N0)->isVolatile()) && - !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType()))) - return {}; - - bool DoXform = true; - SmallVector<SDNode *, 4> SetCCs; - if (!N0.hasOneUse()) - DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI); - if (VT.isVector()) - DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0)); - if (!DoXform) - return {}; - - LoadSDNode *LN0 = cast<LoadSDNode>(N0); - SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(), - LN0->getBasePtr(), N0.getValueType(), - LN0->getMemOperand()); - Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc); - // If the load value is used only by N, replace it via CombineTo N. - bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse(); - Combiner.CombineTo(N, ExtLoad); - if (NoReplaceTrunc) { - DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1)); - } else { - SDValue Trunc = - DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad); - Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1)); - } - return SDValue(N, 0); // Return N so it doesn't get rechecked! -} - -static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG, - bool LegalOperations) { - assert((N->getOpcode() == ISD::SIGN_EXTEND || - N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext"); - - SDValue SetCC = N->getOperand(0); - if (LegalOperations || SetCC.getOpcode() != ISD::SETCC || - !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1) - return SDValue(); - - SDValue X = SetCC.getOperand(0); - SDValue Ones = SetCC.getOperand(1); - ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get(); - EVT VT = N->getValueType(0); - EVT XVT = X.getValueType(); - // setge X, C is canonicalized to setgt, so we do not need to match that - // pattern. The setlt sibling is folded in SimplifySelectCC() because it does - // not require the 'not' op. - if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) { - // Invert and smear/shift the sign bit: - // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1) - // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1) - SDLoc DL(N); - SDValue NotX = DAG.getNOT(DL, X, VT); - SDValue ShiftAmount = DAG.getConstant(VT.getSizeInBits() - 1, DL, VT); - auto ShiftOpcode = N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL; - return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount); - } - return SDValue(); -} - SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); - SDLoc DL(N); - if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) - return Res; + if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes, + LegalOperations)) + return SDValue(Res, 0); // fold (sext (sext x)) -> (sext x) // fold (sext (aext x)) -> (sext x) if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) - return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0)); + return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, + N0.getOperand(0)); if (N0.getOpcode() == ISD::TRUNCATE) { // fold (sext (truncate (load x))) -> (sext (smaller load x)) // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n))) if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) { - SDNode *oye = N0.getOperand(0).getNode(); + SDNode* oye = N0.getNode()->getOperand(0).getNode(); if (NarrowLoad.getNode() != N0.getNode()) { CombineTo(N0.getNode(), NarrowLoad); // CombineTo deleted the truncate, if needed, but not what's under it. @@ -8627,9 +5985,9 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { // See if the value being truncated is already sign extended. If so, just // eliminate the trunc/sext pair. SDValue Op = N0.getOperand(0); - unsigned OpBits = Op.getScalarValueSizeInBits(); - unsigned MidBits = N0.getScalarValueSizeInBits(); - unsigned DestBits = VT.getScalarSizeInBits(); + unsigned OpBits = Op.getValueType().getScalarType().getSizeInBits(); + unsigned MidBits = N0.getValueType().getScalarType().getSizeInBits(); + unsigned DestBits = VT.getScalarType().getSizeInBits(); unsigned NumSignBits = DAG.ComputeNumSignBits(Op); if (OpBits == DestBits) { @@ -8641,12 +5999,12 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign // bits, just sext from i32. if (NumSignBits > OpBits-MidBits) - return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op); + return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, Op); } else { // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign // bits, just truncate to i32. if (NumSignBits > OpBits-MidBits) - return DAG.getNode(ISD::TRUNCATE, DL, VT, Op); + return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Op); } // fold (sext (truncate x)) -> (sextinreg x). @@ -8656,26 +6014,65 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op); else if (OpBits > DestBits) Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op); - return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op, + return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, Op, DAG.getValueType(N0.getValueType())); } } - // Try to simplify (sext (load x)). - if (SDValue foldedExt = - tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0, - ISD::SEXTLOAD, ISD::SIGN_EXTEND)) - return foldedExt; + // fold (sext (load x)) -> (sext (truncate (sextload x))) + // Only generate vector extloads when 1) they're legal, and 2) they are + // deemed desirable by the target. + if (ISD::isNON_EXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && + ((!LegalOperations && !VT.isVector() && + !cast<LoadSDNode>(N0)->isVolatile()) || + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, N0.getValueType()))) { + bool DoXform = true; + SmallVector<SDNode*, 4> SetCCs; + if (!N0.hasOneUse()) + DoXform = ExtendUsesToFormExtLoad(N, N0, ISD::SIGN_EXTEND, SetCCs, TLI); + if (VT.isVector()) + DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0)); + if (DoXform) { + LoadSDNode *LN0 = cast<LoadSDNode>(N0); + SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT, + LN0->getChain(), + LN0->getBasePtr(), N0.getValueType(), + LN0->getMemOperand()); + CombineTo(N, ExtLoad); + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), + N0.getValueType(), ExtLoad); + CombineTo(N0.getNode(), Trunc, ExtLoad.getValue(1)); + ExtendSetCCUses(SetCCs, Trunc, ExtLoad, SDLoc(N), + ISD::SIGN_EXTEND); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } + } // fold (sext (load x)) to multiple smaller sextloads. // Only on illegal but splittable vectors. if (SDValue ExtLoad = CombineExtLoad(N)) return ExtLoad; - // Try to simplify (sext (sextload x)). - if (SDValue foldedExt = tryToFoldExtOfExtload( - DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD)) - return foldedExt; + // fold (sext (sextload x)) -> (sext (truncate (sextload x))) + // fold (sext ( extload x)) -> (sext (truncate (sextload x))) + if ((ISD::isSEXTLoad(N0.getNode()) || ISD::isEXTLoad(N0.getNode())) && + ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) { + LoadSDNode *LN0 = cast<LoadSDNode>(N0); + EVT MemVT = LN0->getMemoryVT(); + if ((!LegalOperations && !LN0->isVolatile()) || + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT)) { + SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT, + LN0->getChain(), + LN0->getBasePtr(), MemVT, + LN0->getMemOperand()); + CombineTo(N, ExtLoad); + CombineTo(N0.getNode(), + DAG.getNode(ISD::TRUNCATE, SDLoc(N0), + N0.getValueType(), ExtLoad), + ExtLoad.getValue(1)); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } + } // fold (sext (and/or/xor (load x), cst)) -> // (and/or/xor (sextload x), (sext cst)) @@ -8683,115 +6080,92 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { N0.getOpcode() == ISD::XOR) && isa<LoadSDNode>(N0.getOperand(0)) && N0.getOperand(1).getOpcode() == ISD::Constant && + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, N0.getValueType()) && (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) { - LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0)); - EVT MemVT = LN00->getMemoryVT(); - if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) && - LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) { + LoadSDNode *LN0 = cast<LoadSDNode>(N0.getOperand(0)); + if (LN0->getExtensionType() != ISD::ZEXTLOAD && LN0->isUnindexed()) { + bool DoXform = true; SmallVector<SDNode*, 4> SetCCs; - bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0), - ISD::SIGN_EXTEND, SetCCs, TLI); + if (!N0.hasOneUse()) + DoXform = ExtendUsesToFormExtLoad(N, N0.getOperand(0), ISD::SIGN_EXTEND, + SetCCs, TLI); if (DoXform) { - SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT, - LN00->getChain(), LN00->getBasePtr(), - LN00->getMemoryVT(), - LN00->getMemOperand()); + SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN0), VT, + LN0->getChain(), LN0->getBasePtr(), + LN0->getMemoryVT(), + LN0->getMemOperand()); APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); Mask = Mask.sext(VT.getSizeInBits()); + SDLoc DL(N); SDValue And = DAG.getNode(N0.getOpcode(), DL, VT, ExtLoad, DAG.getConstant(Mask, DL, VT)); - ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND); - bool NoReplaceTruncAnd = !N0.hasOneUse(); - bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse(); + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, + SDLoc(N0.getOperand(0)), + N0.getOperand(0).getValueType(), ExtLoad); CombineTo(N, And); - // If N0 has multiple uses, change other uses as well. - if (NoReplaceTruncAnd) { - SDValue TruncAnd = - DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And); - CombineTo(N0.getNode(), TruncAnd); - } - if (NoReplaceTrunc) { - DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1)); - } else { - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00), - LN00->getValueType(0), ExtLoad); - CombineTo(LN00, Trunc, ExtLoad.getValue(1)); - } - return SDValue(N,0); // Return N so it doesn't get rechecked! + CombineTo(N0.getOperand(0).getNode(), Trunc, ExtLoad.getValue(1)); + ExtendSetCCUses(SetCCs, Trunc, ExtLoad, DL, + ISD::SIGN_EXTEND); + return SDValue(N, 0); // Return N so it doesn't get rechecked! } } } - if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations)) - return V; - if (N0.getOpcode() == ISD::SETCC) { - SDValue N00 = N0.getOperand(0); - SDValue N01 = N0.getOperand(1); - ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get(); - EVT N00VT = N0.getOperand(0).getValueType(); - + EVT N0VT = N0.getOperand(0).getValueType(); // sext(setcc) -> sext_in_reg(vsetcc) for vectors. // Only do this before legalize for now. if (VT.isVector() && !LegalOperations && - TLI.getBooleanContents(N00VT) == + TLI.getBooleanContents(N0VT) == TargetLowering::ZeroOrNegativeOneBooleanContent) { // On some architectures (such as SSE/NEON/etc) the SETCC result type is // of the same size as the compared operands. Only optimize sext(setcc()) // if this is the case. - EVT SVT = getSetCCResultType(N00VT); + EVT SVT = getSetCCResultType(N0VT); - // If we already have the desired type, don't change it. - if (SVT != N0.getValueType()) { - // We know that the # elements of the results is the same as the - // # elements of the compare (and the # elements of the compare result - // for that matter). Check to see that they are the same size. If so, - // we know that the element size of the sext'd result matches the - // element size of the compare operands. - if (VT.getSizeInBits() == SVT.getSizeInBits()) - return DAG.getSetCC(DL, VT, N00, N01, CC); - - // If the desired elements are smaller or larger than the source - // elements, we can use a matching integer vector type and then - // truncate/sign extend. - EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger(); - if (SVT == MatchingVecType) { - SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC); - return DAG.getSExtOrTrunc(VsetCC, DL, VT); - } + // We know that the # elements of the results is the same as the + // # elements of the compare (and the # elements of the compare result + // for that matter). Check to see that they are the same size. If so, + // we know that the element size of the sext'd result matches the + // element size of the compare operands. + if (VT.getSizeInBits() == SVT.getSizeInBits()) + return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0), + N0.getOperand(1), + cast<CondCodeSDNode>(N0.getOperand(2))->get()); + + // If the desired elements are smaller or larger than the source + // elements we can use a matching integer vector type and then + // truncate/sign extend + EVT MatchingVectorType = N0VT.changeVectorElementTypeToInteger(); + if (SVT == MatchingVectorType) { + SDValue VsetCC = DAG.getSetCC(SDLoc(N), MatchingVectorType, + N0.getOperand(0), N0.getOperand(1), + cast<CondCodeSDNode>(N0.getOperand(2))->get()); + return DAG.getSExtOrTrunc(VsetCC, SDLoc(N), VT); } } - // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0) - // Here, T can be 1 or -1, depending on the type of the setcc and - // getBooleanContents(). - unsigned SetCCWidth = N0.getScalarValueSizeInBits(); - - // To determine the "true" side of the select, we need to know the high bit - // of the value returned by the setcc if it evaluates to true. - // If the type of the setcc is i1, then the true case of the select is just - // sext(i1 1), that is, -1. - // If the type of the setcc is larger (say, i8) then the value of the high - // bit depends on getBooleanContents(), so ask TLI for a real "true" value - // of the appropriate width. - SDValue ExtTrueVal = (SetCCWidth == 1) - ? DAG.getAllOnesConstant(DL, VT) - : DAG.getBoolConstant(true, DL, VT, N00VT); - SDValue Zero = DAG.getConstant(0, DL, VT); - if (SDValue SCC = - SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true)) - return SCC; - - if (!VT.isVector() && !TLI.convertSelectOfConstantsToMath(VT)) { - EVT SetCCVT = getSetCCResultType(N00VT); - // Don't do this transform for i1 because there's a select transform - // that would reverse it. - // TODO: We should not do this transform at all without a target hook - // because a sext is likely cheaper than a select? - if (SetCCVT.getScalarSizeInBits() != 1 && - (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) { - SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC); - return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero); + // sext(setcc x, y, cc) -> (select (setcc x, y, cc), -1, 0) + unsigned ElementWidth = VT.getScalarType().getSizeInBits(); + SDLoc DL(N); + SDValue NegOne = + DAG.getConstant(APInt::getAllOnesValue(ElementWidth), DL, VT); + SDValue SCC = + SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), + NegOne, DAG.getConstant(0, DL, VT), + cast<CondCodeSDNode>(N0.getOperand(2))->get(), true); + if (SCC.getNode()) return SCC; + + if (!VT.isVector()) { + EVT SetCCVT = getSetCCResultType(N0.getOperand(0).getValueType()); + if (!LegalOperations || + TLI.isOperationLegal(ISD::SETCC, N0.getOperand(0).getValueType())) { + SDLoc DL(N); + ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get(); + SDValue SetCC = DAG.getSetCC(DL, SetCCVT, + N0.getOperand(0), N0.getOperand(1), CC); + return DAG.getSelect(DL, VT, SetCC, + NegOne, DAG.getConstant(0, DL, VT)); } } } @@ -8799,53 +6173,54 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { // fold (sext x) -> (zext x) if the sign bit is known zero. if ((!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) && DAG.SignBitIsZero(N0)) - return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0); - - if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N)) - return NewVSel; + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, N0); return SDValue(); } // isTruncateOf - If N is a truncate of some other value, return true, record -// the value being truncated in Op and which of Op's bits are zero/one in Known. -// This function computes KnownBits to avoid a duplicated call to +// the value being truncated in Op and which of Op's bits are zero in KnownZero. +// This function computes KnownZero to avoid a duplicated call to // computeKnownBits in the caller. static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op, - KnownBits &Known) { + APInt &KnownZero) { + APInt KnownOne; if (N->getOpcode() == ISD::TRUNCATE) { Op = N->getOperand(0); - Known = DAG.computeKnownBits(Op); + DAG.computeKnownBits(Op, KnownZero, KnownOne); return true; } - if (N.getOpcode() != ISD::SETCC || - N.getValueType().getScalarType() != MVT::i1 || - cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE) + if (N->getOpcode() != ISD::SETCC || N->getValueType(0) != MVT::i1 || + cast<CondCodeSDNode>(N->getOperand(2))->get() != ISD::SETNE) return false; SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); assert(Op0.getValueType() == Op1.getValueType()); - if (isNullOrNullSplat(Op0)) + if (isNullConstant(Op0)) Op = Op1; - else if (isNullOrNullSplat(Op1)) + else if (isNullConstant(Op1)) Op = Op0; else return false; - Known = DAG.computeKnownBits(Op); + DAG.computeKnownBits(Op, KnownZero, KnownOne); + + if (!(KnownZero | APInt(Op.getValueSizeInBits(), 1)).isAllOnesValue()) + return false; - return (Known.Zero | 1).isAllOnesValue(); + return true; } SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); - if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) - return Res; + if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes, + LegalOperations)) + return SDValue(Res, 0); // fold (zext (zext x)) -> (zext x) // fold (zext (aext x)) -> (zext x) @@ -8856,18 +6231,39 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { // fold (zext (truncate x)) -> (zext x) or // (zext (truncate x)) -> (truncate x) // This is valid when the truncated bits of x are already zero. + // FIXME: We should extend this to work for vectors too. SDValue Op; - KnownBits Known; - if (isTruncateOf(DAG, N0, Op, Known)) { + APInt KnownZero; + if (!VT.isVector() && isTruncateOf(DAG, N0, Op, KnownZero)) { APInt TruncatedBits = - (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ? - APInt(Op.getScalarValueSizeInBits(), 0) : - APInt::getBitsSet(Op.getScalarValueSizeInBits(), - N0.getScalarValueSizeInBits(), - std::min(Op.getScalarValueSizeInBits(), - VT.getScalarSizeInBits())); - if (TruncatedBits.isSubsetOf(Known.Zero)) - return DAG.getZExtOrTrunc(Op, SDLoc(N), VT); + (Op.getValueSizeInBits() == N0.getValueSizeInBits()) ? + APInt(Op.getValueSizeInBits(), 0) : + APInt::getBitsSet(Op.getValueSizeInBits(), + N0.getValueSizeInBits(), + std::min(Op.getValueSizeInBits(), + VT.getSizeInBits())); + if (TruncatedBits == (KnownZero & TruncatedBits)) { + if (VT.bitsGT(Op.getValueType())) + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, Op); + if (VT.bitsLT(Op.getValueType())) + return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Op); + + return Op; + } + } + + // fold (zext (truncate (load x))) -> (zext (smaller load x)) + // fold (zext (truncate (srl (load x), c))) -> (zext (small load (x+c/n))) + if (N0.getOpcode() == ISD::TRUNCATE) { + if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) { + SDNode* oye = N0.getNode()->getOperand(0).getNode(); + if (NarrowLoad.getNode() != N0.getNode()) { + CombineTo(N0.getNode(), NarrowLoad); + // CombineTo deleted the truncate, if needed, but not what's under it. + AddToWorklist(oye); + } + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } } // fold (zext (truncate x)) -> (and x, mask) @@ -8875,7 +6271,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { // fold (zext (truncate (load x))) -> (zext (smaller load x)) // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n))) if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) { - SDNode *oye = N0.getOperand(0).getNode(); + SDNode *oye = N0.getNode()->getOperand(0).getNode(); if (NarrowLoad.getNode() != N0.getNode()) { CombineTo(N0.getNode(), NarrowLoad); // CombineTo deleted the truncate, if needed, but not what's under it. @@ -8889,27 +6285,26 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { // Try to mask before the extension to avoid having to generate a larger mask, // possibly over several sub-vectors. - if (SrcVT.bitsLT(VT) && VT.isVector()) { + if (SrcVT.bitsLT(VT)) { if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) && TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) { SDValue Op = N0.getOperand(0); Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType()); AddToWorklist(Op.getNode()); - SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT); - // Transfer the debug info; the new node is equivalent to N0. - DAG.transferDbgValues(N0, ZExtOrTrunc); - return ZExtOrTrunc; + return DAG.getZExtOrTrunc(Op, SDLoc(N), VT); } } if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) { - SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT); - AddToWorklist(Op.getNode()); - SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType()); - // We may safely transfer the debug info describing the truncate node over - // to the equivalent and operation. - DAG.transferDbgValues(N0, And); - return And; + SDValue Op = N0.getOperand(0); + if (SrcVT.bitsLT(VT)) { + Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, Op); + AddToWorklist(Op.getNode()); + } else if (SrcVT.bitsGT(VT)) { + Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Op); + AddToWorklist(Op.getNode()); + } + return DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType()); } } @@ -8922,7 +6317,11 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { N0.getValueType()) || !TLI.isZExtFree(N0.getValueType(), VT))) { SDValue X = N0.getOperand(0).getOperand(0); - X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT); + if (X.getValueType().bitsLT(VT)) { + X = DAG.getNode(ISD::ANY_EXTEND, SDLoc(X), VT, X); + } else if (X.getValueType().bitsGT(VT)) { + X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X); + } APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); Mask = Mask.zext(VT.getSizeInBits()); SDLoc DL(N); @@ -8930,11 +6329,35 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { X, DAG.getConstant(Mask, DL, VT)); } - // Try to simplify (zext (load x)). - if (SDValue foldedExt = - tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0, - ISD::ZEXTLOAD, ISD::ZERO_EXTEND)) - return foldedExt; + // fold (zext (load x)) -> (zext (truncate (zextload x))) + // Only generate vector extloads when 1) they're legal, and 2) they are + // deemed desirable by the target. + if (ISD::isNON_EXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && + ((!LegalOperations && !VT.isVector() && + !cast<LoadSDNode>(N0)->isVolatile()) || + TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, N0.getValueType()))) { + bool DoXform = true; + SmallVector<SDNode*, 4> SetCCs; + if (!N0.hasOneUse()) + DoXform = ExtendUsesToFormExtLoad(N, N0, ISD::ZERO_EXTEND, SetCCs, TLI); + if (VT.isVector()) + DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0)); + if (DoXform) { + LoadSDNode *LN0 = cast<LoadSDNode>(N0); + SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N), VT, + LN0->getChain(), + LN0->getBasePtr(), N0.getValueType(), + LN0->getMemOperand()); + CombineTo(N, ExtLoad); + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), + N0.getValueType(), ExtLoad); + CombineTo(N0.getNode(), Trunc, ExtLoad.getValue(1)); + + ExtendSetCCUses(SetCCs, Trunc, ExtLoad, SDLoc(N), + ISD::ZERO_EXTEND); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } + } // fold (zext (load x)) to multiple smaller zextloads. // Only on illegal but splittable vectors. @@ -8949,110 +6372,120 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { N0.getOpcode() == ISD::XOR) && isa<LoadSDNode>(N0.getOperand(0)) && N0.getOperand(1).getOpcode() == ISD::Constant && + TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, N0.getValueType()) && (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) { - LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0)); - EVT MemVT = LN00->getMemoryVT(); - if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) && - LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) { + LoadSDNode *LN0 = cast<LoadSDNode>(N0.getOperand(0)); + if (LN0->getExtensionType() != ISD::SEXTLOAD && LN0->isUnindexed()) { bool DoXform = true; SmallVector<SDNode*, 4> SetCCs; if (!N0.hasOneUse()) { if (N0.getOpcode() == ISD::AND) { auto *AndC = cast<ConstantSDNode>(N0.getOperand(1)); + auto NarrowLoad = false; EVT LoadResultTy = AndC->getValueType(0); - EVT ExtVT; - if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT)) + EVT ExtVT, LoadedVT; + if (isAndLoadExtLoad(AndC, LN0, LoadResultTy, ExtVT, LoadedVT, + NarrowLoad)) DoXform = false; } + if (DoXform) + DoXform = ExtendUsesToFormExtLoad(N, N0.getOperand(0), + ISD::ZERO_EXTEND, SetCCs, TLI); } - if (DoXform) - DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0), - ISD::ZERO_EXTEND, SetCCs, TLI); if (DoXform) { - SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT, - LN00->getChain(), LN00->getBasePtr(), - LN00->getMemoryVT(), - LN00->getMemOperand()); + SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN0), VT, + LN0->getChain(), LN0->getBasePtr(), + LN0->getMemoryVT(), + LN0->getMemOperand()); APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); Mask = Mask.zext(VT.getSizeInBits()); SDLoc DL(N); SDValue And = DAG.getNode(N0.getOpcode(), DL, VT, ExtLoad, DAG.getConstant(Mask, DL, VT)); - ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND); - bool NoReplaceTruncAnd = !N0.hasOneUse(); - bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse(); + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, + SDLoc(N0.getOperand(0)), + N0.getOperand(0).getValueType(), ExtLoad); CombineTo(N, And); - // If N0 has multiple uses, change other uses as well. - if (NoReplaceTruncAnd) { - SDValue TruncAnd = - DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And); - CombineTo(N0.getNode(), TruncAnd); - } - if (NoReplaceTrunc) { - DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1)); - } else { - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00), - LN00->getValueType(0), ExtLoad); - CombineTo(LN00, Trunc, ExtLoad.getValue(1)); - } - return SDValue(N,0); // Return N so it doesn't get rechecked! + CombineTo(N0.getOperand(0).getNode(), Trunc, ExtLoad.getValue(1)); + ExtendSetCCUses(SetCCs, Trunc, ExtLoad, DL, + ISD::ZERO_EXTEND); + return SDValue(N, 0); // Return N so it doesn't get rechecked! } } } - // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) -> - // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst)) - if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N)) - return ZExtLoad; - - // Try to simplify (zext (zextload x)). - if (SDValue foldedExt = tryToFoldExtOfExtload( - DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD)) - return foldedExt; - - if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations)) - return V; + // fold (zext (zextload x)) -> (zext (truncate (zextload x))) + // fold (zext ( extload x)) -> (zext (truncate (zextload x))) + if ((ISD::isZEXTLoad(N0.getNode()) || ISD::isEXTLoad(N0.getNode())) && + ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) { + LoadSDNode *LN0 = cast<LoadSDNode>(N0); + EVT MemVT = LN0->getMemoryVT(); + if ((!LegalOperations && !LN0->isVolatile()) || + TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT)) { + SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N), VT, + LN0->getChain(), + LN0->getBasePtr(), MemVT, + LN0->getMemOperand()); + CombineTo(N, ExtLoad); + CombineTo(N0.getNode(), + DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), + ExtLoad), + ExtLoad.getValue(1)); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } + } if (N0.getOpcode() == ISD::SETCC) { - // Only do this before legalize for now. if (!LegalOperations && VT.isVector() && N0.getValueType().getVectorElementType() == MVT::i1) { - EVT N00VT = N0.getOperand(0).getValueType(); - if (getSetCCResultType(N00VT) == N0.getValueType()) + EVT N0VT = N0.getOperand(0).getValueType(); + if (getSetCCResultType(N0VT) == N0.getValueType()) return SDValue(); - // We know that the # elements of the results is the same as the # - // elements of the compare (and the # elements of the compare result for - // that matter). Check to see that they are the same size. If so, we know - // that the element size of the sext'd result matches the element size of - // the compare operands. + // zext(setcc) -> (and (vsetcc), (1, 1, ...) for vectors. + // Only do this before legalize for now. + EVT EltVT = VT.getVectorElementType(); SDLoc DL(N); - SDValue VecOnes = DAG.getConstant(1, DL, VT); - if (VT.getSizeInBits() == N00VT.getSizeInBits()) { - // zext(setcc) -> (and (vsetcc), (1, 1, ...) for vectors. - SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0), - N0.getOperand(1), N0.getOperand(2)); - return DAG.getNode(ISD::AND, DL, VT, VSetCC, VecOnes); - } + SmallVector<SDValue,8> OneOps(VT.getVectorNumElements(), + DAG.getConstant(1, DL, EltVT)); + if (VT.getSizeInBits() == N0VT.getSizeInBits()) + // We know that the # elements of the results is the same as the + // # elements of the compare (and the # elements of the compare result + // for that matter). Check to see that they are the same size. If so, + // we know that the element size of the sext'd result matches the + // element size of the compare operands. + return DAG.getNode(ISD::AND, DL, VT, + DAG.getSetCC(DL, VT, N0.getOperand(0), + N0.getOperand(1), + cast<CondCodeSDNode>(N0.getOperand(2))->get()), + DAG.getNode(ISD::BUILD_VECTOR, DL, VT, + OneOps)); // If the desired elements are smaller or larger than the source // elements we can use a matching integer vector type and then - // truncate/sign extend. - EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger(); + // truncate/sign extend + EVT MatchingElementType = + EVT::getIntegerVT(*DAG.getContext(), + N0VT.getScalarType().getSizeInBits()); + EVT MatchingVectorType = + EVT::getVectorVT(*DAG.getContext(), MatchingElementType, + N0VT.getVectorNumElements()); SDValue VsetCC = - DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0), - N0.getOperand(1), N0.getOperand(2)); - return DAG.getNode(ISD::AND, DL, VT, DAG.getSExtOrTrunc(VsetCC, DL, VT), - VecOnes); + DAG.getSetCC(DL, MatchingVectorType, N0.getOperand(0), + N0.getOperand(1), + cast<CondCodeSDNode>(N0.getOperand(2))->get()); + return DAG.getNode(ISD::AND, DL, VT, + DAG.getSExtOrTrunc(VsetCC, DL, VT), + DAG.getNode(ISD::BUILD_VECTOR, DL, VT, OneOps)); } // zext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc SDLoc DL(N); - if (SDValue SCC = SimplifySelectCC( - DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT), - DAG.getConstant(0, DL, VT), - cast<CondCodeSDNode>(N0.getOperand(2))->get(), true)) - return SCC; + SDValue SCC = + SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), + DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT), + cast<CondCodeSDNode>(N0.getOperand(2))->get(), true); + if (SCC.getNode()) return SCC; } // (zext (shl (zext x), cst)) -> (shl (zext x), cst) @@ -9066,8 +6499,8 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { SDValue InnerZExt = N0.getOperand(0); // If the original shl may be shifting out bits, do not perform this // transformation. - unsigned KnownZeroBits = InnerZExt.getValueSizeInBits() - - InnerZExt.getOperand(0).getValueSizeInBits(); + unsigned KnownZeroBits = InnerZExt.getValueType().getSizeInBits() - + InnerZExt.getOperand(0).getValueType().getSizeInBits(); if (ShAmtVal > KnownZeroBits) return SDValue(); } @@ -9083,9 +6516,6 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { ShAmt); } - if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N)) - return NewVSel; - return SDValue(); } @@ -9093,8 +6523,9 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); - if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) - return Res; + if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes, + LegalOperations)) + return SDValue(Res, 0); // fold (aext (aext x)) -> (aext x) // fold (aext (zext x)) -> (zext x) @@ -9108,7 +6539,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n))) if (N0.getOpcode() == ISD::TRUNCATE) { if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) { - SDNode *oye = N0.getOperand(0).getNode(); + SDNode* oye = N0.getNode()->getOperand(0).getNode(); if (NarrowLoad.getNode() != N0.getNode()) { CombineTo(N0.getNode(), NarrowLoad); // CombineTo deleted the truncate, if needed, but not what's under it. @@ -9119,8 +6550,14 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { } // fold (aext (truncate x)) - if (N0.getOpcode() == ISD::TRUNCATE) - return DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT); + if (N0.getOpcode() == ISD::TRUNCATE) { + SDValue TruncOp = N0.getOperand(0); + if (TruncOp.getValueType() == VT) + return TruncOp; // x iff x size == zext size. + if (TruncOp.getValueType().bitsGT(VT)) + return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, TruncOp); + return DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, TruncOp); + } // Fold (aext (and (trunc x), cst)) -> (and x, cst) // if the trunc is not free. @@ -9129,11 +6566,15 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { N0.getOperand(1).getOpcode() == ISD::Constant && !TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(), N0.getValueType())) { - SDLoc DL(N); SDValue X = N0.getOperand(0).getOperand(0); - X = DAG.getAnyExtOrTrunc(X, DL, VT); + if (X.getValueType().bitsLT(VT)) { + X = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, X); + } else if (X.getValueType().bitsGT(VT)) { + X = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, X); + } APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); Mask = Mask.zext(VT.getSizeInBits()); + SDLoc DL(N); return DAG.getNode(ISD::AND, DL, VT, X, DAG.getConstant(Mask, DL, VT)); } @@ -9148,34 +6589,29 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { bool DoXform = true; SmallVector<SDNode*, 4> SetCCs; if (!N0.hasOneUse()) - DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs, - TLI); + DoXform = ExtendUsesToFormExtLoad(N, N0, ISD::ANY_EXTEND, SetCCs, TLI); if (DoXform) { LoadSDNode *LN0 = cast<LoadSDNode>(N0); SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT, LN0->getChain(), LN0->getBasePtr(), N0.getValueType(), LN0->getMemOperand()); - ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND); - // If the load value is used only by N, replace it via CombineTo N. - bool NoReplaceTrunc = N0.hasOneUse(); CombineTo(N, ExtLoad); - if (NoReplaceTrunc) { - DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1)); - } else { - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), - N0.getValueType(), ExtLoad); - CombineTo(LN0, Trunc, ExtLoad.getValue(1)); - } - return SDValue(N, 0); // Return N so it doesn't get rechecked! + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), + N0.getValueType(), ExtLoad); + CombineTo(N0.getNode(), Trunc, ExtLoad.getValue(1)); + ExtendSetCCUses(SetCCs, Trunc, ExtLoad, SDLoc(N), + ISD::ANY_EXTEND); + return SDValue(N, 0); // Return N so it doesn't get rechecked! } } // fold (aext (zextload x)) -> (aext (truncate (zextload x))) // fold (aext (sextload x)) -> (aext (truncate (sextload x))) // fold (aext ( extload x)) -> (aext (truncate (extload x))) - if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) && - ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) { + if (N0.getOpcode() == ISD::LOAD && + !ISD::isNON_EXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && + N0.hasOneUse()) { LoadSDNode *LN0 = cast<LoadSDNode>(N0); ISD::LoadExtType ExtType = LN0->getExtensionType(); EVT MemVT = LN0->getMemoryVT(); @@ -9184,7 +6620,10 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { VT, LN0->getChain(), LN0->getBasePtr(), MemVT, LN0->getMemOperand()); CombineTo(N, ExtLoad); - DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1)); + CombineTo(N0.getNode(), + DAG.getNode(ISD::TRUNCATE, SDLoc(N0), + N0.getValueType(), ExtLoad), + ExtLoad.getValue(1)); return SDValue(N, 0); // Return N so it doesn't get rechecked! } } @@ -9196,104 +6635,89 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { // aext(setcc) -> aext(vsetcc) // Only do this before legalize for now. if (VT.isVector() && !LegalOperations) { - EVT N00VT = N0.getOperand(0).getValueType(); - if (getSetCCResultType(N00VT) == N0.getValueType()) - return SDValue(); - - // We know that the # elements of the results is the same as the - // # elements of the compare (and the # elements of the compare result - // for that matter). Check to see that they are the same size. If so, - // we know that the element size of the sext'd result matches the - // element size of the compare operands. - if (VT.getSizeInBits() == N00VT.getSizeInBits()) + EVT N0VT = N0.getOperand(0).getValueType(); + // We know that the # elements of the results is the same as the + // # elements of the compare (and the # elements of the compare result + // for that matter). Check to see that they are the same size. If so, + // we know that the element size of the sext'd result matches the + // element size of the compare operands. + if (VT.getSizeInBits() == N0VT.getSizeInBits()) return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0), N0.getOperand(1), cast<CondCodeSDNode>(N0.getOperand(2))->get()); - // If the desired elements are smaller or larger than the source // elements we can use a matching integer vector type and then // truncate/any extend - EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger(); - SDValue VsetCC = - DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0), - N0.getOperand(1), - cast<CondCodeSDNode>(N0.getOperand(2))->get()); - return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT); + else { + EVT MatchingVectorType = N0VT.changeVectorElementTypeToInteger(); + SDValue VsetCC = + DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0), + N0.getOperand(1), + cast<CondCodeSDNode>(N0.getOperand(2))->get()); + return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT); + } } // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc SDLoc DL(N); - if (SDValue SCC = SimplifySelectCC( - DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT), - DAG.getConstant(0, DL, VT), - cast<CondCodeSDNode>(N0.getOperand(2))->get(), true)) + SDValue SCC = + SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), + DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT), + cast<CondCodeSDNode>(N0.getOperand(2))->get(), true); + if (SCC.getNode()) return SCC; } return SDValue(); } -SDValue DAGCombiner::visitAssertExt(SDNode *N) { - unsigned Opcode = N->getOpcode(); - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - EVT AssertVT = cast<VTSDNode>(N1)->getVT(); - - // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt) - if (N0.getOpcode() == Opcode && - AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT()) - return N0; - - if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() && - N0.getOperand(0).getOpcode() == Opcode) { - // We have an assert, truncate, assert sandwich. Make one stronger assert - // by asserting on the smallest asserted type to the larger source type. - // This eliminates the later assert: - // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN - // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN - SDValue BigA = N0.getOperand(0); - EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT(); - assert(BigA_AssertVT.bitsLE(N0.getValueType()) && - "Asserting zero/sign-extended bits to a type larger than the " - "truncated destination does not provide information"); +/// See if the specified operand can be simplified with the knowledge that only +/// the bits specified by Mask are used. If so, return the simpler operand, +/// otherwise return a null SDValue. +SDValue DAGCombiner::GetDemandedBits(SDValue V, const APInt &Mask) { + switch (V.getOpcode()) { + default: break; + case ISD::Constant: { + const ConstantSDNode *CV = cast<ConstantSDNode>(V.getNode()); + assert(CV && "Const value should be ConstSDNode."); + const APInt &CVal = CV->getAPIntValue(); + APInt NewVal = CVal & Mask; + if (NewVal != CVal) + return DAG.getConstant(NewVal, SDLoc(V), V.getValueType()); + break; + } + case ISD::OR: + case ISD::XOR: + // If the LHS or RHS don't contribute bits to the or, drop them. + if (DAG.MaskedValueIsZero(V.getOperand(0), Mask)) + return V.getOperand(1); + if (DAG.MaskedValueIsZero(V.getOperand(1), Mask)) + return V.getOperand(0); + break; + case ISD::SRL: + // Only look at single-use SRLs. + if (!V.getNode()->hasOneUse()) + break; + if (ConstantSDNode *RHSC = getAsNonOpaqueConstant(V.getOperand(1))) { + // See if we can recursively simplify the LHS. + unsigned Amt = RHSC->getZExtValue(); - SDLoc DL(N); - EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT; - SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT); - SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(), - BigA.getOperand(0), MinAssertVTVal); - return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert); - } - - // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller - // than X. Just move the AssertZext in front of the truncate and drop the - // AssertSExt. - if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() && - N0.getOperand(0).getOpcode() == ISD::AssertSext && - Opcode == ISD::AssertZext) { - SDValue BigA = N0.getOperand(0); - EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT(); - assert(BigA_AssertVT.bitsLE(N0.getValueType()) && - "Asserting zero/sign-extended bits to a type larger than the " - "truncated destination does not provide information"); - - if (AssertVT.bitsLT(BigA_AssertVT)) { - SDLoc DL(N); - SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(), - BigA.getOperand(0), N1); - return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert); + // Watch out for shift count overflow though. + if (Amt >= Mask.getBitWidth()) break; + APInt NewMask = Mask << Amt; + if (SDValue SimplifyLHS = GetDemandedBits(V.getOperand(0), NewMask)) + return DAG.getNode(ISD::SRL, SDLoc(V), V.getValueType(), + SimplifyLHS, V.getOperand(1)); } } - return SDValue(); } /// If the result of a wider load is shifted to right of N bits and then /// truncated to a narrower type and where N is a multiple of number of bits of /// the narrower type, transform it to a narrower load from address + N / num of -/// bits of new type. Also narrow the load if the result is masked with an AND -/// to effectively produce a smaller type. If the result is to be extended, also -/// fold the extension to form a extending load. +/// bits of new type. If the result is to be extended, also fold the extension +/// to form a extending load. SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { unsigned Opc = N->getOpcode(); @@ -9306,100 +6730,56 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { if (VT.isVector()) return SDValue(); - unsigned ShAmt = 0; - bool HasShiftedOffset = false; // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then // extended to VT. if (Opc == ISD::SIGN_EXTEND_INREG) { ExtType = ISD::SEXTLOAD; ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT(); } else if (Opc == ISD::SRL) { - // Another special-case: SRL is basically zero-extending a narrower value, - // or it maybe shifting a higher subword, half or byte into the lowest - // bits. + // Another special-case: SRL is basically zero-extending a narrower value. ExtType = ISD::ZEXTLOAD; N0 = SDValue(N, 0); + ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1)); + if (!N01) return SDValue(); + ExtVT = EVT::getIntegerVT(*DAG.getContext(), + VT.getSizeInBits() - N01->getZExtValue()); + } + if (LegalOperations && !TLI.isLoadExtLegal(ExtType, VT, ExtVT)) + return SDValue(); - auto *LN0 = dyn_cast<LoadSDNode>(N0.getOperand(0)); - auto *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1)); - if (!N01 || !LN0) - return SDValue(); - - uint64_t ShiftAmt = N01->getZExtValue(); - uint64_t MemoryWidth = LN0->getMemoryVT().getSizeInBits(); - if (LN0->getExtensionType() != ISD::SEXTLOAD && MemoryWidth > ShiftAmt) - ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShiftAmt); - else - ExtVT = EVT::getIntegerVT(*DAG.getContext(), - VT.getSizeInBits() - ShiftAmt); - } else if (Opc == ISD::AND) { - // An AND with a constant mask is the same as a truncate + zero-extend. - auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1)); - if (!AndC) - return SDValue(); + unsigned EVTBits = ExtVT.getSizeInBits(); - const APInt &Mask = AndC->getAPIntValue(); - unsigned ActiveBits = 0; - if (Mask.isMask()) { - ActiveBits = Mask.countTrailingOnes(); - } else if (Mask.isShiftedMask()) { - ShAmt = Mask.countTrailingZeros(); - APInt ShiftedMask = Mask.lshr(ShAmt); - ActiveBits = ShiftedMask.countTrailingOnes(); - HasShiftedOffset = true; - } else - return SDValue(); - - ExtType = ISD::ZEXTLOAD; - ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits); - } + // Do not generate loads of non-round integer types since these can + // be expensive (and would be wrong if the type is not byte sized). + if (!ExtVT.isRound()) + return SDValue(); + unsigned ShAmt = 0; if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) { - SDValue SRL = N0; - if (auto *ConstShift = dyn_cast<ConstantSDNode>(SRL.getOperand(1))) { - ShAmt = ConstShift->getZExtValue(); - unsigned EVTBits = ExtVT.getSizeInBits(); + if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) { + ShAmt = N01->getZExtValue(); // Is the shift amount a multiple of size of VT? if ((ShAmt & (EVTBits-1)) == 0) { N0 = N0.getOperand(0); // Is the load width a multiple of size of VT? - if ((N0.getValueSizeInBits() & (EVTBits-1)) != 0) + if ((N0.getValueType().getSizeInBits() & (EVTBits-1)) != 0) return SDValue(); } // At this point, we must have a load or else we can't do the transform. if (!isa<LoadSDNode>(N0)) return SDValue(); - auto *LN0 = cast<LoadSDNode>(N0); - // Because a SRL must be assumed to *need* to zero-extend the high bits // (as opposed to anyext the high bits), we can't combine the zextload // lowering of SRL and an sextload. - if (LN0->getExtensionType() == ISD::SEXTLOAD) + if (cast<LoadSDNode>(N0)->getExtensionType() == ISD::SEXTLOAD) return SDValue(); // If the shift amount is larger than the input type then we're not // accessing any of the loaded bytes. If the load was a zextload/extload // then the result of the shift+trunc is zero/undef (handled elsewhere). - if (ShAmt >= LN0->getMemoryVT().getSizeInBits()) + if (ShAmt >= cast<LoadSDNode>(N0)->getMemoryVT().getSizeInBits()) return SDValue(); - - // If the SRL is only used by a masking AND, we may be able to adjust - // the ExtVT to make the AND redundant. - SDNode *Mask = *(SRL->use_begin()); - if (Mask->getOpcode() == ISD::AND && - isa<ConstantSDNode>(Mask->getOperand(1))) { - const APInt &ShiftMask = - cast<ConstantSDNode>(Mask->getOperand(1))->getAPIntValue(); - if (ShiftMask.isMask()) { - EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(), - ShiftMask.countTrailingOnes()); - // If the mask is smaller, recompute the type. - if ((ExtVT.getSizeInBits() > MaskedVT.getSizeInBits()) && - TLI.isLoadExtLegal(ExtType, N0.getValueType(), MaskedVT)) - ExtVT = MaskedVT; - } - } } } @@ -9414,26 +6794,52 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { } } - // If we haven't found a load, we can't narrow it. - if (!isa<LoadSDNode>(N0)) + // If we haven't found a load, we can't narrow it. Don't transform one with + // multiple uses, this would require adding a new load. + if (!isa<LoadSDNode>(N0) || !N0.hasOneUse()) return SDValue(); + // Don't change the width of a volatile load. LoadSDNode *LN0 = cast<LoadSDNode>(N0); - if (!isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt)) + if (LN0->isVolatile()) return SDValue(); - auto AdjustBigEndianShift = [&](unsigned ShAmt) { - unsigned LVTStoreBits = LN0->getMemoryVT().getStoreSizeInBits(); - unsigned EVTStoreBits = ExtVT.getStoreSizeInBits(); - return LVTStoreBits - EVTStoreBits - ShAmt; - }; + // Verify that we are actually reducing a load width here. + if (LN0->getMemoryVT().getSizeInBits() < EVTBits) + return SDValue(); + + // For the transform to be legal, the load must produce only two values + // (the value loaded and the chain). Don't transform a pre-increment + // load, for example, which produces an extra value. Otherwise the + // transformation is not equivalent, and the downstream logic to replace + // uses gets things wrong. + if (LN0->getNumValues() > 2) + return SDValue(); + + // If the load that we're shrinking is an extload and we're not just + // discarding the extension we can't simply shrink the load. Bail. + // TODO: It would be possible to merge the extensions in some cases. + if (LN0->getExtensionType() != ISD::NON_EXTLOAD && + LN0->getMemoryVT().getSizeInBits() < ExtVT.getSizeInBits() + ShAmt) + return SDValue(); + + if (!TLI.shouldReduceLoadWidth(LN0, ExtType, ExtVT)) + return SDValue(); + + EVT PtrType = N0.getOperand(1).getValueType(); + + if (PtrType == MVT::Untyped || PtrType.isExtended()) + // It's not possible to generate a constant of extended or untyped type. + return SDValue(); // For big endian targets, we need to adjust the offset to the pointer to // load the correct bytes. - if (DAG.getDataLayout().isBigEndian()) - ShAmt = AdjustBigEndianShift(ShAmt); + if (DAG.getDataLayout().isBigEndian()) { + unsigned LVTStoreBits = LN0->getMemoryVT().getStoreSizeInBits(); + unsigned EVTStoreBits = ExtVT.getStoreSizeInBits(); + ShAmt = LVTStoreBits - EVTStoreBits - ShAmt; + } - EVT PtrType = N0.getOperand(1).getValueType(); uint64_t PtrOff = ShAmt / 8; unsigned NewAlign = MinAlign(LN0->getAlignment(), PtrOff); SDLoc DL(LN0); @@ -9443,19 +6849,20 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { SDValue NewPtr = DAG.getNode(ISD::ADD, DL, PtrType, LN0->getBasePtr(), DAG.getConstant(PtrOff, DL, PtrType), - Flags); + &Flags); AddToWorklist(NewPtr.getNode()); SDValue Load; if (ExtType == ISD::NON_EXTLOAD) - Load = DAG.getLoad(VT, SDLoc(N0), LN0->getChain(), NewPtr, - LN0->getPointerInfo().getWithOffset(PtrOff), NewAlign, - LN0->getMemOperand()->getFlags(), LN0->getAAInfo()); + Load = DAG.getLoad(VT, SDLoc(N0), LN0->getChain(), NewPtr, + LN0->getPointerInfo().getWithOffset(PtrOff), + LN0->isVolatile(), LN0->isNonTemporal(), + LN0->isInvariant(), NewAlign, LN0->getAAInfo()); else - Load = DAG.getExtLoad(ExtType, SDLoc(N0), VT, LN0->getChain(), NewPtr, - LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT, - NewAlign, LN0->getMemOperand()->getFlags(), - LN0->getAAInfo()); + Load = DAG.getExtLoad(ExtType, SDLoc(N0), VT, LN0->getChain(),NewPtr, + LN0->getPointerInfo().getWithOffset(PtrOff), + ExtVT, LN0->isVolatile(), LN0->isNonTemporal(), + LN0->isInvariant(), NewAlign, LN0->getAAInfo()); // Replace the old load's chain with the new load's chain. WorklistRemover DeadNodes(*this); @@ -9479,21 +6886,6 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy)); } - if (HasShiftedOffset) { - // Recalculate the shift amount after it has been altered to calculate - // the offset. - if (DAG.getDataLayout().isBigEndian()) - ShAmt = AdjustBigEndianShift(ShAmt); - - // We're using a shifted mask, so the load now has an offset. This means - // that data has been loaded into the lower bytes than it would have been - // before, so we need to shl the loaded data into the correct position in the - // register. - SDValue ShiftC = DAG.getConstant(ShAmt, DL, VT); - Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC); - DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result); - } - // Return the new loaded value. return Result; } @@ -9503,14 +6895,14 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); EVT EVT = cast<VTSDNode>(N1)->getVT(); - unsigned VTBits = VT.getScalarSizeInBits(); - unsigned EVTBits = EVT.getScalarSizeInBits(); + unsigned VTBits = VT.getScalarType().getSizeInBits(); + unsigned EVTBits = EVT.getScalarType().getSizeInBits(); if (N0.isUndef()) return DAG.getUNDEF(VT); // fold (sext_in_reg c1) -> c1 - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) + if (isConstantIntBuildVectorOrConstantInt(N0)) return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1); // If the input is already sign extended, just drop the extension. @@ -9525,40 +6917,17 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { // fold (sext_in_reg (sext x)) -> (sext x) // fold (sext_in_reg (aext x)) -> (sext x) - // if x is small enough or if we know that x has more than 1 sign bit and the - // sign_extend_inreg is extending from one of them. + // if x is small enough. if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) { SDValue N00 = N0.getOperand(0); - unsigned N00Bits = N00.getScalarValueSizeInBits(); - if ((N00Bits <= EVTBits || - (N00Bits - DAG.ComputeNumSignBits(N00)) < EVTBits) && - (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT))) - return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00); - } - - // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x) - if ((N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG || - N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG || - N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) && - N0.getOperand(0).getScalarValueSizeInBits() == EVTBits) { - if (!LegalOperations || - TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT)) - return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT, - N0.getOperand(0)); - } - - // fold (sext_in_reg (zext x)) -> (sext x) - // iff we are extending the source sign bit. - if (N0.getOpcode() == ISD::ZERO_EXTEND) { - SDValue N00 = N0.getOperand(0); - if (N00.getScalarValueSizeInBits() == EVTBits && + if (N00.getValueType().getScalarType().getSizeInBits() <= EVTBits && (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT))) return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1); } // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero. - if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, EVTBits - 1))) - return DAG.getZeroExtendInReg(N0, SDLoc(N), EVT.getScalarType()); + if (DAG.MaskedValueIsZero(N0, APInt::getBitsSet(VTBits, EVTBits-1, EVTBits))) + return DAG.getZeroExtendInReg(N0, SDLoc(N), EVT); // fold operands of sext_in_reg based on knowledge that the top bits are not // demanded. @@ -9586,14 +6955,10 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { } // fold (sext_inreg (extload x)) -> (sextload x) - // If sextload is not supported by target, we can only do the combine when - // load has one use. Doing otherwise can block folding the extload with other - // extends that the target does support. if (ISD::isEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && EVT == cast<LoadSDNode>(N0)->getMemoryVT() && - ((!LegalOperations && !cast<LoadSDNode>(N0)->isVolatile() && - N0.hasOneUse()) || + ((!LegalOperations && !cast<LoadSDNode>(N0)->isVolatile()) || TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, EVT))) { LoadSDNode *LN0 = cast<LoadSDNode>(N0); SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT, @@ -9623,8 +6988,9 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16)) if (EVTBits <= 16 && N0.getOpcode() == ISD::OR) { - if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0), - N0.getOperand(1), false)) + SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0), + N0.getOperand(1), false); + if (BSwap.getNode()) return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, BSwap, N1); } @@ -9636,30 +7002,12 @@ SDValue DAGCombiner::visitSIGN_EXTEND_VECTOR_INREG(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); - if (N0.isUndef()) - return DAG.getUNDEF(VT); - - if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) - return Res; - - if (SimplifyDemandedVectorElts(SDValue(N, 0))) - return SDValue(N, 0); - - return SDValue(); -} - -SDValue DAGCombiner::visitZERO_EXTEND_VECTOR_INREG(SDNode *N) { - SDValue N0 = N->getOperand(0); - EVT VT = N->getValueType(0); - - if (N0.isUndef()) + if (N0.getOpcode() == ISD::UNDEF) return DAG.getUNDEF(VT); - if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) - return Res; - - if (SimplifyDemandedVectorElts(SDValue(N, 0))) - return SDValue(N, 0); + if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes, + LegalOperations)) + return SDValue(Res, 0); return SDValue(); } @@ -9672,37 +7020,28 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { // noop truncate if (N0.getValueType() == N->getValueType(0)) return N0; - + // fold (truncate c1) -> c1 + if (isConstantIntBuildVectorOrConstantInt(N0)) + return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0); // fold (truncate (truncate x)) -> (truncate x) if (N0.getOpcode() == ISD::TRUNCATE) return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0)); - - // fold (truncate c1) -> c1 - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) { - SDValue C = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0); - if (C.getNode() != N) - return C; - } - // fold (truncate (ext x)) -> (ext x) or (truncate x) or x if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) { - // if the source is smaller than the dest, we still need an extend. if (N0.getOperand(0).getValueType().bitsLT(VT)) - return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0)); - // if the source is larger than the dest, than we just need the truncate. + // if the source is smaller than the dest, we still need an extend + return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, + N0.getOperand(0)); if (N0.getOperand(0).getValueType().bitsGT(VT)) + // if the source is larger than the dest, than we just need the truncate return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0)); // if the source and dest are the same type, we can drop both the extend // and the truncate. return N0.getOperand(0); } - // If this is anyext(trunc), don't fold it, allow ourselves to be folded. - if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND)) - return SDValue(); - // Fold extract-and-trunc into a narrow extract. For example: // i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1) // i32 y = TRUNCATE(i64 x) @@ -9715,6 +7054,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { // we need to be more careful about the vector instructions that we generate. if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT && LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) { + EVT VecTy = N0.getOperand(0).getValueType(); EVT ExTy = N0.getValueType(); EVT TrTy = N->getValueType(0); @@ -9731,15 +7071,18 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout()); int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1)); + SDValue V = DAG.getNode(ISD::BITCAST, SDLoc(N), + NVT, N0.getOperand(0)); + SDLoc DL(N); - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy, - DAG.getBitcast(NVT, N0.getOperand(0)), + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, + DL, TrTy, V, DAG.getConstant(Index, DL, IndexTy)); } } // trunc (select c, a, b) -> select c, (trunc a), (trunc b) - if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) { + if (N0.getOpcode() == ISD::SELECT) { EVT SrcVT = N0.getValueType(); if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) && TLI.isTruncateFree(SrcVT, VT)) { @@ -9751,26 +7094,6 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { } } - // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits() - if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() && - (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) && - TLI.isTypeDesirableForOp(ISD::SHL, VT)) { - SDValue Amt = N0.getOperand(1); - KnownBits Known = DAG.computeKnownBits(Amt); - unsigned Size = VT.getScalarSizeInBits(); - if (Known.getBitWidth() - Known.countMinLeadingZeros() <= Log2_32(Size)) { - SDLoc SL(N); - EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout()); - - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0)); - if (AmtVT != Amt.getValueType()) { - Amt = DAG.getZExtOrTrunc(Amt, SL, AmtVT); - AddToWorklist(Amt.getNode()); - } - return DAG.getNode(ISD::SHL, SL, VT, Trunc, Amt); - } - } - // Fold a series of buildvector, bitcast, and truncate if possible. // For example fold // (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to @@ -9779,6 +7102,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() && N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR && N0.getOperand(0).hasOneUse()) { + SDValue BuildVect = N0.getOperand(0); EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType(); EVT TruncVecEltTy = VT.getVectorElementType(); @@ -9797,7 +7121,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset) Opnds.push_back(BuildVect.getOperand(i)); - return DAG.getBuildVector(VT, SDLoc(N), Opnds); + return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), VT, Opnds); } } @@ -9807,12 +7131,12 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { // Currently we only perform this optimization on scalars because vectors // may have different active low bits. if (!VT.isVector()) { - APInt Mask = - APInt::getLowBitsSet(N0.getValueSizeInBits(), VT.getSizeInBits()); - if (SDValue Shorter = DAG.GetDemandedBits(N0, Mask)) + SDValue Shorter = + GetDemandedBits(N0, APInt::getLowBitsSet(N0.getValueSizeInBits(), + VT.getSizeInBits())); + if (Shorter.getNode()) return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Shorter); } - // fold (truncate (load x)) -> (smaller load x) // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits)) if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) { @@ -9834,7 +7158,6 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { } } } - // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)), // where ... are all 'undef'. if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) { @@ -9845,7 +7168,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) { SDValue X = N0.getOperand(i); - if (!X.isUndef()) { + if (X.getOpcode() != ISD::UNDEF) { V = X; Idx = i; NumDefs++; @@ -9877,88 +7200,11 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { } } - // Fold truncate of a bitcast of a vector to an extract of the low vector - // element. - // - // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx - if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) { - SDValue VecSrc = N0.getOperand(0); - EVT SrcVT = VecSrc.getValueType(); - if (SrcVT.isVector() && SrcVT.getScalarType() == VT && - (!LegalOperations || - TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, SrcVT))) { - SDLoc SL(N); - - EVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout()); - unsigned Idx = isLE ? 0 : SrcVT.getVectorNumElements() - 1; - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, VT, - VecSrc, DAG.getConstant(Idx, SL, IdxVT)); - } - } - // Simplify the operands using demanded-bits information. if (!VT.isVector() && SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); - // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry) - // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry) - // When the adde's carry is not used. - if ((N0.getOpcode() == ISD::ADDE || N0.getOpcode() == ISD::ADDCARRY) && - N0.hasOneUse() && !N0.getNode()->hasAnyUseOfValue(1) && - (!LegalOperations || TLI.isOperationLegal(N0.getOpcode(), VT))) { - SDLoc SL(N); - auto X = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0)); - auto Y = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1)); - auto VTs = DAG.getVTList(VT, N0->getValueType(1)); - return DAG.getNode(N0.getOpcode(), SL, VTs, X, Y, N0.getOperand(2)); - } - - // fold (truncate (extract_subvector(ext x))) -> - // (extract_subvector x) - // TODO: This can be generalized to cover cases where the truncate and extract - // do not fully cancel each other out. - if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) { - SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == ISD::SIGN_EXTEND || - N00.getOpcode() == ISD::ZERO_EXTEND || - N00.getOpcode() == ISD::ANY_EXTEND) { - if (N00.getOperand(0)->getValueType(0).getVectorElementType() == - VT.getVectorElementType()) - return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT, - N00.getOperand(0), N0.getOperand(1)); - } - } - - if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N)) - return NewVSel; - - // Narrow a suitable binary operation with a non-opaque constant operand by - // moving it ahead of the truncate. This is limited to pre-legalization - // because targets may prefer a wider type during later combines and invert - // this transform. - switch (N0.getOpcode()) { - case ISD::ADD: - case ISD::SUB: - case ISD::MUL: - case ISD::AND: - case ISD::OR: - case ISD::XOR: - if (!LegalOperations && N0.hasOneUse() && - (isConstantOrConstantVector(N0.getOperand(0), true) || - isConstantOrConstantVector(N0.getOperand(1), true))) { - // TODO: We already restricted this to pre-legalization, but for vectors - // we are extra cautious to not create an unsupported operation. - // Target-specific changes are likely needed to avoid regressions here. - if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) { - SDLoc DL(N); - SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0)); - SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1)); - return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR); - } - } - } - return SDValue(); } @@ -9976,28 +7222,27 @@ SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) { LoadSDNode *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0)); LoadSDNode *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1)); - - // A BUILD_PAIR is always having the least significant part in elt 0 and the - // most significant part in elt 1. So when combining into one large load, we - // need to consider the endianness. - if (DAG.getDataLayout().isBigEndian()) - std::swap(LD1, LD2); - if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !LD1->hasOneUse() || LD1->getAddressSpace() != LD2->getAddressSpace()) return SDValue(); EVT LD1VT = LD1->getValueType(0); - unsigned LD1Bytes = LD1VT.getStoreSize(); - if (ISD::isNON_EXTLoad(LD2) && LD2->hasOneUse() && - DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1)) { + + if (ISD::isNON_EXTLoad(LD2) && + LD2->hasOneUse() && + // If both are volatile this would reduce the number of volatile loads. + // If one is volatile it might be ok, but play conservative and bail out. + !LD1->isVolatile() && + !LD2->isVolatile() && + DAG.isConsecutiveLoad(LD2, LD1, LD1VT.getSizeInBits()/8, 1)) { unsigned Align = LD1->getAlignment(); unsigned NewAlign = DAG.getDataLayout().getABITypeAlignment( VT.getTypeForEVT(*DAG.getContext())); if (NewAlign <= Align && (!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT))) - return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(), - LD1->getPointerInfo(), Align); + return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), + LD1->getBasePtr(), LD1->getPointerInfo(), + false, false, false, Align); } return SDValue(); @@ -10009,77 +7254,25 @@ static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) { return DAG.getDataLayout().isBigEndian() ? 1 : 0; } -static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG, - const TargetLowering &TLI) { - // If this is not a bitcast to an FP type or if the target doesn't have - // IEEE754-compliant FP logic, we're done. - EVT VT = N->getValueType(0); - if (!VT.isFloatingPoint() || !TLI.hasBitPreservingFPLogic(VT)) - return SDValue(); - - // TODO: Handle cases where the integer constant is a different scalar - // bitwidth to the FP. - SDValue N0 = N->getOperand(0); - EVT SourceVT = N0.getValueType(); - if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits()) - return SDValue(); - - unsigned FPOpcode; - APInt SignMask; - switch (N0.getOpcode()) { - case ISD::AND: - FPOpcode = ISD::FABS; - SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits()); - break; - case ISD::XOR: - FPOpcode = ISD::FNEG; - SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits()); - break; - case ISD::OR: - FPOpcode = ISD::FABS; - SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits()); - break; - default: - return SDValue(); - } - - // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X - // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X - // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) -> - // fneg (fabs X) - SDValue LogicOp0 = N0.getOperand(0); - ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true); - if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask && - LogicOp0.getOpcode() == ISD::BITCAST && - LogicOp0.getOperand(0).getValueType() == VT) { - SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, LogicOp0.getOperand(0)); - NumFPLogicOpsConv++; - if (N0.getOpcode() == ISD::OR) - return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp); - return FPOp; - } - - return SDValue(); -} - SDValue DAGCombiner::visitBITCAST(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); - if (N0.isUndef()) - return DAG.getUNDEF(VT); - // If the input is a BUILD_VECTOR with all constant elements, fold this now. - // Only do this before legalize types, since we might create an illegal - // scalar type. Even if we knew we wouldn't create an illegal scalar type - // we can only do this before legalize ops, since the target maybe - // depending on the bitcast. + // Only do this before legalize, since afterward the target may be depending + // on the bitconvert. // First check to see if this is all constant. if (!LegalTypes && N0.getOpcode() == ISD::BUILD_VECTOR && N0.getNode()->hasOneUse() && - VT.isVector() && cast<BuildVectorSDNode>(N0)->isConstant()) - return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(), - VT.getVectorElementType()); + VT.isVector()) { + bool isSimple = cast<BuildVectorSDNode>(N0)->isConstant(); + + EVT DestEltVT = N->getValueType(0).getVectorElementType(); + assert(!DestEltVT.isVector() && + "Element type of vector ValueType must not be vector!"); + if (isSimple) + return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(), DestEltVT); + } // If the input is a constant, let getNode fold it. if (isa<ConstantSDNode>(N0) || isa<ConstantFPSDNode>(N0)) { @@ -10090,50 +7283,41 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() && TLI.isOperationLegal(ISD::ConstantFP, VT)) || (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() && - TLI.isOperationLegal(ISD::Constant, VT))) { - SDValue C = DAG.getBitcast(VT, N0); - if (C.getNode() != N) - return C; - } + TLI.isOperationLegal(ISD::Constant, VT))) + return DAG.getNode(ISD::BITCAST, SDLoc(N), VT, N0); } // (conv (conv x, t1), t2) -> (conv x, t2) if (N0.getOpcode() == ISD::BITCAST) - return DAG.getBitcast(VT, N0.getOperand(0)); + return DAG.getNode(ISD::BITCAST, SDLoc(N), VT, + N0.getOperand(0)); // fold (conv (load x)) -> (load (conv*)x) // If the resultant load doesn't need a higher alignment than the original! if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() && + // Do not change the width of a volatile load. + !cast<LoadSDNode>(N0)->isVolatile() && // Do not remove the cast if the types differ in endian layout. TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) == TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) && - // If the load is volatile, we only want to change the load type if the - // resulting load is legal. Otherwise we might increase the number of - // memory accesses. We don't care if the original type was legal or not - // as we assume software couldn't rely on the number of accesses of an - // illegal type. - ((!LegalOperations && !cast<LoadSDNode>(N0)->isVolatile()) || - TLI.isOperationLegal(ISD::LOAD, VT)) && + (!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)) && TLI.isLoadBitCastBeneficial(N0.getValueType(), VT)) { LoadSDNode *LN0 = cast<LoadSDNode>(N0); + unsigned Align = DAG.getDataLayout().getABITypeAlignment( + VT.getTypeForEVT(*DAG.getContext())); unsigned OrigAlign = LN0->getAlignment(); - bool Fast = false; - if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT, - LN0->getAddressSpace(), OrigAlign, &Fast) && - Fast) { - SDValue Load = - DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(), - LN0->getPointerInfo(), OrigAlign, - LN0->getMemOperand()->getFlags(), LN0->getAAInfo()); + if (Align <= OrigAlign) { + SDValue Load = DAG.getLoad(VT, SDLoc(N), LN0->getChain(), + LN0->getBasePtr(), LN0->getPointerInfo(), + LN0->isVolatile(), LN0->isNonTemporal(), + LN0->isInvariant(), OrigAlign, + LN0->getAAInfo()); DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1)); return Load; } } - if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI)) - return V; - // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit) // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit)) // @@ -10150,14 +7334,15 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) && N0.getNode()->hasOneUse() && VT.isInteger() && !VT.isVector() && !N0.getValueType().isVector()) { - SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0)); + SDValue NewConv = DAG.getNode(ISD::BITCAST, SDLoc(N0), VT, + N0.getOperand(0)); AddToWorklist(NewConv.getNode()); SDLoc DL(N); if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) { assert(VT.getSizeInBits() == 128); SDValue SignBit = DAG.getConstant( - APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64); + APInt::getSignBit(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64); SDValue FlipBit; if (N0.getOpcode() == ISD::FNEG) { FlipBit = SignBit; @@ -10177,7 +7362,7 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { AddToWorklist(FlipBits.getNode()); return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits); } - APInt SignBit = APInt::getSignMask(VT.getSizeInBits()); + APInt SignBit = APInt::getSignBit(VT.getSizeInBits()); if (N0.getOpcode() == ISD::FNEG) return DAG.getNode(ISD::XOR, DL, VT, NewConv, DAG.getConstant(SignBit, DL, VT)); @@ -10200,10 +7385,11 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { if (N0.getOpcode() == ISD::FCOPYSIGN && N0.getNode()->hasOneUse() && isa<ConstantFPSDNode>(N0.getOperand(0)) && VT.isInteger() && !VT.isVector()) { - unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits(); + unsigned OrigXWidth = N0.getOperand(1).getValueType().getSizeInBits(); EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth); if (isTypeLegal(IntXVT)) { - SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1)); + SDValue X = DAG.getNode(ISD::BITCAST, SDLoc(N0), + IntXVT, N0.getOperand(1)); AddToWorklist(X.getNode()); // If X has a different width than the result/lhs, sext it or truncate it. @@ -10225,10 +7411,12 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { } if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) { - APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2); - SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0)); + APInt SignBit = APInt::getSignBit(VT.getSizeInBits() / 2); + SDValue Cst = DAG.getNode(ISD::BITCAST, SDLoc(N0.getOperand(0)), VT, + N0.getOperand(0)); AddToWorklist(Cst.getNode()); - SDValue X = DAG.getBitcast(VT, N0.getOperand(1)); + SDValue X = DAG.getNode(ISD::BITCAST, SDLoc(N0.getOperand(1)), VT, + N0.getOperand(1)); AddToWorklist(X.getNode()); SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X); AddToWorklist(XorResult.getNode()); @@ -10246,12 +7434,13 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { AddToWorklist(FlipBits.getNode()); return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits); } - APInt SignBit = APInt::getSignMask(VT.getSizeInBits()); + APInt SignBit = APInt::getSignBit(VT.getSizeInBits()); X = DAG.getNode(ISD::AND, SDLoc(X), VT, X, DAG.getConstant(SignBit, SDLoc(X), VT)); AddToWorklist(X.getNode()); - SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0)); + SDValue Cst = DAG.getNode(ISD::BITCAST, SDLoc(N0), + VT, N0.getOperand(0)); Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT, Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT)); AddToWorklist(Cst.getNode()); @@ -10270,7 +7459,7 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { // float vectors bitcast to integer vectors) into shuffles. // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1) if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() && - N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() && + N0->getOpcode() == ISD::VECTOR_SHUFFLE && VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() && !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) { ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0); @@ -10281,15 +7470,12 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { if (Op.getOpcode() == ISD::BITCAST && Op.getOperand(0).getValueType() == VT) return SDValue(Op.getOperand(0)); - if (Op.isUndef() || ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) || + if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) || ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode())) - return DAG.getBitcast(VT, Op); + return DAG.getNode(ISD::BITCAST, SDLoc(N), VT, Op); return SDValue(); }; - // FIXME: If either input vector is bitcast, try to convert the shuffle to - // the result type of this bitcast. This would eliminate at least one - // bitcast. See the transform in InstCombine. SDValue SV0 = PeekThroughBitcast(N0->getOperand(0)); SDValue SV1 = PeekThroughBitcast(N0->getOperand(1)); if (!(SV0 && SV1)) @@ -10336,18 +7522,27 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { // If this is a conversion of N elements of one type to N elements of another // type, convert each element. This handles FP<->INT cases. if (SrcBitSize == DstBitSize) { + EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, + BV->getValueType(0).getVectorNumElements()); + + // Due to the FP element handling below calling this routine recursively, + // we can end up with a scalar-to-vector node here. + if (BV->getOpcode() == ISD::SCALAR_TO_VECTOR) + return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(BV), VT, + DAG.getNode(ISD::BITCAST, SDLoc(BV), + DstEltVT, BV->getOperand(0))); + SmallVector<SDValue, 8> Ops; for (SDValue Op : BV->op_values()) { // If the vector element type is not legal, the BUILD_VECTOR operands // are promoted and implicitly truncated. Make that explicit here. if (Op.getValueType() != SrcEltVT) Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op); - Ops.push_back(DAG.getBitcast(DstEltVT, Op)); + Ops.push_back(DAG.getNode(ISD::BITCAST, SDLoc(BV), + DstEltVT, Op)); AddToWorklist(Ops.back().getNode()); } - EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, - BV->getValueType(0).getVectorNumElements()); - return DAG.getBuildVector(VT, SDLoc(BV), Ops); + return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(BV), VT, Ops); } // Otherwise, we're growing or shrinking the elements. To avoid having to @@ -10389,7 +7584,7 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { // Shift the previously computed bits over. NewBits <<= SrcBitSize; SDValue Op = BV->getOperand(i+ (isLE ? (NumInputsPerOutput-j-1) : j)); - if (Op.isUndef()) continue; + if (Op.getOpcode() == ISD::UNDEF) continue; EltIsUndef = false; NewBits |= cast<ConstantSDNode>(Op)->getAPIntValue(). @@ -10403,7 +7598,7 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { } EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size()); - return DAG.getBuildVector(VT, DL, Ops); + return DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Ops); } // Finally, this must be the case where we are shrinking elements: each input @@ -10414,7 +7609,7 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { SmallVector<SDValue, 8> Ops; for (const SDValue &Op : BV->op_values()) { - if (Op.isUndef()) { + if (Op.getOpcode() == ISD::UNDEF) { Ops.append(NumOutputsPerInput, DAG.getUNDEF(DstEltVT)); continue; } @@ -10425,7 +7620,7 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { for (unsigned j = 0; j != NumOutputsPerInput; ++j) { APInt ThisVal = OpVal.trunc(DstBitSize); Ops.push_back(DAG.getConstant(ThisVal, DL, DstEltVT)); - OpVal.lshrInPlace(DstBitSize); + OpVal = OpVal.lshr(DstBitSize); } // For big endian targets, swap the order of the pieces of each element. @@ -10433,12 +7628,7 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { std::reverse(Ops.end()-NumOutputsPerInput, Ops.end()); } - return DAG.getBuildVector(VT, DL, Ops); -} - -static bool isContractable(SDNode *N) { - SDNodeFlags F = N->getFlags(); - return F.hasAllowContract() || F.hasAllowReassociation(); + return DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Ops); } /// Try to perform FMA combining on a given FADD node. @@ -10449,202 +7639,173 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDLoc SL(N); const TargetOptions &Options = DAG.getTarget().Options; + bool AllowFusion = + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); // Floating-point multiply-add with intermediate rounding. bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)); // Floating-point multiply-add without intermediate rounding. bool HasFMA = - TLI.isFMAFasterThanFMulAndFAdd(VT) && + AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) && (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) return SDValue(); - SDNodeFlags Flags = N->getFlags(); - bool CanFuse = Options.UnsafeFPMath || isContractable(N); - bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast || - CanFuse || HasFMAD); - // If the addition is not contractable, do not combine. - if (!AllowFusionGlobally && !isContractable(N)) - return SDValue(); - - const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo(); - if (STI && STI->generateFMAsInMachineCombiner(OptLevel)) - return SDValue(); - // Always prefer FMAD to FMA for precision. unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); + bool LookThroughFPExt = TLI.isFPExtFree(VT); - // Is the node an FMUL and contractable either due to global flags or - // SDNodeFlags. - auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { - if (N.getOpcode() != ISD::FMUL) - return false; - return AllowFusionGlobally || isContractable(N.getNode()); - }; // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), // prefer to fold the multiply with fewer uses. - if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) { + if (Aggressive && N0.getOpcode() == ISD::FMUL && + N1.getOpcode() == ISD::FMUL) { if (N0.getNode()->use_size() > N1.getNode()->use_size()) std::swap(N0, N1); } // fold (fadd (fmul x, y), z) -> (fma x, y, z) - if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { + if (N0.getOpcode() == ISD::FMUL && + (Aggressive || N0->hasOneUse())) { return DAG.getNode(PreferredFusedOpcode, SL, VT, - N0.getOperand(0), N0.getOperand(1), N1, Flags); + N0.getOperand(0), N0.getOperand(1), N1); } // fold (fadd x, (fmul y, z)) -> (fma y, z, x) // Note: Commutes FADD operands. - if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { + if (N1.getOpcode() == ISD::FMUL && + (Aggressive || N1->hasOneUse())) { return DAG.getNode(PreferredFusedOpcode, SL, VT, - N1.getOperand(0), N1.getOperand(1), N0, Flags); + N1.getOperand(0), N1.getOperand(1), N0); } // Look through FP_EXTEND nodes to do more combining. - - // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) - if (N0.getOpcode() == ISD::FP_EXTEND) { - SDValue N00 = N0.getOperand(0); - if (isContractableFMUL(N00) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N00.getOperand(1)), N1, Flags); + if (AllowFusion && LookThroughFPExt) { + // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) + if (N0.getOpcode() == ISD::FP_EXTEND) { + SDValue N00 = N0.getOperand(0); + if (N00.getOpcode() == ISD::FMUL) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(1)), N1); } - } - // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) - // Note: Commutes FADD operands. - if (N1.getOpcode() == ISD::FP_EXTEND) { - SDValue N10 = N1.getOperand(0); - if (isContractableFMUL(N10) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N10.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N10.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N10.getOperand(1)), N0, Flags); + // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) + // Note: Commutes FADD operands. + if (N1.getOpcode() == ISD::FP_EXTEND) { + SDValue N10 = N1.getOperand(0); + if (N10.getOpcode() == ISD::FMUL) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N10.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N10.getOperand(1)), N0); } } // More folding opportunities when target permits. - if (Aggressive) { + if ((AllowFusion || HasFMAD) && Aggressive) { // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) - if (CanFuse && - N0.getOpcode() == PreferredFusedOpcode && - N0.getOperand(2).getOpcode() == ISD::FMUL && - N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { + if (N0.getOpcode() == PreferredFusedOpcode && + N0.getOperand(2).getOpcode() == ISD::FMUL) { return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(2).getOperand(0), N0.getOperand(2).getOperand(1), - N1, Flags), Flags); + N1)); } // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x)) - if (CanFuse && - N1->getOpcode() == PreferredFusedOpcode && - N1.getOperand(2).getOpcode() == ISD::FMUL && - N1->hasOneUse() && N1.getOperand(2)->hasOneUse()) { + if (N1->getOpcode() == PreferredFusedOpcode && + N1.getOperand(2).getOpcode() == ISD::FMUL) { return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(2).getOperand(0), N1.getOperand(2).getOperand(1), - N0, Flags), Flags); + N0)); } - - // fold (fadd (fma x, y, (fpext (fmul u, v))), z) - // -> (fma x, y, (fma (fpext u), (fpext v), z)) - auto FoldFAddFMAFPExtFMul = [&] ( - SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, - SDNodeFlags Flags) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), - Z, Flags), Flags); - }; - if (N0.getOpcode() == PreferredFusedOpcode) { - SDValue N02 = N0.getOperand(2); - if (N02.getOpcode() == ISD::FP_EXTEND) { - SDValue N020 = N02.getOperand(0); - if (isContractableFMUL(N020) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N020.getValueType())) { - return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1), - N020.getOperand(0), N020.getOperand(1), - N1, Flags); + if (AllowFusion && LookThroughFPExt) { + // fold (fadd (fma x, y, (fpext (fmul u, v))), z) + // -> (fma x, y, (fma (fpext u), (fpext v), z)) + auto FoldFAddFMAFPExtFMul = [&] ( + SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) { + return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y, + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, U), + DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + Z)); + }; + if (N0.getOpcode() == PreferredFusedOpcode) { + SDValue N02 = N0.getOperand(2); + if (N02.getOpcode() == ISD::FP_EXTEND) { + SDValue N020 = N02.getOperand(0); + if (N020.getOpcode() == ISD::FMUL) + return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1), + N020.getOperand(0), N020.getOperand(1), + N1); } } - } - // fold (fadd (fpext (fma x, y, (fmul u, v))), z) - // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) - // FIXME: This turns two single-precision and one double-precision - // operation into two double-precision operations, which might not be - // interesting for all targets, especially GPUs. - auto FoldFAddFPExtFMAFMul = [&] ( - SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, - SDNodeFlags Flags) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, X), - DAG.getNode(ISD::FP_EXTEND, SL, VT, Y), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), - Z, Flags), Flags); - }; - if (N0.getOpcode() == ISD::FP_EXTEND) { - SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == PreferredFusedOpcode) { - SDValue N002 = N00.getOperand(2); - if (isContractableFMUL(N002) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) { - return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1), - N002.getOperand(0), N002.getOperand(1), - N1, Flags); + // fold (fadd (fpext (fma x, y, (fmul u, v))), z) + // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) + // FIXME: This turns two single-precision and one double-precision + // operation into two double-precision operations, which might not be + // interesting for all targets, especially GPUs. + auto FoldFAddFPExtFMAFMul = [&] ( + SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) { + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, X), + DAG.getNode(ISD::FP_EXTEND, SL, VT, Y), + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, U), + DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + Z)); + }; + if (N0.getOpcode() == ISD::FP_EXTEND) { + SDValue N00 = N0.getOperand(0); + if (N00.getOpcode() == PreferredFusedOpcode) { + SDValue N002 = N00.getOperand(2); + if (N002.getOpcode() == ISD::FMUL) + return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1), + N002.getOperand(0), N002.getOperand(1), + N1); } } - } - // fold (fadd x, (fma y, z, (fpext (fmul u, v))) - // -> (fma y, z, (fma (fpext u), (fpext v), x)) - if (N1.getOpcode() == PreferredFusedOpcode) { - SDValue N12 = N1.getOperand(2); - if (N12.getOpcode() == ISD::FP_EXTEND) { - SDValue N120 = N12.getOperand(0); - if (isContractableFMUL(N120) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N120.getValueType())) { - return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1), - N120.getOperand(0), N120.getOperand(1), - N0, Flags); + // fold (fadd x, (fma y, z, (fpext (fmul u, v))) + // -> (fma y, z, (fma (fpext u), (fpext v), x)) + if (N1.getOpcode() == PreferredFusedOpcode) { + SDValue N12 = N1.getOperand(2); + if (N12.getOpcode() == ISD::FP_EXTEND) { + SDValue N120 = N12.getOperand(0); + if (N120.getOpcode() == ISD::FMUL) + return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1), + N120.getOperand(0), N120.getOperand(1), + N0); } } - } - // fold (fadd x, (fpext (fma y, z, (fmul u, v))) - // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x)) - // FIXME: This turns two single-precision and one double-precision - // operation into two double-precision operations, which might not be - // interesting for all targets, especially GPUs. - if (N1.getOpcode() == ISD::FP_EXTEND) { - SDValue N10 = N1.getOperand(0); - if (N10.getOpcode() == PreferredFusedOpcode) { - SDValue N102 = N10.getOperand(2); - if (isContractableFMUL(N102) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N10.getValueType())) { - return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1), - N102.getOperand(0), N102.getOperand(1), - N0, Flags); + // fold (fadd x, (fpext (fma y, z, (fmul u, v))) + // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x)) + // FIXME: This turns two single-precision and one double-precision + // operation into two double-precision operations, which might not be + // interesting for all targets, especially GPUs. + if (N1.getOpcode() == ISD::FP_EXTEND) { + SDValue N10 = N1.getOperand(0); + if (N10.getOpcode() == PreferredFusedOpcode) { + SDValue N102 = N10.getOperand(2); + if (N102.getOpcode() == ISD::FMUL) + return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1), + N102.getOperand(0), N102.getOperand(1), + N0); } } } @@ -10661,169 +7822,149 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { SDLoc SL(N); const TargetOptions &Options = DAG.getTarget().Options; + bool AllowFusion = + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); + // Floating-point multiply-add with intermediate rounding. bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)); // Floating-point multiply-add without intermediate rounding. bool HasFMA = - TLI.isFMAFasterThanFMulAndFAdd(VT) && + AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) && (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) return SDValue(); - const SDNodeFlags Flags = N->getFlags(); - bool CanFuse = Options.UnsafeFPMath || isContractable(N); - bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast || - CanFuse || HasFMAD); - - // If the subtraction is not contractable, do not combine. - if (!AllowFusionGlobally && !isContractable(N)) - return SDValue(); - - const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo(); - if (STI && STI->generateFMAsInMachineCombiner(OptLevel)) - return SDValue(); - // Always prefer FMAD to FMA for precision. unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); - - // Is the node an FMUL and contractable either due to global flags or - // SDNodeFlags. - auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { - if (N.getOpcode() != ISD::FMUL) - return false; - return AllowFusionGlobally || isContractable(N.getNode()); - }; + bool LookThroughFPExt = TLI.isFPExtFree(VT); // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z)) - if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { + if (N0.getOpcode() == ISD::FMUL && + (Aggressive || N0->hasOneUse())) { return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), - DAG.getNode(ISD::FNEG, SL, VT, N1), Flags); + DAG.getNode(ISD::FNEG, SL, VT, N1)); } // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x) // Note: Commutes FSUB operands. - if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { + if (N1.getOpcode() == ISD::FMUL && + (Aggressive || N1->hasOneUse())) return DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), - N1.getOperand(1), N0, Flags); - } + N1.getOperand(1), N0); // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z)) - if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) && + if (N0.getOpcode() == ISD::FNEG && + N0.getOperand(0).getOpcode() == ISD::FMUL && (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) { SDValue N00 = N0.getOperand(0).getOperand(0); SDValue N01 = N0.getOperand(0).getOperand(1); return DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, N00), N01, - DAG.getNode(ISD::FNEG, SL, VT, N1), Flags); + DAG.getNode(ISD::FNEG, SL, VT, N1)); } // Look through FP_EXTEND nodes to do more combining. - - // fold (fsub (fpext (fmul x, y)), z) - // -> (fma (fpext x), (fpext y), (fneg z)) - if (N0.getOpcode() == ISD::FP_EXTEND) { - SDValue N00 = N0.getOperand(0); - if (isContractableFMUL(N00) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N00.getOperand(1)), - DAG.getNode(ISD::FNEG, SL, VT, N1), Flags); - } - } - - // fold (fsub x, (fpext (fmul y, z))) - // -> (fma (fneg (fpext y)), (fpext z), x) - // Note: Commutes FSUB operands. - if (N1.getOpcode() == ISD::FP_EXTEND) { - SDValue N10 = N1.getOperand(0); - if (isContractableFMUL(N10) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N10.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N10.getOperand(0))), - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N10.getOperand(1)), - N0, Flags); + if (AllowFusion && LookThroughFPExt) { + // fold (fsub (fpext (fmul x, y)), z) + // -> (fma (fpext x), (fpext y), (fneg z)) + if (N0.getOpcode() == ISD::FP_EXTEND) { + SDValue N00 = N0.getOperand(0); + if (N00.getOpcode() == ISD::FMUL) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(1)), + DAG.getNode(ISD::FNEG, SL, VT, N1)); } - } - // fold (fsub (fpext (fneg (fmul, x, y))), z) - // -> (fneg (fma (fpext x), (fpext y), z)) - // Note: This could be removed with appropriate canonicalization of the - // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the - // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent - // from implementing the canonicalization in visitFSUB. - if (N0.getOpcode() == ISD::FP_EXTEND) { - SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == ISD::FNEG) { - SDValue N000 = N00.getOperand(0); - if (isContractableFMUL(N000) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N000.getOperand(0)), + // fold (fsub x, (fpext (fmul y, z))) + // -> (fma (fneg (fpext y)), (fpext z), x) + // Note: Commutes FSUB operands. + if (N1.getOpcode() == ISD::FP_EXTEND) { + SDValue N10 = N1.getOperand(0); + if (N10.getOpcode() == ISD::FMUL) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, - N000.getOperand(1)), - N1, Flags)); + N10.getOperand(0))), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N10.getOperand(1)), + N0); + } + + // fold (fsub (fpext (fneg (fmul, x, y))), z) + // -> (fneg (fma (fpext x), (fpext y), z)) + // Note: This could be removed with appropriate canonicalization of the + // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the + // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent + // from implementing the canonicalization in visitFSUB. + if (N0.getOpcode() == ISD::FP_EXTEND) { + SDValue N00 = N0.getOperand(0); + if (N00.getOpcode() == ISD::FNEG) { + SDValue N000 = N00.getOperand(0); + if (N000.getOpcode() == ISD::FMUL) { + return DAG.getNode(ISD::FNEG, SL, VT, + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N000.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N000.getOperand(1)), + N1)); + } } } - } - // fold (fsub (fneg (fpext (fmul, x, y))), z) - // -> (fneg (fma (fpext x)), (fpext y), z) - // Note: This could be removed with appropriate canonicalization of the - // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the - // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent - // from implementing the canonicalization in visitFSUB. - if (N0.getOpcode() == ISD::FNEG) { - SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == ISD::FP_EXTEND) { - SDValue N000 = N00.getOperand(0); - if (isContractableFMUL(N000) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N000.getValueType())) { - return DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N000.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N000.getOperand(1)), - N1, Flags)); + // fold (fsub (fneg (fpext (fmul, x, y))), z) + // -> (fneg (fma (fpext x)), (fpext y), z) + // Note: This could be removed with appropriate canonicalization of the + // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the + // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent + // from implementing the canonicalization in visitFSUB. + if (N0.getOpcode() == ISD::FNEG) { + SDValue N00 = N0.getOperand(0); + if (N00.getOpcode() == ISD::FP_EXTEND) { + SDValue N000 = N00.getOperand(0); + if (N000.getOpcode() == ISD::FMUL) { + return DAG.getNode(ISD::FNEG, SL, VT, + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N000.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N000.getOperand(1)), + N1)); + } } } + } // More folding opportunities when target permits. - if (Aggressive) { + if ((AllowFusion || HasFMAD) && Aggressive) { // fold (fsub (fma x, y, (fmul u, v)), z) // -> (fma x, y (fma u, v, (fneg z))) - if (CanFuse && N0.getOpcode() == PreferredFusedOpcode && - isContractableFMUL(N0.getOperand(2)) && N0->hasOneUse() && - N0.getOperand(2)->hasOneUse()) { + if (N0.getOpcode() == PreferredFusedOpcode && + N0.getOperand(2).getOpcode() == ISD::FMUL) { return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(2).getOperand(0), N0.getOperand(2).getOperand(1), DAG.getNode(ISD::FNEG, SL, VT, - N1), Flags), Flags); + N1))); } // fold (fsub x, (fma y, z, (fmul u, v))) // -> (fma (fneg y), z, (fma (fneg u), v, x)) - if (CanFuse && N1.getOpcode() == PreferredFusedOpcode && - isContractableFMUL(N1.getOperand(2))) { + if (N1.getOpcode() == PreferredFusedOpcode && + N1.getOperand(2).getOpcode() == ISD::FMUL) { SDValue N20 = N1.getOperand(2).getOperand(0); SDValue N21 = N1.getOperand(2).getOperand(1); return DAG.getNode(PreferredFusedOpcode, SL, VT, @@ -10832,109 +7973,104 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { N1.getOperand(1), DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, N20), - N21, N0, Flags), Flags); - } + N21, N0)); + } + + if (AllowFusion && LookThroughFPExt) { + // fold (fsub (fma x, y, (fpext (fmul u, v))), z) + // -> (fma x, y (fma (fpext u), (fpext v), (fneg z))) + if (N0.getOpcode() == PreferredFusedOpcode) { + SDValue N02 = N0.getOperand(2); + if (N02.getOpcode() == ISD::FP_EXTEND) { + SDValue N020 = N02.getOperand(0); + if (N020.getOpcode() == ISD::FMUL) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + N0.getOperand(0), N0.getOperand(1), + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N020.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N020.getOperand(1)), + DAG.getNode(ISD::FNEG, SL, VT, + N1))); + } + } - // fold (fsub (fma x, y, (fpext (fmul u, v))), z) - // -> (fma x, y (fma (fpext u), (fpext v), (fneg z))) - if (N0.getOpcode() == PreferredFusedOpcode) { - SDValue N02 = N0.getOperand(2); - if (N02.getOpcode() == ISD::FP_EXTEND) { - SDValue N020 = N02.getOperand(0); - if (isContractableFMUL(N020) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N020.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - N0.getOperand(0), N0.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N020.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N020.getOperand(1)), - DAG.getNode(ISD::FNEG, SL, VT, - N1), Flags), Flags); + // fold (fsub (fpext (fma x, y, (fmul u, v))), z) + // -> (fma (fpext x), (fpext y), + // (fma (fpext u), (fpext v), (fneg z))) + // FIXME: This turns two single-precision and one double-precision + // operation into two double-precision operations, which might not be + // interesting for all targets, especially GPUs. + if (N0.getOpcode() == ISD::FP_EXTEND) { + SDValue N00 = N0.getOperand(0); + if (N00.getOpcode() == PreferredFusedOpcode) { + SDValue N002 = N00.getOperand(2); + if (N002.getOpcode() == ISD::FMUL) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(1)), + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N002.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N002.getOperand(1)), + DAG.getNode(ISD::FNEG, SL, VT, + N1))); } } - } - // fold (fsub (fpext (fma x, y, (fmul u, v))), z) - // -> (fma (fpext x), (fpext y), - // (fma (fpext u), (fpext v), (fneg z))) - // FIXME: This turns two single-precision and one double-precision - // operation into two double-precision operations, which might not be - // interesting for all targets, especially GPUs. - if (N0.getOpcode() == ISD::FP_EXTEND) { - SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == PreferredFusedOpcode) { - SDValue N002 = N00.getOperand(2); - if (isContractableFMUL(N002) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) { + // fold (fsub x, (fma y, z, (fpext (fmul u, v)))) + // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x)) + if (N1.getOpcode() == PreferredFusedOpcode && + N1.getOperand(2).getOpcode() == ISD::FP_EXTEND) { + SDValue N120 = N1.getOperand(2).getOperand(0); + if (N120.getOpcode() == ISD::FMUL) { + SDValue N1200 = N120.getOperand(0); + SDValue N1201 = N120.getOperand(1); return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N00.getOperand(1)), + DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), + N1.getOperand(1), DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N002.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N002.getOperand(1)), DAG.getNode(ISD::FNEG, SL, VT, - N1), Flags), Flags); + DAG.getNode(ISD::FP_EXTEND, SL, + VT, N1200)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N1201), + N0)); } } - } - - // fold (fsub x, (fma y, z, (fpext (fmul u, v)))) - // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x)) - if (N1.getOpcode() == PreferredFusedOpcode && - N1.getOperand(2).getOpcode() == ISD::FP_EXTEND) { - SDValue N120 = N1.getOperand(2).getOperand(0); - if (isContractableFMUL(N120) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N120.getValueType())) { - SDValue N1200 = N120.getOperand(0); - SDValue N1201 = N120.getOperand(1); - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), - N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, - VT, N1200)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N1201), - N0, Flags), Flags); - } - } - // fold (fsub x, (fpext (fma y, z, (fmul u, v)))) - // -> (fma (fneg (fpext y)), (fpext z), - // (fma (fneg (fpext u)), (fpext v), x)) - // FIXME: This turns two single-precision and one double-precision - // operation into two double-precision operations, which might not be - // interesting for all targets, especially GPUs. - if (N1.getOpcode() == ISD::FP_EXTEND && + // fold (fsub x, (fpext (fma y, z, (fmul u, v)))) + // -> (fma (fneg (fpext y)), (fpext z), + // (fma (fneg (fpext u)), (fpext v), x)) + // FIXME: This turns two single-precision and one double-precision + // operation into two double-precision operations, which might not be + // interesting for all targets, especially GPUs. + if (N1.getOpcode() == ISD::FP_EXTEND && N1.getOperand(0).getOpcode() == PreferredFusedOpcode) { - SDValue CvtSrc = N1.getOperand(0); - SDValue N100 = CvtSrc.getOperand(0); - SDValue N101 = CvtSrc.getOperand(1); - SDValue N102 = CvtSrc.getOperand(2); - if (isContractableFMUL(N102) && - TLI.isFPExtFoldable(PreferredFusedOpcode, VT, CvtSrc.getValueType())) { - SDValue N1020 = N102.getOperand(0); - SDValue N1021 = N102.getOperand(1); - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N100)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N101), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, - VT, N1020)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, - N1021), - N0, Flags), Flags); + SDValue N100 = N1.getOperand(0).getOperand(0); + SDValue N101 = N1.getOperand(0).getOperand(1); + SDValue N102 = N1.getOperand(0).getOperand(2); + if (N102.getOpcode() == ISD::FMUL) { + SDValue N1020 = N102.getOperand(0); + SDValue N1021 = N102.getOperand(1); + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N100)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, N101), + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, + VT, N1020)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N1021), + N0)); + } } } } @@ -10942,36 +8078,27 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { return SDValue(); } -/// Try to perform FMA combining on a given FMUL node based on the distributive -/// law x * (y + 1) = x * y + x and variants thereof (commuted versions, -/// subtraction instead of addition). -SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) { +/// Try to perform FMA combining on a given FMUL node. +SDValue DAGCombiner::visitFMULForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); - const SDNodeFlags Flags = N->getFlags(); assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation"); const TargetOptions &Options = DAG.getTarget().Options; + bool AllowFusion = + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); - // The transforms below are incorrect when x == 0 and y == inf, because the - // intermediate multiplication produces a nan. - if (!Options.NoInfsFPMath) - return SDValue(); + // Floating-point multiply-add with intermediate rounding. + bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)); // Floating-point multiply-add without intermediate rounding. bool HasFMA = - (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) && - TLI.isFMAFasterThanFMulAndFAdd(VT) && + AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) && (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); - // Floating-point multiply-add with intermediate rounding. This can result - // in a less precise result due to the changed rounding order. - bool HasFMAD = Options.UnsafeFPMath && - (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)); - // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) return SDValue(); @@ -10980,58 +8107,54 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) { unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); - // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y) - // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y)) - auto FuseFADD = [&](SDValue X, SDValue Y, const SDNodeFlags Flags) { + // fold (fmul (fadd x, +1.0), y) -> (fma x, y, y) + // fold (fmul (fadd x, -1.0), y) -> (fma x, y, (fneg y)) + auto FuseFADD = [&](SDValue X, SDValue Y) { if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) { - if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) { - if (C->isExactlyValue(+1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, - Y, Flags); - if (C->isExactlyValue(-1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, - DAG.getNode(ISD::FNEG, SL, VT, Y), Flags); - } + auto XC1 = isConstOrConstSplatFP(X.getOperand(1)); + if (XC1 && XC1->isExactlyValue(+1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y); + if (XC1 && XC1->isExactlyValue(-1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, + DAG.getNode(ISD::FNEG, SL, VT, Y)); } return SDValue(); }; - if (SDValue FMA = FuseFADD(N0, N1, Flags)) + if (SDValue FMA = FuseFADD(N0, N1)) return FMA; - if (SDValue FMA = FuseFADD(N1, N0, Flags)) + if (SDValue FMA = FuseFADD(N1, N0)) return FMA; - // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y) - // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y)) - // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y)) - // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y) - auto FuseFSUB = [&](SDValue X, SDValue Y, const SDNodeFlags Flags) { + // fold (fmul (fsub +1.0, x), y) -> (fma (fneg x), y, y) + // fold (fmul (fsub -1.0, x), y) -> (fma (fneg x), y, (fneg y)) + // fold (fmul (fsub x, +1.0), y) -> (fma x, y, (fneg y)) + // fold (fmul (fsub x, -1.0), y) -> (fma x, y, y) + auto FuseFSUB = [&](SDValue X, SDValue Y) { if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) { - if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) { - if (C0->isExactlyValue(+1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, - Y, Flags); - if (C0->isExactlyValue(-1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, - DAG.getNode(ISD::FNEG, SL, VT, Y), Flags); - } - if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) { - if (C1->isExactlyValue(+1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, - DAG.getNode(ISD::FNEG, SL, VT, Y), Flags); - if (C1->isExactlyValue(-1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, - Y, Flags); - } + auto XC0 = isConstOrConstSplatFP(X.getOperand(0)); + if (XC0 && XC0->isExactlyValue(+1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, + Y); + if (XC0 && XC0->isExactlyValue(-1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, + DAG.getNode(ISD::FNEG, SL, VT, Y)); + + auto XC1 = isConstOrConstSplatFP(X.getOperand(1)); + if (XC1 && XC1->isExactlyValue(+1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, + DAG.getNode(ISD::FNEG, SL, VT, Y)); + if (XC1 && XC1->isExactlyValue(-1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y); } return SDValue(); }; - if (SDValue FMA = FuseFSUB(N0, N1, Flags)) + if (SDValue FMA = FuseFSUB(N0, N1)) return FMA; - if (SDValue FMA = FuseFSUB(N1, N0, Flags)) + if (SDValue FMA = FuseFSUB(N1, N0)) return FMA; return SDValue(); @@ -11045,7 +8168,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { EVT VT = N->getValueType(0); SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; - const SDNodeFlags Flags = N->getFlags(); + const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags; // fold vector ops if (VT.isVector()) @@ -11060,15 +8183,6 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { if (N0CFP && !N1CFP) return DAG.getNode(ISD::FADD, DL, VT, N1, N0, Flags); - // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math) - ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true); - if (N1C && N1C->isZero()) - if (N1C->isNegative() || Options.UnsafeFPMath || Flags.hasNoSignedZeros()) - return N0; - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; - // fold (fadd A, (fneg B)) -> (fsub A, B) if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) && isNegatibleForFree(N1, LegalOperations, TLI, &Options) == 2) @@ -11081,53 +8195,32 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { return DAG.getNode(ISD::FSUB, DL, VT, N1, GetNegatedExpression(N0, DAG, LegalOperations), Flags); - auto isFMulNegTwo = [](SDValue FMul) { - if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL) - return false; - auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true); - return C && C->isExactlyValue(-2.0); - }; + // If 'unsafe math' is enabled, fold lots of things. + if (Options.UnsafeFPMath) { + // No FP constant should be created after legalization as Instruction + // Selection pass has a hard time dealing with FP constants. + bool AllowNewConst = (Level < AfterLegalizeDAG); - // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B) - if (isFMulNegTwo(N0)) { - SDValue B = N0.getOperand(0); - SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B, Flags); - return DAG.getNode(ISD::FSUB, DL, VT, N1, Add, Flags); - } - // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B) - if (isFMulNegTwo(N1)) { - SDValue B = N1.getOperand(0); - SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B, Flags); - return DAG.getNode(ISD::FSUB, DL, VT, N0, Add, Flags); - } + // fold (fadd A, 0) -> A + if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1)) + if (N1C->isZero()) + return N0; - // No FP constant should be created after legalization as Instruction - // Selection pass has a hard time dealing with FP constants. - bool AllowNewConst = (Level < AfterLegalizeDAG); + // fold (fadd (fadd x, c1), c2) -> (fadd x, (fadd c1, c2)) + if (N1CFP && N0.getOpcode() == ISD::FADD && N0.getNode()->hasOneUse() && + isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) + return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), + DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1, + Flags), + Flags); - // If 'unsafe math' or nnan is enabled, fold lots of things. - if ((Options.UnsafeFPMath || Flags.hasNoNaNs()) && AllowNewConst) { // If allowed, fold (fadd (fneg x), x) -> 0.0 - if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1) + if (AllowNewConst && N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1) return DAG.getConstantFP(0.0, DL, VT); // If allowed, fold (fadd x, (fneg x)) -> 0.0 - if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0) + if (AllowNewConst && N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0) return DAG.getConstantFP(0.0, DL, VT); - } - - // If 'unsafe math' or reassoc and nsz, fold lots of things. - // TODO: break out portions of the transformations below for which Unsafe is - // considered and which do not require both nsz and reassoc - if ((Options.UnsafeFPMath || - (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) && - AllowNewConst) { - // fadd (fadd x, c1), c2 -> fadd x, c1 + c2 - if (N1CFP && N0.getOpcode() == ISD::FADD && - isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) { - SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1, Flags); - return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC, Flags); - } // We can fold chains of FADD's of the same value into multiplications. // This transform is not safe in general because we are reducing the number @@ -11175,7 +8268,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { } } - if (N0.getOpcode() == ISD::FADD) { + if (N0.getOpcode() == ISD::FADD && AllowNewConst) { bool CFP00 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(0)); // (fadd (fadd x, x), x) -> (fmul x, 3.0) if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) && @@ -11185,7 +8278,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { } } - if (N1.getOpcode() == ISD::FADD) { + if (N1.getOpcode() == ISD::FADD && AllowNewConst) { bool CFP10 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(0)); // (fadd x, (fadd x, x)) -> (fmul x, 3.0) if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) && @@ -11196,7 +8289,8 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { } // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0) - if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD && + if (AllowNewConst && + N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD && N0.getOperand(0) == N0.getOperand(1) && N1.getOperand(0) == N1.getOperand(1) && N0.getOperand(0) == N1.getOperand(0)) { @@ -11211,18 +8305,19 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { AddToWorklist(Fused.getNode()); return Fused; } + return SDValue(); } SDValue DAGCombiner::visitFSUB(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true); - ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true); + ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0); + ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1); EVT VT = N->getValueType(0); - SDLoc DL(N); + SDLoc dl(N); const TargetOptions &Options = DAG.getTarget().Options; - const SDNodeFlags Flags = N->getFlags(); + const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags; // fold vector ops if (VT.isVector()) @@ -11231,51 +8326,44 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { // fold (fsub c1, c2) -> c1-c2 if (N0CFP && N1CFP) - return DAG.getNode(ISD::FSUB, DL, VT, N0, N1, Flags); + return DAG.getNode(ISD::FSUB, dl, VT, N0, N1, Flags); - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; + // fold (fsub A, (fneg B)) -> (fadd A, B) + if (isNegatibleForFree(N1, LegalOperations, TLI, &Options)) + return DAG.getNode(ISD::FADD, dl, VT, N0, + GetNegatedExpression(N1, DAG, LegalOperations), Flags); - // (fsub A, 0) -> A - if (N1CFP && N1CFP->isZero()) { - if (!N1CFP->isNegative() || Options.UnsafeFPMath || - Flags.hasNoSignedZeros()) { + // If 'unsafe math' is enabled, fold lots of things. + if (Options.UnsafeFPMath) { + // (fsub A, 0) -> A + if (N1CFP && N1CFP->isZero()) return N0; - } - } - - if (N0 == N1) { - // (fsub x, x) -> 0.0 - if (Options.UnsafeFPMath || Flags.hasNoNaNs()) - return DAG.getConstantFP(0.0f, DL, VT); - } - // (fsub -0.0, N1) -> -N1 - if (N0CFP && N0CFP->isZero()) { - if (N0CFP->isNegative() || - (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) { + // (fsub 0, B) -> -B + if (N0CFP && N0CFP->isZero()) { if (isNegatibleForFree(N1, LegalOperations, TLI, &Options)) return GetNegatedExpression(N1, DAG, LegalOperations); if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT)) - return DAG.getNode(ISD::FNEG, DL, VT, N1, Flags); + return DAG.getNode(ISD::FNEG, dl, VT, N1); } - } - if ((Options.UnsafeFPMath || - (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) - && N1.getOpcode() == ISD::FADD) { - // X - (X + Y) -> -Y - if (N0 == N1->getOperand(0)) - return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1), Flags); - // X - (Y + X) -> -Y - if (N0 == N1->getOperand(1)) - return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0), Flags); - } + // (fsub x, x) -> 0.0 + if (N0 == N1) + return DAG.getConstantFP(0.0f, dl, VT); - // fold (fsub A, (fneg B)) -> (fadd A, B) - if (isNegatibleForFree(N1, LegalOperations, TLI, &Options)) - return DAG.getNode(ISD::FADD, DL, VT, N0, - GetNegatedExpression(N1, DAG, LegalOperations), Flags); + // (fsub x, (fadd x, y)) -> (fneg y) + // (fsub x, (fadd y, x)) -> (fneg y) + if (N1.getOpcode() == ISD::FADD) { + SDValue N10 = N1->getOperand(0); + SDValue N11 = N1->getOperand(1); + + if (N10 == N0 && isNegatibleForFree(N11, LegalOperations, TLI, &Options)) + return GetNegatedExpression(N11, DAG, LegalOperations); + + if (N11 == N0 && isNegatibleForFree(N10, LegalOperations, TLI, &Options)) + return GetNegatedExpression(N10, DAG, LegalOperations); + } + } // FSUB -> FMA combines: if (SDValue Fused = visitFSUBForFMACombine(N)) { @@ -11289,12 +8377,12 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { SDValue DAGCombiner::visitFMUL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true); - ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true); + ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0); + ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1); EVT VT = N->getValueType(0); SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; - const SDNodeFlags Flags = N->getFlags(); + const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags; // fold vector ops if (VT.isVector()) { @@ -11316,35 +8404,42 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { if (N1CFP && N1CFP->isExactlyValue(1.0)) return N0; - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; - - if (Options.UnsafeFPMath || - (Flags.hasNoNaNs() && Flags.hasNoSignedZeros())) { + if (Options.UnsafeFPMath) { // fold (fmul A, 0) -> 0 if (N1CFP && N1CFP->isZero()) return N1; - } - if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) { - // fmul (fmul X, C1), C2 -> fmul X, C1 * C2 - if (isConstantFPBuildVectorOrConstantFP(N1) && - N0.getOpcode() == ISD::FMUL) { + // fold (fmul (fmul x, c1), c2) -> (fmul x, (fmul c1, c2)) + if (N0.getOpcode() == ISD::FMUL) { + // Fold scalars or any vector constants (not just splats). + // This fold is done in general by InstCombine, but extra fmul insts + // may have been generated during lowering. SDValue N00 = N0.getOperand(0); SDValue N01 = N0.getOperand(1); - // Avoid an infinite loop by making sure that N00 is not a constant - // (the inner multiply has not been constant folded yet). - if (isConstantFPBuildVectorOrConstantFP(N01) && - !isConstantFPBuildVectorOrConstantFP(N00)) { - SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1, Flags); - return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts, Flags); + auto *BV1 = dyn_cast<BuildVectorSDNode>(N1); + auto *BV00 = dyn_cast<BuildVectorSDNode>(N00); + auto *BV01 = dyn_cast<BuildVectorSDNode>(N01); + + // Check 1: Make sure that the first operand of the inner multiply is NOT + // a constant. Otherwise, we may induce infinite looping. + if (!(isConstOrConstSplatFP(N00) || (BV00 && BV00->isConstant()))) { + // Check 2: Make sure that the second operand of the inner multiply and + // the second operand of the outer multiply are constants. + if ((N1CFP && isConstOrConstSplatFP(N01)) || + (BV1 && BV01 && BV1->isConstant() && BV01->isConstant())) { + SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1, Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts, Flags); + } } } - // Match a special-case: we convert X * 2.0 into fadd. - // fmul (fadd X, X), C -> fmul X, 2.0 * C - if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() && - N0.getOperand(0) == N0.getOperand(1)) { + // fold (fmul (fadd x, x), c) -> (fmul x, (fmul 2.0, c)) + // Undo the fmul 2.0, x -> fadd x, x transformation, since if it occurs + // during an early run of DAGCombiner can prevent folding with fmuls + // inserted during lowering. + if (N0.getOpcode() == ISD::FADD && + (N0.getOperand(0) == N0.getOperand(1)) && + N0.hasOneUse()) { const SDValue Two = DAG.getConstantFP(2.0, DL, VT); SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1, Flags); return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts, Flags); @@ -11373,54 +8468,8 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { } } - // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X)) - // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X) - if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() && - (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) && - TLI.isOperationLegal(ISD::FABS, VT)) { - SDValue Select = N0, X = N1; - if (Select.getOpcode() != ISD::SELECT) - std::swap(Select, X); - - SDValue Cond = Select.getOperand(0); - auto TrueOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(1)); - auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2)); - - if (TrueOpnd && FalseOpnd && - Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X && - isa<ConstantFPSDNode>(Cond.getOperand(1)) && - cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) { - ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get(); - switch (CC) { - default: break; - case ISD::SETOLT: - case ISD::SETULT: - case ISD::SETOLE: - case ISD::SETULE: - case ISD::SETLT: - case ISD::SETLE: - std::swap(TrueOpnd, FalseOpnd); - LLVM_FALLTHROUGH; - case ISD::SETOGT: - case ISD::SETUGT: - case ISD::SETOGE: - case ISD::SETUGE: - case ISD::SETGT: - case ISD::SETGE: - if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) && - TLI.isOperationLegal(ISD::FNEG, VT)) - return DAG.getNode(ISD::FNEG, DL, VT, - DAG.getNode(ISD::FABS, DL, VT, X)); - if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0)) - return DAG.getNode(ISD::FABS, DL, VT, X); - - break; - } - } - } - // FMUL -> FMA combines: - if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) { + if (SDValue Fused = visitFMULForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; } @@ -11435,21 +8484,17 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0); ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1); EVT VT = N->getValueType(0); - SDLoc DL(N); + SDLoc dl(N); const TargetOptions &Options = DAG.getTarget().Options; - // FMA nodes have flags that propagate to the created nodes. - const SDNodeFlags Flags = N->getFlags(); - bool UnsafeFPMath = Options.UnsafeFPMath || isContractable(N); - // Constant fold FMA. if (isa<ConstantFPSDNode>(N0) && isa<ConstantFPSDNode>(N1) && isa<ConstantFPSDNode>(N2)) { - return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2); + return DAG.getNode(ISD::FMA, dl, VT, N0, N1, N2); } - if (UnsafeFPMath) { + if (Options.UnsafeFPMath) { if (N0CFP && N0CFP->isZero()) return N2; if (N1CFP && N1CFP->isZero()) @@ -11466,24 +8511,29 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { !isConstantFPBuildVectorOrConstantFP(N1)) return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2); - if (UnsafeFPMath) { + // TODO: FMA nodes should have flags that propagate to the created nodes. + // For now, create a Flags object for use with all unsafe math transforms. + SDNodeFlags Flags; + Flags.setUnsafeAlgebra(true); + + if (Options.UnsafeFPMath) { // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2) if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) && isConstantFPBuildVectorOrConstantFP(N1) && isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) { - return DAG.getNode(ISD::FMUL, DL, VT, N0, - DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1), - Flags), Flags); + return DAG.getNode(ISD::FMUL, dl, VT, N0, + DAG.getNode(ISD::FADD, dl, VT, N1, N2.getOperand(1), + &Flags), &Flags); } // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y) if (N0.getOpcode() == ISD::FMUL && isConstantFPBuildVectorOrConstantFP(N1) && isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) { - return DAG.getNode(ISD::FMA, DL, VT, + return DAG.getNode(ISD::FMA, dl, VT, N0.getOperand(0), - DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1), - Flags), + DAG.getNode(ISD::FMUL, dl, VT, N1, N0.getOperand(1), + &Flags), N2); } } @@ -11493,40 +8543,32 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { if (N1CFP) { if (N1CFP->isExactlyValue(1.0)) // TODO: The FMA node should have flags that propagate to this node. - return DAG.getNode(ISD::FADD, DL, VT, N0, N2); + return DAG.getNode(ISD::FADD, dl, VT, N0, N2); if (N1CFP->isExactlyValue(-1.0) && (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) { - SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0); + SDValue RHSNeg = DAG.getNode(ISD::FNEG, dl, VT, N0); AddToWorklist(RHSNeg.getNode()); // TODO: The FMA node should have flags that propagate to this node. - return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg); - } - - // fma (fneg x), K, y -> fma x -K, y - if (N0.getOpcode() == ISD::FNEG && - (TLI.isOperationLegal(ISD::ConstantFP, VT) || - (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT)))) { - return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0), - DAG.getNode(ISD::FNEG, DL, VT, N1, Flags), N2); + return DAG.getNode(ISD::FADD, dl, VT, N2, RHSNeg); } } - if (UnsafeFPMath) { + if (Options.UnsafeFPMath) { // (fma x, c, x) -> (fmul x, (c+1)) if (N1CFP && N0 == N2) { - return DAG.getNode(ISD::FMUL, DL, VT, N0, - DAG.getNode(ISD::FADD, DL, VT, N1, - DAG.getConstantFP(1.0, DL, VT), Flags), - Flags); + return DAG.getNode(ISD::FMUL, dl, VT, N0, + DAG.getNode(ISD::FADD, dl, VT, + N1, DAG.getConstantFP(1.0, dl, VT), + &Flags), &Flags); } // (fma x, c, (fneg x)) -> (fmul x, (c-1)) if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) { - return DAG.getNode(ISD::FMUL, DL, VT, N0, - DAG.getNode(ISD::FADD, DL, VT, N1, - DAG.getConstantFP(-1.0, DL, VT), Flags), - Flags); + return DAG.getNode(ISD::FMUL, dl, VT, N0, + DAG.getNode(ISD::FADD, dl, VT, + N1, DAG.getConstantFP(-1.0, dl, VT), + &Flags), &Flags); } } @@ -11536,14 +8578,14 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { // Combine multiple FDIVs with the same divisor into multiple FMULs by the // reciprocal. // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip) -// Notice that this is not always beneficial. One reason is different targets +// Notice that this is not always beneficial. One reason is different target // may have different costs for FDIV and FMUL, so sometimes the cost of two // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL". SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) { bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath; - const SDNodeFlags Flags = N->getFlags(); - if (!UnsafeMath && !Flags.hasAllowReciprocal()) + const SDNodeFlags *Flags = N->getFlags(); + if (!UnsafeMath && !Flags->hasAllowReciprocal()) return SDValue(); // Skip if current node is a reciprocal. @@ -11566,7 +8608,7 @@ SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) { if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) { // This division is eligible for optimization only if global unsafe math // is enabled or if this division allows reciprocal formation. - if (UnsafeMath || U->getFlags().hasAllowReciprocal()) + if (UnsafeMath || U->getFlags()->hasAllowReciprocal()) Users.insert(U); } } @@ -11605,7 +8647,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { EVT VT = N->getValueType(0); SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; - SDNodeFlags Flags = N->getFlags(); + SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags; // fold vector ops if (VT.isVector()) @@ -11616,14 +8658,11 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { if (N0CFP && N1CFP) return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1, Flags); - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; - - if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) { + if (Options.UnsafeFPMath) { // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable. if (N1CFP) { // Compute the reciprocal 1.0 / c2. - const APFloat &N1APF = N1CFP->getValueAPF(); + APFloat N1APF = N1CFP->getValueAPF(); APFloat Recip(N1APF.getSemantics(), 1); // 1.0 APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven); // Only do the transform if the reciprocal is a legal fp immediate that @@ -11632,8 +8671,8 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { (!LegalOperations || // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM // backend)... we should handle this gracefully after Legalize. - // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) || - TLI.isOperationLegal(ISD::ConstantFP, VT) || + // TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT) || + TLI.isOperationLegal(llvm::ISD::ConstantFP, VT) || TLI.isFPImmLegal(Recip, VT))) return DAG.getNode(ISD::FMUL, DL, VT, N0, DAG.getConstantFP(Recip, DL, VT), Flags); @@ -11642,12 +8681,12 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { // If this FDIV is part of a reciprocal square root, it may be folded // into a target-specific square root estimate instruction. if (N1.getOpcode() == ISD::FSQRT) { - if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags)) { + if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0), Flags)) { return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); } } else if (N1.getOpcode() == ISD::FP_EXTEND && N1.getOperand(0).getOpcode() == ISD::FSQRT) { - if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0), + if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) { RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV); AddToWorklist(RV.getNode()); @@ -11655,7 +8694,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { } } else if (N1.getOpcode() == ISD::FP_ROUND && N1.getOperand(0).getOpcode() == ISD::FSQRT) { - if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0), + if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) { RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1)); AddToWorklist(RV.getNode()); @@ -11676,7 +8715,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { if (SqrtOp.getNode()) { // We found a FSQRT, so try to make this fold: // x / (y * sqrt(z)) -> x * (rsqrt(z) / y) - if (SDValue RV = buildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) { + if (SDValue RV = BuildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) { RV = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, RV, OtherOp, Flags); AddToWorklist(RV.getNode()); return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); @@ -11719,26 +8758,41 @@ SDValue DAGCombiner::visitFREM(SDNode *N) { // fold (frem c1, c2) -> fmod(c1,c2) if (N0CFP && N1CFP) - return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1, N->getFlags()); - - if (SDValue NewSel = foldBinOpIntoSelect(N)) - return NewSel; + return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1, + &cast<BinaryWithFlagsSDNode>(N)->Flags); return SDValue(); } SDValue DAGCombiner::visitFSQRT(SDNode *N) { - SDNodeFlags Flags = N->getFlags(); - if (!DAG.getTarget().Options.UnsafeFPMath && - !Flags.hasApproximateFuncs()) + if (!DAG.getTarget().Options.UnsafeFPMath || TLI.isFsqrtCheap()) return SDValue(); - SDValue N0 = N->getOperand(0); - if (TLI.isFsqrtCheap(N0, DAG)) + // TODO: FSQRT nodes should have flags that propagate to the created nodes. + // For now, create a Flags object for use with all unsafe math transforms. + SDNodeFlags Flags; + Flags.setUnsafeAlgebra(true); + + // Compute this as X * (1/sqrt(X)) = X * (X ** -0.5) + SDValue RV = BuildRsqrtEstimate(N->getOperand(0), &Flags); + if (!RV) return SDValue(); - // FSQRT nodes have flags that propagate to the created nodes. - return buildSqrtEstimate(N0, Flags); + EVT VT = RV.getValueType(); + SDLoc DL(N); + RV = DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV, &Flags); + AddToWorklist(RV.getNode()); + + // Unfortunately, RV is now NaN if the input was exactly 0. + // Select out this case and force the answer to 0. + SDValue Zero = DAG.getConstantFP(0.0, DL, VT); + EVT CCVT = getSetCCResultType(VT); + SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, N->getOperand(0), Zero, ISD::SETEQ); + AddToWorklist(ZeroCmp.getNode()); + AddToWorklist(RV.getNode()); + + return DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, + ZeroCmp, Zero, RV); } /// copysign(x, fp_extend(y)) -> copysign(x, y) @@ -11752,7 +8806,7 @@ static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) { // value in one SSE register, but instruction selection cannot handle // FCOPYSIGN on SSE registers yet. EVT N1VT = N1->getValueType(0); - EVT N1Op0VT = N1->getOperand(0).getValueType(); + EVT N1Op0VT = N1->getOperand(0)->getValueType(0); return (N1VT == N1Op0VT || N1Op0VT != MVT::f128); } return false; @@ -11761,15 +8815,15 @@ static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) { SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - bool N0CFP = isConstantFPBuildVectorOrConstantFP(N0); - bool N1CFP = isConstantFPBuildVectorOrConstantFP(N1); + ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0); + ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1); EVT VT = N->getValueType(0); - if (N0CFP && N1CFP) // Constant fold + if (N0CFP && N1CFP) // Constant fold return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1); - if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N->getOperand(1))) { - const APFloat &V = N1C->getValueAPF(); + if (N1CFP) { + const APFloat& V = N1CFP->getValueAPF(); // copysign(x, c1) -> fabs(x) iff ispos(c1) // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1) if (!V.isNegative()) { @@ -11787,7 +8841,8 @@ SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) { // copysign(copysign(x,z), y) -> copysign(x, y) if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN) - return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0.getOperand(0), N1); + return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, + N0.getOperand(0), N1); // copysign(x, abs(y)) -> abs(x) if (N1.getOpcode() == ISD::FABS) @@ -11795,113 +8850,14 @@ SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) { // copysign(x, copysign(y,z)) -> copysign(x, z) if (N1.getOpcode() == ISD::FCOPYSIGN) - return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(1)); + return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, + N0, N1.getOperand(1)); // copysign(x, fp_extend(y)) -> copysign(x, y) // copysign(x, fp_round(y)) -> copysign(x, y) if (CanCombineFCOPYSIGN_EXTEND_ROUND(N)) - return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(0)); - - return SDValue(); -} - -SDValue DAGCombiner::visitFPOW(SDNode *N) { - ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1)); - if (!ExponentC) - return SDValue(); - - // Try to convert x ** (1/3) into cube root. - // TODO: Handle the various flavors of long double. - // TODO: Since we're approximating, we don't need an exact 1/3 exponent. - // Some range near 1/3 should be fine. - EVT VT = N->getValueType(0); - if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) || - (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) { - // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0. - // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf. - // pow(-val, 1/3) = nan; cbrt(-val) = -num. - // For regular numbers, rounding may cause the results to differ. - // Therefore, we require { nsz ninf nnan afn } for this transform. - // TODO: We could select out the special cases if we don't have nsz/ninf. - SDNodeFlags Flags = N->getFlags(); - if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() || - !Flags.hasApproximateFuncs()) - return SDValue(); - - // Do not create a cbrt() libcall if the target does not have it, and do not - // turn a pow that has lowering support into a cbrt() libcall. - if (!DAG.getLibInfo().has(LibFunc_cbrt) || - (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) && - DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT))) - return SDValue(); - - return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0), Flags); - } - - // Try to convert x ** (1/4) into square roots. - // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case. - // TODO: This could be extended (using a target hook) to handle smaller - // power-of-2 fractional exponents. - if (ExponentC->getValueAPF().isExactlyValue(0.25)) { - // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0. - // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) = NaN. - // For regular numbers, rounding may cause the results to differ. - // Therefore, we require { nsz ninf afn } for this transform. - // TODO: We could select out the special cases if we don't have nsz/ninf. - SDNodeFlags Flags = N->getFlags(); - if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || - !Flags.hasApproximateFuncs()) - return SDValue(); - - // Don't double the number of libcalls. We are trying to inline fast code. - if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT)) - return SDValue(); - - // Assume that libcalls are the smallest code. - // TODO: This restriction should probably be lifted for vectors. - if (DAG.getMachineFunction().getFunction().optForSize()) - return SDValue(); - - // pow(X, 0.25) --> sqrt(sqrt(X)) - SDLoc DL(N); - SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0), Flags); - return DAG.getNode(ISD::FSQRT, DL, VT, Sqrt, Flags); - } - - return SDValue(); -} - -static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG, - const TargetLowering &TLI) { - // This optimization is guarded by a function attribute because it may produce - // unexpected results. Ie, programs may be relying on the platform-specific - // undefined behavior when the float-to-int conversion overflows. - const Function &F = DAG.getMachineFunction().getFunction(); - Attribute StrictOverflow = F.getFnAttribute("strict-float-cast-overflow"); - if (StrictOverflow.getValueAsString().equals("false")) - return SDValue(); - - // We only do this if the target has legal ftrunc. Otherwise, we'd likely be - // replacing casts with a libcall. We also must be allowed to ignore -0.0 - // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer - // conversions would return +0.0. - // FIXME: We should be able to use node-level FMF here. - // TODO: If strict math, should we use FABS (+ range check for signed cast)? - EVT VT = N->getValueType(0); - if (!TLI.isOperationLegal(ISD::FTRUNC, VT) || - !DAG.getTarget().Options.NoSignedZerosFPMath) - return SDValue(); - - // fptosi/fptoui round towards zero, so converting from FP to integer and - // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X - SDValue N0 = N->getOperand(0); - if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT && - N0.getOperand(0).getValueType() == VT) - return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0)); - - if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT && - N0.getOperand(0).getValueType() == VT) - return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0)); + return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, + N0, N1.getOperand(0)); return SDValue(); } @@ -11912,16 +8868,16 @@ SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) { EVT OpVT = N0.getValueType(); // fold (sint_to_fp c1) -> c1fp - if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && + if (isConstantIntBuildVectorOrConstantInt(N0) && // ...but only if the target supports immediate floating-point values (!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) + TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT))) return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0); // If the input is a legal type, and SINT_TO_FP is not legal on this target, // but UINT_TO_FP is legal on this target, try to convert. - if (!hasOperation(ISD::SINT_TO_FP, OpVT) && - hasOperation(ISD::UINT_TO_FP, OpVT)) { + if (!TLI.isOperationLegalOrCustom(ISD::SINT_TO_FP, OpVT) && + TLI.isOperationLegalOrCustom(ISD::UINT_TO_FP, OpVT)) { // If the sign bit is known to be zero, we can change this to UINT_TO_FP. if (DAG.SignBitIsZero(N0)) return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0); @@ -11933,7 +8889,7 @@ SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) { if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 && !VT.isVector() && (!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) { + TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT))) { SDLoc DL(N); SDValue Ops[] = { N0.getOperand(0), N0.getOperand(1), @@ -11947,7 +8903,7 @@ SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) { if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.getOperand(0).getOpcode() == ISD::SETCC &&!VT.isVector() && (!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) { + TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT))) { SDLoc DL(N); SDValue Ops[] = { N0.getOperand(0).getOperand(0), N0.getOperand(0).getOperand(1), @@ -11957,9 +8913,6 @@ SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) { } } - if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI)) - return FTrunc; - return SDValue(); } @@ -11969,16 +8922,16 @@ SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) { EVT OpVT = N0.getValueType(); // fold (uint_to_fp c1) -> c1fp - if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && + if (isConstantIntBuildVectorOrConstantInt(N0) && // ...but only if the target supports immediate floating-point values (!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) + TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT))) return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0); // If the input is a legal type, and UINT_TO_FP is not legal on this target, // but SINT_TO_FP is legal on this target, try to convert. - if (!hasOperation(ISD::UINT_TO_FP, OpVT) && - hasOperation(ISD::SINT_TO_FP, OpVT)) { + if (!TLI.isOperationLegalOrCustom(ISD::UINT_TO_FP, OpVT) && + TLI.isOperationLegalOrCustom(ISD::SINT_TO_FP, OpVT)) { // If the sign bit is known to be zero, we can change this to SINT_TO_FP. if (DAG.SignBitIsZero(N0)) return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0); @@ -11987,9 +8940,10 @@ SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) { // The next optimizations are desirable only if SELECT_CC can be lowered. if (TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT) || !LegalOperations) { // fold (uint_to_fp (setcc x, y, cc)) -> (select_cc x, y, -1.0, 0.0,, cc) + if (N0.getOpcode() == ISD::SETCC && !VT.isVector() && (!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) { + TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT))) { SDLoc DL(N); SDValue Ops[] = { N0.getOperand(0), N0.getOperand(1), @@ -11999,9 +8953,6 @@ SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) { } } - if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI)) - return FTrunc; - return SDValue(); } @@ -12042,7 +8993,9 @@ static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) { } if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits()) return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Src); - return DAG.getBitcast(VT, Src); + if (SrcVT == VT) + return Src; + return DAG.getNode(ISD::BITCAST, SDLoc(N), VT, Src); } return SDValue(); } @@ -12086,18 +9039,7 @@ SDValue DAGCombiner::visitFP_ROUND(SDNode *N) { // fold (fp_round (fp_round x)) -> (fp_round x) if (N0.getOpcode() == ISD::FP_ROUND) { const bool NIsTrunc = N->getConstantOperandVal(1) == 1; - const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1; - - // Skip this folding if it results in an fp_round from f80 to f16. - // - // f80 to f16 always generates an expensive (and as yet, unimplemented) - // libcall to __truncxfhf2 instead of selecting native f16 conversion - // instructions from f32 or f64. Moreover, the first (value-preserving) - // fp_round from f80 to either f32 or f64 may become a NOP in platforms like - // x86. - if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16) - return SDValue(); - + const bool N0IsTrunc = N0.getNode()->getConstantOperandVal(1) == 1; // If the first fp_round isn't a value preserving truncation, it might // introduce a tie in the second fp_round, that wouldn't occur in the // single-step fp_round we want to fold to. @@ -12119,9 +9061,6 @@ SDValue DAGCombiner::visitFP_ROUND(SDNode *N) { Tmp, N0.getOperand(1)); } - if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N)) - return NewVSel; - return SDValue(); } @@ -12162,7 +9101,7 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) { // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the // value of X. if (N0.getOpcode() == ISD::FP_ROUND - && N0.getConstantOperandVal(1) == 1) { + && N0.getNode()->getConstantOperandVal(1) == 1) { SDValue In = N0.getOperand(0); if (In.getValueType() == VT) return In; if (VT.bitsLT(In.getValueType())) @@ -12188,9 +9127,6 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) { return SDValue(N, 0); // Return N so it doesn't get rechecked! } - if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N)) - return NewVSel; - return SDValue(); } @@ -12213,19 +9149,6 @@ SDValue DAGCombiner::visitFTRUNC(SDNode *N) { if (isConstantFPBuildVectorOrConstantFP(N0)) return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0); - // fold ftrunc (known rounded int x) -> x - // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is - // likely to be generated to extract integer from a rounded floating value. - switch (N0.getOpcode()) { - default: break; - case ISD::FRINT: - case ISD::FTRUNC: - case ISD::FNEARBYINT: - case ISD::FFLOOR: - case ISD::FCEIL: - return N0; - } - return SDValue(); } @@ -12265,17 +9188,17 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) { if (N0.getValueType().isVector()) { // For a vector, get a mask such as 0x80... per scalar element // and splat it. - SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits()); + SignMask = APInt::getSignBit(N0.getValueType().getScalarSizeInBits()); SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask); } else { // For a scalar, just generate 0x80... - SignMask = APInt::getSignMask(IntVT.getSizeInBits()); + SignMask = APInt::getSignBit(IntVT.getSizeInBits()); } SDLoc DL0(N0); Int = DAG.getNode(ISD::XOR, DL0, IntVT, Int, DAG.getConstant(SignMask, DL0, IntVT)); AddToWorklist(Int.getNode()); - return DAG.getBitcast(VT, Int); + return DAG.getNode(ISD::BITCAST, SDLoc(N), VT, Int); } } @@ -12289,18 +9212,17 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) { if (Level >= AfterLegalizeDAG && (TLI.isFPImmLegal(CVal, VT) || TLI.isOperationLegal(ISD::ConstantFP, VT))) - return DAG.getNode( - ISD::FMUL, SDLoc(N), VT, N0.getOperand(0), - DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0.getOperand(1)), - N0->getFlags()); + return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0.getOperand(0), + DAG.getNode(ISD::FNEG, SDLoc(N), VT, + N0.getOperand(1)), + &cast<BinaryWithFlagsSDNode>(N0)->Flags); } } return SDValue(); } -static SDValue visitFMinMax(SelectionDAG &DAG, SDNode *N, - APFloat (*Op)(const APFloat &, const APFloat &)) { +SDValue DAGCombiner::visitFMINNUM(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); @@ -12310,31 +9232,36 @@ static SDValue visitFMinMax(SelectionDAG &DAG, SDNode *N, if (N0CFP && N1CFP) { const APFloat &C0 = N0CFP->getValueAPF(); const APFloat &C1 = N1CFP->getValueAPF(); - return DAG.getConstantFP(Op(C0, C1), SDLoc(N), VT); + return DAG.getConstantFP(minnum(C0, C1), SDLoc(N), VT); } // Canonicalize to constant on RHS. if (isConstantFPBuildVectorOrConstantFP(N0) && - !isConstantFPBuildVectorOrConstantFP(N1)) - return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0); + !isConstantFPBuildVectorOrConstantFP(N1)) + return DAG.getNode(ISD::FMINNUM, SDLoc(N), VT, N1, N0); return SDValue(); } -SDValue DAGCombiner::visitFMINNUM(SDNode *N) { - return visitFMinMax(DAG, N, minnum); -} - SDValue DAGCombiner::visitFMAXNUM(SDNode *N) { - return visitFMinMax(DAG, N, maxnum); -} + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N->getValueType(0); + const ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0); + const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1); -SDValue DAGCombiner::visitFMINIMUM(SDNode *N) { - return visitFMinMax(DAG, N, minimum); -} + if (N0CFP && N1CFP) { + const APFloat &C0 = N0CFP->getValueAPF(); + const APFloat &C1 = N1CFP->getValueAPF(); + return DAG.getConstantFP(maxnum(C0, C1), SDLoc(N), VT); + } + + // Canonicalize to constant on RHS. + if (isConstantFPBuildVectorOrConstantFP(N0) && + !isConstantFPBuildVectorOrConstantFP(N1)) + return DAG.getNode(ISD::FMAXNUM, SDLoc(N), VT, N1, N0); -SDValue DAGCombiner::visitFMAXIMUM(SDNode *N) { - return visitFMinMax(DAG, N, maximum); + return SDValue(); } SDValue DAGCombiner::visitFABS(SDNode *N) { @@ -12354,8 +9281,11 @@ SDValue DAGCombiner::visitFABS(SDNode *N) { if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN) return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0.getOperand(0)); - // fabs(bitcast(x)) -> bitcast(x & ~sign) to avoid constant pool loads. - if (!TLI.isFAbsFree(VT) && N0.getOpcode() == ISD::BITCAST && N0.hasOneUse()) { + // Transform fabs(bitconvert(x)) -> bitconvert(x & ~sign) to avoid loading + // constant pool values. + if (!TLI.isFAbsFree(VT) && + N0.getOpcode() == ISD::BITCAST && + N0.getNode()->hasOneUse()) { SDValue Int = N0.getOperand(0); EVT IntVT = Int.getValueType(); if (IntVT.isInteger() && !IntVT.isVector()) { @@ -12363,17 +9293,17 @@ SDValue DAGCombiner::visitFABS(SDNode *N) { if (N0.getValueType().isVector()) { // For a vector, get a mask such as 0x7f... per scalar element // and splat it. - SignMask = ~APInt::getSignMask(N0.getScalarValueSizeInBits()); + SignMask = ~APInt::getSignBit(N0.getValueType().getScalarSizeInBits()); SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask); } else { // For a scalar, just generate 0x7f... - SignMask = ~APInt::getSignMask(IntVT.getSizeInBits()); + SignMask = ~APInt::getSignBit(IntVT.getSizeInBits()); } SDLoc DL(N0); Int = DAG.getNode(ISD::AND, DL, IntVT, Int, DAG.getConstant(SignMask, DL, IntVT)); AddToWorklist(Int.getNode()); - return DAG.getBitcast(N->getValueType(0), Int); + return DAG.getNode(ISD::BITCAST, SDLoc(N), N->getValueType(0), Int); } } @@ -12401,22 +9331,16 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) { N1.getOperand(0), N1.getOperand(1), N2); } - if (N1.hasOneUse()) { - if (SDValue NewN1 = rebuildSetCC(N1)) - return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain, NewN1, N2); - } - - return SDValue(); -} - -SDValue DAGCombiner::rebuildSetCC(SDValue N) { - if (N.getOpcode() == ISD::SRL || - (N.getOpcode() == ISD::TRUNCATE && - (N.getOperand(0).hasOneUse() && - N.getOperand(0).getOpcode() == ISD::SRL))) { - // Look pass the truncate. - if (N.getOpcode() == ISD::TRUNCATE) - N = N.getOperand(0); + if ((N1.hasOneUse() && N1.getOpcode() == ISD::SRL) || + ((N1.getOpcode() == ISD::TRUNCATE && N1.hasOneUse()) && + (N1.getOperand(0).hasOneUse() && + N1.getOperand(0).getOpcode() == ISD::SRL))) { + SDNode *Trunc = nullptr; + if (N1.getOpcode() == ISD::TRUNCATE) { + // Look pass the truncate. + Trunc = N1.getNode(); + N1 = N1.getOperand(0); + } // Match this pattern so that we can generate simpler code: // @@ -12435,55 +9359,74 @@ SDValue DAGCombiner::rebuildSetCC(SDValue N) { // This applies only when the AND constant value has one bit set and the // SRL constant is equal to the log2 of the AND constant. The back-end is // smart enough to convert the result into a TEST/JMP sequence. - SDValue Op0 = N.getOperand(0); - SDValue Op1 = N.getOperand(1); + SDValue Op0 = N1.getOperand(0); + SDValue Op1 = N1.getOperand(1); - if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) { + if (Op0.getOpcode() == ISD::AND && + Op1.getOpcode() == ISD::Constant) { SDValue AndOp1 = Op0.getOperand(1); if (AndOp1.getOpcode() == ISD::Constant) { const APInt &AndConst = cast<ConstantSDNode>(AndOp1)->getAPIntValue(); if (AndConst.isPowerOf2() && - cast<ConstantSDNode>(Op1)->getAPIntValue() == AndConst.logBase2()) { + cast<ConstantSDNode>(Op1)->getAPIntValue()==AndConst.logBase2()) { SDLoc DL(N); - return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()), - Op0, DAG.getConstant(0, DL, Op0.getValueType()), - ISD::SETNE); + SDValue SetCC = + DAG.getSetCC(DL, + getSetCCResultType(Op0.getValueType()), + Op0, DAG.getConstant(0, DL, Op0.getValueType()), + ISD::SETNE); + + SDValue NewBRCond = DAG.getNode(ISD::BRCOND, DL, + MVT::Other, Chain, SetCC, N2); + // Don't add the new BRCond into the worklist or else SimplifySelectCC + // will convert it back to (X & C1) >> C2. + CombineTo(N, NewBRCond, false); + // Truncate is dead. + if (Trunc) + deleteAndRecombine(Trunc); + // Replace the uses of SRL with SETCC + WorklistRemover DeadNodes(*this); + DAG.ReplaceAllUsesOfValueWith(N1, SetCC); + deleteAndRecombine(N1.getNode()); + return SDValue(N, 0); // Return N so it doesn't get rechecked! } } } + + if (Trunc) + // Restore N1 if the above transformation doesn't match. + N1 = N->getOperand(1); } // Transform br(xor(x, y)) -> br(x != y) // Transform br(xor(xor(x,y), 1)) -> br (x == y) - if (N.getOpcode() == ISD::XOR) { - // Because we may call this on a speculatively constructed - // SimplifiedSetCC Node, we need to simplify this node first. - // Ideally this should be folded into SimplifySetCC and not - // here. For now, grab a handle to N so we don't lose it from - // replacements interal to the visit. - HandleSDNode XORHandle(N); - while (N.getOpcode() == ISD::XOR) { - SDValue Tmp = visitXOR(N.getNode()); - // No simplification done. - if (!Tmp.getNode()) - break; - // Returning N is form in-visit replacement that may invalidated - // N. Grab value from Handle. - if (Tmp.getNode() == N.getNode()) - N = XORHandle.getValue(); - else // Node simplified. Try simplifying again. - N = Tmp; - } - - if (N.getOpcode() != ISD::XOR) - return N; - - SDNode *TheXor = N.getNode(); - + if (N1.hasOneUse() && N1.getOpcode() == ISD::XOR) { + SDNode *TheXor = N1.getNode(); SDValue Op0 = TheXor->getOperand(0); SDValue Op1 = TheXor->getOperand(1); + if (Op0.getOpcode() == Op1.getOpcode()) { + // Avoid missing important xor optimizations. + if (SDValue Tmp = visitXOR(TheXor)) { + if (Tmp.getNode() != TheXor) { + DEBUG(dbgs() << "\nReplacing.8 "; + TheXor->dump(&DAG); + dbgs() << "\nWith: "; + Tmp.getNode()->dump(&DAG); + dbgs() << '\n'); + WorklistRemover DeadNodes(*this); + DAG.ReplaceAllUsesOfValueWith(N1, Tmp); + deleteAndRecombine(TheXor); + return DAG.getNode(ISD::BRCOND, SDLoc(N), + MVT::Other, Chain, Tmp, N2); + } + + // visitXOR has changed XOR's operands or replaced the XOR completely, + // bail out. + return SDValue(N, 0); + } + } if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) { bool Equal = false; @@ -12493,12 +9436,19 @@ SDValue DAGCombiner::rebuildSetCC(SDValue N) { Equal = true; } - EVT SetCCVT = N.getValueType(); + EVT SetCCVT = N1.getValueType(); if (LegalTypes) SetCCVT = getSetCCResultType(SetCCVT); + SDValue SetCC = DAG.getSetCC(SDLoc(TheXor), + SetCCVT, + Op0, Op1, + Equal ? ISD::SETEQ : ISD::SETNE); // Replace the uses of XOR with SETCC - return DAG.getSetCC(SDLoc(TheXor), SetCCVT, Op0, Op1, - Equal ? ISD::SETEQ : ISD::SETNE); + WorklistRemover DeadNodes(*this); + DAG.ReplaceAllUsesOfValueWith(N1, SetCC); + deleteAndRecombine(N1.getNode()); + return DAG.getNode(ISD::BRCOND, SDLoc(N), + MVT::Other, Chain, SetCC, N2); } } @@ -12657,11 +9607,6 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) { return false; } - // Caches for hasPredecessorHelper. - SmallPtrSet<const SDNode *, 32> Visited; - SmallVector<const SDNode *, 16> Worklist; - Worklist.push_back(N); - // If the offset is a constant, there may be other adds of constants that // can be folded with this one. We should do this to avoid having to keep // a copy of the original base pointer. @@ -12676,7 +9621,7 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) { if (Use.getUser() == Ptr.getNode() || Use != BasePtr) continue; - if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist)) + if (Use.getUser()->isPredecessorOf(N)) continue; if (Use.getUser()->getOpcode() != ISD::ADD && @@ -12706,10 +9651,14 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) { // Now check for #3 and #4. bool RealUse = false; + // Caches for hasPredecessorHelper + SmallPtrSet<const SDNode *, 32> Visited; + SmallVector<const SDNode *, 16> Worklist; + for (SDNode *Use : Ptr.getNode()->uses()) { if (Use == N) continue; - if (SDNode::hasPredecessorHelper(Use, Visited, Worklist)) + if (N->hasPredecessorHelper(Use, Visited, Worklist)) return false; // If Ptr may be folded in addressing mode of other use, then it's @@ -12730,8 +9679,11 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) { BasePtr, Offset, AM); ++PreIndexedNodes; ++NodesCombined; - LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: "; - Result.getNode()->dump(&DAG); dbgs() << '\n'); + DEBUG(dbgs() << "\nReplacing.4 "; + N->dump(&DAG); + dbgs() << "\nWith: "; + Result.getNode()->dump(&DAG); + dbgs() << '\n'); WorklistRemover DeadNodes(*this); if (isLoad) { DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0)); @@ -12760,7 +9712,7 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) { // x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store) // // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the - // indexed load/store and the expression that needs to be re-written. + // indexed load/store and the expresion that needs to be re-written. // // Therefore, we have: // t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1 @@ -12768,7 +9720,7 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) { ConstantSDNode *CN = cast<ConstantSDNode>(OtherUses[i]->getOperand(OffsetIdx)); int X0, X1, Y0, Y1; - const APInt &Offset0 = CN->getAPIntValue(); + APInt Offset0 = CN->getAPIntValue(); APInt Offset1 = cast<ConstantSDNode>(Offset)->getAPIntValue(); X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1; @@ -12799,7 +9751,6 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) { // Replace the uses of Ptr with uses of the updated base value. DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(isLoad ? 1 : 0)); deleteAndRecombine(Ptr.getNode()); - AddToWorklist(Result.getNode()); return true; } @@ -12887,15 +9838,8 @@ bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) { if (TryNext) continue; - // Check for #2. - SmallPtrSet<const SDNode *, 32> Visited; - SmallVector<const SDNode *, 8> Worklist; - // Ptr is predecessor to both N and Op. - Visited.insert(Ptr.getNode()); - Worklist.push_back(N); - Worklist.push_back(Op); - if (!SDNode::hasPredecessorHelper(N, Visited, Worklist) && - !SDNode::hasPredecessorHelper(Op, Visited, Worklist)) { + // Check for #2 + if (!Op->isPredecessorOf(N) && !N->isPredecessorOf(Op)) { SDValue Result = isLoad ? DAG.getIndexedLoad(SDValue(N,0), SDLoc(N), BasePtr, Offset, AM) @@ -12903,9 +9847,11 @@ bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) { BasePtr, Offset, AM); ++PostIndexedNodes; ++NodesCombined; - LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); - dbgs() << "\nWith: "; Result.getNode()->dump(&DAG); - dbgs() << '\n'); + DEBUG(dbgs() << "\nReplacing.5 "; + N->dump(&DAG); + dbgs() << "\nWith: "; + Result.getNode()->dump(&DAG); + dbgs() << '\n'); WorklistRemover DeadNodes(*this); if (isLoad) { DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0)); @@ -12929,7 +9875,7 @@ bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) { return false; } -/// Return the base-pointer arithmetic from an indexed \p LD. +/// \brief Return the base-pointer arithmetic from an indexed \p LD. SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) { ISD::MemIndexedMode AM = LD->getAddressingMode(); assert(AM != ISD::UNINDEXED); @@ -12953,157 +9899,6 @@ SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) { return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc); } -static inline int numVectorEltsOrZero(EVT T) { - return T.isVector() ? T.getVectorNumElements() : 0; -} - -bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) { - Val = ST->getValue(); - EVT STType = Val.getValueType(); - EVT STMemType = ST->getMemoryVT(); - if (STType == STMemType) - return true; - if (isTypeLegal(STMemType)) - return false; // fail. - if (STType.isFloatingPoint() && STMemType.isFloatingPoint() && - TLI.isOperationLegal(ISD::FTRUNC, STMemType)) { - Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val); - return true; - } - if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) && - STType.isInteger() && STMemType.isInteger()) { - Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val); - return true; - } - if (STType.getSizeInBits() == STMemType.getSizeInBits()) { - Val = DAG.getBitcast(STMemType, Val); - return true; - } - return false; // fail. -} - -bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) { - EVT LDMemType = LD->getMemoryVT(); - EVT LDType = LD->getValueType(0); - assert(Val.getValueType() == LDMemType && - "Attempting to extend value of non-matching type"); - if (LDType == LDMemType) - return true; - if (LDMemType.isInteger() && LDType.isInteger()) { - switch (LD->getExtensionType()) { - case ISD::NON_EXTLOAD: - Val = DAG.getBitcast(LDType, Val); - return true; - case ISD::EXTLOAD: - Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val); - return true; - case ISD::SEXTLOAD: - Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val); - return true; - case ISD::ZEXTLOAD: - Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val); - return true; - } - } - return false; -} - -SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) { - if (OptLevel == CodeGenOpt::None || LD->isVolatile()) - return SDValue(); - SDValue Chain = LD->getOperand(0); - StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain.getNode()); - if (!ST || ST->isVolatile()) - return SDValue(); - - EVT LDType = LD->getValueType(0); - EVT LDMemType = LD->getMemoryVT(); - EVT STMemType = ST->getMemoryVT(); - EVT STType = ST->getValue().getValueType(); - - BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG); - BaseIndexOffset BasePtrST = BaseIndexOffset::match(ST, DAG); - int64_t Offset; - if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset)) - return SDValue(); - - // Normalize for Endianness. After this Offset=0 will denote that the least - // significant bit in the loaded value maps to the least significant bit in - // the stored value). With Offset=n (for n > 0) the loaded value starts at the - // n:th least significant byte of the stored value. - if (DAG.getDataLayout().isBigEndian()) - Offset = (STMemType.getStoreSizeInBits() - - LDMemType.getStoreSizeInBits()) / 8 - Offset; - - // Check that the stored value cover all bits that are loaded. - bool STCoversLD = - (Offset >= 0) && - (Offset * 8 + LDMemType.getSizeInBits() <= STMemType.getSizeInBits()); - - auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue { - if (LD->isIndexed()) { - bool IsSub = (LD->getAddressingMode() == ISD::PRE_DEC || - LD->getAddressingMode() == ISD::POST_DEC); - unsigned Opc = IsSub ? ISD::SUB : ISD::ADD; - SDValue Idx = DAG.getNode(Opc, SDLoc(LD), LD->getOperand(1).getValueType(), - LD->getOperand(1), LD->getOperand(2)); - SDValue Ops[] = {Val, Idx, Chain}; - return CombineTo(LD, Ops, 3); - } - return CombineTo(LD, Val, Chain); - }; - - if (!STCoversLD) - return SDValue(); - - // Memory as copy space (potentially masked). - if (Offset == 0 && LDType == STType && STMemType == LDMemType) { - // Simple case: Direct non-truncating forwarding - if (LDType.getSizeInBits() == LDMemType.getSizeInBits()) - return ReplaceLd(LD, ST->getValue(), Chain); - // Can we model the truncate and extension with an and mask? - if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() && - !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) { - // Mask to size of LDMemType - auto Mask = - DAG.getConstant(APInt::getLowBitsSet(STType.getSizeInBits(), - STMemType.getSizeInBits()), - SDLoc(ST), STType); - auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask); - return ReplaceLd(LD, Val, Chain); - } - } - - // TODO: Deal with nonzero offset. - if (LD->getBasePtr().isUndef() || Offset != 0) - return SDValue(); - // Model necessary truncations / extenstions. - SDValue Val; - // Truncate Value To Stored Memory Size. - do { - if (!getTruncatedStoreValue(ST, Val)) - continue; - if (!isTypeLegal(LDMemType)) - continue; - if (STMemType != LDMemType) { - // TODO: Support vectors? This requires extract_subvector/bitcast. - if (!STMemType.isVector() && !LDMemType.isVector() && - STMemType.isInteger() && LDMemType.isInteger()) - Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val); - else - continue; - } - if (!extendLoadedValueToExtension(LD, Val)) - continue; - return ReplaceLd(LD, Val, Chain); - } while (false); - - // On failure, cleanup dead nodes we may have created. - if (Val->use_empty()) - deleteAndRecombine(Val.getNode()); - return SDValue(); -} - SDValue DAGCombiner::visitLOAD(SDNode *N) { LoadSDNode *LD = cast<LoadSDNode>(N); SDValue Chain = LD->getChain(); @@ -13122,12 +9917,14 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { // v3 = add v2, c // Now we replace use of chain2 with chain1. This makes the second load // isomorphic to the one we are deleting, and thus makes this load live. - LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG); - dbgs() << "\nWith chain: "; Chain.getNode()->dump(&DAG); - dbgs() << "\n"); + DEBUG(dbgs() << "\nReplacing.6 "; + N->dump(&DAG); + dbgs() << "\nWith chain: "; + Chain.getNode()->dump(&DAG); + dbgs() << "\n"); WorklistRemover DeadNodes(*this); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain); - AddUsersToWorklist(Chain.getNode()); + if (N->use_empty()) deleteAndRecombine(N); @@ -13155,9 +9952,11 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { AddUsersToWorklist(N); } else Index = DAG.getUNDEF(N->getValueType(1)); - LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG); - dbgs() << "\nWith: "; Undef.getNode()->dump(&DAG); - dbgs() << " and 2 other values\n"); + DEBUG(dbgs() << "\nReplacing.7 "; + N->dump(&DAG); + dbgs() << "\nWith: "; + Undef.getNode()->dump(&DAG); + dbgs() << " and 2 other values\n"); WorklistRemover DeadNodes(*this); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index); @@ -13170,25 +9969,42 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { // If this load is directly stored, replace the load value with the stored // value. - if (auto V = ForwardStoreValueToDirectLoad(LD)) - return V; + // TODO: Handle store large -> read small portion. + // TODO: Handle TRUNCSTORE/LOADEXT + if (ISD::isNormalLoad(N) && !LD->isVolatile()) { + if (ISD::isNON_TRUNCStore(Chain.getNode())) { + StoreSDNode *PrevST = cast<StoreSDNode>(Chain); + if (PrevST->getBasePtr() == Ptr && + PrevST->getValue().getValueType() == N->getValueType(0)) + return CombineTo(N, Chain.getOperand(1), Chain); + } + } // Try to infer better alignment information than the load already has. if (OptLevel != CodeGenOpt::None && LD->isUnindexed()) { if (unsigned Align = DAG.InferPtrAlignment(Ptr)) { - if (Align > LD->getAlignment() && LD->getSrcValueOffset() % Align == 0) { - SDValue NewLoad = DAG.getExtLoad( - LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr, - LD->getPointerInfo(), LD->getMemoryVT(), Align, - LD->getMemOperand()->getFlags(), LD->getAAInfo()); - // NewLoad will always be N as we are only refining the alignment - assert(NewLoad.getNode() == N); - (void)NewLoad; + if (Align > LD->getMemOperand()->getBaseAlignment()) { + SDValue NewLoad = + DAG.getExtLoad(LD->getExtensionType(), SDLoc(N), + LD->getValueType(0), + Chain, Ptr, LD->getPointerInfo(), + LD->getMemoryVT(), + LD->isVolatile(), LD->isNonTemporal(), + LD->isInvariant(), Align, LD->getAAInfo()); + if (NewLoad.getNode() != N) + return CombineTo(N, NewLoad, SDValue(NewLoad.getNode(), 1), true); } } } - if (LD->isUnindexed()) { + bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA + : DAG.getSubtarget().useAA(); +#ifndef NDEBUG + if (CombinerAAOnlyFunc.getNumOccurrences() && + CombinerAAOnlyFunc != DAG.getMachineFunction().getName()) + UseAA = false; +#endif + if (UseAA && LD->isUnindexed()) { // Walk up chain skipping non-aliasing memory nodes. SDValue BetterChain = FindBetterChain(N, Chain); @@ -13211,8 +10027,12 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Chain, ReplLoad.getValue(1)); - // Replace uses with load result and token factor - return CombineTo(N, ReplLoad.getValue(0), Token); + // Make sure the new and old chains are cleaned up. + AddToWorklist(Token.getNode()); + + // Replace uses with load result and token factor. Don't add users + // to work list. + return CombineTo(N, ReplLoad.getValue(0), Token, false); } } @@ -13229,37 +10049,38 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { } namespace { - -/// Helper structure used to slice a load in smaller loads. +/// \brief Helper structure used to slice a load in smaller loads. /// Basically a slice is obtained from the following sequence: /// Origin = load Ty1, Base /// Shift = srl Ty1 Origin, CstTy Amount /// Inst = trunc Shift to Ty2 /// -/// Then, it will be rewritten into: +/// Then, it will be rewriten into: /// Slice = load SliceTy, Base + SliceOffset /// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2 /// /// SliceTy is deduced from the number of bits that are actually used to /// build Inst. struct LoadedSlice { - /// Helper structure used to compute the cost of a slice. + /// \brief Helper structure used to compute the cost of a slice. struct Cost { /// Are we optimizing for code size. bool ForCodeSize; - /// Various cost. - unsigned Loads = 0; - unsigned Truncates = 0; - unsigned CrossRegisterBanksCopies = 0; - unsigned ZExts = 0; - unsigned Shift = 0; + unsigned Loads; + unsigned Truncates; + unsigned CrossRegisterBanksCopies; + unsigned ZExts; + unsigned Shift; - Cost(bool ForCodeSize = false) : ForCodeSize(ForCodeSize) {} + Cost(bool ForCodeSize = false) + : ForCodeSize(ForCodeSize), Loads(0), Truncates(0), + CrossRegisterBanksCopies(0), ZExts(0), Shift(0) {} - /// Get the cost of one isolated slice. + /// \brief Get the cost of one isolated slice. Cost(const LoadedSlice &LS, bool ForCodeSize = false) - : ForCodeSize(ForCodeSize), Loads(1) { + : ForCodeSize(ForCodeSize), Loads(1), Truncates(0), + CrossRegisterBanksCopies(0), ZExts(0), Shift(0) { EVT TruncType = LS.Inst->getValueType(0); EVT LoadedType = LS.getLoadedType(); if (TruncType != LoadedType && @@ -13267,7 +10088,7 @@ struct LoadedSlice { ZExts = 1; } - /// Account for slicing gain in the current cost. + /// \brief Account for slicing gain in the current cost. /// Slicing provide a few gains like removing a shift or a /// truncate. This method allows to grow the cost of the original /// load with the gain from this slice. @@ -13321,17 +10142,13 @@ struct LoadedSlice { bool operator>=(const Cost &RHS) const { return !(*this < RHS); } }; - // The last instruction that represent the slice. This should be a // truncate instruction. SDNode *Inst; - // The original load instruction. LoadSDNode *Origin; - // The right shift amount in bits from the original load. unsigned Shift; - // The DAG from which Origin came from. // This is used to get some contextual information about legal types, etc. SelectionDAG *DAG; @@ -13340,7 +10157,7 @@ struct LoadedSlice { unsigned Shift = 0, SelectionDAG *DAG = nullptr) : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {} - /// Get the bits used in a chunk of bits \p BitWidth large. + /// \brief Get the bits used in a chunk of bits \p BitWidth large. /// \return Result is \p BitWidth and has used bits set to 1 and /// not used bits set to 0. APInt getUsedBits() const { @@ -13360,14 +10177,14 @@ struct LoadedSlice { return UsedBits; } - /// Get the size of the slice to be loaded in bytes. + /// \brief Get the size of the slice to be loaded in bytes. unsigned getLoadedSize() const { unsigned SliceSize = getUsedBits().countPopulation(); assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte."); return SliceSize / 8; } - /// Get the type that will be loaded for this slice. + /// \brief Get the type that will be loaded for this slice. /// Note: This may not be the final type for the slice. EVT getLoadedType() const { assert(DAG && "Missing context"); @@ -13375,7 +10192,7 @@ struct LoadedSlice { return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8); } - /// Get the alignment of the load used for this slice. + /// \brief Get the alignment of the load used for this slice. unsigned getAlignment() const { unsigned Alignment = Origin->getAlignment(); unsigned Offset = getOffsetFromBase(); @@ -13384,14 +10201,14 @@ struct LoadedSlice { return Alignment; } - /// Check if this slice can be rewritten with legal operations. + /// \brief Check if this slice can be rewritten with legal operations. bool isLegal() const { // An invalid slice is not legal. if (!Origin || !Inst || !DAG) return false; // Offsets are for indexed load only, we do not handle that. - if (!Origin->getOffset().isUndef()) + if (Origin->getOffset().getOpcode() != ISD::UNDEF) return false; const TargetLowering &TLI = DAG->getTargetLoweringInfo(); @@ -13428,7 +10245,7 @@ struct LoadedSlice { return true; } - /// Get the offset in bytes of this slice in the original chunk of + /// \brief Get the offset in bytes of this slice in the original chunk of /// bits. /// \pre DAG != nullptr. uint64_t getOffsetFromBase() const { @@ -13449,7 +10266,7 @@ struct LoadedSlice { return Offset; } - /// Generate the sequence of instructions to load the slice + /// \brief Generate the sequence of instructions to load the slice /// represented by this object and redirect the uses of this slice to /// this new sequence of instructions. /// \pre this->Inst && this->Origin are valid Instructions and this @@ -13459,7 +10276,7 @@ struct LoadedSlice { assert(Inst && Origin && "Unable to replace a non-existing slice."); const SDValue &OldBaseAddr = Origin->getBasePtr(); SDValue BaseAddr = OldBaseAddr; - // Get the offset in that chunk of bytes w.r.t. the endianness. + // Get the offset in that chunk of bytes w.r.t. the endianess. int64_t Offset = static_cast<int64_t>(getOffsetFromBase()); assert(Offset >= 0 && "Offset too big to fit in int64_t!"); if (Offset) { @@ -13474,10 +10291,10 @@ struct LoadedSlice { EVT SliceType = getLoadedType(); // Create the load for the slice. - SDValue LastInst = - DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr, - Origin->getPointerInfo().getWithOffset(Offset), - getAlignment(), Origin->getMemOperand()->getFlags()); + SDValue LastInst = DAG->getLoad( + SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr, + Origin->getPointerInfo().getWithOffset(Offset), Origin->isVolatile(), + Origin->isNonTemporal(), Origin->isInvariant(), getAlignment()); // If the final type is not the same as the loaded type, this means that // we have to pad with zero. Create a zero extend for that. EVT FinalType = Inst->getValueType(0); @@ -13487,7 +10304,7 @@ struct LoadedSlice { return LastInst; } - /// Check if this slice can be merged with an expensive cross register + /// \brief Check if this slice can be merged with an expensive cross register /// bank copy. E.g., /// i = load i32 /// f = bitcast i32 i to float @@ -13533,10 +10350,9 @@ struct LoadedSlice { return true; } }; +} -} // end anonymous namespace - -/// Check that all bits set in \p UsedBits form a dense region, i.e., +/// \brief Check that all bits set in \p UsedBits form a dense region, i.e., /// \p UsedBits looks like 0..0 1..1 0..0. static bool areUsedBitsDense(const APInt &UsedBits) { // If all the bits are one, this is dense! @@ -13552,7 +10368,7 @@ static bool areUsedBitsDense(const APInt &UsedBits) { return NarrowedUsedBits.isAllOnesValue(); } -/// Check whether or not \p First and \p Second are next to each other +/// \brief Check whether or not \p First and \p Second are next to each other /// in memory. This means that there is no hole between the bits loaded /// by \p First and the bits loaded by \p Second. static bool areSlicesNextToEachOther(const LoadedSlice &First, @@ -13566,7 +10382,7 @@ static bool areSlicesNextToEachOther(const LoadedSlice &First, return areUsedBitsDense(UsedBits); } -/// Adjust the \p GlobalLSCost according to the target +/// \brief Adjust the \p GlobalLSCost according to the target /// paring capabilities and the layout of the slices. /// \pre \p GlobalLSCost should account for at least as many loads as /// there is in the slices in \p LoadedSlices. @@ -13579,7 +10395,8 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices, // Sort the slices so that elements that are likely to be next to each // other in memory are next to each other in the list. - llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) { + std::sort(LoadedSlices.begin(), LoadedSlices.end(), + [](const LoadedSlice &LHS, const LoadedSlice &RHS) { assert(LHS.Origin == RHS.Origin && "Different bases not implemented."); return LHS.getOffsetFromBase() < RHS.getOffsetFromBase(); }); @@ -13591,6 +10408,7 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices, for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice, // Set the beginning of the pair. First = Second) { + Second = &LoadedSlices[CurrSlice]; // If First is NULL, it means we start a new pair. @@ -13626,7 +10444,7 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices, } } -/// Check the profitability of all involved LoadedSlice. +/// \brief Check the profitability of all involved LoadedSlice. /// Currently, it is considered profitable if there is exactly two /// involved slices (1) which are (2) next to each other in memory, and /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3). @@ -13670,7 +10488,7 @@ static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices, return OrigCost > GlobalSlicingCost; } -/// If the given load, \p LI, is used only by trunc or trunc(lshr) +/// \brief If the given load, \p LI, is used only by trunc or trunc(lshr) /// operations, split it in the various pieces being extracted. /// /// This sort of thing is introduced by SROA. @@ -13706,7 +10524,7 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) { // Check if this is a trunc(lshr). if (User->getOpcode() == ISD::SRL && User->hasOneUse() && isa<ConstantSDNode>(User->getOperand(1))) { - Shift = User->getConstantOperandVal(1); + Shift = cast<ConstantSDNode>(User->getOperand(1))->getZExtValue(); User = *User->use_begin(); } @@ -13721,7 +10539,7 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) { // will be across several bytes. We do not support that. unsigned Width = User->getValueSizeInBits(0); if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7)) - return false; + return 0; // Build the slice for this chain of computations. LoadedSlice LS(User, LD, Shift, &DAG); @@ -13758,7 +10576,7 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) { LSIt != LSItEnd; ++LSIt) { SDValue SliceInst = LSIt->loadSlice(); CombineTo(LSIt->Inst, SliceInst, true); - if (SliceInst.getOpcode() != ISD::LOAD) + if (SliceInst.getNode()->getOpcode() != ISD::LOAD) SliceInst = SliceInst.getOperand(0); assert(SliceInst->getOpcode() == ISD::LOAD && "It takes more than a zext to get to the loaded slice!!"); @@ -13768,7 +10586,6 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) { SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other, ArgChains); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain); - AddToWorklist(Chain.getNode()); return true; } @@ -13789,6 +10606,22 @@ CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) { LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0)); if (LD->getBasePtr() != Ptr) return Result; // Not from same pointer. + // The store should be chained directly to the load or be an operand of a + // tokenfactor. + if (LD == Chain.getNode()) + ; // ok. + else if (Chain->getOpcode() != ISD::TokenFactor) + return Result; // Fail. + else { + bool isOk = false; + for (const SDValue &ChainOp : Chain->op_values()) + if (ChainOp.getNode() == LD) { + isOk = true; + break; + } + if (!isOk) return Result; + } + // This only handles simple types. if (V.getValueType() != MVT::i16 && V.getValueType() != MVT::i32 && @@ -13825,29 +10658,12 @@ CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) { // is aligned the same as the access width. if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result; - // For narrowing to be valid, it must be the case that the load the - // immediately preceeding memory operation before the store. - if (LD == Chain.getNode()) - ; // ok. - else if (Chain->getOpcode() == ISD::TokenFactor && - SDValue(LD, 1).hasOneUse()) { - // LD has only 1 chain use so they are no indirect dependencies. - bool isOk = false; - for (const SDValue &ChainOp : Chain->op_values()) - if (ChainOp.getNode() == LD) { - isOk = true; - break; - } - if (!isOk) - return Result; - } else - return Result; // Fail. - Result.first = MaskedBytes; Result.second = NotMaskTZ/8; return Result; } + /// Check to see if IVal is something that provides a value as specified by /// MaskInfo. If so, replace the specified store with a narrower store of /// truncated IVal. @@ -13902,12 +10718,12 @@ ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo, IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal); ++OpsNarrowed; - return DAG - .getStore(St->getChain(), SDLoc(St), IVal, Ptr, - St->getPointerInfo().getWithOffset(StOffset), NewAlign) - .getNode(); + return DAG.getStore(St->getChain(), SDLoc(St), IVal, Ptr, + St->getPointerInfo().getWithOffset(StOffset), + false, false, NewAlign).getNode(); } + /// Look for sequence of load / op / store where op is one of 'or', 'xor', and /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try /// narrowing the load and store if it would end up being a win for performance @@ -14010,16 +10826,19 @@ SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) { Ptr.getValueType(), Ptr, DAG.getConstant(PtrOff, SDLoc(LD), Ptr.getValueType())); - SDValue NewLD = - DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr, - LD->getPointerInfo().getWithOffset(PtrOff), NewAlign, - LD->getMemOperand()->getFlags(), LD->getAAInfo()); + SDValue NewLD = DAG.getLoad(NewVT, SDLoc(N0), + LD->getChain(), NewPtr, + LD->getPointerInfo().getWithOffset(PtrOff), + LD->isVolatile(), LD->isNonTemporal(), + LD->isInvariant(), NewAlign, + LD->getAAInfo()); SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD, DAG.getConstant(NewImm, SDLoc(Value), NewVT)); - SDValue NewST = - DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr, - ST->getPointerInfo().getWithOffset(PtrOff), NewAlign); + SDValue NewST = DAG.getStore(Chain, SDLoc(N), + NewVal, NewPtr, + ST->getPointerInfo().getWithOffset(PtrOff), + false, false, NewAlign); AddToWorklist(NewPtr.getNode()); AddToWorklist(NewLD.getNode()); @@ -14068,13 +10887,15 @@ SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) { if (LDAlign < ABIAlign || STAlign < ABIAlign) return SDValue(); - SDValue NewLD = - DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(), LD->getBasePtr(), - LD->getPointerInfo(), LDAlign); + SDValue NewLD = DAG.getLoad(IntVT, SDLoc(Value), + LD->getChain(), LD->getBasePtr(), + LD->getPointerInfo(), + false, false, false, LDAlign); - SDValue NewST = - DAG.getStore(NewLD.getValue(1), SDLoc(N), NewLD, ST->getBasePtr(), - ST->getPointerInfo(), STAlign); + SDValue NewST = DAG.getStore(NewLD.getValue(1), SDLoc(N), + NewLD, ST->getBasePtr(), + ST->getPointerInfo(), + false, false, STAlign); AddToWorklist(NewLD.getNode()); AddToWorklist(NewST.getNode()); @@ -14087,6 +10908,96 @@ SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) { return SDValue(); } +namespace { +/// Helper struct to parse and store a memory address as base + index + offset. +/// We ignore sign extensions when it is safe to do so. +/// The following two expressions are not equivalent. To differentiate we need +/// to store whether there was a sign extension involved in the index +/// computation. +/// (load (i64 add (i64 copyfromreg %c) +/// (i64 signextend (add (i8 load %index) +/// (i8 1)))) +/// vs +/// +/// (load (i64 add (i64 copyfromreg %c) +/// (i64 signextend (i32 add (i32 signextend (i8 load %index)) +/// (i32 1))))) +struct BaseIndexOffset { + SDValue Base; + SDValue Index; + int64_t Offset; + bool IsIndexSignExt; + + BaseIndexOffset() : Offset(0), IsIndexSignExt(false) {} + + BaseIndexOffset(SDValue Base, SDValue Index, int64_t Offset, + bool IsIndexSignExt) : + Base(Base), Index(Index), Offset(Offset), IsIndexSignExt(IsIndexSignExt) {} + + bool equalBaseIndex(const BaseIndexOffset &Other) { + return Other.Base == Base && Other.Index == Index && + Other.IsIndexSignExt == IsIndexSignExt; + } + + /// Parses tree in Ptr for base, index, offset addresses. + static BaseIndexOffset match(SDValue Ptr) { + bool IsIndexSignExt = false; + + // We only can pattern match BASE + INDEX + OFFSET. If Ptr is not an ADD + // instruction, then it could be just the BASE or everything else we don't + // know how to handle. Just use Ptr as BASE and give up. + if (Ptr->getOpcode() != ISD::ADD) + return BaseIndexOffset(Ptr, SDValue(), 0, IsIndexSignExt); + + // We know that we have at least an ADD instruction. Try to pattern match + // the simple case of BASE + OFFSET. + if (isa<ConstantSDNode>(Ptr->getOperand(1))) { + int64_t Offset = cast<ConstantSDNode>(Ptr->getOperand(1))->getSExtValue(); + return BaseIndexOffset(Ptr->getOperand(0), SDValue(), Offset, + IsIndexSignExt); + } + + // Inside a loop the current BASE pointer is calculated using an ADD and a + // MUL instruction. In this case Ptr is the actual BASE pointer. + // (i64 add (i64 %array_ptr) + // (i64 mul (i64 %induction_var) + // (i64 %element_size))) + if (Ptr->getOperand(1)->getOpcode() == ISD::MUL) + return BaseIndexOffset(Ptr, SDValue(), 0, IsIndexSignExt); + + // Look at Base + Index + Offset cases. + SDValue Base = Ptr->getOperand(0); + SDValue IndexOffset = Ptr->getOperand(1); + + // Skip signextends. + if (IndexOffset->getOpcode() == ISD::SIGN_EXTEND) { + IndexOffset = IndexOffset->getOperand(0); + IsIndexSignExt = true; + } + + // Either the case of Base + Index (no offset) or something else. + if (IndexOffset->getOpcode() != ISD::ADD) + return BaseIndexOffset(Base, IndexOffset, 0, IsIndexSignExt); + + // Now we have the case of Base + Index + offset. + SDValue Index = IndexOffset->getOperand(0); + SDValue Offset = IndexOffset->getOperand(1); + + if (!isa<ConstantSDNode>(Offset)) + return BaseIndexOffset(Ptr, SDValue(), 0, IsIndexSignExt); + + // Ignore signextends. + if (Index->getOpcode() == ISD::SIGN_EXTEND) { + Index = Index->getOperand(0); + IsIndexSignExt = true; + } else IsIndexSignExt = false; + + int64_t Off = cast<ConstantSDNode>(Offset)->getSExtValue(); + return BaseIndexOffset(Base, Index, Off, IsIndexSignExt); + } +}; +} // namespace + // This is a helper function for visitMUL to check the profitability // of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2). // MulNode is the original multiply, AddNode is (add x, c1), @@ -14111,6 +11022,7 @@ bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, // Walk all the users of the constant with which we're multiplying. for (SDNode *Use : ConstNode->uses()) { + if (Use == MulNode) // This use is the one we're on right now. Skip it. continue; @@ -14151,7 +11063,7 @@ bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, // multiply (CONST * A) after we also do the same transformation // to the "t2" instruction. if (OtherOp->getOpcode() == ISD::ADD && - DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) && + isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) && OtherOp->getOperand(0).getNode() == MulVar) return true; } @@ -14161,118 +11073,83 @@ bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, return false; } -SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes, - unsigned NumStores) { - SmallVector<SDValue, 8> Chains; - SmallPtrSet<const SDNode *, 8> Visited; - SDLoc StoreDL(StoreNodes[0].MemNode); - - for (unsigned i = 0; i < NumStores; ++i) { - Visited.insert(StoreNodes[i].MemNode); - } +SDValue DAGCombiner::getMergedConstantVectorStore(SelectionDAG &DAG, + SDLoc SL, + ArrayRef<MemOpLink> Stores, + SmallVectorImpl<SDValue> &Chains, + EVT Ty) const { + SmallVector<SDValue, 8> BuildVector; - // don't include nodes that are children - for (unsigned i = 0; i < NumStores; ++i) { - if (Visited.count(StoreNodes[i].MemNode->getChain().getNode()) == 0) - Chains.push_back(StoreNodes[i].MemNode->getChain()); + for (unsigned I = 0, E = Ty.getVectorNumElements(); I != E; ++I) { + StoreSDNode *St = cast<StoreSDNode>(Stores[I].MemNode); + Chains.push_back(St->getChain()); + BuildVector.push_back(St->getValue()); } - assert(Chains.size() > 0 && "Chain should have generated a chain"); - return DAG.getNode(ISD::TokenFactor, StoreDL, MVT::Other, Chains); + return DAG.getNode(ISD::BUILD_VECTOR, SL, Ty, BuildVector); } bool DAGCombiner::MergeStoresOfConstantsOrVecElts( - SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores, - bool IsConstantSrc, bool UseVector, bool UseTrunc) { + SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, + unsigned NumStores, bool IsConstantSrc, bool UseVector) { // Make sure we have something to merge. if (NumStores < 2) return false; + int64_t ElementSizeBytes = MemVT.getSizeInBits() / 8; + LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; + unsigned LatestNodeUsed = 0; + + for (unsigned i=0; i < NumStores; ++i) { + // Find a chain for the new wide-store operand. Notice that some + // of the store nodes that we found may not be selected for inclusion + // in the wide store. The chain we use needs to be the chain of the + // latest store node which is *used* and replaced by the wide store. + if (StoreNodes[i].SequenceNum < StoreNodes[LatestNodeUsed].SequenceNum) + LatestNodeUsed = i; + } + + SmallVector<SDValue, 8> Chains; + // The latest Node in the DAG. + LSBaseSDNode *LatestOp = StoreNodes[LatestNodeUsed].MemNode; SDLoc DL(StoreNodes[0].MemNode); - int64_t ElementSizeBits = MemVT.getStoreSizeInBits(); - unsigned SizeInBits = NumStores * ElementSizeBits; - unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1; - - EVT StoreTy; + SDValue StoredVal; if (UseVector) { - unsigned Elts = NumStores * NumMemElts; + bool IsVec = MemVT.isVector(); + unsigned Elts = NumStores; + if (IsVec) { + // When merging vector stores, get the total number of elements. + Elts *= MemVT.getVectorNumElements(); + } // Get the type for the merged vector store. - StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); - } else - StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits); + EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); + assert(TLI.isTypeLegal(Ty) && "Illegal vector store"); - SDValue StoredVal; - if (UseVector) { if (IsConstantSrc) { - SmallVector<SDValue, 8> BuildVector; - for (unsigned I = 0; I != NumStores; ++I) { - StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode); - SDValue Val = St->getValue(); - // If constant is of the wrong type, convert it now. - if (MemVT != Val.getValueType()) { - Val = peekThroughBitcasts(Val); - // Deal with constants of wrong size. - if (ElementSizeBits != Val.getValueSizeInBits()) { - EVT IntMemVT = - EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()); - if (isa<ConstantFPSDNode>(Val)) { - // Not clear how to truncate FP values. - return false; - } else if (auto *C = dyn_cast<ConstantSDNode>(Val)) - Val = DAG.getConstant(C->getAPIntValue() - .zextOrTrunc(Val.getValueSizeInBits()) - .zextOrTrunc(ElementSizeBits), - SDLoc(C), IntMemVT); - } - // Make sure correctly size type is the correct type. - Val = DAG.getBitcast(MemVT, Val); - } - BuildVector.push_back(Val); - } - StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS - : ISD::BUILD_VECTOR, - DL, StoreTy, BuildVector); + StoredVal = getMergedConstantVectorStore(DAG, DL, StoreNodes, Chains, Ty); } else { SmallVector<SDValue, 8> Ops; for (unsigned i = 0; i < NumStores; ++i) { StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode); - SDValue Val = peekThroughBitcasts(St->getValue()); - // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of - // type MemVT. If the underlying value is not the correct - // type, but it is an extraction of an appropriate vector we - // can recast Val to be of the correct type. This may require - // converting between EXTRACT_VECTOR_ELT and - // EXTRACT_SUBVECTOR. - if ((MemVT != Val.getValueType()) && - (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT || - Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) { - EVT MemVTScalarTy = MemVT.getScalarType(); - // We may need to add a bitcast here to get types to line up. - if (MemVTScalarTy != Val.getValueType().getScalarType()) { - Val = DAG.getBitcast(MemVT, Val); - } else { - unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR - : ISD::EXTRACT_VECTOR_ELT; - SDValue Vec = Val.getOperand(0); - SDValue Idx = Val.getOperand(1); - Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx); - } - } + SDValue Val = St->getValue(); + // All operands of BUILD_VECTOR / CONCAT_VECTOR must have the same type. + if (Val.getValueType() != MemVT) + return false; Ops.push_back(Val); + Chains.push_back(St->getChain()); } // Build the extracted vector elements back into a vector. - StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS - : ISD::BUILD_VECTOR, - DL, StoreTy, Ops); - } + StoredVal = DAG.getNode(IsVec ? ISD::CONCAT_VECTORS : ISD::BUILD_VECTOR, + DL, Ty, Ops); } } else { // We should always use a vector store when merging extracted vector // elements, so this path implies a store of constants. assert(IsConstantSrc && "Merged vector elements should use vector store"); + unsigned SizeInBits = NumStores * ElementSizeBytes * 8; APInt StoreInt(SizeInBits, 0); // Construct a single integer constant which is made of the smaller @@ -14281,259 +11158,189 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( for (unsigned i = 0; i < NumStores; ++i) { unsigned Idx = IsLE ? (NumStores - 1 - i) : i; StoreSDNode *St = cast<StoreSDNode>(StoreNodes[Idx].MemNode); + Chains.push_back(St->getChain()); SDValue Val = St->getValue(); - Val = peekThroughBitcasts(Val); - StoreInt <<= ElementSizeBits; + StoreInt <<= ElementSizeBytes * 8; if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) { - StoreInt |= C->getAPIntValue() - .zextOrTrunc(ElementSizeBits) - .zextOrTrunc(SizeInBits); + StoreInt |= C->getAPIntValue().zext(SizeInBits); } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) { - StoreInt |= C->getValueAPF() - .bitcastToAPInt() - .zextOrTrunc(ElementSizeBits) - .zextOrTrunc(SizeInBits); - // If fp truncation is necessary give up for now. - if (MemVT.getSizeInBits() != ElementSizeBits) - return false; + StoreInt |= C->getValueAPF().bitcastToAPInt().zext(SizeInBits); } else { llvm_unreachable("Invalid constant element type"); } } // Create the new Load and Store operations. + EVT StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits); StoredVal = DAG.getConstant(StoreInt, DL, StoreTy); } - LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; - SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores); - - // make sure we use trunc store if it's necessary to be legal. - SDValue NewStore; - if (!UseTrunc) { - NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(), - FirstInChain->getPointerInfo(), - FirstInChain->getAlignment()); - } else { // Must be realized as a trunc store - EVT LegalizedStoredValTy = - TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType()); - unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits(); - ConstantSDNode *C = cast<ConstantSDNode>(StoredVal); - SDValue ExtendedStoreVal = - DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL, - LegalizedStoredValTy); - NewStore = DAG.getTruncStore( - NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(), - FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/, - FirstInChain->getAlignment(), - FirstInChain->getMemOperand()->getFlags()); - } - - // Replace all merged stores with the new store. - for (unsigned i = 0; i < NumStores; ++i) - CombineTo(StoreNodes[i].MemNode, NewStore); - - AddToWorklist(NewChain.getNode()); + assert(!Chains.empty()); + + SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains); + SDValue NewStore = DAG.getStore(NewChain, DL, StoredVal, + FirstInChain->getBasePtr(), + FirstInChain->getPointerInfo(), + false, false, + FirstInChain->getAlignment()); + + bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA + : DAG.getSubtarget().useAA(); + if (UseAA) { + // Replace all merged stores with the new store. + for (unsigned i = 0; i < NumStores; ++i) + CombineTo(StoreNodes[i].MemNode, NewStore); + } else { + // Replace the last store with the new store. + CombineTo(LatestOp, NewStore); + // Erase all other stores. + for (unsigned i = 0; i < NumStores; ++i) { + if (StoreNodes[i].MemNode == LatestOp) + continue; + StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode); + // ReplaceAllUsesWith will replace all uses that existed when it was + // called, but graph optimizations may cause new ones to appear. For + // example, the case in pr14333 looks like + // + // St's chain -> St -> another store -> X + // + // And the only difference from St to the other store is the chain. + // When we change it's chain to be St's chain they become identical, + // get CSEed and the net result is that X is now a use of St. + // Since we know that St is redundant, just iterate. + while (!St->use_empty()) + DAG.ReplaceAllUsesWith(SDValue(St, 0), St->getChain()); + deleteAndRecombine(St); + } + } + return true; } -void DAGCombiner::getStoreMergeCandidates( - StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes, - SDNode *&RootNode) { +void DAGCombiner::getStoreMergeAndAliasCandidates( + StoreSDNode* St, SmallVectorImpl<MemOpLink> &StoreNodes, + SmallVectorImpl<LSBaseSDNode*> &AliasLoadNodes) { // This holds the base pointer, index, and the offset in bytes from the base // pointer. - BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG); - EVT MemVT = St->getMemoryVT(); + BaseIndexOffset BasePtr = BaseIndexOffset::match(St->getBasePtr()); - SDValue Val = peekThroughBitcasts(St->getValue()); // We must have a base and an offset. - if (!BasePtr.getBase().getNode()) + if (!BasePtr.Base.getNode()) return; // Do not handle stores to undef base pointers. - if (BasePtr.getBase().isUndef()) + if (BasePtr.Base.getOpcode() == ISD::UNDEF) return; - bool IsConstantSrc = isa<ConstantSDNode>(Val) || isa<ConstantFPSDNode>(Val); - bool IsExtractVecSrc = (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT || - Val.getOpcode() == ISD::EXTRACT_SUBVECTOR); - bool IsLoadSrc = isa<LoadSDNode>(Val); - BaseIndexOffset LBasePtr; - // Match on loadbaseptr if relevant. - EVT LoadVT; - if (IsLoadSrc) { - auto *Ld = cast<LoadSDNode>(Val); - LBasePtr = BaseIndexOffset::match(Ld, DAG); - LoadVT = Ld->getMemoryVT(); - // Load and store should be the same type. - if (MemVT != LoadVT) - return; - // Loads must only have one use. - if (!Ld->hasNUsesOfValue(1, 0)) - return; - // The memory operands must not be volatile. - if (Ld->isVolatile() || Ld->isIndexed()) - return; - } - auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr, - int64_t &Offset) -> bool { - if (Other->isVolatile() || Other->isIndexed()) - return false; - SDValue Val = peekThroughBitcasts(Other->getValue()); - // Allow merging constants of different types as integers. - bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT()) - : Other->getMemoryVT() != MemVT; - if (IsLoadSrc) { - if (NoTypeMatch) - return false; - // The Load's Base Ptr must also match - if (LoadSDNode *OtherLd = dyn_cast<LoadSDNode>(Val)) { - auto LPtr = BaseIndexOffset::match(OtherLd, DAG); - if (LoadVT != OtherLd->getMemoryVT()) - return false; - // Loads must only have one use. - if (!OtherLd->hasNUsesOfValue(1, 0)) - return false; - // The memory operands must not be volatile. - if (OtherLd->isVolatile() || OtherLd->isIndexed()) - return false; - if (!(LBasePtr.equalBaseIndex(LPtr, DAG))) - return false; - } else - return false; - } - if (IsConstantSrc) { - if (NoTypeMatch) - return false; - if (!(isa<ConstantSDNode>(Val) || isa<ConstantFPSDNode>(Val))) - return false; - } - if (IsExtractVecSrc) { - // Do not merge truncated stores here. - if (Other->isTruncatingStore()) - return false; - if (!MemVT.bitsEq(Val.getValueType())) - return false; - if (Val.getOpcode() != ISD::EXTRACT_VECTOR_ELT && - Val.getOpcode() != ISD::EXTRACT_SUBVECTOR) - return false; + // Walk up the chain and look for nodes with offsets from the same + // base pointer. Stop when reaching an instruction with a different kind + // or instruction which has a different base pointer. + EVT MemVT = St->getMemoryVT(); + unsigned Seq = 0; + StoreSDNode *Index = St; + + + bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA + : DAG.getSubtarget().useAA(); + + if (UseAA) { + // Look at other users of the same chain. Stores on the same chain do not + // alias. If combiner-aa is enabled, non-aliasing stores are canonicalized + // to be on the same chain, so don't bother looking at adjacent chains. + + SDValue Chain = St->getChain(); + for (auto I = Chain->use_begin(), E = Chain->use_end(); I != E; ++I) { + if (StoreSDNode *OtherST = dyn_cast<StoreSDNode>(*I)) { + if (I.getOperandNo() != 0) + continue; + + if (OtherST->isVolatile() || OtherST->isIndexed()) + continue; + + if (OtherST->getMemoryVT() != MemVT) + continue; + + BaseIndexOffset Ptr = BaseIndexOffset::match(OtherST->getBasePtr()); + + if (Ptr.equalBaseIndex(BasePtr)) + StoreNodes.push_back(MemOpLink(OtherST, Ptr.Offset, Seq++)); + } } - Ptr = BaseIndexOffset::match(Other, DAG); - return (BasePtr.equalBaseIndex(Ptr, DAG, Offset)); - }; - // We looking for a root node which is an ancestor to all mergable - // stores. We search up through a load, to our root and then down - // through all children. For instance we will find Store{1,2,3} if - // St is Store1, Store2. or Store3 where the root is not a load - // which always true for nonvolatile ops. TODO: Expand - // the search to find all valid candidates through multiple layers of loads. - // - // Root - // |-------|-------| - // Load Load Store3 - // | | - // Store1 Store2 - // - // FIXME: We should be able to climb and - // descend TokenFactors to find candidates as well. - - RootNode = St->getChain().getNode(); - - if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(RootNode)) { - RootNode = Ldn->getChain().getNode(); - for (auto I = RootNode->use_begin(), E = RootNode->use_end(); I != E; ++I) - if (I.getOperandNo() == 0 && isa<LoadSDNode>(*I)) // walk down chain - for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2) - if (I2.getOperandNo() == 0) - if (StoreSDNode *OtherST = dyn_cast<StoreSDNode>(*I2)) { - BaseIndexOffset Ptr; - int64_t PtrDiff; - if (CandidateMatch(OtherST, Ptr, PtrDiff)) - StoreNodes.push_back(MemOpLink(OtherST, PtrDiff)); - } - } else - for (auto I = RootNode->use_begin(), E = RootNode->use_end(); I != E; ++I) - if (I.getOperandNo() == 0) - if (StoreSDNode *OtherST = dyn_cast<StoreSDNode>(*I)) { - BaseIndexOffset Ptr; - int64_t PtrDiff; - if (CandidateMatch(OtherST, Ptr, PtrDiff)) - StoreNodes.push_back(MemOpLink(OtherST, PtrDiff)); - } -} + return; + } -// We need to check that merging these stores does not cause a loop in -// the DAG. Any store candidate may depend on another candidate -// indirectly through its operand (we already consider dependencies -// through the chain). Check in parallel by searching up from -// non-chain operands of candidates. -bool DAGCombiner::checkMergeStoreCandidatesForDependencies( - SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores, - SDNode *RootNode) { - // FIXME: We should be able to truncate a full search of - // predecessors by doing a BFS and keeping tabs the originating - // stores from which worklist nodes come from in a similar way to - // TokenFactor simplfication. + while (Index) { + // If the chain has more than one use, then we can't reorder the mem ops. + if (Index != St && !SDValue(Index, 0)->hasOneUse()) + break; - SmallPtrSet<const SDNode *, 32> Visited; - SmallVector<const SDNode *, 8> Worklist; + // Find the base pointer and offset for this memory node. + BaseIndexOffset Ptr = BaseIndexOffset::match(Index->getBasePtr()); - // RootNode is a predecessor to all candidates so we need not search - // past it. Add RootNode (peeking through TokenFactors). Do not count - // these towards size check. + // Check that the base pointer is the same as the original one. + if (!Ptr.equalBaseIndex(BasePtr)) + break; - Worklist.push_back(RootNode); - while (!Worklist.empty()) { - auto N = Worklist.pop_back_val(); - if (!Visited.insert(N).second) - continue; // Already present in Visited. - if (N->getOpcode() == ISD::TokenFactor) { - for (SDValue Op : N->ops()) - Worklist.push_back(Op.getNode()); - } - } - - // Don't count pruning nodes towards max. - unsigned int Max = 1024 + Visited.size(); - // Search Ops of store candidates. - for (unsigned i = 0; i < NumStores; ++i) { - SDNode *N = StoreNodes[i].MemNode; - // Of the 4 Store Operands: - // * Chain (Op 0) -> We have already considered these - // in candidate selection and can be - // safely ignored - // * Value (Op 1) -> Cycles may happen (e.g. through load chains) - // * Address (Op 2) -> Merged addresses may only vary by a fixed constant, - // but aren't necessarily fromt the same base node, so - // cycles possible (e.g. via indexed store). - // * (Op 3) -> Represents the pre or post-indexing offset (or undef for - // non-indexed stores). Not constant on all targets (e.g. ARM) - // and so can participate in a cycle. - for (unsigned j = 1; j < N->getNumOperands(); ++j) - Worklist.push_back(N->getOperand(j).getNode()); - } - // Search through DAG. We can stop early if we find a store node. - for (unsigned i = 0; i < NumStores; ++i) - if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist, - Max)) - return false; - return true; + // The memory operands must not be volatile. + if (Index->isVolatile() || Index->isIndexed()) + break; + + // No truncation. + if (StoreSDNode *St = dyn_cast<StoreSDNode>(Index)) + if (St->isTruncatingStore()) + break; + + // The stored memory type must be the same. + if (Index->getMemoryVT() != MemVT) + break; + + // We do not allow under-aligned stores in order to prevent + // overriding stores. NOTE: this is a bad hack. Alignment SHOULD + // be irrelevant here; what MATTERS is that we not move memory + // operations that potentially overlap past each-other. + if (Index->getAlignment() < MemVT.getStoreSize()) + break; + + // We found a potential memory operand to merge. + StoreNodes.push_back(MemOpLink(Index, Ptr.Offset, Seq++)); + + // Find the next memory operand in the chain. If the next operand in the + // chain is a store then move up and continue the scan with the next + // memory operand. If the next operand is a load save it and use alias + // information to check if it interferes with anything. + SDNode *NextInChain = Index->getChain().getNode(); + while (1) { + if (StoreSDNode *STn = dyn_cast<StoreSDNode>(NextInChain)) { + // We found a store node. Use it for the next iteration. + Index = STn; + break; + } else if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(NextInChain)) { + if (Ldn->isVolatile()) { + Index = nullptr; + break; + } + + // Save the load node for later. Continue the scan. + AliasLoadNodes.push_back(Ldn); + NextInChain = Ldn->getChain().getNode(); + continue; + } else { + Index = nullptr; + break; + } + } + } } -bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { +bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { if (OptLevel == CodeGenOpt::None) return false; EVT MemVT = St->getMemoryVT(); - int64_t ElementSizeBytes = MemVT.getStoreSize(); - unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1; - - if (MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits) - return false; - - bool NoVectors = DAG.getMachineFunction().getFunction().hasFnAttribute( + int64_t ElementSizeBytes = MemVT.getSizeInBits() / 8; + bool NoVectors = DAG.getMachineFunction().getFunction()->hasFnAttribute( Attribute::NoImplicitFloat); // This function cannot currently deal with non-byte-sized memory sizes. @@ -14545,7 +11352,7 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { // Perform an early exit check. Do not bother looking at stored values that // are not constants, loads, or extracted vector elements. - SDValue StoredVal = peekThroughBitcasts(St->getValue()); + SDValue StoredVal = St->getValue(); bool IsLoadSrc = isa<LoadSDNode>(StoredVal); bool IsConstantSrc = isa<ConstantSDNode>(StoredVal) || isa<ConstantFPSDNode>(StoredVal); @@ -14555,495 +11362,381 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { if (!IsConstantSrc && !IsLoadSrc && !IsExtractVecSrc) return false; + // Don't merge vectors into wider vectors if the source data comes from loads. + // TODO: This restriction can be lifted by using logic similar to the + // ExtractVecSrc case. + if (MemVT.isVector() && IsLoadSrc) + return false; + + // Only look at ends of store sequences. + SDValue Chain = SDValue(St, 0); + if (Chain->hasOneUse() && Chain->use_begin()->getOpcode() == ISD::STORE) + return false; + + // Save the LoadSDNodes that we find in the chain. + // We need to make sure that these nodes do not interfere with + // any of the store nodes. + SmallVector<LSBaseSDNode*, 8> AliasLoadNodes; + + // Save the StoreSDNodes that we find in the chain. SmallVector<MemOpLink, 8> StoreNodes; - SDNode *RootNode; - // Find potential store merge candidates by searching through chain sub-DAG - getStoreMergeCandidates(St, StoreNodes, RootNode); + + getStoreMergeAndAliasCandidates(St, StoreNodes, AliasLoadNodes); // Check if there is anything to merge. if (StoreNodes.size() < 2) return false; // Sort the memory operands according to their distance from the - // base pointer. - llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) { - return LHS.OffsetFromBase < RHS.OffsetFromBase; + // base pointer. As a secondary criteria: make sure stores coming + // later in the code come first in the list. This is important for + // the non-UseAA case, because we're merging stores into the FINAL + // store along a chain which potentially contains aliasing stores. + // Thus, if there are multiple stores to the same address, the last + // one can be considered for merging but not the others. + std::sort(StoreNodes.begin(), StoreNodes.end(), + [](MemOpLink LHS, MemOpLink RHS) { + return LHS.OffsetFromBase < RHS.OffsetFromBase || + (LHS.OffsetFromBase == RHS.OffsetFromBase && + LHS.SequenceNum < RHS.SequenceNum); }); - // Store Merge attempts to merge the lowest stores. This generally - // works out as if successful, as the remaining stores are checked - // after the first collection of stores is merged. However, in the - // case that a non-mergeable store is found first, e.g., {p[-2], - // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent - // mergeable cases. To prevent this, we prune such stores from the - // front of StoreNodes here. - - bool RV = false; - while (StoreNodes.size() > 1) { - unsigned StartIdx = 0; - while ((StartIdx + 1 < StoreNodes.size()) && - StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes != - StoreNodes[StartIdx + 1].OffsetFromBase) - ++StartIdx; - - // Bail if we don't have enough candidates to merge. - if (StartIdx + 1 >= StoreNodes.size()) - return RV; - - if (StartIdx) - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx); - - // Scan the memory operations on the chain and find the first - // non-consecutive store memory address. - unsigned NumConsecutiveStores = 1; - int64_t StartAddress = StoreNodes[0].OffsetFromBase; + // Scan the memory operations on the chain and find the first non-consecutive + // store memory address. + unsigned LastConsecutiveStore = 0; + int64_t StartAddress = StoreNodes[0].OffsetFromBase; + for (unsigned i = 0, e = StoreNodes.size(); i < e; ++i) { + // Check that the addresses are consecutive starting from the second // element in the list of stores. - for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) { + if (i > 0) { int64_t CurrAddress = StoreNodes[i].OffsetFromBase; if (CurrAddress - StartAddress != (ElementSizeBytes * i)) break; - NumConsecutiveStores = i + 1; } - if (NumConsecutiveStores < 2) { - StoreNodes.erase(StoreNodes.begin(), - StoreNodes.begin() + NumConsecutiveStores); - continue; - } - - // The node with the lowest store address. - LLVMContext &Context = *DAG.getContext(); - const DataLayout &DL = DAG.getDataLayout(); - - // Store the constants into memory as one consecutive store. - if (IsConstantSrc) { - while (NumConsecutiveStores >= 2) { - LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; - unsigned FirstStoreAS = FirstInChain->getAddressSpace(); - unsigned FirstStoreAlign = FirstInChain->getAlignment(); - unsigned LastLegalType = 1; - unsigned LastLegalVectorType = 1; - bool LastIntegerTrunc = false; - bool NonZero = false; - unsigned FirstZeroAfterNonZero = NumConsecutiveStores; - for (unsigned i = 0; i < NumConsecutiveStores; ++i) { - StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode); - SDValue StoredVal = ST->getValue(); - bool IsElementZero = false; - if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal)) - IsElementZero = C->isNullValue(); - else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal)) - IsElementZero = C->getConstantFPValue()->isNullValue(); - if (IsElementZero) { - if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores) - FirstZeroAfterNonZero = i; - } - NonZero |= !IsElementZero; - - // Find a legal type for the constant store. - unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8; - EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits); - bool IsFast = false; + // Check if this store interferes with any of the loads that we found. + // If we find a load that alias with this store. Stop the sequence. + if (std::any_of(AliasLoadNodes.begin(), AliasLoadNodes.end(), + [&](LSBaseSDNode* Ldn) { + return isAlias(Ldn, StoreNodes[i].MemNode); + })) + break; - // Break early when size is too large to be legal. - if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits) - break; + // Mark this node as useful. + LastConsecutiveStore = i; + } - if (TLI.isTypeLegal(StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && - TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, - FirstStoreAlign, &IsFast) && - IsFast) { - LastIntegerTrunc = false; - LastLegalType = i + 1; - // Or check whether a truncstore is legal. - } else if (TLI.getTypeAction(Context, StoreTy) == - TargetLowering::TypePromoteInteger) { - EVT LegalizedStoredValTy = - TLI.getTypeToTransformTo(Context, StoredVal.getValueType()); - if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) && - TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, - FirstStoreAlign, &IsFast) && - IsFast) { - LastIntegerTrunc = true; - LastLegalType = i + 1; - } - } + // The node with the lowest store address. + LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; + unsigned FirstStoreAS = FirstInChain->getAddressSpace(); + unsigned FirstStoreAlign = FirstInChain->getAlignment(); + LLVMContext &Context = *DAG.getContext(); + const DataLayout &DL = DAG.getDataLayout(); + + // Store the constants into memory as one consecutive store. + if (IsConstantSrc) { + unsigned LastLegalType = 0; + unsigned LastLegalVectorType = 0; + bool NonZero = false; + for (unsigned i=0; i<LastConsecutiveStore+1; ++i) { + StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode); + SDValue StoredVal = St->getValue(); + + if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal)) { + NonZero |= !C->isNullValue(); + } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal)) { + NonZero |= !C->getConstantFPValue()->isNullValue(); + } else { + // Non-constant. + break; + } - // We only use vectors if the constant is known to be zero or the - // target allows it and the function is not marked with the - // noimplicitfloat attribute. - if ((!NonZero || - TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) && - !NoVectors) { - // Find a legal type for the vector store. - unsigned Elts = (i + 1) * NumMemElts; - EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); - if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) && - TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) && - TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS, - FirstStoreAlign, &IsFast) && - IsFast) - LastLegalVectorType = i + 1; - } + // Find a legal type for the constant store. + unsigned SizeInBits = (i+1) * ElementSizeBytes * 8; + EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits); + bool IsFast; + if (TLI.isTypeLegal(StoreTy) && + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, + FirstStoreAlign, &IsFast) && IsFast) { + LastLegalType = i+1; + // Or check whether a truncstore is legal. + } else if (TLI.getTypeAction(Context, StoreTy) == + TargetLowering::TypePromoteInteger) { + EVT LegalizedStoredValueTy = + TLI.getTypeToTransformTo(Context, StoredVal.getValueType()); + if (TLI.isTruncStoreLegal(LegalizedStoredValueTy, StoreTy) && + TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy, + FirstStoreAS, FirstStoreAlign, &IsFast) && + IsFast) { + LastLegalType = i + 1; } + } - bool UseVector = (LastLegalVectorType > LastLegalType) && !NoVectors; - unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType; - - // Check if we found a legal integer type that creates a meaningful - // merge. - if (NumElem < 2) { - // We know that candidate stores are in order and of correct - // shape. While there is no mergeable sequence from the - // beginning one may start later in the sequence. The only - // reason a merge of size N could have failed where another of - // the same size would not have, is if the alignment has - // improved or we've dropped a non-zero value. Drop as many - // candidates as we can here. - unsigned NumSkip = 1; - while ( - (NumSkip < NumConsecutiveStores) && - (NumSkip < FirstZeroAfterNonZero) && - (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) - NumSkip++; - - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); - NumConsecutiveStores -= NumSkip; - continue; - } + // We only use vectors if the constant is known to be zero or the target + // allows it and the function is not marked with the noimplicitfloat + // attribute. + if ((!NonZero || TLI.storeOfVectorConstantIsCheap(MemVT, i+1, + FirstStoreAS)) && + !NoVectors) { + // Find a legal type for the vector store. + EVT Ty = EVT::getVectorVT(Context, MemVT, i+1); + if (TLI.isTypeLegal(Ty) && + TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS, + FirstStoreAlign, &IsFast) && IsFast) + LastLegalVectorType = i + 1; + } + } - // Check that we can merge these candidates without causing a cycle. - if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem, - RootNode)) { - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); - NumConsecutiveStores -= NumElem; - continue; - } + // Check if we found a legal integer type to store. + if (LastLegalType == 0 && LastLegalVectorType == 0) + return false; - RV |= MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem, true, - UseVector, LastIntegerTrunc); + bool UseVector = (LastLegalVectorType > LastLegalType) && !NoVectors; + unsigned NumElem = UseVector ? LastLegalVectorType : LastLegalType; + + return MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem, + true, UseVector); + } + + // When extracting multiple vector elements, try to store them + // in one vector store rather than a sequence of scalar stores. + if (IsExtractVecSrc) { + unsigned NumStoresToMerge = 0; + bool IsVec = MemVT.isVector(); + for (unsigned i = 0; i < LastConsecutiveStore + 1; ++i) { + StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode); + unsigned StoreValOpcode = St->getValue().getOpcode(); + // This restriction could be loosened. + // Bail out if any stored values are not elements extracted from a vector. + // It should be possible to handle mixed sources, but load sources need + // more careful handling (see the block of code below that handles + // consecutive loads). + if (StoreValOpcode != ISD::EXTRACT_VECTOR_ELT && + StoreValOpcode != ISD::EXTRACT_SUBVECTOR) + return false; - // Remove merged stores for next iteration. - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); - NumConsecutiveStores -= NumElem; + // Find a legal type for the vector store. + unsigned Elts = i + 1; + if (IsVec) { + // When merging vector stores, get the total number of elements. + Elts *= MemVT.getVectorNumElements(); } - continue; + EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); + bool IsFast; + if (TLI.isTypeLegal(Ty) && + TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS, + FirstStoreAlign, &IsFast) && IsFast) + NumStoresToMerge = i + 1; } - // When extracting multiple vector elements, try to store them - // in one vector store rather than a sequence of scalar stores. - if (IsExtractVecSrc) { - // Loop on Consecutive Stores on success. - while (NumConsecutiveStores >= 2) { - LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; - unsigned FirstStoreAS = FirstInChain->getAddressSpace(); - unsigned FirstStoreAlign = FirstInChain->getAlignment(); - unsigned NumStoresToMerge = 1; - for (unsigned i = 0; i < NumConsecutiveStores; ++i) { - // Find a legal type for the vector store. - unsigned Elts = (i + 1) * NumMemElts; - EVT Ty = - EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); - bool IsFast; - - // Break early when size is too large to be legal. - if (Ty.getSizeInBits() > MaximumLegalStoreInBits) - break; - - if (TLI.isTypeLegal(Ty) && - TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) && - TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS, - FirstStoreAlign, &IsFast) && - IsFast) - NumStoresToMerge = i + 1; - } - - // Check if we found a legal integer type creating a meaningful - // merge. - if (NumStoresToMerge < 2) { - // We know that candidate stores are in order and of correct - // shape. While there is no mergeable sequence from the - // beginning one may start later in the sequence. The only - // reason a merge of size N could have failed where another of - // the same size would not have, is if the alignment has - // improved. Drop as many candidates as we can here. - unsigned NumSkip = 1; - while ( - (NumSkip < NumConsecutiveStores) && - (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) - NumSkip++; - - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); - NumConsecutiveStores -= NumSkip; - continue; - } + return MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumStoresToMerge, + false, true); + } - // Check that we can merge these candidates without causing a cycle. - if (!checkMergeStoreCandidatesForDependencies( - StoreNodes, NumStoresToMerge, RootNode)) { - StoreNodes.erase(StoreNodes.begin(), - StoreNodes.begin() + NumStoresToMerge); - NumConsecutiveStores -= NumStoresToMerge; - continue; - } + // Below we handle the case of multiple consecutive stores that + // come from multiple consecutive loads. We merge them into a single + // wide load and a single wide store. - RV |= MergeStoresOfConstantsOrVecElts( - StoreNodes, MemVT, NumStoresToMerge, false, true, false); + // Look for load nodes which are used by the stored values. + SmallVector<MemOpLink, 8> LoadNodes; - StoreNodes.erase(StoreNodes.begin(), - StoreNodes.begin() + NumStoresToMerge); - NumConsecutiveStores -= NumStoresToMerge; - } - continue; - } + // Find acceptable loads. Loads need to have the same chain (token factor), + // must not be zext, volatile, indexed, and they must be consecutive. + BaseIndexOffset LdBasePtr; + for (unsigned i=0; i<LastConsecutiveStore+1; ++i) { + StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode); + LoadSDNode *Ld = dyn_cast<LoadSDNode>(St->getValue()); + if (!Ld) break; - // Below we handle the case of multiple consecutive stores that - // come from multiple consecutive loads. We merge them into a single - // wide load and a single wide store. + // Loads must only have one use. + if (!Ld->hasNUsesOfValue(1, 0)) + break; - // Look for load nodes which are used by the stored values. - SmallVector<MemOpLink, 8> LoadNodes; + // The memory operands must not be volatile. + if (Ld->isVolatile() || Ld->isIndexed()) + break; - // Find acceptable loads. Loads need to have the same chain (token factor), - // must not be zext, volatile, indexed, and they must be consecutive. - BaseIndexOffset LdBasePtr; + // We do not accept ext loads. + if (Ld->getExtensionType() != ISD::NON_EXTLOAD) + break; - for (unsigned i = 0; i < NumConsecutiveStores; ++i) { - StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode); - SDValue Val = peekThroughBitcasts(St->getValue()); - LoadSDNode *Ld = cast<LoadSDNode>(Val); - - BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG); - // If this is not the first ptr that we check. - int64_t LdOffset = 0; - if (LdBasePtr.getBase().getNode()) { - // The base ptr must be the same. - if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset)) - break; - } else { - // Check that all other base pointers are the same as this one. - LdBasePtr = LdPtr; - } + // The stored memory type must be the same. + if (Ld->getMemoryVT() != MemVT) + break; - // We found a potential memory operand to merge. - LoadNodes.push_back(MemOpLink(Ld, LdOffset)); + BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld->getBasePtr()); + // If this is not the first ptr that we check. + if (LdBasePtr.Base.getNode()) { + // The base ptr must be the same. + if (!LdPtr.equalBaseIndex(LdBasePtr)) + break; + } else { + // Check that all other base pointers are the same as this one. + LdBasePtr = LdPtr; } - while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) { - // If we have load/store pair instructions and we only have two values, - // don't bother merging. - unsigned RequiredAlignment; - if (LoadNodes.size() == 2 && - TLI.hasPairedLoad(MemVT, RequiredAlignment) && - StoreNodes[0].MemNode->getAlignment() >= RequiredAlignment) { - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2); - LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2); - break; - } - LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; - unsigned FirstStoreAS = FirstInChain->getAddressSpace(); - unsigned FirstStoreAlign = FirstInChain->getAlignment(); - LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode); - unsigned FirstLoadAS = FirstLoad->getAddressSpace(); - unsigned FirstLoadAlign = FirstLoad->getAlignment(); - - // Scan the memory operations on the chain and find the first - // non-consecutive load memory address. These variables hold the index in - // the store node array. - - unsigned LastConsecutiveLoad = 1; - - // This variable refers to the size and not index in the array. - unsigned LastLegalVectorType = 1; - unsigned LastLegalIntegerType = 1; - bool isDereferenceable = true; - bool DoIntegerTruncate = false; - StartAddress = LoadNodes[0].OffsetFromBase; - SDValue FirstChain = FirstLoad->getChain(); - for (unsigned i = 1; i < LoadNodes.size(); ++i) { - // All loads must share the same chain. - if (LoadNodes[i].MemNode->getChain() != FirstChain) - break; + // We found a potential memory operand to merge. + LoadNodes.push_back(MemOpLink(Ld, LdPtr.Offset, 0)); + } - int64_t CurrAddress = LoadNodes[i].OffsetFromBase; - if (CurrAddress - StartAddress != (ElementSizeBytes * i)) - break; - LastConsecutiveLoad = i; + if (LoadNodes.size() < 2) + return false; - if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable()) - isDereferenceable = false; + // If we have load/store pair instructions and we only have two values, + // don't bother. + unsigned RequiredAlignment; + if (LoadNodes.size() == 2 && TLI.hasPairedLoad(MemVT, RequiredAlignment) && + St->getAlignment() >= RequiredAlignment) + return false; - // Find a legal type for the vector store. - unsigned Elts = (i + 1) * NumMemElts; - EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); + LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode); + unsigned FirstLoadAS = FirstLoad->getAddressSpace(); + unsigned FirstLoadAlign = FirstLoad->getAlignment(); + + // Scan the memory operations on the chain and find the first non-consecutive + // load memory address. These variables hold the index in the store node + // array. + unsigned LastConsecutiveLoad = 0; + // This variable refers to the size and not index in the array. + unsigned LastLegalVectorType = 0; + unsigned LastLegalIntegerType = 0; + StartAddress = LoadNodes[0].OffsetFromBase; + SDValue FirstChain = FirstLoad->getChain(); + for (unsigned i = 1; i < LoadNodes.size(); ++i) { + // All loads must share the same chain. + if (LoadNodes[i].MemNode->getChain() != FirstChain) + break; - // Break early when size is too large to be legal. - if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits) - break; + int64_t CurrAddress = LoadNodes[i].OffsetFromBase; + if (CurrAddress - StartAddress != (ElementSizeBytes * i)) + break; + LastConsecutiveLoad = i; + // Find a legal type for the vector store. + EVT StoreTy = EVT::getVectorVT(Context, MemVT, i+1); + bool IsFastSt, IsFastLd; + if (TLI.isTypeLegal(StoreTy) && + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, + FirstStoreAlign, &IsFastSt) && IsFastSt && + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS, + FirstLoadAlign, &IsFastLd) && IsFastLd) { + LastLegalVectorType = i + 1; + } + + // Find a legal type for the integer store. + unsigned SizeInBits = (i+1) * ElementSizeBytes * 8; + StoreTy = EVT::getIntegerVT(Context, SizeInBits); + if (TLI.isTypeLegal(StoreTy) && + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, + FirstStoreAlign, &IsFastSt) && IsFastSt && + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS, + FirstLoadAlign, &IsFastLd) && IsFastLd) + LastLegalIntegerType = i + 1; + // Or check whether a truncstore and extload is legal. + else if (TLI.getTypeAction(Context, StoreTy) == + TargetLowering::TypePromoteInteger) { + EVT LegalizedStoredValueTy = + TLI.getTypeToTransformTo(Context, StoreTy); + if (TLI.isTruncStoreLegal(LegalizedStoredValueTy, StoreTy) && + TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValueTy, StoreTy) && + TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValueTy, StoreTy) && + TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValueTy, StoreTy) && + TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy, + FirstStoreAS, FirstStoreAlign, &IsFastSt) && + IsFastSt && + TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy, + FirstLoadAS, FirstLoadAlign, &IsFastLd) && + IsFastLd) + LastLegalIntegerType = i+1; + } + } + + // Only use vector types if the vector type is larger than the integer type. + // If they are the same, use integers. + bool UseVectorTy = LastLegalVectorType > LastLegalIntegerType && !NoVectors; + unsigned LastLegalType = std::max(LastLegalVectorType, LastLegalIntegerType); + + // We add +1 here because the LastXXX variables refer to location while + // the NumElem refers to array/index size. + unsigned NumElem = std::min(LastConsecutiveStore, LastConsecutiveLoad) + 1; + NumElem = std::min(LastLegalType, NumElem); + + if (NumElem < 2) + return false; - bool IsFastSt, IsFastLd; - if (TLI.isTypeLegal(StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && - TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, - FirstStoreAlign, &IsFastSt) && - IsFastSt && - TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS, - FirstLoadAlign, &IsFastLd) && - IsFastLd) { - LastLegalVectorType = i + 1; - } + // Collect the chains from all merged stores. + SmallVector<SDValue, 8> MergeStoreChains; + MergeStoreChains.push_back(StoreNodes[0].MemNode->getChain()); - // Find a legal type for the integer store. - unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8; - StoreTy = EVT::getIntegerVT(Context, SizeInBits); - if (TLI.isTypeLegal(StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && - TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, - FirstStoreAlign, &IsFastSt) && - IsFastSt && - TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS, - FirstLoadAlign, &IsFastLd) && - IsFastLd) { - LastLegalIntegerType = i + 1; - DoIntegerTruncate = false; - // Or check whether a truncstore and extload is legal. - } else if (TLI.getTypeAction(Context, StoreTy) == - TargetLowering::TypePromoteInteger) { - EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy); - if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) && - TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, - StoreTy) && - TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, - StoreTy) && - TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) && - TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, - FirstStoreAlign, &IsFastSt) && - IsFastSt && - TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS, - FirstLoadAlign, &IsFastLd) && - IsFastLd) { - LastLegalIntegerType = i + 1; - DoIntegerTruncate = true; - } - } - } + // The latest Node in the DAG. + unsigned LatestNodeUsed = 0; + for (unsigned i=1; i<NumElem; ++i) { + // Find a chain for the new wide-store operand. Notice that some + // of the store nodes that we found may not be selected for inclusion + // in the wide store. The chain we use needs to be the chain of the + // latest store node which is *used* and replaced by the wide store. + if (StoreNodes[i].SequenceNum < StoreNodes[LatestNodeUsed].SequenceNum) + LatestNodeUsed = i; - // Only use vector types if the vector type is larger than the integer - // type. If they are the same, use integers. - bool UseVectorTy = - LastLegalVectorType > LastLegalIntegerType && !NoVectors; - unsigned LastLegalType = - std::max(LastLegalVectorType, LastLegalIntegerType); - - // We add +1 here because the LastXXX variables refer to location while - // the NumElem refers to array/index size. - unsigned NumElem = - std::min(NumConsecutiveStores, LastConsecutiveLoad + 1); - NumElem = std::min(LastLegalType, NumElem); - - if (NumElem < 2) { - // We know that candidate stores are in order and of correct - // shape. While there is no mergeable sequence from the - // beginning one may start later in the sequence. The only - // reason a merge of size N could have failed where another of - // the same size would not have is if the alignment or either - // the load or store has improved. Drop as many candidates as we - // can here. - unsigned NumSkip = 1; - while ((NumSkip < LoadNodes.size()) && - (LoadNodes[NumSkip].MemNode->getAlignment() <= FirstLoadAlign) && - (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) - NumSkip++; - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); - LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip); - NumConsecutiveStores -= NumSkip; - continue; - } + MergeStoreChains.push_back(StoreNodes[i].MemNode->getChain()); + } - // Check that we can merge these candidates without causing a cycle. - if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem, - RootNode)) { - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); - LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem); - NumConsecutiveStores -= NumElem; - continue; - } + LSBaseSDNode *LatestOp = StoreNodes[LatestNodeUsed].MemNode; - // Find if it is better to use vectors or integers to load and store - // to memory. - EVT JointMemOpVT; - if (UseVectorTy) { - // Find a legal type for the vector store. - unsigned Elts = NumElem * NumMemElts; - JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); - } else { - unsigned SizeInBits = NumElem * ElementSizeBytes * 8; - JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits); - } + // Find if it is better to use vectors or integers to load and store + // to memory. + EVT JointMemOpVT; + if (UseVectorTy) { + JointMemOpVT = EVT::getVectorVT(Context, MemVT, NumElem); + } else { + unsigned SizeInBits = NumElem * ElementSizeBytes * 8; + JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits); + } - SDLoc LoadDL(LoadNodes[0].MemNode); - SDLoc StoreDL(StoreNodes[0].MemNode); - - // The merged loads are required to have the same incoming chain, so - // using the first's chain is acceptable. - - SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem); - AddToWorklist(NewStoreChain.getNode()); - - MachineMemOperand::Flags MMOFlags = - isDereferenceable ? MachineMemOperand::MODereferenceable - : MachineMemOperand::MONone; - - SDValue NewLoad, NewStore; - if (UseVectorTy || !DoIntegerTruncate) { - NewLoad = - DAG.getLoad(JointMemOpVT, LoadDL, FirstLoad->getChain(), - FirstLoad->getBasePtr(), FirstLoad->getPointerInfo(), - FirstLoadAlign, MMOFlags); - NewStore = DAG.getStore( - NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(), - FirstInChain->getPointerInfo(), FirstStoreAlign); - } else { // This must be the truncstore/extload case - EVT ExtendedTy = - TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT); - NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy, - FirstLoad->getChain(), FirstLoad->getBasePtr(), - FirstLoad->getPointerInfo(), JointMemOpVT, - FirstLoadAlign, MMOFlags); - NewStore = DAG.getTruncStore(NewStoreChain, StoreDL, NewLoad, - FirstInChain->getBasePtr(), - FirstInChain->getPointerInfo(), - JointMemOpVT, FirstInChain->getAlignment(), - FirstInChain->getMemOperand()->getFlags()); - } + SDLoc LoadDL(LoadNodes[0].MemNode); + SDLoc StoreDL(StoreNodes[0].MemNode); - // Transfer chain users from old loads to the new load. - for (unsigned i = 0; i < NumElem; ++i) { - LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), - SDValue(NewLoad.getNode(), 1)); - } + // The merged loads are required to have the same incoming chain, so + // using the first's chain is acceptable. + SDValue NewLoad = DAG.getLoad( + JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(), + FirstLoad->getPointerInfo(), false, false, false, FirstLoadAlign); - // Replace the all stores with the new store. Recursively remove - // corresponding value if its no longer used. - for (unsigned i = 0; i < NumElem; ++i) { - SDValue Val = StoreNodes[i].MemNode->getOperand(1); - CombineTo(StoreNodes[i].MemNode, NewStore); - if (Val.getNode()->use_empty()) - recursivelyDeleteUnusedNodes(Val.getNode()); - } + SDValue NewStoreChain = + DAG.getNode(ISD::TokenFactor, StoreDL, MVT::Other, MergeStoreChains); + + SDValue NewStore = DAG.getStore( + NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(), + FirstInChain->getPointerInfo(), false, false, FirstStoreAlign); - RV = true; - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); - LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem); - NumConsecutiveStores -= NumElem; + // Transfer chain users from old loads to the new load. + for (unsigned i = 0; i < NumElem; ++i) { + LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode); + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), + SDValue(NewLoad.getNode(), 1)); + } + + bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA + : DAG.getSubtarget().useAA(); + if (UseAA) { + // Replace the all stores with the new store. + for (unsigned i = 0; i < NumElem; ++i) + CombineTo(StoreNodes[i].MemNode, NewStore); + } else { + // Replace the last store with the new store. + CombineTo(LatestOp, NewStore); + // Erase all other stores. + for (unsigned i = 0; i < NumElem; ++i) { + // Remove all Store nodes. + if (StoreNodes[i].MemNode == LatestOp) + continue; + StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode); + DAG.ReplaceAllUsesOfValueWith(SDValue(St, 0), St->getChain()); + deleteAndRecombine(St); } } - return RV; + + return true; } SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) { @@ -15131,17 +11824,21 @@ SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) { std::swap(Lo, Hi); unsigned Alignment = ST->getAlignment(); - MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags(); + bool isVolatile = ST->isVolatile(); + bool isNonTemporal = ST->isNonTemporal(); AAMDNodes AAInfo = ST->getAAInfo(); - SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(), - ST->getAlignment(), MMOFlags, AAInfo); + SDValue St0 = DAG.getStore(Chain, DL, Lo, + Ptr, ST->getPointerInfo(), + isVolatile, isNonTemporal, + ST->getAlignment(), AAInfo); Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, DAG.getConstant(4, DL, Ptr.getValueType())); Alignment = MinAlign(Alignment, 4U); - SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr, - ST->getPointerInfo().getWithOffset(4), - Alignment, MMOFlags, AAInfo); + SDValue St1 = DAG.getStore(Chain, DL, Hi, + Ptr, ST->getPointerInfo().getWithOffset(4), + isVolatile, isNonTemporal, + Alignment, AAInfo); return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, St0, St1); } @@ -15160,42 +11857,34 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { // resultant store does not need a higher alignment than the original. if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() && ST->isUnindexed()) { + unsigned OrigAlign = ST->getAlignment(); EVT SVT = Value.getOperand(0).getValueType(); - // If the store is volatile, we only want to change the store type if the - // resulting store is legal. Otherwise we might increase the number of - // memory accesses. We don't care if the original type was legal or not - // as we assume software couldn't rely on the number of accesses of an - // illegal type. - if (((!LegalOperations && !ST->isVolatile()) || - TLI.isOperationLegal(ISD::STORE, SVT)) && - TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT)) { - unsigned OrigAlign = ST->getAlignment(); - bool Fast = false; - if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), SVT, - ST->getAddressSpace(), OrigAlign, &Fast) && - Fast) { - return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr, - ST->getPointerInfo(), OrigAlign, - ST->getMemOperand()->getFlags(), ST->getAAInfo()); - } - } + unsigned Align = DAG.getDataLayout().getABITypeAlignment( + SVT.getTypeForEVT(*DAG.getContext())); + if (Align <= OrigAlign && + ((!LegalOperations && !ST->isVolatile()) || + TLI.isOperationLegalOrCustom(ISD::STORE, SVT))) + return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), + Ptr, ST->getPointerInfo(), ST->isVolatile(), + ST->isNonTemporal(), OrigAlign, + ST->getAAInfo()); } // Turn 'store undef, Ptr' -> nothing. - if (Value.isUndef() && ST->isUnindexed()) + if (Value.getOpcode() == ISD::UNDEF && ST->isUnindexed()) return Chain; // Try to infer better alignment information than the store already has. if (OptLevel != CodeGenOpt::None && ST->isUnindexed()) { if (unsigned Align = DAG.InferPtrAlignment(Ptr)) { - if (Align > ST->getAlignment() && ST->getSrcValueOffset() % Align == 0) { + if (Align > ST->getAlignment()) { SDValue NewStore = - DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(), - ST->getMemoryVT(), Align, - ST->getMemOperand()->getFlags(), ST->getAAInfo()); - // NewStore will always be N as we are only refining the alignment - assert(NewStore.getNode() == N); - (void)NewStore; + DAG.getTruncStore(Chain, SDLoc(N), Value, + Ptr, ST->getPointerInfo(), ST->getMemoryVT(), + ST->isVolatile(), ST->isNonTemporal(), Align, + ST->getAAInfo()); + if (NewStore.getNode() != N) + return CombineTo(ST, NewStore, true); } } } @@ -15205,7 +11894,19 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { if (SDValue NewST = TransformFPLoadStorePair(N)) return NewST; - if (ST->isUnindexed()) { + bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA + : DAG.getSubtarget().useAA(); +#ifndef NDEBUG + if (CombinerAAOnlyFunc.getNumOccurrences() && + CombinerAAOnlyFunc != DAG.getMachineFunction().getName()) + UseAA = false; +#endif + if (UseAA && ST->isUnindexed()) { + // FIXME: We should do this even without AA enabled. AA will just allow + // FindBetterChain to work in more situations. The problem with this is that + // any combine that expects memory operations to be on consecutive chains + // first needs to be updated to look for users of the same chain. + // Walk up chain skipping non-aliasing memory nodes, on this store and any // adjacent stores. if (findBetterNeighborChains(ST)) { @@ -15213,20 +11914,23 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { // manipulation. Return the original node to not do anything else. return SDValue(ST, 0); } - Chain = ST->getChain(); } + // Try transforming N to an indexed store. + if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N)) + return SDValue(N, 0); + // FIXME: is there such a thing as a truncating indexed store? if (ST->isTruncatingStore() && ST->isUnindexed() && - Value.getValueType().isInteger() && - (!isa<ConstantSDNode>(Value) || - !cast<ConstantSDNode>(Value)->isOpaque())) { + Value.getValueType().isInteger()) { // See if we can simplify the input to this truncstore with knowledge that // only the low bits are being used. For example: // "truncstore (or (shl x, 8), y), i8" -> "truncstore y, i8" - SDValue Shorter = DAG.GetDemandedBits( - Value, APInt::getLowBitsSet(Value.getScalarValueSizeInBits(), - ST->getMemoryVT().getScalarSizeInBits())); + SDValue Shorter = + GetDemandedBits(Value, + APInt::getLowBitsSet( + Value.getValueType().getScalarType().getSizeInBits(), + ST->getMemoryVT().getScalarType().getSizeInBits())); AddToWorklist(Value.getNode()); if (Shorter.getNode()) return DAG.getTruncStore(Chain, SDLoc(N), Shorter, @@ -15234,18 +11938,11 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { // Otherwise, see if we can simplify the operation with // SimplifyDemandedBits, which only works if the value has a single use. - if (SimplifyDemandedBits( - Value, - APInt::getLowBitsSet(Value.getScalarValueSizeInBits(), - ST->getMemoryVT().getScalarSizeInBits()))) { - // Re-visit the store if anything changed and the store hasn't been merged - // with another node (N is deleted) SimplifyDemandedBits will add Value's - // node back to the worklist if necessary, but we also need to re-visit - // the Store node itself. - if (N->getOpcode() != ISD::DELETED_NODE) - AddToWorklist(N); + if (SimplifyDemandedBits(Value, + APInt::getLowBitsSet( + Value.getValueType().getScalarType().getSizeInBits(), + ST->getMemoryVT().getScalarType().getSizeInBits()))) return SDValue(N, 0); - } } // If this is a load followed by a store to the same location, then the store @@ -15261,28 +11958,14 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { } } + // If this is a store followed by a store with the same value to the same + // location, then the store is dead/noop. if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) { - if (ST->isUnindexed() && !ST->isVolatile() && ST1->isUnindexed() && - !ST1->isVolatile() && ST1->getBasePtr() == Ptr && - ST->getMemoryVT() == ST1->getMemoryVT()) { - // If this is a store followed by a store with the same value to the same - // location, then the store is dead/noop. - if (ST1->getValue() == Value) { - // The store is dead, remove it. - return Chain; - } - - // If this is a store who's preceeding store to the same location - // and no one other node is chained to that store we can effectively - // drop the store. Do not remove stores to undef as they may be used as - // data sinks. - if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() && - !ST1->getBasePtr().isUndef()) { - // ST1 is fully overwritten and can be elided. Combine with it's chain - // value. - CombineTo(ST1, ST1->getChain()); - return SDValue(); - } + if (ST1->getBasePtr() == Ptr && ST->getMemoryVT() == ST1->getMemoryVT() && + ST1->getValue() == Value && ST->isUnindexed() && !ST->isVolatile() && + ST1->isUnindexed() && !ST1->isVolatile()) { + // The store is dead, remove it. + return Chain; } } @@ -15296,237 +11979,57 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { Ptr, ST->getMemoryVT(), ST->getMemOperand()); } - // Always perform this optimization before types are legal. If the target - // prefers, also try this after legalization to catch stores that were created - // by intrinsics or other nodes. - if (!LegalTypes || (TLI.mergeStoresAfterLegalization())) { - while (true) { + // Only perform this optimization before the types are legal, because we + // don't want to perform this optimization on every DAGCombine invocation. + if (!LegalTypes) { + bool EverChanged = false; + + do { // There can be multiple store sequences on the same chain. // Keep trying to merge store sequences until we are unable to do so // or until we merge the last store on the chain. bool Changed = MergeConsecutiveStores(ST); + EverChanged |= Changed; if (!Changed) break; - // Return N as merge only uses CombineTo and no worklist clean - // up is necessary. - if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N)) - return SDValue(N, 0); - } - } + } while (ST->getOpcode() != ISD::DELETED_NODE); - // Try transforming N to an indexed store. - if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N)) - return SDValue(N, 0); + if (EverChanged) + return SDValue(N, 0); + } // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr' // // Make sure to do this only after attempting to merge stores in order to // avoid changing the types of some subset of stores due to visit order, // preventing their merging. - if (isa<ConstantFPSDNode>(ST->getValue())) { + if (isa<ConstantFPSDNode>(Value)) { if (SDValue NewSt = replaceStoreOfFPConstant(ST)) return NewSt; } - if (SDValue NewSt = splitMergedValStore(ST)) - return NewSt; - return ReduceLoadOpStoreWidth(N); } -/// For the instruction sequence of store below, F and I values -/// are bundled together as an i64 value before being stored into memory. -/// Sometimes it is more efficent to generate separate stores for F and I, -/// which can remove the bitwise instructions or sink them to colder places. -/// -/// (store (or (zext (bitcast F to i32) to i64), -/// (shl (zext I to i64), 32)), addr) --> -/// (store F, addr) and (store I, addr+4) -/// -/// Similarly, splitting for other merged store can also be beneficial, like: -/// For pair of {i32, i32}, i64 store --> two i32 stores. -/// For pair of {i32, i16}, i64 store --> two i32 stores. -/// For pair of {i16, i16}, i32 store --> two i16 stores. -/// For pair of {i16, i8}, i32 store --> two i16 stores. -/// For pair of {i8, i8}, i16 store --> two i8 stores. -/// -/// We allow each target to determine specifically which kind of splitting is -/// supported. -/// -/// The store patterns are commonly seen from the simple code snippet below -/// if only std::make_pair(...) is sroa transformed before inlined into hoo. -/// void goo(const std::pair<int, float> &); -/// hoo() { -/// ... -/// goo(std::make_pair(tmp, ftmp)); -/// ... -/// } -/// -SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) { - if (OptLevel == CodeGenOpt::None) - return SDValue(); - - SDValue Val = ST->getValue(); - SDLoc DL(ST); - - // Match OR operand. - if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR) - return SDValue(); - - // Match SHL operand and get Lower and Higher parts of Val. - SDValue Op1 = Val.getOperand(0); - SDValue Op2 = Val.getOperand(1); - SDValue Lo, Hi; - if (Op1.getOpcode() != ISD::SHL) { - std::swap(Op1, Op2); - if (Op1.getOpcode() != ISD::SHL) - return SDValue(); - } - Lo = Op2; - Hi = Op1.getOperand(0); - if (!Op1.hasOneUse()) - return SDValue(); - - // Match shift amount to HalfValBitSize. - unsigned HalfValBitSize = Val.getValueSizeInBits() / 2; - ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1)); - if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize) - return SDValue(); - - // Lo and Hi are zero-extended from int with size less equal than 32 - // to i64. - if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() || - !Lo.getOperand(0).getValueType().isScalarInteger() || - Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize || - Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() || - !Hi.getOperand(0).getValueType().isScalarInteger() || - Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize) - return SDValue(); - - // Use the EVT of low and high parts before bitcast as the input - // of target query. - EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST) - ? Lo.getOperand(0).getValueType() - : Lo.getValueType(); - EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST) - ? Hi.getOperand(0).getValueType() - : Hi.getValueType(); - if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy)) - return SDValue(); - - // Start to split store. - unsigned Alignment = ST->getAlignment(); - MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags(); - AAMDNodes AAInfo = ST->getAAInfo(); - - // Change the sizes of Lo and Hi's value types to HalfValBitSize. - EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize); - Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0)); - Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0)); - - SDValue Chain = ST->getChain(); - SDValue Ptr = ST->getBasePtr(); - // Lower value store. - SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(), - ST->getAlignment(), MMOFlags, AAInfo); - Ptr = - DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, - DAG.getConstant(HalfValBitSize / 8, DL, Ptr.getValueType())); - // Higher value store. - SDValue St1 = - DAG.getStore(St0, DL, Hi, Ptr, - ST->getPointerInfo().getWithOffset(HalfValBitSize / 8), - Alignment / 2, MMOFlags, AAInfo); - return St1; -} - -/// Convert a disguised subvector insertion into a shuffle: -/// insert_vector_elt V, (bitcast X from vector type), IdxC --> -/// bitcast(shuffle (bitcast V), (extended X), Mask) -/// Note: We do not use an insert_subvector node because that requires a legal -/// subvector type. -SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) { - SDValue InsertVal = N->getOperand(1); - if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() || - !InsertVal.getOperand(0).getValueType().isVector()) - return SDValue(); - - SDValue SubVec = InsertVal.getOperand(0); - SDValue DestVec = N->getOperand(0); - EVT SubVecVT = SubVec.getValueType(); - EVT VT = DestVec.getValueType(); - unsigned NumSrcElts = SubVecVT.getVectorNumElements(); - unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits(); - unsigned NumMaskVals = ExtendRatio * NumSrcElts; - - // Step 1: Create a shuffle mask that implements this insert operation. The - // vector that we are inserting into will be operand 0 of the shuffle, so - // those elements are just 'i'. The inserted subvector is in the first - // positions of operand 1 of the shuffle. Example: - // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7} - SmallVector<int, 16> Mask(NumMaskVals); - for (unsigned i = 0; i != NumMaskVals; ++i) { - if (i / NumSrcElts == InsIndex) - Mask[i] = (i % NumSrcElts) + NumMaskVals; - else - Mask[i] = i; - } - - // Bail out if the target can not handle the shuffle we want to create. - EVT SubVecEltVT = SubVecVT.getVectorElementType(); - EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals); - if (!TLI.isShuffleMaskLegal(Mask, ShufVT)) - return SDValue(); - - // Step 2: Create a wide vector from the inserted source vector by appending - // undefined elements. This is the same size as our destination vector. - SDLoc DL(N); - SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT)); - ConcatOps[0] = SubVec; - SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps); - - // Step 3: Shuffle in the padded subvector. - SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec); - SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask); - AddToWorklist(PaddedSubV.getNode()); - AddToWorklist(DestVecBC.getNode()); - AddToWorklist(Shuf.getNode()); - return DAG.getBitcast(VT, Shuf); -} - SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { SDValue InVec = N->getOperand(0); SDValue InVal = N->getOperand(1); SDValue EltNo = N->getOperand(2); - SDLoc DL(N); + SDLoc dl(N); // If the inserted element is an UNDEF, just use the input vector. - if (InVal.isUndef()) + if (InVal.getOpcode() == ISD::UNDEF) return InVec; EVT VT = InVec.getValueType(); - unsigned NumElts = VT.getVectorNumElements(); - - // Remove redundant insertions: - // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x - if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT && - InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1)) - return InVec; - auto *IndexC = dyn_cast<ConstantSDNode>(EltNo); - if (!IndexC) { - // If this is variable insert to undef vector, it might be better to splat: - // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... > - if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) { - SmallVector<SDValue, 8> Ops(NumElts, InVal); - return DAG.getBuildVector(VT, DL, Ops); - } + // If we can't generate a legal BUILD_VECTOR, exit + if (LegalOperations && !TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) return SDValue(); - } - // We must know which element is being inserted for folds below here. - unsigned Elt = IndexC->getZExtValue(); - if (SDValue Shuf = combineInsertEltToShuffle(N, Elt)) - return Shuf; + // Check that we know which element is being inserted + if (!isa<ConstantSDNode>(EltNo)) + return SDValue(); + unsigned Elt = cast<ConstantSDNode>(EltNo)->getZExtValue(); // Canonicalize insert_vector_elt dag nodes. // Example: @@ -15537,10 +12040,11 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { // do this only if indices are both constants and Idx1 < Idx0. if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse() && isa<ConstantSDNode>(InVec.getOperand(2))) { - unsigned OtherElt = InVec.getConstantOperandVal(2); + unsigned OtherElt = + cast<ConstantSDNode>(InVec.getOperand(2))->getZExtValue(); if (Elt < OtherElt) { // Swap nodes. - SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, + SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(N), VT, InVec.getOperand(0), InVal, EltNo); AddToWorklist(NewOp.getNode()); return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()), @@ -15548,10 +12052,6 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { } } - // If we can't generate a legal BUILD_VECTOR, exit - if (LegalOperations && !TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) - return SDValue(); - // Check that the operand is a BUILD_VECTOR (or UNDEF, which can essentially // be converted to a BUILD_VECTOR). Fill in the Ops vector with the // vector elements. @@ -15561,30 +12061,31 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { if (InVec.getOpcode() == ISD::BUILD_VECTOR && InVec.hasOneUse()) { Ops.append(InVec.getNode()->op_begin(), InVec.getNode()->op_end()); - } else if (InVec.isUndef()) { - Ops.append(NumElts, DAG.getUNDEF(InVal.getValueType())); + } else if (InVec.getOpcode() == ISD::UNDEF) { + unsigned NElts = VT.getVectorNumElements(); + Ops.append(NElts, DAG.getUNDEF(InVal.getValueType())); } else { return SDValue(); } - assert(Ops.size() == NumElts && "Unexpected vector size"); // Insert the element if (Elt < Ops.size()) { // All the operands of BUILD_VECTOR must have the same type; // we enforce that here. EVT OpVT = Ops[0].getValueType(); - Ops[Elt] = OpVT.isInteger() ? DAG.getAnyExtOrTrunc(InVal, DL, OpVT) : InVal; + if (InVal.getValueType() != OpVT) + InVal = OpVT.bitsGT(InVal.getValueType()) ? + DAG.getNode(ISD::ANY_EXTEND, dl, OpVT, InVal) : + DAG.getNode(ISD::TRUNCATE, dl, OpVT, InVal); + Ops[Elt] = InVal; } // Return the new vector - return DAG.getBuildVector(VT, DL, Ops); + return DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Ops); } -SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT, - SDValue EltNo, - LoadSDNode *OriginalLoad) { - assert(!OriginalLoad->isVolatile()); - +SDValue DAGCombiner::ReplaceExtractVectorEltOfLoadWithNarrowedLoad( + SDNode *EVE, EVT InVecVT, SDValue EltNo, LoadSDNode *OriginalLoad) { EVT ResultVT = EVE->getValueType(0); EVT VecEltVT = InVecVT.getVectorElementType(); unsigned Align = OriginalLoad->getAlignment(); @@ -15594,11 +12095,6 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT, if (NewAlign > Align || !TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT)) return SDValue(); - ISD::LoadExtType ExtTy = ResultVT.bitsGT(VecEltVT) ? - ISD::NON_EXTLOAD : ISD::EXTLOAD; - if (!TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT)) - return SDValue(); - Align = NewAlign; SDValue NewPtr = OriginalLoad->getBasePtr(); @@ -15635,20 +12131,21 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT, VecEltVT) ? ISD::ZEXTLOAD : ISD::EXTLOAD; - Load = DAG.getExtLoad(ExtType, SDLoc(EVE), ResultVT, - OriginalLoad->getChain(), NewPtr, MPI, VecEltVT, - Align, OriginalLoad->getMemOperand()->getFlags(), - OriginalLoad->getAAInfo()); + Load = DAG.getExtLoad( + ExtType, SDLoc(EVE), ResultVT, OriginalLoad->getChain(), NewPtr, MPI, + VecEltVT, OriginalLoad->isVolatile(), OriginalLoad->isNonTemporal(), + OriginalLoad->isInvariant(), Align, OriginalLoad->getAAInfo()); Chain = Load.getValue(1); } else { - Load = DAG.getLoad(VecEltVT, SDLoc(EVE), OriginalLoad->getChain(), NewPtr, - MPI, Align, OriginalLoad->getMemOperand()->getFlags(), - OriginalLoad->getAAInfo()); + Load = DAG.getLoad( + VecEltVT, SDLoc(EVE), OriginalLoad->getChain(), NewPtr, MPI, + OriginalLoad->isVolatile(), OriginalLoad->isNonTemporal(), + OriginalLoad->isInvariant(), Align, OriginalLoad->getAAInfo()); Chain = Load.getValue(1); if (ResultVT.bitsLT(VecEltVT)) Load = DAG.getNode(ISD::TRUNCATE, SDLoc(EVE), ResultVT, Load); else - Load = DAG.getBitcast(ResultVT, Load); + Load = DAG.getNode(ISD::BITCAST, SDLoc(EVE), ResultVT, Load); } WorklistRemover DeadNodes(*this); SDValue From[] = { SDValue(EVE, 0), SDValue(OriginalLoad, 1) }; @@ -15664,162 +12161,74 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT, return SDValue(EVE, 0); } -/// Transform a vector binary operation into a scalar binary operation by moving -/// the math/logic after an extract element of a vector. -static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG, - bool LegalOperations) { - SDValue Vec = ExtElt->getOperand(0); - SDValue Index = ExtElt->getOperand(1); - auto *IndexC = dyn_cast<ConstantSDNode>(Index); - if (!IndexC || !ISD::isBinaryOp(Vec.getNode()) || !Vec.hasOneUse()) - return SDValue(); - - // Targets may want to avoid this to prevent an expensive register transfer. - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (!TLI.shouldScalarizeBinop(Vec)) - return SDValue(); - - // Extracting an element of a vector constant is constant-folded, so this - // transform is just replacing a vector op with a scalar op while moving the - // extract. - SDValue Op0 = Vec.getOperand(0); - SDValue Op1 = Vec.getOperand(1); - if (isAnyConstantBuildVector(Op0, true) || - isAnyConstantBuildVector(Op1, true)) { - // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C' - // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC) - SDLoc DL(ExtElt); - EVT VT = ExtElt->getValueType(0); - SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index); - SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index); - return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1); - } - - return SDValue(); -} - SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { - SDValue VecOp = N->getOperand(0); - SDValue Index = N->getOperand(1); - EVT ScalarVT = N->getValueType(0); - EVT VecVT = VecOp.getValueType(); - if (VecOp.isUndef()) - return DAG.getUNDEF(ScalarVT); - - // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val - // - // This only really matters if the index is non-constant since other combines - // on the constant elements already work. - SDLoc DL(N); - if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT && - Index == VecOp.getOperand(2)) { - SDValue Elt = VecOp.getOperand(1); - return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt; - } - // (vextract (scalar_to_vector val, 0) -> val - if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) { + SDValue InVec = N->getOperand(0); + EVT VT = InVec.getValueType(); + EVT NVT = N->getValueType(0); + + if (InVec.getOpcode() == ISD::SCALAR_TO_VECTOR) { // Check if the result type doesn't match the inserted element type. A // SCALAR_TO_VECTOR may truncate the inserted element and the // EXTRACT_VECTOR_ELT may widen the extracted vector. - SDValue InOp = VecOp.getOperand(0); - if (InOp.getValueType() != ScalarVT) { - assert(InOp.getValueType().isInteger() && ScalarVT.isInteger()); - return DAG.getSExtOrTrunc(InOp, DL, ScalarVT); + SDValue InOp = InVec.getOperand(0); + if (InOp.getValueType() != NVT) { + assert(InOp.getValueType().isInteger() && NVT.isInteger()); + return DAG.getSExtOrTrunc(InOp, SDLoc(InVec), NVT); } return InOp; } - // extract_vector_elt of out-of-bounds element -> UNDEF - auto *IndexC = dyn_cast<ConstantSDNode>(Index); - unsigned NumElts = VecVT.getVectorNumElements(); - if (IndexC && IndexC->getAPIntValue().uge(NumElts)) - return DAG.getUNDEF(ScalarVT); + SDValue EltNo = N->getOperand(1); + ConstantSDNode *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo); // extract_vector_elt (build_vector x, y), 1 -> y - if (IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR && - TLI.isTypeLegal(VecVT) && - (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) { - SDValue Elt = VecOp.getOperand(IndexC->getZExtValue()); + if (ConstEltNo && + InVec.getOpcode() == ISD::BUILD_VECTOR && + TLI.isTypeLegal(VT) && + (InVec.hasOneUse() || + TLI.aggressivelyPreferBuildVectorSources(VT))) { + SDValue Elt = InVec.getOperand(ConstEltNo->getZExtValue()); EVT InEltVT = Elt.getValueType(); // Sometimes build_vector's scalar input types do not match result type. - if (ScalarVT == InEltVT) + if (NVT == InEltVT) return Elt; // TODO: It may be useful to truncate if free if the build_vector implicitly // converts. } - // TODO: These transforms should not require the 'hasOneUse' restriction, but - // there are regressions on multiple targets without it. We can end up with a - // mess of scalar and vector code if we reduce only part of the DAG to scalar. - if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() && - VecOp.hasOneUse()) { - // The vector index of the LSBs of the source depend on the endian-ness. - bool IsLE = DAG.getDataLayout().isLittleEndian(); - unsigned ExtractIndex = IndexC->getZExtValue(); - // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x) - unsigned BCTruncElt = IsLE ? 0 : NumElts - 1; - SDValue BCSrc = VecOp.getOperand(0); - if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger()) - return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, BCSrc); - - if (LegalTypes && BCSrc.getValueType().isInteger() && - BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) { - // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt --> - // trunc i64 X to i32 - SDValue X = BCSrc.getOperand(0); - assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() && - "Extract element and scalar to vector can't change element type " - "from FP to integer."); - unsigned XBitWidth = X.getValueSizeInBits(); - unsigned VecEltBitWidth = VecVT.getScalarSizeInBits(); - BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1; - - // An extract element return value type can be wider than its vector - // operand element type. In that case, the high bits are undefined, so - // it's possible that we may need to extend rather than truncate. - if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) { - assert(XBitWidth % VecEltBitWidth == 0 && - "Scalar bitwidth must be a multiple of vector element bitwidth"); - return DAG.getAnyExtOrTrunc(X, DL, ScalarVT); - } - } - } - - if (SDValue BO = scalarizeExtractedBinop(N, DAG, LegalOperations)) - return BO; - // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT. // We only perform this optimization before the op legalization phase because // we may introduce new vector instructions which are not backed by TD // patterns. For example on AVX, extracting elements from a wide vector // without using extract_subvector. However, if we can find an underlying // scalar value, then we can always use that. - if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) { - auto *Shuf = cast<ShuffleVectorSDNode>(VecOp); + if (ConstEltNo && InVec.getOpcode() == ISD::VECTOR_SHUFFLE) { + int NumElem = VT.getVectorNumElements(); + ShuffleVectorSDNode *SVOp = cast<ShuffleVectorSDNode>(InVec); // Find the new index to extract from. - int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue()); + int OrigElt = SVOp->getMaskElt(ConstEltNo->getZExtValue()); // Extracting an undef index is undef. if (OrigElt == -1) - return DAG.getUNDEF(ScalarVT); + return DAG.getUNDEF(NVT); // Select the right vector half to extract from. SDValue SVInVec; - if (OrigElt < (int)NumElts) { - SVInVec = VecOp.getOperand(0); + if (OrigElt < NumElem) { + SVInVec = InVec->getOperand(0); } else { - SVInVec = VecOp.getOperand(1); - OrigElt -= NumElts; + SVInVec = InVec->getOperand(1); + OrigElt -= NumElem; } if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) { SDValue InOp = SVInVec.getOperand(OrigElt); - if (InOp.getValueType() != ScalarVT) { - assert(InOp.getValueType().isInteger() && ScalarVT.isInteger()); - InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT); + if (InOp.getValueType() != NVT) { + assert(InOp.getValueType().isInteger() && NVT.isInteger()); + InOp = DAG.getSExtOrTrunc(InOp, SDLoc(SVInVec), NVT); } return InOp; @@ -15828,133 +12237,115 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { // FIXME: We should handle recursing on other vector shuffles and // scalar_to_vector here as well. - if (!LegalOperations || - // FIXME: Should really be just isOperationLegalOrCustom. - TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) || - TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) { + if (!LegalOperations) { EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout()); - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec, - DAG.getConstant(OrigElt, DL, IndexTy)); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(N), NVT, SVInVec, + DAG.getConstant(OrigElt, SDLoc(SVOp), IndexTy)); } } - // If only EXTRACT_VECTOR_ELT nodes use the source vector we can - // simplify it based on the (valid) extraction indices. - if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) { - return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT && - Use->getOperand(0) == VecOp && - isa<ConstantSDNode>(Use->getOperand(1)); - })) { - APInt DemandedElts = APInt::getNullValue(NumElts); - for (SDNode *Use : VecOp->uses()) { - auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1)); - if (CstElt->getAPIntValue().ult(NumElts)) - DemandedElts.setBit(CstElt->getZExtValue()); - } - if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) { - // We simplified the vector operand of this extract element. If this - // extract is not dead, visit it again so it is folded properly. - if (N->getOpcode() != ISD::DELETED_NODE) - AddToWorklist(N); - return SDValue(N, 0); - } - } + bool BCNumEltsChanged = false; + EVT ExtVT = VT.getVectorElementType(); + EVT LVT = ExtVT; - // Everything under here is trying to match an extract of a loaded value. // If the result of load has to be truncated, then it's not necessarily // profitable. - bool BCNumEltsChanged = false; - EVT ExtVT = VecVT.getVectorElementType(); - EVT LVT = ExtVT; - if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT)) + if (NVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, NVT)) return SDValue(); - if (VecOp.getOpcode() == ISD::BITCAST) { + if (InVec.getOpcode() == ISD::BITCAST) { // Don't duplicate a load with other uses. - if (!VecOp.hasOneUse()) + if (!InVec.hasOneUse()) return SDValue(); - EVT BCVT = VecOp.getOperand(0).getValueType(); + EVT BCVT = InVec.getOperand(0).getValueType(); if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType())) return SDValue(); - if (NumElts != BCVT.getVectorNumElements()) + if (VT.getVectorNumElements() != BCVT.getVectorNumElements()) BCNumEltsChanged = true; - VecOp = VecOp.getOperand(0); + InVec = InVec.getOperand(0); ExtVT = BCVT.getVectorElementType(); } - // extract (vector load $addr), i --> load $addr + i * size - if (!LegalOperations && !IndexC && VecOp.hasOneUse() && - ISD::isNormalLoad(VecOp.getNode()) && - !Index->hasPredecessor(VecOp.getNode())) { - auto *VecLoad = dyn_cast<LoadSDNode>(VecOp); - if (VecLoad && !VecLoad->isVolatile()) - return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad); + // (vextract (vN[if]M load $addr), i) -> ([if]M load $addr + i * size) + if (!LegalOperations && !ConstEltNo && InVec.hasOneUse() && + ISD::isNormalLoad(InVec.getNode()) && + !N->getOperand(1)->hasPredecessor(InVec.getNode())) { + SDValue Index = N->getOperand(1); + if (LoadSDNode *OrigLoad = dyn_cast<LoadSDNode>(InVec)) + return ReplaceExtractVectorEltOfLoadWithNarrowedLoad(N, VT, Index, + OrigLoad); } // Perform only after legalization to ensure build_vector / vector_shuffle // optimizations have already been done. - if (!LegalOperations || !IndexC) - return SDValue(); + if (!LegalOperations) return SDValue(); // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size) // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size) // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr) - int Elt = IndexC->getZExtValue(); - LoadSDNode *LN0 = nullptr; - if (ISD::isNormalLoad(VecOp.getNode())) { - LN0 = cast<LoadSDNode>(VecOp); - } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR && - VecOp.getOperand(0).getValueType() == ExtVT && - ISD::isNormalLoad(VecOp.getOperand(0).getNode())) { - // Don't duplicate a load with other uses. - if (!VecOp.hasOneUse()) - return SDValue(); - LN0 = cast<LoadSDNode>(VecOp.getOperand(0)); - } - if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) { - // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1) - // => - // (load $addr+1*size) - - // Don't duplicate a load with other uses. - if (!VecOp.hasOneUse()) - return SDValue(); + if (ConstEltNo) { + int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue(); - // If the bit convert changed the number of elements, it is unsafe - // to examine the mask. - if (BCNumEltsChanged) - return SDValue(); + LoadSDNode *LN0 = nullptr; + const ShuffleVectorSDNode *SVN = nullptr; + if (ISD::isNormalLoad(InVec.getNode())) { + LN0 = cast<LoadSDNode>(InVec); + } else if (InVec.getOpcode() == ISD::SCALAR_TO_VECTOR && + InVec.getOperand(0).getValueType() == ExtVT && + ISD::isNormalLoad(InVec.getOperand(0).getNode())) { + // Don't duplicate a load with other uses. + if (!InVec.hasOneUse()) + return SDValue(); - // Select the input vector, guarding against out of range extract vector. - int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt); - VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1); + LN0 = cast<LoadSDNode>(InVec.getOperand(0)); + } else if ((SVN = dyn_cast<ShuffleVectorSDNode>(InVec))) { + // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1) + // => + // (load $addr+1*size) - if (VecOp.getOpcode() == ISD::BITCAST) { // Don't duplicate a load with other uses. - if (!VecOp.hasOneUse()) + if (!InVec.hasOneUse()) return SDValue(); - VecOp = VecOp.getOperand(0); - } - if (ISD::isNormalLoad(VecOp.getNode())) { - LN0 = cast<LoadSDNode>(VecOp); - Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts; - Index = DAG.getConstant(Elt, DL, Index.getValueType()); + // If the bit convert changed the number of elements, it is unsafe + // to examine the mask. + if (BCNumEltsChanged) + return SDValue(); + + // Select the input vector, guarding against out of range extract vector. + unsigned NumElems = VT.getVectorNumElements(); + int Idx = (Elt > (int)NumElems) ? -1 : SVN->getMaskElt(Elt); + InVec = (Idx < (int)NumElems) ? InVec.getOperand(0) : InVec.getOperand(1); + + if (InVec.getOpcode() == ISD::BITCAST) { + // Don't duplicate a load with other uses. + if (!InVec.hasOneUse()) + return SDValue(); + + InVec = InVec.getOperand(0); + } + if (ISD::isNormalLoad(InVec.getNode())) { + LN0 = cast<LoadSDNode>(InVec); + Elt = (Idx < (int)NumElems) ? Idx : Idx - (int)NumElems; + EltNo = DAG.getConstant(Elt, SDLoc(EltNo), EltNo.getValueType()); + } } - } - // Make sure we found a non-volatile load and the extractelement is - // the only use. - if (!LN0 || !LN0->hasNUsesOfValue(1,0) || LN0->isVolatile()) - return SDValue(); + // Make sure we found a non-volatile load and the extractelement is + // the only use. + if (!LN0 || !LN0->hasNUsesOfValue(1,0) || LN0->isVolatile()) + return SDValue(); - // If Idx was -1 above, Elt is going to be -1, so just return undef. - if (Elt == -1) - return DAG.getUNDEF(LVT); + // If Idx was -1 above, Elt is going to be -1, so just return undef. + if (Elt == -1) + return DAG.getUNDEF(LVT); - return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0); + return ReplaceExtractVectorEltOfLoadWithNarrowedLoad(N, VT, EltNo, LN0); + } + + return SDValue(); } // Simplify (build_vec (ext )) to (bitcast (build_vec )) @@ -15969,7 +12360,7 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) { return SDValue(); unsigned NumInScalars = N->getNumOperands(); - SDLoc DL(N); + SDLoc dl(N); EVT VT = N->getValueType(0); // Check to see if this is a BUILD_VECTOR of a bunch of values @@ -15983,7 +12374,7 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) { for (unsigned i = 0; i != NumInScalars; ++i) { SDValue In = N->getOperand(i); // Ignore undef inputs. - if (In.isUndef()) continue; + if (In.getOpcode() == ISD::UNDEF) continue; bool AnyExt = In.getOpcode() == ISD::ANY_EXTEND; bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND; @@ -16028,7 +12419,7 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) { unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits(); assert(ElemRatio > 1 && "Invalid element size ratio"); SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType): - DAG.getConstant(0, DL, SourceType); + DAG.getConstant(0, SDLoc(N), SourceType); unsigned NewBVElems = ElemRatio * VT.getVectorNumElements(); SmallVector<SDValue, 8> Ops(NewBVElems, Filler); @@ -16038,9 +12429,9 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) { SDValue Cast = N->getOperand(i); assert((Cast.getOpcode() == ISD::ANY_EXTEND || Cast.getOpcode() == ISD::ZERO_EXTEND || - Cast.isUndef()) && "Invalid cast opcode"); + Cast.getOpcode() == ISD::UNDEF) && "Invalid cast opcode"); SDValue In; - if (Cast.isUndef()) + if (Cast.getOpcode() == ISD::UNDEF) In = DAG.getUNDEF(SourceType); else In = Cast->getOperand(0); @@ -16056,550 +12447,259 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) { assert(VecVT.getSizeInBits() == VT.getSizeInBits() && "Invalid vector size"); // Check if the new vector type is legal. - if (!isTypeLegal(VecVT) || - (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) && - TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))) - return SDValue(); + if (!isTypeLegal(VecVT)) return SDValue(); // Make the new BUILD_VECTOR. - SDValue BV = DAG.getBuildVector(VecVT, DL, Ops); + SDValue BV = DAG.getNode(ISD::BUILD_VECTOR, dl, VecVT, Ops); // The new BUILD_VECTOR node has the potential to be further optimized. AddToWorklist(BV.getNode()); // Bitcast to the desired type. - return DAG.getBitcast(VT, BV); + return DAG.getNode(ISD::BITCAST, dl, VT, BV); } -SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N, - ArrayRef<int> VectorMask, - SDValue VecIn1, SDValue VecIn2, - unsigned LeftIdx) { - MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout()); - SDValue ZeroIdx = DAG.getConstant(0, DL, IdxTy); - +SDValue DAGCombiner::reduceBuildVecConvertToConvertBuildVec(SDNode *N) { EVT VT = N->getValueType(0); - EVT InVT1 = VecIn1.getValueType(); - EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1; - - unsigned Vec2Offset = 0; - unsigned NumElems = VT.getVectorNumElements(); - unsigned ShuffleNumElems = NumElems; - - // In case both the input vectors are extracted from same base - // vector we do not need extra addend (Vec2Offset) while - // computing shuffle mask. - if (!VecIn2 || !(VecIn1.getOpcode() == ISD::EXTRACT_SUBVECTOR) || - !(VecIn2.getOpcode() == ISD::EXTRACT_SUBVECTOR) || - !(VecIn1.getOperand(0) == VecIn2.getOperand(0))) - Vec2Offset = InVT1.getVectorNumElements(); - - // We can't generate a shuffle node with mismatched input and output types. - // Try to make the types match the type of the output. - if (InVT1 != VT || InVT2 != VT) { - if ((VT.getSizeInBits() % InVT1.getSizeInBits() == 0) && InVT1 == InVT2) { - // If the output vector length is a multiple of both input lengths, - // we can concatenate them and pad the rest with undefs. - unsigned NumConcats = VT.getSizeInBits() / InVT1.getSizeInBits(); - assert(NumConcats >= 2 && "Concat needs at least two inputs!"); - SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1)); - ConcatOps[0] = VecIn1; - ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1); - VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps); - VecIn2 = SDValue(); - } else if (InVT1.getSizeInBits() == VT.getSizeInBits() * 2) { - if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems)) - return SDValue(); - if (!VecIn2.getNode()) { - // If we only have one input vector, and it's twice the size of the - // output, split it in two. - VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, - DAG.getConstant(NumElems, DL, IdxTy)); - VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, ZeroIdx); - // Since we now have shorter input vectors, adjust the offset of the - // second vector's start. - Vec2Offset = NumElems; - } else if (InVT2.getSizeInBits() <= InVT1.getSizeInBits()) { - // VecIn1 is wider than the output, and we have another, possibly - // smaller input. Pad the smaller input with undefs, shuffle at the - // input vector width, and extract the output. - // The shuffle type is different than VT, so check legality again. - if (LegalOperations && - !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1)) - return SDValue(); + unsigned NumInScalars = N->getNumOperands(); + SDLoc dl(N); - // Legalizing INSERT_SUBVECTOR is tricky - you basically have to - // lower it back into a BUILD_VECTOR. So if the inserted type is - // illegal, don't even try. - if (InVT1 != InVT2) { - if (!TLI.isTypeLegal(InVT2)) - return SDValue(); - VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1, - DAG.getUNDEF(InVT1), VecIn2, ZeroIdx); - } - ShuffleNumElems = NumElems * 2; - } else { - // Both VecIn1 and VecIn2 are wider than the output, and VecIn2 is wider - // than VecIn1. We can't handle this for now - this case will disappear - // when we start sorting the vectors by type. - return SDValue(); - } - } else if (InVT2.getSizeInBits() * 2 == VT.getSizeInBits() && - InVT1.getSizeInBits() == VT.getSizeInBits()) { - SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2)); - ConcatOps[0] = VecIn2; - VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps); - } else { - // TODO: Support cases where the length mismatch isn't exactly by a - // factor of 2. - // TODO: Move this check upwards, so that if we have bad type - // mismatches, we don't create any DAG nodes. - return SDValue(); - } - } + EVT SrcVT = MVT::Other; + unsigned Opcode = ISD::DELETED_NODE; + unsigned NumDefs = 0; - // Initialize mask to undef. - SmallVector<int, 8> Mask(ShuffleNumElems, -1); + for (unsigned i = 0; i != NumInScalars; ++i) { + SDValue In = N->getOperand(i); + unsigned Opc = In.getOpcode(); - // Only need to run up to the number of elements actually used, not the - // total number of elements in the shuffle - if we are shuffling a wider - // vector, the high lanes should be set to undef. - for (unsigned i = 0; i != NumElems; ++i) { - if (VectorMask[i] <= 0) + if (Opc == ISD::UNDEF) continue; - unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1); - if (VectorMask[i] == (int)LeftIdx) { - Mask[i] = ExtIndex; - } else if (VectorMask[i] == (int)LeftIdx + 1) { - Mask[i] = Vec2Offset + ExtIndex; + // If all scalar values are floats and converted from integers. + if (Opcode == ISD::DELETED_NODE && + (Opc == ISD::UINT_TO_FP || Opc == ISD::SINT_TO_FP)) { + Opcode = Opc; } - } - - // The type the input vectors may have changed above. - InVT1 = VecIn1.getValueType(); - - // If we already have a VecIn2, it should have the same type as VecIn1. - // If we don't, get an undef/zero vector of the appropriate type. - VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1); - assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type."); - SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask); - if (ShuffleNumElems > NumElems) - Shuffle = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Shuffle, ZeroIdx); - - return Shuffle; -} + if (Opc != Opcode) + return SDValue(); -static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) { - assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector"); + EVT InVT = In.getOperand(0).getValueType(); - // First, determine where the build vector is not undef. - // TODO: We could extend this to handle zero elements as well as undefs. - int NumBVOps = BV->getNumOperands(); - int ZextElt = -1; - for (int i = 0; i != NumBVOps; ++i) { - SDValue Op = BV->getOperand(i); - if (Op.isUndef()) - continue; - if (ZextElt == -1) - ZextElt = i; - else + // If all scalar values are typed differently, bail out. It's chosen to + // simplify BUILD_VECTOR of integer types. + if (SrcVT == MVT::Other) + SrcVT = InVT; + if (SrcVT != InVT) return SDValue(); + NumDefs++; } - // Bail out if there's no non-undef element. - if (ZextElt == -1) + + // If the vector has just one element defined, it's not worth to fold it into + // a vectorized one. + if (NumDefs < 2) return SDValue(); - // The build vector contains some number of undef elements and exactly - // one other element. That other element must be a zero-extended scalar - // extracted from a vector at a constant index to turn this into a shuffle. - // Also, require that the build vector does not implicitly truncate/extend - // its elements. - // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND. - EVT VT = BV->getValueType(0); - SDValue Zext = BV->getOperand(ZextElt); - if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() || - Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT || - !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) || - Zext.getValueSizeInBits() != VT.getScalarSizeInBits()) + assert((Opcode == ISD::UINT_TO_FP || Opcode == ISD::SINT_TO_FP) + && "Should only handle conversion from integer to float."); + assert(SrcVT != MVT::Other && "Cannot determine source type!"); + + EVT NVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumInScalars); + + if (!TLI.isOperationLegalOrCustom(Opcode, NVT)) return SDValue(); - // The zero-extend must be a multiple of the source size, and we must be - // building a vector of the same size as the source of the extract element. - SDValue Extract = Zext.getOperand(0); - unsigned DestSize = Zext.getValueSizeInBits(); - unsigned SrcSize = Extract.getValueSizeInBits(); - if (DestSize % SrcSize != 0 || - Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits()) + // Just because the floating-point vector type is legal does not necessarily + // mean that the corresponding integer vector type is. + if (!isTypeLegal(NVT)) return SDValue(); - // Create a shuffle mask that will combine the extracted element with zeros - // and undefs. - int ZextRatio = DestSize / SrcSize; - int NumMaskElts = NumBVOps * ZextRatio; - SmallVector<int, 32> ShufMask(NumMaskElts, -1); - for (int i = 0; i != NumMaskElts; ++i) { - if (i / ZextRatio == ZextElt) { - // The low bits of the (potentially translated) extracted element map to - // the source vector. The high bits map to zero. We will use a zero vector - // as the 2nd source operand of the shuffle, so use the 1st element of - // that vector (mask value is number-of-elements) for the high bits. - if (i % ZextRatio == 0) - ShufMask[i] = Extract.getConstantOperandVal(1); - else - ShufMask[i] = NumMaskElts; - } + SmallVector<SDValue, 8> Opnds; + for (unsigned i = 0; i != NumInScalars; ++i) { + SDValue In = N->getOperand(i); - // Undef elements of the build vector remain undef because we initialize - // the shuffle mask with -1. + if (In.getOpcode() == ISD::UNDEF) + Opnds.push_back(DAG.getUNDEF(SrcVT)); + else + Opnds.push_back(In.getOperand(0)); } + SDValue BV = DAG.getNode(ISD::BUILD_VECTOR, dl, NVT, Opnds); + AddToWorklist(BV.getNode()); - // Turn this into a shuffle with zero if that's legal. - EVT VecVT = Extract.getOperand(0).getValueType(); - if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(ShufMask, VecVT)) - return SDValue(); - - // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... --> - // bitcast (shuffle V, ZeroVec, VectorMask) - SDLoc DL(BV); - SDValue ZeroVec = DAG.getConstant(0, DL, VecVT); - SDValue Shuf = DAG.getVectorShuffle(VecVT, DL, Extract.getOperand(0), ZeroVec, - ShufMask); - return DAG.getBitcast(VT, Shuf); + return DAG.getNode(Opcode, dl, VT, BV); } -// Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT -// operations. If the types of the vectors we're extracting from allow it, -// turn this into a vector_shuffle node. -SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) { - SDLoc DL(N); +SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) { + unsigned NumInScalars = N->getNumOperands(); + SDLoc dl(N); EVT VT = N->getValueType(0); + // A vector built entirely of undefs is undef. + if (ISD::allOperandsUndef(N)) + return DAG.getUNDEF(VT); + + if (SDValue V = reduceBuildVecExtToExtBuildVec(N)) + return V; + + if (SDValue V = reduceBuildVecConvertToConvertBuildVec(N)) + return V; + + // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT + // operations. If so, and if the EXTRACT_VECTOR_ELT vector inputs come from + // at most two distinct vectors, turn this into a shuffle node. + // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes. if (!isTypeLegal(VT)) return SDValue(); - if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG)) - return V; - // May only combine to shuffle after legalize if shuffle is legal. if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT)) return SDValue(); + SDValue VecIn1, VecIn2; bool UsesZeroVector = false; - unsigned NumElems = N->getNumOperands(); - - // Record, for each element of the newly built vector, which input vector - // that element comes from. -1 stands for undef, 0 for the zero vector, - // and positive values for the input vectors. - // VectorMask maps each element to its vector number, and VecIn maps vector - // numbers to their initial SDValues. - - SmallVector<int, 8> VectorMask(NumElems, -1); - SmallVector<SDValue, 8> VecIn; - VecIn.push_back(SDValue()); - - for (unsigned i = 0; i != NumElems; ++i) { + for (unsigned i = 0; i != NumInScalars; ++i) { SDValue Op = N->getOperand(i); + // Ignore undef inputs. + if (Op.getOpcode() == ISD::UNDEF) continue; - if (Op.isUndef()) - continue; - - // See if we can use a blend with a zero vector. - // TODO: Should we generalize this to a blend with an arbitrary constant - // vector? - if (isNullConstant(Op) || isNullFPConstant(Op)) { + // See if we can combine this build_vector into a blend with a zero vector. + if (!VecIn2.getNode() && (isNullConstant(Op) || isNullFPConstant(Op))) { UsesZeroVector = true; - VectorMask[i] = 0; continue; } - // Not an undef or zero. If the input is something other than an - // EXTRACT_VECTOR_ELT with an in-range constant index, bail out. + // If this input is something other than a EXTRACT_VECTOR_ELT with a + // constant index, bail out. if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT || - !isa<ConstantSDNode>(Op.getOperand(1))) - return SDValue(); - SDValue ExtractedFromVec = Op.getOperand(0); - - APInt ExtractIdx = cast<ConstantSDNode>(Op.getOperand(1))->getAPIntValue(); - if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements())) - return SDValue(); - - // All inputs must have the same element type as the output. - if (VT.getVectorElementType() != - ExtractedFromVec.getValueType().getVectorElementType()) - return SDValue(); + !isa<ConstantSDNode>(Op.getOperand(1))) { + VecIn1 = VecIn2 = SDValue(nullptr, 0); + break; + } - // Have we seen this input vector before? - // The vectors are expected to be tiny (usually 1 or 2 elements), so using - // a map back from SDValues to numbers isn't worth it. - unsigned Idx = std::distance( - VecIn.begin(), std::find(VecIn.begin(), VecIn.end(), ExtractedFromVec)); - if (Idx == VecIn.size()) - VecIn.push_back(ExtractedFromVec); + // We allow up to two distinct input vectors. + SDValue ExtractedFromVec = Op.getOperand(0); + if (ExtractedFromVec == VecIn1 || ExtractedFromVec == VecIn2) + continue; - VectorMask[i] = Idx; + if (!VecIn1.getNode()) { + VecIn1 = ExtractedFromVec; + } else if (!VecIn2.getNode() && !UsesZeroVector) { + VecIn2 = ExtractedFromVec; + } else { + // Too many inputs. + VecIn1 = VecIn2 = SDValue(nullptr, 0); + break; + } } - // If we didn't find at least one input vector, bail out. - if (VecIn.size() < 2) - return SDValue(); - - // If all the Operands of BUILD_VECTOR extract from same - // vector, then split the vector efficiently based on the maximum - // vector access index and adjust the VectorMask and - // VecIn accordingly. - if (VecIn.size() == 2) { - unsigned MaxIndex = 0; - unsigned NearestPow2 = 0; - SDValue Vec = VecIn.back(); - EVT InVT = Vec.getValueType(); - MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout()); - SmallVector<unsigned, 8> IndexVec(NumElems, 0); - - for (unsigned i = 0; i < NumElems; i++) { - if (VectorMask[i] <= 0) + // If everything is good, we can make a shuffle operation. + if (VecIn1.getNode()) { + unsigned InNumElements = VecIn1.getValueType().getVectorNumElements(); + SmallVector<int, 8> Mask; + for (unsigned i = 0; i != NumInScalars; ++i) { + unsigned Opcode = N->getOperand(i).getOpcode(); + if (Opcode == ISD::UNDEF) { + Mask.push_back(-1); continue; - unsigned Index = N->getOperand(i).getConstantOperandVal(1); - IndexVec[i] = Index; - MaxIndex = std::max(MaxIndex, Index); - } - - NearestPow2 = PowerOf2Ceil(MaxIndex); - if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 && - NumElems * 2 < NearestPow2) { - unsigned SplitSize = NearestPow2 / 2; - EVT SplitVT = EVT::getVectorVT(*DAG.getContext(), - InVT.getVectorElementType(), SplitSize); - if (TLI.isTypeLegal(SplitVT)) { - SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec, - DAG.getConstant(SplitSize, DL, IdxTy)); - SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec, - DAG.getConstant(0, DL, IdxTy)); - VecIn.pop_back(); - VecIn.push_back(VecIn1); - VecIn.push_back(VecIn2); - - for (unsigned i = 0; i < NumElems; i++) { - if (VectorMask[i] <= 0) - continue; - VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2; - } } - } - } - - // TODO: We want to sort the vectors by descending length, so that adjacent - // pairs have similar length, and the longer vector is always first in the - // pair. - - // TODO: Should this fire if some of the input vectors has illegal type (like - // it does now), or should we let legalization run its course first? - - // Shuffle phase: - // Take pairs of vectors, and shuffle them so that the result has elements - // from these vectors in the correct places. - // For example, given: - // t10: i32 = extract_vector_elt t1, Constant:i64<0> - // t11: i32 = extract_vector_elt t2, Constant:i64<0> - // t12: i32 = extract_vector_elt t3, Constant:i64<0> - // t13: i32 = extract_vector_elt t1, Constant:i64<1> - // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13 - // We will generate: - // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2 - // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef - SmallVector<SDValue, 4> Shuffles; - for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) { - unsigned LeftIdx = 2 * In + 1; - SDValue VecLeft = VecIn[LeftIdx]; - SDValue VecRight = - (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue(); - - if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft, - VecRight, LeftIdx)) - Shuffles.push_back(Shuffle); - else - return SDValue(); - } - // If we need the zero vector as an "ingredient" in the blend tree, add it - // to the list of shuffles. - if (UsesZeroVector) - Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT) - : DAG.getConstantFP(0.0, DL, VT)); - - // If we only have one shuffle, we're done. - if (Shuffles.size() == 1) - return Shuffles[0]; + // Operands can also be zero. + if (Opcode != ISD::EXTRACT_VECTOR_ELT) { + assert(UsesZeroVector && + (Opcode == ISD::Constant || Opcode == ISD::ConstantFP) && + "Unexpected node found!"); + Mask.push_back(NumInScalars+i); + continue; + } - // Update the vector mask to point to the post-shuffle vectors. - for (int &Vec : VectorMask) - if (Vec == 0) - Vec = Shuffles.size() - 1; - else - Vec = (Vec - 1) / 2; - - // More than one shuffle. Generate a binary tree of blends, e.g. if from - // the previous step we got the set of shuffles t10, t11, t12, t13, we will - // generate: - // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2 - // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4 - // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6 - // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8 - // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11 - // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13 - // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21 - - // Make sure the initial size of the shuffle list is even. - if (Shuffles.size() % 2) - Shuffles.push_back(DAG.getUNDEF(VT)); - - for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) { - if (CurSize % 2) { - Shuffles[CurSize] = DAG.getUNDEF(VT); - CurSize++; - } - for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) { - int Left = 2 * In; - int Right = 2 * In + 1; - SmallVector<int, 8> Mask(NumElems, -1); - for (unsigned i = 0; i != NumElems; ++i) { - if (VectorMask[i] == Left) { - Mask[i] = i; - VectorMask[i] = In; - } else if (VectorMask[i] == Right) { - Mask[i] = i + NumElems; - VectorMask[i] = In; - } + // If extracting from the first vector, just use the index directly. + SDValue Extract = N->getOperand(i); + SDValue ExtVal = Extract.getOperand(1); + unsigned ExtIndex = cast<ConstantSDNode>(ExtVal)->getZExtValue(); + if (Extract.getOperand(0) == VecIn1) { + Mask.push_back(ExtIndex); + continue; } - Shuffles[In] = - DAG.getVectorShuffle(VT, DL, Shuffles[Left], Shuffles[Right], Mask); + // Otherwise, use InIdx + InputVecSize + Mask.push_back(InNumElements + ExtIndex); } - } - return Shuffles[0]; -} -// Try to turn a build vector of zero extends of extract vector elts into a -// a vector zero extend and possibly an extract subvector. -// TODO: Support sign extend or any extend? -// TODO: Allow undef elements? -// TODO: Don't require the extracts to start at element 0. -SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) { - if (LegalOperations) - return SDValue(); - - EVT VT = N->getValueType(0); - - SDValue Op0 = N->getOperand(0); - auto checkElem = [&](SDValue Op) -> int64_t { - if (Op.getOpcode() == ISD::ZERO_EXTEND && - Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT && - Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0)) - if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1))) - return C->getZExtValue(); - return -1; - }; - - // Make sure the first element matches - // (zext (extract_vector_elt X, C)) - int64_t Offset = checkElem(Op0); - if (Offset < 0) - return SDValue(); - - unsigned NumElems = N->getNumOperands(); - SDValue In = Op0.getOperand(0).getOperand(0); - EVT InSVT = In.getValueType().getScalarType(); - EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems); - - // Don't create an illegal input type after type legalization. - if (LegalTypes && !TLI.isTypeLegal(InVT)) - return SDValue(); - - // Ensure all the elements come from the same vector and are adjacent. - for (unsigned i = 1; i != NumElems; ++i) { - if ((Offset + i) != checkElem(N->getOperand(i))) + // Avoid introducing illegal shuffles with zero. + if (UsesZeroVector && !TLI.isVectorClearMaskLegal(Mask, VT)) return SDValue(); - } - SDLoc DL(N); - In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In, - Op0.getOperand(0).getOperand(1)); - return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, In); -} + // We can't generate a shuffle node with mismatched input and output types. + // Attempt to transform a single input vector to the correct type. + if ((VT != VecIn1.getValueType())) { + // If the input vector type has a different base type to the output + // vector type, bail out. + EVT VTElemType = VT.getVectorElementType(); + if ((VecIn1.getValueType().getVectorElementType() != VTElemType) || + (VecIn2.getNode() && + (VecIn2.getValueType().getVectorElementType() != VTElemType))) + return SDValue(); -SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) { - EVT VT = N->getValueType(0); + // If the input vector is too small, widen it. + // We only support widening of vectors which are half the size of the + // output registers. For example XMM->YMM widening on X86 with AVX. + EVT VecInT = VecIn1.getValueType(); + if (VecInT.getSizeInBits() * 2 == VT.getSizeInBits()) { + // If we only have one small input, widen it by adding undef values. + if (!VecIn2.getNode()) + VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, VecIn1, + DAG.getUNDEF(VecIn1.getValueType())); + else if (VecIn1.getValueType() == VecIn2.getValueType()) { + // If we have two small inputs of the same type, try to concat them. + VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, VecIn1, VecIn2); + VecIn2 = SDValue(nullptr, 0); + } else + return SDValue(); + } else if (VecInT.getSizeInBits() == VT.getSizeInBits() * 2) { + // If the input vector is too large, try to split it. + // We don't support having two input vectors that are too large. + // If the zero vector was used, we can not split the vector, + // since we'd need 3 inputs. + if (UsesZeroVector || VecIn2.getNode()) + return SDValue(); - // A vector built entirely of undefs is undef. - if (ISD::allOperandsUndef(N)) - return DAG.getUNDEF(VT); + if (!TLI.isExtractSubvectorCheap(VT, VT.getVectorNumElements())) + return SDValue(); - // If this is a splat of a bitcast from another vector, change to a - // concat_vector. - // For example: - // (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) -> - // (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X)))) - // - // If X is a build_vector itself, the concat can become a larger build_vector. - // TODO: Maybe this is useful for non-splat too? - if (!LegalOperations) { - if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) { - Splat = peekThroughBitcasts(Splat); - EVT SrcVT = Splat.getValueType(); - if (SrcVT.isVector()) { - unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements(); - EVT NewVT = EVT::getVectorVT(*DAG.getContext(), - SrcVT.getVectorElementType(), NumElts); - if (!LegalTypes || TLI.isTypeLegal(NewVT)) { - SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat); - SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), - NewVT, Ops); - return DAG.getBitcast(VT, Concat); - } - } + // Try to replace VecIn1 with two extract_subvectors + // No need to update the masks, they should still be correct. + VecIn2 = DAG.getNode( + ISD::EXTRACT_SUBVECTOR, dl, VT, VecIn1, + DAG.getConstant(VT.getVectorNumElements(), dl, + TLI.getVectorIdxTy(DAG.getDataLayout()))); + VecIn1 = DAG.getNode( + ISD::EXTRACT_SUBVECTOR, dl, VT, VecIn1, + DAG.getConstant(0, dl, TLI.getVectorIdxTy(DAG.getDataLayout()))); + } else + return SDValue(); } - } - // Check if we can express BUILD VECTOR via subvector extract. - if (!LegalTypes && (N->getNumOperands() > 1)) { - SDValue Op0 = N->getOperand(0); - auto checkElem = [&](SDValue Op) -> uint64_t { - if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) && - (Op0.getOperand(0) == Op.getOperand(0))) - if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1))) - return CNode->getZExtValue(); - return -1; - }; + if (UsesZeroVector) + VecIn2 = VT.isInteger() ? DAG.getConstant(0, dl, VT) : + DAG.getConstantFP(0.0, dl, VT); + else + // If VecIn2 is unused then change it to undef. + VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(VT); - int Offset = checkElem(Op0); - for (unsigned i = 0; i < N->getNumOperands(); ++i) { - if (Offset + i != checkElem(N->getOperand(i))) { - Offset = -1; - break; - } - } + // Check that we were able to transform all incoming values to the same + // type. + if (VecIn2.getValueType() != VecIn1.getValueType() || + VecIn1.getValueType() != VT) + return SDValue(); - if ((Offset == 0) && - (Op0.getOperand(0).getValueType() == N->getValueType(0))) - return Op0.getOperand(0); - if ((Offset != -1) && - ((Offset % N->getValueType(0).getVectorNumElements()) == - 0)) // IDX must be multiple of output size. - return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0), - Op0.getOperand(0), Op0.getOperand(1)); + // Return the new VECTOR_SHUFFLE node. + SDValue Ops[2]; + Ops[0] = VecIn1; + Ops[1] = VecIn2; + return DAG.getVectorShuffle(VT, dl, Ops[0], Ops[1], &Mask[0]); } - if (SDValue V = convertBuildVecZextToZext(N)) - return V; - - if (SDValue V = reduceBuildVecExtToExtBuildVec(N)) - return V; - - if (SDValue V = reduceBuildVecToShuffle(N)) - return V; - return SDValue(); } @@ -16651,17 +12751,18 @@ static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) { for (SDValue &Op : Ops) { if (Op.getValueType() == SVT) continue; - if (Op.isUndef()) + if (Op.getOpcode() == ISD::UNDEF) Op = ScalarUndef; else - Op = DAG.getBitcast(SVT, Op); + Op = DAG.getNode(ISD::BITCAST, DL, SVT, Op); } } } EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT, VT.getSizeInBits() / SVT.getSizeInBits()); - return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops)); + return DAG.getNode(ISD::BITCAST, DL, VT, + DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, Ops)); } // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR @@ -16678,10 +12779,12 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) { SmallVector<int, 8> Mask; for (SDValue Op : N->ops()) { - Op = peekThroughBitcasts(Op); + // Peek through any bitcast. + while (Op.getOpcode() == ISD::BITCAST) + Op = Op.getOperand(0); // UNDEF nodes convert to UNDEF shuffle mask values. - if (Op.isUndef()) { + if (Op.getOpcode() == ISD::UNDEF) { Mask.append((unsigned)NumOpElts, -1); continue; } @@ -16695,17 +12798,20 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) { // We want the EVT of the original extraction to correctly scale the // extraction index. EVT ExtVT = ExtVec.getValueType(); - ExtVec = peekThroughBitcasts(ExtVec); + + // Peek through any bitcast. + while (ExtVec.getOpcode() == ISD::BITCAST) + ExtVec = ExtVec.getOperand(0); // UNDEF nodes convert to UNDEF shuffle mask values. - if (ExtVec.isUndef()) { + if (ExtVec.getOpcode() == ISD::UNDEF) { Mask.append((unsigned)NumOpElts, -1); continue; } if (!isa<ConstantSDNode>(Op.getOperand(1))) return SDValue(); - int ExtIdx = Op.getConstantOperandVal(1); + int ExtIdx = cast<ConstantSDNode>(Op.getOperand(1))->getZExtValue(); // Ensure that we are extracting a subvector from a vector the same // size as the result. @@ -16722,11 +12828,11 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) { return SDValue(); // At most we can reference 2 inputs in the final shuffle. - if (SV0.isUndef() || SV0 == ExtVec) { + if (SV0.getOpcode() == ISD::UNDEF || SV0 == ExtVec) { SV0 = ExtVec; for (int i = 0; i != NumOpElts; ++i) Mask.push_back(i + ExtIdx); - } else if (SV1.isUndef() || SV1 == ExtVec) { + } else if (SV1.getOpcode() == ISD::UNDEF || SV1 == ExtVec) { SV1 = ExtVec; for (int i = 0; i != NumOpElts; ++i) Mask.push_back(i + ExtIdx + NumElts); @@ -16754,24 +12860,16 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { // Optimize concat_vectors where all but the first of the vectors are undef. if (std::all_of(std::next(N->op_begin()), N->op_end(), [](const SDValue &Op) { - return Op.isUndef(); + return Op.getOpcode() == ISD::UNDEF; })) { SDValue In = N->getOperand(0); assert(In.getValueType().isVector() && "Must concat vectors"); - SDValue Scalar = peekThroughOneUseBitcasts(In); + // Transform: concat_vectors(scalar, undef) -> scalar_to_vector(sclr). + if (In->getOpcode() == ISD::BITCAST && + !In->getOperand(0)->getValueType(0).isVector()) { + SDValue Scalar = In->getOperand(0); - // concat_vectors(scalar_to_vector(scalar), undef) -> - // scalar_to_vector(scalar) - if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR && - Scalar.hasOneUse()) { - EVT SVT = Scalar.getValueType().getVectorElementType(); - if (SVT == Scalar.getOperand(0).getValueType()) - Scalar = Scalar.getOperand(0); - } - - // concat_vectors(scalar, undef) -> scalar_to_vector(scalar) - if (!Scalar.getValueType().isVector()) { // If the bitcast type isn't legal, it might be a trunc of a legal type; // look through the trunc so we can still do the transform: // concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar) @@ -16780,25 +12878,19 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { TLI.isTypeLegal(Scalar->getOperand(0).getValueType())) Scalar = Scalar->getOperand(0); - EVT SclTy = Scalar.getValueType(); + EVT SclTy = Scalar->getValueType(0); if (!SclTy.isFloatingPoint() && !SclTy.isInteger()) return SDValue(); - // Bail out if the vector size is not a multiple of the scalar size. - if (VT.getSizeInBits() % SclTy.getSizeInBits()) - return SDValue(); - - unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits(); - if (VNTNumElms < 2) - return SDValue(); - - EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms); + EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, + VT.getSizeInBits() / SclTy.getSizeInBits()); if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType())) return SDValue(); - SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar); - return DAG.getBitcast(VT, Res); + SDLoc dl = SDLoc(N); + SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, NVT, Scalar); + return DAG.getNode(ISD::BITCAST, dl, VT, Res); } } @@ -16809,7 +12901,9 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { auto IsBuildVectorOrUndef = [](const SDValue &Op) { return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode(); }; - if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) { + bool AllBuildVectorsOrUndefs = + std::all_of(N->op_begin(), N->op_end(), IsBuildVectorOrUndef); + if (AllBuildVectorsOrUndefs) { SmallVector<SDValue, 8> Opnds; EVT SVT = VT.getScalarType(); @@ -16820,7 +12914,7 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { bool FoundMinVT = false; for (const SDValue &Op : N->ops()) if (ISD::BUILD_VECTOR == Op.getOpcode()) { - EVT OpSVT = Op.getOperand(0).getValueType(); + EVT OpSVT = Op.getOperand(0)->getValueType(0); MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT; FoundMinVT = true; } @@ -16848,7 +12942,7 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { assert(VT.getVectorNumElements() == Opnds.size() && "Concat vector type mismatch"); - return DAG.getBuildVector(VT, SDLoc(N), Opnds); + return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), VT, Opnds); } // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR. @@ -16870,7 +12964,7 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) { SDValue Op = N->getOperand(i); - if (Op.isUndef()) + if (Op.getOpcode() == ISD::UNDEF) continue; // Check if this is the identity extract: @@ -16908,177 +13002,19 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { return SDValue(); } -/// If we are extracting a subvector produced by a wide binary operator try -/// to use a narrow binary operator and/or avoid concatenation and extraction. -static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG) { - // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share - // some of these bailouts with other transforms. - - // The extract index must be a constant, so we can map it to a concat operand. - auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1)); - if (!ExtractIndexC) - return SDValue(); - - // We are looking for an optionally bitcasted wide vector binary operator - // feeding an extract subvector. - SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0)); - if (!ISD::isBinaryOp(BinOp.getNode())) - return SDValue(); - - // The binop must be a vector type, so we can extract some fraction of it. - EVT WideBVT = BinOp.getValueType(); - if (!WideBVT.isVector()) - return SDValue(); - - EVT VT = Extract->getValueType(0); - unsigned ExtractIndex = ExtractIndexC->getZExtValue(); - assert(ExtractIndex % VT.getVectorNumElements() == 0 && - "Extract index is not a multiple of the vector length."); - - // Bail out if this is not a proper multiple width extraction. - unsigned WideWidth = WideBVT.getSizeInBits(); - unsigned NarrowWidth = VT.getSizeInBits(); - if (WideWidth % NarrowWidth != 0) - return SDValue(); - - // Bail out if we are extracting a fraction of a single operation. This can - // occur because we potentially looked through a bitcast of the binop. - unsigned NarrowingRatio = WideWidth / NarrowWidth; - unsigned WideNumElts = WideBVT.getVectorNumElements(); - if (WideNumElts % NarrowingRatio != 0) - return SDValue(); - - // Bail out if the target does not support a narrower version of the binop. - EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(), - WideNumElts / NarrowingRatio); - unsigned BOpcode = BinOp.getOpcode(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT)) - return SDValue(); - - // If extraction is cheap, we don't need to look at the binop operands - // for concat ops. The narrow binop alone makes this transform profitable. - // We can't just reuse the original extract index operand because we may have - // bitcasted. - unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements(); - unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements(); - EVT ExtBOIdxVT = Extract->getOperand(1).getValueType(); - if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) && - BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) { - // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N) - SDLoc DL(Extract); - SDValue NewExtIndex = DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT); - SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, - BinOp.getOperand(0), NewExtIndex); - SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, - BinOp.getOperand(1), NewExtIndex); - SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y, - BinOp.getNode()->getFlags()); - return DAG.getBitcast(VT, NarrowBinOp); - } - - // Only handle the case where we are doubling and then halving. A larger ratio - // may require more than two narrow binops to replace the wide binop. - if (NarrowingRatio != 2) - return SDValue(); - - // TODO: The motivating case for this transform is an x86 AVX1 target. That - // target has temptingly almost legal versions of bitwise logic ops in 256-bit - // flavors, but no other 256-bit integer support. This could be extended to - // handle any binop, but that may require fixing/adding other folds to avoid - // codegen regressions. - if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR) - return SDValue(); - - // We need at least one concatenation operation of a binop operand to make - // this transform worthwhile. The concat must double the input vector sizes. - // TODO: Should we also handle INSERT_SUBVECTOR patterns? - SDValue LHS = peekThroughBitcasts(BinOp.getOperand(0)); - SDValue RHS = peekThroughBitcasts(BinOp.getOperand(1)); - bool ConcatL = - LHS.getOpcode() == ISD::CONCAT_VECTORS && LHS.getNumOperands() == 2; - bool ConcatR = - RHS.getOpcode() == ISD::CONCAT_VECTORS && RHS.getNumOperands() == 2; - if (!ConcatL && !ConcatR) - return SDValue(); - - // If one of the binop operands was not the result of a concat, we must - // extract a half-sized operand for our new narrow binop. - SDLoc DL(Extract); - - // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN - // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, N) - // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, N), YN - SDValue X = ConcatL ? DAG.getBitcast(NarrowBVT, LHS.getOperand(ConcatOpNum)) - : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, - BinOp.getOperand(0), - DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT)); - - SDValue Y = ConcatR ? DAG.getBitcast(NarrowBVT, RHS.getOperand(ConcatOpNum)) - : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, - BinOp.getOperand(1), - DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT)); - - SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y); - return DAG.getBitcast(VT, NarrowBinOp); -} - -/// If we are extracting a subvector from a wide vector load, convert to a -/// narrow load to eliminate the extraction: -/// (extract_subvector (load wide vector)) --> (load narrow vector) -static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) { - // TODO: Add support for big-endian. The offset calculation must be adjusted. - if (DAG.getDataLayout().isBigEndian()) - return SDValue(); - - auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0)); - auto *ExtIdx = dyn_cast<ConstantSDNode>(Extract->getOperand(1)); - if (!Ld || Ld->getExtensionType() || Ld->isVolatile() || !ExtIdx) - return SDValue(); - - // Allow targets to opt-out. - EVT VT = Extract->getValueType(0); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT)) - return SDValue(); - - // The narrow load will be offset from the base address of the old load if - // we are extracting from something besides index 0 (little-endian). - SDLoc DL(Extract); - SDValue BaseAddr = Ld->getOperand(1); - unsigned Offset = ExtIdx->getZExtValue() * VT.getScalarType().getStoreSize(); - - // TODO: Use "BaseIndexOffset" to make this more effective. - SDValue NewAddr = DAG.getMemBasePlusOffset(BaseAddr, Offset, DL); - MachineFunction &MF = DAG.getMachineFunction(); - MachineMemOperand *MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset, - VT.getStoreSize()); - SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO); - DAG.makeEquivalentMemoryOrdering(Ld, NewLd); - return NewLd; -} - SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) { EVT NVT = N->getValueType(0); SDValue V = N->getOperand(0); - // Extract from UNDEF is UNDEF. - if (V.isUndef()) - return DAG.getUNDEF(NVT); - - if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT)) - if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DAG)) - return NarrowLoad; - - // Combine: - // (extract_subvec (concat V1, V2, ...), i) - // Into: - // Vi if possible - // Only operand 0 is checked as 'concat' assumes all inputs of the same - // type. - if (V.getOpcode() == ISD::CONCAT_VECTORS && - isa<ConstantSDNode>(N->getOperand(1)) && - V.getOperand(0).getValueType() == NVT) { + if (V->getOpcode() == ISD::CONCAT_VECTORS) { + // Combine: + // (extract_subvec (concat V1, V2, ...), i) + // Into: + // Vi if possible + // Only operand 0 is checked as 'concat' assumes all inputs of the same + // type. + if (V->getOperand(0).getValueType() != NVT) + return SDValue(); unsigned Idx = N->getConstantOperandVal(1); unsigned NumElems = NVT.getVectorNumElements(); assert((Idx % NumElems) == 0 && @@ -17086,78 +13022,128 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) { return V->getOperand(Idx / NumElems); } - V = peekThroughBitcasts(V); - - // If the input is a build vector. Try to make a smaller build vector. - if (V.getOpcode() == ISD::BUILD_VECTOR) { - if (auto *Idx = dyn_cast<ConstantSDNode>(N->getOperand(1))) { - EVT InVT = V.getValueType(); - unsigned ExtractSize = NVT.getSizeInBits(); - unsigned EltSize = InVT.getScalarSizeInBits(); - // Only do this if we won't split any elements. - if (ExtractSize % EltSize == 0) { - unsigned NumElems = ExtractSize / EltSize; - EVT EltVT = InVT.getVectorElementType(); - EVT ExtractVT = NumElems == 1 ? EltVT : - EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems); - if ((Level < AfterLegalizeDAG || - (NumElems == 1 || - TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) && - (!LegalTypes || TLI.isTypeLegal(ExtractVT))) { - unsigned IdxVal = (Idx->getZExtValue() * NVT.getScalarSizeInBits()) / - EltSize; - if (NumElems == 1) { - SDValue Src = V->getOperand(IdxVal); - if (EltVT != Src.getValueType()) - Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src); - - return DAG.getBitcast(NVT, Src); - } - - // Extract the pieces from the original build_vector. - SDValue BuildVec = DAG.getBuildVector(ExtractVT, SDLoc(N), - makeArrayRef(V->op_begin() + IdxVal, - NumElems)); - return DAG.getBitcast(NVT, BuildVec); - } - } - } - } + // Skip bitcasting + if (V->getOpcode() == ISD::BITCAST) + V = V.getOperand(0); - if (V.getOpcode() == ISD::INSERT_SUBVECTOR) { + if (V->getOpcode() == ISD::INSERT_SUBVECTOR) { + SDLoc dl(N); // Handle only simple case where vector being inserted and vector - // being extracted are of same size. - EVT SmallVT = V.getOperand(1).getValueType(); - if (!NVT.bitsEq(SmallVT)) + // being extracted are of same type, and are half size of larger vectors. + EVT BigVT = V->getOperand(0).getValueType(); + EVT SmallVT = V->getOperand(1).getValueType(); + if (!NVT.bitsEq(SmallVT) || NVT.getSizeInBits()*2 != BigVT.getSizeInBits()) return SDValue(); - // Only handle cases where both indexes are constants. - auto *ExtIdx = dyn_cast<ConstantSDNode>(N->getOperand(1)); - auto *InsIdx = dyn_cast<ConstantSDNode>(V.getOperand(2)); + // Only handle cases where both indexes are constants with the same type. + ConstantSDNode *ExtIdx = dyn_cast<ConstantSDNode>(N->getOperand(1)); + ConstantSDNode *InsIdx = dyn_cast<ConstantSDNode>(V->getOperand(2)); - if (InsIdx && ExtIdx) { + if (InsIdx && ExtIdx && + InsIdx->getValueType(0).getSizeInBits() <= 64 && + ExtIdx->getValueType(0).getSizeInBits() <= 64) { // Combine: // (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx) // Into: // indices are equal or bit offsets are equal => V1 // otherwise => (extract_subvec V1, ExtIdx) - if (InsIdx->getZExtValue() * SmallVT.getScalarSizeInBits() == - ExtIdx->getZExtValue() * NVT.getScalarSizeInBits()) - return DAG.getBitcast(NVT, V.getOperand(1)); - return DAG.getNode( - ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT, - DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)), - N->getOperand(1)); + if (InsIdx->getZExtValue() * SmallVT.getScalarType().getSizeInBits() == + ExtIdx->getZExtValue() * NVT.getScalarType().getSizeInBits()) + return DAG.getNode(ISD::BITCAST, dl, NVT, V->getOperand(1)); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NVT, + DAG.getNode(ISD::BITCAST, dl, + N->getOperand(0).getValueType(), + V->getOperand(0)), N->getOperand(1)); } } - if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG)) - return NarrowBOp; + return SDValue(); +} - if (SimplifyDemandedVectorElts(SDValue(N, 0))) - return SDValue(N, 0); +static SDValue simplifyShuffleOperandRecursively(SmallBitVector &UsedElements, + SDValue V, SelectionDAG &DAG) { + SDLoc DL(V); + EVT VT = V.getValueType(); - return SDValue(); + switch (V.getOpcode()) { + default: + return V; + + case ISD::CONCAT_VECTORS: { + EVT OpVT = V->getOperand(0).getValueType(); + int OpSize = OpVT.getVectorNumElements(); + SmallBitVector OpUsedElements(OpSize, false); + bool FoundSimplification = false; + SmallVector<SDValue, 4> NewOps; + NewOps.reserve(V->getNumOperands()); + for (int i = 0, NumOps = V->getNumOperands(); i < NumOps; ++i) { + SDValue Op = V->getOperand(i); + bool OpUsed = false; + for (int j = 0; j < OpSize; ++j) + if (UsedElements[i * OpSize + j]) { + OpUsedElements[j] = true; + OpUsed = true; + } + NewOps.push_back( + OpUsed ? simplifyShuffleOperandRecursively(OpUsedElements, Op, DAG) + : DAG.getUNDEF(OpVT)); + FoundSimplification |= Op == NewOps.back(); + OpUsedElements.reset(); + } + if (FoundSimplification) + V = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, NewOps); + return V; + } + + case ISD::INSERT_SUBVECTOR: { + SDValue BaseV = V->getOperand(0); + SDValue SubV = V->getOperand(1); + auto *IdxN = dyn_cast<ConstantSDNode>(V->getOperand(2)); + if (!IdxN) + return V; + + int SubSize = SubV.getValueType().getVectorNumElements(); + int Idx = IdxN->getZExtValue(); + bool SubVectorUsed = false; + SmallBitVector SubUsedElements(SubSize, false); + for (int i = 0; i < SubSize; ++i) + if (UsedElements[i + Idx]) { + SubVectorUsed = true; + SubUsedElements[i] = true; + UsedElements[i + Idx] = false; + } + + // Now recurse on both the base and sub vectors. + SDValue SimplifiedSubV = + SubVectorUsed + ? simplifyShuffleOperandRecursively(SubUsedElements, SubV, DAG) + : DAG.getUNDEF(SubV.getValueType()); + SDValue SimplifiedBaseV = simplifyShuffleOperandRecursively(UsedElements, BaseV, DAG); + if (SimplifiedSubV != SubV || SimplifiedBaseV != BaseV) + V = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, + SimplifiedBaseV, SimplifiedSubV, V->getOperand(2)); + return V; + } + } +} + +static SDValue simplifyShuffleOperands(ShuffleVectorSDNode *SVN, SDValue N0, + SDValue N1, SelectionDAG &DAG) { + EVT VT = SVN->getValueType(0); + int NumElts = VT.getVectorNumElements(); + SmallBitVector N0UsedElements(NumElts, false), N1UsedElements(NumElts, false); + for (int M : SVN->getMask()) + if (M >= 0 && M < NumElts) + N0UsedElements[M] = true; + else if (M >= NumElts) + N1UsedElements[M - NumElts] = true; + + SDValue S0 = simplifyShuffleOperandRecursively(N0UsedElements, N0, DAG); + SDValue S1 = simplifyShuffleOperandRecursively(N1UsedElements, N1, DAG); + if (S0 == N0 && S1 == N1) + return SDValue(); + + return DAG.getVectorShuffle(VT, SDLoc(SVN), S0, S1, SVN->getMask()); } // Tries to turn a shuffle of two CONCAT_VECTORS into a single concat, @@ -17178,7 +13164,7 @@ static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) { // Special case: shuffle(concat(A,B)) can be more efficiently represented // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high // half vector elements. - if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() && + if (NumElemsPerConcat * 2 == NumElts && N1.getOpcode() == ISD::UNDEF && std::all_of(SVN->getMask().begin() + NumElemsPerConcat, SVN->getMask().end(), [](int i) { return i == -1; })) { N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0), N0.getOperand(1), @@ -17224,343 +13210,6 @@ static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) { return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops); } -// Attempt to combine a shuffle of 2 inputs of 'scalar sources' - -// BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR. -// -// SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always -// a simplification in some sense, but it isn't appropriate in general: some -// BUILD_VECTORs are substantially cheaper than others. The general case -// of a BUILD_VECTOR requires inserting each element individually (or -// performing the equivalent in a temporary stack variable). A BUILD_VECTOR of -// all constants is a single constant pool load. A BUILD_VECTOR where each -// element is identical is a splat. A BUILD_VECTOR where most of the operands -// are undef lowers to a small number of element insertions. -// -// To deal with this, we currently use a bunch of mostly arbitrary heuristics. -// We don't fold shuffles where one side is a non-zero constant, and we don't -// fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate -// non-constant operands. This seems to work out reasonably well in practice. -static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN, - SelectionDAG &DAG, - const TargetLowering &TLI) { - EVT VT = SVN->getValueType(0); - unsigned NumElts = VT.getVectorNumElements(); - SDValue N0 = SVN->getOperand(0); - SDValue N1 = SVN->getOperand(1); - - if (!N0->hasOneUse()) - return SDValue(); - - // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as - // discussed above. - if (!N1.isUndef()) { - if (!N1->hasOneUse()) - return SDValue(); - - bool N0AnyConst = isAnyConstantBuildVector(N0); - bool N1AnyConst = isAnyConstantBuildVector(N1); - if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode())) - return SDValue(); - if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode())) - return SDValue(); - } - - // If both inputs are splats of the same value then we can safely merge this - // to a single BUILD_VECTOR with undef elements based on the shuffle mask. - bool IsSplat = false; - auto *BV0 = dyn_cast<BuildVectorSDNode>(N0); - auto *BV1 = dyn_cast<BuildVectorSDNode>(N1); - if (BV0 && BV1) - if (SDValue Splat0 = BV0->getSplatValue()) - IsSplat = (Splat0 == BV1->getSplatValue()); - - SmallVector<SDValue, 8> Ops; - SmallSet<SDValue, 16> DuplicateOps; - for (int M : SVN->getMask()) { - SDValue Op = DAG.getUNDEF(VT.getScalarType()); - if (M >= 0) { - int Idx = M < (int)NumElts ? M : M - NumElts; - SDValue &S = (M < (int)NumElts ? N0 : N1); - if (S.getOpcode() == ISD::BUILD_VECTOR) { - Op = S.getOperand(Idx); - } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) { - assert(Idx == 0 && "Unexpected SCALAR_TO_VECTOR operand index."); - Op = S.getOperand(0); - } else { - // Operand can't be combined - bail out. - return SDValue(); - } - } - - // Don't duplicate a non-constant BUILD_VECTOR operand unless we're - // generating a splat; semantically, this is fine, but it's likely to - // generate low-quality code if the target can't reconstruct an appropriate - // shuffle. - if (!Op.isUndef() && !isa<ConstantSDNode>(Op) && !isa<ConstantFPSDNode>(Op)) - if (!IsSplat && !DuplicateOps.insert(Op).second) - return SDValue(); - - Ops.push_back(Op); - } - - // BUILD_VECTOR requires all inputs to be of the same type, find the - // maximum type and extend them all. - EVT SVT = VT.getScalarType(); - if (SVT.isInteger()) - for (SDValue &Op : Ops) - SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT); - if (SVT != VT.getScalarType()) - for (SDValue &Op : Ops) - Op = TLI.isZExtFree(Op.getValueType(), SVT) - ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT) - : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT); - return DAG.getBuildVector(VT, SDLoc(SVN), Ops); -} - -// Match shuffles that can be converted to any_vector_extend_in_reg. -// This is often generated during legalization. -// e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src)) -// TODO Add support for ZERO_EXTEND_VECTOR_INREG when we have a test case. -static SDValue combineShuffleToVectorExtend(ShuffleVectorSDNode *SVN, - SelectionDAG &DAG, - const TargetLowering &TLI, - bool LegalOperations) { - EVT VT = SVN->getValueType(0); - bool IsBigEndian = DAG.getDataLayout().isBigEndian(); - - // TODO Add support for big-endian when we have a test case. - if (!VT.isInteger() || IsBigEndian) - return SDValue(); - - unsigned NumElts = VT.getVectorNumElements(); - unsigned EltSizeInBits = VT.getScalarSizeInBits(); - ArrayRef<int> Mask = SVN->getMask(); - SDValue N0 = SVN->getOperand(0); - - // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32)) - auto isAnyExtend = [&Mask, &NumElts](unsigned Scale) { - for (unsigned i = 0; i != NumElts; ++i) { - if (Mask[i] < 0) - continue; - if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale)) - continue; - return false; - } - return true; - }; - - // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for - // power-of-2 extensions as they are the most likely. - for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) { - // Check for non power of 2 vector sizes - if (NumElts % Scale != 0) - continue; - if (!isAnyExtend(Scale)) - continue; - - EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale); - EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale); - // Never create an illegal type. Only create unsupported operations if we - // are pre-legalization. - if (TLI.isTypeLegal(OutVT)) - if (!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND_VECTOR_INREG, OutVT)) - return DAG.getBitcast(VT, - DAG.getNode(ISD::ANY_EXTEND_VECTOR_INREG, - SDLoc(SVN), OutVT, N0)); - } - - return SDValue(); -} - -// Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of -// each source element of a large type into the lowest elements of a smaller -// destination type. This is often generated during legalization. -// If the source node itself was a '*_extend_vector_inreg' node then we should -// then be able to remove it. -static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN, - SelectionDAG &DAG) { - EVT VT = SVN->getValueType(0); - bool IsBigEndian = DAG.getDataLayout().isBigEndian(); - - // TODO Add support for big-endian when we have a test case. - if (!VT.isInteger() || IsBigEndian) - return SDValue(); - - SDValue N0 = peekThroughBitcasts(SVN->getOperand(0)); - - unsigned Opcode = N0.getOpcode(); - if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG && - Opcode != ISD::SIGN_EXTEND_VECTOR_INREG && - Opcode != ISD::ZERO_EXTEND_VECTOR_INREG) - return SDValue(); - - SDValue N00 = N0.getOperand(0); - ArrayRef<int> Mask = SVN->getMask(); - unsigned NumElts = VT.getVectorNumElements(); - unsigned EltSizeInBits = VT.getScalarSizeInBits(); - unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits(); - unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits(); - - if (ExtDstSizeInBits % ExtSrcSizeInBits != 0) - return SDValue(); - unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits; - - // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1> - // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1> - // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1> - auto isTruncate = [&Mask, &NumElts](unsigned Scale) { - for (unsigned i = 0; i != NumElts; ++i) { - if (Mask[i] < 0) - continue; - if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale)) - continue; - return false; - } - return true; - }; - - // At the moment we just handle the case where we've truncated back to the - // same size as before the extension. - // TODO: handle more extension/truncation cases as cases arise. - if (EltSizeInBits != ExtSrcSizeInBits) - return SDValue(); - - // We can remove *extend_vector_inreg only if the truncation happens at - // the same scale as the extension. - if (isTruncate(ExtScale)) - return DAG.getBitcast(VT, N00); - - return SDValue(); -} - -// Combine shuffles of splat-shuffles of the form: -// shuffle (shuffle V, undef, splat-mask), undef, M -// If splat-mask contains undef elements, we need to be careful about -// introducing undef's in the folded mask which are not the result of composing -// the masks of the shuffles. -static SDValue combineShuffleOfSplat(ArrayRef<int> UserMask, - ShuffleVectorSDNode *Splat, - SelectionDAG &DAG) { - ArrayRef<int> SplatMask = Splat->getMask(); - assert(UserMask.size() == SplatMask.size() && "Mask length mismatch"); - - // Prefer simplifying to the splat-shuffle, if possible. This is legal if - // every undef mask element in the splat-shuffle has a corresponding undef - // element in the user-shuffle's mask or if the composition of mask elements - // would result in undef. - // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask): - // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u] - // In this case it is not legal to simplify to the splat-shuffle because we - // may be exposing the users of the shuffle an undef element at index 1 - // which was not there before the combine. - // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u] - // In this case the composition of masks yields SplatMask, so it's ok to - // simplify to the splat-shuffle. - // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u] - // In this case the composed mask includes all undef elements of SplatMask - // and in addition sets element zero to undef. It is safe to simplify to - // the splat-shuffle. - auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask, - ArrayRef<int> SplatMask) { - for (unsigned i = 0, e = UserMask.size(); i != e; ++i) - if (UserMask[i] != -1 && SplatMask[i] == -1 && - SplatMask[UserMask[i]] != -1) - return false; - return true; - }; - if (CanSimplifyToExistingSplat(UserMask, SplatMask)) - return SDValue(Splat, 0); - - // Create a new shuffle with a mask that is composed of the two shuffles' - // masks. - SmallVector<int, 32> NewMask; - for (int Idx : UserMask) - NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]); - - return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat), - Splat->getOperand(0), Splat->getOperand(1), - NewMask); -} - -/// If the shuffle mask is taking exactly one element from the first vector -/// operand and passing through all other elements from the second vector -/// operand, return the index of the mask element that is choosing an element -/// from the first operand. Otherwise, return -1. -static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) { - int MaskSize = Mask.size(); - int EltFromOp0 = -1; - // TODO: This does not match if there are undef elements in the shuffle mask. - // Should we ignore undefs in the shuffle mask instead? The trade-off is - // removing an instruction (a shuffle), but losing the knowledge that some - // vector lanes are not needed. - for (int i = 0; i != MaskSize; ++i) { - if (Mask[i] >= 0 && Mask[i] < MaskSize) { - // We're looking for a shuffle of exactly one element from operand 0. - if (EltFromOp0 != -1) - return -1; - EltFromOp0 = i; - } else if (Mask[i] != i + MaskSize) { - // Nothing from operand 1 can change lanes. - return -1; - } - } - return EltFromOp0; -} - -/// If a shuffle inserts exactly one element from a source vector operand into -/// another vector operand and we can access the specified element as a scalar, -/// then we can eliminate the shuffle. -static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf, - SelectionDAG &DAG) { - // First, check if we are taking one element of a vector and shuffling that - // element into another vector. - ArrayRef<int> Mask = Shuf->getMask(); - SmallVector<int, 16> CommutedMask(Mask.begin(), Mask.end()); - SDValue Op0 = Shuf->getOperand(0); - SDValue Op1 = Shuf->getOperand(1); - int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask); - if (ShufOp0Index == -1) { - // Commute mask and check again. - ShuffleVectorSDNode::commuteMask(CommutedMask); - ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask); - if (ShufOp0Index == -1) - return SDValue(); - // Commute operands to match the commuted shuffle mask. - std::swap(Op0, Op1); - Mask = CommutedMask; - } - - // The shuffle inserts exactly one element from operand 0 into operand 1. - // Now see if we can access that element as a scalar via a real insert element - // instruction. - // TODO: We can try harder to locate the element as a scalar. Examples: it - // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant. - assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() && - "Shuffle mask value must be from operand 0"); - if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT) - return SDValue(); - - auto *InsIndexC = dyn_cast<ConstantSDNode>(Op0.getOperand(2)); - if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index]) - return SDValue(); - - // There's an existing insertelement with constant insertion index, so we - // don't need to check the legality/profitability of a replacement operation - // that differs at most in the constant value. The target should be able to - // lower any of those in a similar way. If not, legalization will expand this - // to a scalar-to-vector plus shuffle. - // - // Note that the shuffle may move the scalar from the position that the insert - // element used. Therefore, our new insert element occurs at the shuffle's - // mask index value, not the insert's index value. - // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C' - SDValue NewInsIndex = DAG.getConstant(ShufOp0Index, SDLoc(Shuf), - Op0.getOperand(2).getValueType()); - return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(), - Op1, Op0.getOperand(1), NewInsIndex); -} - SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { EVT VT = N->getValueType(0); unsigned NumElts = VT.getVectorNumElements(); @@ -17571,7 +13220,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG"); // Canonicalize shuffle undef, undef -> undef - if (N0.isUndef() && N1.isUndef()) + if (N0.getOpcode() == ISD::UNDEF && N1.getOpcode() == ISD::UNDEF) return DAG.getUNDEF(VT); ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N); @@ -17584,15 +13233,29 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { if (Idx >= (int)NumElts) Idx -= NumElts; NewMask.push_back(Idx); } - return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT), NewMask); + return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT), + &NewMask[0]); } // Canonicalize shuffle undef, v -> v, undef. Commute the shuffle mask. - if (N0.isUndef()) - return DAG.getCommutedVectorShuffle(*SVN); + if (N0.getOpcode() == ISD::UNDEF) { + SmallVector<int, 8> NewMask; + for (unsigned i = 0; i != NumElts; ++i) { + int Idx = SVN->getMaskElt(i); + if (Idx >= 0) { + if (Idx >= (int)NumElts) + Idx -= NumElts; + else + Idx = -1; // remove reference to lhs + } + NewMask.push_back(Idx); + } + return DAG.getVectorShuffle(VT, SDLoc(N), N1, DAG.getUNDEF(VT), + &NewMask[0]); + } // Remove references to rhs if it is undef - if (N1.isUndef()) { + if (N1.getOpcode() == ISD::UNDEF) { bool Changed = false; SmallVector<int, 8> NewMask; for (unsigned i = 0; i != NumElts; ++i) { @@ -17604,17 +13267,9 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { NewMask.push_back(Idx); } if (Changed) - return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask); + return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, &NewMask[0]); } - if (SDValue InsElt = replaceShuffleOfInsert(SVN, DAG)) - return InsElt; - - // A shuffle of a single vector that is a splat can always be folded. - if (auto *N0Shuf = dyn_cast<ShuffleVectorSDNode>(N0)) - if (N1->isUndef() && N0Shuf->isSplat()) - return combineShuffleOfSplat(SVN->getMask(), N0Shuf, DAG); - // If it is a splat, check if the argument vector is another splat or a // build_vector. if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) { @@ -17636,7 +13291,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { SDValue Base; bool AllSame = true; for (unsigned i = 0; i != NumElts; ++i) { - if (!V->getOperand(i).isUndef()) { + if (V->getOperand(i).getOpcode() != ISD::UNDEF) { Base = V->getOperand(i); break; } @@ -17657,49 +13312,86 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { // Canonicalize any other splat as a build_vector. const SDValue &Splatted = V->getOperand(SVN->getSplatIndex()); SmallVector<SDValue, 8> Ops(NumElts, Splatted); - SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops); + SDValue NewBV = DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), + V->getValueType(0), Ops); // We may have jumped through bitcasts, so the type of the // BUILD_VECTOR may not match the type of the shuffle. if (V->getValueType(0) != VT) - NewBV = DAG.getBitcast(VT, NewBV); + NewBV = DAG.getNode(ISD::BITCAST, SDLoc(N), VT, NewBV); return NewBV; } } - // Simplify source operands based on shuffle mask. - if (SimplifyDemandedVectorElts(SDValue(N, 0))) - return SDValue(N, 0); - - // Match shuffles that can be converted to any_vector_extend_in_reg. - if (SDValue V = combineShuffleToVectorExtend(SVN, DAG, TLI, LegalOperations)) - return V; - - // Combine "truncate_vector_in_reg" style shuffles. - if (SDValue V = combineTruncationShuffle(SVN, DAG)) - return V; + // There are various patterns used to build up a vector from smaller vectors, + // subvectors, or elements. Scan chains of these and replace unused insertions + // or components with undef. + if (SDValue S = simplifyShuffleOperands(SVN, N0, N1, DAG)) + return S; if (N0.getOpcode() == ISD::CONCAT_VECTORS && Level < AfterLegalizeVectorOps && - (N1.isUndef() || + (N1.getOpcode() == ISD::UNDEF || (N1.getOpcode() == ISD::CONCAT_VECTORS && N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) { - if (SDValue V = partitionShuffleOfConcats(N, DAG)) + SDValue V = partitionShuffleOfConcats(N, DAG); + + if (V.getNode()) return V; } // Attempt to combine a shuffle of 2 inputs of 'scalar sources' - // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR. - if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) - if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI)) - return Res; + if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) { + SmallVector<SDValue, 8> Ops; + for (int M : SVN->getMask()) { + SDValue Op = DAG.getUNDEF(VT.getScalarType()); + if (M >= 0) { + int Idx = M % NumElts; + SDValue &S = (M < (int)NumElts ? N0 : N1); + if (S.getOpcode() == ISD::BUILD_VECTOR && S.hasOneUse()) { + Op = S.getOperand(Idx); + } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR && S.hasOneUse()) { + if (Idx == 0) + Op = S.getOperand(0); + } else { + // Operand can't be combined - bail out. + break; + } + } + Ops.push_back(Op); + } + if (Ops.size() == VT.getVectorNumElements()) { + // BUILD_VECTOR requires all inputs to be of the same type, find the + // maximum type and extend them all. + EVT SVT = VT.getScalarType(); + if (SVT.isInteger()) + for (SDValue &Op : Ops) + SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT); + if (SVT != VT.getScalarType()) + for (SDValue &Op : Ops) + Op = TLI.isZExtFree(Op.getValueType(), SVT) + ? DAG.getZExtOrTrunc(Op, SDLoc(N), SVT) + : DAG.getSExtOrTrunc(Op, SDLoc(N), SVT); + return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), VT, Ops); + } + } // If this shuffle only has a single input that is a bitcasted shuffle, // attempt to merge the 2 shuffles and suitably bitcast the inputs/output // back to their original types. if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() && - N1.isUndef() && Level < AfterLegalizeVectorOps && + N1.getOpcode() == ISD::UNDEF && Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) { + + // Peek through the bitcast only if there is one user. + SDValue BC0 = N0; + while (BC0.getOpcode() == ISD::BITCAST) { + if (!BC0.hasOneUse()) + break; + BC0 = BC0.getOperand(0); + } + auto ScaleShuffleMask = [](ArrayRef<int> Mask, int Scale) { if (Scale == 1) return SmallVector<int, 8>(Mask.begin(), Mask.end()); @@ -17710,8 +13402,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { NewMask.push_back(M < 0 ? -1 : Scale * M + s); return NewMask; }; - - SDValue BC0 = peekThroughOneUseBitcasts(N0); + if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) { EVT SVT = VT.getScalarType(); EVT InnerVT = BC0->getValueType(0); @@ -17724,6 +13415,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { if (TLI.isTypeLegal(ScaleVT) && 0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) && 0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) { + int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits(); int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits(); @@ -17750,10 +13442,11 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { } if (LegalMask) { - SV0 = DAG.getBitcast(ScaleVT, SV0); - SV1 = DAG.getBitcast(ScaleVT, SV1); - return DAG.getBitcast( - VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask)); + SV0 = DAG.getNode(ISD::BITCAST, SDLoc(N), ScaleVT, SV0); + SV1 = DAG.getNode(ISD::BITCAST, SDLoc(N), ScaleVT, SV1); + return DAG.getNode( + ISD::BITCAST, SDLoc(N), VT, + DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask)); } } } @@ -17774,7 +13467,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { SDValue SV0 = N1->getOperand(0); SDValue SV1 = N1->getOperand(1); bool HasSameOp0 = N0 == SV0; - bool IsSV1Undef = SV1.isUndef(); + bool IsSV1Undef = SV1.getOpcode() == ISD::UNDEF; if (HasSameOp0 || IsSV1Undef || N0 == SV1) // Commute the operands of this shuffle so that next rule // will trigger. @@ -17791,11 +13484,6 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) { ShuffleVectorSDNode *OtherSV = cast<ShuffleVectorSDNode>(N0); - // Don't try to fold splats; they're likely to simplify somehow, or they - // might be free. - if (OtherSV->isSplat()) - return SDValue(); - // The incoming shuffle must be of the same type as the result of the // current shuffle. assert(OtherSV->getOperand(0).getValueType() == VT && @@ -17832,7 +13520,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { } // Simple case where 'CurrentVec' is UNDEF. - if (CurrentVec.isUndef()) { + if (CurrentVec.getOpcode() == ISD::UNDEF) { Mask.push_back(-1); continue; } @@ -17887,7 +13575,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2) // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2) // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2) - return DAG.getVectorShuffle(VT, SDLoc(N), SV0, SV1, Mask); + return DAG.getVectorShuffle(VT, SDLoc(N), SV0, SV1, &Mask[0]); } return SDValue(); @@ -17898,46 +13586,23 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) { EVT VT = N->getValueType(0); // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern - // with a VECTOR_SHUFFLE and possible truncate. + // with a VECTOR_SHUFFLE. if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT) { SDValue InVec = InVal->getOperand(0); SDValue EltNo = InVal->getOperand(1); - auto InVecT = InVec.getValueType(); - if (ConstantSDNode *C0 = dyn_cast<ConstantSDNode>(EltNo)) { - SmallVector<int, 8> NewMask(InVecT.getVectorNumElements(), -1); + + // FIXME: We could support implicit truncation if the shuffle can be + // scaled to a smaller vector scalar type. + ConstantSDNode *C0 = dyn_cast<ConstantSDNode>(EltNo); + if (C0 && VT == InVec.getValueType() && + VT.getScalarType() == InVal.getValueType()) { + SmallVector<int, 8> NewMask(VT.getVectorNumElements(), -1); int Elt = C0->getZExtValue(); NewMask[0] = Elt; - SDValue Val; - // If we have an implict truncate do truncate here as long as it's legal. - // if it's not legal, this should - if (VT.getScalarType() != InVal.getValueType() && - InVal.getValueType().isScalarInteger() && - isTypeLegal(VT.getScalarType())) { - Val = - DAG.getNode(ISD::TRUNCATE, SDLoc(InVal), VT.getScalarType(), InVal); - return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val); - } - if (VT.getScalarType() == InVecT.getScalarType() && - VT.getVectorNumElements() <= InVecT.getVectorNumElements() && - TLI.isShuffleMaskLegal(NewMask, VT)) { - Val = DAG.getVectorShuffle(InVecT, SDLoc(N), InVec, - DAG.getUNDEF(InVecT), NewMask); - // If the initial vector is the correct size this shuffle is a - // valid result. - if (VT == InVecT) - return Val; - // If not we must truncate the vector. - if (VT.getVectorNumElements() != InVecT.getVectorNumElements()) { - MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout()); - SDValue ZeroIdx = DAG.getConstant(0, SDLoc(N), IdxTy); - EVT SubVT = - EVT::getVectorVT(*DAG.getContext(), InVecT.getVectorElementType(), - VT.getVectorNumElements()); - Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT, Val, - ZeroIdx); - return Val; - } - } + + if (TLI.isShuffleMaskLegal(NewMask, VT)) + return DAG.getVectorShuffle(VT, SDLoc(N), InVec, DAG.getUNDEF(VT), + NewMask); } } @@ -17945,109 +13610,29 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) { } SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) { - EVT VT = N->getValueType(0); SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); SDValue N2 = N->getOperand(2); - // If inserting an UNDEF, just return the original vector. - if (N1.isUndef()) - return N0; - - // If this is an insert of an extracted vector into an undef vector, we can - // just use the input to the extract. - if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR && - N1.getOperand(1) == N2 && N1.getOperand(0).getValueType() == VT) - return N1.getOperand(0); - - // If we are inserting a bitcast value into an undef, with the same - // number of elements, just use the bitcast input of the extract. - // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 -> - // BITCAST (INSERT_SUBVECTOR UNDEF N1 N2) - if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST && - N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR && - N1.getOperand(0).getOperand(1) == N2 && - N1.getOperand(0).getOperand(0).getValueType().getVectorNumElements() == - VT.getVectorNumElements() && - N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() == - VT.getSizeInBits()) { - return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0)); - } - - // If both N1 and N2 are bitcast values on which insert_subvector - // would makes sense, pull the bitcast through. - // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 -> - // BITCAST (INSERT_SUBVECTOR N0 N1 N2) - if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) { - SDValue CN0 = N0.getOperand(0); - SDValue CN1 = N1.getOperand(0); - EVT CN0VT = CN0.getValueType(); - EVT CN1VT = CN1.getValueType(); - if (CN0VT.isVector() && CN1VT.isVector() && - CN0VT.getVectorElementType() == CN1VT.getVectorElementType() && - CN0VT.getVectorNumElements() == VT.getVectorNumElements()) { - SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), - CN0.getValueType(), CN0, CN1, N2); - return DAG.getBitcast(VT, NewINSERT); - } - } - - // Combine INSERT_SUBVECTORs where we are inserting to the same index. - // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx ) - // --> INSERT_SUBVECTOR( Vec, SubNew, Idx ) - if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && - N0.getOperand(1).getValueType() == N1.getValueType() && - N0.getOperand(2) == N2) - return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0), - N1, N2); - - // Eliminate an intermediate insert into an undef vector: - // insert_subvector undef, (insert_subvector undef, X, 0), N2 --> - // insert_subvector undef, X, N2 - if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR && - N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2))) - return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0, - N1.getOperand(1), N2); - - if (!isa<ConstantSDNode>(N2)) - return SDValue(); - - unsigned InsIdx = cast<ConstantSDNode>(N2)->getZExtValue(); - - // Canonicalize insert_subvector dag nodes. - // Example: - // (insert_subvector (insert_subvector A, Idx0), Idx1) - // -> (insert_subvector (insert_subvector A, Idx1), Idx0) - if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() && - N1.getValueType() == N0.getOperand(1).getValueType() && - isa<ConstantSDNode>(N0.getOperand(2))) { - unsigned OtherIdx = N0.getConstantOperandVal(2); - if (InsIdx < OtherIdx) { - // Swap nodes. - SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, - N0.getOperand(0), N1, N2); - AddToWorklist(NewOp.getNode()); - return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()), - VT, NewOp, N0.getOperand(1), N0.getOperand(2)); - } - } - // If the input vector is a concatenation, and the insert replaces - // one of the pieces, we can optimize into a single concat_vectors. - if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() && - N0.getOperand(0).getValueType() == N1.getValueType()) { - unsigned Factor = N1.getValueType().getVectorNumElements(); + // one of the halves, we can optimize into a single concat_vectors. + if (N0.getOpcode() == ISD::CONCAT_VECTORS && + N0->getNumOperands() == 2 && N2.getOpcode() == ISD::Constant) { + APInt InsIdx = cast<ConstantSDNode>(N2)->getAPIntValue(); + EVT VT = N->getValueType(0); - SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end()); - Ops[cast<ConstantSDNode>(N2)->getZExtValue() / Factor] = N1; + // Lower half: fold (insert_subvector (concat_vectors X, Y), Z) -> + // (concat_vectors Z, Y) + if (InsIdx == 0) + return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, + N->getOperand(1), N0.getOperand(1)); - return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops); + // Upper half: fold (insert_subvector (concat_vectors X, Y), Z) -> + // (concat_vectors X, Z) + if (InsIdx == VT.getVectorNumElements()/2) + return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, + N0.getOperand(0), N->getOperand(1)); } - // Simplify source operands based on insertion. - if (SimplifyDemandedVectorElts(SDValue(N, 0))) - return SDValue(N, 0); - return SDValue(); } @@ -18081,18 +13666,22 @@ SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) { /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==> /// vector_shuffle V, Zero, <0, 4, 2, 4> SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) { - assert(N->getOpcode() == ISD::AND && "Unexpected opcode!"); - EVT VT = N->getValueType(0); SDValue LHS = N->getOperand(0); - SDValue RHS = peekThroughBitcasts(N->getOperand(1)); - SDLoc DL(N); + SDValue RHS = N->getOperand(1); + SDLoc dl(N); // Make sure we're not running after operation legalization where it // may have custom lowered the vector shuffles. if (LegalOperations) return SDValue(); + if (N->getOpcode() != ISD::AND) + return SDValue(); + + if (RHS.getOpcode() == ISD::BITCAST) + RHS = RHS.getOperand(0); + if (RHS.getOpcode() != ISD::BUILD_VECTOR) return SDValue(); @@ -18111,7 +13700,7 @@ SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) { int EltIdx = i / Split; int SubIdx = i % Split; SDValue Elt = RHS.getOperand(EltIdx); - if (Elt.isUndef()) { + if (Elt.getOpcode() == ISD::UNDEF) { Indices.push_back(-1); continue; } @@ -18126,9 +13715,9 @@ SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) { // Extract the sub element from the constant bit mask. if (DAG.getDataLayout().isBigEndian()) { - Bits.lshrInPlace((Split - SubIdx - 1) * NumSubBits); + Bits = Bits.lshr((Split - SubIdx - 1) * NumSubBits); } else { - Bits.lshrInPlace(SubIdx * NumSubBits); + Bits = Bits.lshr(SubIdx * NumSubBits); } if (Split > 1) @@ -18148,10 +13737,10 @@ SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) { if (!TLI.isVectorClearMaskLegal(Indices, ClearVT)) return SDValue(); - SDValue Zero = DAG.getConstant(0, DL, ClearVT); - return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL, + SDValue Zero = DAG.getConstant(0, dl, ClearVT); + return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, dl, DAG.getBitcast(ClearVT, LHS), - Zero, Indices)); + Zero, &Indices[0])); }; // Determine maximum split level (byte level masking). @@ -18181,13 +13770,17 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) { N->getOpcode(), SDLoc(LHS), LHS.getValueType(), Ops, N->getFlags())) return Fold; + // Try to convert a constant mask AND into a shuffle clear mask. + if (SDValue Shuffle = XformToShuffleWithZero(N)) + return Shuffle; + // Type legalization might introduce new shuffles in the DAG. // Fold (VBinOp (shuffle (A, Undef, Mask)), (shuffle (B, Undef, Mask))) // -> (shuffle (VBinOp (A, B)), Undef, Mask). if (LegalTypes && isa<ShuffleVectorSDNode>(LHS) && isa<ShuffleVectorSDNode>(RHS) && LHS.hasOneUse() && RHS.hasOneUse() && - LHS.getOperand(1).isUndef() && - RHS.getOperand(1).isUndef()) { + LHS.getOperand(1).getOpcode() == ISD::UNDEF && + RHS.getOperand(1).getOpcode() == ISD::UNDEF) { ShuffleVectorSDNode *SVN0 = cast<ShuffleVectorSDNode>(LHS); ShuffleVectorSDNode *SVN1 = cast<ShuffleVectorSDNode>(RHS); @@ -18199,15 +13792,15 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) { N->getFlags()); AddUsersToWorklist(N); return DAG.getVectorShuffle(VT, SDLoc(N), NewBinOp, UndefVector, - SVN0->getMask()); + &SVN0->getMask()[0]); } } return SDValue(); } -SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, - SDValue N2) { +SDValue DAGCombiner::SimplifySelect(SDLoc DL, SDValue N0, + SDValue N1, SDValue N2){ assert(N0.getOpcode() ==ISD::SETCC && "First argument must be a SetCC node!"); SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2, @@ -18241,33 +13834,34 @@ SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, /// the DAG combiner loop to avoid it being looked at. bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS, SDValue RHS) { - // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x)) - // The select + setcc is redundant, because fsqrt returns NaN for X < 0. + + // fold (select (setcc x, -0.0, *lt), NaN, (fsqrt x)) + // The select + setcc is redundant, because fsqrt returns NaN for X < -0. if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) { if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) { // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?)) SDValue Sqrt = RHS; ISD::CondCode CC; SDValue CmpLHS; - const ConstantFPSDNode *Zero = nullptr; + const ConstantFPSDNode *NegZero = nullptr; if (TheSelect->getOpcode() == ISD::SELECT_CC) { - CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get(); + CC = dyn_cast<CondCodeSDNode>(TheSelect->getOperand(4))->get(); CmpLHS = TheSelect->getOperand(0); - Zero = isConstOrConstSplatFP(TheSelect->getOperand(1)); + NegZero = isConstOrConstSplatFP(TheSelect->getOperand(1)); } else { // SELECT or VSELECT SDValue Cmp = TheSelect->getOperand(0); if (Cmp.getOpcode() == ISD::SETCC) { - CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get(); + CC = dyn_cast<CondCodeSDNode>(Cmp.getOperand(2))->get(); CmpLHS = Cmp.getOperand(0); - Zero = isConstOrConstSplatFP(Cmp.getOperand(1)); + NegZero = isConstOrConstSplatFP(Cmp.getOperand(1)); } } - if (Zero && Zero->isZero() && + if (NegZero && NegZero->isNegative() && NegZero->isZero() && Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT || CC == ISD::SETULT || CC == ISD::SETLT)) { - // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x)) + // We have: (select (setcc x, -0.0, *lt), NaN, (fsqrt x)) CombineTo(TheSelect, Sqrt); return true; } @@ -18315,64 +13909,31 @@ bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS, LLD->getBasePtr().getValueType())) return false; - // The loads must not depend on one another. - if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD)) - return false; - // Check that the select condition doesn't reach either load. If so, // folding this will induce a cycle into the DAG. If not, this is safe to // xform, so create a select of the addresses. - - SmallPtrSet<const SDNode *, 32> Visited; - SmallVector<const SDNode *, 16> Worklist; - - // Always fail if LLD and RLD are not independent. TheSelect is a - // predecessor to all Nodes in question so we need not search past it. - - Visited.insert(TheSelect); - Worklist.push_back(LLD); - Worklist.push_back(RLD); - - if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) || - SDNode::hasPredecessorHelper(RLD, Visited, Worklist)) - return false; - SDValue Addr; if (TheSelect->getOpcode() == ISD::SELECT) { - // We cannot do this optimization if any pair of {RLD, LLD} is a - // predecessor to {RLD, LLD, CondNode}. As we've already compared the - // Loads, we only need to check if CondNode is a successor to one of the - // loads. We can further avoid this if there's no use of their chain - // value. SDNode *CondNode = TheSelect->getOperand(0).getNode(); - Worklist.push_back(CondNode); - - if ((LLD->hasAnyUseOfValue(1) && - SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) || - (RLD->hasAnyUseOfValue(1) && - SDNode::hasPredecessorHelper(RLD, Visited, Worklist))) + if ((LLD->hasAnyUseOfValue(1) && LLD->isPredecessorOf(CondNode)) || + (RLD->hasAnyUseOfValue(1) && RLD->isPredecessorOf(CondNode))) + return false; + // The loads must not depend on one another. + if (LLD->isPredecessorOf(RLD) || + RLD->isPredecessorOf(LLD)) return false; - Addr = DAG.getSelect(SDLoc(TheSelect), LLD->getBasePtr().getValueType(), TheSelect->getOperand(0), LLD->getBasePtr(), RLD->getBasePtr()); } else { // Otherwise SELECT_CC - // We cannot do this optimization if any pair of {RLD, LLD} is a - // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared - // the Loads, we only need to check if CondLHS/CondRHS is a successor to - // one of the loads. We can further avoid this if there's no use of their - // chain value. - SDNode *CondLHS = TheSelect->getOperand(0).getNode(); SDNode *CondRHS = TheSelect->getOperand(1).getNode(); - Worklist.push_back(CondLHS); - Worklist.push_back(CondRHS); if ((LLD->hasAnyUseOfValue(1) && - SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) || + (LLD->isPredecessorOf(CondLHS) || LLD->isPredecessorOf(CondRHS))) || (RLD->hasAnyUseOfValue(1) && - SDNode::hasPredecessorHelper(RLD, Visited, Worklist))) + (RLD->isPredecessorOf(CondLHS) || RLD->isPredecessorOf(CondRHS)))) return false; Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect), @@ -18387,24 +13948,24 @@ bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS, // It is safe to replace the two loads if they have different alignments, // but the new load must be the minimum (most restrictive) alignment of the // inputs. + bool isInvariant = LLD->isInvariant() & RLD->isInvariant(); unsigned Alignment = std::min(LLD->getAlignment(), RLD->getAlignment()); - MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags(); - if (!RLD->isInvariant()) - MMOFlags &= ~MachineMemOperand::MOInvariant; - if (!RLD->isDereferenceable()) - MMOFlags &= ~MachineMemOperand::MODereferenceable; if (LLD->getExtensionType() == ISD::NON_EXTLOAD) { - // FIXME: Discards pointer and AA info. - Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect), - LLD->getChain(), Addr, MachinePointerInfo(), Alignment, - MMOFlags); + Load = DAG.getLoad(TheSelect->getValueType(0), + SDLoc(TheSelect), + // FIXME: Discards pointer and AA info. + LLD->getChain(), Addr, MachinePointerInfo(), + LLD->isVolatile(), LLD->isNonTemporal(), + isInvariant, Alignment); } else { - // FIXME: Discards pointer and AA info. - Load = DAG.getExtLoad( - LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType() - : LLD->getExtensionType(), - SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr, - MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags); + Load = DAG.getExtLoad(LLD->getExtensionType() == ISD::EXTLOAD ? + RLD->getExtensionType() : LLD->getExtensionType(), + SDLoc(TheSelect), + TheSelect->getValueType(0), + // FIXME: Discards pointer and AA info. + LLD->getChain(), Addr, MachinePointerInfo(), + LLD->getMemoryVT(), LLD->isVolatile(), + LLD->isNonTemporal(), isInvariant, Alignment); } // Users of the select now use the result of the load. @@ -18420,161 +13981,144 @@ bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS, return false; } -/// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and -/// bitwise 'and'. -SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, - SDValue N1, SDValue N2, SDValue N3, - ISD::CondCode CC) { - // If this is a select where the false operand is zero and the compare is a - // check of the sign bit, see if we can perform the "gzip trick": - // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A - // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A - EVT XType = N0.getValueType(); - EVT AType = N2.getValueType(); - if (!isNullConstant(N3) || !XType.bitsGE(AType)) - return SDValue(); - - // If the comparison is testing for a positive value, we have to invert - // the sign bit mask, so only do that transform if the target has a bitwise - // 'and not' instruction (the invert is free). - if (CC == ISD::SETGT && TLI.hasAndNot(N2)) { - // (X > -1) ? A : 0 - // (X > 0) ? X : 0 <-- This is canonical signed max. - if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2))) - return SDValue(); - } else if (CC == ISD::SETLT) { - // (X < 0) ? A : 0 - // (X < 1) ? X : 0 <-- This is un-canonicalized signed min. - if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2))) - return SDValue(); - } else { - return SDValue(); - } - - // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit - // constant. - EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType()); - auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode()); - if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) { - unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1; - SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy); - SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt); - AddToWorklist(Shift.getNode()); - - if (XType.bitsGT(AType)) { - Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift); - AddToWorklist(Shift.getNode()); - } - - if (CC == ISD::SETGT) - Shift = DAG.getNOT(DL, Shift, AType); - - return DAG.getNode(ISD::AND, DL, AType, Shift, N2); - } - - SDValue ShiftAmt = DAG.getConstant(XType.getSizeInBits() - 1, DL, ShiftAmtTy); - SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt); - AddToWorklist(Shift.getNode()); - - if (XType.bitsGT(AType)) { - Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift); - AddToWorklist(Shift.getNode()); - } - - if (CC == ISD::SETGT) - Shift = DAG.getNOT(DL, Shift, AType); - - return DAG.getNode(ISD::AND, DL, AType, Shift, N2); -} - -/// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)" -/// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0 -/// in it. This may be a win when the constant is not otherwise available -/// because it replaces two constant pool loads with one. -SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset( - const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3, - ISD::CondCode CC) { - if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType().isFloatingPoint())) - return SDValue(); - - // If we are before legalize types, we want the other legalization to happen - // first (for example, to avoid messing with soft float). - auto *TV = dyn_cast<ConstantFPSDNode>(N2); - auto *FV = dyn_cast<ConstantFPSDNode>(N3); - EVT VT = N2.getValueType(); - if (!TV || !FV || !TLI.isTypeLegal(VT)) - return SDValue(); - - // If a constant can be materialized without loads, this does not make sense. - if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal || - TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0)) || - TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0))) - return SDValue(); - - // If both constants have multiple uses, then we won't need to do an extra - // load. The values are likely around in registers for other users. - if (!TV->hasOneUse() && !FV->hasOneUse()) - return SDValue(); - - Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()), - const_cast<ConstantFP*>(TV->getConstantFPValue()) }; - Type *FPTy = Elts[0]->getType(); - const DataLayout &TD = DAG.getDataLayout(); - - // Create a ConstantArray of the two constants. - Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts); - SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()), - TD.getPrefTypeAlignment(FPTy)); - unsigned Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlignment(); - - // Get offsets to the 0 and 1 elements of the array, so we can select between - // them. - SDValue Zero = DAG.getIntPtrConstant(0, DL); - unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType()); - SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV)); - SDValue Cond = - DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC); - AddToWorklist(Cond.getNode()); - SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero); - AddToWorklist(CstOffset.getNode()); - CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset); - AddToWorklist(CPIdx.getNode()); - return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx, - MachinePointerInfo::getConstantPool( - DAG.getMachineFunction()), Alignment); -} - /// Simplify an expression of the form (N0 cond N1) ? N2 : N3 /// where 'cond' is the comparison specified by CC. -SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, - SDValue N2, SDValue N3, ISD::CondCode CC, - bool NotExtCompare) { +SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1, + SDValue N2, SDValue N3, + ISD::CondCode CC, bool NotExtCompare) { // (x ? y : y) -> y. if (N2 == N3) return N2; - EVT CmpOpVT = N0.getValueType(); EVT VT = N2.getValueType(); - auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode()); - auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode()); - auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode()); + ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getNode()); + ConstantSDNode *N2C = dyn_cast<ConstantSDNode>(N2.getNode()); - // Determine if the condition we're dealing with is constant. - SDValue SCC = SimplifySetCC(getSetCCResultType(CmpOpVT), N0, N1, CC, DL, - false); + // Determine if the condition we're dealing with is constant + SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), + N0, N1, CC, DL, false); if (SCC.getNode()) AddToWorklist(SCC.getNode()); - if (auto *SCCC = dyn_cast_or_null<ConstantSDNode>(SCC.getNode())) { + if (ConstantSDNode *SCCC = dyn_cast_or_null<ConstantSDNode>(SCC.getNode())) { // fold select_cc true, x, y -> x // fold select_cc false, x, y -> y return !SCCC->isNullValue() ? N2 : N3; } - if (SDValue V = - convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC)) - return V; + // Check to see if we can simplify the select into an fabs node + if (ConstantFPSDNode *CFP = dyn_cast<ConstantFPSDNode>(N1)) { + // Allow either -0.0 or 0.0 + if (CFP->isZero()) { + // select (setg[te] X, +/-0.0), X, fneg(X) -> fabs + if ((CC == ISD::SETGE || CC == ISD::SETGT) && + N0 == N2 && N3.getOpcode() == ISD::FNEG && + N2 == N3.getOperand(0)) + return DAG.getNode(ISD::FABS, DL, VT, N0); + + // select (setl[te] X, +/-0.0), fneg(X), X -> fabs + if ((CC == ISD::SETLT || CC == ISD::SETLE) && + N0 == N3 && N2.getOpcode() == ISD::FNEG && + N2.getOperand(0) == N3) + return DAG.getNode(ISD::FABS, DL, VT, N3); + } + } + + // Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)" + // where "tmp" is a constant pool entry containing an array with 1.0 and 2.0 + // in it. This is a win when the constant is not otherwise available because + // it replaces two constant pool loads with one. We only do this if the FP + // type is known to be legal, because if it isn't, then we are before legalize + // types an we want the other legalization to happen first (e.g. to avoid + // messing with soft float) and if the ConstantFP is not legal, because if + // it is legal, we may not need to store the FP constant in a constant pool. + if (ConstantFPSDNode *TV = dyn_cast<ConstantFPSDNode>(N2)) + if (ConstantFPSDNode *FV = dyn_cast<ConstantFPSDNode>(N3)) { + if (TLI.isTypeLegal(N2.getValueType()) && + (TLI.getOperationAction(ISD::ConstantFP, N2.getValueType()) != + TargetLowering::Legal && + !TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0)) && + !TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0))) && + // If both constants have multiple uses, then we won't need to do an + // extra load, they are likely around in registers for other users. + (TV->hasOneUse() || FV->hasOneUse())) { + Constant *Elts[] = { + const_cast<ConstantFP*>(FV->getConstantFPValue()), + const_cast<ConstantFP*>(TV->getConstantFPValue()) + }; + Type *FPTy = Elts[0]->getType(); + const DataLayout &TD = DAG.getDataLayout(); + + // Create a ConstantArray of the two constants. + Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts); + SDValue CPIdx = + DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()), + TD.getPrefTypeAlignment(FPTy)); + unsigned Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlignment(); + + // Get the offsets to the 0 and 1 element of the array so that we can + // select between them. + SDValue Zero = DAG.getIntPtrConstant(0, DL); + unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType()); + SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV)); + + SDValue Cond = DAG.getSetCC(DL, + getSetCCResultType(N0.getValueType()), + N0, N1, CC); + AddToWorklist(Cond.getNode()); + SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), + Cond, One, Zero); + AddToWorklist(CstOffset.getNode()); + CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, + CstOffset); + AddToWorklist(CPIdx.getNode()); + return DAG.getLoad( + TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx, + MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), + false, false, false, Alignment); + } + } + + // Check to see if we can perform the "gzip trick", transforming + // (select_cc setlt X, 0, A, 0) -> (and (sra X, (sub size(X), 1), A) + if (isNullConstant(N3) && CC == ISD::SETLT && + (isNullConstant(N1) || // (a < 0) ? b : 0 + (isOneConstant(N1) && N0 == N2))) { // (a < 1) ? a : 0 + EVT XType = N0.getValueType(); + EVT AType = N2.getValueType(); + if (XType.bitsGE(AType)) { + // and (sra X, size(X)-1, A) -> "and (srl X, C2), A" iff A is a + // single-bit constant. + if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) { + unsigned ShCtV = N2C->getAPIntValue().logBase2(); + ShCtV = XType.getSizeInBits() - ShCtV - 1; + SDValue ShCt = DAG.getConstant(ShCtV, SDLoc(N0), + getShiftAmountTy(N0.getValueType())); + SDValue Shift = DAG.getNode(ISD::SRL, SDLoc(N0), + XType, N0, ShCt); + AddToWorklist(Shift.getNode()); + + if (XType.bitsGT(AType)) { + Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift); + AddToWorklist(Shift.getNode()); + } - if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC)) - return V; + return DAG.getNode(ISD::AND, DL, AType, Shift, N2); + } + + SDValue Shift = DAG.getNode(ISD::SRA, SDLoc(N0), + XType, N0, + DAG.getConstant(XType.getSizeInBits() - 1, + SDLoc(N0), + getShiftAmountTy(N0.getValueType()))); + AddToWorklist(Shift.getNode()); + + if (XType.bitsGT(AType)) { + Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift); + AddToWorklist(Shift.getNode()); + } + + return DAG.getNode(ISD::AND, DL, AType, Shift, N2); + } + } // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (shr (shl x)) A) // where y is has a single bit set. @@ -18585,10 +14129,10 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND && N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) { SDValue AndLHS = N0->getOperand(0); - auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1)); + ConstantSDNode *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1)); if (ConstAndRHS && ConstAndRHS->getAPIntValue().countPopulation() == 1) { // Shift the tested bit over the sign bit. - const APInt &AndMask = ConstAndRHS->getAPIntValue(); + APInt AndMask = ConstAndRHS->getAPIntValue(); SDValue ShlAmt = DAG.getConstant(AndMask.countLeadingZeros(), SDLoc(AndLHS), getShiftAmountTy(AndLHS.getValueType())); @@ -18606,48 +14150,48 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, } // fold select C, 16, 0 -> shl C, 4 - bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2(); - bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2(); - - if ((Fold || Swap) && - TLI.getBooleanContents(CmpOpVT) == - TargetLowering::ZeroOrOneBooleanContent && - (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) { - - if (Swap) { - CC = ISD::getSetCCInverse(CC, CmpOpVT.isInteger()); - std::swap(N2C, N3C); - } + if (N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2() && + TLI.getBooleanContents(N0.getValueType()) == + TargetLowering::ZeroOrOneBooleanContent) { // If the caller doesn't want us to simplify this into a zext of a compare, // don't do it. if (NotExtCompare && N2C->isOne()) return SDValue(); - SDValue Temp, SCC; - // zext (setcc n0, n1) - if (LegalTypes) { - SCC = DAG.getSetCC(DL, getSetCCResultType(CmpOpVT), N0, N1, CC); - if (VT.bitsLT(SCC.getValueType())) - Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2), VT); - else - Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC); - } else { - SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC); - Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC); - } + // Get a SetCC of the condition + // NOTE: Don't create a SETCC if it's not legal on this target. + if (!LegalOperations || + TLI.isOperationLegal(ISD::SETCC, N0.getValueType())) { + SDValue Temp, SCC; + // cast from setcc result type to select result type + if (LegalTypes) { + SCC = DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), + N0, N1, CC); + if (N2.getValueType().bitsLT(SCC.getValueType())) + Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2), + N2.getValueType()); + else + Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), + N2.getValueType(), SCC); + } else { + SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC); + Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), + N2.getValueType(), SCC); + } - AddToWorklist(SCC.getNode()); - AddToWorklist(Temp.getNode()); + AddToWorklist(SCC.getNode()); + AddToWorklist(Temp.getNode()); - if (N2C->isOne()) - return Temp; + if (N2C->isOne()) + return Temp; - // shl setcc result by log2 n2c - return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp, - DAG.getConstant(N2C->getAPIntValue().logBase2(), - SDLoc(Temp), - getShiftAmountTy(Temp.getValueType()))); + // shl setcc result by log2 n2c + return DAG.getNode( + ISD::SHL, DL, N2.getValueType(), Temp, + DAG.getConstant(N2C->getAPIntValue().logBase2(), SDLoc(Temp), + getShiftAmountTy(Temp.getValueType()))); + } } // Check to see if this is an integer abs. @@ -18667,51 +14211,18 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, N0 == N3 && N2.getOpcode() == ISD::SUB && N0 == N2.getOperand(1)) SubC = dyn_cast<ConstantSDNode>(N2.getOperand(0)); - if (SubC && SubC->isNullValue() && CmpOpVT.isInteger()) { + EVT XType = N0.getValueType(); + if (SubC && SubC->isNullValue() && XType.isInteger()) { SDLoc DL(N0); - SDValue Shift = DAG.getNode(ISD::SRA, DL, CmpOpVT, N0, - DAG.getConstant(CmpOpVT.getSizeInBits() - 1, - DL, - getShiftAmountTy(CmpOpVT))); - SDValue Add = DAG.getNode(ISD::ADD, DL, CmpOpVT, N0, Shift); + SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, + N0, + DAG.getConstant(XType.getSizeInBits() - 1, DL, + getShiftAmountTy(N0.getValueType()))); + SDValue Add = DAG.getNode(ISD::ADD, DL, + XType, N0, Shift); AddToWorklist(Shift.getNode()); AddToWorklist(Add.getNode()); - return DAG.getNode(ISD::XOR, DL, CmpOpVT, Add, Shift); - } - } - - // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X) - // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X) - // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X) - // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X) - // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X) - // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X) - // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X) - // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X) - if (N1C && N1C->isNullValue() && (CC == ISD::SETEQ || CC == ISD::SETNE)) { - SDValue ValueOnZero = N2; - SDValue Count = N3; - // If the condition is NE instead of E, swap the operands. - if (CC == ISD::SETNE) - std::swap(ValueOnZero, Count); - // Check if the value on zero is a constant equal to the bits in the type. - if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) { - if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) { - // If the other operand is cttz/cttz_zero_undef of N0, and cttz is - // legal, combine to just cttz. - if ((Count.getOpcode() == ISD::CTTZ || - Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) && - N0 == Count.getOperand(0) && - (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT))) - return DAG.getNode(ISD::CTTZ, DL, VT, N0); - // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is - // legal, combine to just ctlz. - if ((Count.getOpcode() == ISD::CTLZ || - Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) && - N0 == Count.getOperand(0) && - (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT))) - return DAG.getNode(ISD::CTLZ, DL, VT, N0); - } + return DAG.getNode(ISD::XOR, DL, XType, Add, Shift); } } @@ -18719,9 +14230,9 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, } /// This is a stub for TargetLowering::SimplifySetCC. -SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, - ISD::CondCode Cond, const SDLoc &DL, - bool foldBooleans) { +SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, + SDValue N1, ISD::CondCode Cond, + SDLoc DL, bool foldBooleans) { TargetLowering::DAGCombinerInfo DagCombineInfo(DAG, Level, false, this); return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL); @@ -18732,19 +14243,21 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, /// by a magic number. /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide". SDValue DAGCombiner::BuildSDIV(SDNode *N) { - // when optimising for minimum size, we don't want to expand a div to a mul - // and a shift. - if (DAG.getMachineFunction().getFunction().optForMinSize()) + ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1)); + if (!C) return SDValue(); - SmallVector<SDNode *, 8> Built; - if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, Built)) { - for (SDNode *N : Built) - AddToWorklist(N); - return S; - } + // Avoid division by zero. + if (C->isNullValue()) + return SDValue(); - return SDValue(); + std::vector<SDNode*> Built; + SDValue S = + TLI.BuildSDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built); + + for (SDNode *N : Built) + AddToWorklist(N); + return S; } /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a @@ -18758,14 +14271,12 @@ SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) { if (C->isNullValue()) return SDValue(); - SmallVector<SDNode *, 8> Built; - if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) { - for (SDNode *N : Built) - AddToWorklist(N); - return S; - } + std::vector<SDNode *> Built; + SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, &Built); - return SDValue(); + for (SDNode *N : Built) + AddToWorklist(N); + return S; } /// Given an ISD::UDIV node expressing a divide by constant, return a DAG @@ -18773,66 +14284,47 @@ SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) { /// number. /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide". SDValue DAGCombiner::BuildUDIV(SDNode *N) { - // when optimising for minimum size, we don't want to expand a div to a mul - // and a shift. - if (DAG.getMachineFunction().getFunction().optForMinSize()) + ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1)); + if (!C) return SDValue(); - SmallVector<SDNode *, 8> Built; - if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, Built)) { - for (SDNode *N : Built) - AddToWorklist(N); - return S; - } + // Avoid division by zero. + if (C->isNullValue()) + return SDValue(); - return SDValue(); -} + std::vector<SDNode*> Built; + SDValue S = + TLI.BuildUDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built); -/// Determines the LogBase2 value for a non-null input value using the -/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V). -SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) { - EVT VT = V.getValueType(); - unsigned EltBits = VT.getScalarSizeInBits(); - SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V); - SDValue Base = DAG.getConstant(EltBits - 1, DL, VT); - SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz); - return LogBase2; + for (SDNode *N : Built) + AddToWorklist(N); + return S; } -/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i) -/// For the reciprocal, we need to find the zero of the function: -/// F(X) = A X - 1 [which has a zero at X = 1/A] -/// => -/// X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form -/// does not require additional intermediate precision] -SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags Flags) { +SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags) { if (Level >= AfterLegalizeDAG) return SDValue(); - // TODO: Handle half and/or extended types? - EVT VT = Op.getValueType(); - if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64) - return SDValue(); - - // If estimates are explicitly disabled for this function, we're done. - MachineFunction &MF = DAG.getMachineFunction(); - int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF); - if (Enabled == TLI.ReciprocalEstimate::Disabled) - return SDValue(); - - // Estimates may be explicitly enabled for this type with a custom number of - // refinement steps. - int Iterations = TLI.getDivRefinementSteps(VT, MF); - if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) { - AddToWorklist(Est.getNode()); + // Expose the DAG combiner to the target combiner implementations. + TargetLowering::DAGCombinerInfo DCI(DAG, Level, false, this); + unsigned Iterations = 0; + if (SDValue Est = TLI.getRecipEstimate(Op, DCI, Iterations)) { if (Iterations) { + // Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i) + // For the reciprocal, we need to find the zero of the function: + // F(X) = A X - 1 [which has a zero at X = 1/A] + // => + // X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form + // does not require additional intermediate precision] EVT VT = Op.getValueType(); SDLoc DL(Op); SDValue FPOne = DAG.getConstantFP(1.0, DL, VT); + AddToWorklist(Est.getNode()); + // Newton iterations: Est = Est + Est (1 - Arg * Est) - for (int i = 0; i < Iterations; ++i) { + for (unsigned i = 0; i < Iterations; ++i) { SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, Est, Flags); AddToWorklist(NewEst.getNode()); @@ -18858,9 +14350,9 @@ SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags Flags) { /// => /// X_{i+1} = X_i (1.5 - A X_i^2 / 2) /// As a result, we precompute A/2 prior to the iteration loop. -SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est, - unsigned Iterations, - SDNodeFlags Flags, bool Reciprocal) { +SDValue DAGCombiner::BuildRsqrtNROneConst(SDValue Arg, SDValue Est, + unsigned Iterations, + SDNodeFlags *Flags) { EVT VT = Arg.getValueType(); SDLoc DL(Arg); SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT); @@ -18887,13 +14379,6 @@ SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est, Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags); AddToWorklist(Est.getNode()); } - - // If non-reciprocal square root is requested, multiply the result by Arg. - if (!Reciprocal) { - Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags); - AddToWorklist(Est.getNode()); - } - return Est; } @@ -18902,114 +14387,48 @@ SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est, /// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)] /// => /// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0)) -SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est, - unsigned Iterations, - SDNodeFlags Flags, bool Reciprocal) { +SDValue DAGCombiner::BuildRsqrtNRTwoConst(SDValue Arg, SDValue Est, + unsigned Iterations, + SDNodeFlags *Flags) { EVT VT = Arg.getValueType(); SDLoc DL(Arg); SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT); SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT); - // This routine must enter the loop below to work correctly - // when (Reciprocal == false). - assert(Iterations > 0); - - // Newton iterations for reciprocal square root: - // E = (E * -0.5) * ((A * E) * E + -3.0) + // Newton iterations: Est = -0.5 * Est * (-3.0 + Arg * Est * Est) for (unsigned i = 0; i < Iterations; ++i) { - SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags); - AddToWorklist(AE.getNode()); - - SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags); - AddToWorklist(AEE.getNode()); - - SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags); - AddToWorklist(RHS.getNode()); - - // When calculating a square root at the last iteration build: - // S = ((A * E) * -0.5) * ((A * E) * E + -3.0) - // (notice a common subexpression) - SDValue LHS; - if (Reciprocal || (i + 1) < Iterations) { - // RSQRT: LHS = (E * -0.5) - LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags); - } else { - // SQRT: LHS = (A * E) * -0.5 - LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags); - } - AddToWorklist(LHS.getNode()); + SDValue HalfEst = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags); + AddToWorklist(HalfEst.getNode()); - Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags); + Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags); + AddToWorklist(Est.getNode()); + + Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags); + AddToWorklist(Est.getNode()); + + Est = DAG.getNode(ISD::FADD, DL, VT, Est, MinusThree, Flags); AddToWorklist(Est.getNode()); - } + Est = DAG.getNode(ISD::FMUL, DL, VT, Est, HalfEst, Flags); + AddToWorklist(Est.getNode()); + } return Est; } -/// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case -/// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if -/// Op can be zero. -SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, - bool Reciprocal) { +SDValue DAGCombiner::BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags) { if (Level >= AfterLegalizeDAG) return SDValue(); - // TODO: Handle half and/or extended types? - EVT VT = Op.getValueType(); - if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64) - return SDValue(); - - // If estimates are explicitly disabled for this function, we're done. - MachineFunction &MF = DAG.getMachineFunction(); - int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF); - if (Enabled == TLI.ReciprocalEstimate::Disabled) - return SDValue(); - - // Estimates may be explicitly enabled for this type with a custom number of - // refinement steps. - int Iterations = TLI.getSqrtRefinementSteps(VT, MF); - + // Expose the DAG combiner to the target combiner implementations. + TargetLowering::DAGCombinerInfo DCI(DAG, Level, false, this); + unsigned Iterations = 0; bool UseOneConstNR = false; - if (SDValue Est = - TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR, - Reciprocal)) { + if (SDValue Est = TLI.getRsqrtEstimate(Op, DCI, Iterations, UseOneConstNR)) { AddToWorklist(Est.getNode()); - if (Iterations) { - Est = UseOneConstNR - ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal) - : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal); - - if (!Reciprocal) { - // The estimate is now completely wrong if the input was exactly 0.0 or - // possibly a denormal. Force the answer to 0.0 for those cases. - EVT VT = Op.getValueType(); - SDLoc DL(Op); - EVT CCVT = getSetCCResultType(VT); - ISD::NodeType SelOpcode = VT.isVector() ? ISD::VSELECT : ISD::SELECT; - const Function &F = DAG.getMachineFunction().getFunction(); - Attribute Denorms = F.getFnAttribute("denormal-fp-math"); - if (Denorms.getValueAsString().equals("ieee")) { - // fabs(X) < SmallestNormal ? 0.0 : Est - const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT); - APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem); - SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT); - SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); - SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op); - SDValue IsDenorm = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT); - Est = DAG.getNode(SelOpcode, DL, VT, IsDenorm, FPZero, Est); - AddToWorklist(Fabs.getNode()); - AddToWorklist(IsDenorm.getNode()); - AddToWorklist(Est.getNode()); - } else { - // X == 0.0 ? 0.0 : Est - SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); - SDValue IsZero = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ); - Est = DAG.getNode(SelOpcode, DL, VT, IsZero, FPZero, Est); - AddToWorklist(IsZero.getNode()); - AddToWorklist(Est.getNode()); - } - } + Est = UseOneConstNR ? + BuildRsqrtNROneConst(Op, Est, Iterations, Flags) : + BuildRsqrtNRTwoConst(Op, Est, Iterations, Flags); } return Est; } @@ -19017,12 +14436,41 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, return SDValue(); } -SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) { - return buildSqrtEstimateImpl(Op, Flags, true); -} +/// Return true if base is a frame index, which is known not to alias with +/// anything but itself. Provides base object and offset as results. +static bool FindBaseOffset(SDValue Ptr, SDValue &Base, int64_t &Offset, + const GlobalValue *&GV, const void *&CV) { + // Assume it is a primitive operation. + Base = Ptr; Offset = 0; GV = nullptr; CV = nullptr; + + // If it's an adding a simple constant then integrate the offset. + if (Base.getOpcode() == ISD::ADD) { + if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Base.getOperand(1))) { + Base = Base.getOperand(0); + Offset += C->getZExtValue(); + } + } + + // Return the underlying GlobalValue, and update the Offset. Return false + // for GlobalAddressSDNode since the same GlobalAddress may be represented + // by multiple nodes with different offsets. + if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Base)) { + GV = G->getGlobal(); + Offset += G->getOffset(); + return false; + } -SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) { - return buildSqrtEstimateImpl(Op, Flags, false); + // Return the underlying Constant value, and update the Offset. Return false + // for ConstantSDNodes since the same constant pool entry may be represented + // by multiple nodes with different offsets. + if (ConstantPoolSDNode *C = dyn_cast<ConstantPoolSDNode>(Base)) { + CV = C->isMachineConstantPoolEntry() ? (const void *)C->getMachineCPVal() + : (const void *)C->getConstVal(); + Offset += C->getOffset(); + return false; + } + // If it's any of the following then it can't alias with anything but itself. + return isa<FrameIndexSDNode>(Base); } /// Return true if there is any possibility that the two addresses overlap. @@ -19042,63 +14490,54 @@ bool DAGCombiner::isAlias(LSBaseSDNode *Op0, LSBaseSDNode *Op1) const { if (Op1->isInvariant() && Op0->writeMem()) return false; - unsigned NumBytes0 = Op0->getMemoryVT().getStoreSize(); - unsigned NumBytes1 = Op1->getMemoryVT().getStoreSize(); - - // Check for BaseIndexOffset matching. - BaseIndexOffset BasePtr0 = BaseIndexOffset::match(Op0, DAG); - BaseIndexOffset BasePtr1 = BaseIndexOffset::match(Op1, DAG); - int64_t PtrDiff; - if (BasePtr0.getBase().getNode() && BasePtr1.getBase().getNode()) { - if (BasePtr0.equalBaseIndex(BasePtr1, DAG, PtrDiff)) - return !((NumBytes0 <= PtrDiff) || (PtrDiff + NumBytes1 <= 0)); - - // If both BasePtr0 and BasePtr1 are FrameIndexes, we will not be - // able to calculate their relative offset if at least one arises - // from an alloca. However, these allocas cannot overlap and we - // can infer there is no alias. - if (auto *A = dyn_cast<FrameIndexSDNode>(BasePtr0.getBase())) - if (auto *B = dyn_cast<FrameIndexSDNode>(BasePtr1.getBase())) { - MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo(); - // If the base are the same frame index but the we couldn't find a - // constant offset, (indices are different) be conservative. - if (A != B && (!MFI.isFixedObjectIndex(A->getIndex()) || - !MFI.isFixedObjectIndex(B->getIndex()))) - return false; - } - - bool IsFI0 = isa<FrameIndexSDNode>(BasePtr0.getBase()); - bool IsFI1 = isa<FrameIndexSDNode>(BasePtr1.getBase()); - bool IsGV0 = isa<GlobalAddressSDNode>(BasePtr0.getBase()); - bool IsGV1 = isa<GlobalAddressSDNode>(BasePtr1.getBase()); - bool IsCV0 = isa<ConstantPoolSDNode>(BasePtr0.getBase()); - bool IsCV1 = isa<ConstantPoolSDNode>(BasePtr1.getBase()); - - // If of mismatched base types or checkable indices we can check - // they do not alias. - if ((BasePtr0.getIndex() == BasePtr1.getIndex() || (IsFI0 != IsFI1) || - (IsGV0 != IsGV1) || (IsCV0 != IsCV1)) && - (IsFI0 || IsGV0 || IsCV0) && (IsFI1 || IsGV1 || IsCV1)) - return false; - } + // Gather base node and offset information. + SDValue Base1, Base2; + int64_t Offset1, Offset2; + const GlobalValue *GV1, *GV2; + const void *CV1, *CV2; + bool isFrameIndex1 = FindBaseOffset(Op0->getBasePtr(), + Base1, Offset1, GV1, CV1); + bool isFrameIndex2 = FindBaseOffset(Op1->getBasePtr(), + Base2, Offset2, GV2, CV2); + + // If they have a same base address then check to see if they overlap. + if (Base1 == Base2 || (GV1 && (GV1 == GV2)) || (CV1 && (CV1 == CV2))) + return !((Offset1 + (Op0->getMemoryVT().getSizeInBits() >> 3)) <= Offset2 || + (Offset2 + (Op1->getMemoryVT().getSizeInBits() >> 3)) <= Offset1); + + // It is possible for different frame indices to alias each other, mostly + // when tail call optimization reuses return address slots for arguments. + // To catch this case, look up the actual index of frame indices to compute + // the real alias relationship. + if (isFrameIndex1 && isFrameIndex2) { + MachineFrameInfo *MFI = DAG.getMachineFunction().getFrameInfo(); + Offset1 += MFI->getObjectOffset(cast<FrameIndexSDNode>(Base1)->getIndex()); + Offset2 += MFI->getObjectOffset(cast<FrameIndexSDNode>(Base2)->getIndex()); + return !((Offset1 + (Op0->getMemoryVT().getSizeInBits() >> 3)) <= Offset2 || + (Offset2 + (Op1->getMemoryVT().getSizeInBits() >> 3)) <= Offset1); + } + + // Otherwise, if we know what the bases are, and they aren't identical, then + // we know they cannot alias. + if ((isFrameIndex1 || CV1 || GV1) && (isFrameIndex2 || CV2 || GV2)) + return false; - // If we know required SrcValue1 and SrcValue2 have relatively large - // alignment compared to the size and offset of the access, we may be able - // to prove they do not alias. This check is conservative for now to catch - // cases created by splitting vector types. - int64_t SrcValOffset0 = Op0->getSrcValueOffset(); - int64_t SrcValOffset1 = Op1->getSrcValueOffset(); - unsigned OrigAlignment0 = Op0->getOriginalAlignment(); - unsigned OrigAlignment1 = Op1->getOriginalAlignment(); - if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 && - NumBytes0 == NumBytes1 && OrigAlignment0 > NumBytes0) { - int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0; - int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1; - - // There is no overlap between these relatively aligned accesses of - // similar size. Return no alias. - if ((OffAlign0 + NumBytes0) <= OffAlign1 || - (OffAlign1 + NumBytes1) <= OffAlign0) + // If we know required SrcValue1 and SrcValue2 have relatively large alignment + // compared to the size and offset of the access, we may be able to prove they + // do not alias. This check is conservative for now to catch cases created by + // splitting vector types. + if ((Op0->getOriginalAlignment() == Op1->getOriginalAlignment()) && + (Op0->getSrcValueOffset() != Op1->getSrcValueOffset()) && + (Op0->getMemoryVT().getSizeInBits() >> 3 == + Op1->getMemoryVT().getSizeInBits() >> 3) && + (Op0->getOriginalAlignment() > Op0->getMemoryVT().getSizeInBits()) >> 3) { + int64_t OffAlign1 = Op0->getSrcValueOffset() % Op0->getOriginalAlignment(); + int64_t OffAlign2 = Op1->getSrcValueOffset() % Op1->getOriginalAlignment(); + + // There is no overlap between these relatively aligned accesses of similar + // size, return no alias. + if ((OffAlign1 + (Op0->getMemoryVT().getSizeInBits() >> 3)) <= OffAlign2 || + (OffAlign2 + (Op1->getMemoryVT().getSizeInBits() >> 3)) <= OffAlign1) return false; } @@ -19110,18 +14549,20 @@ bool DAGCombiner::isAlias(LSBaseSDNode *Op0, LSBaseSDNode *Op1) const { CombinerAAOnlyFunc != DAG.getMachineFunction().getName()) UseAA = false; #endif - - if (UseAA && AA && + if (UseAA && Op0->getMemOperand()->getValue() && Op1->getMemOperand()->getValue()) { // Use alias analysis information. - int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1); - int64_t Overlap0 = NumBytes0 + SrcValOffset0 - MinOffset; - int64_t Overlap1 = NumBytes1 + SrcValOffset1 - MinOffset; + int64_t MinOffset = std::min(Op0->getSrcValueOffset(), + Op1->getSrcValueOffset()); + int64_t Overlap1 = (Op0->getMemoryVT().getSizeInBits() >> 3) + + Op0->getSrcValueOffset() - MinOffset; + int64_t Overlap2 = (Op1->getMemoryVT().getSizeInBits() >> 3) + + Op1->getSrcValueOffset() - MinOffset; AliasResult AAResult = - AA->alias(MemoryLocation(Op0->getMemOperand()->getValue(), Overlap0, - UseTBAA ? Op0->getAAInfo() : AAMDNodes()), - MemoryLocation(Op1->getMemOperand()->getValue(), Overlap1, - UseTBAA ? Op1->getAAInfo() : AAMDNodes()) ); + AA.alias(MemoryLocation(Op0->getMemOperand()->getValue(), Overlap1, + UseTBAA ? Op0->getAAInfo() : AAMDNodes()), + MemoryLocation(Op1->getMemOperand()->getValue(), Overlap2, + UseTBAA ? Op1->getAAInfo() : AAMDNodes())); if (AAResult == NoAlias) return false; } @@ -19203,28 +14644,75 @@ void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain, ++Depth; break; - case ISD::CopyFromReg: - // Forward past CopyFromReg. - Chains.push_back(Chain.getOperand(0)); - ++Depth; - break; - default: // For all other instructions we will just have to take what we can get. Aliases.push_back(Chain); break; } } + + // We need to be careful here to also search for aliases through the + // value operand of a store, etc. Consider the following situation: + // Token1 = ... + // L1 = load Token1, %52 + // S1 = store Token1, L1, %51 + // L2 = load Token1, %52+8 + // S2 = store Token1, L2, %51+8 + // Token2 = Token(S1, S2) + // L3 = load Token2, %53 + // S3 = store Token2, L3, %52 + // L4 = load Token2, %53+8 + // S4 = store Token2, L4, %52+8 + // If we search for aliases of S3 (which loads address %52), and we look + // only through the chain, then we'll miss the trivial dependence on L1 + // (which also loads from %52). We then might change all loads and + // stores to use Token1 as their chain operand, which could result in + // copying %53 into %52 before copying %52 into %51 (which should + // happen first). + // + // The problem is, however, that searching for such data dependencies + // can become expensive, and the cost is not directly related to the + // chain depth. Instead, we'll rule out such configurations here by + // insisting that we've visited all chain users (except for users + // of the original chain, which is not necessary). When doing this, + // we need to look through nodes we don't care about (otherwise, things + // like register copies will interfere with trivial cases). + + SmallVector<const SDNode *, 16> Worklist; + for (const SDNode *N : Visited) + if (N != OriginalChain.getNode()) + Worklist.push_back(N); + + while (!Worklist.empty()) { + const SDNode *M = Worklist.pop_back_val(); + + // We have already visited M, and want to make sure we've visited any uses + // of M that we care about. For uses that we've not visisted, and don't + // care about, queue them to the worklist. + + for (SDNode::use_iterator UI = M->use_begin(), + UIE = M->use_end(); UI != UIE; ++UI) + if (UI.getUse().getValueType() == MVT::Other && + Visited.insert(*UI).second) { + if (isa<MemSDNode>(*UI)) { + // We've not visited this use, and we care about it (it could have an + // ordering dependency with the original node). + Aliases.clear(); + Aliases.push_back(OriginalChain); + return; + } + + // We've not visited this use, but we don't care about it. Mark it as + // visited and enqueue it to the worklist. + Worklist.push_back(*UI); + } + } } /// Walk up chain skipping non-aliasing memory nodes, looking for a better chain /// (aliasing node.) SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) { - if (OptLevel == CodeGenOpt::None) - return OldChain; - - // Ops for replacing token factor. - SmallVector<SDValue, 8> Aliases; + SmallVector<SDValue, 8> Aliases; // Ops for replacing token factor. // Accumulate all the aliases to this node. GatherAllAliases(N, OldChain, Aliases); @@ -19241,160 +14729,85 @@ SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) { return DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Aliases); } -// TODO: Replace with with std::monostate when we move to C++17. -struct UnitT { } Unit; -bool operator==(const UnitT &, const UnitT &) { return true; } -bool operator!=(const UnitT &, const UnitT &) { return false; } - -// This function tries to collect a bunch of potentially interesting -// nodes to improve the chains of, all at once. This might seem -// redundant, as this function gets called when visiting every store -// node, so why not let the work be done on each store as it's visited? -// -// I believe this is mainly important because MergeConsecutiveStores -// is unable to deal with merging stores of different sizes, so unless -// we improve the chains of all the potential candidates up-front -// before running MergeConsecutiveStores, it might only see some of -// the nodes that will eventually be candidates, and then not be able -// to go from a partially-merged state to the desired final -// fully-merged state. - -bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) { - SmallVector<StoreSDNode *, 8> ChainedStores; - StoreSDNode *STChain = St; - // Intervals records which offsets from BaseIndex have been covered. In - // the common case, every store writes to the immediately previous address - // space and thus merged with the previous interval at insertion time. - - using IMap = - llvm::IntervalMap<int64_t, UnitT, 8, IntervalMapHalfOpenInfo<int64_t>>; - IMap::Allocator A; - IMap Intervals(A); - +bool DAGCombiner::findBetterNeighborChains(StoreSDNode* St) { // This holds the base pointer, index, and the offset in bytes from the base // pointer. - const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG); + BaseIndexOffset BasePtr = BaseIndexOffset::match(St->getBasePtr()); // We must have a base and an offset. - if (!BasePtr.getBase().getNode()) + if (!BasePtr.Base.getNode()) return false; // Do not handle stores to undef base pointers. - if (BasePtr.getBase().isUndef()) + if (BasePtr.Base.getOpcode() == ISD::UNDEF) return false; - // Add ST's interval. - Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8, Unit); + SmallVector<StoreSDNode *, 8> ChainedStores; + ChainedStores.push_back(St); - while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) { + // Walk up the chain and look for nodes with offsets from the same + // base pointer. Stop when reaching an instruction with a different kind + // or instruction which has a different base pointer. + StoreSDNode *Index = St; + while (Index) { // If the chain has more than one use, then we can't reorder the mem ops. - if (!SDValue(Chain, 0)->hasOneUse()) + if (Index != St && !SDValue(Index, 0)->hasOneUse()) break; - if (Chain->isVolatile() || Chain->isIndexed()) + + if (Index->isVolatile() || Index->isIndexed()) break; // Find the base pointer and offset for this memory node. - const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG); + BaseIndexOffset Ptr = BaseIndexOffset::match(Index->getBasePtr()); + // Check that the base pointer is the same as the original one. - int64_t Offset; - if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset)) - break; - int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8; - // Make sure we don't overlap with other intervals by checking the ones to - // the left or right before inserting. - auto I = Intervals.find(Offset); - // If there's a next interval, we should end before it. - if (I != Intervals.end() && I.start() < (Offset + Length)) - break; - // If there's a previous interval, we should start after it. - if (I != Intervals.begin() && (--I).stop() <= Offset) + if (!Ptr.equalBaseIndex(BasePtr)) break; - Intervals.insert(Offset, Offset + Length, Unit); - ChainedStores.push_back(Chain); - STChain = Chain; + // Find the next memory operand in the chain. If the next operand in the + // chain is a store then move up and continue the scan with the next + // memory operand. If the next operand is a load save it and use alias + // information to check if it interferes with anything. + SDNode *NextInChain = Index->getChain().getNode(); + while (true) { + if (StoreSDNode *STn = dyn_cast<StoreSDNode>(NextInChain)) { + // We found a store node. Use it for the next iteration. + ChainedStores.push_back(STn); + Index = STn; + break; + } else if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(NextInChain)) { + NextInChain = Ldn->getChain().getNode(); + continue; + } else { + Index = nullptr; + break; + } + } } - // If we didn't find a chained store, exit. - if (ChainedStores.size() == 0) - return false; - - // Improve all chained stores (St and ChainedStores members) starting from - // where the store chain ended and return single TokenFactor. - SDValue NewChain = STChain->getChain(); - SmallVector<SDValue, 8> TFOps; - for (unsigned I = ChainedStores.size(); I;) { - StoreSDNode *S = ChainedStores[--I]; - SDValue BetterChain = FindBetterChain(S, NewChain); - S = cast<StoreSDNode>(DAG.UpdateNodeOperands( - S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3))); - TFOps.push_back(SDValue(S, 0)); - ChainedStores[I] = S; - } - - // Improve St's chain. Use a new node to avoid creating a loop from CombineTo. - SDValue BetterChain = FindBetterChain(St, NewChain); - SDValue NewST; - if (St->isTruncatingStore()) - NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(), - St->getBasePtr(), St->getMemoryVT(), - St->getMemOperand()); - else - NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(), - St->getBasePtr(), St->getMemOperand()); - - TFOps.push_back(NewST); - - // If we improved every element of TFOps, then we've lost the dependence on - // NewChain to successors of St and we need to add it back to TFOps. Do so at - // the beginning to keep relative order consistent with FindBetterChains. - auto hasImprovedChain = [&](SDValue ST) -> bool { - return ST->getOperand(0) != NewChain; - }; - bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain); - if (AddNewChain) - TFOps.insert(TFOps.begin(), NewChain); - - SDValue TF = DAG.getNode(ISD::TokenFactor, SDLoc(STChain), MVT::Other, TFOps); - CombineTo(St, TF); - - AddToWorklist(STChain); - // Add TF operands worklist in reverse order. - for (auto I = TF->getNumOperands(); I;) - AddToWorklist(TF->getOperand(--I).getNode()); - AddToWorklist(TF.getNode()); - return true; -} - -bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) { - if (OptLevel == CodeGenOpt::None) - return false; - - const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG); + bool MadeChange = false; + SmallVector<std::pair<StoreSDNode *, SDValue>, 8> BetterChains; - // We must have a base and an offset. - if (!BasePtr.getBase().getNode()) - return false; + for (StoreSDNode *ChainedStore : ChainedStores) { + SDValue Chain = ChainedStore->getChain(); + SDValue BetterChain = FindBetterChain(ChainedStore, Chain); - // Do not handle stores to undef base pointers. - if (BasePtr.getBase().isUndef()) - return false; + if (Chain != BetterChain) { + MadeChange = true; + BetterChains.push_back(std::make_pair(ChainedStore, BetterChain)); + } + } - // Directly improve a chain of disjoint stores starting at St. - if (parallelizeChainedStores(St)) - return true; + // Do all replacements after finding the replacements to make to avoid making + // the chains more complicated by introducing new TokenFactors. + for (auto Replacement : BetterChains) + replaceStoreChain(Replacement.first, Replacement.second); - // Improve St's Chain.. - SDValue BetterChain = FindBetterChain(St, St->getChain()); - if (St->getChain() != BetterChain) { - replaceStoreChain(St, BetterChain); - return true; - } - return false; + return MadeChange; } /// This is the entry point for the file. -void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA, +void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis &AA, CodeGenOpt::Level OptLevel) { /// This is the main entry point to this class. DAGCombiner(*this, AA, OptLevel).Run(Level); |